Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4c922709
Unverified
Commit
4c922709
authored
Mar 11, 2024
by
Zhuohan Li
Committed by
GitHub
Mar 11, 2024
Browse files
Add distributed model executor abstraction (#3191)
parent
657061fd
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
818 additions
and
509 deletions
+818
-509
docs/source/dev/engine/llm_engine.rst
docs/source/dev/engine/llm_engine.rst
+1
-1
format.sh
format.sh
+6
-2
tests/lora/conftest.py
tests/lora/conftest.py
+2
-1
vllm/__init__.py
vllm/__init__.py
+2
-2
vllm/config.py
vllm/config.py
+6
-1
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+38
-68
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+44
-402
vllm/engine/ray_utils.py
vllm/engine/ray_utils.py
+26
-32
vllm/executor/__init__.py
vllm/executor/__init__.py
+0
-0
vllm/executor/executor_base.py
vllm/executor/executor_base.py
+75
-0
vllm/executor/gpu_executor.py
vllm/executor/gpu_executor.py
+163
-0
vllm/executor/ray_gpu_executor.py
vllm/executor/ray_gpu_executor.py
+442
-0
vllm/executor/utils.py
vllm/executor/utils.py
+13
-0
No files found.
docs/source/dev/engine/llm_engine.rst
View file @
4c922709
...
...
@@ -2,5 +2,5 @@ LLMEngine
=================================
.. autoclass:: vllm.engine.llm_engine.LLMEngine
:members: add_request, abort_request, step
, _init_cache
:members: add_request, abort_request, step
:show-inheritance:
\ No newline at end of file
format.sh
View file @
4c922709
...
...
@@ -95,13 +95,17 @@ echo 'vLLM yapf: Done'
# echo 'vLLM mypy:'
# mypy
CODESPELL_EXCLUDES
=(
'--skip'
'*docs/source/_build/**'
)
# check spelling of specified files
spell_check
()
{
codespell
"
$@
"
}
spell_check_all
(){
codespell
--toml
pyproject.toml
codespell
--toml
pyproject.toml
"
${
CODESPELL_EXCLUDES
[@]
}
"
}
# Spelling check of files that differ from main branch.
...
...
@@ -116,7 +120,7 @@ spell_check_changed() {
if
!
git diff
--diff-filter
=
ACM
--quiet
--exit-code
"
$MERGEBASE
"
--
'*.py'
'*.pyi'
&>/dev/null
;
then
git diff
--name-only
--diff-filter
=
ACM
"
$MERGEBASE
"
--
'*.py'
'*.pyi'
| xargs
\
codespell
codespell
"
${
CODESPELL_EXCLUDES
[@]
}
"
fi
}
...
...
tests/lora/conftest.py
View file @
4c922709
...
...
@@ -152,4 +152,5 @@ def llama_2_7b_engine_extra_embeddings() -> nn.Module:
@
pytest
.
fixture
def
llama_2_7b_model_extra_embeddings
(
llama_2_7b_engine_extra_embeddings
)
->
nn
.
Module
:
yield
llama_2_7b_engine_extra_embeddings
.
driver_worker
.
model_runner
.
model
yield
(
llama_2_7b_engine_extra_embeddings
.
model_executor
.
driver_worker
.
model_runner
.
model
)
vllm/__init__.py
View file @
4c922709
...
...
@@ -3,7 +3,7 @@
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
EngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.engine.ray_utils
import
initialize_cluster
from
vllm.engine.ray_utils
import
initialize_
ray_
cluster
from
vllm.entrypoints.llm
import
LLM
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
from
vllm.sampling_params
import
SamplingParams
...
...
@@ -19,5 +19,5 @@ __all__ = [
"EngineArgs"
,
"AsyncLLMEngine"
,
"AsyncEngineArgs"
,
"initialize_cluster"
,
"initialize_
ray_
cluster"
,
]
vllm/config.py
View file @
4c922709
from
typing
import
Optional
,
Union
,
ClassVar
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
,
ClassVar
from
dataclasses
import
dataclass
import
os
from
packaging.version
import
Version
...
...
@@ -10,6 +10,9 @@ from vllm.logger import init_logger
from
vllm.transformers_utils.config
import
get_config
from
vllm.utils
import
get_cpu_memory
,
is_hip
,
is_neuron
,
get_nvcc_cuda_version
if
TYPE_CHECKING
:
from
ray.util.placement_group
import
PlacementGroup
logger
=
init_logger
(
__name__
)
_GB
=
1
<<
30
...
...
@@ -397,6 +400,7 @@ class ParallelConfig:
max_parallel_loading_workers
:
Optional
[
int
]
=
None
,
disable_custom_all_reduce
:
bool
=
False
,
ray_workers_use_nsight
:
bool
=
False
,
placement_group
:
Optional
[
"PlacementGroup"
]
=
None
,
)
->
None
:
self
.
pipeline_parallel_size
=
pipeline_parallel_size
if
is_neuron
():
...
...
@@ -412,6 +416,7 @@ class ParallelConfig:
self
.
max_parallel_loading_workers
=
max_parallel_loading_workers
self
.
disable_custom_all_reduce
=
disable_custom_all_reduce
self
.
ray_workers_use_nsight
=
ray_workers_use_nsight
self
.
placement_group
=
placement_group
self
.
world_size
=
pipeline_parallel_size
*
self
.
tensor_parallel_size
# Ray worker is not supported for Neuron backend.
...
...
vllm/engine/async_llm_engine.py
View file @
4c922709
...
...
@@ -2,8 +2,8 @@ import asyncio
import
os
import
time
from
functools
import
partial
from
typing
import
(
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
Union
,
AsyncIterator
,
Callable
)
from
typing
import
(
Callable
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
Union
,
AsyncIterator
)
from
transformers
import
PreTrainedTokenizer
...
...
@@ -11,7 +11,7 @@ from vllm.lora.request import LoRARequest
from
vllm.config
import
ModelConfig
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.engine.ray_utils
import
initialize_cluster
,
ray
from
vllm.engine.ray_utils
import
initialize_
ray_
cluster
,
ray
from
vllm.logger
import
init_logger
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
...
...
@@ -208,17 +208,10 @@ class _AsyncLLMEngine(LLMEngine):
if
not
scheduler_outputs
.
is_empty
():
# Execute the model.
all_outputs
=
await
self
.
_run_workers_async
(
"execute_model"
,
driver_kwargs
=
{
"seq_group_metadata_list"
:
seq_group_metadata_list
,
"blocks_to_swap_in"
:
scheduler_outputs
.
blocks_to_swap_in
,
"blocks_to_swap_out"
:
scheduler_outputs
.
blocks_to_swap_out
,
"blocks_to_copy"
:
scheduler_outputs
.
blocks_to_copy
,
})
# Only the driver worker returns the sampling results.
output
=
all_outputs
[
0
]
output
=
await
self
.
model_executor
.
execute_model_async
(
seq_group_metadata_list
,
scheduler_outputs
.
blocks_to_swap_in
,
scheduler_outputs
.
blocks_to_swap_out
,
scheduler_outputs
.
blocks_to_copy
)
else
:
output
=
[]
...
...
@@ -268,37 +261,8 @@ class _AsyncLLMEngine(LLMEngine):
lora_request
=
lora_request
,
)
async
def
_run_workers_async
(
self
,
method
:
str
,
*
args
,
driver_args
:
Optional
[
List
[
Any
]]
=
None
,
driver_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
**
kwargs
,
)
->
Any
:
"""Runs the given method on all workers."""
coros
=
[]
if
driver_args
is
None
:
driver_args
=
args
if
driver_kwargs
is
None
:
driver_kwargs
=
kwargs
# Run the driver worker asynchronously.
driver_executor
=
getattr
(
self
.
driver_worker
,
method
)
coros
.
append
(
asyncio
.
get_event_loop
().
run_in_executor
(
None
,
partial
(
driver_executor
,
*
driver_args
,
**
driver_kwargs
)))
# Run the ray workers asynchronously.
for
worker
in
self
.
workers
:
coros
.
append
(
worker
.
execute_method
.
remote
(
method
,
*
args
,
**
kwargs
))
all_outputs
=
await
asyncio
.
gather
(
*
coros
)
return
all_outputs
async
def
check_health_async
(
self
):
"""Raises an error if engine is unhealthy."""
self
.
_check_if_any_actor_is_dead
()
async
def
check_health_async
(
self
)
->
None
:
self
.
model_executor
.
check_health
()
class
AsyncLLMEngine
:
...
...
@@ -353,6 +317,34 @@ class AsyncLLMEngine:
self
.
_request_tracker
:
Optional
[
RequestTracker
]
=
None
self
.
_errored_with
:
Optional
[
BaseException
]
=
None
@
classmethod
def
from_engine_args
(
cls
,
engine_args
:
AsyncEngineArgs
,
start_engine_loop
:
bool
=
True
)
->
"AsyncLLMEngine"
:
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_configs
=
engine_args
.
create_engine_configs
()
parallel_config
=
engine_configs
[
2
]
if
parallel_config
.
worker_use_ray
or
engine_args
.
engine_use_ray
:
initialize_ray_cluster
(
parallel_config
)
from
vllm.executor.ray_gpu_executor
import
RayGPUExecutorAsync
executor_class
=
RayGPUExecutorAsync
else
:
assert
parallel_config
.
world_size
==
1
,
(
"Ray is required if parallel_config.world_size > 1."
)
from
vllm.executor.gpu_executor
import
GPUExecutorAsync
executor_class
=
GPUExecutorAsync
# Create the async LLM engine.
engine
=
cls
(
parallel_config
.
worker_use_ray
,
engine_args
.
engine_use_ray
,
*
engine_configs
,
executor_class
,
log_requests
=
not
engine_args
.
disable_log_requests
,
log_stats
=
not
engine_args
.
disable_log_stats
,
max_log_len
=
engine_args
.
max_log_len
,
start_engine_loop
=
start_engine_loop
)
return
engine
@
property
def
is_running
(
self
)
->
bool
:
return
(
self
.
background_loop
is
not
None
...
...
@@ -670,35 +662,13 @@ class AsyncLLMEngine:
else
:
return
self
.
engine
.
get_model_config
()
@
classmethod
def
from_engine_args
(
cls
,
engine_args
:
AsyncEngineArgs
,
start_engine_loop
:
bool
=
True
)
->
"AsyncLLMEngine"
:
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_configs
=
engine_args
.
create_engine_configs
()
parallel_config
=
engine_configs
[
2
]
# Initialize the cluster.
placement_group
=
initialize_cluster
(
parallel_config
,
engine_args
.
engine_use_ray
)
# Create the async LLM engine.
engine
=
cls
(
parallel_config
.
worker_use_ray
,
engine_args
.
engine_use_ray
,
*
engine_configs
,
placement_group
,
log_requests
=
not
engine_args
.
disable_log_requests
,
log_stats
=
not
engine_args
.
disable_log_stats
,
max_log_len
=
engine_args
.
max_log_len
,
start_engine_loop
=
start_engine_loop
)
return
engine
async
def
do_log_stats
(
self
)
->
None
:
if
self
.
engine_use_ray
:
await
self
.
engine
.
do_log_stats
.
remote
()
else
:
self
.
engine
.
do_log_stats
()
async
def
check_health
(
self
):
async
def
check_health
(
self
)
->
None
:
"""Raises an error if engine is unhealthy."""
t
=
time
.
perf_counter
()
logger
.
debug
(
"Starting health check..."
)
...
...
vllm/engine/llm_engine.py
View file @
4c922709
import
copy
from
collections
import
defaultdict
import
os
import
time
import
pickle
import
importlib
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
)
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Type
,
Union
from
transformers
import
PreTrainedTokenizer
...
...
@@ -15,8 +9,9 @@ from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig
,
SchedulerConfig
,
LoRAConfig
)
from
vllm.core.scheduler
import
Scheduler
,
SchedulerOutputs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.engine.metrics
import
StatLogger
,
Stats
from
vllm.engine.ray_utils
import
RayWorkerVllm
,
initialize_cluster
,
ray
from
vllm.engine.ray_utils
import
initialize_
ray_
cluster
from
vllm.logger
import
init_logger
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
...
...
@@ -24,29 +19,11 @@ from vllm.sequence import (Logprob, SamplerOutput, Sequence, SequenceGroup,
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
from
vllm.transformers_utils.tokenizer
import
(
detokenize_incrementally
,
TokenizerGroup
)
from
vllm.utils
import
(
Counter
,
set_cuda_visible_devices
,
get_ip
,
get_open_port
,
get_distributed_init_method
)
if
ray
:
from
ray.util.scheduling_strategies
import
PlacementGroupSchedulingStrategy
if
TYPE_CHECKING
:
from
ray.util.placement_group
import
PlacementGroup
from
vllm.utils
import
Counter
logger
=
init_logger
(
__name__
)
_LOCAL_LOGGING_INTERVAL_SEC
=
5
# A map between the device type (in device config) to its worker module.
DEVICE_TO_WORKER_MODULE_MAP
=
{
"cuda"
:
"vllm.worker.worker"
,
"neuron"
:
"vllm.worker.neuron_worker"
,
}
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run VLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
USE_RAY_COMPILED_DAG
=
bool
(
os
.
getenv
(
"VLLM_USE_RAY_COMPILED_DAG"
,
0
))
class
LLMEngine
:
"""An LLM engine that receives requests and generates texts.
...
...
@@ -71,8 +48,8 @@ class LLMEngine:
parallel_config: The configuration related to distributed execution.
scheduler_config: The configuration related to the request scheduler.
device_config: The configuration related to the device.
placement_group: Ray placement group for distributed execution.
Required for distributed
execution.
executor_class: The model executor class for managing distributed
execution.
log_stats: Whether to log statistics.
"""
...
...
@@ -84,7 +61,7 @@ class LLMEngine:
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
placement_group
:
Optional
[
"PlacementGroup"
],
executor_class
:
Type
[
ExecutorBase
],
log_stats
:
bool
,
)
->
None
:
logger
.
info
(
...
...
@@ -121,33 +98,13 @@ class LLMEngine:
self
.
_init_tokenizer
()
self
.
seq_counter
=
Counter
()
# Create the parallel GPU workers.
if
self
.
parallel_config
.
worker_use_ray
:
# Disable Ray usage stats collection.
ray_usage
=
os
.
environ
.
get
(
"RAY_USAGE_STATS_ENABLED"
,
"0"
)
if
ray_usage
!=
"1"
:
os
.
environ
[
"RAY_USAGE_STATS_ENABLED"
]
=
"0"
# Pass additional arguments to initialize the worker
additional_ray_args
=
{}
if
self
.
parallel_config
.
ray_workers_use_nsight
:
logger
.
info
(
"Configuring Ray workers to use nsight."
)
additional_ray_args
=
{
"runtime_env"
:
{
"nsight"
:
{
"t"
:
"cuda,cudnn,cublas"
,
"o"
:
"'worker_process_%p'"
,
"cuda-graph-trace"
:
"node"
,
}
}
}
self
.
_init_workers_ray
(
placement_group
,
**
additional_ray_args
)
else
:
self
.
_init_workers
()
# Profile the memory usage and initialize the cache.
self
.
_init_cache
()
self
.
model_executor
=
executor_class
(
model_config
,
cache_config
,
parallel_config
,
scheduler_config
,
device_config
,
lora_config
)
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
self
.
scheduler
=
Scheduler
(
scheduler_config
,
cache_config
,
lora_config
)
# Metric Logging.
...
...
@@ -157,9 +114,29 @@ class LLMEngine:
labels
=
dict
(
model_name
=
model_config
.
model
))
self
.
stat_logger
.
info
(
"cache_config"
,
self
.
cache_config
)
self
.
forward_dag
=
None
if
USE_RAY_COMPILED_DAG
:
self
.
forward_dag
=
self
.
_compiled_ray_dag
()
@
classmethod
def
from_engine_args
(
cls
,
engine_args
:
EngineArgs
)
->
"LLMEngine"
:
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
engine_configs
=
engine_args
.
create_engine_configs
()
parallel_config
=
engine_configs
[
2
]
# Initialize the cluster and specify the executor class.
if
parallel_config
.
worker_use_ray
:
initialize_ray_cluster
(
parallel_config
)
from
vllm.executor.ray_gpu_executor
import
RayGPUExecutor
executor_class
=
RayGPUExecutor
else
:
assert
parallel_config
.
world_size
==
1
,
(
"Ray is required if parallel_config.world_size > 1."
)
from
vllm.executor.gpu_executor
import
GPUExecutor
executor_class
=
GPUExecutor
# Create the LLM engine.
engine
=
cls
(
*
engine_configs
,
executor_class
=
executor_class
,
log_stats
=
not
engine_args
.
disable_log_stats
)
return
engine
def
__reduce__
(
self
):
# This is to ensure that the LLMEngine is not referenced in
...
...
@@ -173,39 +150,6 @@ class LLMEngine:
sequence
:
Sequence
)
->
"PreTrainedTokenizer"
:
return
self
.
tokenizer
.
get_lora_tokenizer
(
sequence
.
lora_request
)
def
_dispatch_worker
(
self
):
worker_module
=
DEVICE_TO_WORKER_MODULE_MAP
[
self
.
device_config
.
device_type
]
imported_worker
=
importlib
.
import_module
(
worker_module
)
Worker
=
imported_worker
.
Worker
return
Worker
def
_init_workers
(
self
):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
Worker
=
self
.
_dispatch_worker
()
assert
self
.
parallel_config
.
world_size
==
1
,
(
"Ray is required if parallel_config.world_size > 1."
)
self
.
workers
:
List
[
Worker
]
=
[]
distributed_init_method
=
get_distributed_init_method
(
get_ip
(),
get_open_port
())
self
.
driver_worker
=
Worker
(
self
.
model_config
,
self
.
parallel_config
,
self
.
scheduler_config
,
self
.
device_config
,
local_rank
=
0
,
rank
=
0
,
distributed_init_method
=
distributed_init_method
,
lora_config
=
self
.
lora_config
,
kv_cache_dtype
=
self
.
cache_config
.
cache_dtype
,
is_driver_worker
=
True
,
)
self
.
_run_workers
(
"init_model"
)
self
.
_run_workers
(
"load_model"
)
def
_init_tokenizer
(
self
,
**
tokenizer_init_kwargs
):
init_kwargs
=
dict
(
enable_lora
=
bool
(
self
.
lora_config
),
...
...
@@ -218,126 +162,6 @@ class LLMEngine:
self
.
tokenizer
:
TokenizerGroup
=
TokenizerGroup
(
self
.
model_config
.
tokenizer
,
**
init_kwargs
)
def
_init_workers_ray
(
self
,
placement_group
:
"PlacementGroup"
,
**
ray_remote_kwargs
):
if
self
.
parallel_config
.
tensor_parallel_size
==
1
:
num_gpus
=
self
.
cache_config
.
gpu_memory_utilization
else
:
num_gpus
=
1
self
.
driver_dummy_worker
:
RayWorkerVllm
=
None
self
.
workers
:
List
[
RayWorkerVllm
]
=
[]
driver_ip
=
get_ip
()
for
bundle_id
,
bundle
in
enumerate
(
placement_group
.
bundle_specs
):
if
not
bundle
.
get
(
"GPU"
,
0
):
continue
scheduling_strategy
=
PlacementGroupSchedulingStrategy
(
placement_group
=
placement_group
,
placement_group_capture_child_tasks
=
True
,
placement_group_bundle_index
=
bundle_id
,
)
worker
=
ray
.
remote
(
num_cpus
=
0
,
num_gpus
=
num_gpus
,
scheduling_strategy
=
scheduling_strategy
,
**
ray_remote_kwargs
,
)(
RayWorkerVllm
).
remote
(
self
.
model_config
.
trust_remote_code
)
worker_ip
=
ray
.
get
(
worker
.
get_node_ip
.
remote
())
if
worker_ip
==
driver_ip
and
self
.
driver_dummy_worker
is
None
:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self
.
driver_dummy_worker
=
worker
else
:
self
.
workers
.
append
(
worker
)
if
self
.
driver_dummy_worker
is
None
:
raise
ValueError
(
"Ray does not allocate any GPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a "
"GPU node."
)
driver_node_id
,
driver_gpu_ids
=
ray
.
get
(
self
.
driver_dummy_worker
.
get_node_and_gpu_ids
.
remote
())
worker_node_and_gpu_ids
=
ray
.
get
(
[
worker
.
get_node_and_gpu_ids
.
remote
()
for
worker
in
self
.
workers
])
node_workers
=
defaultdict
(
list
)
node_gpus
=
defaultdict
(
list
)
node_workers
[
driver_node_id
].
append
(
0
)
node_gpus
[
driver_node_id
].
extend
(
driver_gpu_ids
)
for
i
,
(
node_id
,
gpu_ids
)
in
enumerate
(
worker_node_and_gpu_ids
,
start
=
1
):
node_workers
[
node_id
].
append
(
i
)
node_gpus
[
node_id
].
extend
(
gpu_ids
)
for
node_id
,
gpu_ids
in
node_gpus
.
items
():
node_gpus
[
node_id
]
=
sorted
(
gpu_ids
)
# Set CUDA_VISIBLE_DEVICES for the driver.
set_cuda_visible_devices
(
node_gpus
[
driver_node_id
])
for
worker
,
(
node_id
,
_
)
in
zip
(
self
.
workers
,
worker_node_and_gpu_ids
):
worker
.
set_cuda_visible_devices
.
remote
(
node_gpus
[
node_id
])
distributed_init_method
=
get_distributed_init_method
(
driver_ip
,
get_open_port
())
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
Worker
=
self
.
_dispatch_worker
()
# Initialize torch distributed process group for the workers.
model_config
=
copy
.
deepcopy
(
self
.
model_config
)
parallel_config
=
copy
.
deepcopy
(
self
.
parallel_config
)
scheduler_config
=
copy
.
deepcopy
(
self
.
scheduler_config
)
device_config
=
copy
.
deepcopy
(
self
.
device_config
)
lora_config
=
copy
.
deepcopy
(
self
.
lora_config
)
kv_cache_dtype
=
self
.
cache_config
.
cache_dtype
for
rank
,
(
worker
,
(
node_id
,
_
))
in
enumerate
(
zip
(
self
.
workers
,
worker_node_and_gpu_ids
),
start
=
1
):
local_rank
=
node_workers
[
node_id
].
index
(
rank
)
worker
.
init_worker
.
remote
(
lambda
rank
=
rank
,
local_rank
=
local_rank
:
Worker
(
model_config
,
parallel_config
,
scheduler_config
,
device_config
,
local_rank
,
rank
,
distributed_init_method
,
lora_config
=
lora_config
,
kv_cache_dtype
=
kv_cache_dtype
,
))
driver_rank
=
0
driver_local_rank
=
node_workers
[
driver_node_id
].
index
(
driver_rank
)
self
.
driver_worker
=
Worker
(
self
.
model_config
,
self
.
parallel_config
,
self
.
scheduler_config
,
self
.
device_config
,
driver_local_rank
,
driver_rank
,
distributed_init_method
,
lora_config
=
self
.
lora_config
,
kv_cache_dtype
=
kv_cache_dtype
,
is_driver_worker
=
True
,
)
# don't use cupy for eager mode
self
.
_run_workers
(
"init_model"
,
cupy_port
=
get_open_port
()
if
not
model_config
.
enforce_eager
else
None
)
self
.
_run_workers
(
"load_model"
,
max_concurrent_workers
=
self
.
parallel_config
.
max_parallel_loading_workers
,
)
def
_verify_args
(
self
)
->
None
:
self
.
model_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
self
.
cache_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
...
...
@@ -346,81 +170,6 @@ class LLMEngine:
self
.
lora_config
.
verify_with_scheduler_config
(
self
.
scheduler_config
)
def
_init_cache
(
self
)
->
None
:
"""Profiles the memory usage and initializes the KV cache.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
More details can be found in the
:meth:`~vllm.worker.worker.Worker.profile_num_available_blocks` method
from class :class:`~vllm.worker.Worker`.
Afterwards, as there may be multiple workers,
we take the minimum number of blocks across all workers
to ensure this can be applied to all of them.
Finally, the engine will initialize the KV cache
with the calculated number of blocks.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameters.
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks
=
self
.
_run_workers
(
"profile_num_available_blocks"
,
block_size
=
self
.
cache_config
.
block_size
,
gpu_memory_utilization
=
self
.
cache_config
.
gpu_memory_utilization
,
cpu_swap_space
=
self
.
cache_config
.
swap_space_bytes
,
cache_dtype
=
self
.
cache_config
.
cache_dtype
,
)
# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
# operators can be applied to all workers.
num_gpu_blocks
=
min
(
b
[
0
]
for
b
in
num_blocks
)
num_cpu_blocks
=
min
(
b
[
1
]
for
b
in
num_blocks
)
# FIXME(woosuk): Change to debug log.
logger
.
info
(
f
"# GPU blocks:
{
num_gpu_blocks
}
, "
f
"# CPU blocks:
{
num_cpu_blocks
}
"
)
if
num_gpu_blocks
<=
0
:
raise
ValueError
(
"No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine."
)
max_seq_len
=
self
.
cache_config
.
block_size
*
num_gpu_blocks
if
self
.
model_config
.
max_model_len
>
max_seq_len
:
raise
ValueError
(
f
"The model's max seq len (
{
self
.
model_config
.
max_model_len
}
) "
"is larger than the maximum number of tokens that can be "
f
"stored in KV cache (
{
max_seq_len
}
). Try increasing "
"`gpu_memory_utilization` or decreasing `max_model_len` when "
"initializing the engine."
)
self
.
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
self
.
cache_config
.
num_cpu_blocks
=
num_cpu_blocks
# Initialize the cache.
self
.
_run_workers
(
"init_cache_engine"
,
cache_config
=
self
.
cache_config
)
# Warm up the model. This includes capturing the model into CUDA graph
# if enforce_eager is False.
self
.
_run_workers
(
"warm_up_model"
)
@
classmethod
def
from_engine_args
(
cls
,
engine_args
:
EngineArgs
)
->
"LLMEngine"
:
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
engine_configs
=
engine_args
.
create_engine_configs
()
parallel_config
=
engine_configs
[
2
]
# Initialize the cluster.
placement_group
=
initialize_cluster
(
parallel_config
)
# Create the LLM engine.
engine
=
cls
(
*
engine_configs
,
placement_group
,
log_stats
=
not
engine_args
.
disable_log_stats
)
return
engine
def
encode_request
(
self
,
request_id
:
str
,
# pylint: disable=unused-argument
...
...
@@ -826,7 +575,7 @@ class LLMEngine:
- A Sequence Group (SG) refer to a group of sequences
that are generated from the same prompt.
- Step 2: Calls the
workers
to execute the model.
- Step 2: Calls the
distributed executor
to execute the model.
- Step 3: Processes the model output. This mainly includes:
- Decodes the relevant outputs.
...
...
@@ -862,19 +611,10 @@ class LLMEngine:
seq_group_metadata_list
,
scheduler_outputs
=
self
.
scheduler
.
schedule
()
if
not
scheduler_outputs
.
is_empty
():
# Execute the model.
all_outputs
=
self
.
_run_workers
(
"execute_model"
,
driver_kwargs
=
{
"seq_group_metadata_list"
:
seq_group_metadata_list
,
"blocks_to_swap_in"
:
scheduler_outputs
.
blocks_to_swap_in
,
"blocks_to_swap_out"
:
scheduler_outputs
.
blocks_to_swap_out
,
"blocks_to_copy"
:
scheduler_outputs
.
blocks_to_copy
,
},
use_ray_compiled_dag
=
USE_RAY_COMPILED_DAG
)
# Only the driver worker returns the sampling results.
output
=
all_outputs
[
0
]
output
=
self
.
model_executor
.
execute_model
(
seq_group_metadata_list
,
scheduler_outputs
.
blocks_to_swap_in
,
scheduler_outputs
.
blocks_to_swap_out
,
scheduler_outputs
.
blocks_to_copy
)
else
:
output
=
[]
...
...
@@ -1043,111 +783,13 @@ class LLMEngine:
seq
.
output_text
=
seq
.
output_text
[:
-
len
(
stop_string
)]
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
assert
lora_request
.
lora_int_id
>
0
,
"lora_id must be greater than 0."
return
self
.
_run_workers
(
"add_lora"
,
lora_request
=
lora_request
,
)
return
self
.
model_executor
.
add_lora
(
lora_request
)
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
assert
lora_id
>
0
,
"lora_id must be greater than 0."
return
self
.
_run_workers
(
"remove_lora"
,
lora_id
=
lora_id
,
)
return
self
.
model_executor
.
remove_lora
(
lora_id
)
def
list_loras
(
self
)
->
List
[
int
]:
return
self
.
_run_workers
(
"list_loras"
)
def
_run_workers
(
self
,
method
:
str
,
*
args
,
driver_args
:
Optional
[
List
[
Any
]]
=
None
,
driver_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_concurrent_workers
:
Optional
[
int
]
=
None
,
use_ray_compiled_dag
:
bool
=
False
,
**
kwargs
,
)
->
Any
:
"""Runs the given method on all workers."""
if
max_concurrent_workers
:
raise
NotImplementedError
(
"max_concurrent_workers is not supported yet."
)
if
use_ray_compiled_dag
:
# Right now, compiled DAG can only accept a single
# input. TODO(sang): Fix it.
output_channels
=
self
.
forward_dag
.
execute
(
1
)
else
:
# Start the ray workers first.
ray_worker_outputs
=
[
worker
.
execute_method
.
remote
(
method
,
*
args
,
**
kwargs
)
for
worker
in
self
.
workers
]
if
driver_args
is
None
:
driver_args
=
args
if
driver_kwargs
is
None
:
driver_kwargs
=
kwargs
# Start the driver worker after all the ray workers.
driver_worker_output
=
getattr
(
self
.
driver_worker
,
method
)(
*
driver_args
,
**
driver_kwargs
)
# Get the results of the ray workers.
if
self
.
workers
:
if
use_ray_compiled_dag
:
try
:
ray_worker_outputs
=
[
pickle
.
loads
(
chan
.
begin_read
())
for
chan
in
output_channels
]
finally
:
# Has to call end_read in order to reuse the DAG.
for
chan
in
output_channels
:
chan
.
end_read
()
else
:
ray_worker_outputs
=
ray
.
get
(
ray_worker_outputs
)
return
[
driver_worker_output
]
+
ray_worker_outputs
def
_compiled_ray_dag
(
self
):
import
pkg_resources
required_version
=
"2.9"
current_version
=
pkg_resources
.
get_distribution
(
"ray"
).
version
if
current_version
<
required_version
:
raise
ValueError
(
f
"Ray version
{
required_version
}
or greater is "
f
"required, but found
{
current_version
}
"
)
from
ray.dag
import
MultiOutputNode
,
InputNode
assert
self
.
parallel_config
.
worker_use_ray
# Right now, compiled DAG requires at least 1 arg. We send
# a dummy value for now. It will be fixed soon.
with
InputNode
()
as
input_data
:
forward_dag
=
MultiOutputNode
([
worker
.
execute_model_compiled_dag_remote
.
bind
(
input_data
)
for
worker
in
self
.
workers
])
return
forward_dag
.
experimental_compile
()
return
self
.
model_executor
.
list_loras
()
def
check_health
(
self
)
->
None
:
"""Raises an error if engine is unhealthy."""
self
.
_check_if_any_actor_is_dead
()
def
_check_if_any_actor_is_dead
(
self
):
if
not
self
.
parallel_config
.
worker_use_ray
:
return
if
not
self
.
workers
:
return
dead_actors
=
[]
for
actor
in
self
.
workers
:
actor_state
=
ray
.
state
.
actors
(
actor
.
_ray_actor_id
.
hex
())
# pylint: disable=protected-access
if
actor_state
[
"State"
]
==
"DEAD"
:
dead_actors
.
append
(
actor
)
if
dead_actors
:
raise
RuntimeError
(
"At least one Worker is dead. "
f
"Dead Workers:
{
dead_actors
}
. "
)
self
.
model_executor
.
check_health
()
vllm/engine/ray_utils.py
View file @
4c922709
import
pickle
from
typing
import
Optional
,
List
,
Tuple
,
TYPE_CHECKING
from
typing
import
Optional
,
List
,
Tuple
from
vllm.config
import
ParallelConfig
from
vllm.logger
import
init_logger
...
...
@@ -65,45 +65,38 @@ except ImportError as e:
ray
=
None
RayWorkerVllm
=
None
if
TYPE_CHECKING
:
from
ray.util.placement_group
import
PlacementGroup
def
initialize_cluster
(
def
initialize_ray_cluster
(
parallel_config
:
ParallelConfig
,
engine_use_ray
:
bool
=
False
,
ray_address
:
Optional
[
str
]
=
None
,
)
->
Optional
[
"PlacementGroup"
]:
"""Initialize the distributed cluster probably with Ray.
):
"""Initialize the distributed cluster with Ray.
it will connect to the Ray cluster and create a placement group
for the workers, which includes the specification of the resources
for each distributed worker.
Args:
parallel_config: The configurations for parallel execution.
engine_use_ray: Whether to use Ray for async engine.
ray_address: The address of the Ray cluster. If None, uses
the default Ray cluster address.
Returns:
An optional `PlacementGroup`. It includes the specification
of the resources for each distributed worker. None if Ray is
not used.
"""
if
parallel_config
.
worker_use_ray
or
engine_use_ray
:
if
ray
is
None
:
raise
ImportError
(
"Ray is not installed. Please install Ray to use distributed "
"serving."
)
# Connect to a ray cluster.
if
is_hip
():
ray
.
init
(
address
=
ray_address
,
ignore_reinit_error
=
True
,
num_gpus
=
parallel_config
.
world_size
)
else
:
ray
.
init
(
address
=
ray_address
,
ignore_reinit_error
=
True
)
if
not
parallel_config
.
worker_use_ray
:
assert
parallel_config
.
world_size
==
1
,
(
"Ray is required if parallel_config.world_size > 1."
)
return
None
if
ray
is
None
:
raise
ImportError
(
"Ray is not installed. Please install Ray to use distributed "
"serving."
)
# Connect to a ray cluster.
if
is_hip
():
ray
.
init
(
address
=
ray_address
,
ignore_reinit_error
=
True
,
num_gpus
=
parallel_config
.
world_size
)
else
:
ray
.
init
(
address
=
ray_address
,
ignore_reinit_error
=
True
)
if
parallel_config
.
placement_group
:
# Placement group is already set.
return
# Create placement group for worker processes
current_placement_group
=
ray
.
util
.
get_current_placement_group
()
...
...
@@ -138,4 +131,5 @@ def initialize_cluster(
# if they cannot be provisioned.
ray
.
get
(
current_placement_group
.
ready
(),
timeout
=
1800
)
return
current_placement_group
# Set the placement group in the parallel config
parallel_config
.
placement_group
=
current_placement_group
vllm/executor/__init__.py
0 → 100644
View file @
4c922709
vllm/executor/executor_base.py
0 → 100644
View file @
4c922709
from
abc
import
ABC
,
abstractmethod
from
typing
import
Dict
,
List
,
Optional
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
LoRAConfig
)
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
class
ExecutorBase
(
ABC
):
"""Base class for all executors.
An executor is responsible for executing the model on a specific device
type (e.g., CPU, GPU, Neuron, etc.). Or it can be a distributed executor
that can execute the model on multiple devices.
"""
@
abstractmethod
def
__init__
(
self
,
model_config
:
ModelConfig
,
cache_config
:
CacheConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
)
->
None
:
raise
NotImplementedError
@
abstractmethod
def
execute_model
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]])
->
SamplerOutput
:
"""Executes one model step on the given sequences."""
raise
NotImplementedError
@
abstractmethod
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
raise
NotImplementedError
@
abstractmethod
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
raise
NotImplementedError
@
abstractmethod
def
list_loras
(
self
)
->
List
[
int
]:
raise
NotImplementedError
@
abstractmethod
def
check_health
(
self
)
->
None
:
"""Checks if the executor is healthy. If not, it should raise an
exception."""
raise
NotImplementedError
class
ExecutorAsyncBase
(
ExecutorBase
):
@
abstractmethod
async
def
execute_model_async
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
)
->
SamplerOutput
:
"""Executes one model step on the given sequences."""
raise
NotImplementedError
@
abstractmethod
async
def
check_health_async
(
self
)
->
None
:
"""Checks if the executor is healthy. If not, it should raise an
exception."""
raise
NotImplementedError
vllm/executor/gpu_executor.py
0 → 100644
View file @
4c922709
import
importlib
from
typing
import
Dict
,
List
,
Optional
from
vllm.lora.request
import
LoRARequest
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
LoRAConfig
)
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.executor.utils
import
check_block_size_valid
from
vllm.logger
import
init_logger
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.utils
import
(
get_ip
,
get_open_port
,
get_distributed_init_method
,
make_async
)
logger
=
init_logger
(
__name__
)
# A map between the device type (in device config) to its worker module.
DEVICE_TO_WORKER_MODULE_MAP
=
{
"cuda"
:
"vllm.worker.worker"
,
"neuron"
:
"vllm.worker.neuron_worker"
,
}
class
GPUExecutor
(
ExecutorBase
):
def
__init__
(
self
,
model_config
:
ModelConfig
,
cache_config
:
CacheConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
)
->
None
:
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
self
.
lora_config
=
lora_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
# Instantiate the worker and load the model to GPU.
self
.
_init_worker
()
# Profile the memory usage and initialize the cache.
self
.
_init_cache
()
def
_dispatch_worker
(
self
):
worker_module
=
DEVICE_TO_WORKER_MODULE_MAP
[
self
.
device_config
.
device_type
]
imported_worker
=
importlib
.
import_module
(
worker_module
)
Worker
=
imported_worker
.
Worker
return
Worker
def
_init_worker
(
self
):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
Worker
=
self
.
_dispatch_worker
()
assert
self
.
parallel_config
.
world_size
==
1
,
(
"GPUExecutor only supports single GPU."
)
distributed_init_method
=
get_distributed_init_method
(
get_ip
(),
get_open_port
())
self
.
driver_worker
=
Worker
(
self
.
model_config
,
self
.
parallel_config
,
self
.
scheduler_config
,
self
.
device_config
,
local_rank
=
0
,
rank
=
0
,
distributed_init_method
=
distributed_init_method
,
lora_config
=
self
.
lora_config
,
kv_cache_dtype
=
self
.
cache_config
.
cache_dtype
,
is_driver_worker
=
True
,
)
self
.
driver_worker
.
init_model
()
self
.
driver_worker
.
load_model
()
def
_init_cache
(
self
)
->
None
:
"""Profiles the memory usage and initializes the KV cache.
The engine first profiles the existing memory usage.
Then, it allocates the remaining memory for KV blocks.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_gpu_blocks
,
num_cpu_blocks
=
(
self
.
driver_worker
.
profile_num_available_blocks
(
block_size
=
self
.
cache_config
.
block_size
,
gpu_memory_utilization
=
self
.
cache_config
.
gpu_memory_utilization
,
cpu_swap_space
=
self
.
cache_config
.
swap_space_bytes
,
cache_dtype
=
self
.
cache_config
.
cache_dtype
,
))
logger
.
info
(
f
"# GPU blocks:
{
num_gpu_blocks
}
, "
f
"# CPU blocks:
{
num_cpu_blocks
}
"
)
check_block_size_valid
(
num_gpu_blocks
,
self
.
cache_config
.
block_size
,
self
.
model_config
.
max_model_len
)
self
.
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
self
.
cache_config
.
num_cpu_blocks
=
num_cpu_blocks
# Initialize the cache.
self
.
driver_worker
.
init_cache_engine
(
cache_config
=
self
.
cache_config
)
# Warm up the model. This includes capturing the model into CUDA graph
# if enforce_eager is False.
self
.
driver_worker
.
warm_up_model
()
def
execute_model
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]])
->
SamplerOutput
:
output
=
self
.
driver_worker
.
execute_model
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
)
return
output
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
assert
lora_request
.
lora_int_id
>
0
,
"lora_id must be greater than 0."
return
self
.
driver_worker
.
add_lora
(
lora_request
)
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
assert
lora_id
>
0
,
"lora_id must be greater than 0."
return
self
.
driver_worker
.
remove_lora
(
lora_id
)
def
list_loras
(
self
)
->
List
[
int
]:
return
self
.
driver_worker
.
list_loras
()
def
check_health
(
self
)
->
None
:
# GPUExecutor will always be healthy as long as
# it's running.
return
class
GPUExecutorAsync
(
GPUExecutor
,
ExecutorAsyncBase
):
async
def
execute_model_async
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
)
->
SamplerOutput
:
output
=
await
make_async
(
self
.
driver_worker
.
execute_model
)(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
)
return
output
async
def
check_health_async
(
self
)
->
None
:
# GPUExecutor will always be healthy as long as
# it's running.
return
vllm/executor/ray_gpu_executor.py
0 → 100644
View file @
4c922709
import
asyncio
import
copy
from
collections
import
defaultdict
import
os
import
pickle
import
importlib
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
LoRAConfig
)
from
vllm.engine.ray_utils
import
RayWorkerVllm
,
ray
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.executor.utils
import
check_block_size_valid
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.utils
import
(
set_cuda_visible_devices
,
get_ip
,
get_open_port
,
get_distributed_init_method
,
make_async
)
if
ray
is
not
None
:
from
ray.util.scheduling_strategies
import
PlacementGroupSchedulingStrategy
if
TYPE_CHECKING
:
from
ray.util.placement_group
import
PlacementGroup
logger
=
init_logger
(
__name__
)
# A map between the device type (in device config) to its worker module.
DEVICE_TO_WORKER_MODULE_MAP
=
{
"cuda"
:
"vllm.worker.worker"
,
"neuron"
:
"vllm.worker.neuron_worker"
,
}
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
USE_RAY_COMPILED_DAG
=
bool
(
os
.
getenv
(
"VLLM_USE_RAY_COMPILED_DAG"
,
0
))
class
RayGPUExecutor
(
ExecutorBase
):
def
__init__
(
self
,
model_config
:
ModelConfig
,
cache_config
:
CacheConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
)
->
None
:
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
self
.
lora_config
=
lora_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
assert
self
.
parallel_config
.
worker_use_ray
placement_group
=
self
.
parallel_config
.
placement_group
# Disable Ray usage stats collection.
ray_usage
=
os
.
environ
.
get
(
"RAY_USAGE_STATS_ENABLED"
,
"0"
)
if
ray_usage
!=
"1"
:
os
.
environ
[
"RAY_USAGE_STATS_ENABLED"
]
=
"0"
# Create the parallel GPU workers.
self
.
_init_workers_ray
(
placement_group
)
# Profile the memory usage and initialize the cache.
self
.
_init_cache
()
self
.
forward_dag
=
None
if
USE_RAY_COMPILED_DAG
:
self
.
forward_dag
=
self
.
_compiled_ray_dag
()
def
_dispatch_worker
(
self
):
worker_module
=
DEVICE_TO_WORKER_MODULE_MAP
[
self
.
device_config
.
device_type
]
imported_worker
=
importlib
.
import_module
(
worker_module
)
Worker
=
imported_worker
.
Worker
return
Worker
def
_init_workers_ray
(
self
,
placement_group
:
"PlacementGroup"
,
**
ray_remote_kwargs
):
if
self
.
parallel_config
.
tensor_parallel_size
==
1
:
# For single GPU case, we use a ray worker with constrained memory.
num_gpus
=
self
.
cache_config
.
gpu_memory_utilization
else
:
# Otherwise, the ray workers are allocated with a full GPU.
num_gpus
=
1
# The driver dummy worker does not actually use any resources.
# It holds the resource for the driver worker.
self
.
driver_dummy_worker
:
RayWorkerVllm
=
None
# The remaining workers are the actual ray actors.
self
.
workers
:
List
[
RayWorkerVllm
]
=
[]
# Create the workers.
driver_ip
=
get_ip
()
for
bundle_id
,
bundle
in
enumerate
(
placement_group
.
bundle_specs
):
if
not
bundle
.
get
(
"GPU"
,
0
):
continue
scheduling_strategy
=
PlacementGroupSchedulingStrategy
(
placement_group
=
placement_group
,
placement_group_capture_child_tasks
=
True
,
placement_group_bundle_index
=
bundle_id
,
)
worker
=
ray
.
remote
(
num_cpus
=
0
,
num_gpus
=
num_gpus
,
scheduling_strategy
=
scheduling_strategy
,
**
ray_remote_kwargs
,
)(
RayWorkerVllm
).
remote
(
self
.
model_config
.
trust_remote_code
)
worker_ip
=
ray
.
get
(
worker
.
get_node_ip
.
remote
())
if
worker_ip
==
driver_ip
and
self
.
driver_dummy_worker
is
None
:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self
.
driver_dummy_worker
=
worker
else
:
# Else, added to the list of workers.
self
.
workers
.
append
(
worker
)
if
self
.
driver_dummy_worker
is
None
:
raise
ValueError
(
"Ray does not allocate any GPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a "
"GPU node."
)
# Get the set of GPU IDs used on each node.
driver_node_id
,
driver_gpu_ids
=
ray
.
get
(
self
.
driver_dummy_worker
.
get_node_and_gpu_ids
.
remote
())
worker_node_and_gpu_ids
=
ray
.
get
(
[
worker
.
get_node_and_gpu_ids
.
remote
()
for
worker
in
self
.
workers
])
node_workers
=
defaultdict
(
list
)
node_gpus
=
defaultdict
(
list
)
node_workers
[
driver_node_id
].
append
(
0
)
node_gpus
[
driver_node_id
].
extend
(
driver_gpu_ids
)
for
i
,
(
node_id
,
gpu_ids
)
in
enumerate
(
worker_node_and_gpu_ids
,
start
=
1
):
node_workers
[
node_id
].
append
(
i
)
node_gpus
[
node_id
].
extend
(
gpu_ids
)
for
node_id
,
gpu_ids
in
node_gpus
.
items
():
node_gpus
[
node_id
]
=
sorted
(
gpu_ids
)
# Set CUDA_VISIBLE_DEVICES for the driver and workers.
set_cuda_visible_devices
(
node_gpus
[
driver_node_id
])
for
worker
,
(
node_id
,
_
)
in
zip
(
self
.
workers
,
worker_node_and_gpu_ids
):
worker
.
set_cuda_visible_devices
.
remote
(
node_gpus
[
node_id
])
distributed_init_method
=
get_distributed_init_method
(
driver_ip
,
get_open_port
())
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
Worker
=
self
.
_dispatch_worker
()
model_config
=
copy
.
deepcopy
(
self
.
model_config
)
parallel_config
=
copy
.
deepcopy
(
self
.
parallel_config
)
scheduler_config
=
copy
.
deepcopy
(
self
.
scheduler_config
)
device_config
=
copy
.
deepcopy
(
self
.
device_config
)
lora_config
=
copy
.
deepcopy
(
self
.
lora_config
)
kv_cache_dtype
=
self
.
cache_config
.
cache_dtype
# Initialize the actual workers with the Worker class.
for
rank
,
(
worker
,
(
node_id
,
_
))
in
enumerate
(
zip
(
self
.
workers
,
worker_node_and_gpu_ids
),
start
=
1
,
):
local_rank
=
node_workers
[
node_id
].
index
(
rank
)
worker
.
init_worker
.
remote
(
lambda
rank
=
rank
,
local_rank
=
local_rank
:
Worker
(
model_config
,
parallel_config
,
scheduler_config
,
device_config
,
local_rank
,
rank
,
distributed_init_method
,
lora_config
=
lora_config
,
kv_cache_dtype
=
kv_cache_dtype
,
))
# Initialize the driver worker with the Worker class.
driver_rank
=
0
driver_local_rank
=
node_workers
[
driver_node_id
].
index
(
driver_rank
)
self
.
driver_worker
=
Worker
(
self
.
model_config
,
self
.
parallel_config
,
self
.
scheduler_config
,
self
.
device_config
,
driver_local_rank
,
driver_rank
,
distributed_init_method
,
lora_config
=
self
.
lora_config
,
kv_cache_dtype
=
kv_cache_dtype
,
is_driver_worker
=
True
,
)
# FIXME(woosuk): We are not properly initializing cupy NCCL when
# we have multiple nodes.
self
.
_run_workers
(
"init_model"
,
cupy_port
=
get_open_port
()
if
not
model_config
.
enforce_eager
else
None
)
self
.
_run_workers
(
"load_model"
,
max_concurrent_workers
=
self
.
parallel_config
.
max_parallel_loading_workers
,
)
def
_init_cache
(
self
)
->
None
:
"""Profiles the memory usage and initializes the KV cache.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
More details can be found in the
:meth:`~vllm.worker.worker.Worker.profile_num_available_blocks` method
from class :class:`~vllm.worker.Worker`.
Afterwards, as there may be multiple workers,
we take the minimum number of blocks across all workers
to ensure this can be applied to all of them.
Finally, the engine will initialize the KV cache
with the calculated number of blocks.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks
=
self
.
_run_workers
(
"profile_num_available_blocks"
,
block_size
=
self
.
cache_config
.
block_size
,
gpu_memory_utilization
=
self
.
cache_config
.
gpu_memory_utilization
,
cpu_swap_space
=
self
.
cache_config
.
swap_space_bytes
,
cache_dtype
=
self
.
cache_config
.
cache_dtype
,
)
# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
# operators can be applied to all workers.
num_gpu_blocks
=
min
(
b
[
0
]
for
b
in
num_blocks
)
num_cpu_blocks
=
min
(
b
[
1
]
for
b
in
num_blocks
)
logger
.
info
(
f
"# GPU blocks:
{
num_gpu_blocks
}
, "
f
"# CPU blocks:
{
num_cpu_blocks
}
"
)
check_block_size_valid
(
num_gpu_blocks
,
self
.
cache_config
.
block_size
,
self
.
model_config
.
max_model_len
)
self
.
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
self
.
cache_config
.
num_cpu_blocks
=
num_cpu_blocks
# Initialize the cache.
self
.
_run_workers
(
"init_cache_engine"
,
cache_config
=
self
.
cache_config
)
# Warm up the model. This includes capturing the model into CUDA graph
# if enforce_eager is False.
self
.
_run_workers
(
"warm_up_model"
)
def
execute_model
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]])
->
SamplerOutput
:
all_outputs
=
self
.
_run_workers
(
"execute_model"
,
driver_kwargs
=
{
"seq_group_metadata_list"
:
seq_group_metadata_list
,
"blocks_to_swap_in"
:
blocks_to_swap_in
,
"blocks_to_swap_out"
:
blocks_to_swap_out
,
"blocks_to_copy"
:
blocks_to_copy
,
},
use_ray_compiled_dag
=
USE_RAY_COMPILED_DAG
)
# Only the driver worker returns the sampling results.
output
=
all_outputs
[
0
]
return
output
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
assert
lora_request
.
lora_int_id
>
0
,
"lora_id must be greater than 0."
return
self
.
_run_workers
(
"add_lora"
,
lora_request
=
lora_request
,
)
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
assert
lora_id
>
0
,
"lora_id must be greater than 0."
return
self
.
_run_workers
(
"remove_lora"
,
lora_id
=
lora_id
,
)
def
list_loras
(
self
)
->
List
[
int
]:
return
self
.
_run_workers
(
"list_loras"
)
def
_run_workers
(
self
,
method
:
str
,
*
args
,
driver_args
:
Optional
[
List
[
Any
]]
=
None
,
driver_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_concurrent_workers
:
Optional
[
int
]
=
None
,
use_ray_compiled_dag
:
bool
=
False
,
**
kwargs
,
)
->
Any
:
"""Runs the given method on all workers."""
if
max_concurrent_workers
:
raise
NotImplementedError
(
"max_concurrent_workers is not supported yet."
)
if
use_ray_compiled_dag
:
# Right now, compiled DAG can only accept a single
# input. TODO(sang): Fix it.
output_channels
=
self
.
forward_dag
.
execute
(
1
)
else
:
# Start the ray workers first.
ray_worker_outputs
=
[
worker
.
execute_method
.
remote
(
method
,
*
args
,
**
kwargs
)
for
worker
in
self
.
workers
]
if
driver_args
is
None
:
driver_args
=
args
if
driver_kwargs
is
None
:
driver_kwargs
=
kwargs
# Start the driver worker after all the ray workers.
driver_worker_output
=
getattr
(
self
.
driver_worker
,
method
)(
*
driver_args
,
**
driver_kwargs
)
# Get the results of the ray workers.
if
self
.
workers
:
if
use_ray_compiled_dag
:
try
:
ray_worker_outputs
=
[
pickle
.
loads
(
chan
.
begin_read
())
for
chan
in
output_channels
]
finally
:
# Has to call end_read in order to reuse the DAG.
for
chan
in
output_channels
:
chan
.
end_read
()
else
:
ray_worker_outputs
=
ray
.
get
(
ray_worker_outputs
)
return
[
driver_worker_output
]
+
ray_worker_outputs
def
_compiled_ray_dag
(
self
):
import
pkg_resources
required_version
=
"2.9"
current_version
=
pkg_resources
.
get_distribution
(
"ray"
).
version
if
current_version
<
required_version
:
raise
ValueError
(
f
"Ray version
{
required_version
}
or greater is "
f
"required, but found
{
current_version
}
"
)
from
ray.dag
import
MultiOutputNode
,
InputNode
assert
self
.
parallel_config
.
worker_use_ray
# Right now, compiled DAG requires at least 1 arg. We send
# a dummy value for now. It will be fixed soon.
with
InputNode
()
as
input_data
:
forward_dag
=
MultiOutputNode
([
worker
.
execute_model_compiled_dag_remote
.
bind
(
input_data
)
for
worker
in
self
.
workers
])
return
forward_dag
.
experimental_compile
()
def
check_health
(
self
)
->
None
:
"""Raises an error if engine is unhealthy."""
self
.
_check_if_any_actor_is_dead
()
def
_check_if_any_actor_is_dead
(
self
):
if
not
self
.
workers
:
return
dead_actors
=
[]
for
actor
in
self
.
workers
:
actor_state
=
ray
.
state
.
actors
(
actor
.
_ray_actor_id
.
hex
())
# pylint: disable=protected-access
if
actor_state
[
"State"
]
==
"DEAD"
:
dead_actors
.
append
(
actor
)
if
dead_actors
:
raise
RuntimeError
(
"At least one Worker is dead. "
f
"Dead Workers:
{
dead_actors
}
. "
)
class
RayGPUExecutorAsync
(
RayGPUExecutor
,
ExecutorAsyncBase
):
async
def
_run_workers_async
(
self
,
method
:
str
,
*
args
,
driver_args
:
Optional
[
List
[
Any
]]
=
None
,
driver_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
**
kwargs
,
)
->
Any
:
"""Runs the given method on all workers."""
coros
=
[]
if
driver_args
is
None
:
driver_args
=
args
if
driver_kwargs
is
None
:
driver_kwargs
=
kwargs
# Run the driver worker asynchronously.
driver_executor
=
make_async
(
getattr
(
self
.
driver_worker
,
method
))
coros
.
append
(
driver_executor
(
*
driver_args
,
**
driver_kwargs
))
# Run the ray workers asynchronously.
for
worker
in
self
.
workers
:
coros
.
append
(
worker
.
execute_method
.
remote
(
method
,
*
args
,
**
kwargs
))
all_outputs
=
await
asyncio
.
gather
(
*
coros
)
return
all_outputs
async
def
execute_model_async
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
)
->
SamplerOutput
:
all_outputs
=
await
self
.
_run_workers_async
(
"execute_model"
,
driver_kwargs
=
{
"seq_group_metadata_list"
:
seq_group_metadata_list
,
"blocks_to_swap_in"
:
blocks_to_swap_in
,
"blocks_to_swap_out"
:
blocks_to_swap_out
,
"blocks_to_copy"
:
blocks_to_copy
,
},
use_ray_compiled_dag
=
USE_RAY_COMPILED_DAG
)
# Only the driver worker returns the sampling results.
output
=
all_outputs
[
0
]
return
output
async
def
check_health_async
(
self
)
->
None
:
"""Raises an error if engine is unhealthy."""
self
.
_check_if_any_actor_is_dead
()
vllm/executor/utils.py
0 → 100644
View file @
4c922709
def
check_block_size_valid
(
num_gpu_blocks
,
block_size
,
max_model_len
)
->
None
:
if
num_gpu_blocks
<=
0
:
raise
ValueError
(
"No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine."
)
max_seq_len
=
block_size
*
num_gpu_blocks
if
max_model_len
>
max_seq_len
:
raise
ValueError
(
f
"The model's max seq len (
{
max_model_len
}
) "
"is larger than the maximum number of tokens that can be "
f
"stored in KV cache (
{
max_seq_len
}
). Try increasing "
"`gpu_memory_utilization` or decreasing `max_model_len` when "
"initializing the engine."
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment