Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7c4f76e3
Commit
7c4f76e3
authored
Apr 15, 2024
by
zhuwenwen
Browse files
merge v0.4.0
parents
2da0dd3e
51c31bc1
Changes
332
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1931 additions
and
1080 deletions
+1931
-1080
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+199
-459
vllm/engine/metrics.py
vllm/engine/metrics.py
+19
-13
vllm/engine/ray_utils.py
vllm/engine/ray_utils.py
+38
-36
vllm/entrypoints/api_server.py
vllm/entrypoints/api_server.py
+24
-12
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+28
-18
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+28
-105
vllm/entrypoints/openai/cli_args.py
vllm/entrypoints/openai/cli_args.py
+118
-0
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+147
-47
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+148
-119
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+211
-208
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+33
-15
vllm/executor/__init__.py
vllm/executor/__init__.py
+0
-0
vllm/executor/executor_base.py
vllm/executor/executor_base.py
+76
-0
vllm/executor/gpu_executor.py
vllm/executor/gpu_executor.py
+158
-0
vllm/executor/neuron_executor.py
vllm/executor/neuron_executor.py
+81
-0
vllm/executor/ray_gpu_executor.py
vllm/executor/ray_gpu_executor.py
+433
-0
vllm/executor/utils.py
vllm/executor/utils.py
+13
-0
vllm/logger.py
vllm/logger.py
+1
-1
vllm/lora/layers.py
vllm/lora/layers.py
+173
-45
vllm/lora/lora.py
vllm/lora/lora.py
+3
-2
No files found.
vllm/engine/llm_engine.py
View file @
7c4f76e3
import
copy
from
collections
import
defaultdict
import
os
import
time
import
pickle
import
importlib
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
)
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Type
,
Union
from
vllm.lora.request
import
LoRARequest
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
LoRAConfig
)
from
transformers
import
PreTrainedTokenizer
import
vllm
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.core.scheduler
import
Scheduler
,
SchedulerOutputs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.metrics
import
StatLogger
,
Stats
from
vllm.engine.ray_utils
import
RayWorkerVllm
,
initialize_cluster
,
ray
from
vllm.engine.ray_utils
import
initialize_ray_cluster
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.model_loader
import
get_architecture_class_name
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
SamplerOutput
,
Sequence
,
SequenceGroup
,
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
from
vllm.transformers_utils.tokenizer
import
(
detokenize_incrementally
,
TokenizerGroup
)
from
vllm.utils
import
(
Counter
,
set_cuda_visible_devices
,
get_ip
,
get_open_port
,
get_distributed_init_method
)
if
ray
:
from
ray.util.scheduling_strategies
import
PlacementGroupSchedulingStrategy
if
TYPE_CHECKING
:
from
ray.util.placement_group
import
PlacementGroup
from
vllm.sequence
import
(
MultiModalData
,
SamplerOutput
,
Sequence
,
SequenceGroup
,
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.tokenizer_group
import
(
BaseTokenizerGroup
,
get_tokenizer_group
)
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
usage_message
)
from
vllm.utils
import
Counter
logger
=
init_logger
(
__name__
)
_LOCAL_LOGGING_INTERVAL_SEC
=
5
# A map between the device type (in device config) to its worker module.
DEVICE_TO_WORKER_MODULE_MAP
=
{
"cuda"
:
"vllm.worker.worker"
,
"neuron"
:
"vllm.worker.neuron_worker"
,
}
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run VLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
USE_RAY_COMPILED_DAG
=
bool
(
os
.
getenv
(
"VLLM_USE_RAY_COMPILED_DAG"
,
0
))
class
LLMEngine
:
"""An LLM engine that receives requests and generates texts.
...
...
@@ -68,9 +53,10 @@ class LLMEngine:
parallel_config: The configuration related to distributed execution.
scheduler_config: The configuration related to the request scheduler.
device_config: The configuration related to the device.
placement_group: Ray placement group for distributed execution.
Required for distributed
execution.
executor_class: The model executor class for managing distributed
execution.
log_stats: Whether to log statistics.
usage_context: Specified entry point, used for usage info collection
"""
def
__init__
(
...
...
@@ -81,11 +67,13 @@ class LLMEngine:
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
placement_group
:
Optional
[
"PlacementGroup"
],
vision_language_config
:
Optional
[
"VisionLanguageConfig"
],
executor_class
:
Type
[
ExecutorBase
],
log_stats
:
bool
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
)
->
None
:
logger
.
info
(
"Initializing an LLM engine with config: "
f
"Initializing an LLM engine
(v
{
vllm
.
__version__
}
)
with config: "
f
"model=
{
model_config
.
model
!
r
}
, "
f
"tokenizer=
{
model_config
.
tokenizer
!
r
}
, "
f
"tokenizer_mode=
{
model_config
.
tokenizer_mode
}
, "
...
...
@@ -97,7 +85,8 @@ class LLMEngine:
f
"download_dir=
{
model_config
.
download_dir
!
r
}
, "
f
"load_format=
{
model_config
.
load_format
}
, "
f
"tensor_parallel_size=
{
parallel_config
.
tensor_parallel_size
}
, "
f
"disable_custom_all_reduce=
{
parallel_config
.
disable_custom_all_reduce
}
, "
f
"disable_custom_all_reduce="
f
"
{
parallel_config
.
disable_custom_all_reduce
}
, "
f
"quantization=
{
model_config
.
quantization
}
, "
f
"enforce_eager=
{
model_config
.
enforce_eager
}
, "
f
"kv_cache_dtype=
{
cache_config
.
cache_dtype
}
, "
...
...
@@ -108,6 +97,7 @@ class LLMEngine:
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
self
.
lora_config
=
lora_config
self
.
vision_language_config
=
vision_language_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
...
...
@@ -115,22 +105,54 @@ class LLMEngine:
self
.
_verify_args
()
self
.
_init_tokenizer
()
self
.
detokenizer
=
Detokenizer
(
self
.
tokenizer
)
self
.
seq_counter
=
Counter
()
# Create the parallel GPU workers.
if
self
.
parallel_config
.
worker_use_ray
:
# Disable Ray usage stats collection.
ray_usage
=
os
.
environ
.
get
(
"RAY_USAGE_STATS_ENABLED"
,
"0"
)
if
ray_usage
!=
"1"
:
os
.
environ
[
"RAY_USAGE_STATS_ENABLED"
]
=
"0"
self
.
_init_workers_ray
(
placement_group
)
else
:
self
.
_init_workers
()
# Profile the memory usage and initialize the cache.
self
.
_init_cache
()
self
.
model_executor
=
executor_class
(
model_config
,
cache_config
,
parallel_config
,
scheduler_config
,
device_config
,
lora_config
,
vision_language_config
)
# If usage stat is enabled, collect relevant info.
if
is_usage_stats_enabled
():
usage_message
.
report_usage
(
get_architecture_class_name
(
model_config
),
usage_context
,
extra_kvs
=
{
# Common configuration
"dtype"
:
str
(
model_config
.
dtype
),
"tensor_parallel_size"
:
parallel_config
.
tensor_parallel_size
,
"block_size"
:
cache_config
.
block_size
,
"gpu_memory_utilization"
:
cache_config
.
gpu_memory_utilization
,
# Quantization
"quantization"
:
model_config
.
quantization
,
"kv_cache_dtype"
:
cache_config
.
cache_dtype
,
# Feature flags
"enable_lora"
:
bool
(
lora_config
),
"enable_prefix_caching"
:
cache_config
.
enable_prefix_caching
,
"enforce_eager"
:
model_config
.
enforce_eager
,
"disable_custom_all_reduce"
:
parallel_config
.
disable_custom_all_reduce
,
})
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
self
.
tokenizer
.
ping
()
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
self
.
scheduler
=
Scheduler
(
scheduler_config
,
cache_config
,
lora_config
)
# Metric Logging.
...
...
@@ -140,48 +162,56 @@ class LLMEngine:
labels
=
dict
(
model_name
=
model_config
.
model
))
self
.
stat_logger
.
info
(
"cache_config"
,
self
.
cache_config
)
self
.
forward_dag
=
None
if
USE_RAY_COMPILED_DAG
:
self
.
forward_dag
=
self
.
_compiled_ray_dag
()
def
get_tokenizer_for_seq
(
self
,
sequence
:
Sequence
):
return
self
.
tokenizer
.
get_lora_tokenizer
(
sequence
.
lora_request
)
@
classmethod
def
from_engine_args
(
cls
,
engine_args
:
EngineArgs
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
)
->
"LLMEngine"
:
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
engine_configs
=
engine_args
.
create_engine_configs
()
parallel_config
=
engine_configs
[
2
]
device_config
=
engine_configs
[
4
]
# Initialize the cluster and specify the executor class.
if
device_config
.
device_type
==
"neuron"
:
from
vllm.executor.neuron_executor
import
NeuronExecutor
executor_class
=
NeuronExecutor
elif
parallel_config
.
worker_use_ray
:
initialize_ray_cluster
(
parallel_config
)
from
vllm.executor.ray_gpu_executor
import
RayGPUExecutor
executor_class
=
RayGPUExecutor
else
:
assert
parallel_config
.
world_size
==
1
,
(
"Ray is required if parallel_config.world_size > 1."
)
from
vllm.executor.gpu_executor
import
GPUExecutor
executor_class
=
GPUExecutor
def
_dispatch_worker
(
self
):
worker_module
=
DEVICE_TO_WORKER_MODULE_MAP
[
self
.
device_config
.
device_type
]
imported_worker
=
importlib
.
import_module
(
worker_module
)
Worker
=
imported_worker
.
Worker
return
Worker
def
_init_workers
(
self
):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
Worker
=
self
.
_dispatch_worker
()
assert
self
.
parallel_config
.
world_size
==
1
,
(
"Ray is required if parallel_config.world_size > 1."
)
self
.
workers
:
List
[
Worker
]
=
[]
distributed_init_method
=
get_distributed_init_method
(
get_ip
(),
get_open_port
())
self
.
driver_worker
=
Worker
(
self
.
model_config
,
self
.
parallel_config
,
self
.
scheduler_config
,
self
.
device_config
,
local_rank
=
0
,
rank
=
0
,
distributed_init_method
=
distributed_init_method
,
lora_config
=
self
.
lora_config
,
kv_cache_dtype
=
self
.
cache_config
.
cache_dtype
,
is_driver_worker
=
True
,
# Create the LLM engine.
engine
=
cls
(
*
engine_configs
,
executor_class
=
executor_class
,
log_stats
=
not
engine_args
.
disable_log_stats
,
usage_context
=
usage_context
,
)
self
.
_run_workers
(
"init_model"
)
self
.
_run_workers
(
"load_model"
)
return
engine
def
__reduce__
(
self
):
# This is to ensure that the LLMEngine is not referenced in
# the closure used to initialize Ray worker actors
raise
RuntimeError
(
"LLMEngine should not be pickled!"
)
def
get_tokenizer
(
self
)
->
"PreTrainedTokenizer"
:
return
self
.
tokenizer
.
get_lora_tokenizer
(
None
)
def
get_tokenizer_for_seq
(
self
,
sequence
:
Sequence
)
->
"PreTrainedTokenizer"
:
return
self
.
tokenizer
.
get_lora_tokenizer
(
sequence
.
lora_request
)
def
_init_tokenizer
(
self
,
**
tokenizer_init_kwargs
):
init_kwargs
=
dict
(
tokenizer_id
=
self
.
model_config
.
tokenizer
,
enable_lora
=
bool
(
self
.
lora_config
),
max_num_seqs
=
self
.
scheduler_config
.
max_num_seqs
,
max_input_length
=
None
,
...
...
@@ -189,126 +219,8 @@ class LLMEngine:
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
revision
=
self
.
model_config
.
tokenizer_revision
)
init_kwargs
.
update
(
tokenizer_init_kwargs
)
self
.
tokenizer
:
TokenizerGroup
=
TokenizerGroup
(
self
.
model_config
.
tokenizer
,
**
init_kwargs
)
def
_init_workers_ray
(
self
,
placement_group
:
"PlacementGroup"
,
**
ray_remote_kwargs
):
if
self
.
parallel_config
.
tensor_parallel_size
==
1
:
num_gpus
=
self
.
cache_config
.
gpu_memory_utilization
else
:
num_gpus
=
1
self
.
driver_dummy_worker
:
RayWorkerVllm
=
None
self
.
workers
:
List
[
RayWorkerVllm
]
=
[]
driver_ip
=
get_ip
()
for
bundle_id
,
bundle
in
enumerate
(
placement_group
.
bundle_specs
):
if
not
bundle
.
get
(
"GPU"
,
0
):
continue
scheduling_strategy
=
PlacementGroupSchedulingStrategy
(
placement_group
=
placement_group
,
placement_group_capture_child_tasks
=
True
,
placement_group_bundle_index
=
bundle_id
,
)
worker
=
ray
.
remote
(
num_cpus
=
0
,
num_gpus
=
num_gpus
,
scheduling_strategy
=
scheduling_strategy
,
**
ray_remote_kwargs
,
)(
RayWorkerVllm
).
remote
(
self
.
model_config
.
trust_remote_code
)
worker_ip
=
ray
.
get
(
worker
.
get_node_ip
.
remote
())
if
worker_ip
==
driver_ip
and
self
.
driver_dummy_worker
is
None
:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self
.
driver_dummy_worker
=
worker
else
:
self
.
workers
.
append
(
worker
)
if
self
.
driver_dummy_worker
is
None
:
raise
ValueError
(
"Ray does not allocate any GPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a "
"GPU node."
)
driver_node_id
,
driver_gpu_ids
=
ray
.
get
(
self
.
driver_dummy_worker
.
get_node_and_gpu_ids
.
remote
())
worker_node_and_gpu_ids
=
ray
.
get
(
[
worker
.
get_node_and_gpu_ids
.
remote
()
for
worker
in
self
.
workers
])
node_workers
=
defaultdict
(
list
)
node_gpus
=
defaultdict
(
list
)
node_workers
[
driver_node_id
].
append
(
0
)
node_gpus
[
driver_node_id
].
extend
(
driver_gpu_ids
)
for
i
,
(
node_id
,
gpu_ids
)
in
enumerate
(
worker_node_and_gpu_ids
,
start
=
1
):
node_workers
[
node_id
].
append
(
i
)
node_gpus
[
node_id
].
extend
(
gpu_ids
)
for
node_id
,
gpu_ids
in
node_gpus
.
items
():
node_gpus
[
node_id
]
=
sorted
(
gpu_ids
)
# Set CUDA_VISIBLE_DEVICES for the driver.
set_cuda_visible_devices
(
node_gpus
[
driver_node_id
])
for
worker
,
(
node_id
,
_
)
in
zip
(
self
.
workers
,
worker_node_and_gpu_ids
):
worker
.
set_cuda_visible_devices
.
remote
(
node_gpus
[
node_id
])
distributed_init_method
=
get_distributed_init_method
(
driver_ip
,
get_open_port
())
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
Worker
=
self
.
_dispatch_worker
()
# Initialize torch distributed process group for the workers.
model_config
=
copy
.
deepcopy
(
self
.
model_config
)
parallel_config
=
copy
.
deepcopy
(
self
.
parallel_config
)
scheduler_config
=
copy
.
deepcopy
(
self
.
scheduler_config
)
device_config
=
copy
.
deepcopy
(
self
.
device_config
)
for
rank
,
(
worker
,
(
node_id
,
_
))
in
enumerate
(
zip
(
self
.
workers
,
worker_node_and_gpu_ids
),
start
=
1
):
local_rank
=
node_workers
[
node_id
].
index
(
rank
)
worker
.
init_worker
.
remote
(
lambda
rank
=
rank
,
local_rank
=
local_rank
:
Worker
(
model_config
,
parallel_config
,
scheduler_config
,
device_config
,
local_rank
,
rank
,
distributed_init_method
,
lora_config
=
self
.
lora_config
,
kv_cache_dtype
=
self
.
cache_config
.
cache_dtype
,
))
driver_rank
=
0
driver_local_rank
=
node_workers
[
driver_node_id
].
index
(
driver_rank
)
self
.
driver_worker
=
Worker
(
model_config
,
parallel_config
,
scheduler_config
,
device_config
,
driver_local_rank
,
driver_rank
,
distributed_init_method
,
lora_config
=
self
.
lora_config
,
kv_cache_dtype
=
self
.
cache_config
.
cache_dtype
,
is_driver_worker
=
True
,
)
# don't use cupy for eager mode
self
.
_run_workers
(
"init_model"
,
cupy_port
=
get_open_port
()
if
not
model_config
.
enforce_eager
else
None
)
self
.
_run_workers
(
"load_model"
,
max_concurrent_workers
=
self
.
parallel_config
.
max_parallel_loading_workers
,
)
self
.
tokenizer
:
BaseTokenizerGroup
=
get_tokenizer_group
(
self
.
parallel_config
.
tokenizer_pool_config
,
**
init_kwargs
)
def
_verify_args
(
self
)
->
None
:
self
.
model_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
...
...
@@ -318,81 +230,6 @@ class LLMEngine:
self
.
lora_config
.
verify_with_scheduler_config
(
self
.
scheduler_config
)
def
_init_cache
(
self
)
->
None
:
"""Profiles the memory usage and initializes the KV cache.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
More details can be found in the
:meth:`~vllm.worker.worker.Worker.profile_num_available_blocks` method
from class :class:`~vllm.worker.Worker`.
Afterwards, as there may be multiple workers,
we take the minimum number of blocks across all workers
to ensure this can be applied to all of them.
Finally, the engine will initialize the KV cache
with the calculated number of blocks.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameters.
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks
=
self
.
_run_workers
(
"profile_num_available_blocks"
,
block_size
=
self
.
cache_config
.
block_size
,
gpu_memory_utilization
=
self
.
cache_config
.
gpu_memory_utilization
,
cpu_swap_space
=
self
.
cache_config
.
swap_space_bytes
,
cache_dtype
=
self
.
cache_config
.
cache_dtype
,
)
# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
# operators can be applied to all workers.
num_gpu_blocks
=
min
(
b
[
0
]
for
b
in
num_blocks
)
num_cpu_blocks
=
min
(
b
[
1
]
for
b
in
num_blocks
)
# FIXME(woosuk): Change to debug log.
logger
.
info
(
f
"# GPU blocks:
{
num_gpu_blocks
}
, "
f
"# CPU blocks:
{
num_cpu_blocks
}
"
)
if
num_gpu_blocks
<=
0
:
raise
ValueError
(
"No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine."
)
max_seq_len
=
self
.
cache_config
.
block_size
*
num_gpu_blocks
if
self
.
model_config
.
max_model_len
>
max_seq_len
:
raise
ValueError
(
f
"The model's max seq len (
{
self
.
model_config
.
max_model_len
}
) "
"is larger than the maximum number of tokens that can be "
f
"stored in KV cache (
{
max_seq_len
}
). Try increasing "
"`gpu_memory_utilization` or decreasing `max_model_len` when "
"initializing the engine."
)
self
.
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
self
.
cache_config
.
num_cpu_blocks
=
num_cpu_blocks
# Initialize the cache.
self
.
_run_workers
(
"init_cache_engine"
,
cache_config
=
self
.
cache_config
)
# Warm up the model. This includes capturing the model into CUDA graph
# if enforce_eager is False.
self
.
_run_workers
(
"warm_up_model"
)
@
classmethod
def
from_engine_args
(
cls
,
engine_args
:
EngineArgs
)
->
"LLMEngine"
:
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
engine_configs
=
engine_args
.
create_engine_configs
()
parallel_config
=
engine_configs
[
2
]
# Initialize the cluster.
placement_group
=
initialize_cluster
(
parallel_config
)
# Create the LLM engine.
engine
=
cls
(
*
engine_configs
,
placement_group
,
log_stats
=
not
engine_args
.
disable_log_stats
)
return
engine
def
encode_request
(
self
,
request_id
:
str
,
# pylint: disable=unused-argument
...
...
@@ -415,7 +252,7 @@ class LLMEngine:
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prefix_pos
:
Optional
[
int
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
)
->
None
:
"""Add a request to the engine's request pool.
...
...
@@ -432,11 +269,7 @@ class LLMEngine:
use the tokenizer to convert the prompts to token IDs.
arrival_time: The arrival time of the request. If None, we use
the current monotonic time.
prefix_pos: If not None, we use the given position as the prefix
position for each prompt. We will cache the prefix's KV
cache and reuse it for the next request with the same prefix.
This is an experimental feature, and may be replaced with
automatic prefix caching in the future.
multi_modal_data: Multi modal data per request.
Details:
- Set arrival_time to the current time if it is None.
...
...
@@ -465,8 +298,15 @@ class LLMEngine:
if
lora_request
is
not
None
and
not
self
.
lora_config
:
raise
ValueError
(
f
"Got lora_request
{
lora_request
}
but LoRA is "
"not enabled!"
)
max_logprobs
=
self
.
get_model_config
().
max_logprobs
if
(
sampling_params
.
logprobs
and
sampling_params
.
logprobs
>
max_logprobs
)
or
(
sampling_params
.
prompt_logprobs
and
sampling_params
.
prompt_logprobs
>
max_logprobs
):
raise
ValueError
(
f
"Cannot request more than "
f
"
{
max_logprobs
}
logprobs."
)
if
arrival_time
is
None
:
arrival_time
=
time
.
monotonic
()
arrival_time
=
time
.
time
()
prompt_token_ids
=
self
.
encode_request
(
request_id
=
request_id
,
prompt
=
prompt
,
...
...
@@ -476,21 +316,21 @@ class LLMEngine:
# Create the sequences.
block_size
=
self
.
cache_config
.
block_size
seq_id
=
next
(
self
.
seq_counter
)
eos_token_id
=
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
).
eos_token_id
seq
=
Sequence
(
seq_id
,
prompt
,
prompt_token_ids
,
block_size
,
lora_request
)
# Check whether the input specifies prefix
prefix
=
self
.
scheduler
.
prefix_pool
.
add_or_get_prefix
(
prompt_token_ids
[:
prefix_pos
],
lora_request
.
lora_int_id
if
lora_request
else
0
)
if
prefix_pos
is
not
None
else
None
eos_token_id
,
lora_request
)
# Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects
sampling_params
=
sampling_params
.
clone
()
# inject the eos token id into the sampling_params to support min_tokens
# processing
sampling_params
.
eos_token_id
=
seq
.
eos_token_id
# Create the sequence group.
seq_group
=
SequenceGroup
(
request_id
,
[
seq
],
sampling_params
,
arrival_time
,
lora_request
,
prefix
)
arrival_time
,
lora_request
,
multi_modal_data
)
# Add the sequence group to the scheduler.
self
.
scheduler
.
add_seq_group
(
seq_group
)
...
...
@@ -538,15 +378,13 @@ class LLMEngine:
if
early_stopping
is
True
:
return
True
current_worst_score
=
(
current_worst_seq
.
get_beam_search_score
(
current_worst_score
=
current_worst_seq
.
get_beam_search_score
(
length_penalty
=
length_penalty
,
eos_token_id
=
self
.
get_tokenizer_for_seq
(
current_worst_seq
).
eos_token_id
))
eos_token_id
=
current_worst_seq
.
eos_token_id
)
if
early_stopping
is
False
:
highest_attainable_score
=
(
best_running_seq
.
get_beam_search_score
(
highest_attainable_score
=
best_running_seq
.
get_beam_search_score
(
length_penalty
=
length_penalty
,
eos_token_id
=
self
.
get_tokenizer_for_seq
(
best_running_seq
).
eos_token_id
))
eos_token_id
=
best_running_seq
.
eos_token_id
)
else
:
assert
early_stopping
==
"never"
if
length_penalty
>
0.0
:
...
...
@@ -560,8 +398,7 @@ class LLMEngine:
highest_attainable_score
=
(
best_running_seq
.
get_beam_search_score
(
length_penalty
=
length_penalty
,
eos_token_id
=
self
.
get_tokenizer_for_seq
(
best_running_seq
).
eos_token_id
,
eos_token_id
=
best_running_seq
.
eos_token_id
,
seq_len
=
max_possible_length
))
else
:
# Otherwise, beam search will prefer shorter sequences. The
...
...
@@ -570,8 +407,7 @@ class LLMEngine:
highest_attainable_score
=
(
best_running_seq
.
get_beam_search_score
(
length_penalty
=
length_penalty
,
eos_token_id
=
self
.
get_tokenizer_for_seq
(
best_running_seq
).
eos_token_id
))
eos_token_id
=
best_running_seq
.
eos_token_id
))
return
current_worst_score
>=
highest_attainable_score
def
_process_sequence_group_outputs
(
self
,
seq_group
:
SequenceGroup
,
...
...
@@ -580,6 +416,8 @@ class LLMEngine:
# Process prompt logprobs
prompt_logprobs
=
outputs
.
prompt_logprobs
if
prompt_logprobs
is
not
None
:
self
.
detokenizer
.
decode_prompt_logprobs_inplace
(
seq_group
,
prompt_logprobs
)
seq_group
.
prompt_logprobs
=
prompt_logprobs
# Process samples
...
...
@@ -623,7 +461,8 @@ class LLMEngine:
child_seqs
.
append
((
parent
,
parent
))
for
seq
,
_
in
child_seqs
:
self
.
_decode_sequence
(
seq
,
seq_group
.
sampling_params
)
self
.
detokenizer
.
decode_sequence_inplace
(
seq
,
seq_group
.
sampling_params
)
self
.
_check_stop
(
seq
,
seq_group
.
sampling_params
)
# Non-beam search case
...
...
@@ -662,8 +501,7 @@ class LLMEngine:
all_finished_seqs
=
existing_finished_seqs
+
new_finished_seqs
# Sort the finished sequences by their scores.
all_finished_seqs
.
sort
(
key
=
lambda
x
:
x
[
0
].
get_beam_search_score
(
length_penalty
=
length_penalty
,
eos_token_id
=
self
.
get_tokenizer_for_seq
(
x
[
0
]).
eos_token_id
),
length_penalty
=
length_penalty
,
eos_token_id
=
x
[
0
].
eos_token_id
),
reverse
=
True
)
for
seq
,
parent
,
is_new
in
all_finished_seqs
[:
beam_width
]:
if
is_new
:
...
...
@@ -690,8 +528,7 @@ class LLMEngine:
if
not
seq
.
is_finished
()]
# Sort the running sequences by their scores.
running_child_seqs
.
sort
(
key
=
lambda
x
:
x
[
0
].
get_beam_search_score
(
length_penalty
=
length_penalty
,
eos_token_id
=
self
.
get_tokenizer_for_seq
(
x
[
0
]).
eos_token_id
),
length_penalty
=
length_penalty
,
eos_token_id
=
x
[
0
].
eos_token_id
),
reverse
=
True
)
# Check if we can stop the beam search.
...
...
@@ -752,7 +589,11 @@ class LLMEngine:
now
=
time
.
time
()
# Update the scheduled sequence groups with the model outputs.
scheduled_seq_groups
=
scheduler_outputs
.
scheduled_seq_groups
for
seq_group
,
outputs
in
zip
(
scheduled_seq_groups
,
output
):
for
scheduled_seq_group
,
outputs
in
zip
(
scheduled_seq_groups
,
output
):
seq_group
=
scheduled_seq_group
.
seq_group
token_chunk_size
=
scheduled_seq_group
.
token_chunk_size
seq_group
.
update_num_computed_tokens
(
token_chunk_size
)
self
.
_process_sequence_group_outputs
(
seq_group
,
outputs
)
# Free the finished sequence groups.
...
...
@@ -760,7 +601,8 @@ class LLMEngine:
# Create the outputs.
request_outputs
:
List
[
RequestOutput
]
=
[]
for
seq_group
in
scheduled_seq_groups
:
for
scheduled_seq_group
in
scheduled_seq_groups
:
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
maybe_set_first_token_time
(
now
)
request_output
=
RequestOutput
.
from_seq_group
(
seq_group
)
request_outputs
.
append
(
request_output
)
...
...
@@ -768,16 +610,9 @@ class LLMEngine:
request_output
=
RequestOutput
.
from_seq_group
(
seq_group
)
request_outputs
.
append
(
request_output
)
# Update prefix state, now all the uncomputed prefixes are computed.
for
seq_group
in
scheduled_seq_groups
:
if
(
seq_group
.
prefix
is
not
None
and
seq_group
.
prefix
.
allocated
and
not
seq_group
.
prefix
.
computed
):
seq_group
.
prefix
.
computed
=
True
# Log stats.
if
self
.
log_stats
:
self
.
stat_logger
.
log
(
self
.
_get_stats
(
scheduler_outputs
))
return
request_outputs
def
step
(
self
)
->
List
[
RequestOutput
]:
...
...
@@ -798,7 +633,7 @@ class LLMEngine:
- A Sequence Group (SG) refer to a group of sequences
that are generated from the same prompt.
- Step 2: Calls the
workers
to execute the model.
- Step 2: Calls the
distributed executor
to execute the model.
- Step 3: Processes the model output. This mainly includes:
- Decodes the relevant outputs.
...
...
@@ -834,19 +669,10 @@ class LLMEngine:
seq_group_metadata_list
,
scheduler_outputs
=
self
.
scheduler
.
schedule
()
if
not
scheduler_outputs
.
is_empty
():
# Execute the model.
all_outputs
=
self
.
_run_workers
(
"execute_model"
,
driver_kwargs
=
{
"seq_group_metadata_list"
:
seq_group_metadata_list
,
"blocks_to_swap_in"
:
scheduler_outputs
.
blocks_to_swap_in
,
"blocks_to_swap_out"
:
scheduler_outputs
.
blocks_to_swap_out
,
"blocks_to_copy"
:
scheduler_outputs
.
blocks_to_copy
,
},
use_ray_compiled_dag
=
USE_RAY_COMPILED_DAG
)
# Only the driver worker returns the sampling results.
output
=
all_outputs
[
0
]
output
=
self
.
model_executor
.
execute_model
(
seq_group_metadata_list
,
scheduler_outputs
.
blocks_to_swap_in
,
scheduler_outputs
.
blocks_to_swap_out
,
scheduler_outputs
.
blocks_to_copy
)
else
:
output
=
[]
...
...
@@ -860,7 +686,7 @@ class LLMEngine:
def
_get_stats
(
self
,
scheduler_outputs
:
Optional
[
SchedulerOutputs
])
->
Stats
:
"""Get Stats to be Logged to Prometheus."""
now
=
time
.
monotonic
()
now
=
time
.
time
()
# KV Cache Usage in %.
num_total_gpu
=
self
.
cache_config
.
num_gpu_blocks
...
...
@@ -891,18 +717,22 @@ class LLMEngine:
# Number of Tokens.
if
prompt_run
:
num_prompt_tokens
=
sum
(
len
(
seq_group
.
prompt_token_ids
)
for
seq_group
in
scheduler_outputs
.
scheduled_seq_groups
)
len
(
scheduled_seq_group
.
seq_group
.
prompt_token_ids
)
for
scheduled_seq_group
in
scheduler_outputs
.
scheduled_seq_groups
)
num_generation_tokens
=
sum
(
seq_group
.
num_seqs
()
for
seq_group
in
scheduler_outputs
.
scheduled_seq_groups
)
scheduled_seq_group
.
seq_group
.
num_seqs
()
for
scheduled_seq_group
in
scheduler_outputs
.
scheduled_seq_groups
)
else
:
num_generation_tokens
=
scheduler_outputs
.
num_batched_tokens
# Latency Timings.
time_last_iters
=
[]
for
seq_group
in
scheduler_outputs
.
scheduled_seq_groups
:
# Time since last token. (n.b. updates seq_group.metrics.last_token_time)
for
scheduled_seq_group
in
scheduler_outputs
.
scheduled_seq_groups
:
seq_group
=
scheduled_seq_group
.
seq_group
# Time since last token.
# (n.b. updates seq_group.metrics.last_token_time)
time_last_iters
.
append
(
seq_group
.
get_last_latency
(
now
))
# Time since arrival for all finished requests.
if
seq_group
.
is_finished
():
...
...
@@ -926,41 +756,9 @@ class LLMEngine:
time_e2e_requests
=
time_e2e_requests
,
)
def
_decode_sequence
(
self
,
seq
:
Sequence
,
prms
:
SamplingParams
)
->
None
:
"""Decodes the new token for a sequence."""
(
new_tokens
,
new_output_text
,
prefix_offset
,
read_offset
)
=
detokenize_incrementally
(
self
.
get_tokenizer_for_seq
(
seq
),
all_input_ids
=
seq
.
get_token_ids
(),
prev_tokens
=
seq
.
tokens
,
prefix_offset
=
seq
.
prefix_offset
,
read_offset
=
seq
.
read_offset
,
skip_special_tokens
=
prms
.
skip_special_tokens
,
spaces_between_special_tokens
=
prms
.
spaces_between_special_tokens
,
)
if
seq
.
tokens
is
None
:
seq
.
tokens
=
new_tokens
else
:
seq
.
tokens
.
extend
(
new_tokens
)
seq
.
prefix_offset
=
prefix_offset
seq
.
read_offset
=
read_offset
seq
.
output_text
+=
new_output_text
def
_check_stop
(
self
,
seq
:
Sequence
,
sampling_params
:
SamplingParams
)
->
None
:
"""Stop the finished sequences."""
for
stop_str
in
sampling_params
.
stop
:
if
seq
.
output_text
.
endswith
(
stop_str
):
self
.
_finalize_sequence
(
seq
,
sampling_params
,
stop_str
)
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
return
if
seq
.
get_last_token_id
()
in
sampling_params
.
stop_token_ids
:
stop_str
=
self
.
get_tokenizer_for_seq
(
seq
).
convert_ids_to_tokens
(
seq
.
get_last_token_id
())
self
.
_finalize_sequence
(
seq
,
sampling_params
,
stop_str
)
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
return
# Check if the sequence has reached max_model_len.
if
seq
.
get_len
()
>
self
.
scheduler_config
.
max_model_len
:
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
...
...
@@ -971,9 +769,29 @@ class LLMEngine:
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
return
# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
if
seq
.
get_output_len
()
<
sampling_params
.
min_tokens
:
return
for
stop_str
in
sampling_params
.
stop
:
if
seq
.
output_text
.
endswith
(
stop_str
):
self
.
_finalize_sequence
(
seq
,
sampling_params
,
stop_str
)
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
stop_reason
=
stop_str
return
last_token_id
=
seq
.
get_last_token_id
()
if
last_token_id
in
sampling_params
.
stop_token_ids
:
stop_str
=
self
.
get_tokenizer_for_seq
(
seq
).
convert_ids_to_tokens
(
last_token_id
)
self
.
_finalize_sequence
(
seq
,
sampling_params
,
stop_str
)
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
stop_reason
=
last_token_id
return
# Check if the sequence has generated the EOS token.
if
((
not
sampling_params
.
ignore_eos
)
and
seq
.
get_last_token_id
()
==
se
lf
.
get_
tokenizer_for_seq
(
seq
)
.
eos_token_id
):
if
((
not
sampling_params
.
ignore_eos
)
and
se
q
.
get_
last_token_id
()
==
seq
.
eos_token_id
):
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
return
...
...
@@ -989,91 +807,13 @@ class LLMEngine:
seq
.
output_text
=
seq
.
output_text
[:
-
len
(
stop_string
)]
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
assert
lora_request
.
lora_int_id
>
0
,
"lora_id must be greater than 0."
return
self
.
_run_workers
(
"add_lora"
,
lora_request
=
lora_request
,
)
return
self
.
model_executor
.
add_lora
(
lora_request
)
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
assert
lora_id
>
0
,
"lora_id must be greater than 0."
return
self
.
_run_workers
(
"remove_lora"
,
lora_id
=
lora_id
,
)
return
self
.
model_executor
.
remove_lora
(
lora_id
)
def
list_loras
(
self
)
->
List
[
int
]:
return
self
.
_run_workers
(
"
list_loras
"
)
return
self
.
model_executor
.
list_loras
(
)
def
_run_workers
(
self
,
method
:
str
,
*
args
,
driver_args
:
Optional
[
List
[
Any
]]
=
None
,
driver_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_concurrent_workers
:
Optional
[
int
]
=
None
,
use_ray_compiled_dag
:
bool
=
False
,
**
kwargs
,
)
->
Any
:
"""Runs the given method on all workers."""
if
max_concurrent_workers
:
raise
NotImplementedError
(
"max_concurrent_workers is not supported yet."
)
if
use_ray_compiled_dag
:
# Right now, compiled DAG can only accept a single
# input. TODO(sang): Fix it.
output_channels
=
self
.
forward_dag
.
execute
(
1
)
else
:
# Start the ray workers first.
ray_worker_outputs
=
[
worker
.
execute_method
.
remote
(
method
,
*
args
,
**
kwargs
)
for
worker
in
self
.
workers
]
if
driver_args
is
None
:
driver_args
=
args
if
driver_kwargs
is
None
:
driver_kwargs
=
kwargs
# Start the driver worker after all the ray workers.
driver_worker_output
=
getattr
(
self
.
driver_worker
,
method
)(
*
driver_args
,
**
driver_kwargs
)
# Get the results of the ray workers.
if
self
.
workers
:
if
use_ray_compiled_dag
:
try
:
ray_worker_outputs
=
[
pickle
.
loads
(
chan
.
begin_read
())
for
chan
in
output_channels
]
finally
:
# Has to call end_read in order to reuse the DAG.
for
chan
in
output_channels
:
chan
.
end_read
()
else
:
ray_worker_outputs
=
ray
.
get
(
ray_worker_outputs
)
return
[
driver_worker_output
]
+
ray_worker_outputs
def
_compiled_ray_dag
(
self
):
import
pkg_resources
required_version
=
"2.9"
current_version
=
pkg_resources
.
get_distribution
(
"ray"
).
version
if
current_version
<
required_version
:
raise
ValueError
(
f
"Ray version
{
required_version
}
or greater is "
f
"required, but found
{
current_version
}
"
)
from
ray.dag
import
MultiOutputNode
,
InputNode
assert
self
.
parallel_config
.
worker_use_ray
# Right now, compiled DAG requires at least 1 arg. We send
# a dummy value for now. It will be fixed soon.
with
InputNode
()
as
input_data
:
forward_dag
=
MultiOutputNode
([
worker
.
execute_model_compiled_dag_remote
.
bind
(
input_data
)
for
worker
in
self
.
workers
])
return
forward_dag
.
experimental_compile
()
def
check_health
(
self
)
->
None
:
self
.
model_executor
.
check_health
()
vllm/engine/metrics.py
View file @
7c4f76e3
from
vllm.logger
import
init_logger
from
prometheus_client
import
Counter
,
Gauge
,
Histogram
,
Info
,
REGISTRY
,
disable_created_metrics
import
time
import
numpy
as
np
from
typing
import
Dict
,
List
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
import
numpy
as
np
from
prometheus_client
import
(
REGISTRY
,
Counter
,
Gauge
,
Histogram
,
Info
,
disable_created_metrics
)
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
...
...
@@ -23,6 +25,7 @@ class Metrics:
if
hasattr
(
collector
,
"_name"
)
and
"vllm"
in
collector
.
_name
:
REGISTRY
.
unregister
(
collector
)
# Config Information
self
.
info_cache_config
=
Info
(
name
=
'vllm:cache_config'
,
documentation
=
'information of cache_config'
)
...
...
@@ -176,10 +179,12 @@ class StatLogger:
def
_log_prometheus_interval
(
self
,
prompt_throughput
:
float
,
generation_throughput
:
float
)
->
None
:
# Logs metrics to prometheus that are computed every logging_interval.
# Support legacy gauge metrics that make throughput calculations on the vLLM side.
# Moving forward, we should use counters like counter_prompt_tokens, counter_generation_tokens
# Which log raw data and calculate summaries using rate() on the grafana/prometheus side.
# See https://github.com/vllm-project/vllm/pull/2316#discussion_r1464204666
# Support legacy gauge metrics that make throughput calculations on
# the vLLM side. Moving forward, we should use counters like
# counter_prompt_tokens, counter_generation_tokens
# Which log raw data and calculate summaries using rate() on the
# grafana/prometheus side. See
# https://github.com/vllm-project/vllm/pull/2316#discussion_r1464204666
self
.
metrics
.
gauge_avg_prompt_throughput
.
labels
(
**
self
.
labels
).
set
(
prompt_throughput
)
self
.
metrics
.
gauge_avg_generation_throughput
.
labels
(
...
...
@@ -187,7 +192,7 @@ class StatLogger:
def
log
(
self
,
stats
:
Stats
)
->
None
:
"""Called by LLMEngine.
Logs to prometheus and tracked stats every iteration.
Logs to prometheus and tracked stats every iteration.
Logs to Stdout every self.local_interval seconds."""
# Log to prometheus.
...
...
@@ -199,8 +204,8 @@ class StatLogger:
# Log locally every local_interval seconds.
if
self
.
_local_interval_elapsed
(
stats
.
now
):
#
Compute summary metrics for tracked stats (and log them
to promethus if applicable).
# Compute summary metrics for tracked stats (and log them
# to promethus if applicable).
prompt_throughput
=
self
.
_get_throughput
(
self
.
num_prompt_tokens
,
now
=
stats
.
now
)
generation_throughput
=
self
.
_get_throughput
(
...
...
@@ -212,7 +217,8 @@ class StatLogger:
# Log to stdout.
logger
.
info
(
f
"Avg prompt throughput:
{
prompt_throughput
:.
1
f
}
tokens/s, "
f
"Avg generation throughput:
{
generation_throughput
:.
1
f
}
tokens/s, "
f
"Avg generation throughput: "
f
"
{
generation_throughput
:.
1
f
}
tokens/s, "
f
"Running:
{
stats
.
num_running
}
reqs, "
f
"Swapped:
{
stats
.
num_swapped
}
reqs, "
f
"Pending:
{
stats
.
num_waiting
}
reqs, "
...
...
vllm/engine/ray_utils.py
View file @
7c4f76e3
import
pickle
from
typing
import
Optional
,
List
,
Tuple
,
TYPE_CHECKING
from
typing
import
List
,
Optional
,
Tuple
from
vllm.config
import
ParallelConfig
from
vllm.logger
import
init_logger
from
vllm.utils
import
is_hip
,
set_cuda_visible_devices
,
get_ip
from
vllm.utils
import
get_ip
,
is_hip
,
set_cuda_visible_devices
logger
=
init_logger
(
__name__
)
...
...
@@ -33,8 +32,17 @@ try:
return
getattr
(
self
.
worker
,
name
)
def
execute_method
(
self
,
method
,
*
args
,
**
kwargs
):
executor
=
getattr
(
self
,
method
)
return
executor
(
*
args
,
**
kwargs
)
try
:
executor
=
getattr
(
self
,
method
)
return
executor
(
*
args
,
**
kwargs
)
except
Exception
as
e
:
# exceptions in ray worker may cause deadlock
# see https://github.com/vllm-project/vllm/issues/3455
# print the error and inform the user to solve the error
msg
=
(
f
"Error executing method
{
method
}
. "
"This might cause deadlock in distributed execution."
)
logger
.
exception
(
msg
)
raise
e
def
get_node_ip
(
self
)
->
str
:
return
get_ip
()
...
...
@@ -65,45 +73,38 @@ except ImportError as e:
ray
=
None
RayWorkerVllm
=
None
if
TYPE_CHECKING
:
from
ray.util.placement_group
import
PlacementGroup
def
initialize_cluster
(
def
initialize_
ray_
cluster
(
parallel_config
:
ParallelConfig
,
engine_use_ray
:
bool
=
False
,
ray_address
:
Optional
[
str
]
=
None
,
)
->
Optional
[
"PlacementGroup"
]:
"""Initialize the distributed cluster probably with Ray.
):
"""Initialize the distributed cluster with Ray.
it will connect to the Ray cluster and create a placement group
for the workers, which includes the specification of the resources
for each distributed worker.
Args:
parallel_config: The configurations for parallel execution.
engine_use_ray: Whether to use Ray for async engine.
ray_address: The address of the Ray cluster. If None, uses
the default Ray cluster address.
Returns:
An optional `PlacementGroup`. It includes the specification
of the resources for each distributed worker. None if Ray is
not used.
"""
if
parallel_config
.
worker_use_ray
or
engine_use_ray
:
if
ray
is
None
:
raise
ImportError
(
"Ray is not installed. Please install Ray to use distributed "
"serving."
)
# Connect to a ray cluster.
if
is_hip
():
ray
.
init
(
address
=
ray_address
,
ignore_reinit_error
=
True
,
num_gpus
=
parallel_config
.
world_size
)
else
:
ray
.
init
(
address
=
ray_address
,
ignore_reinit_error
=
True
)
if
not
parallel_config
.
worker_use_ray
:
assert
parallel_config
.
world_size
==
1
,
(
"Ray is required if parallel_config.world_size > 1."
)
return
None
if
ray
is
None
:
raise
ImportError
(
"Ray is not installed. Please install Ray to use distributed "
"serving."
)
# Connect to a ray cluster.
if
is_hip
():
ray
.
init
(
address
=
ray_address
,
ignore_reinit_error
=
True
,
num_gpus
=
parallel_config
.
world_size
)
else
:
ray
.
init
(
address
=
ray_address
,
ignore_reinit_error
=
True
)
if
parallel_config
.
placement_group
:
# Placement group is already set.
return
# Create placement group for worker processes
current_placement_group
=
ray
.
util
.
get_current_placement_group
()
...
...
@@ -138,4 +139,5 @@ def initialize_cluster(
# if they cannot be provisioned.
ray
.
get
(
current_placement_group
.
ready
(),
timeout
=
1800
)
return
current_placement_group
# Set the placement group in the parallel config
parallel_config
.
placement_group
=
current_placement_group
vllm/entrypoints/api_server.py
View file @
7c4f76e3
"""
NOTE: This API server is used only for demonstrating usage of AsyncEngine and simple performance benchmarks.
It is not intended for production use. For production use, we recommend using our OpenAI compatible server.
We are also not going to accept PRs modifying this file, please change `vllm/entrypoints/openai/api_server.py` instead.
NOTE: This API server is used only for demonstrating usage of AsyncEngine
and simple performance benchmarks. It is not intended for production use.
For production use, we recommend using our OpenAI compatible server.
We are also not going to accept PRs modifying this file, please
change `vllm/entrypoints/openai/api_server.py` instead.
"""
import
argparse
import
json
import
ssl
from
typing
import
AsyncGenerator
import
uvicorn
from
fastapi
import
FastAPI
,
Request
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
import
uvicorn
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.sampling_params
import
SamplingParams
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
random_uuid
TIMEOUT_KEEP_ALIVE
=
5
# seconds.
...
...
@@ -39,15 +43,11 @@ async def generate(request: Request) -> Response:
"""
request_dict
=
await
request
.
json
()
prompt
=
request_dict
.
pop
(
"prompt"
)
prefix_pos
=
request_dict
.
pop
(
"prefix_pos"
,
None
)
stream
=
request_dict
.
pop
(
"stream"
,
False
)
sampling_params
=
SamplingParams
(
**
request_dict
)
request_id
=
random_uuid
()
results_generator
=
engine
.
generate
(
prompt
,
sampling_params
,
request_id
,
prefix_pos
=
prefix_pos
)
results_generator
=
engine
.
generate
(
prompt
,
sampling_params
,
request_id
)
# Streaming case
async
def
stream_results
()
->
AsyncGenerator
[
bytes
,
None
]:
...
...
@@ -84,6 +84,16 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
)
parser
.
add_argument
(
"--ssl-keyfile"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--ssl-certfile"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--ssl-ca-certs"
,
type
=
str
,
default
=
None
,
help
=
"The CA certificates file"
)
parser
.
add_argument
(
"--ssl-cert-reqs"
,
type
=
int
,
default
=
int
(
ssl
.
CERT_NONE
),
help
=
"Whether client certificate is required (see stdlib ssl module's)"
)
parser
.
add_argument
(
"--root-path"
,
type
=
str
,
...
...
@@ -91,9 +101,9 @@ if __name__ == "__main__":
help
=
"FastAPI root_path when app is behind a path based routing proxy"
)
parser
=
AsyncEngineArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
engine
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
)
engine
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
,
usage_context
=
UsageContext
.
API_SERVER
)
app
.
root_path
=
args
.
root_path
uvicorn
.
run
(
app
,
...
...
@@ -102,4 +112,6 @@ if __name__ == "__main__":
log_level
=
"debug"
,
timeout_keep_alive
=
TIMEOUT_KEEP_ALIVE
,
ssl_keyfile
=
args
.
ssl_keyfile
,
ssl_certfile
=
args
.
ssl_certfile
)
ssl_certfile
=
args
.
ssl_certfile
,
ssl_ca_certs
=
args
.
ssl_ca_certs
,
ssl_cert_reqs
=
args
.
ssl_cert_reqs
)
vllm/entrypoints/llm.py
View file @
7c4f76e3
from
typing
import
List
,
Optional
,
Union
import
torch
from
tqdm
import
tqdm
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
from
vllm.lora.request
import
LoRARequest
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
MultiModalData
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
Counter
...
...
@@ -83,7 +86,7 @@ class LLM:
swap_space
:
int
=
4
,
enforce_eager
:
bool
=
False
,
max_context_len_to_capture
:
int
=
8192
,
disable_custom_all_reduce
:
bool
=
Fals
e
,
disable_custom_all_reduce
:
bool
=
Tru
e
,
**
kwargs
,
)
->
None
:
if
"disable_log_stats"
not
in
kwargs
:
...
...
@@ -106,7 +109,8 @@ class LLM:
disable_custom_all_reduce
=
disable_custom_all_reduce
,
**
kwargs
,
)
self
.
llm_engine
=
LLMEngine
.
from_engine_args
(
engine_args
)
self
.
llm_engine
=
LLMEngine
.
from_engine_args
(
engine_args
,
usage_context
=
UsageContext
.
LLM_CLASS
)
self
.
request_counter
=
Counter
()
def
get_tokenizer
(
...
...
@@ -124,9 +128,9 @@ class LLM:
prompts
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
prompt_token_ids
:
Optional
[
List
[
List
[
int
]]]
=
None
,
prefix_pos
:
Optional
[
Union
[
int
,
List
[
int
]]]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
)
->
List
[
RequestOutput
]:
"""Generates the completions for the input prompts.
...
...
@@ -140,13 +144,9 @@ class LLM:
None, we use the default sampling parameters.
prompt_token_ids: A list of token IDs for the prompts. If None, we
use the tokenizer to convert the prompts to token IDs.
prefix_pos: If not None, we use the given position as the prefix
position for each prompt. We will cache the prefix's KV
cache and reuse it for the next request with the same prefix.
This is an experimental feature, and may be replaced with
automatic prefix caching in the future.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data.
Returns:
A list of `RequestOutput` objects containing the generated
...
...
@@ -166,19 +166,27 @@ class LLM:
# Use default sampling params.
sampling_params
=
SamplingParams
()
if
multi_modal_data
:
multi_modal_data
.
data
=
multi_modal_data
.
data
.
to
(
torch
.
float16
)
# Add requests to the engine.
num_requests
=
len
(
prompts
)
if
prompts
is
not
None
else
len
(
prompt_token_ids
)
for
i
in
range
(
num_requests
):
prompt
=
prompts
[
i
]
if
prompts
is
not
None
else
None
prefix_pos_i
=
prefix_pos
[
i
]
if
prefix_pos
is
not
None
else
None
token_ids
=
None
if
prompt_token_ids
is
None
else
prompt_token_ids
[
i
]
self
.
_add_request
(
prompt
,
sampling_params
,
token_ids
,
lora_request
=
lora_request
,
prefix_pos
=
prefix_pos_i
)
self
.
_add_request
(
prompt
,
sampling_params
,
token_ids
,
lora_request
=
lora_request
,
# Get ith image while maintaining the batch dim.
multi_modal_data
=
MultiModalData
(
type
=
multi_modal_data
.
type
,
data
=
multi_modal_data
.
data
[
i
].
unsqueeze
(
0
))
if
multi_modal_data
else
None
,
)
return
self
.
_run_engine
(
use_tqdm
)
def
_add_request
(
...
...
@@ -187,7 +195,7 @@ class LLM:
sampling_params
:
SamplingParams
,
prompt_token_ids
:
Optional
[
List
[
int
]],
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prefix_pos
:
Optional
[
int
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
)
->
None
:
request_id
=
str
(
next
(
self
.
request_counter
))
self
.
llm_engine
.
add_request
(
request_id
,
...
...
@@ -195,13 +203,15 @@ class LLM:
sampling_params
,
prompt_token_ids
,
lora_request
=
lora_request
,
prefix_pos
=
prefix_pos
)
multi_modal_data
=
multi_modal_data
)
def
_run_engine
(
self
,
use_tqdm
:
bool
)
->
List
[
RequestOutput
]:
# Initialize tqdm.
if
use_tqdm
:
num_requests
=
self
.
llm_engine
.
get_num_unfinished_requests
()
pbar
=
tqdm
(
total
=
num_requests
,
desc
=
"Processed prompts"
)
pbar
=
tqdm
(
total
=
num_requests
,
desc
=
"Processed prompts"
,
dynamic_ncols
=
True
)
# Run the engine.
outputs
:
List
[
RequestOutput
]
=
[]
while
self
.
llm_engine
.
has_unfinished_requests
():
...
...
vllm/entrypoints/openai/api_server.py
View file @
7c4f76e3
import
argparse
import
asyncio
import
json
from
contextlib
import
asynccontextmanager
import
os
import
importlib
import
inspect
import
os
from
contextlib
import
asynccontextmanager
from
http
import
HTTPStatus
from
prometheus_client
import
make_asgi_app
import
fastapi
import
uvicorn
from
http
import
HTTPStatus
from
fastapi
import
Request
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
JSONResponse
,
StreamingResponse
,
Response
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
from
prometheus_client
import
make_asgi_app
import
vllm
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.openai.protocol
import
CompletionRequest
,
ChatCompletionRequest
,
ErrorResponse
from
vllm.logger
import
init_logger
from
vllm.entrypoints.openai.cli_args
import
make_arg_parser
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
CompletionRequest
,
ErrorResponse
)
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_completion
import
OpenAIServingCompletion
from
vllm.entrypoints.openai.serving_engine
import
LoRA
from
vllm.logger
import
init_logger
from
vllm.usage.usage_lib
import
UsageContext
TIMEOUT_KEEP_ALIVE
=
5
# seconds
...
...
@@ -47,95 +48,8 @@ async def lifespan(app: fastapi.FastAPI):
app
=
fastapi
.
FastAPI
(
lifespan
=
lifespan
)
class
LoRAParserAction
(
argparse
.
Action
):
def
__call__
(
self
,
parser
,
namespace
,
values
,
option_string
=
None
):
lora_list
=
[]
for
item
in
values
:
name
,
path
=
item
.
split
(
'='
)
lora_list
.
append
(
LoRA
(
name
,
path
))
setattr
(
namespace
,
self
.
dest
,
lora_list
)
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
"vLLM OpenAI-Compatible RESTful API server."
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
None
,
help
=
"host name"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
,
help
=
"port number"
)
parser
.
add_argument
(
"--allow-credentials"
,
action
=
"store_true"
,
help
=
"allow credentials"
)
parser
.
add_argument
(
"--allowed-origins"
,
type
=
json
.
loads
,
default
=
[
"*"
],
help
=
"allowed origins"
)
parser
.
add_argument
(
"--allowed-methods"
,
type
=
json
.
loads
,
default
=
[
"*"
],
help
=
"allowed methods"
)
parser
.
add_argument
(
"--allowed-headers"
,
type
=
json
.
loads
,
default
=
[
"*"
],
help
=
"allowed headers"
)
parser
.
add_argument
(
"--api-key"
,
type
=
str
,
default
=
None
,
help
=
"If provided, the server will require this key to be presented in the header."
)
parser
.
add_argument
(
"--served-model-name"
,
type
=
str
,
default
=
None
,
help
=
"The model name used in the API. If not "
"specified, the model name will be the same as "
"the huggingface name."
)
parser
.
add_argument
(
"--lora-modules"
,
type
=
str
,
default
=
None
,
nargs
=
'+'
,
action
=
LoRAParserAction
,
help
=
"LoRA module configurations in the format name=path. Multiple modules can be specified."
)
parser
.
add_argument
(
"--chat-template"
,
type
=
str
,
default
=
None
,
help
=
"The file path to the chat template, "
"or the template in single-line form "
"for the specified model"
)
parser
.
add_argument
(
"--response-role"
,
type
=
str
,
default
=
"assistant"
,
help
=
"The role name to return if "
"`request.add_generation_prompt=true`."
)
parser
.
add_argument
(
"--ssl-keyfile"
,
type
=
str
,
default
=
None
,
help
=
"The file path to the SSL key file"
)
parser
.
add_argument
(
"--ssl-certfile"
,
type
=
str
,
default
=
None
,
help
=
"The file path to the SSL cert file"
)
parser
.
add_argument
(
"--root-path"
,
type
=
str
,
default
=
None
,
help
=
"FastAPI root_path when app is behind a path based routing proxy"
)
parser
.
add_argument
(
"--middleware"
,
type
=
str
,
action
=
"append"
,
default
=
[],
help
=
"Additional ASGI middleware to apply to the app. "
"We accept multiple --middleware arguments. "
"The value should be an import path. "
"If a function is provided, vLLM will add it to the server using @app.middleware('http'). "
"If a class is provided, vLLM will add it to the server using app.add_middleware(). "
)
parser
=
AsyncEngineArgs
.
add_cli_args
(
parser
)
parser
=
make_arg_parser
()
return
parser
.
parse_args
()
...
...
@@ -153,6 +67,7 @@ async def validation_exception_handler(_, exc):
@
app
.
get
(
"/health"
)
async
def
health
()
->
Response
:
"""Health check."""
await
openai_serving_chat
.
engine
.
check_health
()
return
Response
(
status_code
=
200
)
...
...
@@ -162,6 +77,12 @@ async def show_available_models():
return
JSONResponse
(
content
=
models
.
model_dump
())
@
app
.
get
(
"/version"
)
async
def
show_version
():
ver
=
{
"version"
:
vllm
.
__version__
}
return
JSONResponse
(
content
=
ver
)
@
app
.
post
(
"/v1/chat/completions"
)
async
def
create_chat_completion
(
request
:
ChatCompletionRequest
,
raw_request
:
Request
):
...
...
@@ -221,19 +142,19 @@ if __name__ == "__main__":
elif
inspect
.
iscoroutinefunction
(
imported
):
app
.
middleware
(
"http"
)(
imported
)
else
:
raise
ValueError
(
f
"Invalid middleware
{
middleware
}
. Must be a function or a class."
)
raise
ValueError
(
f
"Invalid middleware
{
middleware
}
. "
f
"Must be a function or a class."
)
logger
.
info
(
f
"vLLM API server version
{
vllm
.
__version__
}
"
)
logger
.
info
(
f
"args:
{
args
}
"
)
if
args
.
served_model_name
is
not
None
:
served_model
=
args
.
served_model_name
else
:
served_model
=
args
.
model
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
engine
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
)
engine
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
,
usage_context
=
UsageContext
.
OPENAI_API_SERVER
)
openai_serving_chat
=
OpenAIServingChat
(
engine
,
served_model
,
args
.
response_role
,
args
.
lora_modules
,
...
...
@@ -245,7 +166,9 @@ if __name__ == "__main__":
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"info"
,
log_level
=
args
.
uvicorn_log_level
,
timeout_keep_alive
=
TIMEOUT_KEEP_ALIVE
,
ssl_keyfile
=
args
.
ssl_keyfile
,
ssl_certfile
=
args
.
ssl_certfile
)
ssl_certfile
=
args
.
ssl_certfile
,
ssl_ca_certs
=
args
.
ssl_ca_certs
,
ssl_cert_reqs
=
args
.
ssl_cert_reqs
)
vllm/entrypoints/openai/cli_args.py
0 → 100644
View file @
7c4f76e3
"""
This file contains the command line arguments for the vLLM's
OpenAI-compatible server. It is kept in a separate file for documentation
purposes.
"""
import
argparse
import
json
import
ssl
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.entrypoints.openai.serving_engine
import
LoRA
class
LoRAParserAction
(
argparse
.
Action
):
def
__call__
(
self
,
parser
,
namespace
,
values
,
option_string
=
None
):
lora_list
=
[]
for
item
in
values
:
name
,
path
=
item
.
split
(
'='
)
lora_list
.
append
(
LoRA
(
name
,
path
))
setattr
(
namespace
,
self
.
dest
,
lora_list
)
def
make_arg_parser
():
parser
=
argparse
.
ArgumentParser
(
description
=
"vLLM OpenAI-Compatible RESTful API server."
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
None
,
help
=
"host name"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
,
help
=
"port number"
)
parser
.
add_argument
(
"--uvicorn-log-level"
,
type
=
str
,
default
=
"info"
,
choices
=
[
'debug'
,
'info'
,
'warning'
,
'error'
,
'critical'
,
'trace'
],
help
=
"log level for uvicorn"
)
parser
.
add_argument
(
"--allow-credentials"
,
action
=
"store_true"
,
help
=
"allow credentials"
)
parser
.
add_argument
(
"--allowed-origins"
,
type
=
json
.
loads
,
default
=
[
"*"
],
help
=
"allowed origins"
)
parser
.
add_argument
(
"--allowed-methods"
,
type
=
json
.
loads
,
default
=
[
"*"
],
help
=
"allowed methods"
)
parser
.
add_argument
(
"--allowed-headers"
,
type
=
json
.
loads
,
default
=
[
"*"
],
help
=
"allowed headers"
)
parser
.
add_argument
(
"--api-key"
,
type
=
str
,
default
=
None
,
help
=
"If provided, the server will require this key "
"to be presented in the header."
)
parser
.
add_argument
(
"--served-model-name"
,
type
=
str
,
default
=
None
,
help
=
"The model name used in the API. If not "
"specified, the model name will be the same as "
"the huggingface name."
)
parser
.
add_argument
(
"--lora-modules"
,
type
=
str
,
default
=
None
,
nargs
=
'+'
,
action
=
LoRAParserAction
,
help
=
"LoRA module configurations in the format name=path. "
"Multiple modules can be specified."
)
parser
.
add_argument
(
"--chat-template"
,
type
=
str
,
default
=
None
,
help
=
"The file path to the chat template, "
"or the template in single-line form "
"for the specified model"
)
parser
.
add_argument
(
"--response-role"
,
type
=
str
,
default
=
"assistant"
,
help
=
"The role name to return if "
"`request.add_generation_prompt=true`."
)
parser
.
add_argument
(
"--ssl-keyfile"
,
type
=
str
,
default
=
None
,
help
=
"The file path to the SSL key file"
)
parser
.
add_argument
(
"--ssl-certfile"
,
type
=
str
,
default
=
None
,
help
=
"The file path to the SSL cert file"
)
parser
.
add_argument
(
"--ssl-ca-certs"
,
type
=
str
,
default
=
None
,
help
=
"The CA certificates file"
)
parser
.
add_argument
(
"--ssl-cert-reqs"
,
type
=
int
,
default
=
int
(
ssl
.
CERT_NONE
),
help
=
"Whether client certificate is required (see stdlib ssl module's)"
)
parser
.
add_argument
(
"--root-path"
,
type
=
str
,
default
=
None
,
help
=
"FastAPI root_path when app is behind a path based routing proxy"
)
parser
.
add_argument
(
"--middleware"
,
type
=
str
,
action
=
"append"
,
default
=
[],
help
=
"Additional ASGI middleware to apply to the app. "
"We accept multiple --middleware arguments. "
"The value should be an import path. "
"If a function is provided, vLLM will add it to the server "
"using @app.middleware('http'). "
"If a class is provided, vLLM will add it to the server "
"using app.add_middleware(). "
)
parser
=
AsyncEngineArgs
.
add_cli_args
(
parser
)
return
parser
vllm/entrypoints/openai/protocol.py
View file @
7c4f76e3
...
...
@@ -3,12 +3,11 @@
import
time
from
typing
import
Dict
,
List
,
Literal
,
Optional
,
Union
import
torch
from
pydantic
import
BaseModel
,
Field
,
model_validator
from
vllm.utils
import
random_uuid
from
vllm.sampling_params
import
SamplingParams
import
torch
from
vllm.utils
import
random_uuid
class
ErrorResponse
(
BaseModel
):
...
...
@@ -55,40 +54,87 @@ class UsageInfo(BaseModel):
completion_tokens
:
Optional
[
int
]
=
0
class
ResponseFormat
(
BaseModel
):
# type must be "json_object" or "text"
type
:
str
=
Literal
[
"text"
,
"json_object"
]
class
ChatCompletionRequest
(
BaseModel
):
model
:
str
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
messages
:
List
[
Dict
[
str
,
str
]]
temperature
:
Optional
[
float
]
=
0.7
top_p
:
Optional
[
float
]
=
1.0
n
:
Optional
[
int
]
=
1
model
:
str
frequency_penalty
:
Optional
[
float
]
=
0.0
logit_bias
:
Optional
[
Dict
[
str
,
float
]]
=
None
logprobs
:
Optional
[
bool
]
=
False
top_logprobs
:
Optional
[
int
]
=
None
max_tokens
:
Optional
[
int
]
=
None
n
:
Optional
[
int
]
=
1
presence_penalty
:
Optional
[
float
]
=
0.0
response_format
:
Optional
[
ResponseFormat
]
=
None
seed
:
Optional
[
int
]
=
None
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default_factory
=
list
)
stream
:
Optional
[
bool
]
=
False
logprobs
:
Optional
[
bool
]
=
False
top_logprobs
:
Optional
[
int
]
=
None
presence_penalty
:
Optional
[
float
]
=
0.0
frequency_penalty
:
Optional
[
float
]
=
0.0
logit_bias
:
Optional
[
Dict
[
str
,
float
]]
=
None
temperature
:
Optional
[
float
]
=
0.7
top_p
:
Optional
[
float
]
=
1.0
user
:
Optional
[
str
]
=
None
# Additional parameters supported by vLLM
# doc: begin-chat-completion-sampling-params
best_of
:
Optional
[
int
]
=
None
top_k
:
Optional
[
int
]
=
-
1
ignore_eos
:
Optional
[
bool
]
=
False
use_beam_search
:
Optional
[
bool
]
=
False
top_k
:
Optional
[
int
]
=
-
1
min_p
:
Optional
[
float
]
=
0.0
repetition_penalty
:
Optional
[
float
]
=
1.0
length_penalty
:
Optional
[
float
]
=
1.0
early_stopping
:
Optional
[
bool
]
=
False
ignore_eos
:
Optional
[
bool
]
=
False
min_tokens
:
Optional
[
int
]
=
0
stop_token_ids
:
Optional
[
List
[
int
]]
=
Field
(
default_factory
=
list
)
skip_special_tokens
:
Optional
[
bool
]
=
True
spaces_between_special_tokens
:
Optional
[
bool
]
=
True
add_generation_prompt
:
Optional
[
bool
]
=
True
echo
:
Optional
[
bool
]
=
False
repetition_penalty
:
Optional
[
float
]
=
1.0
min_p
:
Optional
[
float
]
=
0.0
include_stop_str_in_output
:
Optional
[
bool
]
=
False
length_penalty
:
Optional
[
float
]
=
1.0
guided_json
:
Optional
[
Union
[
str
,
dict
,
BaseModel
]]
=
None
guided_regex
:
Optional
[
str
]
=
None
guided_choice
:
Optional
[
List
[
str
]]
=
None
# doc: end-chat-completion-sampling-params
# doc: begin-chat-completion-extra-params
echo
:
Optional
[
bool
]
=
Field
(
default
=
False
,
description
=
(
"If true, the new message will be prepended with the last message "
"if they belong to the same role."
),
)
add_generation_prompt
:
Optional
[
bool
]
=
Field
(
default
=
True
,
description
=
(
"If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."
),
)
include_stop_str_in_output
:
Optional
[
bool
]
=
Field
(
default
=
False
,
description
=
(
"Whether to include the stop string in the output. "
"This is only applied when the stop or stop_token_ids is set."
),
)
guided_json
:
Optional
[
Union
[
str
,
dict
,
BaseModel
]]
=
Field
(
default
=
None
,
description
=
(
"If specified, the output will follow the JSON schema."
),
)
guided_regex
:
Optional
[
str
]
=
Field
(
default
=
None
,
description
=
(
"If specified, the output will follow the regex pattern."
),
)
guided_choice
:
Optional
[
List
[
str
]]
=
Field
(
default
=
None
,
description
=
(
"If specified, the output will be exactly one of the choices."
),
)
guided_grammar
:
Optional
[
str
]
=
Field
(
default
=
None
,
description
=
(
"If specified, the output will follow the context free grammar."
),
)
# doc: end-chat-completion-extra-params
def
to_sampling_params
(
self
)
->
SamplingParams
:
if
self
.
logprobs
and
not
self
.
top_logprobs
:
...
...
@@ -120,6 +166,7 @@ class ChatCompletionRequest(BaseModel):
stop
=
self
.
stop
,
stop_token_ids
=
self
.
stop_token_ids
,
max_tokens
=
self
.
max_tokens
,
min_tokens
=
self
.
min_tokens
,
logprobs
=
self
.
top_logprobs
if
self
.
logprobs
else
None
,
prompt_logprobs
=
self
.
top_logprobs
if
self
.
echo
else
None
,
best_of
=
self
.
best_of
,
...
...
@@ -150,39 +197,75 @@ class ChatCompletionRequest(BaseModel):
class
CompletionRequest
(
BaseModel
):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create
model
:
str
# a string, array of strings, array of tokens, or array of token arrays
prompt
:
Union
[
List
[
int
],
List
[
List
[
int
]],
str
,
List
[
str
]]
suffix
:
Optional
[
str
]
=
None
max_tokens
:
Optional
[
int
]
=
16
temperature
:
Optional
[
float
]
=
1.0
top_p
:
Optional
[
float
]
=
1.0
n
:
Optional
[
int
]
=
1
stream
:
Optional
[
bool
]
=
False
logprobs
:
Optional
[
int
]
=
None
best_of
:
Optional
[
int
]
=
None
echo
:
Optional
[
bool
]
=
False
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default_factory
=
list
)
seed
:
Optional
[
int
]
=
None
presence_penalty
:
Optional
[
float
]
=
0.0
frequency_penalty
:
Optional
[
float
]
=
0.0
best_of
:
Optional
[
int
]
=
None
logit_bias
:
Optional
[
Dict
[
str
,
float
]]
=
None
logprobs
:
Optional
[
int
]
=
None
max_tokens
:
Optional
[
int
]
=
16
n
:
Optional
[
int
]
=
1
presence_penalty
:
Optional
[
float
]
=
0.0
seed
:
Optional
[
int
]
=
None
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default_factory
=
list
)
stream
:
Optional
[
bool
]
=
False
suffix
:
Optional
[
str
]
=
None
temperature
:
Optional
[
float
]
=
1.0
top_p
:
Optional
[
float
]
=
1.0
user
:
Optional
[
str
]
=
None
# Additional parameters supported by vLLM
top_k
:
Optional
[
int
]
=
-
1
ignore_eos
:
Optional
[
bool
]
=
False
# doc: begin-completion-sampling-params
use_beam_search
:
Optional
[
bool
]
=
False
top_k
:
Optional
[
int
]
=
-
1
min_p
:
Optional
[
float
]
=
0.0
repetition_penalty
:
Optional
[
float
]
=
1.0
length_penalty
:
Optional
[
float
]
=
1.0
early_stopping
:
Optional
[
bool
]
=
False
stop_token_ids
:
Optional
[
List
[
int
]]
=
Field
(
default_factory
=
list
)
ignore_eos
:
Optional
[
bool
]
=
False
min_tokens
:
Optional
[
int
]
=
0
skip_special_tokens
:
Optional
[
bool
]
=
True
spaces_between_special_tokens
:
Optional
[
bool
]
=
True
repetition_penalty
:
Optional
[
float
]
=
1.0
min_p
:
Optional
[
float
]
=
0.0
include_stop_str_in_output
:
Optional
[
bool
]
=
False
length_penalty
:
Optional
[
float
]
=
1.0
guided_json
:
Optional
[
Union
[
str
,
dict
,
BaseModel
]]
=
None
guided_regex
:
Optional
[
str
]
=
None
guided_choice
:
Optional
[
List
[
str
]]
=
None
# doc: end-completion-sampling-params
# doc: begin-completion-extra-params
include_stop_str_in_output
:
Optional
[
bool
]
=
Field
(
default
=
False
,
description
=
(
"Whether to include the stop string in the output. "
"This is only applied when the stop or stop_token_ids is set."
),
)
response_format
:
Optional
[
ResponseFormat
]
=
Field
(
default
=
None
,
description
=
(
"Similar to chat completion, this parameter specifies the format of "
"output. Only {'type': 'json_object'} or {'type': 'text' } is "
"supported."
),
)
guided_json
:
Optional
[
Union
[
str
,
dict
,
BaseModel
]]
=
Field
(
default
=
None
,
description
=
(
"If specified, the output will follow the JSON schema."
),
)
guided_regex
:
Optional
[
str
]
=
Field
(
default
=
None
,
description
=
(
"If specified, the output will follow the regex pattern."
),
)
guided_choice
:
Optional
[
List
[
str
]]
=
Field
(
default
=
None
,
description
=
(
"If specified, the output will be exactly one of the choices."
),
)
guided_grammar
:
Optional
[
str
]
=
Field
(
default
=
None
,
description
=
(
"If specified, the output will follow the context free grammar."
),
)
# doc: end-completion-extra-params
def
to_sampling_params
(
self
):
echo_without_generation
=
self
.
echo
and
self
.
max_tokens
==
0
...
...
@@ -216,6 +299,7 @@ class CompletionRequest(BaseModel):
stop_token_ids
=
self
.
stop_token_ids
,
ignore_eos
=
self
.
ignore_eos
,
max_tokens
=
self
.
max_tokens
if
not
echo_without_generation
else
1
,
min_tokens
=
self
.
min_tokens
,
logprobs
=
self
.
logprobs
,
use_beam_search
=
self
.
use_beam_search
,
early_stopping
=
self
.
early_stopping
,
...
...
@@ -246,7 +330,7 @@ class LogProbs(BaseModel):
text_offset
:
List
[
int
]
=
Field
(
default_factory
=
list
)
token_logprobs
:
List
[
Optional
[
float
]]
=
Field
(
default_factory
=
list
)
tokens
:
List
[
str
]
=
Field
(
default_factory
=
list
)
top_logprobs
:
Optional
[
List
[
Optional
[
Dict
[
int
,
float
]]]]
=
None
top_logprobs
:
Optional
[
List
[
Optional
[
Dict
[
str
,
float
]]]]
=
None
class
CompletionResponseChoice
(
BaseModel
):
...
...
@@ -254,6 +338,13 @@ class CompletionResponseChoice(BaseModel):
text
:
str
logprobs
:
Optional
[
LogProbs
]
=
None
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
]]
=
None
stop_reason
:
Union
[
None
,
int
,
str
]
=
Field
(
default
=
None
,
description
=
(
"The stop string or token id that caused the completion "
"to stop, None if the completion finished for some other reason "
"including encountering the EOS token"
),
)
class
CompletionResponse
(
BaseModel
):
...
...
@@ -270,6 +361,13 @@ class CompletionResponseStreamChoice(BaseModel):
text
:
str
logprobs
:
Optional
[
LogProbs
]
=
None
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
]]
=
None
stop_reason
:
Union
[
None
,
int
,
str
]
=
Field
(
default
=
None
,
description
=
(
"The stop string or token id that caused the completion "
"to stop, None if the completion finished for some other reason "
"including encountering the EOS token"
),
)
class
CompletionStreamResponse
(
BaseModel
):
...
...
@@ -291,6 +389,7 @@ class ChatCompletionResponseChoice(BaseModel):
message
:
ChatMessage
logprobs
:
Optional
[
LogProbs
]
=
None
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
]]
=
None
stop_reason
:
Union
[
None
,
int
,
str
]
=
None
class
ChatCompletionResponse
(
BaseModel
):
...
...
@@ -312,6 +411,7 @@ class ChatCompletionResponseStreamChoice(BaseModel):
delta
:
DeltaMessage
logprobs
:
Optional
[
LogProbs
]
=
None
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
]]
=
None
stop_reason
:
Union
[
None
,
int
,
str
]
=
None
class
ChatCompletionStreamResponse
(
BaseModel
):
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
7c4f76e3
import
time
import
codecs
import
time
from
typing
import
AsyncGenerator
,
AsyncIterator
,
List
,
Optional
,
Union
from
fastapi
import
Request
from
typing
import
AsyncGenerator
,
AsyncIterator
,
Optional
,
List
,
Union
from
vllm.logger
import
init_logger
from
vllm.utils
import
random_uuid
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
ChatCompletionResponse
,
ChatCompletionResponseChoice
,
ChatCompletionResponseStreamChoice
,
ChatCompletionStreamResponse
,
ChatMessage
,
DeltaMessage
,
ErrorResponse
,
UsageInfo
)
from
vllm.entrypoints.openai.serving_engine
import
LoRA
,
OpenAIServing
from
vllm.logger
import
init_logger
from
vllm.model_executor.guided_decoding
import
(
get_guided_decoding_logits_processor
)
from
vllm.outputs
import
RequestOutput
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
,
LoRA
from
vllm.model_executor.guided_decoding
import
get_guided_decoding_logits_processor
from
vllm.utils
import
random_uuid
logger
=
init_logger
(
__name__
)
...
...
@@ -37,8 +40,9 @@ class OpenAIServingChat(OpenAIServing):
ChatCompletionResponse
]:
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/chat/create
for the API specification. This API mimics the OpenAI ChatCompletion API.
See https://platform.openai.com/docs/api-reference/chat/create
for the API specification. This API mimics the OpenAI
ChatCompletion API.
NOTE: Currently we do not support the following feature:
- function_call (Users should implement this by themselves)
...
...
@@ -65,7 +69,7 @@ class OpenAIServingChat(OpenAIServing):
lora_request
=
self
.
_maybe_get_lora
(
request
)
guided_decode_logits_processor
=
(
await
get_guided_decoding_logits_processor
(
request
,
self
.
engine
.
get_tokenizer
()))
request
,
await
self
.
engine
.
get_tokenizer
()))
if
guided_decode_logits_processor
:
if
sampling_params
.
logits_processors
is
None
:
sampling_params
.
logits_processors
=
[]
...
...
@@ -82,8 +86,12 @@ class OpenAIServingChat(OpenAIServing):
return
self
.
chat_completion_stream_generator
(
request
,
result_generator
,
request_id
)
else
:
return
await
self
.
chat_completion_full_generator
(
request
,
raw_request
,
result_generator
,
request_id
)
try
:
return
await
self
.
chat_completion_full_generator
(
request
,
raw_request
,
result_generator
,
request_id
)
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
def
get_chat_request_role
(
self
,
request
:
ChatCompletionRequest
)
->
str
:
if
request
.
add_generation_prompt
:
...
...
@@ -97,119 +105,139 @@ class OpenAIServingChat(OpenAIServing):
)
->
Union
[
ErrorResponse
,
AsyncGenerator
[
str
,
None
]]:
model_name
=
request
.
model
created_time
=
int
(
time
.
monotonic
())
created_time
=
int
(
time
.
time
())
chunk_object_type
=
"chat.completion.chunk"
# Send first response for each request.n (index) with the role
role
=
self
.
get_chat_request_role
(
request
)
for
i
in
range
(
request
.
n
):
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
i
,
delta
=
DeltaMessage
(
role
=
role
),
logprobs
=
None
,
finish_reason
=
None
)
chunk
=
ChatCompletionStreamResponse
(
id
=
request_id
,
object
=
chunk_object_type
,
created
=
created_time
,
choices
=
[
choice_data
],
model
=
model_name
)
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
yield
f
"data:
{
data
}
\n\n
"
# Send response to echo the input portion of the last message
if
request
.
echo
:
last_msg_content
=
""
if
request
.
messages
and
isinstance
(
request
.
messages
,
list
)
and
request
.
messages
[
-
1
].
get
(
"content"
)
and
request
.
messages
[
-
1
].
get
(
"role"
)
==
role
:
last_msg_content
=
request
.
messages
[
-
1
][
"content"
]
if
last_msg_content
:
for
i
in
range
(
request
.
n
):
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
i
,
delta
=
DeltaMessage
(
content
=
last_msg_content
),
finish_reason
=
None
)
chunk
=
ChatCompletionStreamResponse
(
id
=
request_id
,
object
=
chunk_object_type
,
created
=
created_time
,
choices
=
[
choice_data
],
logprobs
=
None
,
model
=
model_name
)
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
yield
f
"data:
{
data
}
\n\n
"
first_iteration
=
True
# Send response for each token for each request.n (index)
previous_texts
=
[
""
]
*
request
.
n
previous_num_tokens
=
[
0
]
*
request
.
n
finish_reason_sent
=
[
False
]
*
request
.
n
async
for
res
in
result_generator
:
res
:
RequestOutput
for
output
in
res
.
outputs
:
i
=
output
.
index
if
finish_reason_sent
[
i
]:
continue
delta_token_ids
=
output
.
token_ids
[
previous_num_tokens
[
i
]:]
top_logprobs
=
output
.
logprobs
[
previous_num_tokens
[
i
]:]
if
output
.
logprobs
else
None
if
request
.
logprobs
:
logprobs
=
self
.
_create_logprobs
(
token_ids
=
delta_token_ids
,
top_logprobs
=
top_logprobs
,
num_output_top_logprobs
=
request
.
logprobs
,
initial_text_offset
=
len
(
previous_texts
[
i
]),
)
else
:
logprobs
=
None
delta_text
=
output
.
text
[
len
(
previous_texts
[
i
]):]
previous_texts
[
i
]
=
output
.
text
previous_num_tokens
[
i
]
=
len
(
output
.
token_ids
)
if
output
.
finish_reason
is
None
:
# Send token-by-token response for each request.n
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
i
,
delta
=
DeltaMessage
(
content
=
delta_text
),
logprobs
=
logprobs
,
finish_reason
=
None
)
chunk
=
ChatCompletionStreamResponse
(
id
=
request_id
,
object
=
chunk_object_type
,
created
=
created_time
,
choices
=
[
choice_data
],
model
=
model_name
)
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
yield
f
"data:
{
data
}
\n\n
"
else
:
# Send the finish response for each request.n only once
prompt_tokens
=
len
(
res
.
prompt_token_ids
)
final_usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
completion_tokens
=
previous_num_tokens
[
i
],
total_tokens
=
prompt_tokens
+
previous_num_tokens
[
i
],
)
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
i
,
delta
=
DeltaMessage
(
content
=
delta_text
),
logprobs
=
logprobs
,
finish_reason
=
output
.
finish_reason
)
chunk
=
ChatCompletionStreamResponse
(
id
=
request_id
,
object
=
chunk_object_type
,
created
=
created_time
,
choices
=
[
choice_data
],
model
=
model_name
)
if
final_usage
is
not
None
:
chunk
.
usage
=
final_usage
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
,
exclude_none
=
True
)
yield
f
"data:
{
data
}
\n\n
"
finish_reason_sent
[
i
]
=
True
try
:
async
for
res
in
result_generator
:
res
:
RequestOutput
# We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST
# response (by the try...catch).
if
first_iteration
:
# Send first response for each request.n (index) with
# the role
role
=
self
.
get_chat_request_role
(
request
)
for
i
in
range
(
request
.
n
):
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
i
,
delta
=
DeltaMessage
(
role
=
role
),
logprobs
=
None
,
finish_reason
=
None
)
chunk
=
ChatCompletionStreamResponse
(
id
=
request_id
,
object
=
chunk_object_type
,
created
=
created_time
,
choices
=
[
choice_data
],
model
=
model_name
)
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
yield
f
"data:
{
data
}
\n\n
"
# Send response to echo the input portion of the
# last message
if
request
.
echo
:
last_msg_content
=
""
if
request
.
messages
and
isinstance
(
request
.
messages
,
list
)
and
request
.
messages
[
-
1
].
get
(
"content"
)
and
request
.
messages
[
-
1
].
get
(
"role"
)
==
role
:
last_msg_content
=
request
.
messages
[
-
1
][
"content"
]
if
last_msg_content
:
for
i
in
range
(
request
.
n
):
choice_data
=
(
ChatCompletionResponseStreamChoice
(
index
=
i
,
delta
=
DeltaMessage
(
content
=
last_msg_content
),
finish_reason
=
None
))
chunk
=
ChatCompletionStreamResponse
(
id
=
request_id
,
object
=
chunk_object_type
,
created
=
created_time
,
choices
=
[
choice_data
],
logprobs
=
None
,
model
=
model_name
)
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
yield
f
"data:
{
data
}
\n\n
"
first_iteration
=
False
for
output
in
res
.
outputs
:
i
=
output
.
index
if
finish_reason_sent
[
i
]:
continue
delta_token_ids
=
output
.
token_ids
[
previous_num_tokens
[
i
]:]
top_logprobs
=
output
.
logprobs
[
previous_num_tokens
[
i
]:]
if
output
.
logprobs
else
None
if
request
.
logprobs
:
logprobs
=
self
.
_create_logprobs
(
token_ids
=
delta_token_ids
,
top_logprobs
=
top_logprobs
,
num_output_top_logprobs
=
request
.
logprobs
,
initial_text_offset
=
len
(
previous_texts
[
i
]),
)
else
:
logprobs
=
None
delta_text
=
output
.
text
[
len
(
previous_texts
[
i
]):]
previous_texts
[
i
]
=
output
.
text
previous_num_tokens
[
i
]
=
len
(
output
.
token_ids
)
if
output
.
finish_reason
is
None
:
# Send token-by-token response for each request.n
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
i
,
delta
=
DeltaMessage
(
content
=
delta_text
),
logprobs
=
logprobs
,
finish_reason
=
None
)
chunk
=
ChatCompletionStreamResponse
(
id
=
request_id
,
object
=
chunk_object_type
,
created
=
created_time
,
choices
=
[
choice_data
],
model
=
model_name
)
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
yield
f
"data:
{
data
}
\n\n
"
else
:
# Send the finish response for each request.n only once
prompt_tokens
=
len
(
res
.
prompt_token_ids
)
final_usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
completion_tokens
=
previous_num_tokens
[
i
],
total_tokens
=
prompt_tokens
+
previous_num_tokens
[
i
],
)
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
i
,
delta
=
DeltaMessage
(
content
=
delta_text
),
logprobs
=
logprobs
,
finish_reason
=
output
.
finish_reason
,
stop_reason
=
output
.
stop_reason
)
chunk
=
ChatCompletionStreamResponse
(
id
=
request_id
,
object
=
chunk_object_type
,
created
=
created_time
,
choices
=
[
choice_data
],
model
=
model_name
)
if
final_usage
is
not
None
:
chunk
.
usage
=
final_usage
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
,
exclude_none
=
True
)
yield
f
"data:
{
data
}
\n\n
"
finish_reason_sent
[
i
]
=
True
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
data
=
self
.
create_streaming_error_response
(
str
(
e
))
yield
f
"data:
{
data
}
\n\n
"
# Send the final done message after all response.n are finished
yield
"data: [DONE]
\n\n
"
...
...
@@ -219,7 +247,7 @@ class OpenAIServingChat(OpenAIServing):
request_id
:
str
)
->
Union
[
ErrorResponse
,
ChatCompletionResponse
]:
model_name
=
request
.
model
created_time
=
int
(
time
.
monotonic
())
created_time
=
int
(
time
.
time
())
final_res
:
RequestOutput
=
None
async
for
res
in
result_generator
:
...
...
@@ -251,6 +279,7 @@ class OpenAIServingChat(OpenAIServing):
message
=
ChatMessage
(
role
=
role
,
content
=
output
.
text
),
logprobs
=
logprobs
,
finish_reason
=
output
.
finish_reason
,
stop_reason
=
output
.
stop_reason
,
)
choices
.
append
(
choice_data
)
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
7c4f76e3
import
asyncio
import
time
from
typing
import
(
AsyncGenerator
,
AsyncIterator
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
)
from
fastapi
import
Request
from
typing
import
AsyncGenerator
,
AsyncIterator
,
Callable
,
List
,
Optional
,
Dict
,
Tuple
from
vllm.logger
import
init_logger
from
vllm.utils
import
random_uuid
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.openai.protocol
import
(
CompletionRequest
,
CompletionResponse
,
CompletionResponseChoice
,
CompletionResponseStreamChoice
,
CompletionStreamResponse
,
LogProbs
,
UsageInfo
,
)
from
vllm.entrypoints.openai.protocol
import
(
CompletionRequest
,
CompletionResponse
,
CompletionResponseChoice
,
CompletionResponseStreamChoice
,
CompletionStreamResponse
,
LogProbs
,
UsageInfo
)
from
vllm.entrypoints.openai.serving_engine
import
LoRA
,
OpenAIServing
from
vllm.logger
import
init_logger
from
vllm.model_executor.guided_decoding
import
(
get_guided_decoding_logits_processor
)
from
vllm.outputs
import
RequestOutput
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
,
LoRA
from
vllm.model_executor.guided_decoding
import
get_guided_decoding_logits_processor
from
vllm.utils
import
random_uuid
logger
=
init_logger
(
__name__
)
...
...
@@ -26,107 +27,6 @@ TypeCreateLogProbsFn = Callable[
[
TypeTokenIDs
,
TypeTopLogProbs
,
Optional
[
int
],
int
],
LogProbs
]
async
def
completion_stream_generator
(
request
:
CompletionRequest
,
raw_request
:
Request
,
on_abort
,
result_generator
:
AsyncIterator
[
Tuple
[
int
,
RequestOutput
]],
create_logprobs_fn
:
TypeCreateLogProbsFn
,
request_id
:
str
,
created_time
:
int
,
model_name
:
str
,
num_prompts
:
int
,
)
->
AsyncGenerator
[
str
,
None
]:
previous_texts
=
[
""
]
*
request
.
n
*
num_prompts
previous_num_tokens
=
[
0
]
*
request
.
n
*
num_prompts
has_echoed
=
[
False
]
*
request
.
n
*
num_prompts
async
for
prompt_idx
,
res
in
result_generator
:
# Abort the request if the client disconnects.
if
await
raw_request
.
is_disconnected
():
await
on_abort
(
f
"
{
request_id
}
-
{
prompt_idx
}
"
)
raise
StopAsyncIteration
()
for
output
in
res
.
outputs
:
i
=
output
.
index
+
prompt_idx
*
request
.
n
# TODO(simon): optimize the performance by avoiding full text O(n^2) sending.
if
request
.
echo
and
request
.
max_tokens
==
0
:
# only return the prompt
delta_text
=
res
.
prompt
delta_token_ids
=
res
.
prompt_token_ids
top_logprobs
=
res
.
prompt_logprobs
has_echoed
[
i
]
=
True
elif
request
.
echo
and
request
.
max_tokens
>
0
and
not
has_echoed
[
i
]:
# echo the prompt and first token
delta_text
=
res
.
prompt
+
output
.
text
delta_token_ids
=
res
.
prompt_token_ids
+
output
.
token_ids
top_logprobs
=
res
.
prompt_logprobs
+
(
output
.
logprobs
or
[])
has_echoed
[
i
]
=
True
else
:
# return just the delta
delta_text
=
output
.
text
[
len
(
previous_texts
[
i
]):]
delta_token_ids
=
output
.
token_ids
[
previous_num_tokens
[
i
]:]
top_logprobs
=
output
.
logprobs
[
previous_num_tokens
[
i
]:]
if
output
.
logprobs
else
None
if
request
.
logprobs
is
not
None
:
assert
top_logprobs
is
not
None
,
"top_logprobs must be provided when logprobs is requested"
logprobs
=
create_logprobs_fn
(
token_ids
=
delta_token_ids
,
top_logprobs
=
top_logprobs
,
num_output_top_logprobs
=
request
.
logprobs
,
initial_text_offset
=
len
(
previous_texts
[
i
]),
)
else
:
logprobs
=
None
previous_texts
[
i
]
=
output
.
text
previous_num_tokens
[
i
]
=
len
(
output
.
token_ids
)
finish_reason
=
output
.
finish_reason
response_json
=
CompletionStreamResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
choices
=
[
CompletionResponseStreamChoice
(
index
=
i
,
text
=
delta_text
,
logprobs
=
logprobs
,
finish_reason
=
finish_reason
,
)
]).
model_dump_json
()
yield
f
"data:
{
response_json
}
\n\n
"
if
output
.
finish_reason
is
not
None
:
# return final usage
logprobs
=
LogProbs
()
if
request
.
logprobs
is
not
None
else
None
prompt_tokens
=
len
(
res
.
prompt_token_ids
)
completion_tokens
=
len
(
output
.
token_ids
)
final_usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
completion_tokens
=
completion_tokens
,
total_tokens
=
prompt_tokens
+
completion_tokens
,
)
response_json
=
CompletionStreamResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
choices
=
[
CompletionResponseStreamChoice
(
index
=
i
,
text
=
""
,
logprobs
=
logprobs
,
finish_reason
=
output
.
finish_reason
,
)
],
usage
=
final_usage
,
).
model_dump_json
()
yield
f
"data:
{
response_json
}
\n\n
"
yield
"data: [DONE]
\n\n
"
def
parse_prompt_format
(
prompt
)
->
Tuple
[
bool
,
list
]:
# get the prompt, openai supports the following
# "a string, array of strings, array of tokens, or array of token arrays."
...
...
@@ -145,79 +45,11 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]:
prompt_is_tokens
=
True
prompts
=
prompt
# case 4: array of token arrays
else
:
raise
ValueError
(
"prompt must be a string, array of strings, array of tokens, or array of token arrays"
)
raise
ValueError
(
"prompt must be a string, array of strings, "
"array of tokens, or array of token arrays"
)
return
prompt_is_tokens
,
prompts
def
request_output_to_completion_response
(
final_res_batch
:
List
[
RequestOutput
],
request
:
CompletionRequest
,
create_logprobs_fn
:
TypeCreateLogProbsFn
,
request_id
:
str
,
created_time
:
int
,
model_name
:
str
,
)
->
CompletionResponse
:
choices
=
[]
num_prompt_tokens
=
0
num_generated_tokens
=
0
for
final_res
in
final_res_batch
:
assert
final_res
is
not
None
prompt_token_ids
=
final_res
.
prompt_token_ids
prompt_logprobs
=
final_res
.
prompt_logprobs
prompt_text
=
final_res
.
prompt
for
output
in
final_res
.
outputs
:
if
request
.
echo
and
request
.
max_tokens
==
0
:
token_ids
=
prompt_token_ids
top_logprobs
=
prompt_logprobs
output_text
=
prompt_text
elif
request
.
echo
and
request
.
max_tokens
>
0
:
token_ids
=
prompt_token_ids
+
output
.
token_ids
top_logprobs
=
prompt_logprobs
+
output
.
logprobs
output_text
=
prompt_text
+
output
.
text
else
:
token_ids
=
output
.
token_ids
top_logprobs
=
output
.
logprobs
output_text
=
output
.
text
if
request
.
logprobs
is
not
None
:
logprobs
=
create_logprobs_fn
(
token_ids
=
token_ids
,
top_logprobs
=
top_logprobs
,
num_output_top_logprobs
=
request
.
logprobs
,
)
else
:
logprobs
=
None
choice_data
=
CompletionResponseChoice
(
index
=
len
(
choices
),
text
=
output_text
,
logprobs
=
logprobs
,
finish_reason
=
output
.
finish_reason
,
)
choices
.
append
(
choice_data
)
num_prompt_tokens
+=
len
(
prompt_token_ids
)
num_generated_tokens
+=
sum
(
len
(
output
.
token_ids
)
for
output
in
final_res
.
outputs
)
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
completion_tokens
=
num_generated_tokens
,
total_tokens
=
num_prompt_tokens
+
num_generated_tokens
,
)
return
CompletionResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
choices
=
choices
,
usage
=
usage
,
)
def
merge_async_iterators
(
*
iterators
):
"""Merge multiple asynchronous iterators into a single iterator.
...
...
@@ -230,8 +62,11 @@ def merge_async_iterators(*iterators):
finished
=
[
False
]
*
len
(
iterators
)
async
def
producer
(
i
,
iterator
):
async
for
item
in
iterator
:
await
queue
.
put
((
i
,
item
))
try
:
async
for
item
in
iterator
:
await
queue
.
put
((
i
,
item
))
except
Exception
as
e
:
await
queue
.
put
(
e
)
finished
[
i
]
=
True
_tasks
=
[
...
...
@@ -242,6 +77,8 @@ def merge_async_iterators(*iterators):
async
def
consumer
():
while
not
all
(
finished
)
or
not
queue
.
empty
():
item
=
await
queue
.
get
()
if
isinstance
(
item
,
Exception
):
raise
item
yield
item
await
asyncio
.
gather
(
*
_tasks
)
...
...
@@ -280,7 +117,7 @@ class OpenAIServingCompletion(OpenAIServing):
model_name
=
request
.
model
request_id
=
f
"cmpl-
{
random_uuid
()
}
"
created_time
=
int
(
time
.
monotonic
())
created_time
=
int
(
time
.
time
())
# Schedule the request and get the result generator.
generators
=
[]
...
...
@@ -289,7 +126,7 @@ class OpenAIServingCompletion(OpenAIServing):
lora_request
=
self
.
_maybe_get_lora
(
request
)
guided_decode_logit_processor
=
(
await
get_guided_decoding_logits_processor
(
request
,
self
.
engine
.
get_tokenizer
()))
request
,
await
self
.
engine
.
get_tokenizer
()))
if
guided_decode_logit_processor
is
not
None
:
if
sampling_params
.
logits_processors
is
None
:
sampling_params
.
logits_processors
=
[]
...
...
@@ -312,40 +149,43 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_token_ids
=
input_ids
,
lora_request
=
lora_request
))
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
result_generator
:
AsyncIterator
[
Tuple
[
int
,
RequestOutput
]]
=
merge_async_iterators
(
*
generators
)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
# results. In addition, we do not stream the results when use
# beam search.
stream
=
(
request
.
stream
and
(
request
.
best_of
is
None
or
request
.
n
==
request
.
best_of
)
and
not
request
.
use_beam_search
)
# Streaming response
if
stream
:
return
completion_stream_generator
(
request
,
raw_request
,
self
.
engine
.
abort
,
result_generator
,
self
.
_create_logprobs
,
request_id
,
created_time
,
model_name
,
num_prompts
=
len
(
prompts
))
return
self
.
completion_stream_generator
(
request
,
raw_request
,
result_generator
,
request_id
,
created_time
,
model_name
,
num_prompts
=
len
(
prompts
))
# Non-streaming response
final_res_batch
:
RequestOutput
=
[
None
]
*
len
(
prompts
)
async
for
i
,
res
in
result_generator
:
if
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
await
self
.
engine
.
abort
(
f
"
{
request_id
}
-
{
i
}
"
)
return
self
.
create_error_response
(
"Client disconnected"
)
final_res_batch
[
i
]
=
res
response
=
request_output_to_completion_response
(
final_res_batch
,
request
,
self
.
_create_logprobs
,
request_id
,
created_time
,
model_name
)
try
:
async
for
i
,
res
in
result_generator
:
if
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
await
self
.
engine
.
abort
(
f
"
{
request_id
}
-
{
i
}
"
)
return
self
.
create_error_response
(
"Client disconnected"
)
final_res_batch
[
i
]
=
res
response
=
self
.
request_output_to_completion_response
(
final_res_batch
,
request
,
request_id
,
created_time
,
model_name
)
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
...
...
@@ -359,3 +199,166 @@ class OpenAIServingCompletion(OpenAIServing):
return
fake_stream_generator
()
return
response
async
def
completion_stream_generator
(
self
,
request
:
CompletionRequest
,
raw_request
:
Request
,
result_generator
:
AsyncIterator
[
Tuple
[
int
,
RequestOutput
]],
request_id
:
str
,
created_time
:
int
,
model_name
:
str
,
num_prompts
:
int
,
)
->
AsyncGenerator
[
str
,
None
]:
previous_texts
=
[
""
]
*
request
.
n
*
num_prompts
previous_num_tokens
=
[
0
]
*
request
.
n
*
num_prompts
has_echoed
=
[
False
]
*
request
.
n
*
num_prompts
try
:
async
for
prompt_idx
,
res
in
result_generator
:
# Abort the request if the client disconnects.
if
await
raw_request
.
is_disconnected
():
await
self
.
engine
.
abort
(
f
"
{
request_id
}
-
{
prompt_idx
}
"
)
raise
StopAsyncIteration
()
for
output
in
res
.
outputs
:
i
=
output
.
index
+
prompt_idx
*
request
.
n
# TODO(simon): optimize the performance by avoiding full
# text O(n^2) sending.
if
request
.
echo
and
request
.
max_tokens
==
0
:
# only return the prompt
delta_text
=
res
.
prompt
delta_token_ids
=
res
.
prompt_token_ids
top_logprobs
=
res
.
prompt_logprobs
has_echoed
[
i
]
=
True
elif
(
request
.
echo
and
request
.
max_tokens
>
0
and
not
has_echoed
[
i
]):
# echo the prompt and first token
delta_text
=
res
.
prompt
+
output
.
text
delta_token_ids
=
(
res
.
prompt_token_ids
+
output
.
token_ids
)
top_logprobs
=
res
.
prompt_logprobs
+
(
output
.
logprobs
or
[])
has_echoed
[
i
]
=
True
else
:
# return just the delta
delta_text
=
output
.
text
[
len
(
previous_texts
[
i
]):]
delta_token_ids
=
output
.
token_ids
[
previous_num_tokens
[
i
]:]
top_logprobs
=
output
.
logprobs
[
previous_num_tokens
[
i
]:]
if
output
.
logprobs
else
None
if
request
.
logprobs
is
not
None
:
logprobs
=
self
.
_create_logprobs
(
token_ids
=
delta_token_ids
,
top_logprobs
=
top_logprobs
,
num_output_top_logprobs
=
request
.
logprobs
,
initial_text_offset
=
len
(
previous_texts
[
i
]),
)
else
:
logprobs
=
None
previous_texts
[
i
]
=
output
.
text
previous_num_tokens
[
i
]
=
len
(
output
.
token_ids
)
finish_reason
=
output
.
finish_reason
stop_reason
=
output
.
stop_reason
if
output
.
finish_reason
is
not
None
:
# return final usage
prompt_tokens
=
len
(
res
.
prompt_token_ids
)
completion_tokens
=
len
(
output
.
token_ids
)
final_usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
completion_tokens
=
completion_tokens
,
total_tokens
=
prompt_tokens
+
completion_tokens
,
)
else
:
final_usage
=
None
response_json
=
CompletionStreamResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
choices
=
[
CompletionResponseStreamChoice
(
index
=
i
,
text
=
delta_text
,
logprobs
=
logprobs
,
finish_reason
=
finish_reason
,
stop_reason
=
stop_reason
,
)
],
usage
=
final_usage
,
).
model_dump_json
(
exclude_unset
=
True
)
yield
f
"data:
{
response_json
}
\n\n
"
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
data
=
self
.
create_streaming_error_response
(
str
(
e
))
yield
f
"data:
{
data
}
\n\n
"
yield
"data: [DONE]
\n\n
"
def
request_output_to_completion_response
(
self
,
final_res_batch
:
List
[
RequestOutput
],
request
:
CompletionRequest
,
request_id
:
str
,
created_time
:
int
,
model_name
:
str
,
)
->
CompletionResponse
:
choices
=
[]
num_prompt_tokens
=
0
num_generated_tokens
=
0
for
final_res
in
final_res_batch
:
assert
final_res
is
not
None
prompt_token_ids
=
final_res
.
prompt_token_ids
prompt_logprobs
=
final_res
.
prompt_logprobs
prompt_text
=
final_res
.
prompt
for
output
in
final_res
.
outputs
:
if
request
.
echo
and
request
.
max_tokens
==
0
:
token_ids
=
prompt_token_ids
top_logprobs
=
prompt_logprobs
output_text
=
prompt_text
elif
request
.
echo
and
request
.
max_tokens
>
0
:
token_ids
=
prompt_token_ids
+
output
.
token_ids
top_logprobs
=
prompt_logprobs
+
output
.
logprobs
output_text
=
prompt_text
+
output
.
text
else
:
token_ids
=
output
.
token_ids
top_logprobs
=
output
.
logprobs
output_text
=
output
.
text
if
request
.
logprobs
is
not
None
:
logprobs
=
self
.
_create_logprobs
(
token_ids
=
token_ids
,
top_logprobs
=
top_logprobs
,
num_output_top_logprobs
=
request
.
logprobs
,
)
else
:
logprobs
=
None
choice_data
=
CompletionResponseChoice
(
index
=
len
(
choices
),
text
=
output_text
,
logprobs
=
logprobs
,
finish_reason
=
output
.
finish_reason
,
stop_reason
=
output
.
stop_reason
,
)
choices
.
append
(
choice_data
)
num_prompt_tokens
+=
len
(
prompt_token_ids
)
num_generated_tokens
+=
sum
(
len
(
output
.
token_ids
)
for
output
in
final_res
.
outputs
)
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
completion_tokens
=
num_generated_tokens
,
total_tokens
=
num_prompt_tokens
+
num_generated_tokens
,
)
return
CompletionResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
choices
=
choices
,
usage
=
usage
,
)
vllm/entrypoints/openai/serving_engine.py
View file @
7c4f76e3
import
asyncio
import
json
from
dataclasses
import
dataclass
from
http
import
HTTPStatus
from
typing
import
Dict
,
List
,
Optional
,
Union
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.openai.protocol
import
(
CompletionRequest
,
ChatCompletionRequest
,
ErrorResponse
,
LogProbs
,
ModelCard
,
ModelList
,
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
CompletionRequest
,
ErrorResponse
,
LogProbs
,
ModelCard
,
ModelList
,
ModelPermission
)
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
Logprob
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
logger
=
init_logger
(
__name__
)
...
...
@@ -48,10 +50,12 @@ class OpenAIServing:
except
RuntimeError
:
event_loop
=
None
if
event_loop
is
not
None
and
event_loop
.
is_running
(
):
# If the current is instanced by Ray Serve, there is already a running event loop
if
event_loop
is
not
None
and
event_loop
.
is_running
():
# If the current is instanced by Ray Serve,
# there is already a running event loop
event_loop
.
create_task
(
self
.
_post_init
())
else
:
# When using single vLLM without engine_use_ray
else
:
# When using single vLLM without engine_use_ray
asyncio
.
run
(
self
.
_post_init
())
async
def
_post_init
(
self
):
...
...
@@ -83,7 +87,7 @@ class OpenAIServing:
def
_create_logprobs
(
self
,
token_ids
:
List
[
int
],
top_logprobs
:
Optional
[
List
[
Optional
[
Dict
[
int
,
float
]]]]
=
None
,
top_logprobs
:
Optional
[
List
[
Optional
[
Dict
[
int
,
Logprob
]]]]
=
None
,
num_output_top_logprobs
:
Optional
[
int
]
=
None
,
initial_text_offset
:
int
=
0
,
)
->
LogProbs
:
...
...
@@ -95,10 +99,10 @@ class OpenAIServing:
for
i
,
token_id
in
enumerate
(
token_ids
):
step_top_logprobs
=
top_logprobs
[
i
]
if
step_top_logprobs
is
not
None
:
token_logprob
=
step_top_logprobs
[
token_id
]
token_logprob
=
step_top_logprobs
[
token_id
]
.
logprob
else
:
token_logprob
=
None
token
=
s
elf
.
tokenizer
.
convert_ids_to_tokens
(
token_id
)
token
=
s
tep_top_logprobs
[
token_id
].
decoded_token
logprobs
.
tokens
.
append
(
token
)
logprobs
.
token_logprobs
.
append
(
token_logprob
)
if
len
(
logprobs
.
text_offset
)
==
0
:
...
...
@@ -110,7 +114,7 @@ class OpenAIServing:
if
num_output_top_logprobs
:
logprobs
.
top_logprobs
.
append
({
self
.
tokenizer
.
convert_ids_to_tokens
(
i
):
p
p
.
decoded_token
:
p
.
logprob
for
i
,
p
in
step_top_logprobs
.
items
()
}
if
step_top_logprobs
else
None
)
return
logprobs
...
...
@@ -124,6 +128,19 @@ class OpenAIServing:
type
=
err_type
,
code
=
status_code
.
value
)
def
create_streaming_error_response
(
self
,
message
:
str
,
err_type
:
str
=
"BadRequestError"
,
status_code
:
HTTPStatus
=
HTTPStatus
.
BAD_REQUEST
)
->
str
:
json_str
=
json
.
dumps
({
"error"
:
self
.
create_error_response
(
message
=
message
,
err_type
=
err_type
,
status_code
=
status_code
).
model_dump
()
})
return
json_str
async
def
_check_model
(
self
,
request
)
->
Optional
[
ErrorResponse
]:
if
request
.
model
==
self
.
served_model
:
return
...
...
@@ -163,8 +180,9 @@ class OpenAIServing:
if
token_num
+
request
.
max_tokens
>
self
.
max_model_len
:
raise
ValueError
(
f
"This model's maximum context length is
{
self
.
max_model_len
}
tokens. "
f
"However, you requested
{
request
.
max_tokens
+
token_num
}
tokens "
f
"This model's maximum context length is "
f
"
{
self
.
max_model_len
}
tokens. However, you requested "
f
"
{
request
.
max_tokens
+
token_num
}
tokens "
f
"(
{
token_num
}
in the messages, "
f
"
{
request
.
max_tokens
}
in the completion). "
f
"Please reduce the length of the messages or completion."
,
)
...
...
vllm/executor/__init__.py
0 → 100644
View file @
7c4f76e3
vllm/executor/executor_base.py
0 → 100644
View file @
7c4f76e3
from
abc
import
ABC
,
abstractmethod
from
typing
import
Dict
,
List
,
Optional
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
class
ExecutorBase
(
ABC
):
"""Base class for all executors.
An executor is responsible for executing the model on a specific device
type (e.g., CPU, GPU, Neuron, etc.). Or it can be a distributed executor
that can execute the model on multiple devices.
"""
@
abstractmethod
def
__init__
(
self
,
model_config
:
ModelConfig
,
cache_config
:
CacheConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
)
->
None
:
raise
NotImplementedError
@
abstractmethod
def
execute_model
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]])
->
SamplerOutput
:
"""Executes one model step on the given sequences."""
raise
NotImplementedError
@
abstractmethod
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
raise
NotImplementedError
@
abstractmethod
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
raise
NotImplementedError
@
abstractmethod
def
list_loras
(
self
)
->
List
[
int
]:
raise
NotImplementedError
@
abstractmethod
def
check_health
(
self
)
->
None
:
"""Checks if the executor is healthy. If not, it should raise an
exception."""
raise
NotImplementedError
class
ExecutorAsyncBase
(
ExecutorBase
):
@
abstractmethod
async
def
execute_model_async
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
)
->
SamplerOutput
:
"""Executes one model step on the given sequences."""
raise
NotImplementedError
@
abstractmethod
async
def
check_health_async
(
self
)
->
None
:
"""Checks if the executor is healthy. If not, it should raise an
exception."""
raise
NotImplementedError
vllm/executor/gpu_executor.py
0 → 100644
View file @
7c4f76e3
from
typing
import
Dict
,
List
,
Optional
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.executor.utils
import
check_block_size_valid
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
make_async
)
logger
=
init_logger
(
__name__
)
class
GPUExecutor
(
ExecutorBase
):
def
__init__
(
self
,
model_config
:
ModelConfig
,
cache_config
:
CacheConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
)
->
None
:
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
self
.
lora_config
=
lora_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
self
.
vision_language_config
=
vision_language_config
# Instantiate the worker and load the model to GPU.
self
.
_init_worker
()
# Profile the memory usage and initialize the cache.
self
.
_init_cache
()
def
_init_worker
(
self
):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from
vllm.worker.worker
import
Worker
assert
self
.
parallel_config
.
world_size
==
1
,
(
"GPUExecutor only supports single GPU."
)
distributed_init_method
=
get_distributed_init_method
(
get_ip
(),
get_open_port
())
self
.
driver_worker
=
Worker
(
self
.
model_config
,
self
.
parallel_config
,
self
.
scheduler_config
,
self
.
device_config
,
local_rank
=
0
,
rank
=
0
,
distributed_init_method
=
distributed_init_method
,
lora_config
=
self
.
lora_config
,
vision_language_config
=
self
.
vision_language_config
,
kv_cache_dtype
=
self
.
cache_config
.
cache_dtype
,
is_driver_worker
=
True
,
)
self
.
driver_worker
.
init_device
()
self
.
driver_worker
.
load_model
()
def
_init_cache
(
self
)
->
None
:
"""Profiles the memory usage and initializes the KV cache.
The engine first profiles the existing memory usage.
Then, it allocates the remaining memory for KV blocks.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_gpu_blocks
,
num_cpu_blocks
=
(
self
.
driver_worker
.
profile_num_available_blocks
(
block_size
=
self
.
cache_config
.
block_size
,
gpu_memory_utilization
=
self
.
cache_config
.
gpu_memory_utilization
,
cpu_swap_space
=
self
.
cache_config
.
swap_space_bytes
,
cache_dtype
=
self
.
cache_config
.
cache_dtype
,
))
if
self
.
cache_config
.
forced_num_gpu_blocks
is
not
None
:
forced_num_gpu_blocks
=
self
.
cache_config
.
forced_num_gpu_blocks
logger
.
info
(
f
"Replacing profiled
{
num_gpu_blocks
=
}
with "
f
"
{
forced_num_gpu_blocks
=
}
"
)
num_gpu_blocks
=
forced_num_gpu_blocks
logger
.
info
(
f
"# GPU blocks:
{
num_gpu_blocks
}
, "
f
"# CPU blocks:
{
num_cpu_blocks
}
"
)
check_block_size_valid
(
num_gpu_blocks
,
self
.
cache_config
.
block_size
,
self
.
model_config
.
max_model_len
)
self
.
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
self
.
cache_config
.
num_cpu_blocks
=
num_cpu_blocks
# Initialize the cache.
self
.
driver_worker
.
init_cache_engine
(
cache_config
=
self
.
cache_config
)
# Warm up the model. This includes capturing the model into CUDA graph
# if enforce_eager is False.
self
.
driver_worker
.
warm_up_model
()
def
execute_model
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]])
->
SamplerOutput
:
output
=
self
.
driver_worker
.
execute_model
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
)
return
output
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
assert
lora_request
.
lora_int_id
>
0
,
"lora_id must be greater than 0."
return
self
.
driver_worker
.
add_lora
(
lora_request
)
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
assert
lora_id
>
0
,
"lora_id must be greater than 0."
return
self
.
driver_worker
.
remove_lora
(
lora_id
)
def
list_loras
(
self
)
->
List
[
int
]:
return
self
.
driver_worker
.
list_loras
()
def
check_health
(
self
)
->
None
:
# GPUExecutor will always be healthy as long as
# it's running.
return
class
GPUExecutorAsync
(
GPUExecutor
,
ExecutorAsyncBase
):
async
def
execute_model_async
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
)
->
SamplerOutput
:
output
=
await
make_async
(
self
.
driver_worker
.
execute_model
)(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
)
return
output
async
def
check_health_async
(
self
)
->
None
:
# GPUExecutor will always be healthy as long as
# it's running.
return
vllm/executor/neuron_executor.py
0 → 100644
View file @
7c4f76e3
from
typing
import
Dict
,
List
,
Optional
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
logger
=
init_logger
(
__name__
)
class
NeuronExecutor
(
ExecutorBase
):
def
__init__
(
self
,
model_config
:
ModelConfig
,
cache_config
:
CacheConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
)
->
None
:
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
assert
lora_config
is
None
,
"LoRA is not supported for Neuron backend."
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
# Set the number of GPU blocks to be the same as the maximum number of
# sequences that can be processed in a single batch. This is equivalent
# to schedule without PagedAttention.
self
.
cache_config
.
num_gpu_blocks
=
self
.
scheduler_config
.
max_num_seqs
self
.
cache_config
.
num_cpu_blocks
=
0
# Instantiate the worker and load the model to the device.
self
.
_init_worker
()
def
_init_worker
(
self
):
from
vllm.worker.neuron_worker
import
NeuronWorker
self
.
driver_worker
=
NeuronWorker
(
self
.
model_config
,
self
.
parallel_config
,
self
.
scheduler_config
,
self
.
device_config
,
)
self
.
driver_worker
.
init_device
()
self
.
driver_worker
.
load_model
()
def
execute_model
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]])
->
SamplerOutput
:
assert
(
blocks_to_swap_in
==
{}
and
blocks_to_swap_out
==
{}
and
blocks_to_copy
==
{}),
(
"Cache operations are not supported for Neuron backend."
)
output
=
self
.
driver_worker
.
execute_model
(
seq_group_metadata_list
=
seq_group_metadata_list
)
return
output
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
raise
NotImplementedError
(
"LoRA is not implemented for neuron backend."
)
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
raise
NotImplementedError
(
"LoRA is not implemented for neuron backend."
)
def
list_loras
(
self
)
->
List
[
int
]:
raise
NotImplementedError
(
"LoRA is not implemented for neuron backend."
)
def
check_health
(
self
)
->
None
:
# NeuronExecutor will always be healthy as long as
# it's running.
return
vllm/executor/ray_gpu_executor.py
0 → 100644
View file @
7c4f76e3
import
asyncio
import
copy
import
os
import
pickle
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.engine.ray_utils
import
RayWorkerVllm
,
ray
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.executor.utils
import
check_block_size_valid
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
make_async
,
set_cuda_visible_devices
)
if
ray
is
not
None
:
from
ray.util.scheduling_strategies
import
PlacementGroupSchedulingStrategy
if
TYPE_CHECKING
:
from
ray.util.placement_group
import
PlacementGroup
logger
=
init_logger
(
__name__
)
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
USE_RAY_COMPILED_DAG
=
bool
(
os
.
getenv
(
"VLLM_USE_RAY_COMPILED_DAG"
,
0
))
class
RayGPUExecutor
(
ExecutorBase
):
def
__init__
(
self
,
model_config
:
ModelConfig
,
cache_config
:
CacheConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
)
->
None
:
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
self
.
lora_config
=
lora_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
self
.
vision_language_config
=
vision_language_config
assert
self
.
parallel_config
.
worker_use_ray
placement_group
=
self
.
parallel_config
.
placement_group
# Disable Ray usage stats collection.
ray_usage
=
os
.
environ
.
get
(
"RAY_USAGE_STATS_ENABLED"
,
"0"
)
if
ray_usage
!=
"1"
:
os
.
environ
[
"RAY_USAGE_STATS_ENABLED"
]
=
"0"
# Create the parallel GPU workers.
self
.
_init_workers_ray
(
placement_group
)
# Profile the memory usage and initialize the cache.
self
.
_init_cache
()
self
.
forward_dag
=
None
if
USE_RAY_COMPILED_DAG
:
self
.
forward_dag
=
self
.
_compiled_ray_dag
()
def
_init_workers_ray
(
self
,
placement_group
:
"PlacementGroup"
,
**
ray_remote_kwargs
):
if
self
.
parallel_config
.
tensor_parallel_size
==
1
:
# For single GPU case, we use a ray worker with constrained memory.
num_gpus
=
self
.
cache_config
.
gpu_memory_utilization
else
:
# Otherwise, the ray workers are allocated with a full GPU.
num_gpus
=
1
# The driver dummy worker does not actually use any resources.
# It holds the resource for the driver worker.
self
.
driver_dummy_worker
:
RayWorkerVllm
=
None
# The remaining workers are the actual ray actors.
self
.
workers
:
List
[
RayWorkerVllm
]
=
[]
# Create the workers.
driver_ip
=
get_ip
()
for
bundle_id
,
bundle
in
enumerate
(
placement_group
.
bundle_specs
):
if
not
bundle
.
get
(
"GPU"
,
0
):
continue
scheduling_strategy
=
PlacementGroupSchedulingStrategy
(
placement_group
=
placement_group
,
placement_group_capture_child_tasks
=
True
,
placement_group_bundle_index
=
bundle_id
,
)
worker
=
ray
.
remote
(
num_cpus
=
0
,
num_gpus
=
num_gpus
,
scheduling_strategy
=
scheduling_strategy
,
**
ray_remote_kwargs
,
)(
RayWorkerVllm
).
remote
(
self
.
model_config
.
trust_remote_code
)
worker_ip
=
ray
.
get
(
worker
.
get_node_ip
.
remote
())
if
worker_ip
==
driver_ip
and
self
.
driver_dummy_worker
is
None
:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self
.
driver_dummy_worker
=
worker
else
:
# Else, added to the list of workers.
self
.
workers
.
append
(
worker
)
if
self
.
driver_dummy_worker
is
None
:
raise
ValueError
(
"Ray does not allocate any GPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a "
"GPU node."
)
# Get the set of GPU IDs used on each node.
driver_node_id
,
driver_gpu_ids
=
ray
.
get
(
self
.
driver_dummy_worker
.
get_node_and_gpu_ids
.
remote
())
worker_node_and_gpu_ids
=
ray
.
get
(
[
worker
.
get_node_and_gpu_ids
.
remote
()
for
worker
in
self
.
workers
])
node_workers
=
defaultdict
(
list
)
node_gpus
=
defaultdict
(
list
)
node_workers
[
driver_node_id
].
append
(
0
)
node_gpus
[
driver_node_id
].
extend
(
driver_gpu_ids
)
for
i
,
(
node_id
,
gpu_ids
)
in
enumerate
(
worker_node_and_gpu_ids
,
start
=
1
):
node_workers
[
node_id
].
append
(
i
)
node_gpus
[
node_id
].
extend
(
gpu_ids
)
for
node_id
,
gpu_ids
in
node_gpus
.
items
():
node_gpus
[
node_id
]
=
sorted
(
gpu_ids
)
# Set CUDA_VISIBLE_DEVICES for the driver and workers.
set_cuda_visible_devices
(
node_gpus
[
driver_node_id
])
for
worker
,
(
node_id
,
_
)
in
zip
(
self
.
workers
,
worker_node_and_gpu_ids
):
worker
.
set_cuda_visible_devices
.
remote
(
node_gpus
[
node_id
])
distributed_init_method
=
get_distributed_init_method
(
driver_ip
,
get_open_port
())
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from
vllm.worker.worker
import
Worker
model_config
=
copy
.
deepcopy
(
self
.
model_config
)
parallel_config
=
copy
.
deepcopy
(
self
.
parallel_config
)
scheduler_config
=
copy
.
deepcopy
(
self
.
scheduler_config
)
device_config
=
copy
.
deepcopy
(
self
.
device_config
)
lora_config
=
copy
.
deepcopy
(
self
.
lora_config
)
kv_cache_dtype
=
self
.
cache_config
.
cache_dtype
# Initialize the actual workers with the Worker class.
for
rank
,
(
worker
,
(
node_id
,
_
))
in
enumerate
(
zip
(
self
.
workers
,
worker_node_and_gpu_ids
),
start
=
1
,
):
local_rank
=
node_workers
[
node_id
].
index
(
rank
)
worker
.
init_worker
.
remote
(
lambda
rank
=
rank
,
local_rank
=
local_rank
:
Worker
(
model_config
,
parallel_config
,
scheduler_config
,
device_config
,
local_rank
,
rank
,
distributed_init_method
,
lora_config
=
lora_config
,
kv_cache_dtype
=
kv_cache_dtype
,
))
# Initialize the driver worker with the Worker class.
driver_rank
=
0
driver_local_rank
=
node_workers
[
driver_node_id
].
index
(
driver_rank
)
self
.
driver_worker
=
Worker
(
self
.
model_config
,
self
.
parallel_config
,
self
.
scheduler_config
,
self
.
device_config
,
driver_local_rank
,
driver_rank
,
distributed_init_method
,
lora_config
=
self
.
lora_config
,
vision_language_config
=
self
.
vision_language_config
,
kv_cache_dtype
=
kv_cache_dtype
,
is_driver_worker
=
True
,
)
self
.
_run_workers
(
"init_device"
)
self
.
_run_workers
(
"load_model"
,
max_concurrent_workers
=
self
.
parallel_config
.
max_parallel_loading_workers
,
)
def
_init_cache
(
self
)
->
None
:
"""Profiles the memory usage and initializes the KV cache.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
More details can be found in the
:meth:`~vllm.worker.worker.Worker.profile_num_available_blocks` method
from class :class:`~vllm.worker.Worker`.
Afterwards, as there may be multiple workers,
we take the minimum number of blocks across all workers
to ensure this can be applied to all of them.
Finally, the engine will initialize the KV cache
with the calculated number of blocks.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks
=
self
.
_run_workers
(
"profile_num_available_blocks"
,
block_size
=
self
.
cache_config
.
block_size
,
gpu_memory_utilization
=
self
.
cache_config
.
gpu_memory_utilization
,
cpu_swap_space
=
self
.
cache_config
.
swap_space_bytes
,
cache_dtype
=
self
.
cache_config
.
cache_dtype
,
)
# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
# operators can be applied to all workers.
num_gpu_blocks
=
min
(
b
[
0
]
for
b
in
num_blocks
)
num_cpu_blocks
=
min
(
b
[
1
]
for
b
in
num_blocks
)
if
self
.
cache_config
.
forced_num_gpu_blocks
is
not
None
:
forced_num_gpu_blocks
=
self
.
cache_config
.
forced_num_gpu_blocks
logger
.
info
(
f
"Replacing profiled
{
num_gpu_blocks
=
}
with "
f
"
{
forced_num_gpu_blocks
=
}
"
)
num_gpu_blocks
=
forced_num_gpu_blocks
logger
.
info
(
f
"# GPU blocks:
{
num_gpu_blocks
}
, "
f
"# CPU blocks:
{
num_cpu_blocks
}
"
)
check_block_size_valid
(
num_gpu_blocks
,
self
.
cache_config
.
block_size
,
self
.
model_config
.
max_model_len
)
self
.
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
self
.
cache_config
.
num_cpu_blocks
=
num_cpu_blocks
# Initialize the cache.
self
.
_run_workers
(
"init_cache_engine"
,
cache_config
=
self
.
cache_config
)
# Warm up the model. This includes capturing the model into CUDA graph
# if enforce_eager is False.
self
.
_run_workers
(
"warm_up_model"
)
def
execute_model
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]])
->
SamplerOutput
:
all_outputs
=
self
.
_run_workers
(
"execute_model"
,
driver_kwargs
=
{
"seq_group_metadata_list"
:
seq_group_metadata_list
,
"blocks_to_swap_in"
:
blocks_to_swap_in
,
"blocks_to_swap_out"
:
blocks_to_swap_out
,
"blocks_to_copy"
:
blocks_to_copy
,
},
use_ray_compiled_dag
=
USE_RAY_COMPILED_DAG
)
# Only the driver worker returns the sampling results.
output
=
all_outputs
[
0
]
return
output
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
assert
lora_request
.
lora_int_id
>
0
,
"lora_id must be greater than 0."
return
self
.
_run_workers
(
"add_lora"
,
lora_request
=
lora_request
,
)
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
assert
lora_id
>
0
,
"lora_id must be greater than 0."
return
self
.
_run_workers
(
"remove_lora"
,
lora_id
=
lora_id
,
)
def
list_loras
(
self
)
->
List
[
int
]:
return
self
.
_run_workers
(
"list_loras"
)
def
_run_workers
(
self
,
method
:
str
,
*
args
,
driver_args
:
Optional
[
List
[
Any
]]
=
None
,
driver_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_concurrent_workers
:
Optional
[
int
]
=
None
,
use_ray_compiled_dag
:
bool
=
False
,
**
kwargs
,
)
->
Any
:
"""Runs the given method on all workers."""
if
max_concurrent_workers
:
raise
NotImplementedError
(
"max_concurrent_workers is not supported yet."
)
if
use_ray_compiled_dag
:
# Right now, compiled DAG can only accept a single
# input. TODO(sang): Fix it.
output_channels
=
self
.
forward_dag
.
execute
(
1
)
else
:
# Start the ray workers first.
ray_worker_outputs
=
[
worker
.
execute_method
.
remote
(
method
,
*
args
,
**
kwargs
)
for
worker
in
self
.
workers
]
if
driver_args
is
None
:
driver_args
=
args
if
driver_kwargs
is
None
:
driver_kwargs
=
kwargs
# Start the driver worker after all the ray workers.
driver_worker_output
=
getattr
(
self
.
driver_worker
,
method
)(
*
driver_args
,
**
driver_kwargs
)
# Get the results of the ray workers.
if
self
.
workers
:
if
use_ray_compiled_dag
:
try
:
ray_worker_outputs
=
[
pickle
.
loads
(
chan
.
begin_read
())
for
chan
in
output_channels
]
finally
:
# Has to call end_read in order to reuse the DAG.
for
chan
in
output_channels
:
chan
.
end_read
()
else
:
ray_worker_outputs
=
ray
.
get
(
ray_worker_outputs
)
return
[
driver_worker_output
]
+
ray_worker_outputs
def
_compiled_ray_dag
(
self
):
import
pkg_resources
required_version
=
"2.9"
current_version
=
pkg_resources
.
get_distribution
(
"ray"
).
version
if
current_version
<
required_version
:
raise
ValueError
(
f
"Ray version
{
required_version
}
or greater is "
f
"required, but found
{
current_version
}
"
)
from
ray.dag
import
InputNode
,
MultiOutputNode
assert
self
.
parallel_config
.
worker_use_ray
# Right now, compiled DAG requires at least 1 arg. We send
# a dummy value for now. It will be fixed soon.
with
InputNode
()
as
input_data
:
forward_dag
=
MultiOutputNode
([
worker
.
execute_model_compiled_dag_remote
.
bind
(
input_data
)
for
worker
in
self
.
workers
])
return
forward_dag
.
experimental_compile
()
def
check_health
(
self
)
->
None
:
"""Raises an error if engine is unhealthy."""
self
.
_check_if_any_actor_is_dead
()
def
_check_if_any_actor_is_dead
(
self
):
if
not
self
.
workers
:
return
dead_actors
=
[]
for
actor
in
self
.
workers
:
actor_state
=
ray
.
state
.
actors
(
actor
.
_ray_actor_id
.
hex
())
# pylint: disable=protected-access
if
actor_state
[
"State"
]
==
"DEAD"
:
dead_actors
.
append
(
actor
)
if
dead_actors
:
raise
RuntimeError
(
"At least one Worker is dead. "
f
"Dead Workers:
{
dead_actors
}
. "
)
class
RayGPUExecutorAsync
(
RayGPUExecutor
,
ExecutorAsyncBase
):
async
def
_run_workers_async
(
self
,
method
:
str
,
*
args
,
driver_args
:
Optional
[
List
[
Any
]]
=
None
,
driver_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
**
kwargs
,
)
->
Any
:
"""Runs the given method on all workers."""
coros
=
[]
if
driver_args
is
None
:
driver_args
=
args
if
driver_kwargs
is
None
:
driver_kwargs
=
kwargs
# Run the driver worker asynchronously.
driver_executor
=
make_async
(
getattr
(
self
.
driver_worker
,
method
))
coros
.
append
(
driver_executor
(
*
driver_args
,
**
driver_kwargs
))
# Run the ray workers asynchronously.
for
worker
in
self
.
workers
:
coros
.
append
(
worker
.
execute_method
.
remote
(
method
,
*
args
,
**
kwargs
))
all_outputs
=
await
asyncio
.
gather
(
*
coros
)
return
all_outputs
async
def
execute_model_async
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
)
->
SamplerOutput
:
all_outputs
=
await
self
.
_run_workers_async
(
"execute_model"
,
driver_kwargs
=
{
"seq_group_metadata_list"
:
seq_group_metadata_list
,
"blocks_to_swap_in"
:
blocks_to_swap_in
,
"blocks_to_swap_out"
:
blocks_to_swap_out
,
"blocks_to_copy"
:
blocks_to_copy
,
})
# Only the driver worker returns the sampling results.
output
=
all_outputs
[
0
]
return
output
async
def
check_health_async
(
self
)
->
None
:
"""Raises an error if engine is unhealthy."""
self
.
_check_if_any_actor_is_dead
()
vllm/executor/utils.py
0 → 100644
View file @
7c4f76e3
def
check_block_size_valid
(
num_gpu_blocks
,
block_size
,
max_model_len
)
->
None
:
if
num_gpu_blocks
<=
0
:
raise
ValueError
(
"No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine."
)
max_seq_len
=
block_size
*
num_gpu_blocks
if
max_model_len
>
max_seq_len
:
raise
ValueError
(
f
"The model's max seq len (
{
max_model_len
}
) "
"is larger than the maximum number of tokens that can be "
f
"stored in KV cache (
{
max_seq_len
}
). Try increasing "
"`gpu_memory_utilization` or decreasing `max_model_len` when "
"initializing the engine."
)
vllm/logger.py
View file @
7c4f76e3
...
...
@@ -2,8 +2,8 @@
# https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py
"""Logging configuration for vLLM."""
import
logging
import
sys
import
os
import
sys
VLLM_CONFIGURE_LOGGING
=
int
(
os
.
getenv
(
"VLLM_CONFIGURE_LOGGING"
,
"1"
))
...
...
vllm/lora/layers.py
View file @
7c4f76e3
# pylint: disable=unused-argument
import
inspect
import
math
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tupl
e
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Set
,
Tuple
,
Typ
e
import
torch
import
torch.nn
as
nn
...
...
@@ -10,20 +11,20 @@ from transformers import PretrainedConfig
from
vllm.config
import
LoRAConfig
from
vllm.lora.punica
import
add_lora
,
add_lora_slice
,
bgmv
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.parallel_utils.communication_op
import
(
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_gather
,
)
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
Row
ParallelLinear
,
MergedColumn
ParallelLinear
,
QKVParallelLinear
,
MergedColumnParallelLinear
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
,
ParallelLMHead
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.communication_op
import
(
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_gather
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.utils
import
split_tensor_along_last_dim
from
vllm.model_executor.parallel_utils.utils
import
(
split_tensor_along_last_dim
)
if
TYPE_CHECKING
:
pass
...
...
@@ -84,7 +85,8 @@ def _apply_lora_packed_nslice(
lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank)
indices: (batch_size)
output: (batch_size, q_slice_size + 2*kv_slice_size)
output_slices: n-1 element tuple of (slice_size...), where n is number of slices
output_slices: n-1 element tuple of (slice_size...),
where n is number of slices
"""
org_output
=
output
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
...
...
@@ -113,8 +115,11 @@ class LoRAMapping:
class
BaseLayerWithLoRA
(
nn
.
Module
):
def
create_lora_weights
(
self
,
max_loras
:
int
,
lora_config
:
LoRAConfig
,
model_config
:
PretrainedConfig
)
->
None
:
def
create_lora_weights
(
self
,
max_loras
:
int
,
lora_config
:
LoRAConfig
,
model_config
:
Optional
[
PretrainedConfig
]
=
None
)
->
None
:
"""Initializes lora matrices."""
...
...
...
@@ -143,6 +148,13 @@ class BaseLayerWithLoRA(nn.Module):
"""Sets the mapping indices."""
...
@
classmethod
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
"""Returns True if the layer can be replaced by this LoRA layer."""
raise
NotImplementedError
class
VocabParallelEmbeddingWithLoRA
(
BaseLayerWithLoRA
):
...
...
@@ -277,12 +289,19 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self
.
indices
[:
self
.
indices_len
[
0
]],
0
,
1.0
)
return
full_output
.
view_as
(
full_output_org
)
@
classmethod
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
return
type
(
source_layer
)
is
VocabParallelEmbedding
class
ColumnParallelLinearWithLoRA
(
BaseLayerWithLoRA
):
def
__init__
(
self
,
base_layer
:
ColumnParallelLinear
)
->
None
:
super
().
__init__
()
self
.
base_layer
=
base_layer
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
def
create_lora_weights
(
self
,
...
...
@@ -308,7 +327,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
self
.
indices
:
Optional
[
torch
.
Tensor
]
=
None
self
.
indices_len
:
Optional
[
List
[
int
]]
=
None
self
.
output_dim
=
self
.
lora_b_stacked
.
shape
[
1
]
self
.
output_dim
=
self
.
lora_b_stacked
.
shape
[
2
]
def
reset_lora
(
self
,
index
:
int
):
self
.
lora_a_stacked
[
index
]
=
0
...
...
@@ -322,7 +341,12 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
embeddings_tensor
:
Optional
[
torch
.
Tensor
],
):
self
.
reset_lora
(
index
)
if
self
.
tp_size
>
1
:
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
shard_size
=
self
.
output_dim
start_idx
=
tensor_model_parallel_rank
*
shard_size
end_idx
=
(
tensor_model_parallel_rank
+
1
)
*
shard_size
lora_b
=
lora_b
[:,
start_idx
:
end_idx
]
self
.
lora_a_stacked
[
index
,
0
,
:
lora_a
.
shape
[
1
],
:
lora_a
.
shape
[
0
]].
copy_
(
lora_a
.
T
,
non_blocking
=
True
)
...
...
@@ -382,6 +406,14 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
def
linear_weights
(
self
):
return
self
.
base_layer
.
linear_weights
@
classmethod
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
return
type
(
source_layer
)
is
ColumnParallelLinear
or
(
type
(
source_layer
)
is
MergedColumnParallelLinear
and
len
(
packed_modules_list
)
==
1
)
class
MergedColumnParallelLinearWithLoRA
(
ColumnParallelLinearWithLoRA
):
"""ColumnParallelLinear layer that is composed of 2 sublayers (slices)
...
...
@@ -484,8 +516,80 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
)
return
output
@
classmethod
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
return
type
(
source_layer
)
is
MergedColumnParallelLinear
and
len
(
packed_modules_list
)
==
2
class
QKVParallelLinearWithLora
(
ColumnParallelLinearWithLoRA
):
"""
ColumnParallelLinear layer that is specifically designed for
qkv_proj. Certain models, such as chtglm3 and baichuan-7b,
only contains a single LoRA within their qkv_proj layer.
During inference with Tensor Parallel, the weights of lora_b
must be accurately partitioned according to the respective ranks.
Q slice may have different shape than K and V slices (which both have
the same shape).
"""
def
__init__
(
self
,
base_layer
:
QKVParallelLinear
)
->
None
:
super
().
__init__
(
base_layer
)
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
q_proj_total_size
=
(
self
.
base_layer
.
total_num_heads
*
self
.
base_layer
.
head_size
)
self
.
q_proj_shard_size
=
(
self
.
base_layer
.
num_heads
*
self
.
base_layer
.
head_size
)
self
.
kv_proj_shard_size
=
(
self
.
base_layer
.
num_kv_heads
*
self
.
base_layer
.
head_size
)
self
.
kv_proj_total_size
=
(
self
.
base_layer
.
total_num_kv_heads
*
self
.
base_layer
.
head_size
)
def
set_lora
(
self
,
index
:
int
,
lora_a
:
torch
.
Tensor
,
lora_b
:
torch
.
Tensor
,
embeddings_tensor
:
Optional
[
torch
.
Tensor
],
):
self
.
reset_lora
(
index
)
if
self
.
tp_size
>
1
:
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
q_shard_id
=
tp_rank
self
.
kv_shard_id
=
tp_rank
//
self
.
base_layer
.
num_kv_head_replicas
lora_b_q
=
lora_b
[:,
self
.
q_proj_shard_size
*
self
.
q_shard_id
:
self
.
q_proj_shard_size
*
(
self
.
q_shard_id
+
1
)]
k_offset
=
self
.
q_proj_total_size
lora_b_k
=
lora_b
[:,
k_offset
+
self
.
kv_proj_shard_size
*
self
.
kv_shard_id
:
k_offset
+
self
.
kv_proj_shard_size
*
(
self
.
kv_shard_id
+
1
)]
v_offset
=
k_offset
+
self
.
kv_proj_total_size
lora_b_v
=
lora_b
[:,
v_offset
+
self
.
kv_proj_shard_size
*
self
.
kv_shard_id
:
v_offset
+
self
.
kv_proj_shard_size
*
(
self
.
kv_shard_id
+
1
)]
lora_b
=
torch
.
cat
([
lora_b_q
,
lora_b_k
,
lora_b_v
],
dim
=
1
)
self
.
lora_a_stacked
[
index
,
0
,
:
lora_a
.
shape
[
1
],
:
lora_a
.
shape
[
0
]].
copy_
(
lora_a
.
T
,
non_blocking
=
True
)
self
.
lora_b_stacked
[
index
,
0
,
:
lora_b
.
shape
[
1
],
:
lora_b
.
shape
[
0
]].
copy_
(
lora_b
.
T
,
non_blocking
=
True
)
@
classmethod
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
return
type
(
source_layer
)
is
QKVParallelLinear
and
len
(
packed_modules_list
)
==
1
class
MergedQKVParallelLinearWithLora
(
ColumnParallelLinearWithLoRA
):
"""ColumnParallelLinear layer that is composed of 3 sublayers (slices)
packed together in qkv proj fashion
(q_proj + k_proj + v_proj -> qkv_proj).
...
...
@@ -653,6 +757,13 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
)
return
output
@
classmethod
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
return
type
(
source_layer
)
is
QKVParallelLinear
and
len
(
packed_modules_list
)
==
3
class
RowParallelLinearWithLoRA
(
BaseLayerWithLoRA
):
...
...
@@ -779,12 +890,18 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def
weight
(
self
):
return
self
.
base_layer
.
weight
@
classmethod
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
return
type
(
source_layer
)
is
RowParallelLinear
class
SamplerWithLoRA
(
BaseLayerWithLoRA
):
class
LogitsProcessorWithLoRA
(
BaseLayerWithLoRA
):
def
__init__
(
self
,
base_layer
:
Sample
r
,
base_layer
:
LogitsProcesso
r
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
...
...
@@ -796,13 +913,17 @@ class SamplerWithLoRA(BaseLayerWithLoRA):
self
.
device
=
device
@
property
def
logits_as_
hidden_states
(
self
):
return
self
.
base_layer
.
logits_as_
hidden_states
def
logits_as_
input
(
self
):
return
self
.
base_layer
.
logits_as_
input
@
property
def
vocab_size
(
self
):
return
self
.
base_layer
.
vocab_size
@
property
def
scale
(
self
):
return
self
.
base_layer
.
scale
@
property
def
org_vocab_size
(
self
):
return
self
.
base_layer
.
org_vocab_size
...
...
@@ -819,9 +940,8 @@ class SamplerWithLoRA(BaseLayerWithLoRA):
)
->
None
:
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h
if
32000
<
self
.
base_layer
.
vocab_size
>
33024
:
raise
ValueError
(
"When using LoRA, vocab size must be 32000 >= vocab_size <= 33024"
)
raise
ValueError
(
"When using LoRA, vocab size must be "
"32000 >= vocab_size <= 33024"
)
self
.
lora_a_stacked
=
torch
.
zeros
(
(
max_loras
,
...
...
@@ -896,7 +1016,7 @@ class SamplerWithLoRA(BaseLayerWithLoRA):
hidden_states
:
torch
.
Tensor
,
embedding
:
torch
.
Tensor
,
embedding_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
Optional
[
torch
.
Tensor
]
:
# Get the logits for the next tokens.
logits
=
torch
.
matmul
(
hidden_states
,
embedding
.
t
())
if
embedding_bias
is
not
None
:
...
...
@@ -945,35 +1065,43 @@ class SamplerWithLoRA(BaseLayerWithLoRA):
def
forward
(
self
,
*
args
,
**
kwargs
):
return
type
(
self
.
base_layer
).
forward
(
self
,
*
args
,
**
kwargs
)
def
from_layer
(
layer
:
nn
.
Module
,
max_loras
:
int
,
lora_config
:
LoRAConfig
,
model_config
:
Optional
[
PretrainedConfig
]
=
None
)
->
BaseLayerWithLoRA
:
supported_layer_types
=
{
VocabParallelEmbedding
:
VocabParallelEmbeddingWithLoRA
,
ColumnParallelLinear
:
ColumnParallelLinearWithLoRA
,
QKVParallelLinear
:
QKVParallelLinearWithLora
,
MergedColumnParallelLinear
:
MergedColumnParallelLinearWithLoRA
,
RowParallelLinear
:
RowParallelLinearWithLoRA
,
}
for
src_layer_type
,
lora_layer_type
in
supported_layer_types
.
items
():
if
type
(
layer
)
is
src_layer_type
:
# pylint: disable=unidiomatic-typecheck
ret
=
lora_layer_type
(
layer
)
@
classmethod
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
# Special handling for the LogitsProcessor.
return
False
_all_lora_classes
:
Set
[
Type
[
BaseLayerWithLoRA
]]
=
{
cls
for
cls
in
globals
().
values
()
if
inspect
.
isclass
(
cls
)
and
issubclass
(
cls
,
BaseLayerWithLoRA
)
and
cls
is
not
BaseLayerWithLoRA
}
def
from_layer
(
layer
:
nn
.
Module
,
max_loras
:
int
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
]
=
None
)
->
nn
.
Module
:
for
lora_cls
in
_all_lora_classes
:
if
lora_cls
.
can_replace_layer
(
layer
,
lora_config
,
packed_modules_list
,
model_config
):
ret
=
lora_cls
(
layer
)
ret
.
create_lora_weights
(
max_loras
,
lora_config
,
model_config
)
return
ret
return
layer
def
from_layer_
sample
r
(
layer
:
Sample
r
,
def
from_layer_
logits_processo
r
(
layer
:
LogitsProcesso
r
,
lm_head
:
ParallelLMHead
,
max_loras
:
int
,
lora_config
:
LoRAConfig
,
model_config
:
Optional
[
PretrainedConfig
]
=
None
,
)
->
Sample
rWithLoRA
:
ret
=
Sample
rWithLoRA
(
layer
,
lm_head
.
embedding_dim
,
lm_head
.
weight
.
dtype
,
lm_head
.
weight
.
device
)
)
->
LogitsProcesso
rWithLoRA
:
ret
=
LogitsProcesso
rWithLoRA
(
layer
,
lm_head
.
embedding_dim
,
lm_head
.
weight
.
dtype
,
lm_head
.
weight
.
device
)
ret
.
create_lora_weights
(
max_loras
,
lora_config
,
model_config
)
return
ret
vllm/lora/lora.py
View file @
7c4f76e3
from
typing
import
List
,
Optional
import
torch
from
vllm.utils
import
in_wsl
from
vllm.utils
import
is_pin_memory_available
class
LoRALayerWeights
:
...
...
@@ -64,7 +65,7 @@ class LoRALayerWeights:
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
embeddings_tensor_dim
:
Optional
[
int
]
=
None
)
->
"LoRALayerWeights"
:
pin_memory
=
str
(
device
)
==
"cpu"
and
not
in_wsl
()
pin_memory
=
str
(
device
)
==
"cpu"
and
is_pin_memory_available
()
lora_a
=
torch
.
zeros
([
input_dim
,
rank
],
dtype
=
dtype
,
device
=
device
,
...
...
Prev
1
…
7
8
9
10
11
12
13
14
15
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment