Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b9e12416
Commit
b9e12416
authored
May 31, 2024
by
zhuwenwen
Browse files
merge v0.4.3
parents
e5d707db
e9d3aa04
Changes
345
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1882 additions
and
606 deletions
+1882
-606
vllm/distributed/utils.py
vllm/distributed/utils.py
+1
-89
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+76
-21
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+203
-100
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+276
-84
vllm/engine/metrics.py
vllm/engine/metrics.py
+8
-1
vllm/engine/output_processor/multi_step.py
vllm/engine/output_processor/multi_step.py
+3
-1
vllm/engine/output_processor/single_step.py
vllm/engine/output_processor/single_step.py
+6
-2
vllm/engine/output_processor/stop_checker.py
vllm/engine/output_processor/stop_checker.py
+22
-4
vllm/engine/output_processor/util.py
vllm/engine/output_processor/util.py
+6
-4
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+380
-70
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+41
-6
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+157
-13
vllm/entrypoints/openai/run_batch.py
vllm/entrypoints/openai/run_batch.py
+141
-0
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+147
-73
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+83
-15
vllm/entrypoints/openai/serving_embedding.py
vllm/entrypoints/openai/serving_embedding.py
+144
-0
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+59
-96
vllm/envs.py
vllm/envs.py
+13
-1
vllm/executor/distributed_gpu_executor.py
vllm/executor/distributed_gpu_executor.py
+108
-26
vllm/executor/executor_base.py
vllm/executor/executor_base.py
+8
-0
No files found.
Too many changes to show.
To preserve performance only
345 of 345+
files are displayed.
Plain diff
Email patch
vllm/distributed/utils.py
View file @
b9e12416
...
@@ -2,19 +2,9 @@
...
@@ -2,19 +2,9 @@
# Adapted from
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import
json
from
typing
import
Sequence
import
os
from
typing
import
Dict
,
Optional
,
Sequence
import
torch
import
torch
import
torch.distributed
as
dist
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
.parallel_state
import
get_cpu_world_group
,
get_local_rank
logger
=
init_logger
(
__name__
)
def
ensure_divisibility
(
numerator
,
denominator
):
def
ensure_divisibility
(
numerator
,
denominator
):
...
@@ -56,81 +46,3 @@ def split_tensor_along_last_dim(
...
@@ -56,81 +46,3 @@ def split_tensor_along_last_dim(
return
tuple
(
chunk
.
contiguous
()
for
chunk
in
tensor_list
)
return
tuple
(
chunk
.
contiguous
()
for
chunk
in
tensor_list
)
return
tensor_list
return
tensor_list
# code partly borrowed from
# https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10
# License: MIT
def
_can_actually_p2p
(
idx_a
,
idx_b
):
dev_i
=
f
"cuda:
{
idx_a
}
"
dev_j
=
f
"cuda:
{
idx_b
}
"
a
=
torch
.
randn
(
5
,
device
=
dev_i
)
+
123.0
b
=
a
.
to
(
dev_j
)
c
=
b
.
to
(
dev_i
)
return
torch
.
all
(
a
==
c
).
cpu
().
item
()
# why do we need this cache?
# 1. we can have runtime checks for P2P access, where every process checks
# P2P access to all other GPUs. Unfortunately, the test might cost many
# (world_size * world_size) cuda context, and reduce the memory available
# for the model. see https://github.com/vllm-project/vllm/issues/3821
# 2. alternatively, we can have a p2p map that is generated by the master
# process and broadcasted to all other processes. This still requires
# #world_size of cuda context, belonging to the master process, on each GPU.
# 3. we can have a cache file, that records the p2p access status. The first
# time the master process checks the p2p access, it will generate the cache
# file, at the cost of #world_size of cuda context. Later on, all processes
# can read the cache file to check the p2p access status without any cost of
# additional cuda context.
# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
# can have different cache files for different CUDA_VISIBLE_DEVICES settings,
# e.g. used by different vllm engines. The device id in the cache file is a
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
# of visible devices in the vllm engine.
_gpu_p2p_access_cache
:
Optional
[
Dict
[
str
,
bool
]]
=
None
def
gpu_p2p_access_check
(
i
:
int
,
j
:
int
)
->
bool
:
"""Check if GPU i can access GPU j."""
# if the cache variable is already calculated,
# read from the cache instead of checking it again
global
_gpu_p2p_access_cache
if
_gpu_p2p_access_cache
is
not
None
:
return
_gpu_p2p_access_cache
[
f
"
{
i
}
->
{
j
}
"
]
is_distributed
=
dist
.
is_initialized
()
num_dev
=
torch
.
cuda
.
device_count
()
cuda_visible_devices
=
envs
.
CUDA_VISIBLE_DEVICES
if
cuda_visible_devices
is
None
:
cuda_visible_devices
=
","
.
join
(
str
(
i
)
for
i
in
range
(
num_dev
))
VLLM_CONFIG_ROOT
=
envs
.
VLLM_CONFIG_ROOT
path
=
os
.
path
.
expanduser
(
f
"
{
VLLM_CONFIG_ROOT
}
/vllm/gpu_p2p_access_cache_for_
{
cuda_visible_devices
}
.json"
)
os
.
makedirs
(
os
.
path
.
dirname
(
path
),
exist_ok
=
True
)
if
(
not
is_distributed
or
get_local_rank
()
==
0
)
\
and
(
not
os
.
path
.
exists
(
path
)):
# only the local master process (with local_rank == 0) can
# enter this block to calculate the cache
logger
.
info
(
"generating GPU P2P access cache for in %s"
,
path
)
cache
=
{}
for
_i
in
range
(
num_dev
):
for
_j
in
range
(
num_dev
):
# on some platforms, P2P support might be buggy and we need
# additional checks. See also:
# https://github.com/vllm-project/vllm/issues/2728
cache
[
f
"
{
_i
}
->
{
_j
}
"
]
=
torch
.
cuda
.
can_device_access_peer
(
_i
,
_j
)
and
_can_actually_p2p
(
_i
,
_j
)
with
open
(
path
,
"w"
)
as
f
:
json
.
dump
(
cache
,
f
,
indent
=
4
)
if
is_distributed
:
cpu_world_group
=
get_cpu_world_group
()
dist
.
barrier
(
cpu_world_group
)
logger
.
info
(
"reading GPU P2P access cache from %s"
,
path
)
with
open
(
path
,
"r"
)
as
f
:
cache
=
json
.
load
(
f
)
_gpu_p2p_access_cache
=
cache
return
_gpu_p2p_access_cache
[
f
"
{
i
}
->
{
j
}
"
]
vllm/engine/arg_utils.py
View file @
b9e12416
import
argparse
import
argparse
import
dataclasses
import
dataclasses
import
json
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
DeviceConfig
,
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
DeviceConfig
,
EngineConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
EngineConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
...
@@ -34,11 +35,13 @@ class EngineArgs:
...
@@ -34,11 +35,13 @@ class EngineArgs:
seed
:
int
=
0
seed
:
int
=
0
max_model_len
:
Optional
[
int
]
=
None
max_model_len
:
Optional
[
int
]
=
None
worker_use_ray
:
bool
=
False
worker_use_ray
:
bool
=
False
distributed_executor_backend
:
Optional
[
str
]
=
None
pipeline_parallel_size
:
int
=
1
pipeline_parallel_size
:
int
=
1
tensor_parallel_size
:
int
=
1
tensor_parallel_size
:
int
=
1
max_parallel_loading_workers
:
Optional
[
int
]
=
None
max_parallel_loading_workers
:
Optional
[
int
]
=
None
block_size
:
int
=
16
block_size
:
int
=
16
enable_prefix_caching
:
bool
=
False
enable_prefix_caching
:
bool
=
False
disable_sliding_window
:
bool
=
False
use_v2_block_manager
:
bool
=
False
use_v2_block_manager
:
bool
=
False
swap_space
:
int
=
4
# GiB
swap_space
:
int
=
4
# GiB
gpu_memory_utilization
:
float
=
0.90
gpu_memory_utilization
:
float
=
0.90
...
@@ -48,6 +51,7 @@ class EngineArgs:
...
@@ -48,6 +51,7 @@ class EngineArgs:
disable_log_stats
:
bool
=
False
disable_log_stats
:
bool
=
False
revision
:
Optional
[
str
]
=
None
revision
:
Optional
[
str
]
=
None
code_revision
:
Optional
[
str
]
=
None
code_revision
:
Optional
[
str
]
=
None
rope_scaling
:
Optional
[
dict
]
=
None
tokenizer_revision
:
Optional
[
str
]
=
None
tokenizer_revision
:
Optional
[
str
]
=
None
quantization
:
Optional
[
str
]
=
None
quantization
:
Optional
[
str
]
=
None
enforce_eager
:
bool
=
False
enforce_eager
:
bool
=
False
...
@@ -62,6 +66,7 @@ class EngineArgs:
...
@@ -62,6 +66,7 @@ class EngineArgs:
max_lora_rank
:
int
=
16
max_lora_rank
:
int
=
16
fully_sharded_loras
:
bool
=
False
fully_sharded_loras
:
bool
=
False
lora_extra_vocab_size
:
int
=
256
lora_extra_vocab_size
:
int
=
256
long_lora_scaling_factors
:
Optional
[
Tuple
[
float
]]
=
None
lora_dtype
=
'auto'
lora_dtype
=
'auto'
max_cpu_loras
:
Optional
[
int
]
=
None
max_cpu_loras
:
Optional
[
int
]
=
None
device
:
str
=
'auto'
device
:
str
=
'auto'
...
@@ -83,6 +88,7 @@ class EngineArgs:
...
@@ -83,6 +88,7 @@ class EngineArgs:
speculative_model
:
Optional
[
str
]
=
None
speculative_model
:
Optional
[
str
]
=
None
num_speculative_tokens
:
Optional
[
int
]
=
None
num_speculative_tokens
:
Optional
[
int
]
=
None
speculative_max_model_len
:
Optional
[
int
]
=
None
speculative_max_model_len
:
Optional
[
int
]
=
None
speculative_disable_by_batch_size
:
Optional
[
int
]
=
None
ngram_prompt_lookup_max
:
Optional
[
int
]
=
None
ngram_prompt_lookup_max
:
Optional
[
int
]
=
None
ngram_prompt_lookup_min
:
Optional
[
int
]
=
None
ngram_prompt_lookup_min
:
Optional
[
int
]
=
None
...
@@ -166,8 +172,8 @@ class EngineArgs:
...
@@ -166,8 +172,8 @@ class EngineArgs:
'* "dummy" will initialize the weights with random values, '
'* "dummy" will initialize the weights with random values, '
'which is mainly for profiling.
\n
'
'which is mainly for profiling.
\n
'
'* "tensorizer" will load the weights using tensorizer from '
'* "tensorizer" will load the weights using tensorizer from '
'CoreWeave
which assumes tensorizer_uri is set to the location of
'
'CoreWeave
. See the Tensorize vLLM Model script in the Examples
'
'
the serialized weights.
'
)
'
section for more information.
\n
'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--dtype'
,
'--dtype'
,
type
=
str
,
type
=
str
,
...
@@ -186,12 +192,11 @@ class EngineArgs:
...
@@ -186,12 +192,11 @@ class EngineArgs:
parser
.
add_argument
(
parser
.
add_argument
(
'--kv-cache-dtype'
,
'--kv-cache-dtype'
,
type
=
str
,
type
=
str
,
choices
=
[
'auto'
,
'fp8'
],
choices
=
[
'auto'
,
'fp8'
,
'fp8_e5m2'
,
'fp8_e4m3'
],
default
=
EngineArgs
.
kv_cache_dtype
,
default
=
EngineArgs
.
kv_cache_dtype
,
help
=
'Data type for kv cache storage. If "auto", will use model '
help
=
'Data type for kv cache storage. If "auto", will use model '
'data type. FP8_E5M2 (without scaling) is only supported on cuda '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)'
)
'supported for common inference criteria.'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--quantization-param-path'
,
'--quantization-param-path'
,
type
=
nullable_str
,
type
=
nullable_str
,
...
@@ -220,10 +225,17 @@ class EngineArgs:
...
@@ -220,10 +225,17 @@ class EngineArgs:
' Can be overridden per request via guided_decoding_backend'
' Can be overridden per request via guided_decoding_backend'
' parameter.'
)
' parameter.'
)
# Parallel arguments
# Parallel arguments
parser
.
add_argument
(
'--worker-use-ray'
,
parser
.
add_argument
(
action
=
'store_true'
,
'--distributed-executor-backend'
,
help
=
'Use Ray for distributed serving, will be '
choices
=
[
'ray'
,
'mp'
],
'automatically set when using more than 1 GPU.'
)
default
=
EngineArgs
.
distributed_executor_backend
,
help
=
'Backend to use for distributed serving. When more than 1 GPU '
'is used, will be automatically set to "ray" if installed '
'or "mp" (multiprocessing) otherwise.'
)
parser
.
add_argument
(
'--worker-use-ray'
,
action
=
'store_true'
,
help
=
'Deprecated, use --distributed-executor-backend=ray.'
)
parser
.
add_argument
(
'--pipeline-parallel-size'
,
parser
.
add_argument
(
'--pipeline-parallel-size'
,
'-pp'
,
'-pp'
,
type
=
int
,
type
=
int
,
...
@@ -256,6 +268,10 @@ class EngineArgs:
...
@@ -256,6 +268,10 @@ class EngineArgs:
parser
.
add_argument
(
'--enable-prefix-caching'
,
parser
.
add_argument
(
'--enable-prefix-caching'
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'Enables automatic prefix caching.'
)
help
=
'Enables automatic prefix caching.'
)
parser
.
add_argument
(
'--disable-sliding-window'
,
action
=
'store_true'
,
help
=
'Disables sliding window, '
'capping to sliding window size'
)
parser
.
add_argument
(
'--use-v2-block-manager'
,
parser
.
add_argument
(
'--use-v2-block-manager'
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'Use BlockSpaceMangerV2.'
)
help
=
'Use BlockSpaceMangerV2.'
)
...
@@ -320,6 +336,11 @@ class EngineArgs:
...
@@ -320,6 +336,11 @@ class EngineArgs:
'None, we assume the model weights are not '
'None, we assume the model weights are not '
'quantized and use `dtype` to determine the data '
'quantized and use `dtype` to determine the data '
'type of the weights.'
)
'type of the weights.'
)
parser
.
add_argument
(
'--rope-scaling'
,
default
=
None
,
type
=
json
.
loads
,
help
=
'RoPE scaling configuration in JSON format. '
'For example, {"type":"dynamic","factor":2.0}'
)
parser
.
add_argument
(
'--enforce-eager'
,
parser
.
add_argument
(
'--enforce-eager'
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'Always use eager-mode PyTorch. If False, '
help
=
'Always use eager-mode PyTorch. If False, '
...
@@ -331,9 +352,9 @@ class EngineArgs:
...
@@ -331,9 +352,9 @@ class EngineArgs:
help
=
'Maximum context length covered by CUDA '
help
=
'Maximum context length covered by CUDA '
'graphs. When a sequence has context length '
'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode. '
'larger than this, we fall back to eager mode. '
'(DEPRECATED. Use --max-seq
_
len-to-capture instead'
'(DEPRECATED. Use --max-seq
-
len-to-capture instead'
')'
)
')'
)
parser
.
add_argument
(
'--max-seq
_
len-to-capture'
,
parser
.
add_argument
(
'--max-seq
-
len-to-capture'
,
type
=
int
,
type
=
int
,
default
=
EngineArgs
.
max_seq_len_to_capture
,
default
=
EngineArgs
.
max_seq_len_to_capture
,
help
=
'Maximum sequence length covered by CUDA '
help
=
'Maximum sequence length covered by CUDA '
...
@@ -388,6 +409,17 @@ class EngineArgs:
...
@@ -388,6 +409,17 @@ class EngineArgs:
choices
=
[
'auto'
,
'float16'
,
'bfloat16'
,
'float32'
],
choices
=
[
'auto'
,
'float16'
,
'bfloat16'
,
'float32'
],
help
=
(
'Data type for LoRA. If auto, will default to '
help
=
(
'Data type for LoRA. If auto, will default to '
'base model dtype.'
))
'base model dtype.'
))
parser
.
add_argument
(
'--long-lora-scaling-factors'
,
type
=
nullable_str
,
default
=
EngineArgs
.
long_lora_scaling_factors
,
help
=
(
'Specify multiple scaling factors (which can '
'be different from base model scaling factor '
'- see eg. Long LoRA) to allow for multiple '
'LoRA adapters trained with those scaling '
'factors to be used at the same time. If not '
'specified, only adapters trained with the '
'base model scaling factor are allowed.'
))
parser
.
add_argument
(
parser
.
add_argument
(
'--max-cpu-loras'
,
'--max-cpu-loras'
,
type
=
int
,
type
=
int
,
...
@@ -467,6 +499,13 @@ class EngineArgs:
...
@@ -467,6 +499,13 @@ class EngineArgs:
'draft model. Sequences over this length will skip '
'draft model. Sequences over this length will skip '
'speculation.'
)
'speculation.'
)
parser
.
add_argument
(
'--speculative-disable-by-batch-size'
,
type
=
int
,
default
=
EngineArgs
.
speculative_disable_by_batch_size
,
help
=
'Disable speculative decoding for new incoming requests '
'if the number of enqueue requests is larger than this value.'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--ngram-prompt-lookup-max'
,
'--ngram-prompt-lookup-max'
,
type
=
int
,
type
=
int
,
...
@@ -508,7 +547,7 @@ class EngineArgs:
...
@@ -508,7 +547,7 @@ class EngineArgs:
return
parser
return
parser
@
classmethod
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
)
->
'EngineArgs'
:
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
# Get the list of attributes of this dataclass.
# Get the list of attributes of this dataclass.
attrs
=
[
attr
.
name
for
attr
in
dataclasses
.
fields
(
cls
)]
attrs
=
[
attr
.
name
for
attr
in
dataclasses
.
fields
(
cls
)]
# Set the attributes from the parsed arguments.
# Set the attributes from the parsed arguments.
...
@@ -520,10 +559,11 @@ class EngineArgs:
...
@@ -520,10 +559,11 @@ class EngineArgs:
model_config
=
ModelConfig
(
model_config
=
ModelConfig
(
self
.
model
,
self
.
tokenizer
,
self
.
tokenizer_mode
,
self
.
model
,
self
.
tokenizer
,
self
.
tokenizer_mode
,
self
.
trust_remote_code
,
self
.
dtype
,
self
.
seed
,
self
.
revision
,
self
.
trust_remote_code
,
self
.
dtype
,
self
.
seed
,
self
.
revision
,
self
.
code_revision
,
self
.
tokenizer_revision
,
self
.
max_model_len
,
self
.
code_revision
,
self
.
rope_scaling
,
self
.
tokenizer_revision
,
self
.
quantization
,
self
.
quantization_param_path
,
self
.
max_model_len
,
self
.
quantization
,
self
.
enforce_eager
,
self
.
max_context_len_to_capture
,
self
.
quantization_param_path
,
self
.
enforce_eager
,
self
.
max_seq_len_to_capture
,
self
.
max_logprobs
,
self
.
max_context_len_to_capture
,
self
.
max_seq_len_to_capture
,
self
.
max_logprobs
,
self
.
disable_sliding_window
,
self
.
skip_tokenizer_init
,
self
.
served_model_name
)
self
.
skip_tokenizer_init
,
self
.
served_model_name
)
cache_config
=
CacheConfig
(
self
.
block_size
,
cache_config
=
CacheConfig
(
self
.
block_size
,
self
.
gpu_memory_utilization
,
self
.
gpu_memory_utilization
,
...
@@ -532,14 +572,18 @@ class EngineArgs:
...
@@ -532,14 +572,18 @@ class EngineArgs:
model_config
.
get_sliding_window
(),
model_config
.
get_sliding_window
(),
self
.
enable_prefix_caching
)
self
.
enable_prefix_caching
)
parallel_config
=
ParallelConfig
(
parallel_config
=
ParallelConfig
(
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
,
self
.
pipeline_parallel_size
,
self
.
worker_use_ray
,
self
.
max_parallel_loading_workers
,
self
.
tensor_parallel_size
,
self
.
worker_use_ray
,
self
.
max_parallel_loading_workers
,
self
.
disable_custom_all_reduce
,
self
.
disable_custom_all_reduce
,
TokenizerPoolConfig
.
create_config
(
TokenizerPoolConfig
.
create_config
(
self
.
tokenizer_pool_size
,
self
.
tokenizer_pool_size
,
self
.
tokenizer_pool_type
,
self
.
tokenizer_pool_type
,
self
.
tokenizer_pool_extra_config
,
self
.
tokenizer_pool_extra_config
,
),
self
.
ray_workers_use_nsight
)
),
self
.
ray_workers_use_nsight
,
distributed_executor_backend
=
self
.
distributed_executor_backend
)
speculative_config
=
SpeculativeConfig
.
maybe_create_spec_config
(
speculative_config
=
SpeculativeConfig
.
maybe_create_spec_config
(
target_model_config
=
model_config
,
target_model_config
=
model_config
,
...
@@ -547,6 +591,8 @@ class EngineArgs:
...
@@ -547,6 +591,8 @@ class EngineArgs:
target_dtype
=
self
.
dtype
,
target_dtype
=
self
.
dtype
,
speculative_model
=
self
.
speculative_model
,
speculative_model
=
self
.
speculative_model
,
num_speculative_tokens
=
self
.
num_speculative_tokens
,
num_speculative_tokens
=
self
.
num_speculative_tokens
,
speculative_disable_by_batch_size
=
self
.
speculative_disable_by_batch_size
,
speculative_max_model_len
=
self
.
speculative_max_model_len
,
speculative_max_model_len
=
self
.
speculative_max_model_len
,
enable_chunked_prefill
=
self
.
enable_chunked_prefill
,
enable_chunked_prefill
=
self
.
enable_chunked_prefill
,
use_v2_block_manager
=
self
.
use_v2_block_manager
,
use_v2_block_manager
=
self
.
use_v2_block_manager
,
...
@@ -564,12 +610,14 @@ class EngineArgs:
...
@@ -564,12 +610,14 @@ class EngineArgs:
speculative_config
.
num_lookahead_slots
),
speculative_config
.
num_lookahead_slots
),
delay_factor
=
self
.
scheduler_delay_factor
,
delay_factor
=
self
.
scheduler_delay_factor
,
enable_chunked_prefill
=
self
.
enable_chunked_prefill
,
enable_chunked_prefill
=
self
.
enable_chunked_prefill
,
embedding_mode
=
model_config
.
embedding_mode
,
)
)
lora_config
=
LoRAConfig
(
lora_config
=
LoRAConfig
(
max_lora_rank
=
self
.
max_lora_rank
,
max_lora_rank
=
self
.
max_lora_rank
,
max_loras
=
self
.
max_loras
,
max_loras
=
self
.
max_loras
,
fully_sharded_loras
=
self
.
fully_sharded_loras
,
fully_sharded_loras
=
self
.
fully_sharded_loras
,
lora_extra_vocab_size
=
self
.
lora_extra_vocab_size
,
lora_extra_vocab_size
=
self
.
lora_extra_vocab_size
,
long_lora_scaling_factors
=
self
.
long_lora_scaling_factors
,
lora_dtype
=
self
.
lora_dtype
,
lora_dtype
=
self
.
lora_dtype
,
max_cpu_loras
=
self
.
max_cpu_loras
if
self
.
max_cpu_loras
max_cpu_loras
=
self
.
max_cpu_loras
if
self
.
max_cpu_loras
and
self
.
max_cpu_loras
>
0
else
None
)
if
self
.
enable_lora
else
None
and
self
.
max_cpu_loras
>
0
else
None
)
if
self
.
enable_lora
else
None
...
@@ -599,6 +647,13 @@ class EngineArgs:
...
@@ -599,6 +647,13 @@ class EngineArgs:
decoding_config
=
DecodingConfig
(
decoding_config
=
DecodingConfig
(
guided_decoding_backend
=
self
.
guided_decoding_backend
)
guided_decoding_backend
=
self
.
guided_decoding_backend
)
if
(
model_config
.
get_sliding_window
()
is
not
None
and
scheduler_config
.
chunked_prefill_enabled
and
not
scheduler_config
.
use_v2_block_manager
):
raise
ValueError
(
"Chunked prefill is not supported with sliding window. "
"Set --disable-sliding-window to disable sliding window."
)
return
EngineConfig
(
model_config
=
model_config
,
return
EngineConfig
(
model_config
=
model_config
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
parallel_config
=
parallel_config
,
parallel_config
=
parallel_config
,
...
...
vllm/engine/async_llm_engine.py
View file @
b9e12416
import
asyncio
import
asyncio
import
time
import
time
from
functools
import
partial
from
functools
import
partial
from
typing
import
(
Any
,
AsyncIterator
,
Callable
,
Dict
,
Iterable
,
List
,
from
typing
import
(
AsyncIterator
,
Callable
,
Dict
,
Iterable
,
List
,
Optional
,
Optional
,
Set
,
Tuple
,
Type
,
Union
)
Set
,
Tuple
,
Type
,
Union
)
from
transformers
import
PreTrainedTokenizer
from
transformers
import
PreTrainedTokenizer
...
@@ -12,11 +12,13 @@ from vllm.core.scheduler import SchedulerOutputs
...
@@ -12,11 +12,13 @@ from vllm.core.scheduler import SchedulerOutputs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.executor.ray_utils
import
initialize_ray_cluster
,
ray
from
vllm.executor.ray_utils
import
initialize_ray_cluster
,
ray
from
vllm.inputs
import
LLMInputs
,
PromptInputs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
ExecuteModelRequest
,
MultiModalData
,
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -47,15 +49,16 @@ def _raise_exception_on_finish(
...
@@ -47,15 +49,16 @@ def _raise_exception_on_finish(
class
AsyncStream
:
class
AsyncStream
:
"""A stream of RequestOutputs
f
or
a request that can be
"""A stream of RequestOutputs or
EmbeddingRequestOutputs for a request
iterated over asynchronously."""
that can be
iterated over asynchronously."""
def
__init__
(
self
,
request_id
:
str
)
->
None
:
def
__init__
(
self
,
request_id
:
str
)
->
None
:
self
.
request_id
=
request_id
self
.
request_id
=
request_id
self
.
_queue
:
asyncio
.
Queue
=
asyncio
.
Queue
()
self
.
_queue
:
asyncio
.
Queue
=
asyncio
.
Queue
()
self
.
_finished
=
False
self
.
_finished
=
False
def
put
(
self
,
item
:
Union
[
RequestOutput
,
Exception
])
->
None
:
def
put
(
self
,
item
:
Union
[
RequestOutput
,
EmbeddingRequestOutput
,
Exception
])
->
None
:
if
self
.
_finished
:
if
self
.
_finished
:
return
return
self
.
_queue
.
put_nowait
(
item
)
self
.
_queue
.
put_nowait
(
item
)
...
@@ -71,7 +74,7 @@ class AsyncStream:
...
@@ -71,7 +74,7 @@ class AsyncStream:
def
__aiter__
(
self
):
def
__aiter__
(
self
):
return
self
return
self
async
def
__anext__
(
self
)
->
RequestOutput
:
async
def
__anext__
(
self
)
->
Union
[
RequestOutput
,
EmbeddingRequestOutput
]
:
result
=
await
self
.
_queue
.
get
()
result
=
await
self
.
_queue
.
get
()
if
isinstance
(
result
,
Exception
):
if
isinstance
(
result
,
Exception
):
raise
result
raise
result
...
@@ -108,7 +111,8 @@ class RequestTracker:
...
@@ -108,7 +111,8 @@ class RequestTracker:
self
.
abort_request
(
rid
)
self
.
abort_request
(
rid
)
def
process_request_output
(
self
,
def
process_request_output
(
self
,
request_output
:
RequestOutput
,
request_output
:
Union
[
RequestOutput
,
EmbeddingRequestOutput
],
*
,
*
,
verbose
:
bool
=
False
)
->
None
:
verbose
:
bool
=
False
)
->
None
:
"""Process a request output from the engine."""
"""Process a request output from the engine."""
...
@@ -196,7 +200,8 @@ class RequestTracker:
...
@@ -196,7 +200,8 @@ class RequestTracker:
class
_AsyncLLMEngine
(
LLMEngine
):
class
_AsyncLLMEngine
(
LLMEngine
):
"""Extension of LLMEngine to add async methods."""
"""Extension of LLMEngine to add async methods."""
async
def
step_async
(
self
)
->
List
[
RequestOutput
]:
async
def
step_async
(
self
)
->
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]:
"""Performs one decoding iteration and returns newly generated results.
"""Performs one decoding iteration and returns newly generated results.
The workers are ran asynchronously if possible.
The workers are ran asynchronously if possible.
...
@@ -230,66 +235,77 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -230,66 +235,77 @@ class _AsyncLLMEngine(LLMEngine):
# Log stats.
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
output
)
self
.
do_log_stats
(
scheduler_outputs
,
output
)
if
not
request_outputs
:
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
# the RPC thread in the workers so that they can process any other
# queued control plane messages, such as add/remove lora adapters.
await
self
.
model_executor
.
stop_remote_worker_execution_loop_async
()
return
request_outputs
return
request_outputs
async
def
encode_request
_async
(
async
def
process_model_inputs
_async
(
self
,
self
,
request_id
:
str
,
# pylint: disable=unused-argument
request_id
:
str
,
prompt
:
Optional
[
str
],
inputs
:
PromptInputs
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
):
)
->
LLMInputs
:
if
prompt_token_ids
is
None
:
if
isinstance
(
inputs
,
str
):
assert
prompt
is
not
None
inputs
=
{
"prompt"
:
inputs
}
prompt_token_ids
=
await
self
.
tokenizer
.
encode_async
(
if
"prompt_token_ids"
not
in
inputs
:
tokenizer
=
self
.
get_tokenizer_group
(
"prompts must be None if "
"skip_tokenizer_init is True"
)
prompt_token_ids
=
await
tokenizer
.
encode_async
(
request_id
=
request_id
,
request_id
=
request_id
,
prompt
=
prompt
,
prompt
=
inputs
[
"
prompt
"
]
,
lora_request
=
lora_request
)
lora_request
=
lora_request
)
return
prompt_token_ids
else
:
prompt_token_ids
=
inputs
[
"prompt_token_ids"
]
return
LLMInputs
(
prompt_token_ids
=
prompt_token_ids
,
prompt
=
inputs
.
get
(
"prompt"
),
multi_modal_data
=
inputs
.
get
(
"multi_modal_data"
))
async
def
add_request_async
(
async
def
add_request_async
(
self
,
self
,
request_id
:
str
,
request_id
:
str
,
prompt
:
Optional
[
str
],
inputs
:
PromptInputs
,
sampling_params
:
SamplingParams
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
)
->
None
:
)
->
None
:
if
lora_request
is
not
None
and
not
self
.
lora_config
:
if
lora_request
is
not
None
and
not
self
.
lora_config
:
raise
ValueError
(
f
"Got lora_request
{
lora_request
}
but LoRA is "
raise
ValueError
(
f
"Got lora_request
{
lora_request
}
but LoRA is "
"not enabled!"
)
"not enabled!"
)
if
arrival_time
is
None
:
if
arrival_time
is
None
:
arrival_time
=
time
.
time
()
arrival_time
=
time
.
time
()
prompt_token_ids
=
await
self
.
encode_request_async
(
processed_inputs
=
await
self
.
process_model_inputs_async
(
request_id
=
request_id
,
inputs
=
inputs
,
lora_request
=
lora_request
)
self
.
_add_processed_request
(
request_id
=
request_id
,
request_id
=
request_id
,
prompt
=
prompt
,
processed_inputs
=
processed_inputs
,
prompt_token_ids
=
prompt_token_ids
,
params
=
params
,
lora_request
=
lora_request
)
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
return
self
.
add_request
(
request_id
,
)
prompt
=
prompt
,
prompt_token_ids
=
prompt_token_ids
,
sampling_params
=
sampling_params
,
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
multi_modal_data
=
multi_modal_data
)
async
def
check_health_async
(
self
)
->
None
:
async
def
check_health_async
(
self
)
->
None
:
self
.
model_executor
.
check_health
()
self
.
model_executor
.
check_health
()
class
AsyncLLMEngine
:
class
AsyncLLMEngine
:
"""An asynchronous wrapper for LLMEngine.
"""An asynchronous wrapper for :class:`LLMEngine`.
This class is used to wrap the LLMEngine class to make it asynchronous. It
uses asyncio to create a background loop that keeps processing incoming
requests. The LLMEngine is kicked by the generate method when there
are requests in the waiting queue. The generate method yields the outputs
from the LLMEngine to the caller.
NOTE: For the comprehensive list of arguments, see `LLMEngine`.
This class is used to wrap the :class:`LLMEngine` class to make it
asynchronous. It uses asyncio to create a background loop that keeps
processing incoming requests. The :class:`LLMEngine` is kicked by the
generate method when there are requests in the waiting queue. The generate
method yields the outputs from the :class:`LLMEngine` to the caller.
Args:
Args:
worker_use_ray: Whether to use Ray for model workers. Required for
worker_use_ray: Whether to use Ray for model workers. Required for
...
@@ -303,8 +319,8 @@ class AsyncLLMEngine:
...
@@ -303,8 +319,8 @@ class AsyncLLMEngine:
being printed in log.
being printed in log.
start_engine_loop: If True, the background task to run the engine
start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call.
will be automatically started in the generate call.
*args: Arguments for LLMEngine.
*args: Arguments for
:class:`
LLMEngine
`
.
*kwargs: Arguments for LLMEngine.
*
*kwargs: Arguments for
:class:`
LLMEngine
`
.
"""
"""
_engine_class
:
Type
[
_AsyncLLMEngine
]
=
_AsyncLLMEngine
_engine_class
:
Type
[
_AsyncLLMEngine
]
=
_AsyncLLMEngine
...
@@ -327,7 +343,7 @@ class AsyncLLMEngine:
...
@@ -327,7 +343,7 @@ class AsyncLLMEngine:
# We need to keep a reference to unshielded
# We need to keep a reference to unshielded
# task as well to prevent it from being garbage
# task as well to prevent it from being garbage
# collected
# collected
self
.
_background_loop_unshielded
:
Optional
[
asyncio
.
Task
[
Any
]
]
=
None
self
.
_background_loop_unshielded
:
Optional
[
asyncio
.
Task
]
=
None
self
.
start_engine_loop
=
start_engine_loop
self
.
start_engine_loop
=
start_engine_loop
self
.
_errored_with
:
Optional
[
BaseException
]
=
None
self
.
_errored_with
:
Optional
[
BaseException
]
=
None
...
@@ -344,27 +360,31 @@ class AsyncLLMEngine:
...
@@ -344,27 +360,31 @@ class AsyncLLMEngine:
"""Creates an async LLM engine from the engine arguments."""
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
# Create the engine configs.
engine_config
=
engine_args
.
create_engine_config
()
engine_config
=
engine_args
.
create_engine_config
()
distributed_executor_backend
=
(
engine_config
.
parallel_config
.
distributed_executor_backend
)
if
engine_config
.
device_config
.
device_type
==
"neuron"
:
if
engine_config
.
device_config
.
device_type
==
"neuron"
:
from
vllm.executor.neuron_executor
import
NeuronExecutorAsync
from
vllm.executor.neuron_executor
import
NeuronExecutorAsync
executor_class
=
NeuronExecutorAsync
executor_class
=
NeuronExecutorAsync
elif
engine_config
.
device_config
.
device_type
==
"cpu"
:
elif
engine_config
.
device_config
.
device_type
==
"cpu"
:
assert
not
engine_config
.
parallel_config
.
worker_use_ray
,
(
assert
distributed_executor_backend
is
None
,
(
"
Ray
is not supported with the CPU backend."
)
"
Distributed execution
is not supported with the CPU backend."
)
from
vllm.executor.cpu_executor
import
CPUExecutorAsync
from
vllm.executor.cpu_executor
import
CPUExecutorAsync
executor_class
=
CPUExecutorAsync
executor_class
=
CPUExecutorAsync
elif
engine_config
.
parallel_config
.
worker_use_
ray
:
elif
distributed_executor_backend
==
"
ray
"
:
initialize_ray_cluster
(
engine_config
.
parallel_config
)
initialize_ray_cluster
(
engine_config
.
parallel_config
)
from
vllm.executor.ray_gpu_executor
import
RayGPUExecutorAsync
from
vllm.executor.ray_gpu_executor
import
RayGPUExecutorAsync
executor_class
=
RayGPUExecutorAsync
executor_class
=
RayGPUExecutorAsync
elif
distributed_executor_backend
==
"mp"
:
from
vllm.executor.multiproc_gpu_executor
import
(
MultiprocessingGPUExecutorAsync
)
executor_class
=
MultiprocessingGPUExecutorAsync
else
:
else
:
assert
engine_config
.
parallel_config
.
world_size
==
1
,
(
"Ray is required if parallel_config.world_size > 1."
)
from
vllm.executor.gpu_executor
import
GPUExecutorAsync
from
vllm.executor.gpu_executor
import
GPUExecutorAsync
executor_class
=
GPUExecutorAsync
executor_class
=
GPUExecutorAsync
# Create the async LLM engine.
# Create the async LLM engine.
engine
=
cls
(
engine
=
cls
(
engine_config
.
parallel_config
.
worker_use_
ray
,
distributed_executor_backend
==
"
ray
"
,
engine_args
.
engine_use_ray
,
engine_args
.
engine_use_ray
,
**
engine_config
.
to_dict
(),
**
engine_config
.
to_dict
(),
executor_class
=
executor_class
,
executor_class
=
executor_class
,
...
@@ -510,27 +530,31 @@ class AsyncLLMEngine:
...
@@ -510,27 +530,31 @@ class AsyncLLMEngine:
async
def
add_request
(
async
def
add_request
(
self
,
self
,
request_id
:
str
,
request_id
:
str
,
prompt
:
Optional
[
str
],
inputs
:
PromptInputs
,
sampling_params
:
SamplingParams
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
)
->
AsyncStream
:
)
->
AsyncStream
:
if
self
.
log_requests
:
if
self
.
log_requests
:
shortened_prompt
=
prompt
if
isinstance
(
inputs
,
str
):
shortened_token_ids
=
prompt_token_ids
shortened_prompt
=
inputs
if
self
.
max_log_len
is
not
None
:
shortened_token_ids
=
None
else
:
shortened_prompt
=
inputs
.
get
(
"prompt"
)
shortened_token_ids
=
inputs
.
get
(
"prompt_token_ids"
)
max_log_len
=
self
.
max_log_len
if
max_log_len
is
not
None
:
if
shortened_prompt
is
not
None
:
if
shortened_prompt
is
not
None
:
shortened_prompt
=
shortened_prompt
[:
self
.
max_log_len
]
shortened_prompt
=
shortened_prompt
[:
max_log_len
]
if
shortened_token_ids
is
not
None
:
if
shortened_token_ids
is
not
None
:
shortened_token_ids
=
shortened_token_ids
[:
self
.
shortened_token_ids
=
shortened_token_ids
[:
max_log_len
]
max_log_len
]
logger
.
info
(
logger
.
info
(
"Received request %s: prompt: %r, "
"Received request %s: prompt: %r, "
"
sampling_
params: %s, prompt_token_ids: %s, "
"params: %s, prompt_token_ids: %s, "
"lora_request: %s."
,
request_id
,
shortened_prompt
,
"lora_request: %s."
,
request_id
,
shortened_prompt
,
params
,
sampling_params
,
shortened_token_ids
,
lora_request
)
shortened_token_ids
,
lora_request
)
if
not
self
.
is_running
:
if
not
self
.
is_running
:
if
self
.
start_engine_loop
:
if
self
.
start_engine_loop
:
...
@@ -546,39 +570,33 @@ class AsyncLLMEngine:
...
@@ -546,39 +570,33 @@ class AsyncLLMEngine:
arrival_time
=
time
.
time
()
arrival_time
=
time
.
time
()
if
self
.
engine_use_ray
:
if
self
.
engine_use_ray
:
pro
mpt_token_ids
=
await
(
pro
cessed_inputs
=
await
self
.
engine
.
process_model_inputs_async
\
self
.
engine
.
encode_request_async
.
remote
(
# type: ignore
.
remote
(
# type: ignore
request_id
=
request_id
,
request_id
=
request_id
,
prompt
=
prompt
,
inputs
=
inputs
,
prompt_token_ids
=
prompt_token_ids
,
lora_request
=
lora_request
)
lora_request
=
lora_request
))
else
:
else
:
pro
mpt_token_id
s
=
await
self
.
engine
.
encode_request
_async
(
pro
cessed_input
s
=
await
self
.
engine
.
process_model_inputs
_async
(
request_id
=
request_id
,
request_id
=
request_id
,
prompt
=
prompt
,
inputs
=
inputs
,
prompt_token_ids
=
prompt_token_ids
,
lora_request
=
lora_request
)
lora_request
=
lora_request
)
stream
=
self
.
_request_tracker
.
add_request
(
stream
=
self
.
_request_tracker
.
add_request
(
request_id
,
request_id
,
prompt
=
prompt
,
inputs
=
processed_inputs
,
sampling_params
=
sampling_params
,
params
=
params
,
prompt_token_ids
=
prompt_token_ids
,
arrival_time
=
arrival_time
,
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
multi_modal_data
=
multi_modal_data
,
)
)
return
stream
return
stream
async
def
generate
(
async
def
generate
(
self
,
self
,
prompt
:
Optional
[
str
]
,
inputs
:
PromptInputs
,
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
request_id
:
str
,
request_id
:
str
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
)
->
AsyncIterator
[
RequestOutput
]:
)
->
AsyncIterator
[
RequestOutput
]:
"""Generate outputs for a request.
"""Generate outputs for a request.
...
@@ -587,18 +605,16 @@ class AsyncLLMEngine:
...
@@ -587,18 +605,16 @@ class AsyncLLMEngine:
from the LLMEngine to the caller.
from the LLMEngine to the caller.
Args:
Args:
prompt: The prompt string. Can be None if prompt_token_ids is
inputs: The inputs to the LLM. See
provided.
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
sampling_params: The sampling parameters of the request.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
request_id: The unique id of the request.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
lora_request: LoRA request to use for generation, if any.
lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data per request.
Yields:
Yields:
The output `RequestOutput` objects from the LLMEngine
for the
The output `RequestOutput` objects from the LLMEngine
request.
for the
request.
Details:
Details:
- If the engine is not running, start the background loop,
- If the engine is not running, start the background loop,
...
@@ -643,25 +659,112 @@ class AsyncLLMEngine:
...
@@ -643,25 +659,112 @@ class AsyncLLMEngine:
>>> # Process and return the final output
>>> # Process and return the final output
>>> ...
>>> ...
"""
"""
# Preprocess the request.
async
for
output
in
self
.
_process_request
(
arrival_time
=
time
.
time
()
try
:
stream
=
await
self
.
add_request
(
request_id
,
request_id
,
prompt
,
inputs
,
sampling_params
,
sampling_params
,
prompt_token_ids
=
prompt_token_ids
,
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
multi_modal_data
=
multi_modal_data
,
):
)
yield
LLMEngine
.
validate_output
(
output
,
RequestOutput
)
async
def
encode
(
self
,
inputs
:
PromptInputs
,
pooling_params
:
PoolingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
AsyncIterator
[
EmbeddingRequestOutput
]:
"""Generate outputs for a request from an embedding model.
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine
for the request.
Details:
- If the engine is not running, start the background loop,
which iteratively invokes
:meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
to process the waiting requests.
- Add the request to the engine's `RequestTracker`.
On the next background loop, this request will be sent to
the underlying engine.
Also, a corresponding `AsyncStream` will be created.
- Wait for the request outputs from `AsyncStream` and yield them.
Example:
>>> # Please refer to entrypoints/api_server.py for
>>> # the complete example.
>>>
>>> # initialize the engine and the example input
>>> engine = AsyncLLMEngine.from_engine_args(engine_args)
>>> example_input = {
>>> "input": "What is LLM?",
>>> "request_id": 0,
>>> }
>>>
>>> # start the generation
>>> results_generator = engine.encode(
>>> example_input["input"],
>>> PoolingParams(),
>>> example_input["request_id"])
>>>
>>> # get the results
>>> final_output = None
>>> async for request_output in results_generator:
>>> if await request.is_disconnected():
>>> # Abort the request if the client disconnects.
>>> await engine.abort(request_id)
>>> # Return or raise an error
>>> ...
>>> final_output = request_output
>>>
>>> # Process and return the final output
>>> ...
"""
async
for
output
in
self
.
_process_request
(
request_id
,
inputs
,
pooling_params
,
lora_request
=
lora_request
,
):
yield
LLMEngine
.
validate_output
(
output
,
EmbeddingRequestOutput
)
async
def
_process_request
(
self
,
request_id
:
str
,
inputs
:
PromptInputs
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
*
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
AsyncIterator
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]:
"""Common logic to process requests with SamplingParams or
PoolingParams."""
arrival_time
=
time
.
time
()
stream
=
await
self
.
add_request
(
request_id
,
inputs
,
params
,
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
)
try
:
async
for
request_output
in
stream
:
async
for
request_output
in
stream
:
yield
request_output
yield
request_output
except
(
Exception
,
asyncio
.
CancelledError
)
as
e
:
except
(
Exception
,
asyncio
.
CancelledError
)
as
e
:
# If there is an exception or coroutine is cancelled, abort the
# request.
self
.
_abort
(
request_id
)
self
.
_abort
(
request_id
)
raise
e
raise
e
...
...
vllm/engine/llm_engine.py
View file @
b9e12416
import
time
import
time
from
typing
import
Iterable
,
List
,
Optional
,
Type
,
Union
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
ClassVar
,
Iterable
,
List
,
Optional
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Type
,
TypeVar
,
Union
from
transformers
import
GenerationConfig
,
PreTrainedTokenizer
from
transformers
import
GenerationConfig
,
PreTrainedTokenizer
...
@@ -18,12 +21,16 @@ from vllm.engine.output_processor.stop_checker import StopChecker
...
@@ -18,12 +21,16 @@ from vllm.engine.output_processor.stop_checker import StopChecker
from
vllm.engine.output_processor.util
import
create_output_by_sequence_group
from
vllm.engine.output_processor.util
import
create_output_by_sequence_group
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.inputs
import
LLMInputs
,
PromptInputs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
(
EmbeddingRequestOutput
,
RequestOutput
,
RequestOutputFactory
)
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
ExecuteModelRequest
,
MultiModalData
,
SamplerOutput
,
from
vllm.sequence
import
(
EmbeddingSequenceGroupOutput
,
ExecuteModelRequest
,
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
,
PoolerOutput
,
SamplerOutput
,
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
,
SequenceStatus
)
SequenceStatus
)
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.tokenizer_group
import
(
BaseTokenizerGroup
,
from
vllm.transformers_utils.tokenizer_group
import
(
BaseTokenizerGroup
,
...
@@ -47,6 +54,9 @@ def _load_generation_config_dict(model_config: ModelConfig):
...
@@ -47,6 +54,9 @@ def _load_generation_config_dict(model_config: ModelConfig):
return
{}
return
{}
_O
=
TypeVar
(
"_O"
,
RequestOutput
,
EmbeddingRequestOutput
)
class
LLMEngine
:
class
LLMEngine
:
"""An LLM engine that receives requests and generates texts.
"""An LLM engine that receives requests and generates texts.
...
@@ -57,11 +67,11 @@ class LLMEngine:
...
@@ -57,11 +67,11 @@ class LLMEngine:
iteration-level scheduling and efficient memory management to maximize the
iteration-level scheduling and efficient memory management to maximize the
serving throughput.
serving throughput.
The
`
LLM` class wraps this class for offline batched inference
and the
The
:class:`~vllm.
LLM` class wraps this class for offline batched inference
`AsyncLLMEngine` class wraps this class for online serving.
and the :class:
`AsyncLLMEngine` class wraps this class for online serving.
NOTE:
The config arguments are derived from
the `EngineArgs` class. For th
e
The config arguments are derived from
:class:`~vllm.EngineArgs`. (Se
e
comprehensive list of arguments, see `E
ngine
A
rgs`
.
:ref:`e
ngine
_a
rgs`
)
Args:
Args:
model_config: The configuration related to the LLM model.
model_config: The configuration related to the LLM model.
...
@@ -78,9 +88,60 @@ class LLMEngine:
...
@@ -78,9 +88,60 @@ class LLMEngine:
executor_class: The model executor class for managing distributed
executor_class: The model executor class for managing distributed
execution.
execution.
log_stats: Whether to log statistics.
log_stats: Whether to log statistics.
usage_context: Specified entry point, used for usage info collection
usage_context: Specified entry point, used for usage info collection
.
"""
"""
DO_VALIDATE_OUTPUT
:
ClassVar
[
bool
]
=
False
"""A flag to toggle whether to validate the type of request output."""
@
classmethod
@
contextmanager
def
enable_output_validation
(
cls
):
cls
.
DO_VALIDATE_OUTPUT
=
True
yield
cls
.
DO_VALIDATE_OUTPUT
=
False
@
classmethod
def
validate_output
(
cls
,
output
:
object
,
output_type
:
Type
[
_O
],
)
->
_O
:
do_validate
=
cls
.
DO_VALIDATE_OUTPUT
if
((
TYPE_CHECKING
or
do_validate
)
and
not
isinstance
(
output
,
output_type
)):
raise
TypeError
(
f
"Expected output of type
{
output_type
}
, "
f
"but found type
{
type
(
output
)
}
"
)
return
output
@
classmethod
def
validate_outputs
(
cls
,
outputs
:
GenericSequence
[
object
],
output_type
:
Type
[
_O
],
)
->
List
[
_O
]:
do_validate
=
cls
.
DO_VALIDATE_OUTPUT
outputs_
:
List
[
_O
]
if
TYPE_CHECKING
or
do_validate
:
outputs_
=
[]
for
output
in
outputs
:
if
not
isinstance
(
output
,
output_type
):
raise
TypeError
(
f
"Expected output of type
{
output_type
}
, "
f
"but found type
{
type
(
output
)
}
"
)
outputs_
.
append
(
output
)
else
:
outputs_
=
outputs
return
outputs_
tokenizer
:
Optional
[
BaseTokenizerGroup
]
def
__init__
(
def
__init__
(
self
,
self
,
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
...
@@ -101,10 +162,11 @@ class LLMEngine:
...
@@ -101,10 +162,11 @@ class LLMEngine:
"Initializing an LLM engine (v%s) with config: "
"Initializing an LLM engine (v%s) with config: "
"model=%r, speculative_config=%r, tokenizer=%r, "
"model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
"tokenizer_revision=%s, trust_remote_code=%s, dtype=%s, "
"rope_scaling=%r, tokenizer_revision=%s, "
"max_seq_len=%d, download_dir=%r, load_format=%s, "
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
"tensor_parallel_size=%d, disable_custom_all_reduce=%s, "
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
"quantization=%s, enforce_eager=%s, kv_cache_dtype=%s, "
"disable_custom_all_reduce=%s, quantization=%s, "
"enforce_eager=%s, kv_cache_dtype=%s, "
"quantization_param_path=%s, device_config=%s, "
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, seed=%d, served_model_name=%s)"
,
"decoding_config=%r, seed=%d, served_model_name=%s)"
,
vllm
.
__version__
,
vllm
.
__version__
,
...
@@ -114,6 +176,7 @@ class LLMEngine:
...
@@ -114,6 +176,7 @@ class LLMEngine:
model_config
.
skip_tokenizer_init
,
model_config
.
skip_tokenizer_init
,
model_config
.
tokenizer_mode
,
model_config
.
tokenizer_mode
,
model_config
.
revision
,
model_config
.
revision
,
model_config
.
rope_scaling
,
model_config
.
tokenizer_revision
,
model_config
.
tokenizer_revision
,
model_config
.
trust_remote_code
,
model_config
.
trust_remote_code
,
model_config
.
dtype
,
model_config
.
dtype
,
...
@@ -146,12 +209,11 @@ class LLMEngine:
...
@@ -146,12 +209,11 @@ class LLMEngine:
self
.
log_stats
=
log_stats
self
.
log_stats
=
log_stats
if
not
self
.
model_config
.
skip_tokenizer_init
:
if
not
self
.
model_config
.
skip_tokenizer_init
:
self
.
tokenizer
:
BaseTokenizerGroup
self
.
tokenizer
=
self
.
_init_tokenizer
()
self
.
_init_tokenizer
()
self
.
detokenizer
=
Detokenizer
(
self
.
tokenizer
)
self
.
detokenizer
=
Detokenizer
(
self
.
tokenizer
)
else
:
else
:
self
.
detokenizer
=
None
self
.
tokenizer
=
None
self
.
tokenizer
=
None
self
.
detokenizer
=
None
self
.
seq_counter
=
Counter
()
self
.
seq_counter
=
Counter
()
self
.
generation_config_fields
=
_load_generation_config_dict
(
self
.
generation_config_fields
=
_load_generation_config_dict
(
...
@@ -169,7 +231,8 @@ class LLMEngine:
...
@@ -169,7 +231,8 @@ class LLMEngine:
load_config
=
load_config
,
load_config
=
load_config
,
)
)
self
.
_initialize_kv_caches
()
if
not
self
.
model_config
.
embedding_mode
:
self
.
_initialize_kv_caches
()
# If usage stat is enabled, collect relevant info.
# If usage stat is enabled, collect relevant info.
if
is_usage_stats_enabled
():
if
is_usage_stats_enabled
():
...
@@ -270,6 +333,8 @@ class LLMEngine:
...
@@ -270,6 +333,8 @@ class LLMEngine:
"""Creates an LLM engine from the engine arguments."""
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
# Create the engine configs.
engine_config
=
engine_args
.
create_engine_config
()
engine_config
=
engine_args
.
create_engine_config
()
distributed_executor_backend
=
(
engine_config
.
parallel_config
.
distributed_executor_backend
)
# Initialize the cluster and specify the executor class.
# Initialize the cluster and specify the executor class.
if
engine_config
.
device_config
.
device_type
==
"neuron"
:
if
engine_config
.
device_config
.
device_type
==
"neuron"
:
...
@@ -278,13 +343,15 @@ class LLMEngine:
...
@@ -278,13 +343,15 @@ class LLMEngine:
elif
engine_config
.
device_config
.
device_type
==
"cpu"
:
elif
engine_config
.
device_config
.
device_type
==
"cpu"
:
from
vllm.executor.cpu_executor
import
CPUExecutor
from
vllm.executor.cpu_executor
import
CPUExecutor
executor_class
=
CPUExecutor
executor_class
=
CPUExecutor
elif
engine_config
.
parallel_config
.
worker_use_
ray
:
elif
distributed_executor_backend
==
"
ray
"
:
initialize_ray_cluster
(
engine_config
.
parallel_config
)
initialize_ray_cluster
(
engine_config
.
parallel_config
)
from
vllm.executor.ray_gpu_executor
import
RayGPUExecutor
from
vllm.executor.ray_gpu_executor
import
RayGPUExecutor
executor_class
=
RayGPUExecutor
executor_class
=
RayGPUExecutor
elif
distributed_executor_backend
==
"mp"
:
from
vllm.executor.multiproc_gpu_executor
import
(
MultiprocessingGPUExecutor
)
executor_class
=
MultiprocessingGPUExecutor
else
:
else
:
assert
engine_config
.
parallel_config
.
world_size
==
1
,
(
"Ray is required if parallel_config.world_size > 1."
)
from
vllm.executor.gpu_executor
import
GPUExecutor
from
vllm.executor.gpu_executor
import
GPUExecutor
executor_class
=
GPUExecutor
executor_class
=
GPUExecutor
...
@@ -308,14 +375,26 @@ class LLMEngine:
...
@@ -308,14 +375,26 @@ class LLMEngine:
if
model_executor
:
=
getattr
(
self
,
"model_executor"
,
None
):
if
model_executor
:
=
getattr
(
self
,
"model_executor"
,
None
):
model_executor
.
shutdown
()
model_executor
.
shutdown
()
MISSING_TOKENIZER_GROUP_MSG
=
(
"Unable to get tokenizer because "
"skip_tokenizer_init is True"
)
def
get_tokenizer_group
(
self
,
fail_msg
:
str
=
MISSING_TOKENIZER_GROUP_MSG
)
->
BaseTokenizerGroup
:
if
self
.
tokenizer
is
None
:
raise
ValueError
(
fail_msg
)
return
self
.
tokenizer
def
get_tokenizer
(
self
)
->
"PreTrainedTokenizer"
:
def
get_tokenizer
(
self
)
->
"PreTrainedTokenizer"
:
return
self
.
tokenizer
.
get_lora_tokenizer
(
None
)
return
self
.
get_
tokenizer
_group
()
.
get_lora_tokenizer
(
None
)
def
get_tokenizer_for_seq
(
self
,
def
get_tokenizer_for_seq
(
self
,
sequence
:
Sequence
)
->
"PreTrainedTokenizer"
:
sequence
:
Sequence
)
->
"PreTrainedTokenizer"
:
return
self
.
tokenizer
.
get_lora_tokenizer
(
sequence
.
lora_request
)
return
self
.
get_tokenizer_group
().
get_lora_tokenizer
(
sequence
.
lora_request
)
def
_init_tokenizer
(
self
,
**
tokenizer_init_kwargs
):
def
_init_tokenizer
(
self
,
**
tokenizer_init_kwargs
)
->
BaseTokenizerGroup
:
init_kwargs
=
dict
(
init_kwargs
=
dict
(
tokenizer_id
=
self
.
model_config
.
tokenizer
,
tokenizer_id
=
self
.
model_config
.
tokenizer
,
enable_lora
=
bool
(
self
.
lora_config
),
enable_lora
=
bool
(
self
.
lora_config
),
...
@@ -325,8 +404,9 @@ class LLMEngine:
...
@@ -325,8 +404,9 @@ class LLMEngine:
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
revision
=
self
.
model_config
.
tokenizer_revision
)
revision
=
self
.
model_config
.
tokenizer_revision
)
init_kwargs
.
update
(
tokenizer_init_kwargs
)
init_kwargs
.
update
(
tokenizer_init_kwargs
)
self
.
tokenizer
=
get_tokenizer_group
(
self
.
parallel_config
.
tokenizer_pool_config
,
**
init_kwargs
)
return
get_tokenizer_group
(
self
.
parallel_config
.
tokenizer_pool_config
,
**
init_kwargs
)
def
_verify_args
(
self
)
->
None
:
def
_verify_args
(
self
)
->
None
:
self
.
model_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
self
.
model_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
...
@@ -336,29 +416,85 @@ class LLMEngine:
...
@@ -336,29 +416,85 @@ class LLMEngine:
self
.
lora_config
.
verify_with_scheduler_config
(
self
.
lora_config
.
verify_with_scheduler_config
(
self
.
scheduler_config
)
self
.
scheduler_config
)
def
encode_request
(
def
_get_eos_token_id
(
self
,
lora_request
:
Optional
[
LoRARequest
])
->
Optional
[
int
]:
if
self
.
tokenizer
is
None
:
logger
.
warning
(
"Using None for EOS token id because tokenizer "
"is not initialized"
)
return
None
return
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
).
eos_token_id
def
_add_processed_request
(
self
,
request_id
:
str
,
processed_inputs
:
LLMInputs
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
float
,
lora_request
:
Optional
[
LoRARequest
],
)
->
None
:
# Create the sequences.
block_size
=
self
.
cache_config
.
block_size
seq_id
=
next
(
self
.
seq_counter
)
eos_token_id
=
self
.
_get_eos_token_id
(
lora_request
)
seq
=
Sequence
(
seq_id
,
processed_inputs
,
block_size
,
eos_token_id
,
lora_request
)
# Create a SequenceGroup based on SamplingParams or PoolingParams
if
isinstance
(
params
,
SamplingParams
):
seq_group
=
self
.
_create_sequence_group_with_sampling
(
request_id
,
seq
,
params
,
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
)
elif
isinstance
(
params
,
PoolingParams
):
seq_group
=
self
.
_create_sequence_group_with_pooling
(
request_id
,
seq
,
params
,
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
)
else
:
raise
ValueError
(
"Either SamplingParams or PoolingParams must be provided."
)
# Add the sequence group to the scheduler.
self
.
scheduler
.
add_seq_group
(
seq_group
)
def
process_model_inputs
(
self
,
self
,
request_id
:
str
,
# pylint: disable=unused-argument
request_id
:
str
,
prompt
:
Optional
[
str
],
inputs
:
PromptInputs
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
):
)
->
LLMInputs
:
if
prompt_token_ids
is
None
:
if
isinstance
(
inputs
,
str
):
assert
prompt
is
not
None
inputs
=
{
"prompt"
:
inputs
}
prompt_token_ids
=
self
.
tokenizer
.
encode
(
request_id
=
request_id
,
prompt
=
prompt
,
if
"prompt_token_ids"
not
in
inputs
:
lora_request
=
lora_request
)
tokenizer
=
self
.
get_tokenizer_group
(
"prompts must be None if "
return
prompt_token_ids
"skip_tokenizer_init is True"
)
prompt_token_ids
=
tokenizer
.
encode
(
request_id
=
request_id
,
prompt
=
inputs
[
"prompt"
],
lora_request
=
lora_request
)
else
:
prompt_token_ids
=
inputs
[
"prompt_token_ids"
]
return
LLMInputs
(
prompt_token_ids
=
prompt_token_ids
,
prompt
=
inputs
.
get
(
"prompt"
),
multi_modal_data
=
inputs
.
get
(
"multi_modal_data"
))
def
add_request
(
def
add_request
(
self
,
self
,
request_id
:
str
,
request_id
:
str
,
prompt
:
Optional
[
str
],
inputs
:
PromptInputs
,
sampling_params
:
SamplingParams
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
)
->
None
:
)
->
None
:
"""Add a request to the engine's request pool.
"""Add a request to the engine's request pool.
...
@@ -368,14 +504,14 @@ class LLMEngine:
...
@@ -368,14 +504,14 @@ class LLMEngine:
Args:
Args:
request_id: The unique ID of the request.
request_id: The unique ID of the request.
prompt: The prompt string. Can be None if prompt_token_ids is
inputs: The inputs to the LLM. See
provided.
:class:`~vllm.inputs.PromptInputs`
sampling_params: The sampling parameters for text generation.
for more details about the format of each input.
prompt_token_ids: The token IDs of the prompt. If None, we
params: Parameters for sampling or pooling.
use the tokenizer to convert the prompts to token IDs.
:class:`~vllm.SamplingParams` for text generation.
:class:`~vllm.PoolingParams` for pooling.
arrival_time: The arrival time of the request. If None, we use
arrival_time: The arrival time of the request. If None, we use
the current monotonic time.
the current monotonic time.
multi_modal_data: Multi modal data per request.
Details:
Details:
- Set arrival_time to the current time if it is None.
- Set arrival_time to the current time if it is None.
...
@@ -404,6 +540,30 @@ class LLMEngine:
...
@@ -404,6 +540,30 @@ class LLMEngine:
if
lora_request
is
not
None
and
not
self
.
lora_config
:
if
lora_request
is
not
None
and
not
self
.
lora_config
:
raise
ValueError
(
f
"Got lora_request
{
lora_request
}
but LoRA is "
raise
ValueError
(
f
"Got lora_request
{
lora_request
}
but LoRA is "
"not enabled!"
)
"not enabled!"
)
if
arrival_time
is
None
:
arrival_time
=
time
.
time
()
processed_inputs
=
self
.
process_model_inputs
(
request_id
=
request_id
,
inputs
=
inputs
,
lora_request
=
lora_request
)
self
.
_add_processed_request
(
request_id
=
request_id
,
processed_inputs
=
processed_inputs
,
params
=
params
,
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
)
def
_create_sequence_group_with_sampling
(
self
,
request_id
:
str
,
seq
:
Sequence
,
sampling_params
:
SamplingParams
,
arrival_time
:
float
,
lora_request
:
Optional
[
LoRARequest
],
)
->
SequenceGroup
:
"""Creates a SequenceGroup with SamplingParams."""
max_logprobs
=
self
.
get_model_config
().
max_logprobs
max_logprobs
=
self
.
get_model_config
().
max_logprobs
if
(
sampling_params
.
logprobs
if
(
sampling_params
.
logprobs
and
sampling_params
.
logprobs
>
max_logprobs
)
or
(
and
sampling_params
.
logprobs
>
max_logprobs
)
or
(
...
@@ -411,26 +571,6 @@ class LLMEngine:
...
@@ -411,26 +571,6 @@ class LLMEngine:
and
sampling_params
.
prompt_logprobs
>
max_logprobs
):
and
sampling_params
.
prompt_logprobs
>
max_logprobs
):
raise
ValueError
(
f
"Cannot request more than "
raise
ValueError
(
f
"Cannot request more than "
f
"
{
max_logprobs
}
logprobs."
)
f
"
{
max_logprobs
}
logprobs."
)
if
arrival_time
is
None
:
arrival_time
=
time
.
time
()
prompt_token_ids
=
self
.
encode_request
(
request_id
=
request_id
,
prompt
=
prompt
,
prompt_token_ids
=
prompt_token_ids
,
lora_request
=
lora_request
)
# Create the sequences.
block_size
=
self
.
cache_config
.
block_size
seq_id
=
next
(
self
.
seq_counter
)
eos_token_id
=
None
if
self
.
tokenizer
:
eos_token_id
=
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
).
eos_token_id
else
:
logger
.
warning
(
"Use None for EOS token id because tokenizer is "
"not initialized"
)
seq
=
Sequence
(
seq_id
,
prompt
,
prompt_token_ids
,
block_size
,
eos_token_id
,
lora_request
)
# Defensive copy of SamplingParams, which are used by the sampler,
# Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects
# this doesn't deep-copy LogitsProcessor objects
...
@@ -443,11 +583,32 @@ class LLMEngine:
...
@@ -443,11 +583,32 @@ class LLMEngine:
self
.
generation_config_fields
)
self
.
generation_config_fields
)
# Create the sequence group.
# Create the sequence group.
seq_group
=
SequenceGroup
(
request_id
,
[
seq
],
sampling_params
,
seq_group
=
SequenceGroup
(
request_id
=
request_id
,
arrival_time
,
lora_request
,
multi_modal_data
)
seqs
=
[
seq
],
arrival_time
=
arrival_time
,
sampling_params
=
sampling_params
,
lora_request
=
lora_request
)
# Add the sequence group to the scheduler.
return
seq_group
self
.
scheduler
.
add_seq_group
(
seq_group
)
def
_create_sequence_group_with_pooling
(
self
,
request_id
:
str
,
seq
:
Sequence
,
pooling_params
:
PoolingParams
,
arrival_time
:
float
,
lora_request
:
Optional
[
LoRARequest
],
)
->
SequenceGroup
:
"""Creates a SequenceGroup with PoolingParams."""
# Defensive copy of PoolingParams, which are used by the pooler
pooling_params
=
pooling_params
.
clone
()
# Create the sequence group.
seq_group
=
SequenceGroup
(
request_id
=
request_id
,
seqs
=
[
seq
],
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
pooling_params
=
pooling_params
)
return
seq_group
def
abort_request
(
self
,
request_id
:
Union
[
str
,
Iterable
[
str
]])
->
None
:
def
abort_request
(
self
,
request_id
:
Union
[
str
,
Iterable
[
str
]])
->
None
:
"""Aborts a request(s) with the given ID.
"""Aborts a request(s) with the given ID.
...
@@ -484,13 +645,25 @@ class LLMEngine:
...
@@ -484,13 +645,25 @@ class LLMEngine:
"""Returns True if there are unfinished requests."""
"""Returns True if there are unfinished requests."""
return
self
.
scheduler
.
has_unfinished_seqs
()
return
self
.
scheduler
.
has_unfinished_seqs
()
def
_process_sequence_group_outputs
(
self
,
seq_group
:
SequenceGroup
,
outputs
:
List
[
EmbeddingSequenceGroupOutput
],
)
->
None
:
seq_group
.
embeddings
=
outputs
[
0
].
embeddings
for
seq
in
seq_group
.
get_seqs
():
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
return
def
_process_model_outputs
(
def
_process_model_outputs
(
self
,
self
,
output
:
List
[
SamplerOutput
],
output
:
GenericSequence
[
Union
[
SamplerOutput
,
PoolerOutput
]
],
scheduled_seq_groups
:
List
[
ScheduledSequenceGroup
],
scheduled_seq_groups
:
List
[
ScheduledSequenceGroup
],
ignored_seq_groups
:
List
[
SequenceGroup
],
ignored_seq_groups
:
List
[
SequenceGroup
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
List
[
RequestOutput
]:
)
->
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]
]:
"""Apply the model output to the sequences in the scheduled seq groups.
"""Apply the model output to the sequences in the scheduled seq groups.
Returns RequestOutputs that can be returned to the client.
Returns RequestOutputs that can be returned to the client.
...
@@ -501,7 +674,7 @@ class LLMEngine:
...
@@ -501,7 +674,7 @@ class LLMEngine:
# Organize outputs by [sequence group][step] instead of
# Organize outputs by [sequence group][step] instead of
# [step][sequence group].
# [step][sequence group].
output_by_sequence_group
=
create_output_by_sequence_group
(
output_by_sequence_group
=
create_output_by_sequence_group
(
sampler_outputs
=
output
,
num_seq_groups
=
len
(
scheduled_seq_groups
))
output
,
num_seq_groups
=
len
(
scheduled_seq_groups
))
# Update the scheduled sequence groups with the model outputs.
# Update the scheduled sequence groups with the model outputs.
for
scheduled_seq_group
,
outputs
,
seq_group_meta
in
zip
(
for
scheduled_seq_group
,
outputs
,
seq_group_meta
in
zip
(
...
@@ -510,6 +683,9 @@ class LLMEngine:
...
@@ -510,6 +683,9 @@ class LLMEngine:
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
update_num_computed_tokens
(
seq_group
.
update_num_computed_tokens
(
scheduled_seq_group
.
token_chunk_size
)
scheduled_seq_group
.
token_chunk_size
)
if
self
.
model_config
.
embedding_mode
:
self
.
_process_sequence_group_outputs
(
seq_group
,
outputs
)
continue
self
.
output_processor
.
process_prompt_logprob
(
seq_group
,
outputs
)
self
.
output_processor
.
process_prompt_logprob
(
seq_group
,
outputs
)
if
seq_group_meta
.
do_sample
:
if
seq_group_meta
.
do_sample
:
...
@@ -519,18 +695,19 @@ class LLMEngine:
...
@@ -519,18 +695,19 @@ class LLMEngine:
self
.
scheduler
.
free_finished_seq_groups
()
self
.
scheduler
.
free_finished_seq_groups
()
# Create the outputs.
# Create the outputs.
request_outputs
:
List
[
RequestOutput
]
=
[]
request_outputs
:
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]
=
[]
for
scheduled_seq_group
in
scheduled_seq_groups
:
for
scheduled_seq_group
in
scheduled_seq_groups
:
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
maybe_set_first_token_time
(
now
)
seq_group
.
maybe_set_first_token_time
(
now
)
request_output
=
RequestOutput
.
from_seq_group
(
seq_group
)
request_output
=
RequestOutput
Factory
.
create
(
seq_group
)
request_outputs
.
append
(
request_output
)
request_outputs
.
append
(
request_output
)
for
seq_group
in
ignored_seq_groups
:
for
seq_group
in
ignored_seq_groups
:
request_output
=
RequestOutput
.
from_seq_group
(
seq_group
)
request_output
=
RequestOutput
Factory
.
create
(
seq_group
)
request_outputs
.
append
(
request_output
)
request_outputs
.
append
(
request_output
)
return
request_outputs
return
request_outputs
def
step
(
self
)
->
List
[
RequestOutput
]:
def
step
(
self
)
->
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]
]:
"""Performs one decoding iteration and returns newly generated results.
"""Performs one decoding iteration and returns newly generated results.
.. figure:: https://i.imgur.com/sv2HssD.png
.. figure:: https://i.imgur.com/sv2HssD.png
...
@@ -570,7 +747,7 @@ class LLMEngine:
...
@@ -570,7 +747,7 @@ class LLMEngine:
>>> while True:
>>> while True:
>>> if example_inputs:
>>> if example_inputs:
>>> req_id, prompt, sampling_params = example_inputs.pop(0)
>>> req_id, prompt, sampling_params = example_inputs.pop(0)
>>> engine.add_request(str(req_id),
prompt,
sampling_params)
>>> engine.add_request(str(req_id),prompt,sampling_params)
>>>
>>>
>>> # continue the request processing
>>> # continue the request processing
>>> request_outputs = engine.step()
>>> request_outputs = engine.step()
...
@@ -604,6 +781,14 @@ class LLMEngine:
...
@@ -604,6 +781,14 @@ class LLMEngine:
# Log stats.
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
output
)
self
.
do_log_stats
(
scheduler_outputs
,
output
)
if
not
request_outputs
:
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
# the RPC thread in the workers so that they can process any other
# queued control plane messages, such as add/remove lora adapters.
self
.
model_executor
.
stop_remote_worker_execution_loop
()
return
request_outputs
return
request_outputs
def
do_log_stats
(
def
do_log_stats
(
...
@@ -637,12 +822,15 @@ class LLMEngine:
...
@@ -637,12 +822,15 @@ class LLMEngine:
# KV Cache Usage in %
# KV Cache Usage in %
num_total_gpu
=
self
.
cache_config
.
num_gpu_blocks
num_total_gpu
=
self
.
cache_config
.
num_gpu_blocks
num_free_gpu
=
self
.
scheduler
.
block_manager
.
get_num_free_gpu_blocks
()
gpu_cache_usage_sys
=
0.
gpu_cache_usage_sys
=
1.0
-
(
num_free_gpu
/
num_total_gpu
)
if
num_total_gpu
is
not
None
:
num_free_gpu
=
self
.
scheduler
.
block_manager
.
get_num_free_gpu_blocks
(
)
gpu_cache_usage_sys
=
1.0
-
(
num_free_gpu
/
num_total_gpu
)
num_total_cpu
=
self
.
cache_config
.
num_cpu_blocks
num_total_cpu
=
self
.
cache_config
.
num_cpu_blocks
cpu_cache_usage_sys
=
0.
cpu_cache_usage_sys
=
0.
if
num_total_cpu
>
0
:
if
num_total_cpu
is
not
None
and
num_total_cpu
>
0
:
num_free_cpu
=
self
.
scheduler
.
block_manager
.
get_num_free_cpu_blocks
(
num_free_cpu
=
self
.
scheduler
.
block_manager
.
get_num_free_cpu_blocks
(
)
)
cpu_cache_usage_sys
=
1.0
-
(
num_free_cpu
/
num_total_cpu
)
cpu_cache_usage_sys
=
1.0
-
(
num_free_cpu
/
num_total_cpu
)
...
@@ -652,6 +840,8 @@ class LLMEngine:
...
@@ -652,6 +840,8 @@ class LLMEngine:
num_generation_tokens_iter
=
0
num_generation_tokens_iter
=
0
time_to_first_tokens_iter
:
List
[
float
]
=
[]
time_to_first_tokens_iter
:
List
[
float
]
=
[]
time_per_output_tokens_iter
:
List
[
float
]
=
[]
time_per_output_tokens_iter
:
List
[
float
]
=
[]
num_preemption_iter
=
(
0
if
scheduler_outputs
is
None
else
scheduler_outputs
.
preempted
)
# Request stats
# Request stats
# Latency
# Latency
...
@@ -716,8 +906,10 @@ class LLMEngine:
...
@@ -716,8 +906,10 @@ class LLMEngine:
seq
.
get_output_len
()
seq
.
get_output_len
()
for
seq
in
seq_group
.
get_finished_seqs
()
for
seq
in
seq_group
.
get_finished_seqs
()
])
])
best_of_requests
.
append
(
seq_group
.
sampling_params
.
best_of
)
if
seq_group
.
sampling_params
is
not
None
:
n_requests
.
append
(
seq_group
.
sampling_params
.
n
)
best_of_requests
.
append
(
seq_group
.
sampling_params
.
best_of
)
n_requests
.
append
(
seq_group
.
sampling_params
.
n
)
finished_reason_requests
.
extend
([
finished_reason_requests
.
extend
([
SequenceStatus
.
get_finished_reason
(
seq
.
status
)
SequenceStatus
.
get_finished_reason
(
seq
.
status
)
for
seq
in
seq_group
.
get_finished_seqs
()
for
seq
in
seq_group
.
get_finished_seqs
()
...
@@ -743,7 +935,6 @@ class LLMEngine:
...
@@ -743,7 +935,6 @@ class LLMEngine:
return
Stats
(
return
Stats
(
now
=
now
,
now
=
now
,
# System stats
# System stats
# Scheduler State
# Scheduler State
num_running_sys
=
num_running_sys
,
num_running_sys
=
num_running_sys
,
...
@@ -759,6 +950,7 @@ class LLMEngine:
...
@@ -759,6 +950,7 @@ class LLMEngine:
time_to_first_tokens_iter
=
time_to_first_tokens_iter
,
time_to_first_tokens_iter
=
time_to_first_tokens_iter
,
time_per_output_tokens_iter
=
time_per_output_tokens_iter
,
time_per_output_tokens_iter
=
time_per_output_tokens_iter
,
spec_decode_metrics
=
spec_decode_metrics
,
spec_decode_metrics
=
spec_decode_metrics
,
num_preemption_iter
=
num_preemption_iter
,
# Request stats
# Request stats
# Latency
# Latency
...
...
vllm/engine/metrics.py
View file @
b9e12416
...
@@ -61,6 +61,10 @@ class Metrics:
...
@@ -61,6 +61,10 @@ class Metrics:
labelnames
=
labelnames
)
labelnames
=
labelnames
)
# Iteration stats
# Iteration stats
self
.
counter_num_preemption
=
Counter
(
name
=
"vllm:num_preemptions_total"
,
documentation
=
"Cumulative number of preemption from the engine."
,
labelnames
=
labelnames
)
self
.
counter_prompt_tokens
=
Counter
(
self
.
counter_prompt_tokens
=
Counter
(
name
=
"vllm:prompt_tokens_total"
,
name
=
"vllm:prompt_tokens_total"
,
documentation
=
"Number of prefill tokens processed."
,
documentation
=
"Number of prefill tokens processed."
,
...
@@ -181,6 +185,7 @@ class Stats:
...
@@ -181,6 +185,7 @@ class Stats:
num_generation_tokens_iter
:
int
num_generation_tokens_iter
:
int
time_to_first_tokens_iter
:
List
[
float
]
time_to_first_tokens_iter
:
List
[
float
]
time_per_output_tokens_iter
:
List
[
float
]
time_per_output_tokens_iter
:
List
[
float
]
num_preemption_iter
:
int
# Request stats (should have _requests suffix)
# Request stats (should have _requests suffix)
# Latency
# Latency
...
@@ -244,6 +249,8 @@ class StatLogger:
...
@@ -244,6 +249,8 @@ class StatLogger:
stats
.
cpu_cache_usage_sys
)
stats
.
cpu_cache_usage_sys
)
# Iteration level data
# Iteration level data
self
.
_log_counter
(
self
.
metrics
.
counter_num_preemption
,
stats
.
num_preemption_iter
)
self
.
_log_counter
(
self
.
metrics
.
counter_prompt_tokens
,
self
.
_log_counter
(
self
.
metrics
.
counter_prompt_tokens
,
stats
.
num_prompt_tokens_iter
)
stats
.
num_prompt_tokens_iter
)
self
.
_log_counter
(
self
.
metrics
.
counter_generation_tokens
,
self
.
_log_counter
(
self
.
metrics
.
counter_generation_tokens
,
...
@@ -336,7 +343,7 @@ class StatLogger:
...
@@ -336,7 +343,7 @@ class StatLogger:
"Avg generation throughput: %.1f tokens/s, "
"Avg generation throughput: %.1f tokens/s, "
"Running: %d reqs, Swapped: %d reqs, "
"Running: %d reqs, Swapped: %d reqs, "
"Pending: %d reqs, GPU KV cache usage: %.1f%%, "
"Pending: %d reqs, GPU KV cache usage: %.1f%%, "
"CPU KV cache usage: %.1f%%"
,
"CPU KV cache usage: %.1f%%
.
"
,
prompt_throughput
,
prompt_throughput
,
generation_throughput
,
generation_throughput
,
stats
.
num_running_sys
,
stats
.
num_running_sys
,
...
...
vllm/engine/output_processor/multi_step.py
View file @
b9e12416
...
@@ -131,10 +131,12 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -131,10 +131,12 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
new_char_count
=
self
.
detokenizer
.
decode_sequence_inplace
(
new_char_count
=
self
.
detokenizer
.
decode_sequence_inplace
(
seq
,
sampling_params
)
seq
,
sampling_params
)
# TODO(sang): Support lora.
self
.
stop_checker
.
maybe_stop_sequence
(
self
.
stop_checker
.
maybe_stop_sequence
(
seq
,
seq
,
new_char_count
=
new_char_count
,
new_char_count
=
new_char_count
,
sampling_params
=
sampling_params
)
sampling_params
=
sampling_params
,
)
if
seq
.
is_finished
():
if
seq
.
is_finished
():
break
break
...
...
vllm/engine/output_processor/single_step.py
View file @
b9e12416
...
@@ -118,8 +118,12 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -118,8 +118,12 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
seq
,
seq_group
.
sampling_params
)
seq
,
seq_group
.
sampling_params
)
else
:
else
:
new_char_count
=
0
new_char_count
=
0
self
.
stop_checker
.
maybe_stop_sequence
(
seq
,
new_char_count
,
self
.
stop_checker
.
maybe_stop_sequence
(
seq_group
.
sampling_params
)
seq
,
new_char_count
,
seq_group
.
sampling_params
,
lora_req
=
seq_group
.
lora_request
,
)
# Non-beam search case
# Non-beam search case
if
not
seq_group
.
sampling_params
.
use_beam_search
:
if
not
seq_group
.
sampling_params
.
use_beam_search
:
...
...
vllm/engine/output_processor/stop_checker.py
View file @
b9e12416
...
@@ -2,6 +2,7 @@ from typing import Callable, Optional
...
@@ -2,6 +2,7 @@ from typing import Callable, Optional
from
transformers
import
PreTrainedTokenizer
from
transformers
import
PreTrainedTokenizer
from
vllm.lora.request
import
LoRARequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
Sequence
,
SequenceStatus
from
vllm.sequence
import
Sequence
,
SequenceStatus
...
@@ -16,11 +17,23 @@ class StopChecker:
...
@@ -16,11 +17,23 @@ class StopChecker:
def
__init__
(
self
,
max_model_len
:
int
,
def
__init__
(
self
,
max_model_len
:
int
,
get_tokenizer_for_seq
:
Callable
[[
Sequence
],
get_tokenizer_for_seq
:
Callable
[[
Sequence
],
PreTrainedTokenizer
]):
PreTrainedTokenizer
]):
self
.
max_model_len
=
max_model_len
# Do not use it directly, but use `self._get_max_model_len`.
self
.
_max_model_len
=
max_model_len
self
.
get_tokenizer_for_seq
=
get_tokenizer_for_seq
self
.
get_tokenizer_for_seq
=
get_tokenizer_for_seq
def
maybe_stop_sequence
(
self
,
seq
:
Sequence
,
new_char_count
:
int
,
def
_get_max_model_len
(
self
,
lora_req
:
Optional
[
LoRARequest
]):
sampling_params
:
SamplingParams
)
->
None
:
if
lora_req
and
lora_req
.
long_lora_max_len
:
return
lora_req
.
long_lora_max_len
else
:
return
self
.
_max_model_len
def
maybe_stop_sequence
(
self
,
seq
:
Sequence
,
new_char_count
:
int
,
sampling_params
:
SamplingParams
,
lora_req
:
Optional
[
LoRARequest
]
=
None
,
)
->
None
:
"""Stop the finished sequences.
"""Stop the finished sequences.
new_char_count is the number of chars added to the
new_char_count is the number of chars added to the
...
@@ -35,6 +48,11 @@ class StopChecker:
...
@@ -35,6 +48,11 @@ class StopChecker:
# Check if the sequence has generated the EOS token.
# Check if the sequence has generated the EOS token.
if
((
not
sampling_params
.
ignore_eos
)
if
((
not
sampling_params
.
ignore_eos
)
and
seq
.
get_last_token_id
()
==
seq
.
eos_token_id
):
and
seq
.
get_last_token_id
()
==
seq
.
eos_token_id
):
# Remove the last EOS token unless explicitly specified
# This prevents unintended exposure of the EOS token
if
new_char_count
and
(
not
sampling_params
.
include_stop_str_in_output
):
seq
.
output_text
=
seq
.
output_text
[:
-
new_char_count
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
return
return
...
@@ -59,7 +77,7 @@ class StopChecker:
...
@@ -59,7 +77,7 @@ class StopChecker:
return
return
# Check if the sequence has reached max_model_len.
# Check if the sequence has reached max_model_len.
if
seq
.
get_len
()
>
self
.
max_model_len
:
if
seq
.
get_len
()
>
self
.
_get_
max_model_len
(
lora_req
)
:
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
return
return
...
...
vllm/engine/output_processor/util.py
View file @
b9e12416
from
typing
import
List
from
typing
import
List
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Union
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupOutput
from
vllm.sequence
import
PoolerOutput
,
SamplerOutput
,
SequenceGroupOutput
def
create_output_by_sequence_group
(
def
create_output_by_sequence_group
(
sampler_
outputs
:
List
[
SamplerOutput
],
outputs
:
GenericSequence
[
Union
[
SamplerOutput
,
PoolerOutput
]
],
num_seq_groups
:
int
)
->
List
[
List
[
SequenceGroupOutput
]]:
num_seq_groups
:
int
)
->
List
[
List
[
SequenceGroupOutput
]]:
"""Helper method which transforms a 2d list organized by
"""Helper method which transforms a 2d list organized by
[step][sequence group] into [sequence group][step].
[step][sequence group] into [sequence group][step].
"""
"""
output_by_sequence_group
:
List
[
List
[
S
ampler
Output
]]
=
[
output_by_sequence_group
:
List
[
List
[
S
equenceGroup
Output
]]
=
[
[]
for
_
in
range
(
num_seq_groups
)
[]
for
_
in
range
(
num_seq_groups
)
]
]
for
step
in
sampler_
outputs
:
for
step
in
outputs
:
for
i
,
sequence_group_output
in
enumerate
(
step
):
for
i
,
sequence_group_output
in
enumerate
(
step
):
output_by_sequence_group
[
i
].
append
(
sequence_group_output
)
output_by_sequence_group
[
i
].
append
(
sequence_group_output
)
...
...
vllm/entrypoints/llm.py
View file @
b9e12416
from
typing
import
List
,
Optional
,
Union
from
contextlib
import
contextmanager
from
typing
import
ClassVar
,
List
,
Optional
,
Sequence
,
Union
,
cast
,
overload
import
torch
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.inputs
import
(
PromptInputs
,
PromptStrictInputs
,
TextPrompt
,
TextTokensPrompt
,
TokensPrompt
,
parse_and_batch_prompt
)
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
MultiModalData
from
vllm.sequence
import
MultiModalData
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
Counter
from
vllm.utils
import
Counter
,
deprecate_kwargs
logger
=
init_logger
(
__name__
)
class
LLM
:
class
LLM
:
...
@@ -23,10 +30,6 @@ class LLM:
...
@@ -23,10 +30,6 @@ class LLM:
this class generates texts from the model, using an intelligent batching
this class generates texts from the model, using an intelligent batching
mechanism and efficient memory management.
mechanism and efficient memory management.
NOTE: This class is intended to be used for offline inference. For online
serving, use the `AsyncLLMEngine` class instead.
NOTE: For the comprehensive list of arguments, see `EngineArgs`.
Args:
Args:
model: The name or path of a HuggingFace Transformers model.
model: The name or path of a HuggingFace Transformers model.
tokenizer: The name or path of a HuggingFace Transformers tokenizer.
tokenizer: The name or path of a HuggingFace Transformers tokenizer.
...
@@ -75,8 +78,26 @@ class LLM:
...
@@ -75,8 +78,26 @@ class LLM:
When a sequence has context length larger than this, we fall back
When a sequence has context length larger than this, we fall back
to eager mode.
to eager mode.
disable_custom_all_reduce: See ParallelConfig
disable_custom_all_reduce: See ParallelConfig
**kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
:ref:`engine_args`)
Note:
This class is intended to be used for offline inference. For online
serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
"""
"""
DEPRECATE_LEGACY
:
ClassVar
[
bool
]
=
False
"""A flag to toggle whether to deprecate the legacy generate/encode API."""
@
classmethod
@
contextmanager
def
deprecate_legacy_api
(
cls
):
cls
.
DEPRECATE_LEGACY
=
True
yield
cls
.
DEPRECATE_LEGACY
=
False
def
__init__
(
def
__init__
(
self
,
self
,
model
:
str
,
model
:
str
,
...
@@ -134,126 +155,415 @@ class LLM:
...
@@ -134,126 +155,415 @@ class LLM:
)
->
None
:
)
->
None
:
self
.
llm_engine
.
tokenizer
.
tokenizer
=
tokenizer
self
.
llm_engine
.
tokenizer
.
tokenizer
=
tokenizer
@
overload
# LEGACY: single (prompt + optional token ids)
def
generate
(
def
generate
(
self
,
self
,
prompts
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
prompts
:
str
,
sampling_params
:
Optional
[
Union
[
SamplingParams
,
List
[
SamplingParams
]]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
)
->
List
[
RequestOutput
]:
...
@
overload
# LEGACY: multi (prompt + optional token ids)
def
generate
(
self
,
prompts
:
List
[
str
],
sampling_params
:
Optional
[
Union
[
SamplingParams
,
sampling_params
:
Optional
[
Union
[
SamplingParams
,
List
[
SamplingParams
]]]
=
None
,
List
[
SamplingParams
]]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
List
[
int
]]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
List
[
int
]]]
=
None
,
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
)
->
List
[
RequestOutput
]:
...
@
overload
# LEGACY: single (token ids + optional prompt)
def
generate
(
self
,
prompts
:
Optional
[
str
]
=
None
,
sampling_params
:
Optional
[
Union
[
SamplingParams
,
List
[
SamplingParams
]]]
=
None
,
*
,
prompt_token_ids
:
List
[
int
],
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
)
->
List
[
RequestOutput
]:
...
@
overload
# LEGACY: multi (token ids + optional prompt)
def
generate
(
self
,
prompts
:
Optional
[
List
[
str
]]
=
None
,
sampling_params
:
Optional
[
Union
[
SamplingParams
,
List
[
SamplingParams
]]]
=
None
,
*
,
prompt_token_ids
:
List
[
List
[
int
]],
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
)
->
List
[
RequestOutput
]:
...
@
overload
# LEGACY: single or multi token ids [pos-only]
def
generate
(
self
,
prompts
:
None
,
sampling_params
:
None
,
prompt_token_ids
:
Union
[
List
[
int
],
List
[
List
[
int
]]],
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
)
->
List
[
RequestOutput
]:
...
@
overload
def
generate
(
self
,
inputs
:
Union
[
PromptStrictInputs
,
Sequence
[
PromptStrictInputs
]],
/
,
# We may enable `inputs` keyword after removing the old API
*
,
sampling_params
:
Optional
[
Union
[
SamplingParams
,
Sequence
[
SamplingParams
]]]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
List
[
RequestOutput
]:
...
@
deprecate_kwargs
(
"prompts"
,
"prompt_token_ids"
,
"multi_modal_data"
,
is_deprecated
=
lambda
:
LLM
.
DEPRECATE_LEGACY
,
additional_message
=
"Please use the 'inputs' parameter "
"instead."
)
def
generate
(
self
,
prompts
:
Union
[
Union
[
PromptStrictInputs
,
Sequence
[
PromptStrictInputs
]],
Optional
[
Union
[
str
,
List
[
str
]]]]
=
None
,
sampling_params
:
Optional
[
Union
[
SamplingParams
,
Sequence
[
SamplingParams
]]]
=
None
,
prompt_token_ids
:
Optional
[
Union
[
List
[
int
],
List
[
List
[
int
]]]]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
)
->
List
[
RequestOutput
]:
)
->
List
[
RequestOutput
]:
"""Generates the completions for the input prompts.
"""Generates the completions for the input prompts.
NOTE:
This class automatically batches the given prompts, considering
This class automatically batches the given prompts, considering
the memory constraint. For the best performance, put all of your prompts
the memory constraint. For the best performance, put all of your prompts
into a single list and pass it to this method.
into a single list and pass it to this method.
Args:
Args:
promp
ts: A list of
promp
ts to generate completions for.
inpu
ts: A list of
inpu
ts to generate completions for.
sampling_params: The sampling parameters for text generation. If
sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters.
None, we use the default sampling parameters.
When it is a single value, it is applied to every prompt.
When it is a single value, it is applied to every prompt.
When it is a list, the list must have the same length as the
When it is a list, the list must have the same length as the
prompts and it is paired one by one with the prompt.
prompts and it is paired one by one with the prompt.
prompt_token_ids: A list of token IDs for the prompts. If None, we
use the tokenizer to convert the prompts to token IDs.
use_tqdm: Whether to use tqdm to display the progress bar.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data.
Returns:
Returns:
A list of `RequestOutput` objects containing the generated
A list of `RequestOutput` objects containing the
completions in the same order as the input prompts.
generated completions in the same order as the input prompts.
"""
if
prompts
is
None
and
prompt_token_ids
is
None
:
raise
ValueError
(
"Either prompts or prompt_token_ids must be "
"provided."
)
if
self
.
llm_engine
.
model_config
.
skip_tokenizer_init
\
and
prompts
is
not
None
:
raise
ValueError
(
"prompts must be None if skip_tokenizer_init "
"is True"
)
if
isinstance
(
prompts
,
str
):
# Convert a single prompt to a list.
prompts
=
[
prompts
]
if
(
prompts
is
not
None
and
prompt_token_ids
is
not
None
and
len
(
prompts
)
!=
len
(
prompt_token_ids
)):
raise
ValueError
(
"The lengths of prompts and prompt_token_ids "
"must be the same."
)
if
prompts
is
not
None
:
Note:
num_requests
=
len
(
prompts
)
Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter.
"""
if
prompt_token_ids
is
not
None
or
multi_modal_data
is
not
None
:
inputs
=
self
.
_convert_v1_inputs
(
prompts
=
cast
(
Optional
[
Union
[
str
,
List
[
str
]]],
prompts
),
prompt_token_ids
=
prompt_token_ids
,
multi_modal_data
=
multi_modal_data
,
)
else
:
else
:
assert
prompt_token_ids
is
not
None
inputs
=
cast
(
num_requests
=
len
(
prompt_token_ids
)
Union
[
PromptStrictInputs
,
Sequence
[
PromptStrictInputs
]],
prompts
)
if
sampling_params
is
None
:
if
sampling_params
is
None
:
# Use default sampling params.
# Use default sampling params.
sampling_params
=
SamplingParams
()
sampling_params
=
SamplingParams
()
elif
isinstance
(
sampling_params
,
self
.
_validate_and_add_requests
(
list
)
and
len
(
sampling_params
)
!=
num_requests
:
inputs
=
inputs
,
raise
ValueError
(
"The lengths of prompts and sampling_params "
params
=
sampling_params
,
lora_request
=
lora_request
,
)
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
return
LLMEngine
.
validate_outputs
(
outputs
,
RequestOutput
)
@
overload
# LEGACY: single (prompt + optional token ids)
def
encode
(
self
,
prompts
:
str
,
pooling_params
:
Optional
[
Union
[
PoolingParams
,
Sequence
[
PoolingParams
]]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
)
->
List
[
EmbeddingRequestOutput
]:
...
@
overload
# LEGACY: multi (prompt + optional token ids)
def
encode
(
self
,
prompts
:
List
[
str
],
pooling_params
:
Optional
[
Union
[
PoolingParams
,
Sequence
[
PoolingParams
]]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
List
[
int
]]]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
)
->
List
[
EmbeddingRequestOutput
]:
...
@
overload
# LEGACY: single (token ids + optional prompt)
def
encode
(
self
,
prompts
:
Optional
[
str
]
=
None
,
pooling_params
:
Optional
[
Union
[
PoolingParams
,
Sequence
[
PoolingParams
]]]
=
None
,
*
,
prompt_token_ids
:
List
[
int
],
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
)
->
List
[
EmbeddingRequestOutput
]:
...
@
overload
# LEGACY: multi (token ids + optional prompt)
def
encode
(
self
,
prompts
:
Optional
[
List
[
str
]]
=
None
,
pooling_params
:
Optional
[
Union
[
PoolingParams
,
Sequence
[
PoolingParams
]]]
=
None
,
*
,
prompt_token_ids
:
List
[
List
[
int
]],
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
)
->
List
[
EmbeddingRequestOutput
]:
...
@
overload
# LEGACY: single or multi token ids [pos-only]
def
encode
(
self
,
prompts
:
None
,
pooling_params
:
None
,
prompt_token_ids
:
Union
[
List
[
int
],
List
[
List
[
int
]]],
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
)
->
List
[
EmbeddingRequestOutput
]:
...
@
overload
def
encode
(
self
,
inputs
:
Union
[
PromptStrictInputs
,
Sequence
[
PromptStrictInputs
]],
/
,
# We may enable `inputs` keyword after removing the old API
*
,
pooling_params
:
Optional
[
Union
[
PoolingParams
,
Sequence
[
PoolingParams
]]]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
List
[
EmbeddingRequestOutput
]:
...
@
deprecate_kwargs
(
"prompts"
,
"prompt_token_ids"
,
"multi_modal_data"
,
is_deprecated
=
lambda
:
LLM
.
DEPRECATE_LEGACY
,
additional_message
=
"Please use the 'inputs' parameter "
"instead."
)
def
encode
(
self
,
prompts
:
Union
[
Union
[
PromptStrictInputs
,
Sequence
[
PromptStrictInputs
]],
Optional
[
Union
[
str
,
List
[
str
]]]]
=
None
,
pooling_params
:
Optional
[
Union
[
PoolingParams
,
Sequence
[
PoolingParams
]]]
=
None
,
prompt_token_ids
:
Optional
[
Union
[
List
[
int
],
List
[
List
[
int
]]]]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
)
->
List
[
EmbeddingRequestOutput
]:
"""Generates the completions for the input prompts.
This class automatically batches the given prompts, considering
the memory constraint. For the best performance, put all of your prompts
into a single list and pass it to this method.
Args:
inputs: The inputs to the LLM. You may pass a sequence of inputs for
batch inference. See :class:`~vllm.inputs.PromptStrictInputs`
for more details about the format of each input.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
Returns:
A list of `EmbeddingRequestOutput` objects containing the
generated embeddings in the same order as the input prompts.
Note:
Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter.
"""
if
prompt_token_ids
is
not
None
or
multi_modal_data
is
not
None
:
inputs
=
self
.
_convert_v1_inputs
(
prompts
=
cast
(
Optional
[
Union
[
str
,
List
[
str
]]],
prompts
),
prompt_token_ids
=
prompt_token_ids
,
multi_modal_data
=
multi_modal_data
,
)
else
:
inputs
=
cast
(
Union
[
PromptStrictInputs
,
Sequence
[
PromptStrictInputs
]],
prompts
)
if
pooling_params
is
None
:
# Use default pooling params.
pooling_params
=
PoolingParams
()
self
.
_validate_and_add_requests
(
inputs
=
inputs
,
params
=
pooling_params
,
lora_request
=
lora_request
,
)
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
return
LLMEngine
.
validate_outputs
(
outputs
,
EmbeddingRequestOutput
)
# LEGACY
def
_convert_v1_inputs
(
self
,
prompts
:
Optional
[
Union
[
str
,
List
[
str
]]],
prompt_token_ids
:
Optional
[
Union
[
List
[
int
],
List
[
List
[
int
]]]],
multi_modal_data
:
Optional
[
MultiModalData
],
):
# skip_tokenizer_init is now checked in engine
if
prompts
is
not
None
:
prompts
=
[
p
[
"content"
]
for
p
in
parse_and_batch_prompt
(
prompts
)]
if
prompt_token_ids
is
not
None
:
prompt_token_ids
=
[
p
[
"content"
]
for
p
in
parse_and_batch_prompt
(
prompt_token_ids
)
]
num_requests
=
None
if
prompts
is
not
None
:
num_requests
=
len
(
prompts
)
if
prompt_token_ids
is
not
None
:
if
(
num_requests
is
not
None
and
num_requests
!=
len
(
prompt_token_ids
)):
raise
ValueError
(
"The lengths of prompts and prompt_token_ids "
"must be the same."
)
num_requests
=
len
(
prompt_token_ids
)
if
num_requests
is
None
:
raise
ValueError
(
"Either prompts or prompt_token_ids must be "
"provided."
)
inputs
:
List
[
PromptInputs
]
=
[]
for
i
in
range
(
num_requests
):
if
prompts
is
not
None
:
if
prompt_token_ids
is
not
None
:
item
=
TextTokensPrompt
(
prompt
=
prompts
[
i
],
prompt_token_ids
=
prompt_token_ids
[
i
])
else
:
item
=
TextPrompt
(
prompt
=
prompts
[
i
])
else
:
if
prompt_token_ids
is
not
None
:
item
=
TokensPrompt
(
prompt_token_ids
=
prompt_token_ids
[
i
])
else
:
raise
AssertionError
if
multi_modal_data
is
not
None
:
item
[
"multi_modal_data"
]
=
multi_modal_data
inputs
.
append
(
item
)
return
inputs
def
_validate_and_add_requests
(
self
,
inputs
:
Union
[
PromptStrictInputs
,
Sequence
[
PromptStrictInputs
]],
params
:
Union
[
SamplingParams
,
Sequence
[
SamplingParams
],
PoolingParams
,
Sequence
[
PoolingParams
]],
lora_request
:
Optional
[
LoRARequest
],
)
->
None
:
if
isinstance
(
inputs
,
(
str
,
dict
)):
# Convert a single prompt to a list.
inputs
=
[
inputs
]
num_requests
=
len
(
inputs
)
if
isinstance
(
params
,
list
)
and
len
(
params
)
!=
num_requests
:
raise
ValueError
(
"The lengths of prompts and params "
"must be the same."
)
"must be the same."
)
if
multi_modal_data
:
multi_modal_data
.
data
=
multi_modal_data
.
data
.
to
(
torch
.
float16
)
# Add requests to the engine.
# Add requests to the engine.
for
i
in
range
(
num_requests
):
for
i
,
request_inputs
in
enumerate
(
inputs
):
prompt
=
prompts
[
i
]
if
prompts
is
not
None
else
None
token_ids
=
None
if
prompt_token_ids
is
None
else
prompt_token_ids
[
i
]
self
.
_add_request
(
self
.
_add_request
(
prompt
,
request_inputs
,
sampling_params
[
i
]
params
[
i
]
if
isinstance
(
params
,
Sequence
)
else
params
,
if
isinstance
(
sampling_params
,
list
)
else
sampling_params
,
token_ids
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
# Get ith image while maintaining the batch dim.
multi_modal_data
=
MultiModalData
(
type
=
multi_modal_data
.
type
,
data
=
multi_modal_data
.
data
[
i
].
unsqueeze
(
0
))
if
multi_modal_data
else
None
,
)
)
return
self
.
_run_engine
(
use_tqdm
)
def
_add_request
(
def
_add_request
(
self
,
self
,
prompt
:
Optional
[
str
],
inputs
:
PromptInputs
,
sampling_params
:
SamplingParams
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
prompt_token_ids
:
Optional
[
List
[
int
]],
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
)
->
None
:
)
->
None
:
request_id
=
str
(
next
(
self
.
request_counter
))
request_id
=
str
(
next
(
self
.
request_counter
))
self
.
llm_engine
.
add_request
(
request_id
,
self
.
llm_engine
.
add_request
(
request_id
,
prompt
,
inputs
,
sampling_params
,
params
,
prompt_token_ids
,
lora_request
=
lora_request
)
lora_request
=
lora_request
,
multi_modal_data
=
multi_modal_data
)
def
_run_engine
(
self
,
use_tqdm
:
bool
)
->
List
[
RequestOutput
]:
def
_run_engine
(
self
,
*
,
use_tqdm
:
bool
)
->
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]:
# Initialize tqdm.
# Initialize tqdm.
if
use_tqdm
:
if
use_tqdm
:
num_requests
=
self
.
llm_engine
.
get_num_unfinished_requests
()
num_requests
=
self
.
llm_engine
.
get_num_unfinished_requests
()
pbar
=
tqdm
(
total
=
num_requests
,
pbar
=
tqdm
(
desc
=
"Processed prompts"
,
total
=
num_requests
,
dynamic_ncols
=
True
)
desc
=
"Processed prompts"
,
dynamic_ncols
=
True
,
postfix
=
f
"Generation Speed:
{
0
:.
2
f
}
toks/s"
,
)
# Run the engine.
# Run the engine.
outputs
:
List
[
RequestOutput
]
=
[]
outputs
:
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]
=
[]
total_toks
=
0
while
self
.
llm_engine
.
has_unfinished_requests
():
while
self
.
llm_engine
.
has_unfinished_requests
():
step_outputs
=
self
.
llm_engine
.
step
()
step_outputs
=
self
.
llm_engine
.
step
()
for
output
in
step_outputs
:
for
output
in
step_outputs
:
if
output
.
finished
:
if
output
.
finished
:
outputs
.
append
(
output
)
outputs
.
append
(
output
)
if
use_tqdm
:
if
use_tqdm
:
if
isinstance
(
output
,
RequestOutput
):
# Calculate tokens only for RequestOutput
total_toks
+=
sum
(
len
(
stp
.
token_ids
)
for
stp
in
output
.
outputs
)
spd
=
total_toks
/
pbar
.
format_dict
[
"elapsed"
]
pbar
.
postfix
=
f
"Generation Speed:
{
spd
:.
2
f
}
toks/s"
pbar
.
update
(
1
)
pbar
.
update
(
1
)
if
use_tqdm
:
if
use_tqdm
:
pbar
.
close
()
pbar
.
close
()
# Sort the outputs by request ID.
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# This is necessary because some requests may be finished earlier than
# its previous requests.
# its previous requests.
outputs
=
sorted
(
outputs
,
key
=
lambda
x
:
int
(
x
.
request_id
))
return
sorted
(
outputs
,
key
=
lambda
x
:
int
(
x
.
request_id
))
return
outputs
\ No newline at end of file
vllm/entrypoints/openai/api_server.py
View file @
b9e12416
...
@@ -4,7 +4,7 @@ import inspect
...
@@ -4,7 +4,7 @@ import inspect
import
re
import
re
from
contextlib
import
asynccontextmanager
from
contextlib
import
asynccontextmanager
from
http
import
HTTPStatus
from
http
import
HTTPStatus
from
typing
import
Any
,
Set
from
typing
import
Optional
,
Set
import
fastapi
import
fastapi
import
uvicorn
import
uvicorn
...
@@ -22,9 +22,11 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
...
@@ -22,9 +22,11 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from
vllm.entrypoints.openai.cli_args
import
make_arg_parser
from
vllm.entrypoints.openai.cli_args
import
make_arg_parser
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
ChatCompletionResponse
,
ChatCompletionResponse
,
CompletionRequest
,
ErrorResponse
)
CompletionRequest
,
EmbeddingRequest
,
ErrorResponse
)
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_completion
import
OpenAIServingCompletion
from
vllm.entrypoints.openai.serving_completion
import
OpenAIServingCompletion
from
vllm.entrypoints.openai.serving_embedding
import
OpenAIServingEmbedding
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
...
@@ -32,9 +34,11 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds
...
@@ -32,9 +34,11 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds
openai_serving_chat
:
OpenAIServingChat
openai_serving_chat
:
OpenAIServingChat
openai_serving_completion
:
OpenAIServingCompletion
openai_serving_completion
:
OpenAIServingCompletion
openai_serving_embedding
:
OpenAIServingEmbedding
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_running_tasks
:
Set
[
asyncio
.
Task
[
Any
]
]
=
set
()
_running_tasks
:
Set
[
asyncio
.
Task
]
=
set
()
@
asynccontextmanager
@
asynccontextmanager
...
@@ -123,6 +127,17 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
...
@@ -123,6 +127,17 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
return
JSONResponse
(
content
=
generator
.
model_dump
())
return
JSONResponse
(
content
=
generator
.
model_dump
())
@
app
.
post
(
"/v1/embeddings"
)
async
def
create_embedding
(
request
:
EmbeddingRequest
,
raw_request
:
Request
):
generator
=
await
openai_serving_embedding
.
create_embedding
(
request
,
raw_request
)
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
code
)
else
:
return
JSONResponse
(
content
=
generator
.
model_dump
())
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
args
=
parse_args
()
args
=
parse_args
()
...
@@ -139,6 +154,8 @@ if __name__ == "__main__":
...
@@ -139,6 +154,8 @@ if __name__ == "__main__":
@
app
.
middleware
(
"http"
)
@
app
.
middleware
(
"http"
)
async
def
authentication
(
request
:
Request
,
call_next
):
async
def
authentication
(
request
:
Request
,
call_next
):
root_path
=
""
if
args
.
root_path
is
None
else
args
.
root_path
root_path
=
""
if
args
.
root_path
is
None
else
args
.
root_path
if
request
.
method
==
"OPTIONS"
:
return
await
call_next
(
request
)
if
not
request
.
url
.
path
.
startswith
(
f
"
{
root_path
}
/v1"
):
if
not
request
.
url
.
path
.
startswith
(
f
"
{
root_path
}
/v1"
):
return
await
call_next
(
request
)
return
await
call_next
(
request
)
if
request
.
headers
.
get
(
"Authorization"
)
!=
"Bearer "
+
token
:
if
request
.
headers
.
get
(
"Authorization"
)
!=
"Bearer "
+
token
:
...
@@ -164,16 +181,34 @@ if __name__ == "__main__":
...
@@ -164,16 +181,34 @@ if __name__ == "__main__":
served_model_names
=
args
.
served_model_name
served_model_names
=
args
.
served_model_name
else
:
else
:
served_model_names
=
[
args
.
model
]
served_model_names
=
[
args
.
model
]
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
engine
=
AsyncLLMEngine
.
from_engine_args
(
engine
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
,
usage_context
=
UsageContext
.
OPENAI_API_SERVER
)
engine_args
,
usage_context
=
UsageContext
.
OPENAI_API_SERVER
)
openai_serving_chat
=
OpenAIServingChat
(
engine
,
served_model_names
,
event_loop
:
Optional
[
asyncio
.
AbstractEventLoop
]
try
:
event_loop
=
asyncio
.
get_running_loop
()
except
RuntimeError
:
event_loop
=
None
if
event_loop
is
not
None
and
event_loop
.
is_running
():
# If the current is instanced by Ray Serve,
# there is already a running event loop
model_config
=
event_loop
.
run_until_complete
(
engine
.
get_model_config
())
else
:
# When using single vLLM without engine_use_ray
model_config
=
asyncio
.
run
(
engine
.
get_model_config
())
openai_serving_chat
=
OpenAIServingChat
(
engine
,
model_config
,
served_model_names
,
args
.
response_role
,
args
.
response_role
,
args
.
lora_modules
,
args
.
lora_modules
,
args
.
chat_template
)
args
.
chat_template
)
openai_serving_completion
=
OpenAIServingCompletion
(
openai_serving_completion
=
OpenAIServingCompletion
(
engine
,
served_model_names
,
args
.
lora_modules
)
engine
,
model_config
,
served_model_names
,
args
.
lora_modules
)
openai_serving_embedding
=
OpenAIServingEmbedding
(
engine
,
model_config
,
served_model_names
)
app
.
root_path
=
args
.
root_path
app
.
root_path
=
args
.
root_path
uvicorn
.
run
(
app
,
uvicorn
.
run
(
app
,
host
=
args
.
host
,
host
=
args
.
host
,
...
...
vllm/entrypoints/openai/protocol.py
View file @
b9e12416
# Adapted from
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import
time
import
time
from
typing
import
Dict
,
List
,
Literal
,
Optional
,
Union
from
typing
import
Any
,
Dict
,
List
,
Literal
,
Optional
,
Union
import
openai.types.chat
import
torch
import
torch
from
openai.types.chat
import
ChatCompletionMessageParam
from
pydantic
import
BaseModel
,
ConfigDict
,
Field
,
model_validator
from
pydantic
import
BaseModel
,
ConfigDict
,
Field
,
model_validator
from
typing_extensions
import
Annotated
# pydantic needs the TypedDict from typing_extensions
from
typing_extensions
import
Annotated
,
Required
,
TypedDict
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
random_uuid
from
vllm.utils
import
random_uuid
class
CustomChatCompletionContentPartParam
(
TypedDict
,
total
=
False
):
__pydantic_config__
=
ConfigDict
(
extra
=
"allow"
)
# type: ignore
type
:
Required
[
str
]
"""The type of the content part."""
ChatCompletionContentPartParam
=
Union
[
openai
.
types
.
chat
.
ChatCompletionContentPartParam
,
CustomChatCompletionContentPartParam
]
class
CustomChatCompletionMessageParam
(
TypedDict
,
total
=
False
):
"""Enables custom roles in the Chat Completion API."""
role
:
Required
[
str
]
"""The role of the message's author."""
content
:
Union
[
str
,
List
[
ChatCompletionContentPartParam
]]
"""The contents of the message."""
name
:
str
"""An optional name for the participant.
Provides the model information to differentiate between participants of the
same role.
"""
ChatCompletionMessageParam
=
Union
[
openai
.
types
.
chat
.
ChatCompletionMessageParam
,
CustomChatCompletionMessageParam
]
class
OpenAIBaseModel
(
BaseModel
):
class
OpenAIBaseModel
(
BaseModel
):
# OpenAI API does not allow extra fields
# OpenAI API does not allow extra fields
model_config
=
ConfigDict
(
extra
=
"forbid"
)
model_config
=
ConfigDict
(
extra
=
"forbid"
)
...
@@ -74,7 +109,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
...
@@ -74,7 +109,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
frequency_penalty
:
Optional
[
float
]
=
0.0
frequency_penalty
:
Optional
[
float
]
=
0.0
logit_bias
:
Optional
[
Dict
[
str
,
float
]]
=
None
logit_bias
:
Optional
[
Dict
[
str
,
float
]]
=
None
logprobs
:
Optional
[
bool
]
=
False
logprobs
:
Optional
[
bool
]
=
False
top_logprobs
:
Optional
[
int
]
=
None
top_logprobs
:
Optional
[
int
]
=
0
max_tokens
:
Optional
[
int
]
=
None
max_tokens
:
Optional
[
int
]
=
None
n
:
Optional
[
int
]
=
1
n
:
Optional
[
int
]
=
1
presence_penalty
:
Optional
[
float
]
=
0.0
presence_penalty
:
Optional
[
float
]
=
0.0
...
@@ -157,8 +192,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
...
@@ -157,8 +192,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: end-chat-completion-extra-params
# doc: end-chat-completion-extra-params
def
to_sampling_params
(
self
)
->
SamplingParams
:
def
to_sampling_params
(
self
)
->
SamplingParams
:
if
self
.
logprobs
and
not
self
.
top_logprobs
:
# We now allow logprobs being true without top_logrobs.
raise
ValueError
(
"Top logprobs must be set when logprobs is."
)
logits_processors
=
None
logits_processors
=
None
if
self
.
logit_bias
:
if
self
.
logit_bias
:
...
@@ -216,6 +250,19 @@ class ChatCompletionRequest(OpenAIBaseModel):
...
@@ -216,6 +250,19 @@ class ChatCompletionRequest(OpenAIBaseModel):
"('guided_json', 'guided_regex' or 'guided_choice')."
)
"('guided_json', 'guided_regex' or 'guided_choice')."
)
return
data
return
data
@
model_validator
(
mode
=
"before"
)
@
classmethod
def
check_logprobs
(
cls
,
data
):
if
"top_logprobs"
in
data
and
data
[
"top_logprobs"
]
is
not
None
:
if
"logprobs"
not
in
data
or
data
[
"logprobs"
]
is
False
:
raise
ValueError
(
"when using `top_logprobs`, `logprobs` must be set to true."
)
elif
not
0
<=
data
[
"top_logprobs"
]
<=
20
:
raise
ValueError
(
"`top_logprobs` must be a value in the interval [0, 20]."
)
return
data
class
CompletionRequest
(
OpenAIBaseModel
):
class
CompletionRequest
(
OpenAIBaseModel
):
# Ordered by official OpenAI API documentation
# Ordered by official OpenAI API documentation
...
@@ -362,8 +409,35 @@ class CompletionRequest(OpenAIBaseModel):
...
@@ -362,8 +409,35 @@ class CompletionRequest(OpenAIBaseModel):
"('guided_json', 'guided_regex' or 'guided_choice')."
)
"('guided_json', 'guided_regex' or 'guided_choice')."
)
return
data
return
data
@
model_validator
(
mode
=
"before"
)
@
classmethod
def
check_logprobs
(
cls
,
data
):
if
"logprobs"
in
data
and
data
[
"logprobs"
]
is
not
None
and
not
0
<=
data
[
"logprobs"
]
<=
5
:
raise
ValueError
((
"if passed, `logprobs` must be a value"
,
" in the interval [0, 5]."
))
return
data
class
LogProbs
(
OpenAIBaseModel
):
class
EmbeddingRequest
(
BaseModel
):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/embeddings
model
:
str
input
:
Union
[
List
[
int
],
List
[
List
[
int
]],
str
,
List
[
str
]]
encoding_format
:
Optional
[
str
]
=
Field
(
'float'
,
pattern
=
'^(float|base64)$'
)
dimensions
:
Optional
[
int
]
=
None
user
:
Optional
[
str
]
=
None
# doc: begin-embedding-pooling-params
additional_data
:
Optional
[
Any
]
=
None
# doc: end-embedding-pooling-params
def
to_pooling_params
(
self
):
return
PoolingParams
(
additional_data
=
self
.
additional_data
)
class
CompletionLogProbs
(
OpenAIBaseModel
):
text_offset
:
List
[
int
]
=
Field
(
default_factory
=
list
)
text_offset
:
List
[
int
]
=
Field
(
default_factory
=
list
)
token_logprobs
:
List
[
Optional
[
float
]]
=
Field
(
default_factory
=
list
)
token_logprobs
:
List
[
Optional
[
float
]]
=
Field
(
default_factory
=
list
)
tokens
:
List
[
str
]
=
Field
(
default_factory
=
list
)
tokens
:
List
[
str
]
=
Field
(
default_factory
=
list
)
...
@@ -373,7 +447,7 @@ class LogProbs(OpenAIBaseModel):
...
@@ -373,7 +447,7 @@ class LogProbs(OpenAIBaseModel):
class
CompletionResponseChoice
(
OpenAIBaseModel
):
class
CompletionResponseChoice
(
OpenAIBaseModel
):
index
:
int
index
:
int
text
:
str
text
:
str
logprobs
:
Optional
[
LogProbs
]
=
None
logprobs
:
Optional
[
Completion
LogProbs
]
=
None
finish_reason
:
Optional
[
str
]
=
None
finish_reason
:
Optional
[
str
]
=
None
stop_reason
:
Optional
[
Union
[
int
,
str
]]
=
Field
(
stop_reason
:
Optional
[
Union
[
int
,
str
]]
=
Field
(
default
=
None
,
default
=
None
,
...
@@ -396,7 +470,7 @@ class CompletionResponse(OpenAIBaseModel):
...
@@ -396,7 +470,7 @@ class CompletionResponse(OpenAIBaseModel):
class
CompletionResponseStreamChoice
(
OpenAIBaseModel
):
class
CompletionResponseStreamChoice
(
OpenAIBaseModel
):
index
:
int
index
:
int
text
:
str
text
:
str
logprobs
:
Optional
[
LogProbs
]
=
None
logprobs
:
Optional
[
Completion
LogProbs
]
=
None
finish_reason
:
Optional
[
str
]
=
None
finish_reason
:
Optional
[
str
]
=
None
stop_reason
:
Optional
[
Union
[
int
,
str
]]
=
Field
(
stop_reason
:
Optional
[
Union
[
int
,
str
]]
=
Field
(
default
=
None
,
default
=
None
,
...
@@ -416,16 +490,45 @@ class CompletionStreamResponse(OpenAIBaseModel):
...
@@ -416,16 +490,45 @@ class CompletionStreamResponse(OpenAIBaseModel):
usage
:
Optional
[
UsageInfo
]
=
Field
(
default
=
None
)
usage
:
Optional
[
UsageInfo
]
=
Field
(
default
=
None
)
class
EmbeddingResponseData
(
BaseModel
):
index
:
int
object
:
str
=
"embedding"
embedding
:
List
[
float
]
class
EmbeddingResponse
(
BaseModel
):
id
:
str
=
Field
(
default_factory
=
lambda
:
f
"cmpl-
{
random_uuid
()
}
"
)
object
:
str
=
"list"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
model
:
str
data
:
List
[
EmbeddingResponseData
]
usage
:
UsageInfo
class
ChatMessage
(
OpenAIBaseModel
):
class
ChatMessage
(
OpenAIBaseModel
):
role
:
str
role
:
str
content
:
str
content
:
str
class
ChatCompletionLogProb
(
OpenAIBaseModel
):
token
:
str
logprob
:
float
=
-
9999.0
bytes
:
Optional
[
List
[
int
]]
=
None
class
ChatCompletionLogProbsContent
(
ChatCompletionLogProb
):
top_logprobs
:
List
[
ChatCompletionLogProb
]
=
Field
(
default_factory
=
list
)
class
ChatCompletionLogProbs
(
OpenAIBaseModel
):
content
:
Optional
[
List
[
ChatCompletionLogProbsContent
]]
=
None
class
ChatCompletionResponseChoice
(
OpenAIBaseModel
):
class
ChatCompletionResponseChoice
(
OpenAIBaseModel
):
index
:
int
index
:
int
message
:
ChatMessage
message
:
ChatMessage
logprobs
:
Optional
[
LogProbs
]
=
None
logprobs
:
Optional
[
ChatCompletion
LogProbs
]
=
None
finish_reason
:
Optional
[
str
]
=
None
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
,
"tool_calls"
]
]
=
None
stop_reason
:
Optional
[
Union
[
int
,
str
]]
=
None
stop_reason
:
Optional
[
Union
[
int
,
str
]]
=
None
...
@@ -446,8 +549,8 @@ class DeltaMessage(OpenAIBaseModel):
...
@@ -446,8 +549,8 @@ class DeltaMessage(OpenAIBaseModel):
class
ChatCompletionResponseStreamChoice
(
OpenAIBaseModel
):
class
ChatCompletionResponseStreamChoice
(
OpenAIBaseModel
):
index
:
int
index
:
int
delta
:
DeltaMessage
delta
:
DeltaMessage
logprobs
:
Optional
[
LogProbs
]
=
None
logprobs
:
Optional
[
ChatCompletion
LogProbs
]
=
None
finish_reason
:
Optional
[
str
]
=
None
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
,
"tool_calls"
]
]
=
None
stop_reason
:
Optional
[
Union
[
int
,
str
]]
=
None
stop_reason
:
Optional
[
Union
[
int
,
str
]]
=
None
...
@@ -458,3 +561,44 @@ class ChatCompletionStreamResponse(OpenAIBaseModel):
...
@@ -458,3 +561,44 @@ class ChatCompletionStreamResponse(OpenAIBaseModel):
model
:
str
model
:
str
choices
:
List
[
ChatCompletionResponseStreamChoice
]
choices
:
List
[
ChatCompletionResponseStreamChoice
]
usage
:
Optional
[
UsageInfo
]
=
Field
(
default
=
None
)
usage
:
Optional
[
UsageInfo
]
=
Field
(
default
=
None
)
class
BatchRequestInput
(
OpenAIBaseModel
):
"""
The per-line object of the batch input file.
NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
"""
# A developer-provided per-request id that will be used to match outputs to
# inputs. Must be unique for each request in a batch.
custom_id
:
str
# The HTTP method to be used for the request. Currently only POST is
# supported.
method
:
str
# The OpenAI API relative URL to be used for the request. Currently
# /v1/chat/completions is supported.
url
:
str
# The parameteters of the request.
body
:
Union
[
ChatCompletionRequest
,
]
class
BatchRequestOutput
(
OpenAIBaseModel
):
"""
The per-line object of the batch output and error files
"""
id
:
str
# A developer-provided per-request id that will be used to match outputs to
# inputs.
custom_id
:
str
response
:
Optional
[
ChatCompletionResponse
]
# For requests that failed with a non-HTTP error, this will contain more
# information on the cause of the failure.
error
:
Optional
[
Any
]
vllm/entrypoints/openai/run_batch.py
0 → 100644
View file @
b9e12416
import
argparse
import
asyncio
import
sys
from
io
import
StringIO
import
aiohttp
import
vllm
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
nullable_str
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.openai.protocol
import
(
BatchRequestInput
,
BatchRequestOutput
,
ChatCompletionResponse
)
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.logger
import
init_logger
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
random_uuid
logger
=
init_logger
(
__name__
)
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
"vLLM OpenAI-Compatible batch runner."
)
parser
.
add_argument
(
"-i"
,
"--input-file"
,
required
=
True
,
type
=
str
,
help
=
"The path or url to a single input file. Currently supports local file "
"paths, or the http protocol (http or https). If a URL is specified, "
"the file should be available via HTTP GET."
)
parser
.
add_argument
(
"-o"
,
"--output-file"
,
required
=
True
,
type
=
str
,
help
=
"The path or url to a single output file. Currently supports "
"local file paths, or web (http or https) urls. If a URL is specified,"
" the file should be available via HTTP PUT."
)
parser
.
add_argument
(
"--response-role"
,
type
=
nullable_str
,
default
=
"assistant"
,
help
=
"The role name to return if "
"`request.add_generation_prompt=true`."
)
parser
=
AsyncEngineArgs
.
add_cli_args
(
parser
)
return
parser
.
parse_args
()
async
def
read_file
(
path_or_url
:
str
)
->
str
:
if
path_or_url
.
startswith
(
"http://"
)
or
path_or_url
.
startswith
(
"https://"
):
async
with
aiohttp
.
ClientSession
()
as
session
,
\
session
.
get
(
path_or_url
)
as
resp
:
return
await
resp
.
text
()
else
:
with
open
(
path_or_url
,
"r"
)
as
f
:
return
f
.
read
()
async
def
write_file
(
path_or_url
:
str
,
data
:
str
)
->
None
:
if
path_or_url
.
startswith
(
"http://"
)
or
path_or_url
.
startswith
(
"https://"
):
async
with
aiohttp
.
ClientSession
()
as
session
,
\
session
.
put
(
path_or_url
,
data
=
data
.
encode
(
"utf-8"
)):
pass
else
:
# We should make this async, but as long as this is always run as a
# standalone program, blocking the event loop won't effect performance
# in this particular case.
with
open
(
path_or_url
,
"w"
)
as
f
:
f
.
write
(
data
)
async
def
run_request
(
chat_serving
:
OpenAIServingChat
,
request
:
BatchRequestInput
)
->
BatchRequestOutput
:
chat_request
=
request
.
body
chat_response
=
await
chat_serving
.
create_chat_completion
(
chat_request
)
if
isinstance
(
chat_response
,
ChatCompletionResponse
):
batch_output
=
BatchRequestOutput
(
id
=
f
"vllm-
{
random_uuid
()
}
"
,
custom_id
=
request
.
custom_id
,
response
=
chat_response
,
error
=
None
,
)
else
:
batch_output
=
BatchRequestOutput
(
id
=
f
"vllm-
{
random_uuid
()
}
"
,
custom_id
=
request
.
custom_id
,
response
=
None
,
error
=
chat_response
,
)
return
batch_output
async
def
main
(
args
):
if
args
.
served_model_name
is
not
None
:
served_model_names
=
args
.
served_model_name
else
:
served_model_names
=
[
args
.
model
]
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
engine
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
,
usage_context
=
UsageContext
.
OPENAI_BATCH_RUNNER
)
# When using single vLLM without engine_use_ray
model_config
=
await
engine
.
get_model_config
()
openai_serving_chat
=
OpenAIServingChat
(
engine
,
model_config
,
served_model_names
,
args
.
response_role
,
)
# Submit all requests in the file to the engine "concurrently".
response_futures
=
[]
for
request_json
in
(
await
read_file
(
args
.
input_file
)).
strip
().
split
(
"
\n
"
):
request
=
BatchRequestInput
.
model_validate_json
(
request_json
)
response_futures
.
append
(
run_request
(
openai_serving_chat
,
request
))
responses
=
await
asyncio
.
gather
(
*
response_futures
)
output_buffer
=
StringIO
()
for
response
in
responses
:
print
(
response
.
model_dump_json
(),
file
=
output_buffer
)
output_buffer
.
seek
(
0
)
await
write_file
(
args
.
output_file
,
output_buffer
.
read
().
strip
())
# Temporary workaround for https://github.com/vllm-project/vllm/issues/4789
sys
.
exit
(
0
)
if
__name__
==
"__main__"
:
args
=
parse_args
()
logger
.
info
(
"vLLM API server version %s"
,
vllm
.
__version__
)
logger
.
info
(
"args: %s"
,
args
)
asyncio
.
run
(
main
(
args
))
vllm/entrypoints/openai/serving_chat.py
View file @
b9e12416
import
asyncio
import
codecs
import
codecs
import
time
import
time
from
typing
import
(
AsyncGenerator
,
AsyncIterator
,
Awaitable
,
Iterable
,
List
,
from
dataclasses
import
dataclass
Optional
,
Tuple
,
TypedDict
,
Union
,
final
)
from
typing
import
(
AsyncGenerator
,
AsyncIterator
,
Dict
,
Iterable
,
List
,
Optional
)
from
typing
import
Sequence
as
GenericSequence
from
typing
import
TypedDict
,
Union
,
cast
,
final
from
fastapi
import
Request
from
fastapi
import
Request
from
openai.types.chat
import
(
ChatCompletionContentPartParam
,
from
openai.types.chat
import
ChatCompletionContentPartTextParam
ChatCompletionRole
)
from
vllm.config
import
ModelConfig
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.openai.protocol
import
(
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
ChatCompletionResponse
,
ChatCompletionContentPartParam
,
ChatCompletionLogProb
,
ChatCompletionLogProbs
,
ChatCompletionLogProbsContent
,
ChatCompletionMessageParam
,
ChatCompletionRequest
,
ChatCompletionResponse
,
ChatCompletionResponseChoice
,
ChatCompletionResponseStreamChoice
,
ChatCompletionResponseChoice
,
ChatCompletionResponseStreamChoice
,
ChatCompletionStreamResponse
,
ChatMessage
,
DeltaMessage
,
ErrorResponse
,
ChatCompletionStreamResponse
,
ChatMessage
,
DeltaMessage
,
ErrorResponse
,
UsageInfo
)
UsageInfo
)
...
@@ -20,6 +24,7 @@ from vllm.logger import init_logger
...
@@ -20,6 +24,7 @@ from vllm.logger import init_logger
from
vllm.model_executor.guided_decoding
import
(
from
vllm.model_executor.guided_decoding
import
(
get_guided_decoding_logits_processor
)
get_guided_decoding_logits_processor
)
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sequence
import
Logprob
from
vllm.utils
import
random_uuid
from
vllm.utils
import
random_uuid
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -31,46 +36,96 @@ class ConversationMessage(TypedDict):
...
@@ -31,46 +36,96 @@ class ConversationMessage(TypedDict):
content
:
str
content
:
str
@
dataclass
(
frozen
=
True
)
class
ChatMessageParseResult
:
messages
:
List
[
ConversationMessage
]
class
OpenAIServingChat
(
OpenAIServing
):
class
OpenAIServingChat
(
OpenAIServing
):
def
__init__
(
self
,
def
__init__
(
self
,
engine
:
AsyncLLMEngine
,
engine
:
AsyncLLMEngine
,
model_config
:
ModelConfig
,
served_model_names
:
List
[
str
],
served_model_names
:
List
[
str
],
response_role
:
str
,
response_role
:
str
,
lora_modules
:
Optional
[
List
[
LoRAModulePath
]]
=
None
,
lora_modules
:
Optional
[
List
[
LoRAModulePath
]]
=
None
,
chat_template
:
Optional
[
str
]
=
None
):
chat_template
:
Optional
[
str
]
=
None
):
super
().
__init__
(
engine
=
engine
,
super
().
__init__
(
engine
=
engine
,
model_config
=
model_config
,
served_model_names
=
served_model_names
,
served_model_names
=
served_model_names
,
lora_modules
=
lora_modules
,
lora_modules
=
lora_modules
)
await_post_init
=
self
.
_load_chat_template
(
chat_template
=
chat_template
))
self
.
response_role
=
response_role
self
.
response_role
=
response_role
self
.
_load_chat_template
(
chat_template
)
def
_parse_chat_message_content
(
def
_load_chat_template
(
self
,
chat_template
:
Optional
[
str
]):
self
,
tokenizer
=
self
.
tokenizer
role
:
ChatCompletionRole
,
content
:
Optional
[
Union
[
str
,
if
chat_template
is
not
None
:
Iterable
[
ChatCompletionContentPartParam
]]],
try
:
)
->
Tuple
[
List
[
ConversationMessage
],
List
[
Awaitable
[
object
]]]:
with
open
(
chat_template
,
"r"
)
as
f
:
if
content
is
None
:
tokenizer
.
chat_template
=
f
.
read
()
return
[],
[]
except
OSError
as
e
:
if
isinstance
(
content
,
str
):
JINJA_CHARS
=
"{}
\n
"
return
[
ConversationMessage
(
role
=
role
,
content
=
content
)],
[]
if
not
any
(
c
in
chat_template
for
c
in
JINJA_CHARS
):
msg
=
(
f
"The supplied chat template (
{
chat_template
}
) "
f
"looks like a file path, but it failed to be "
f
"opened. Reason:
{
e
}
"
)
raise
ValueError
(
msg
)
from
e
# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
tokenizer
.
chat_template
=
codecs
.
decode
(
chat_template
,
"unicode_escape"
)
logger
.
info
(
"Using supplied chat template:
\n
%s"
,
tokenizer
.
chat_template
)
elif
tokenizer
.
chat_template
is
not
None
:
logger
.
info
(
"Using default chat template:
\n
%s"
,
tokenizer
.
chat_template
)
else
:
logger
.
warning
(
"No chat template provided. Chat API will not work."
)
def
_parse_chat_message_content_parts
(
self
,
role
:
str
,
parts
:
Iterable
[
ChatCompletionContentPartParam
],
)
->
ChatMessageParseResult
:
texts
:
List
[
str
]
=
[]
texts
:
List
[
str
]
=
[]
for
_
,
part
in
enumerate
(
content
):
if
part
[
"type"
]
==
"text"
:
for
_
,
part
in
enumerate
(
parts
):
text
=
part
[
"text"
]
part_type
=
part
[
"type"
]
if
part_type
==
"text"
:
text
=
cast
(
ChatCompletionContentPartTextParam
,
part
)[
"text"
]
texts
.
append
(
text
)
texts
.
append
(
text
)
else
:
else
:
raise
NotImplementedError
(
f
"Unknown part type:
{
part
[
'type'
]
}
"
)
raise
NotImplementedError
(
f
"Unknown part type:
{
part_type
}
"
)
messages
=
[
ConversationMessage
(
role
=
role
,
content
=
"
\n
"
.
join
(
texts
))]
return
[
ConversationMessage
(
role
=
role
,
content
=
"
\n
"
.
join
(
texts
))],
[]
return
ChatMessageParseResult
(
messages
=
messages
)
def
_parse_chat_message_content
(
self
,
message
:
ChatCompletionMessageParam
,
)
->
ChatMessageParseResult
:
role
=
message
[
"role"
]
content
=
message
.
get
(
"content"
)
if
content
is
None
:
return
ChatMessageParseResult
(
messages
=
[])
if
isinstance
(
content
,
str
):
messages
=
[
ConversationMessage
(
role
=
role
,
content
=
content
)]
return
ChatMessageParseResult
(
messages
=
messages
)
return
self
.
_parse_chat_message_content_parts
(
role
,
content
)
async
def
create_chat_completion
(
async
def
create_chat_completion
(
self
,
request
:
ChatCompletionRequest
,
raw_request
:
Request
self
,
request
:
ChatCompletionRequest
,
raw_request
:
Optional
[
Request
]
=
None
)
->
Union
[
ErrorResponse
,
AsyncGenerator
[
str
,
None
],
)
->
Union
[
ErrorResponse
,
AsyncGenerator
[
str
,
None
],
ChatCompletionResponse
]:
ChatCompletionResponse
]:
"""Completion API similar to OpenAI's API.
"""Completion API similar to OpenAI's API.
...
@@ -89,11 +144,10 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -89,11 +144,10 @@ class OpenAIServingChat(OpenAIServing):
try
:
try
:
conversation
:
List
[
ConversationMessage
]
=
[]
conversation
:
List
[
ConversationMessage
]
=
[]
for
m
in
request
.
messages
:
for
msg
in
request
.
messages
:
messages
,
_
=
self
.
_parse_chat_message_content
(
parsed_msg
=
self
.
_parse_chat_message_content
(
msg
)
m
[
"role"
],
m
[
"content"
])
conversation
.
extend
(
messages
)
conversation
.
extend
(
parsed_msg
.
messages
)
prompt
=
self
.
tokenizer
.
apply_chat_template
(
prompt
=
self
.
tokenizer
.
apply_chat_template
(
conversation
=
conversation
,
conversation
=
conversation
,
...
@@ -108,7 +162,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -108,7 +162,7 @@ class OpenAIServingChat(OpenAIServing):
try
:
try
:
# Tokenize/detokenize depending on prompt format (string/token list)
# Tokenize/detokenize depending on prompt format (string/token list)
prompt_ids
,
prompt_text
=
self
.
_validate_prompt_and_tokenize
(
prompt_ids
,
prompt_text
=
self
.
_validate_prompt_and_tokenize
(
request
,
prompt
=
prompt
)
request
,
prompt
=
prompt
,
add_special_tokens
=
False
)
sampling_params
=
request
.
to_sampling_params
()
sampling_params
=
request
.
to_sampling_params
()
lora_request
=
self
.
_maybe_get_lora
(
request
)
lora_request
=
self
.
_maybe_get_lora
(
request
)
decoding_config
=
await
self
.
engine
.
get_decoding_config
()
decoding_config
=
await
self
.
engine
.
get_decoding_config
()
...
@@ -126,9 +180,15 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -126,9 +180,15 @@ class OpenAIServingChat(OpenAIServing):
except
ValueError
as
e
:
except
ValueError
as
e
:
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
result_generator
=
self
.
engine
.
generate
(
prompt_text
,
sampling_params
,
result_generator
=
self
.
engine
.
generate
(
request_id
,
prompt_ids
,
{
lora_request
)
"prompt"
:
prompt_text
,
"prompt_token_ids"
:
prompt_ids
},
sampling_params
,
request_id
,
lora_request
,
)
# Streaming response
# Streaming response
if
request
.
stream
:
if
request
.
stream
:
return
self
.
chat_completion_stream_generator
(
return
self
.
chat_completion_stream_generator
(
...
@@ -227,11 +287,10 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -227,11 +287,10 @@ class OpenAIServingChat(OpenAIServing):
previous_num_tokens
[
i
]:]
if
output
.
logprobs
else
None
previous_num_tokens
[
i
]:]
if
output
.
logprobs
else
None
if
request
.
logprobs
:
if
request
.
logprobs
:
logprobs
=
self
.
_create_logprobs
(
logprobs
=
self
.
_create_
chat_
logprobs
(
token_ids
=
delta_token_ids
,
token_ids
=
delta_token_ids
,
top_logprobs
=
top_logprobs
,
top_logprobs
=
top_logprobs
,
num_output_top_logprobs
=
request
.
logprobs
,
num_output_top_logprobs
=
request
.
top_logprobs
,
initial_text_offset
=
len
(
previous_texts
[
i
]),
)
)
else
:
else
:
logprobs
=
None
logprobs
=
None
...
@@ -289,7 +348,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -289,7 +348,7 @@ class OpenAIServingChat(OpenAIServing):
yield
"data: [DONE]
\n\n
"
yield
"data: [DONE]
\n\n
"
async
def
chat_completion_full_generator
(
async
def
chat_completion_full_generator
(
self
,
request
:
ChatCompletionRequest
,
raw_request
:
Request
,
self
,
request
:
ChatCompletionRequest
,
raw_request
:
Optional
[
Request
]
,
result_generator
:
AsyncIterator
[
RequestOutput
],
request_id
:
str
,
result_generator
:
AsyncIterator
[
RequestOutput
],
request_id
:
str
,
conversation
:
List
[
ConversationMessage
]
conversation
:
List
[
ConversationMessage
]
)
->
Union
[
ErrorResponse
,
ChatCompletionResponse
]:
)
->
Union
[
ErrorResponse
,
ChatCompletionResponse
]:
...
@@ -299,7 +358,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -299,7 +358,7 @@ class OpenAIServingChat(OpenAIServing):
final_res
:
Optional
[
RequestOutput
]
=
None
final_res
:
Optional
[
RequestOutput
]
=
None
async
for
res
in
result_generator
:
async
for
res
in
result_generator
:
if
await
raw_request
.
is_disconnected
():
if
raw_request
is
not
None
and
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
# Abort the request if the client disconnects.
await
self
.
engine
.
abort
(
request_id
)
await
self
.
engine
.
abort
(
request_id
)
return
self
.
create_error_response
(
"Client disconnected"
)
return
self
.
create_error_response
(
"Client disconnected"
)
...
@@ -314,10 +373,10 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -314,10 +373,10 @@ class OpenAIServingChat(OpenAIServing):
top_logprobs
=
output
.
logprobs
top_logprobs
=
output
.
logprobs
if
request
.
logprobs
:
if
request
.
logprobs
:
logprobs
=
self
.
_create_logprobs
(
logprobs
=
self
.
_create_
chat_
logprobs
(
token_ids
=
token_ids
,
token_ids
=
token_ids
,
top_logprobs
=
top_logprobs
,
top_logprobs
=
top_logprobs
,
num_output_top_logprobs
=
request
.
logprobs
,
num_output_top_logprobs
=
request
.
top_
logprobs
,
)
)
else
:
else
:
logprobs
=
None
logprobs
=
None
...
@@ -327,8 +386,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -327,8 +386,7 @@ class OpenAIServingChat(OpenAIServing):
message
=
ChatMessage
(
role
=
role
,
content
=
output
.
text
),
message
=
ChatMessage
(
role
=
role
,
content
=
output
.
text
),
logprobs
=
logprobs
,
logprobs
=
logprobs
,
finish_reason
=
output
.
finish_reason
,
finish_reason
=
output
.
finish_reason
,
stop_reason
=
output
.
stop_reason
,
stop_reason
=
output
.
stop_reason
)
)
choices
.
append
(
choice_data
)
choices
.
append
(
choice_data
)
if
request
.
echo
:
if
request
.
echo
:
...
@@ -359,34 +417,50 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -359,34 +417,50 @@ class OpenAIServingChat(OpenAIServing):
return
response
return
response
async
def
_load_chat_template
(
self
,
chat_template
:
Optional
[
str
]):
def
_get_top_logprobs
(
while
self
.
tokenizer
is
None
:
self
,
logprobs
:
Dict
[
int
,
Logprob
],
# Give the parent class time to load the tokenizer
top_logprobs
:
Optional
[
int
])
->
List
[
ChatCompletionLogProb
]:
await
asyncio
.
sleep
(
0.1
)
return
[
tokenizer
=
self
.
tokenizer
ChatCompletionLogProb
(
token
=
self
.
_get_decoded_token
(
p
[
1
],
p
[
0
]),
if
chat_template
is
not
None
:
logprob
=
max
(
p
[
1
].
logprob
,
-
9999.0
),
try
:
bytes
=
list
(
with
open
(
chat_template
,
"r"
)
as
f
:
self
.
_get_decoded_token
(
p
[
1
],
tokenizer
.
chat_template
=
f
.
read
()
p
[
0
]).
encode
(
"utf-8"
,
except
OSError
as
e
:
errors
=
"replace"
)))
JINJA_CHARS
=
"{}
\n
"
for
i
,
p
in
enumerate
(
logprobs
.
items
())
if
not
any
(
c
in
chat_template
for
c
in
JINJA_CHARS
):
if
top_logprobs
and
i
<
top_logprobs
msg
=
(
f
"The supplied chat template (
{
chat_template
}
) "
]
f
"looks like a file path, but it failed to be "
f
"opened. Reason:
{
e
}
"
)
def
_create_chat_logprobs
(
raise
ValueError
(
msg
)
from
e
self
,
token_ids
:
GenericSequence
[
int
],
# If opening a file fails, set chat template to be args to
top_logprobs
:
GenericSequence
[
Optional
[
Dict
[
int
,
Logprob
]]],
# ensure we decode so our escape are interpreted correctly
num_output_top_logprobs
:
Optional
[
int
]
=
None
,
tokenizer
.
chat_template
=
codecs
.
decode
(
)
->
ChatCompletionLogProbs
:
chat_template
,
"unicode_escape"
)
"""Create OpenAI-style logprobs."""
logger
.
info
(
"Using supplied chat template:
\n
%s"
,
logprobs_content
=
[]
tokenizer
.
chat_template
)
elif
tokenizer
.
chat_template
is
not
None
:
for
i
,
token_id
in
enumerate
(
token_ids
):
logger
.
info
(
"Using default chat template:
\n
%s"
,
step_top_logprobs
=
top_logprobs
[
i
]
tokenizer
.
chat_template
)
if
step_top_logprobs
is
None
:
else
:
logprobs_content
.
append
(
logger
.
warning
(
ChatCompletionLogProbsContent
(
"No chat template provided. Chat API will not work."
)
token
=
self
.
tokenizer
.
decode
(
token_id
),
bytes
=
list
(
self
.
tokenizer
.
decode
(
token_id
).
encode
(
"utf-8"
,
errors
=
"replace"
))))
else
:
logprobs_content
.
append
(
ChatCompletionLogProbsContent
(
token
=
step_top_logprobs
[
token_id
].
decoded_token
,
logprob
=
max
(
step_top_logprobs
[
token_id
].
logprob
,
-
9999.0
),
bytes
=
list
(
step_top_logprobs
[
token_id
].
decoded_token
.
encode
(
"utf-8"
,
errors
=
"replace"
)),
top_logprobs
=
self
.
_get_top_logprobs
(
step_top_logprobs
,
num_output_top_logprobs
)))
return
ChatCompletionLogProbs
(
content
=
logprobs_content
)
vllm/entrypoints/openai/serving_completion.py
View file @
b9e12416
import
time
import
time
from
typing
import
(
AsyncGenerator
,
AsyncIterator
,
Callable
,
Dict
,
List
,
from
typing
import
(
AsyncGenerator
,
AsyncIterator
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
)
Optional
)
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Tuple
from
fastapi
import
Request
from
fastapi
import
Request
from
vllm.config
import
ModelConfig
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.openai.protocol
import
(
CompletionRequest
,
# yapf: disable
from
vllm.entrypoints.openai.protocol
import
(
CompletionLogProbs
,
CompletionRequest
,
CompletionResponse
,
CompletionResponse
,
CompletionResponseChoice
,
CompletionResponseChoice
,
CompletionResponseStreamChoice
,
CompletionResponseStreamChoice
,
CompletionStreamResponse
,
CompletionStreamResponse
,
LogProbs
,
UsageInfo
)
UsageInfo
)
# yapf: enable
from
vllm.entrypoints.openai.serving_engine
import
(
LoRAModulePath
,
from
vllm.entrypoints.openai.serving_engine
import
(
LoRAModulePath
,
OpenAIServing
)
OpenAIServing
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.guided_decoding
import
(
from
vllm.model_executor.guided_decoding
import
(
get_guided_decoding_logits_processor
)
get_guided_decoding_logits_processor
)
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sequence
import
Logprob
from
vllm.utils
import
merge_async_iterators
,
random_uuid
from
vllm.utils
import
merge_async_iterators
,
random_uuid
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -24,7 +31,7 @@ logger = init_logger(__name__)
...
@@ -24,7 +31,7 @@ logger = init_logger(__name__)
TypeTokenIDs
=
List
[
int
]
TypeTokenIDs
=
List
[
int
]
TypeTopLogProbs
=
List
[
Optional
[
Dict
[
int
,
float
]]]
TypeTopLogProbs
=
List
[
Optional
[
Dict
[
int
,
float
]]]
TypeCreateLogProbsFn
=
Callable
[
TypeCreateLogProbsFn
=
Callable
[
[
TypeTokenIDs
,
TypeTopLogProbs
,
Optional
[
int
],
int
],
LogProbs
]
[
TypeTokenIDs
,
TypeTopLogProbs
,
Optional
[
int
],
int
],
Completion
LogProbs
]
def
parse_prompt_format
(
prompt
)
->
Tuple
[
bool
,
list
]:
def
parse_prompt_format
(
prompt
)
->
Tuple
[
bool
,
list
]:
...
@@ -52,11 +59,11 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]:
...
@@ -52,11 +59,11 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]:
class
OpenAIServingCompletion
(
OpenAIServing
):
class
OpenAIServingCompletion
(
OpenAIServing
):
def
__init__
(
self
,
def
__init__
(
self
,
engine
:
AsyncLLMEngine
,
model_config
:
ModelConfig
,
engine
:
AsyncLLMEngine
,
served_model_names
:
List
[
str
],
served_model_names
:
List
[
str
],
lora_modules
:
Optional
[
List
[
LoRAModulePath
]]
=
None
):
lora_modules
:
Optional
[
List
[
LoRAModulePath
]]):
super
().
__init__
(
engine
=
engine
,
super
().
__init__
(
engine
=
engine
,
model_config
=
model_config
,
served_model_names
=
served_model_names
,
served_model_names
=
served_model_names
,
lora_modules
=
lora_modules
)
lora_modules
=
lora_modules
)
...
@@ -118,12 +125,17 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -118,12 +125,17 @@ class OpenAIServingCompletion(OpenAIServing):
truncate_prompt_tokens
)
truncate_prompt_tokens
)
prompt_ids
,
prompt_text
=
prompt_formats
prompt_ids
,
prompt_text
=
prompt_formats
generators
.
append
(
generator
=
self
.
engine
.
generate
(
self
.
engine
.
generate
(
prompt_text
,
{
sampling_params
,
"prompt"
:
prompt_text
,
f
"
{
request_id
}
-
{
i
}
"
,
"prompt_token_ids"
:
prompt_ids
prompt_token_ids
=
prompt_ids
,
},
lora_request
=
lora_request
))
sampling_params
,
f
"
{
request_id
}
-
{
i
}
"
,
lora_request
=
lora_request
,
)
generators
.
append
(
generator
)
except
ValueError
as
e
:
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
...
@@ -229,7 +241,7 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -229,7 +241,7 @@ class OpenAIServingCompletion(OpenAIServing):
i
]:]
if
output
.
logprobs
else
None
i
]:]
if
output
.
logprobs
else
None
if
request
.
logprobs
is
not
None
:
if
request
.
logprobs
is
not
None
:
logprobs
=
self
.
_create_logprobs
(
logprobs
=
self
.
_create_
completion_
logprobs
(
token_ids
=
delta_token_ids
,
token_ids
=
delta_token_ids
,
top_logprobs
=
top_logprobs
,
top_logprobs
=
top_logprobs
,
num_output_top_logprobs
=
request
.
logprobs
,
num_output_top_logprobs
=
request
.
logprobs
,
...
@@ -311,7 +323,7 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -311,7 +323,7 @@ class OpenAIServingCompletion(OpenAIServing):
assert
top_logprobs
is
not
None
,
(
assert
top_logprobs
is
not
None
,
(
"top_logprobs must be provided when logprobs "
"top_logprobs must be provided when logprobs "
"is requested"
)
"is requested"
)
logprobs
=
self
.
_create_logprobs
(
logprobs
=
self
.
_create_
completion_
logprobs
(
token_ids
=
token_ids
,
token_ids
=
token_ids
,
top_logprobs
=
top_logprobs
,
top_logprobs
=
top_logprobs
,
num_output_top_logprobs
=
request
.
logprobs
,
num_output_top_logprobs
=
request
.
logprobs
,
...
@@ -345,3 +357,59 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -345,3 +357,59 @@ class OpenAIServingCompletion(OpenAIServing):
choices
=
choices
,
choices
=
choices
,
usage
=
usage
,
usage
=
usage
,
)
)
def
_create_completion_logprobs
(
self
,
token_ids
:
GenericSequence
[
int
],
top_logprobs
:
GenericSequence
[
Optional
[
Dict
[
int
,
Logprob
]]],
num_output_top_logprobs
:
int
,
initial_text_offset
:
int
=
0
,
)
->
CompletionLogProbs
:
"""Create logprobs for OpenAI Completion API."""
out_text_offset
:
List
[
int
]
=
[]
out_token_logprobs
:
List
[
Optional
[
float
]]
=
[]
out_tokens
:
List
[
str
]
=
[]
out_top_logprobs
:
List
[
Optional
[
Dict
[
str
,
float
]]]
=
[]
last_token_len
=
0
for
i
,
token_id
in
enumerate
(
token_ids
):
step_top_logprobs
=
top_logprobs
[
i
]
if
step_top_logprobs
is
None
:
token
=
self
.
tokenizer
.
decode
(
token_id
)
out_tokens
.
append
(
token
)
out_token_logprobs
.
append
(
None
)
out_top_logprobs
.
append
(
None
)
else
:
token
=
self
.
_get_decoded_token
(
step_top_logprobs
[
token_id
],
token_id
)
token_logprob
=
max
(
step_top_logprobs
[
token_id
].
logprob
,
-
9999.0
)
out_tokens
.
append
(
token
)
out_token_logprobs
.
append
(
token_logprob
)
# makes sure to add the top num_output_top_logprobs + 1
# logprobs, as defined in the openai API
# (cf. https://github.com/openai/openai-openapi/blob/
# 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
out_top_logprobs
.
append
({
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
self
.
_get_decoded_token
(
top_lp
[
1
],
top_lp
[
0
]):
max
(
top_lp
[
1
].
logprob
,
-
9999.0
)
for
i
,
top_lp
in
enumerate
(
step_top_logprobs
.
items
())
if
num_output_top_logprobs
>=
i
})
if
len
(
out_text_offset
)
==
0
:
out_text_offset
.
append
(
initial_text_offset
)
else
:
out_text_offset
.
append
(
out_text_offset
[
-
1
]
+
last_token_len
)
last_token_len
=
len
(
token
)
return
CompletionLogProbs
(
text_offset
=
out_text_offset
,
token_logprobs
=
out_token_logprobs
,
tokens
=
out_tokens
,
top_logprobs
=
out_top_logprobs
,
)
vllm/entrypoints/openai/serving_embedding.py
0 → 100644
View file @
b9e12416
import
time
from
typing
import
AsyncIterator
,
List
,
Optional
,
Tuple
from
fastapi
import
Request
from
vllm.config
import
ModelConfig
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.openai.protocol
import
(
EmbeddingRequest
,
EmbeddingResponse
,
EmbeddingResponseData
,
UsageInfo
)
from
vllm.entrypoints.openai.serving_completion
import
parse_prompt_format
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
from
vllm.logger
import
init_logger
from
vllm.outputs
import
EmbeddingRequestOutput
from
vllm.utils
import
merge_async_iterators
,
random_uuid
logger
=
init_logger
(
__name__
)
TypeTokenIDs
=
List
[
int
]
def
request_output_to_embedding_response
(
final_res_batch
:
List
[
EmbeddingRequestOutput
],
request_id
:
str
,
created_time
:
int
,
model_name
:
str
,
)
->
EmbeddingResponse
:
data
=
[]
num_prompt_tokens
=
0
for
idx
,
final_res
in
enumerate
(
final_res_batch
):
assert
final_res
is
not
None
prompt_token_ids
=
final_res
.
prompt_token_ids
embedding_data
=
EmbeddingResponseData
(
index
=
idx
,
embedding
=
final_res
.
outputs
.
embedding
)
data
.
append
(
embedding_data
)
num_prompt_tokens
+=
len
(
prompt_token_ids
)
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
total_tokens
=
num_prompt_tokens
,
)
return
EmbeddingResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
data
=
data
,
usage
=
usage
,
)
class
OpenAIServingEmbedding
(
OpenAIServing
):
def
__init__
(
self
,
engine
:
AsyncLLMEngine
,
model_config
:
ModelConfig
,
served_model_names
:
List
[
str
]):
super
().
__init__
(
engine
=
engine
,
model_config
=
model_config
,
served_model_names
=
served_model_names
,
lora_modules
=
None
)
self
.
_check_embedding_mode
(
model_config
.
embedding_mode
)
async
def
create_embedding
(
self
,
request
:
EmbeddingRequest
,
raw_request
:
Request
):
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/embeddings/create
for the API specification. This API mimics the OpenAI Embedding API.
"""
error_check_ret
=
await
self
.
_check_model
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
# Return error for unsupported features.
if
request
.
encoding_format
==
"base64"
:
return
self
.
create_error_response
(
"base64 encoding is not currently supported"
)
if
request
.
dimensions
is
not
None
:
return
self
.
create_error_response
(
"dimensions is currently not supported"
)
model_name
=
request
.
model
request_id
=
f
"cmpl-
{
random_uuid
()
}
"
created_time
=
int
(
time
.
monotonic
())
# Schedule the request and get the result generator.
generators
=
[]
try
:
prompt_is_tokens
,
prompts
=
parse_prompt_format
(
request
.
input
)
pooling_params
=
request
.
to_pooling_params
()
for
i
,
prompt
in
enumerate
(
prompts
):
if
prompt_is_tokens
:
prompt_formats
=
self
.
_validate_prompt_and_tokenize
(
request
,
prompt_ids
=
prompt
)
else
:
prompt_formats
=
self
.
_validate_prompt_and_tokenize
(
request
,
prompt
=
prompt
)
prompt_ids
,
prompt_text
=
prompt_formats
generator
=
self
.
engine
.
encode
(
{
"prompt"
:
prompt_text
,
"prompt_token_ids"
:
prompt_ids
},
pooling_params
,
f
"
{
request_id
}
-
{
i
}
"
,
)
generators
.
append
(
generator
)
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
result_generator
:
AsyncIterator
[
Tuple
[
int
,
EmbeddingRequestOutput
]]
=
merge_async_iterators
(
*
generators
)
# Non-streaming response
final_res_batch
:
List
[
Optional
[
EmbeddingRequestOutput
]]
final_res_batch
=
[
None
]
*
len
(
prompts
)
try
:
async
for
i
,
res
in
result_generator
:
if
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
await
self
.
engine
.
abort
(
f
"
{
request_id
}
-
{
i
}
"
)
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
"Client disconnected"
)
final_res_batch
[
i
]
=
res
response
=
request_output_to_embedding_response
(
final_res_batch
,
request_id
,
created_time
,
model_name
)
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
return
response
def
_check_embedding_mode
(
self
,
embedding_mode
:
bool
):
if
not
embedding_mode
:
logger
.
warning
(
"embedding_mode is False. Embedding API will not work."
)
else
:
logger
.
info
(
"Activating the server engine with embedding enabled."
)
vllm/entrypoints/openai/serving_engine.py
View file @
b9e12416
import
asyncio
import
json
import
json
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
http
import
HTTPStatus
from
http
import
HTTPStatus
from
typing
import
Any
,
Awaitable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
pydantic
import
Field
from
pydantic
import
Field
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
from
typing_extensions
import
Annotated
from
typing_extensions
import
Annotated
from
vllm.config
import
ModelConfig
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
CompletionRequest
,
ErrorResponse
,
CompletionRequest
,
LogProbs
,
ModelCard
,
ModelList
,
EmbeddingRequest
,
ErrorResponse
,
ModelCard
,
ModelList
,
ModelPermission
)
ModelPermission
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
...
@@ -29,13 +29,24 @@ class LoRAModulePath:
...
@@ -29,13 +29,24 @@ class LoRAModulePath:
class
OpenAIServing
:
class
OpenAIServing
:
def
__init__
(
self
,
def
__init__
(
self
,
engine
:
AsyncLLMEngine
,
model_config
:
ModelConfig
,
engine
:
AsyncLLMEngine
,
served_model_names
:
List
[
str
],
served_model_names
:
List
[
str
],
lora_modules
:
Optional
[
List
[
LoRAModulePath
]],
lora_modules
:
Optional
[
List
[
LoRAModulePath
]]):
await_post_init
:
Optional
[
Awaitable
[
Any
]]
=
None
):
super
().
__init__
()
self
.
engine
=
engine
self
.
engine
=
engine
self
.
max_model_len
=
model_config
.
max_model_len
# A separate tokenizer to map token IDs to strings.
self
.
tokenizer
=
get_tokenizer
(
model_config
.
tokenizer
,
tokenizer_mode
=
model_config
.
tokenizer_mode
,
tokenizer_revision
=
model_config
.
tokenizer_revision
,
trust_remote_code
=
model_config
.
trust_remote_code
,
truncation_side
=
"left"
)
self
.
served_model_names
=
served_model_names
self
.
served_model_names
=
served_model_names
if
lora_modules
is
None
:
if
lora_modules
is
None
:
self
.
lora_requests
=
[]
self
.
lora_requests
=
[]
else
:
else
:
...
@@ -47,38 +58,6 @@ class OpenAIServing:
...
@@ -47,38 +58,6 @@ class OpenAIServing:
)
for
i
,
lora
in
enumerate
(
lora_modules
,
start
=
1
)
)
for
i
,
lora
in
enumerate
(
lora_modules
,
start
=
1
)
]
]
self
.
max_model_len
=
0
# Lazy initialized
self
.
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]
try
:
event_loop
=
asyncio
.
get_running_loop
()
except
RuntimeError
:
event_loop
=
None
if
event_loop
is
not
None
and
event_loop
.
is_running
():
# If the current is instanced by Ray Serve,
# there is already a running event loop
event_loop
.
create_task
(
self
.
_post_init
(
await_post_init
))
else
:
# When using single vLLM without engine_use_ray
asyncio
.
run
(
self
.
_post_init
(
await_post_init
))
async
def
_post_init
(
self
,
await_post_init
):
engine_model_config
=
await
self
.
engine
.
get_model_config
()
self
.
max_model_len
=
engine_model_config
.
max_model_len
# A separate tokenizer to map token IDs to strings.
self
.
tokenizer
=
get_tokenizer
(
engine_model_config
.
tokenizer
,
tokenizer_mode
=
engine_model_config
.
tokenizer_mode
,
tokenizer_revision
=
engine_model_config
.
tokenizer_revision
,
trust_remote_code
=
engine_model_config
.
trust_remote_code
,
truncation_side
=
"left"
)
if
await_post_init
is
not
None
:
await
await_post_init
async
def
show_available_models
(
self
)
->
ModelList
:
async
def
show_available_models
(
self
)
->
ModelList
:
"""Show available models. Right now we only have one model."""
"""Show available models. Right now we only have one model."""
model_cards
=
[
model_cards
=
[
...
@@ -96,50 +75,6 @@ class OpenAIServing:
...
@@ -96,50 +75,6 @@ class OpenAIServing:
model_cards
.
extend
(
lora_cards
)
model_cards
.
extend
(
lora_cards
)
return
ModelList
(
data
=
model_cards
)
return
ModelList
(
data
=
model_cards
)
def
_create_logprobs
(
self
,
token_ids
:
List
[
int
],
top_logprobs
:
List
[
Optional
[
Dict
[
int
,
Logprob
]]],
num_output_top_logprobs
:
Optional
[
int
]
=
None
,
initial_text_offset
:
int
=
0
,
)
->
LogProbs
:
"""Create OpenAI-style logprobs."""
logprobs
=
LogProbs
()
last_token_len
=
0
if
num_output_top_logprobs
:
logprobs
.
top_logprobs
=
[]
for
i
,
token_id
in
enumerate
(
token_ids
):
step_top_logprobs
=
top_logprobs
[
i
]
if
step_top_logprobs
is
None
:
token
=
self
.
tokenizer
.
decode
(
token_id
)
logprobs
.
tokens
.
append
(
token
)
logprobs
.
token_logprobs
.
append
(
None
)
assert
logprobs
.
top_logprobs
is
not
None
logprobs
.
top_logprobs
.
append
(
None
)
else
:
token_logprob
=
step_top_logprobs
[
token_id
].
logprob
token
=
step_top_logprobs
[
token_id
].
decoded_token
logprobs
.
tokens
.
append
(
token
)
logprobs
.
token_logprobs
.
append
(
token_logprob
)
if
num_output_top_logprobs
:
assert
logprobs
.
top_logprobs
is
not
None
logprobs
.
top_logprobs
.
append
({
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
p
.
decoded_token
:
max
(
p
.
logprob
,
-
9999.0
)
for
i
,
p
in
step_top_logprobs
.
items
()
}
if
step_top_logprobs
else
None
)
if
len
(
logprobs
.
text_offset
)
==
0
:
logprobs
.
text_offset
.
append
(
initial_text_offset
)
else
:
logprobs
.
text_offset
.
append
(
logprobs
.
text_offset
[
-
1
]
+
last_token_len
)
last_token_len
=
len
(
token
)
return
logprobs
def
create_error_response
(
def
create_error_response
(
self
,
self
,
message
:
str
,
message
:
str
,
...
@@ -163,7 +98,8 @@ class OpenAIServing:
...
@@ -163,7 +98,8 @@ class OpenAIServing:
return
json_str
return
json_str
async
def
_check_model
(
async
def
_check_model
(
self
,
request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
]
self
,
request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
,
EmbeddingRequest
]
)
->
Optional
[
ErrorResponse
]:
)
->
Optional
[
ErrorResponse
]:
if
request
.
model
in
self
.
served_model_names
:
if
request
.
model
in
self
.
served_model_names
:
return
None
return
None
...
@@ -175,7 +111,8 @@ class OpenAIServing:
...
@@ -175,7 +111,8 @@ class OpenAIServing:
status_code
=
HTTPStatus
.
NOT_FOUND
)
status_code
=
HTTPStatus
.
NOT_FOUND
)
def
_maybe_get_lora
(
def
_maybe_get_lora
(
self
,
request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
]
self
,
request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
,
EmbeddingRequest
]
)
->
Optional
[
LoRARequest
]:
)
->
Optional
[
LoRARequest
]:
if
request
.
model
in
self
.
served_model_names
:
if
request
.
model
in
self
.
served_model_names
:
return
None
return
None
...
@@ -186,12 +123,14 @@ class OpenAIServing:
...
@@ -186,12 +123,14 @@ class OpenAIServing:
raise
ValueError
(
f
"The model `
{
request
.
model
}
` does not exist."
)
raise
ValueError
(
f
"The model `
{
request
.
model
}
` does not exist."
)
def
_validate_prompt_and_tokenize
(
def
_validate_prompt_and_tokenize
(
self
,
self
,
request
:
Union
[
ChatCompletionRequest
,
CompletionRequest
],
request
:
Union
[
ChatCompletionRequest
,
CompletionRequest
,
prompt
:
Optional
[
str
]
=
None
,
EmbeddingRequest
],
prompt_ids
:
Optional
[
List
[
int
]]
=
None
,
prompt
:
Optional
[
str
]
=
None
,
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=
1
)]]
=
None
prompt_ids
:
Optional
[
List
[
int
]]
=
None
,
)
->
Tuple
[
List
[
int
],
str
]:
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=
1
)]]
=
None
,
add_special_tokens
:
bool
=
True
)
->
Tuple
[
List
[
int
],
str
]:
if
not
(
prompt
or
prompt_ids
):
if
not
(
prompt
or
prompt_ids
):
raise
ValueError
(
"Either prompt or prompt_ids should be provided."
)
raise
ValueError
(
"Either prompt or prompt_ids should be provided."
)
if
(
prompt
and
prompt_ids
):
if
(
prompt
and
prompt_ids
):
...
@@ -199,10 +138,19 @@ class OpenAIServing:
...
@@ -199,10 +138,19 @@ class OpenAIServing:
"Only one of prompt or prompt_ids should be provided."
)
"Only one of prompt or prompt_ids should be provided."
)
if
prompt_ids
is
None
:
if
prompt_ids
is
None
:
tokenizer_kwargs
=
{}
if
truncate_prompt_tokens
is
None
else
{
# When using OpenAIServingChat for chat completions, the
"truncation"
:
True
,
# special tokens (e.g., BOS) have already been added by the
"max_length"
:
truncate_prompt_tokens
,
# chat template. Therefore, we do not need to add them again.
# Set add_special_tokens to False to avoid adding the BOS tokens
# again.
tokenizer_kwargs
:
Dict
[
str
,
Any
]
=
{
"add_special_tokens"
:
add_special_tokens
}
}
if
truncate_prompt_tokens
is
not
None
:
tokenizer_kwargs
.
update
({
"truncation"
:
True
,
"max_length"
:
truncate_prompt_tokens
,
})
input_ids
=
self
.
tokenizer
(
prompt
,
**
tokenizer_kwargs
).
input_ids
input_ids
=
self
.
tokenizer
(
prompt
,
**
tokenizer_kwargs
).
input_ids
elif
truncate_prompt_tokens
is
not
None
:
elif
truncate_prompt_tokens
is
not
None
:
input_ids
=
prompt_ids
[
-
truncate_prompt_tokens
:]
input_ids
=
prompt_ids
[
-
truncate_prompt_tokens
:]
...
@@ -213,6 +161,16 @@ class OpenAIServing:
...
@@ -213,6 +161,16 @@ class OpenAIServing:
prompt_ids
)
prompt_ids
)
token_num
=
len
(
input_ids
)
token_num
=
len
(
input_ids
)
# Note: EmbeddingRequest doesn't have max_tokens
if
isinstance
(
request
,
EmbeddingRequest
):
if
token_num
>
self
.
max_model_len
:
raise
ValueError
(
f
"This model's maximum context length is "
f
"
{
self
.
max_model_len
}
tokens. However, you requested "
f
"
{
token_num
}
tokens in the input for embedding "
f
"generation. Please reduce the length of the input."
,
)
return
input_ids
,
input_text
if
request
.
max_tokens
is
None
:
if
request
.
max_tokens
is
None
:
if
token_num
>=
self
.
max_model_len
:
if
token_num
>=
self
.
max_model_len
:
raise
ValueError
(
raise
ValueError
(
...
@@ -232,3 +190,8 @@ class OpenAIServing:
...
@@ -232,3 +190,8 @@ class OpenAIServing:
f
"Please reduce the length of the messages or completion."
,
)
f
"Please reduce the length of the messages or completion."
,
)
else
:
else
:
return
input_ids
,
input_text
return
input_ids
,
input_text
def
_get_decoded_token
(
self
,
logprob
:
Logprob
,
token_id
:
int
)
->
str
:
if
logprob
.
decoded_token
is
not
None
:
return
logprob
.
decoded_token
return
self
.
tokenizer
.
decode
(
token_id
)
vllm/envs.py
View file @
b9e12416
...
@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
...
@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
VLLM_HOST_IP
:
str
=
""
VLLM_HOST_IP
:
str
=
""
VLLM_PORT
:
Optional
[
int
]
=
None
VLLM_USE_MODELSCOPE
:
bool
=
False
VLLM_USE_MODELSCOPE
:
bool
=
False
VLLM_INSTANCE_ID
:
Optional
[
str
]
=
None
VLLM_INSTANCE_ID
:
Optional
[
str
]
=
None
VLLM_NCCL_SO_PATH
:
Optional
[
str
]
=
None
VLLM_NCCL_SO_PATH
:
Optional
[
str
]
=
None
...
@@ -21,6 +22,7 @@ if TYPE_CHECKING:
...
@@ -21,6 +22,7 @@ if TYPE_CHECKING:
VLLM_DO_NOT_TRACK
:
bool
=
False
VLLM_DO_NOT_TRACK
:
bool
=
False
VLLM_USAGE_SOURCE
:
str
=
""
VLLM_USAGE_SOURCE
:
str
=
""
VLLM_CONFIGURE_LOGGING
:
int
=
1
VLLM_CONFIGURE_LOGGING
:
int
=
1
VLLM_LOGGING_LEVEL
:
str
=
"INFO"
VLLM_LOGGING_CONFIG_PATH
:
Optional
[
str
]
=
None
VLLM_LOGGING_CONFIG_PATH
:
Optional
[
str
]
=
None
VLLM_TRACE_FUNCTION
:
int
=
0
VLLM_TRACE_FUNCTION
:
int
=
0
VLLM_ATTENTION_BACKEND
:
Optional
[
str
]
=
None
VLLM_ATTENTION_BACKEND
:
Optional
[
str
]
=
None
...
@@ -96,6 +98,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -96,6 +98,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
'VLLM_HOST_IP'
:
'VLLM_HOST_IP'
:
lambda
:
os
.
getenv
(
'VLLM_HOST_IP'
,
""
)
or
os
.
getenv
(
"HOST_IP"
,
""
),
lambda
:
os
.
getenv
(
'VLLM_HOST_IP'
,
""
)
or
os
.
getenv
(
"HOST_IP"
,
""
),
# used in distributed environment to manually set the communication port
# '0' is used to make mypy happy
'VLLM_PORT'
:
lambda
:
int
(
os
.
getenv
(
'VLLM_PORT'
,
'0'
))
if
'VLLM_PORT'
in
os
.
environ
else
None
,
# If true, will load models from ModelScope instead of Hugging Face Hub.
# If true, will load models from ModelScope instead of Hugging Face Hub.
# note that the value is true or false, not numbers
# note that the value is true or false, not numbers
"VLLM_USE_MODELSCOPE"
:
"VLLM_USE_MODELSCOPE"
:
...
@@ -145,7 +153,7 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -145,7 +153,7 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# S3 access information, used for tensorizer to load model from S3
# S3 access information, used for tensorizer to load model from S3
"S3_ACCESS_KEY_ID"
:
"S3_ACCESS_KEY_ID"
:
lambda
:
os
.
environ
.
get
(
"S3_ACCESS_KEY"
,
None
),
lambda
:
os
.
environ
.
get
(
"S3_ACCESS_KEY
_ID
"
,
None
),
"S3_SECRET_ACCESS_KEY"
:
"S3_SECRET_ACCESS_KEY"
:
lambda
:
os
.
environ
.
get
(
"S3_SECRET_ACCESS_KEY"
,
None
),
lambda
:
os
.
environ
.
get
(
"S3_SECRET_ACCESS_KEY"
,
None
),
"S3_ENDPOINT_URL"
:
"S3_ENDPOINT_URL"
:
...
@@ -171,6 +179,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -171,6 +179,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_LOGGING_CONFIG_PATH"
:
"VLLM_LOGGING_CONFIG_PATH"
:
lambda
:
os
.
getenv
(
"VLLM_LOGGING_CONFIG_PATH"
),
lambda
:
os
.
getenv
(
"VLLM_LOGGING_CONFIG_PATH"
),
# this is used for configuring the default logging level
"VLLM_LOGGING_LEVEL"
:
lambda
:
os
.
getenv
(
"VLLM_LOGGING_LEVEL"
,
"INFO"
),
# Trace function calls
# Trace function calls
# If set to 1, vllm will trace function calls
# If set to 1, vllm will trace function calls
# Useful for debugging
# Useful for debugging
...
...
vllm/executor/distributed_gpu_executor.py
View file @
b9e12416
import
asyncio
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
Any
,
Awaitable
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
vllm.executor.executor_base
import
ExecutorAsyncBase
from
vllm.executor.executor_base
import
ExecutorAsyncBase
from
vllm.executor.gpu_executor
import
GPUExecutor
from
vllm.executor.gpu_executor
import
GPUExecutor
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -13,6 +14,16 @@ logger = init_logger(__name__)
...
@@ -13,6 +14,16 @@ logger = init_logger(__name__)
class
DistributedGPUExecutor
(
GPUExecutor
):
class
DistributedGPUExecutor
(
GPUExecutor
):
"""Abstract superclass of multi-GPU executor implementations."""
"""Abstract superclass of multi-GPU executor implementations."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
# This is non-None when the execute model loop is running
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
self
.
parallel_worker_tasks
:
Optional
[
Union
[
Any
,
Awaitable
[
Any
]]]
=
None
# Updated by implementations that require additional args to be passed
# to the _run_workers execute_model call
self
.
extra_execute_model_run_workers_kwargs
:
Dict
[
str
,
Any
]
=
{}
super
().
__init__
(
*
args
,
**
kwargs
)
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
"""Determine the number of available KV blocks.
"""Determine the number of available KV blocks.
...
@@ -52,13 +63,28 @@ class DistributedGPUExecutor(GPUExecutor):
...
@@ -52,13 +63,28 @@ class DistributedGPUExecutor(GPUExecutor):
num_gpu_blocks
=
num_gpu_blocks
,
num_gpu_blocks
=
num_gpu_blocks
,
num_cpu_blocks
=
num_cpu_blocks
)
num_cpu_blocks
=
num_cpu_blocks
)
def
execute_model
(
self
,
*
args
,
**
kwargs
)
->
List
[
SamplerOutput
]:
def
execute_model
(
all_outputs
=
self
.
_run_workers
(
"execute_model"
,
self
,
driver_args
=
args
,
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
driver_kwargs
=
kwargs
)
if
self
.
parallel_worker_tasks
is
None
:
self
.
parallel_worker_tasks
=
self
.
_run_workers
(
"start_worker_execution_loop"
,
async_run_remote_workers_only
=
True
,
**
self
.
extra_execute_model_run_workers_kwargs
)
# Only the driver worker returns the sampling results.
# Only the driver worker returns the sampling results.
return
all_outputs
[
0
]
return
self
.
_driver_execute_model
(
execute_model_req
)
def
stop_remote_worker_execution_loop
(
self
)
->
None
:
if
self
.
parallel_worker_tasks
is
None
:
return
self
.
_driver_execute_model
()
parallel_worker_tasks
=
self
.
parallel_worker_tasks
self
.
parallel_worker_tasks
=
None
# Ensure that workers exit model loop cleanly
# (this will raise otherwise)
self
.
_wait_for_tasks_completion
(
parallel_worker_tasks
)
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
assert
lora_request
.
lora_int_id
>
0
,
"lora_id must be greater than 0."
assert
lora_request
.
lora_int_id
>
0
,
"lora_id must be greater than 0."
...
@@ -77,39 +103,95 @@ class DistributedGPUExecutor(GPUExecutor):
...
@@ -77,39 +103,95 @@ class DistributedGPUExecutor(GPUExecutor):
def
list_loras
(
self
)
->
Set
[
int
]:
def
list_loras
(
self
)
->
Set
[
int
]:
return
self
.
_run_workers
(
"list_loras"
)
return
self
.
_run_workers
(
"list_loras"
)
def
save_sharded_state
(
self
,
path
:
str
,
pattern
:
Optional
[
str
]
=
None
,
max_size
:
Optional
[
int
]
=
None
,
)
->
None
:
self
.
_run_workers
(
"save_sharded_state"
,
path
=
path
,
pattern
=
pattern
,
max_size
=
max_size
)
@
abstractmethod
def
_driver_execute_model
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
)
->
List
[
SamplerOutput
]:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
_run_workers
(
def
_run_workers
(
self
,
self
,
method
:
str
,
method
:
str
,
*
args
,
*
args
,
driver_args
:
Optional
[
Tuple
[
Any
,
...]]
=
None
,
async_run_remote_workers_only
:
bool
=
False
,
driver_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_concurrent_workers
:
Optional
[
int
]
=
None
,
max_concurrent_workers
:
Optional
[
int
]
=
None
,
**
kwargs
,
**
kwargs
,
)
->
Any
:
)
->
Any
:
"""Runs the given method on all workers."""
"""Runs the given method on all workers.
Args:
async_run_remote_workers_only: If True the method will be run only
in the remote workers, not the driver worker. It will also be
run asynchronously and return a list of futures rather than
blocking on the results.
"""
raise
NotImplementedError
@
abstractmethod
def
_wait_for_tasks_completion
(
self
,
parallel_worker_tasks
:
Any
)
->
None
:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
raise
NotImplementedError
raise
NotImplementedError
class
DistributedGPUExecutorAsync
(
DistributedGPUExecutor
,
ExecutorAsyncBase
):
class
DistributedGPUExecutorAsync
(
DistributedGPUExecutor
,
ExecutorAsyncBase
):
async
def
execute_model_async
(
self
,
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
if
self
.
parallel_worker_tasks
is
None
:
# Start model execution loop running in the parallel workers
self
.
parallel_worker_tasks
=
asyncio
.
create_task
(
self
.
_start_worker_execution_loop
())
# Only the driver worker returns the sampling results.
return
await
self
.
_driver_execute_model_async
(
execute_model_req
)
async
def
stop_remote_worker_execution_loop_async
(
self
)
->
None
:
if
self
.
parallel_worker_tasks
is
None
:
return
await
self
.
_driver_execute_model_async
()
parallel_worker_tasks
=
self
.
parallel_worker_tasks
self
.
parallel_worker_tasks
=
None
# Ensure that workers exit model loop cleanly
# (this will raise otherwise)
await
parallel_worker_tasks
@
abstractmethod
@
abstractmethod
async
def
_
run_workers
_async
(
async
def
_
driver_execute_model
_async
(
self
,
self
,
method
:
str
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
*
args
,
)
->
List
[
SamplerOutput
]:
driver_args
:
Optional
[
Tuple
[
Any
,
...]]
=
None
,
"""Execute the model asynchronously in the driver worker.
driver_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
**
kwargs
,
)
->
Any
:
"""Runs the given method on all workers."""
raise
NotImplementedError
async
def
execute_model_async
(
self
,
*
args
,
Passing None will cause the driver to stop the model execution
**
kwargs
)
->
List
[
SamplerOutput
]:
loop running in each of the remote workers.
all_outputs
=
await
self
.
_run_workers_async
(
"execute_model"
,
"""
driver_args
=
args
,
raise
NotImplementedError
driver_kwargs
=
kwargs
)
# Only the driver worker returns the sampling results.
@
abstractmethod
return
all_outputs
[
0
]
async
def
_start_worker_execution_loop
(
self
):
"""Run execution loop on all workers. It guarantees all workers run
the loop or None of them is running the loop. Loop can be stopped by
`stop_remote_worker_execution_loop`.
The API is idempotent (guarantee only 1 loop run at any moment)."""
raise
NotImplementedError
vllm/executor/executor_base.py
View file @
b9e12416
...
@@ -74,6 +74,10 @@ class ExecutorBase(ABC):
...
@@ -74,6 +74,10 @@ class ExecutorBase(ABC):
"""Executes at least one model step on the given sequences."""
"""Executes at least one model step on the given sequences."""
raise
NotImplementedError
raise
NotImplementedError
def
stop_remote_worker_execution_loop
(
self
)
->
None
:
"""Releases parallel workers from model loop."""
return
@
abstractmethod
@
abstractmethod
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -109,6 +113,10 @@ class ExecutorAsyncBase(ExecutorBase):
...
@@ -109,6 +113,10 @@ class ExecutorAsyncBase(ExecutorBase):
"""Executes one model step on the given sequences."""
"""Executes one model step on the given sequences."""
raise
NotImplementedError
raise
NotImplementedError
async
def
stop_remote_worker_execution_loop_async
(
self
)
->
None
:
"""Releases parallel workers from model loop."""
return
async
def
check_health_async
(
self
)
->
None
:
async
def
check_health_async
(
self
)
->
None
:
"""Checks if the executor is healthy. If not, it should raise an
"""Checks if the executor is healthy. If not, it should raise an
exception."""
exception."""
...
...
Prev
1
…
11
12
13
14
15
16
17
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment