Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9605c125
"tests/vscode:/vscode.git/clone" did not exist on "5fe6bf29d657518eb4251981ada9f8c4f34dbbde"
Unverified
Commit
9605c125
authored
Feb 13, 2025
by
Rui Qiao
Committed by
GitHub
Feb 13, 2025
Browse files
[V1][core] Implement pipeline parallel on Ray (#12996)
parent
0ccd8769
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
110 additions
and
45 deletions
+110
-45
tests/distributed/test_pipeline_parallel.py
tests/distributed/test_pipeline_parallel.py
+39
-12
vllm/executor/ray_utils.py
vllm/executor/ray_utils.py
+9
-2
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+27
-14
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+12
-7
vllm/v1/executor/abstract.py
vllm/v1/executor/abstract.py
+5
-7
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+15
-1
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+3
-2
No files found.
tests/distributed/test_pipeline_parallel.py
View file @
9605c125
...
@@ -40,10 +40,23 @@ class PPTestOptions(NamedTuple):
...
@@ -40,10 +40,23 @@ class PPTestOptions(NamedTuple):
@
dataclass
@
dataclass
class
PPTestSettings
:
class
PPTestSettings
:
parallel_setups
:
List
[
ParallelSetup
]
parallel_setups
:
List
[
ParallelSetup
]
# NOTE: the length of distributed_backends and
# vllm_major_versions should be the same, and they
# are first zipped together to iterate over all
# test settings.
distributed_backends
:
List
[
str
]
distributed_backends
:
List
[
str
]
# vllm major version: "0" for V0, "1" for V1
vllm_major_versions
:
List
[
str
]
task
:
TaskOption
task
:
TaskOption
test_options
:
PPTestOptions
test_options
:
PPTestOptions
def
__post_init__
(
self
):
if
len
(
self
.
distributed_backends
)
!=
len
(
self
.
vllm_major_versions
):
raise
ValueError
(
f
"Length mismatch: distributed_backends "
f
"(
{
len
(
self
.
distributed_backends
)
}
) != "
f
"vllm_major_versions (
{
len
(
self
.
vllm_major_versions
)
}
)"
)
@
staticmethod
@
staticmethod
def
detailed
(
def
detailed
(
*
,
*
,
...
@@ -79,7 +92,9 @@ class PPTestSettings:
...
@@ -79,7 +92,9 @@ class PPTestSettings:
eager_mode
=
True
,
eager_mode
=
True
,
chunked_prefill
=
False
),
chunked_prefill
=
False
),
],
],
distributed_backends
=
[
"mp"
,
"ray"
],
# only ray is supported for V1
distributed_backends
=
[
"mp"
,
"ray"
,
"ray"
],
vllm_major_versions
=
[
"0"
,
"0"
,
"1"
],
task
=
task
,
task
=
task
,
test_options
=
PPTestOptions
(
multi_node_only
=
multi_node_only
,
test_options
=
PPTestOptions
(
multi_node_only
=
multi_node_only
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
...
@@ -108,6 +123,7 @@ class PPTestSettings:
...
@@ -108,6 +123,7 @@ class PPTestSettings:
chunked_prefill
=
False
),
chunked_prefill
=
False
),
],
],
distributed_backends
=
[
"mp"
],
distributed_backends
=
[
"mp"
],
vllm_major_versions
=
[
"0"
],
task
=
task
,
task
=
task
,
test_options
=
PPTestOptions
(
multi_node_only
=
multi_node_only
,
test_options
=
PPTestOptions
(
multi_node_only
=
multi_node_only
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
...
@@ -120,8 +136,9 @@ class PPTestSettings:
...
@@ -120,8 +136,9 @@ class PPTestSettings:
opts
=
self
.
test_options
opts
=
self
.
test_options
for
parallel_setup
in
self
.
parallel_setups
:
for
parallel_setup
in
self
.
parallel_setups
:
for
distributed_backend
in
self
.
distributed_backends
:
for
backend
,
vllm_major_version
in
zip
(
self
.
distributed_backends
,
yield
(
model_name
,
parallel_setup
,
distributed_backend
,
self
.
vllm_major_versions
):
yield
(
model_name
,
parallel_setup
,
backend
,
vllm_major_version
,
self
.
task
,
opts
)
self
.
task
,
opts
)
...
@@ -244,6 +261,7 @@ def _compare_tp(
...
@@ -244,6 +261,7 @@ def _compare_tp(
model_name
:
str
,
model_name
:
str
,
parallel_setup
:
ParallelSetup
,
parallel_setup
:
ParallelSetup
,
distributed_backend
:
str
,
distributed_backend
:
str
,
vllm_major_version
:
str
,
task
:
TaskOption
,
task
:
TaskOption
,
test_options
:
PPTestOptions
,
test_options
:
PPTestOptions
,
num_gpus_available
:
int
,
num_gpus_available
:
int
,
...
@@ -296,10 +314,13 @@ def _compare_tp(
...
@@ -296,10 +314,13 @@ def _compare_tp(
if
hf_overrides
:
if
hf_overrides
:
common_args
.
extend
([
"--hf-overrides"
,
hf_overrides
])
common_args
.
extend
([
"--hf-overrides"
,
hf_overrides
])
if
(
distributed_backend
==
"ray"
and
tp_size
==
2
and
pp_size
==
2
specific_case
=
tp_size
==
2
and
pp_size
==
2
and
chunked_prefill
and
chunked_prefill
):
if
distributed_backend
==
"ray"
and
(
vllm_major_version
==
"1"
# Test Ray ADAG for a subset of the tests
or
specific_case
):
# For V1, test Ray ADAG for all the tests
# For V0, test Ray ADAG for a subset of the tests
pp_env
=
{
pp_env
=
{
"VLLM_USE_V1"
:
vllm_major_version
,
"VLLM_USE_RAY_COMPILED_DAG"
:
"1"
,
"VLLM_USE_RAY_COMPILED_DAG"
:
"1"
,
"VLLM_USE_RAY_SPMD_WORKER"
:
"1"
,
"VLLM_USE_RAY_SPMD_WORKER"
:
"1"
,
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL"
:
"1"
,
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL"
:
"1"
,
...
@@ -348,8 +369,8 @@ def _compare_tp(
...
@@ -348,8 +369,8 @@ def _compare_tp(
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"model_name"
,
"parallel_setup"
,
"distributed_backend"
,
"task"
,
(
"model_name"
,
"parallel_setup"
,
"distributed_backend"
,
"test_options"
),
"vllm_major_version"
,
"task"
,
"test_options"
),
[
[
params
for
model_name
,
settings
in
TEXT_GENERATION_MODELS
.
items
()
params
for
model_name
,
settings
in
TEXT_GENERATION_MODELS
.
items
()
for
params
in
settings
.
iter_params
(
model_name
)
for
params
in
settings
.
iter_params
(
model_name
)
...
@@ -361,6 +382,7 @@ def test_tp_language_generation(
...
@@ -361,6 +382,7 @@ def test_tp_language_generation(
model_name
:
str
,
model_name
:
str
,
parallel_setup
:
ParallelSetup
,
parallel_setup
:
ParallelSetup
,
distributed_backend
:
str
,
distributed_backend
:
str
,
vllm_major_version
:
str
,
task
:
TaskOption
,
task
:
TaskOption
,
test_options
:
PPTestOptions
,
test_options
:
PPTestOptions
,
num_gpus_available
,
num_gpus_available
,
...
@@ -368,6 +390,7 @@ def test_tp_language_generation(
...
@@ -368,6 +390,7 @@ def test_tp_language_generation(
_compare_tp
(
model_name
,
_compare_tp
(
model_name
,
parallel_setup
,
parallel_setup
,
distributed_backend
,
distributed_backend
,
vllm_major_version
,
task
,
task
,
test_options
,
test_options
,
num_gpus_available
,
num_gpus_available
,
...
@@ -375,8 +398,8 @@ def test_tp_language_generation(
...
@@ -375,8 +398,8 @@ def test_tp_language_generation(
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"model_name"
,
"parallel_setup"
,
"distributed_backend"
,
"task"
,
(
"model_name"
,
"parallel_setup"
,
"distributed_backend"
,
"test_options"
),
"vllm_major_version"
,
"task"
,
"test_options"
),
[
[
params
for
model_name
,
settings
in
EMBEDDING_MODELS
.
items
()
params
for
model_name
,
settings
in
EMBEDDING_MODELS
.
items
()
for
params
in
settings
.
iter_params
(
model_name
)
for
params
in
settings
.
iter_params
(
model_name
)
...
@@ -388,6 +411,7 @@ def test_tp_language_embedding(
...
@@ -388,6 +411,7 @@ def test_tp_language_embedding(
model_name
:
str
,
model_name
:
str
,
parallel_setup
:
ParallelSetup
,
parallel_setup
:
ParallelSetup
,
distributed_backend
:
str
,
distributed_backend
:
str
,
vllm_major_version
:
str
,
task
:
TaskOption
,
task
:
TaskOption
,
test_options
:
PPTestOptions
,
test_options
:
PPTestOptions
,
num_gpus_available
,
num_gpus_available
,
...
@@ -395,6 +419,7 @@ def test_tp_language_embedding(
...
@@ -395,6 +419,7 @@ def test_tp_language_embedding(
_compare_tp
(
model_name
,
_compare_tp
(
model_name
,
parallel_setup
,
parallel_setup
,
distributed_backend
,
distributed_backend
,
vllm_major_version
,
task
,
task
,
test_options
,
test_options
,
num_gpus_available
,
num_gpus_available
,
...
@@ -402,8 +427,8 @@ def test_tp_language_embedding(
...
@@ -402,8 +427,8 @@ def test_tp_language_embedding(
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"model_name"
,
"parallel_setup"
,
"distributed_backend"
,
"task"
,
(
"model_name"
,
"parallel_setup"
,
"distributed_backend"
,
"test_options"
),
"vllm_major_version"
,
"task"
,
"test_options"
),
[
[
params
for
model_name
,
settings
in
MULTIMODAL_MODELS
.
items
()
params
for
model_name
,
settings
in
MULTIMODAL_MODELS
.
items
()
for
params
in
settings
.
iter_params
(
model_name
)
for
params
in
settings
.
iter_params
(
model_name
)
...
@@ -415,6 +440,7 @@ def test_tp_multimodal_generation(
...
@@ -415,6 +440,7 @@ def test_tp_multimodal_generation(
model_name
:
str
,
model_name
:
str
,
parallel_setup
:
ParallelSetup
,
parallel_setup
:
ParallelSetup
,
distributed_backend
:
str
,
distributed_backend
:
str
,
vllm_major_version
:
str
,
task
:
TaskOption
,
task
:
TaskOption
,
test_options
:
PPTestOptions
,
test_options
:
PPTestOptions
,
num_gpus_available
,
num_gpus_available
,
...
@@ -422,6 +448,7 @@ def test_tp_multimodal_generation(
...
@@ -422,6 +448,7 @@ def test_tp_multimodal_generation(
_compare_tp
(
model_name
,
_compare_tp
(
model_name
,
parallel_setup
,
parallel_setup
,
distributed_backend
,
distributed_backend
,
vllm_major_version
,
task
,
task
,
test_options
,
test_options
,
num_gpus_available
,
num_gpus_available
,
...
...
vllm/executor/ray_utils.py
View file @
9605c125
...
@@ -35,7 +35,7 @@ try:
...
@@ -35,7 +35,7 @@ try:
class
RayWorkerWrapper
(
WorkerWrapperBase
):
class
RayWorkerWrapper
(
WorkerWrapperBase
):
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
laz
l
iy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
lazi
l
y initialized after Ray sets CUDA_VISIBLE_DEVICES."""
def
__init__
(
self
,
*
args
,
**
kwargs
)
->
None
:
def
__init__
(
self
,
*
args
,
**
kwargs
)
->
None
:
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
...
@@ -118,7 +118,14 @@ try:
...
@@ -118,7 +118,14 @@ try:
)
->
"ModelRunnerOutput"
:
)
->
"ModelRunnerOutput"
:
self
.
setup_device_if_necessary
()
self
.
setup_device_if_necessary
()
assert
self
.
worker
is
not
None
,
"Worker is not initialized"
assert
self
.
worker
is
not
None
,
"Worker is not initialized"
output
=
self
.
worker
.
model_runner
.
execute_model
(
scheduler_output
)
if
isinstance
(
scheduler_output
,
tuple
):
scheduler_output
,
intermediate_tensors
=
scheduler_output
else
:
scheduler_output
,
intermediate_tensors
=
scheduler_output
,
None
output
=
self
.
worker
.
model_runner
.
execute_model
(
scheduler_output
,
intermediate_tensors
)
if
isinstance
(
output
,
IntermediateTensors
):
output
=
scheduler_output
,
output
return
output
return
output
def
override_env_vars
(
self
,
vars
:
Dict
[
str
,
str
]):
def
override_env_vars
(
self
,
vars
:
Dict
[
str
,
str
]):
...
...
vllm/v1/core/kv_cache_utils.py
View file @
9605c125
...
@@ -488,7 +488,8 @@ def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool:
...
@@ -488,7 +488,8 @@ def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool:
def
_get_kv_cache_config_uniform_type
(
vllm_config
:
VllmConfig
,
def
_get_kv_cache_config_uniform_type
(
vllm_config
:
VllmConfig
,
kv_cache_spec
:
KVCacheSpec
,
kv_cache_spec
:
KVCacheSpec
,
available_memory
:
int
)
->
KVCacheConfig
:
available_memory
:
int
,
num_layers
:
int
)
->
KVCacheConfig
:
"""
"""
Generates the KV cache configuration for a model with one type of KV cache.
Generates the KV cache configuration for a model with one type of KV cache.
Divide the available memory equally among all layers.
Divide the available memory equally among all layers.
...
@@ -497,6 +498,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
...
@@ -497,6 +498,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
vllm_config: The global VllmConfig
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of the model
kv_cache_spec: The kv cache spec of the model
available_memory: Memory available for KV cache in bytes.
available_memory: Memory available for KV cache in bytes.
num_layers: The number of layers in the model.
Returns:
Returns:
The generated KVCacheConfig
The generated KVCacheConfig
...
@@ -506,7 +508,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
...
@@ -506,7 +508,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
assert
len
(
page_sizes
)
==
1
assert
len
(
page_sizes
)
==
1
page_size
=
page_sizes
.
pop
()
page_size
=
page_sizes
.
pop
()
num_blocks
=
int
(
available_memory
//
page_size
//
len
(
kv_cache_spec
)
)
num_blocks
=
int
(
available_memory
//
page_size
//
num_layers
)
num_blocks
=
max
(
num_blocks
,
0
)
num_blocks
=
max
(
num_blocks
,
0
)
if
vllm_config
.
cache_config
.
num_gpu_blocks_override
is
not
None
:
if
vllm_config
.
cache_config
.
num_gpu_blocks_override
is
not
None
:
...
@@ -536,25 +538,36 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
...
@@ -536,25 +538,36 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
return
kv_cache_config
return
kv_cache_config
def
get_kv_cache_config
(
vllm_config
:
VllmConfig
,
kv_cache_spec
:
KVCacheSpec
,
def
get_kv_cache_configs
(
vllm_config
:
VllmConfig
,
available_memory
:
int
)
->
KVCacheConfig
:
kv_cache_specs
:
List
[
KVCacheSpec
],
available_memory
:
int
)
->
List
[
KVCacheConfig
]:
"""
"""
Generates the KV cache configuration for a model
Generates the KV cache configuration for a model
TODO: support hybrid models with more than one type of KV cache.
TODO: support hybrid models with more than one type of KV cache.
Args:
Args:
vllm_config: The global VllmConfig
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of the model
kv_cache_spec
s
: The kv cache spec
s
of the model
available_memory: Memory available for KV cache in bytes.
available_memory: Memory available for KV cache in bytes.
Returns:
Returns:
The generated KVCacheConfig
The generated KVCacheConfig
s
"""
"""
check_enough_kv_cache_memory
(
vllm_config
,
kv_cache_spec
,
available_memory
)
# Use the max number of layers to conservatively determine
if
is_kv_cache_type_uniform
(
kv_cache_spec
):
# the number of blocks.
# KV cache of all layers are the same, which is true for most models.
num_layers
=
max
(
len
(
kv_cache_spec
)
for
kv_cache_spec
in
kv_cache_specs
)
# Allocate the same amount of memory for each layer.
kv_cache_configs
=
[]
return
_get_kv_cache_config_uniform_type
(
vllm_config
,
kv_cache_spec
,
for
kv_cache_spec
in
kv_cache_specs
:
check_enough_kv_cache_memory
(
vllm_config
,
kv_cache_spec
,
available_memory
)
available_memory
)
if
is_kv_cache_type_uniform
(
kv_cache_spec
):
# KV cache of all layers are the same, which is true for
# most models. Allocate the same amount of memory for
# each layer.
kv_cache_configs
.
append
(
_get_kv_cache_config_uniform_type
(
vllm_config
,
kv_cache_spec
,
available_memory
,
num_layers
))
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
return
kv_cache_configs
vllm/v1/engine/core.py
View file @
9605c125
...
@@ -16,7 +16,7 @@ from vllm.logger import init_logger
...
@@ -16,7 +16,7 @@ from vllm.logger import init_logger
from
vllm.transformers_utils.config
import
(
from
vllm.transformers_utils.config
import
(
maybe_register_config_serialize_by_value
)
maybe_register_config_serialize_by_value
)
from
vllm.utils
import
get_exception_traceback
,
zmq_socket_ctx
from
vllm.utils
import
get_exception_traceback
,
zmq_socket_ctx
from
vllm.v1.core.kv_cache_utils
import
get_kv_cache_config
from
vllm.v1.core.kv_cache_utils
import
get_kv_cache_config
s
from
vllm.v1.core.scheduler
import
Scheduler
from
vllm.v1.core.scheduler
import
Scheduler
from
vllm.v1.engine
import
(
EngineCoreOutputs
,
EngineCoreRequest
,
from
vllm.v1.engine
import
(
EngineCoreOutputs
,
EngineCoreRequest
,
EngineCoreRequestType
)
EngineCoreRequestType
)
...
@@ -73,20 +73,25 @@ class EngineCore:
...
@@ -73,20 +73,25 @@ class EngineCore:
start
=
time
.
time
()
start
=
time
.
time
()
# Get all kv cache needed by the model
# Get all kv cache needed by the model
kv_cache_spec
=
self
.
model_executor
.
get_kv_cache_spec
()
kv_cache_spec
s
=
self
.
model_executor
.
get_kv_cache_spec
s
()
# Profiles the peak memory usage of the model to determine how much
# Profiles the peak memory usage of the model to determine how much
# memory can be allocated for kv cache.
# memory can be allocated for kv cache.
availble_gpu_memory
=
self
.
model_executor
.
determine_available_memory
()
avail
a
ble_gpu_memory
=
self
.
model_executor
.
determine_available_memory
()
# Get the kv cache tensor size
# Get the kv cache tensor size
kv_cache_config
=
get_kv_cache_config
(
vllm_config
,
kv_cache_spec
,
kv_cache_configs
=
get_kv_cache_configs
(
vllm_config
,
kv_cache_specs
,
availble_gpu_memory
)
available_gpu_memory
)
num_gpu_blocks
=
kv_cache_config
.
num_blocks
num_gpu_blocks_set
=
set
(
config
.
num_blocks
for
config
in
kv_cache_configs
)
assert
len
(
num_gpu_blocks_set
)
==
1
,
(
f
"num_gpu_blocks need to be the same across workers, "
f
"but they are different:
{
num_gpu_blocks_set
}
"
)
num_gpu_blocks
=
num_gpu_blocks_set
.
pop
()
num_cpu_blocks
=
0
num_cpu_blocks
=
0
# Initialize kv cache and warmup the execution
# Initialize kv cache and warmup the execution
self
.
model_executor
.
initialize
(
kv_cache_config
)
self
.
model_executor
.
initialize
(
kv_cache_config
s
)
elapsed
=
time
.
time
()
-
start
elapsed
=
time
.
time
()
-
start
logger
.
info
((
"init engine (profile, create kv cache, "
logger
.
info
((
"init engine (profile, create kv cache, "
...
...
vllm/v1/executor/abstract.py
View file @
9605c125
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Type
from
typing
import
List
,
Type
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.executor.executor_base
import
ExecutorBase
...
@@ -48,12 +48,12 @@ class Executor(ExecutorBase):
...
@@ -48,12 +48,12 @@ class Executor(ExecutorBase):
f
"
{
distributed_executor_backend
}
"
)
f
"
{
distributed_executor_backend
}
"
)
return
executor_class
return
executor_class
def
initialize
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
def
initialize
(
self
,
kv_cache_config
s
:
List
[
KVCacheConfig
]
)
->
None
:
"""
"""
Initialize the KV caches and begin the model execution loop of the
Initialize the KV caches and begin the model execution loop of the
underlying workers.
underlying workers.
"""
"""
self
.
collective_rpc
(
"initialize_cache"
,
args
=
(
kv_cache_config
,
))
self
.
collective_rpc
(
"initialize_cache"
,
args
=
(
kv_cache_config
s
,
))
self
.
collective_rpc
(
"compile_or_warm_up_model"
)
self
.
collective_rpc
(
"compile_or_warm_up_model"
)
def
determine_available_memory
(
self
)
->
int
:
# in bytes
def
determine_available_memory
(
self
)
->
int
:
# in bytes
...
@@ -63,11 +63,9 @@ class Executor(ExecutorBase):
...
@@ -63,11 +63,9 @@ class Executor(ExecutorBase):
# operators can be applied to all workers.
# operators can be applied to all workers.
return
min
(
output
)
return
min
(
output
)
def
get_kv_cache_spec
(
self
)
->
KVCacheSpec
:
def
get_kv_cache_spec
s
(
self
)
->
List
[
KVCacheSpec
]
:
output
=
self
.
collective_rpc
(
"get_kv_cache_spec"
)
output
=
self
.
collective_rpc
(
"get_kv_cache_spec"
)
for
x
in
output
:
return
output
assert
x
==
output
[
0
]
return
output
[
0
]
def
execute_model
(
def
execute_model
(
self
,
self
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
9605c125
...
@@ -12,7 +12,7 @@ import torch.nn as nn
...
@@ -12,7 +12,7 @@ import torch.nn as nn
from
vllm.attention.backends.abstract
import
AttentionType
from
vllm.attention.backends.abstract
import
AttentionType
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.config
import
CompilationLevel
,
VllmConfig
from
vllm.config
import
CompilationLevel
,
VllmConfig
from
vllm.distributed.parallel_state
import
graph_capture
from
vllm.distributed.parallel_state
import
get_pp_group
,
graph_capture
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.inputs
import
INPUT_REGISTRY
from
vllm.inputs
import
INPUT_REGISTRY
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -21,6 +21,7 @@ from vllm.model_executor.model_loader import get_model
...
@@ -21,6 +21,7 @@ from vllm.model_executor.model_loader import get_model
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal.utils
import
group_mm_inputs_by_modality
from
vllm.multimodal.utils
import
group_mm_inputs_by_modality
from
vllm.sampling_params
import
SamplingType
from
vllm.sampling_params
import
SamplingType
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
DeviceMemoryProfiler
,
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
DeviceMemoryProfiler
,
LayerBlockType
,
cdiv
,
is_pin_memory_available
)
LayerBlockType
,
cdiv
,
is_pin_memory_available
)
from
vllm.v1.attention.backends.flash_attn
import
(
FlashAttentionBackend
,
from
vllm.v1.attention.backends.flash_attn
import
(
FlashAttentionBackend
,
...
@@ -773,6 +774,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -773,6 +774,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def
execute_model
(
def
execute_model
(
self
,
self
,
scheduler_output
:
"SchedulerOutput"
,
scheduler_output
:
"SchedulerOutput"
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
ModelRunnerOutput
:
)
->
ModelRunnerOutput
:
batch_changed
=
self
.
_update_states
(
scheduler_output
)
batch_changed
=
self
.
_update_states
(
scheduler_output
)
...
@@ -831,8 +833,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -831,8 +833,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
positions
=
positions
,
positions
=
positions
,
kv_caches
=
self
.
kv_caches
,
kv_caches
=
self
.
kv_caches
,
attn_metadata
=
None
,
attn_metadata
=
None
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
)
)
if
not
get_pp_group
().
is_last_rank
:
return
hidden_states
hidden_states
=
hidden_states
[:
num_scheduled_tokens
]
hidden_states
=
hidden_states
[:
num_scheduled_tokens
]
sample_hidden_states
=
hidden_states
[
logits_indices
]
sample_hidden_states
=
hidden_states
[
logits_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
...
@@ -1007,12 +1012,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1007,12 +1012,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
positions
=
self
.
mrope_positions
[:,
:
num_tokens
]
positions
=
self
.
mrope_positions
[:,
:
num_tokens
]
else
:
else
:
positions
=
self
.
positions
[:
num_tokens
]
positions
=
self
.
positions
[:
num_tokens
]
intermediate_tensors
=
None
if
not
get_pp_group
().
is_first_rank
:
intermediate_tensors
=
self
.
model
.
make_empty_intermediate_tensors
(
batch_size
=
num_tokens
,
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
)
with
set_forward_context
(
None
,
self
.
vllm_config
):
with
set_forward_context
(
None
,
self
.
vllm_config
):
hidden_states
=
model
(
hidden_states
=
model
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
positions
=
positions
,
positions
=
positions
,
kv_caches
=
kv_caches
,
kv_caches
=
kv_caches
,
attn_metadata
=
None
,
attn_metadata
=
None
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
)
)
return
hidden_states
return
hidden_states
...
@@ -1142,6 +1154,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1142,6 +1154,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Trigger compilation for general shape.
# Trigger compilation for general shape.
hidden_states
=
self
.
_dummy_run
(
self
.
max_num_tokens
,
hidden_states
=
self
.
_dummy_run
(
self
.
max_num_tokens
,
dummy_kv_caches
)
dummy_kv_caches
)
if
not
get_pp_group
().
is_last_rank
:
return
hidden_states
hidden_states
=
hidden_states
[
logit_indices
]
hidden_states
=
hidden_states
[
logit_indices
]
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
None
)
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
None
)
# TODO(woosuk): Consider the memory usage of the sampler.
# TODO(woosuk): Consider the memory usage of the sampler.
...
...
vllm/v1/worker/gpu_worker.py
View file @
9605c125
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
"""A GPU worker class."""
"""A GPU worker class."""
import
gc
import
gc
import
os
import
os
from
typing
import
TYPE_CHECKING
,
Optional
from
typing
import
TYPE_CHECKING
,
List
,
Optional
import
torch
import
torch
import
torch.distributed
import
torch.distributed
...
@@ -194,8 +194,9 @@ class Worker:
...
@@ -194,8 +194,9 @@ class Worker:
def
get_kv_cache_spec
(
self
)
->
KVCacheSpec
:
def
get_kv_cache_spec
(
self
)
->
KVCacheSpec
:
return
self
.
model_runner
.
get_kv_cache_spec
()
return
self
.
model_runner
.
get_kv_cache_spec
()
def
initialize_cache
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
def
initialize_cache
(
self
,
kv_cache_config
s
:
List
[
KVCacheConfig
]
)
->
None
:
"""Allocate GPU KV cache with the specified kv_cache_config."""
"""Allocate GPU KV cache with the specified kv_cache_config."""
kv_cache_config
=
kv_cache_configs
[
self
.
rank
]
if
self
.
vllm_config
.
model_config
.
enable_sleep_mode
:
if
self
.
vllm_config
.
model_config
.
enable_sleep_mode
:
allocator
=
CuMemAllocator
.
get_instance
()
allocator
=
CuMemAllocator
.
get_instance
()
context
=
allocator
.
use_memory_pool
(
tag
=
"kv_cache"
)
context
=
allocator
.
use_memory_pool
(
tag
=
"kv_cache"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment