Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b9e12416
Commit
b9e12416
authored
May 31, 2024
by
zhuwenwen
Browse files
merge v0.4.3
parents
e5d707db
e9d3aa04
Changes
345
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1774 additions
and
792 deletions
+1774
-792
vllm/config.py
vllm/config.py
+183
-85
vllm/core/block/block_table.py
vllm/core/block/block_table.py
+32
-2
vllm/core/block/common.py
vllm/core/block/common.py
+10
-11
vllm/core/block/cpu_gpu_block_allocator.py
vllm/core/block/cpu_gpu_block_allocator.py
+78
-4
vllm/core/block/interfaces.py
vllm/core/block/interfaces.py
+12
-3
vllm/core/block/naive_block.py
vllm/core/block/naive_block.py
+4
-4
vllm/core/block/prefix_caching_block.py
vllm/core/block/prefix_caching_block.py
+28
-21
vllm/core/block/utils.py
vllm/core/block/utils.py
+56
-0
vllm/core/block_manager_v1.py
vllm/core/block_manager_v1.py
+141
-60
vllm/core/block_manager_v2.py
vllm/core/block_manager_v2.py
+75
-17
vllm/core/embedding_model_block_manager.py
vllm/core/embedding_model_block_manager.py
+84
-0
vllm/core/interfaces.py
vllm/core/interfaces.py
+10
-4
vllm/core/scheduler.py
vllm/core/scheduler.py
+80
-52
vllm/distributed/communication_op.py
vllm/distributed/communication_op.py
+116
-36
vllm/distributed/device_communicators/custom_all_reduce.py
vllm/distributed/device_communicators/custom_all_reduce.py
+178
-141
vllm/distributed/device_communicators/custom_all_reduce_utils.py
...stributed/device_communicators/custom_all_reduce_utils.py
+186
-0
vllm/distributed/device_communicators/pynccl.py
vllm/distributed/device_communicators/pynccl.py
+114
-216
vllm/distributed/device_communicators/pynccl_utils.py
vllm/distributed/device_communicators/pynccl_utils.py
+0
-66
vllm/distributed/device_communicators/pynccl_wrapper.py
vllm/distributed/device_communicators/pynccl_wrapper.py
+278
-0
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+109
-70
No files found.
Too many changes to show.
To preserve performance only
345 of 345+
files are displayed.
Plain diff
Email patch
vllm/config.py
View file @
b9e12416
import
enum
import
enum
import
json
import
json
from
dataclasses
import
dataclass
,
field
,
fields
from
dataclasses
import
dataclass
,
field
,
fields
from
typing
import
TYPE_CHECKING
,
ClassVar
,
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
ClassVar
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
packaging.version
import
Version
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization
import
(
QUANTIZATION_METHODS
,
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
get_quantization_config
)
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.transformers_utils.config
import
get_config
,
get_hf_text_config
from
vllm.transformers_utils.config
import
get_config
,
get_hf_text_config
from
vllm.utils
import
(
get_cpu_memory
,
get_nvcc_cuda_version
,
is_cpu
,
is_hip
,
from
vllm.utils
import
get_cpu_memory
,
is_cpu
,
is_hip
,
is_neuron
is_neuron
)
GPTQMarlinConfig
=
get_quantization_config
(
"gptq_marlin"
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
ray.util.placement_group
import
PlacementGroup
from
ray.util.placement_group
import
PlacementGroup
...
@@ -24,6 +20,7 @@ if TYPE_CHECKING:
...
@@ -24,6 +20,7 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_GB
=
1
<<
30
_GB
=
1
<<
30
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS
=
32768
class
ModelConfig
:
class
ModelConfig
:
...
@@ -48,6 +45,9 @@ class ModelConfig:
...
@@ -48,6 +45,9 @@ class ModelConfig:
code_revision: The specific revision to use for the model code on
code_revision: The specific revision to use for the model code on
Hugging Face Hub. It can be a branch name, a tag name, or a
Hugging Face Hub. It can be a branch name, a tag name, or a
commit id. If unspecified, will use the default version.
commit id. If unspecified, will use the default version.
rope_scaling: Dictionary containing the scaling configuration for the
RoPE embeddings. When using this flag, don't update
`max_position_embeddings` to the expected new maximum.
tokenizer_revision: The specific tokenizer version to use. It can be a
tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id. If unspecified, will use
branch name, a tag name, or a commit id. If unspecified, will use
the default version.
the default version.
...
@@ -69,6 +69,10 @@ class ModelConfig:
...
@@ -69,6 +69,10 @@ class ModelConfig:
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
When a sequence has context length larger than this, we fall back
to eager mode
to eager mode
disable_sliding_window: Whether to disable sliding window. If True,
we will disable the sliding window functionality of the model.
If the model does not support sliding window, this argument is
ignored.
skip_tokenizer_init: If true, skip initialization of tokenizer and
skip_tokenizer_init: If true, skip initialization of tokenizer and
detokenizer.
detokenizer.
served_model_name: The model name used in metrics tag `model_name`,
served_model_name: The model name used in metrics tag `model_name`,
...
@@ -87,6 +91,7 @@ class ModelConfig:
...
@@ -87,6 +91,7 @@ class ModelConfig:
seed
:
int
,
seed
:
int
,
revision
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
code_revision
:
Optional
[
str
]
=
None
,
code_revision
:
Optional
[
str
]
=
None
,
rope_scaling
:
Optional
[
dict
]
=
None
,
tokenizer_revision
:
Optional
[
str
]
=
None
,
tokenizer_revision
:
Optional
[
str
]
=
None
,
max_model_len
:
Optional
[
int
]
=
None
,
max_model_len
:
Optional
[
int
]
=
None
,
quantization
:
Optional
[
str
]
=
None
,
quantization
:
Optional
[
str
]
=
None
,
...
@@ -95,6 +100,7 @@ class ModelConfig:
...
@@ -95,6 +100,7 @@ class ModelConfig:
max_context_len_to_capture
:
Optional
[
int
]
=
None
,
max_context_len_to_capture
:
Optional
[
int
]
=
None
,
max_seq_len_to_capture
:
Optional
[
int
]
=
None
,
max_seq_len_to_capture
:
Optional
[
int
]
=
None
,
max_logprobs
:
int
=
5
,
max_logprobs
:
int
=
5
,
disable_sliding_window
:
bool
=
False
,
skip_tokenizer_init
:
bool
=
False
,
skip_tokenizer_init
:
bool
=
False
,
served_model_name
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
served_model_name
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
)
->
None
:
)
->
None
:
...
@@ -105,6 +111,7 @@ class ModelConfig:
...
@@ -105,6 +111,7 @@ class ModelConfig:
self
.
seed
=
seed
self
.
seed
=
seed
self
.
revision
=
revision
self
.
revision
=
revision
self
.
code_revision
=
code_revision
self
.
code_revision
=
code_revision
self
.
rope_scaling
=
rope_scaling
self
.
tokenizer_revision
=
tokenizer_revision
self
.
tokenizer_revision
=
tokenizer_revision
self
.
quantization
=
quantization
self
.
quantization
=
quantization
self
.
quantization_param_path
=
quantization_param_path
self
.
quantization_param_path
=
quantization_param_path
...
@@ -116,18 +123,23 @@ class ModelConfig:
...
@@ -116,18 +123,23 @@ class ModelConfig:
self
.
max_seq_len_to_capture
=
(
max_seq_len_to_capture
self
.
max_seq_len_to_capture
=
(
max_seq_len_to_capture
or
max_context_len_to_capture
)
or
max_context_len_to_capture
)
self
.
max_logprobs
=
max_logprobs
self
.
max_logprobs
=
max_logprobs
self
.
disable_sliding_window
=
disable_sliding_window
self
.
skip_tokenizer_init
=
skip_tokenizer_init
self
.
skip_tokenizer_init
=
skip_tokenizer_init
self
.
hf_config
=
get_config
(
self
.
model
,
trust_remote_code
,
revision
,
self
.
hf_config
=
get_config
(
self
.
model
,
trust_remote_code
,
revision
,
code_revision
)
code_revision
,
rope_scaling
)
self
.
hf_text_config
=
get_hf_text_config
(
self
.
hf_config
)
self
.
hf_text_config
=
get_hf_text_config
(
self
.
hf_config
)
self
.
dtype
=
_get_and_verify_dtype
(
self
.
hf_text_config
,
dtype
)
self
.
dtype
=
_get_and_verify_dtype
(
self
.
hf_text_config
,
dtype
)
self
.
max_model_len
=
_get_and_verify_max_len
(
self
.
hf_text_config
,
self
.
max_model_len
=
_get_and_verify_max_len
(
max_model_len
)
hf_config
=
self
.
hf_text_config
,
max_model_len
=
max_model_len
,
disable_sliding_window
=
self
.
disable_sliding_window
,
sliding_window_len
=
self
.
get_hf_config_sliding_window
())
self
.
served_model_name
=
get_served_model_name
(
model
,
self
.
served_model_name
=
get_served_model_name
(
model
,
served_model_name
)
served_model_name
)
if
not
self
.
skip_tokenizer_init
:
if
not
self
.
skip_tokenizer_init
:
self
.
_verify_tokenizer_mode
()
self
.
_verify_tokenizer_mode
()
self
.
_verify_embedding_mode
()
self
.
_verify_quantization
()
self
.
_verify_quantization
()
self
.
_verify_cuda_graph
()
self
.
_verify_cuda_graph
()
...
@@ -139,6 +151,22 @@ class ModelConfig:
...
@@ -139,6 +151,22 @@ class ModelConfig:
"either 'auto' or 'slow'."
)
"either 'auto' or 'slow'."
)
self
.
tokenizer_mode
=
tokenizer_mode
self
.
tokenizer_mode
=
tokenizer_mode
def
_verify_embedding_mode
(
self
)
->
None
:
architectures
=
getattr
(
self
.
hf_config
,
"architectures"
,
[])
self
.
embedding_mode
=
any
(
ModelRegistry
.
is_embedding_model
(
arch
)
for
arch
in
architectures
)
def
_parse_quant_hf_config
(
self
):
quant_cfg
=
getattr
(
self
.
hf_config
,
"quantization_config"
,
None
)
if
quant_cfg
is
None
:
# SparseML uses a "compression_config" with a "quantization_config".
compression_cfg
=
getattr
(
self
.
hf_config
,
"compression_config"
,
None
)
if
compression_cfg
is
not
None
:
quant_cfg
=
compression_cfg
.
get
(
"quantization_config"
,
None
)
return
quant_cfg
def
_verify_quantization
(
self
)
->
None
:
def
_verify_quantization
(
self
)
->
None
:
supported_quantization
=
[
*
QUANTIZATION_METHODS
]
supported_quantization
=
[
*
QUANTIZATION_METHODS
]
rocm_supported_quantization
=
[
"gptq"
,
"squeezellm"
]
rocm_supported_quantization
=
[
"gptq"
,
"squeezellm"
]
...
@@ -146,40 +174,19 @@ class ModelConfig:
...
@@ -146,40 +174,19 @@ class ModelConfig:
self
.
quantization
=
self
.
quantization
.
lower
()
self
.
quantization
=
self
.
quantization
.
lower
()
# Parse quantization method from the HF model config, if available.
# Parse quantization method from the HF model config, if available.
quant_cfg
=
getattr
(
self
.
hf_config
,
"quantization_config"
,
None
)
quant_cfg
=
self
.
_parse_quant_hf_config
()
if
quant_cfg
is
not
None
:
if
quant_cfg
is
not
None
:
quant_method
=
quant_cfg
.
get
(
"quant_method"
,
""
).
lower
()
quant_method
=
quant_cfg
.
get
(
"quant_method"
,
""
).
lower
()
# compat: autogptq >=0.8.0 use checkpoint_format: str
# compat: autogptq <=0.7.1 is_marlin_format: bool
# Detect which checkpoint is it
is_format_marlin
=
(
quant_cfg
.
get
(
"checkpoint_format"
)
==
"marlin"
for
_
,
method
in
QUANTIZATION_METHODS
.
items
():
or
quant_cfg
.
get
(
"is_marlin_format"
,
False
))
quantization_override
=
method
.
override_quantization_method
(
quant_cfg
,
self
.
quantization
)
# Check which LinearMethod the GPTQ model should use.
if
quantization_override
:
if
quant_method
==
"gptq"
:
quant_method
=
quantization_override
# If serialized in Marlin format, use MarlinLinearMethod.
self
.
quantization
=
quantization_override
# TODO (@robertgshaw): migrate under GPTQMarlinLinearMethod.
break
if
is_format_marlin
:
logger
.
info
(
"The model is serialized in Marlin format. "
"Using Marlin kernel."
)
quant_method
=
"marlin"
if
self
.
quantization
==
"gptq"
:
self
.
quantization
=
quant_method
# If convertible to Marlin format, use GPTQMarlinLinearMethod
# unless the user explicitly specified GPTQLinearMethod.
elif
GPTQMarlinConfig
.
is_marlin_compatible
(
quant_cfg
):
if
self
.
quantization
==
"gptq"
:
logger
.
warning
(
"The model is convertible to Marlin format, but "
"you specified quantization=gptq. Use "
"quantization=marlin for faster inference."
)
else
:
logger
.
info
(
"The model is convertible to Marlin format. "
"Using Marlin kernel."
)
quant_method
=
"gptq_marlin"
if
self
.
quantization
==
"marlin"
:
self
.
quantization
=
quant_method
# Verify quantization configurations.
# Verify quantization configurations.
if
self
.
quantization
is
None
:
if
self
.
quantization
is
None
:
...
@@ -201,7 +208,8 @@ class ModelConfig:
...
@@ -201,7 +208,8 @@ class ModelConfig:
raise
ValueError
(
raise
ValueError
(
f
"
{
self
.
quantization
}
quantization is currently not "
f
"
{
self
.
quantization
}
quantization is currently not "
f
"supported in ROCm."
)
f
"supported in ROCm."
)
if
(
self
.
quantization
not
in
[
"marlin"
,
"gptq_marlin"
]):
if
(
self
.
quantization
not
in
[
"marlin"
,
"gptq_marlin_24"
,
"gptq_marlin"
]):
logger
.
warning
(
logger
.
warning
(
"%s quantization is not fully "
"%s quantization is not fully "
"optimized yet. The speed can be slower than "
"optimized yet. The speed can be slower than "
...
@@ -233,7 +241,7 @@ class ModelConfig:
...
@@ -233,7 +241,7 @@ class ModelConfig:
"must be divisible by pipeline parallel size "
"must be divisible by pipeline parallel size "
f
"(
{
pipeline_parallel_size
}
)."
)
f
"(
{
pipeline_parallel_size
}
)."
)
def
get_sliding_window
(
self
)
->
Optional
[
int
]:
def
get_
hf_config_
sliding_window
(
self
)
->
Optional
[
int
]:
"""Get the sliding window size, or None if disabled.
"""Get the sliding window size, or None if disabled.
"""
"""
...
@@ -245,6 +253,15 @@ class ModelConfig:
...
@@ -245,6 +253,15 @@ class ModelConfig:
return
None
return
None
return
getattr
(
self
.
hf_text_config
,
"sliding_window"
,
None
)
return
getattr
(
self
.
hf_text_config
,
"sliding_window"
,
None
)
def
get_sliding_window
(
self
)
->
Optional
[
int
]:
"""Get the sliding window size, or None if disabled.
"""
# If user disables sliding window, return None.
if
self
.
disable_sliding_window
:
return
None
# Otherwise get the value from the hf config.
return
self
.
get_hf_config_sliding_window
()
def
get_vocab_size
(
self
)
->
int
:
def
get_vocab_size
(
self
)
->
int
:
return
self
.
hf_text_config
.
vocab_size
return
self
.
hf_text_config
.
vocab_size
...
@@ -349,6 +366,7 @@ class CacheConfig:
...
@@ -349,6 +366,7 @@ class CacheConfig:
self
.
enable_prefix_caching
=
enable_prefix_caching
self
.
enable_prefix_caching
=
enable_prefix_caching
self
.
_verify_args
()
self
.
_verify_args
()
self
.
_verify_cache_dtype
()
self
.
_verify_cache_dtype
()
self
.
_verify_prefix_caching
()
# Will be set after profiling.
# Will be set after profiling.
self
.
num_gpu_blocks
=
None
self
.
num_gpu_blocks
=
None
...
@@ -368,24 +386,28 @@ class CacheConfig:
...
@@ -368,24 +386,28 @@ class CacheConfig:
def
_verify_cache_dtype
(
self
)
->
None
:
def
_verify_cache_dtype
(
self
)
->
None
:
if
self
.
cache_dtype
==
"auto"
:
if
self
.
cache_dtype
==
"auto"
:
pass
pass
elif
self
.
cache_dtype
==
"fp8"
:
elif
self
.
cache_dtype
in
(
"fp8"
,
"fp8_e4m3"
,
"fp8_e5m2"
):
if
not
is_hip
():
nvcc_cuda_version
=
get_nvcc_cuda_version
()
if
nvcc_cuda_version
is
not
None
\
and
nvcc_cuda_version
<
Version
(
"11.8"
):
raise
ValueError
(
"FP8 is not supported when cuda version is"
"lower than 11.8."
)
logger
.
info
(
logger
.
info
(
"Using fp8 data type to store kv cache. It reduces the GPU "
"Using fp8 data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. "
"memory footprint and boosts the performance. "
"But it may cause slight accuracy drop without scaling "
"Meanwhile, it may cause accuracy drop without a proper "
"factors. FP8_E5M2 (without scaling) is only supported on "
"scaling factor"
)
"cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 "
"is instead supported for common inference criteria."
)
else
:
else
:
raise
ValueError
(
f
"Unknown kv cache dtype:
{
self
.
cache_dtype
}
"
)
raise
ValueError
(
f
"Unknown kv cache dtype:
{
self
.
cache_dtype
}
"
)
def
_verify_prefix_caching
(
self
)
->
None
:
if
not
self
.
enable_prefix_caching
:
return
if
self
.
sliding_window
is
not
None
:
raise
NotImplementedError
(
"Prefix caching is not supported with sliding window. "
"Run with --disable-sliding-window to use prefix caching."
)
if
self
.
cache_dtype
==
"fp8"
:
raise
NotImplementedError
(
"Prefix caching is not supported for fp8 cache_dtype. "
"Run with --kv-cache-dtype auto to use prefix caching."
)
def
verify_with_parallel_config
(
def
verify_with_parallel_config
(
self
,
self
,
parallel_config
:
"ParallelConfig"
,
parallel_config
:
"ParallelConfig"
,
...
@@ -464,6 +486,7 @@ class LoadFormat(str, enum.Enum):
...
@@ -464,6 +486,7 @@ class LoadFormat(str, enum.Enum):
NPCACHE
=
"npcache"
NPCACHE
=
"npcache"
DUMMY
=
"dummy"
DUMMY
=
"dummy"
TENSORIZER
=
"tensorizer"
TENSORIZER
=
"tensorizer"
SHARDED_STATE
=
"sharded_state"
@
dataclass
@
dataclass
...
@@ -522,9 +545,7 @@ class ParallelConfig:
...
@@ -522,9 +545,7 @@ class ParallelConfig:
Args:
Args:
pipeline_parallel_size: Number of pipeline parallel groups.
pipeline_parallel_size: Number of pipeline parallel groups.
tensor_parallel_size: Number of tensor parallel groups.
tensor_parallel_size: Number of tensor parallel groups.
worker_use_ray: Whether to use Ray for model workers. Will be set to
worker_use_ray: Deprecated, use distributed_executor_backend instead.
True if either pipeline_parallel_size or tensor_parallel_size is
greater than 1.
max_parallel_loading_workers: Maximum number of multiple batches
max_parallel_loading_workers: Maximum number of multiple batches
when load model sequentially. To avoid RAM OOM when using tensor
when load model sequentially. To avoid RAM OOM when using tensor
parallel and large models.
parallel and large models.
...
@@ -534,22 +555,28 @@ class ParallelConfig:
...
@@ -534,22 +555,28 @@ class ParallelConfig:
If None, will use synchronous tokenization.
If None, will use synchronous tokenization.
ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
placement_group: ray distributed model workers placement group.
distributed_executor_backend: Backend to use for distributed model
workers, either "ray" or "mp" (multiprocessing). If either
pipeline_parallel_size or tensor_parallel_size is greater than 1,
will default to "ray" if Ray is installed or "mp" otherwise.
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
pipeline_parallel_size
:
int
,
pipeline_parallel_size
:
int
,
tensor_parallel_size
:
int
,
tensor_parallel_size
:
int
,
worker_use_ray
:
bool
,
worker_use_ray
:
Optional
[
bool
]
=
None
,
max_parallel_loading_workers
:
Optional
[
int
]
=
None
,
max_parallel_loading_workers
:
Optional
[
int
]
=
None
,
disable_custom_all_reduce
:
bool
=
False
,
disable_custom_all_reduce
:
bool
=
False
,
tokenizer_pool_config
:
Optional
[
TokenizerPoolConfig
]
=
None
,
tokenizer_pool_config
:
Optional
[
TokenizerPoolConfig
]
=
None
,
ray_workers_use_nsight
:
bool
=
False
,
ray_workers_use_nsight
:
bool
=
False
,
placement_group
:
Optional
[
"PlacementGroup"
]
=
None
,
placement_group
:
Optional
[
"PlacementGroup"
]
=
None
,
distributed_executor_backend
:
Optional
[
str
]
=
None
,
)
->
None
:
)
->
None
:
self
.
pipeline_parallel_size
=
pipeline_parallel_size
self
.
pipeline_parallel_size
=
pipeline_parallel_size
self
.
tensor_parallel_size
=
tensor_parallel_size
self
.
tensor_parallel_size
=
tensor_parallel_size
self
.
worker_use_ray
=
worker_use_ray
self
.
distributed_executor_backend
=
distributed_executor_backend
self
.
max_parallel_loading_workers
=
max_parallel_loading_workers
self
.
max_parallel_loading_workers
=
max_parallel_loading_workers
self
.
disable_custom_all_reduce
=
disable_custom_all_reduce
self
.
disable_custom_all_reduce
=
disable_custom_all_reduce
self
.
tokenizer_pool_config
=
tokenizer_pool_config
self
.
tokenizer_pool_config
=
tokenizer_pool_config
...
@@ -557,14 +584,29 @@ class ParallelConfig:
...
@@ -557,14 +584,29 @@ class ParallelConfig:
self
.
placement_group
=
placement_group
self
.
placement_group
=
placement_group
self
.
world_size
=
pipeline_parallel_size
*
self
.
tensor_parallel_size
self
.
world_size
=
pipeline_parallel_size
*
self
.
tensor_parallel_size
if
self
.
world_size
>
1
:
if
worker_use_ray
:
self
.
worker_use_ray
=
True
if
self
.
distributed_executor_backend
is
None
:
self
.
distributed_executor_backend
=
"ray"
elif
self
.
distributed_executor_backend
!=
"ray"
:
raise
ValueError
(
f
"worker-use-ray can't be used with "
f
"distributed executor backend "
f
"'
{
self
.
distributed_executor_backend
}
'."
)
if
self
.
distributed_executor_backend
is
None
and
self
.
world_size
>
1
:
from
vllm.executor
import
ray_utils
ray_found
=
ray_utils
.
ray
is
not
None
self
.
distributed_executor_backend
=
"ray"
if
ray_found
else
"mp"
self
.
_verify_args
()
self
.
_verify_args
()
def
_verify_args
(
self
)
->
None
:
def
_verify_args
(
self
)
->
None
:
if
self
.
pipeline_parallel_size
>
1
:
if
self
.
pipeline_parallel_size
>
1
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"Pipeline parallelism is not supported yet."
)
"Pipeline parallelism is not supported yet."
)
if
self
.
distributed_executor_backend
not
in
(
"ray"
,
"mp"
,
None
):
raise
ValueError
(
"Unrecognized distributed executor backend. Supported values "
"are 'ray' or 'mp'."
)
if
not
self
.
disable_custom_all_reduce
and
self
.
world_size
>
1
:
if
not
self
.
disable_custom_all_reduce
and
self
.
world_size
>
1
:
if
is_hip
():
if
is_hip
():
self
.
disable_custom_all_reduce
=
True
self
.
disable_custom_all_reduce
=
True
...
@@ -576,7 +618,8 @@ class ParallelConfig:
...
@@ -576,7 +618,8 @@ class ParallelConfig:
logger
.
info
(
logger
.
info
(
"Disabled the custom all-reduce kernel because it is not "
"Disabled the custom all-reduce kernel because it is not "
"supported with pipeline parallelism."
)
"supported with pipeline parallelism."
)
if
self
.
ray_workers_use_nsight
and
not
self
.
worker_use_ray
:
if
self
.
ray_workers_use_nsight
and
(
not
self
.
distributed_executor_backend
==
"ray"
):
raise
ValueError
(
"Unable to use nsight profiling unless workers "
raise
ValueError
(
"Unable to use nsight profiling unless workers "
"run with Ray."
)
"run with Ray."
)
...
@@ -600,6 +643,7 @@ class SchedulerConfig:
...
@@ -600,6 +643,7 @@ class SchedulerConfig:
prompt latency) before scheduling next prompt.
prompt latency) before scheduling next prompt.
enable_chunked_prefill: If True, prefill requests can be chunked based
enable_chunked_prefill: If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens.
on the remaining max_num_batched_tokens.
embedding_mode: Whether the running model is for embedding.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -611,6 +655,7 @@ class SchedulerConfig:
...
@@ -611,6 +655,7 @@ class SchedulerConfig:
num_lookahead_slots
:
int
=
0
,
num_lookahead_slots
:
int
=
0
,
delay_factor
:
float
=
0.0
,
delay_factor
:
float
=
0.0
,
enable_chunked_prefill
:
bool
=
False
,
enable_chunked_prefill
:
bool
=
False
,
embedding_mode
:
Optional
[
bool
]
=
False
,
)
->
None
:
)
->
None
:
if
max_num_batched_tokens
is
not
None
:
if
max_num_batched_tokens
is
not
None
:
self
.
max_num_batched_tokens
=
max_num_batched_tokens
self
.
max_num_batched_tokens
=
max_num_batched_tokens
...
@@ -619,6 +664,10 @@ class SchedulerConfig:
...
@@ -619,6 +664,10 @@ class SchedulerConfig:
# It is the values that have the best balance between ITL
# It is the values that have the best balance between ITL
# and TTFT on A100. Note it is not optimized for throughput.
# and TTFT on A100. Note it is not optimized for throughput.
self
.
max_num_batched_tokens
=
512
self
.
max_num_batched_tokens
=
512
elif
embedding_mode
:
# For embedding, choose specific value for higher throughput
self
.
max_num_batched_tokens
=
max
(
max_model_len
,
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS
)
else
:
else
:
# If max_model_len is too short, use 2048 as the default value
# If max_model_len is too short, use 2048 as the default value
# for higher throughput.
# for higher throughput.
...
@@ -632,6 +681,7 @@ class SchedulerConfig:
...
@@ -632,6 +681,7 @@ class SchedulerConfig:
self
.
num_lookahead_slots
=
num_lookahead_slots
self
.
num_lookahead_slots
=
num_lookahead_slots
self
.
delay_factor
=
delay_factor
self
.
delay_factor
=
delay_factor
self
.
chunked_prefill_enabled
=
enable_chunked_prefill
self
.
chunked_prefill_enabled
=
enable_chunked_prefill
self
.
embedding_mode
=
embedding_mode
self
.
_verify_args
()
self
.
_verify_args
()
...
@@ -701,6 +751,7 @@ class SpeculativeConfig:
...
@@ -701,6 +751,7 @@ class SpeculativeConfig:
speculative_max_model_len
:
Optional
[
int
],
speculative_max_model_len
:
Optional
[
int
],
enable_chunked_prefill
:
bool
,
enable_chunked_prefill
:
bool
,
use_v2_block_manager
:
bool
,
use_v2_block_manager
:
bool
,
speculative_disable_by_batch_size
:
Optional
[
int
],
ngram_prompt_lookup_max
:
Optional
[
int
],
ngram_prompt_lookup_max
:
Optional
[
int
],
ngram_prompt_lookup_min
:
Optional
[
int
],
ngram_prompt_lookup_min
:
Optional
[
int
],
)
->
Optional
[
"SpeculativeConfig"
]:
)
->
Optional
[
"SpeculativeConfig"
]:
...
@@ -729,6 +780,9 @@ class SpeculativeConfig:
...
@@ -729,6 +780,9 @@ class SpeculativeConfig:
use_v2_block_manager (bool): Whether vLLM is configured to use the
use_v2_block_manager (bool): Whether vLLM is configured to use the
v2 block manager or not. Used for raising an error since the v2
v2 block manager or not. Used for raising an error since the v2
block manager is required with spec decode.
block manager is required with spec decode.
speculative_disable_by_batch_size (Optional[int]): Disable
speculative decoding for new incoming requests when the number
of enqueue requests is larger than this value, if provided.
ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
window, if provided.
window, if provided.
ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
...
@@ -739,7 +793,7 @@ class SpeculativeConfig:
...
@@ -739,7 +793,7 @@ class SpeculativeConfig:
the necessary conditions are met, else None.
the necessary conditions are met, else None.
"""
"""
if
(
speculative_model
is
None
and
num_speculative_tokens
is
None
)
:
if
speculative_model
is
None
and
num_speculative_tokens
is
None
:
return
None
return
None
if
speculative_model
is
not
None
and
num_speculative_tokens
is
None
:
if
speculative_model
is
not
None
and
num_speculative_tokens
is
None
:
...
@@ -748,6 +802,12 @@ class SpeculativeConfig:
...
@@ -748,6 +802,12 @@ class SpeculativeConfig:
"num_speculative_tokens to be provided, but found "
"num_speculative_tokens to be provided, but found "
f
"
{
speculative_model
=
}
and
{
num_speculative_tokens
=
}
."
)
f
"
{
speculative_model
=
}
and
{
num_speculative_tokens
=
}
."
)
if
(
speculative_disable_by_batch_size
is
not
None
and
speculative_disable_by_batch_size
<
2
):
raise
ValueError
(
"Expect the batch size threshold of disabling "
"speculative decoding is > 1, but got "
f
"
{
speculative_disable_by_batch_size
=
}
"
)
assert
(
speculative_model
is
not
None
assert
(
speculative_model
is
not
None
and
num_speculative_tokens
is
not
None
)
and
num_speculative_tokens
is
not
None
)
...
@@ -768,12 +828,15 @@ class SpeculativeConfig:
...
@@ -768,12 +828,15 @@ class SpeculativeConfig:
draft_quantization
=
None
draft_quantization
=
None
if
speculative_model
==
"[ngram]"
:
if
speculative_model
==
"[ngram]"
:
assert
(
ngram_prompt_lookup_max
is
not
None
and
ngram_prompt_lookup_max
>
0
)
if
ngram_prompt_lookup_min
is
None
:
if
ngram_prompt_lookup_min
is
None
:
ngram_prompt_lookup_min
=
0
ngram_prompt_lookup_min
=
1
else
:
if
ngram_prompt_lookup_max
is
None
or
ngram_prompt_lookup_max
<
1
:
assert
ngram_prompt_lookup_max
>
ngram_prompt_lookup_min
raise
ValueError
(
f
"
{
ngram_prompt_lookup_max
=
}
must be > 0"
)
if
ngram_prompt_lookup_min
<
1
:
raise
ValueError
(
f
"
{
ngram_prompt_lookup_min
=
}
must be > 0"
)
if
ngram_prompt_lookup_min
>
ngram_prompt_lookup_max
:
raise
ValueError
(
f
"
{
ngram_prompt_lookup_min
=
}
cannot be "
f
"larger than
{
ngram_prompt_lookup_max
=
}
"
)
# TODO: current we still need extract vocab_size from target model
# TODO: current we still need extract vocab_size from target model
# config, in future, we may try refactor it out, and set
# config, in future, we may try refactor it out, and set
...
@@ -816,6 +879,7 @@ class SpeculativeConfig:
...
@@ -816,6 +879,7 @@ class SpeculativeConfig:
draft_model_config
,
draft_model_config
,
draft_parallel_config
,
draft_parallel_config
,
num_speculative_tokens
,
num_speculative_tokens
,
speculative_disable_by_batch_size
,
ngram_prompt_lookup_max
,
ngram_prompt_lookup_max
,
ngram_prompt_lookup_min
,
ngram_prompt_lookup_min
,
)
)
...
@@ -867,7 +931,8 @@ class SpeculativeConfig:
...
@@ -867,7 +931,8 @@ class SpeculativeConfig:
pipeline_parallel_size
=
target_parallel_config
.
pipeline_parallel_size
=
target_parallel_config
.
pipeline_parallel_size
,
pipeline_parallel_size
,
tensor_parallel_size
=
target_parallel_config
.
tensor_parallel_size
,
tensor_parallel_size
=
target_parallel_config
.
tensor_parallel_size
,
worker_use_ray
=
target_parallel_config
.
worker_use_ray
,
distributed_executor_backend
=
target_parallel_config
.
distributed_executor_backend
,
max_parallel_loading_workers
=
target_parallel_config
.
max_parallel_loading_workers
=
target_parallel_config
.
max_parallel_loading_workers
,
max_parallel_loading_workers
,
disable_custom_all_reduce
=
target_parallel_config
.
disable_custom_all_reduce
=
target_parallel_config
.
...
@@ -885,8 +950,9 @@ class SpeculativeConfig:
...
@@ -885,8 +950,9 @@ class SpeculativeConfig:
draft_model_config
:
ModelConfig
,
draft_model_config
:
ModelConfig
,
draft_parallel_config
:
ParallelConfig
,
draft_parallel_config
:
ParallelConfig
,
num_speculative_tokens
:
int
,
num_speculative_tokens
:
int
,
ngram_prompt_lookup_max
:
int
,
speculative_disable_by_batch_size
:
Optional
[
int
],
ngram_prompt_lookup_min
:
int
,
ngram_prompt_lookup_max
:
Optional
[
int
],
ngram_prompt_lookup_min
:
Optional
[
int
],
):
):
"""Create a SpeculativeConfig object.
"""Create a SpeculativeConfig object.
...
@@ -895,12 +961,19 @@ class SpeculativeConfig:
...
@@ -895,12 +961,19 @@ class SpeculativeConfig:
draft_parallel_config: ParallelConfig for the draft model.
draft_parallel_config: ParallelConfig for the draft model.
num_speculative_tokens: The number of tokens to sample from the
num_speculative_tokens: The number of tokens to sample from the
draft model before scoring with the target model.
draft model before scoring with the target model.
speculative_disable_by_batch_size: Disable speculative
decoding for new incoming requests when the number of
enqueue requests is larger than this value.
ngram_prompt_lookup_max: Max size of ngram token window.
ngram_prompt_lookup_min: Min size of ngram token window.
"""
"""
self
.
draft_model_config
=
draft_model_config
self
.
draft_model_config
=
draft_model_config
self
.
draft_parallel_config
=
draft_parallel_config
self
.
draft_parallel_config
=
draft_parallel_config
self
.
num_speculative_tokens
=
num_speculative_tokens
self
.
num_speculative_tokens
=
num_speculative_tokens
self
.
ngram_prompt_lookup_max
=
ngram_prompt_lookup_max
self
.
speculative_disable_by_batch_size
=
\
self
.
ngram_prompt_lookup_min
=
ngram_prompt_lookup_min
speculative_disable_by_batch_size
self
.
ngram_prompt_lookup_max
=
ngram_prompt_lookup_max
or
0
self
.
ngram_prompt_lookup_min
=
ngram_prompt_lookup_min
or
0
self
.
_verify_args
()
self
.
_verify_args
()
...
@@ -942,6 +1015,7 @@ class LoRAConfig:
...
@@ -942,6 +1015,7 @@ class LoRAConfig:
lora_extra_vocab_size
:
int
=
256
lora_extra_vocab_size
:
int
=
256
# This is a constant.
# This is a constant.
lora_vocab_padding_size
:
ClassVar
[
int
]
=
256
lora_vocab_padding_size
:
ClassVar
[
int
]
=
256
long_lora_scaling_factors
:
Optional
[
Tuple
[
float
]]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h
...
@@ -1034,7 +1108,7 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
...
@@ -1034,7 +1108,7 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
"bfloat16"
:
torch
.
bfloat16
,
"bfloat16"
:
torch
.
bfloat16
,
}
}
_ROCM_NOT_SUPPORTED_DTYPE
=
[
"float"
,
"float32"
]
_ROCM_NOT_SUPPORTED_DTYPE
:
List
[
str
]
=
[]
#
def
_get_and_verify_dtype
(
def
_get_and_verify_dtype
(
...
@@ -1053,6 +1127,7 @@ def _get_and_verify_dtype(
...
@@ -1053,6 +1127,7 @@ def _get_and_verify_dtype(
if
config_dtype
==
torch
.
float32
:
if
config_dtype
==
torch
.
float32
:
# Following the common practice, we use float16 for float32
# Following the common practice, we use float16 for float32
# models.
# models.
logger
.
info
(
"Casting torch.float32 to torch.float16."
)
torch_dtype
=
torch
.
float16
torch_dtype
=
torch
.
float16
else
:
else
:
torch_dtype
=
config_dtype
torch_dtype
=
config_dtype
...
@@ -1065,21 +1140,15 @@ def _get_and_verify_dtype(
...
@@ -1065,21 +1140,15 @@ def _get_and_verify_dtype(
else
:
else
:
raise
ValueError
(
f
"Unknown dtype:
{
dtype
}
"
)
raise
ValueError
(
f
"Unknown dtype:
{
dtype
}
"
)
if
is_hip
()
and
torch_dtype
==
torch
.
float32
:
rocm_supported_dtypes
=
[
k
for
k
,
v
in
_STR_DTYPE_TO_TORCH_DTYPE
.
items
()
if
(
k
not
in
_ROCM_NOT_SUPPORTED_DTYPE
)
]
raise
ValueError
(
f
"dtype '
{
dtype
}
' is not supported in ROCm. "
f
"Supported dtypes are
{
rocm_supported_dtypes
}
"
)
# Verify the dtype.
# Verify the dtype.
if
torch_dtype
!=
config_dtype
:
if
torch_dtype
!=
config_dtype
:
if
torch_dtype
==
torch
.
float32
:
if
torch_dtype
==
torch
.
float32
:
# Upcasting to float32 is allowed.
# Upcasting to float32 is allowed.
logger
.
info
(
"Upcasting %s to %s."
,
config_dtype
,
torch_dtype
)
pass
pass
elif
config_dtype
==
torch
.
float32
:
elif
config_dtype
==
torch
.
float32
:
# Downcasting from float32 to float16 or bfloat16 is allowed.
# Downcasting from float32 to float16 or bfloat16 is allowed.
logger
.
info
(
"Downcasting %s to %s."
,
config_dtype
,
torch_dtype
)
pass
pass
else
:
else
:
# Casting between float16 and bfloat16 is allowed with a warning.
# Casting between float16 and bfloat16 is allowed with a warning.
...
@@ -1091,6 +1160,8 @@ def _get_and_verify_dtype(
...
@@ -1091,6 +1160,8 @@ def _get_and_verify_dtype(
def
_get_and_verify_max_len
(
def
_get_and_verify_max_len
(
hf_config
:
PretrainedConfig
,
hf_config
:
PretrainedConfig
,
max_model_len
:
Optional
[
int
],
max_model_len
:
Optional
[
int
],
disable_sliding_window
:
bool
,
sliding_window_len
:
Optional
[
int
],
)
->
int
:
)
->
int
:
"""Get and verify the model's maximum length."""
"""Get and verify the model's maximum length."""
derived_max_model_len
=
float
(
"inf"
)
derived_max_model_len
=
float
(
"inf"
)
...
@@ -1110,6 +1181,7 @@ def _get_and_verify_max_len(
...
@@ -1110,6 +1181,7 @@ def _get_and_verify_max_len(
"max_seq_length"
,
"max_seq_length"
,
"seq_len"
,
"seq_len"
,
]
]
# Choose the smallest "max_length" from the possible keys.
max_len_key
=
None
max_len_key
=
None
for
key
in
possible_keys
:
for
key
in
possible_keys
:
max_len
=
getattr
(
hf_config
,
key
,
None
)
max_len
=
getattr
(
hf_config
,
key
,
None
)
...
@@ -1117,6 +1189,16 @@ def _get_and_verify_max_len(
...
@@ -1117,6 +1189,16 @@ def _get_and_verify_max_len(
max_len_key
=
key
if
max_len
<
derived_max_model_len
\
max_len_key
=
key
if
max_len
<
derived_max_model_len
\
else
max_len_key
else
max_len_key
derived_max_model_len
=
min
(
derived_max_model_len
,
max_len
)
derived_max_model_len
=
min
(
derived_max_model_len
,
max_len
)
# If sliding window is manually disabled, max_length should be less
# than the sliding window length in the model config.
if
disable_sliding_window
and
sliding_window_len
is
not
None
:
max_len_key
=
"sliding_window"
\
if
sliding_window_len
<
derived_max_model_len
else
max_len_key
derived_max_model_len
=
min
(
derived_max_model_len
,
sliding_window_len
)
# If none of the keys were found in the config, use a default and
# log a warning.
if
derived_max_model_len
==
float
(
"inf"
):
if
derived_max_model_len
==
float
(
"inf"
):
if
max_model_len
is
not
None
:
if
max_model_len
is
not
None
:
# If max_model_len is specified, we use it.
# If max_model_len is specified, we use it.
...
@@ -1132,6 +1214,13 @@ def _get_and_verify_max_len(
...
@@ -1132,6 +1214,13 @@ def _get_and_verify_max_len(
rope_scaling
=
getattr
(
hf_config
,
"rope_scaling"
,
None
)
rope_scaling
=
getattr
(
hf_config
,
"rope_scaling"
,
None
)
if
rope_scaling
is
not
None
and
rope_scaling
[
"type"
]
!=
"su"
:
if
rope_scaling
is
not
None
and
rope_scaling
[
"type"
]
!=
"su"
:
if
disable_sliding_window
:
# TODO(robertgshaw): Find a model that supports rope_scaling
# with sliding window to see if this case should be allowed.
raise
NotImplementedError
(
"Disabling sliding window is not supported for models "
"with rope_scaling. Please raise an issue so we can "
"investigate."
)
assert
"factor"
in
rope_scaling
assert
"factor"
in
rope_scaling
scaling_factor
=
rope_scaling
[
"factor"
]
scaling_factor
=
rope_scaling
[
"factor"
]
if
rope_scaling
[
"type"
]
==
"yarn"
:
if
rope_scaling
[
"type"
]
==
"yarn"
:
...
@@ -1139,6 +1228,8 @@ def _get_and_verify_max_len(
...
@@ -1139,6 +1228,8 @@ def _get_and_verify_max_len(
"original_max_position_embeddings"
]
"original_max_position_embeddings"
]
derived_max_model_len
*=
scaling_factor
derived_max_model_len
*=
scaling_factor
# If the user specified a max length, make sure it is smaller than the
# derived length from the HF model config.
if
max_model_len
is
None
:
if
max_model_len
is
None
:
max_model_len
=
int
(
derived_max_model_len
)
max_model_len
=
int
(
derived_max_model_len
)
elif
max_model_len
>
derived_max_model_len
:
elif
max_model_len
>
derived_max_model_len
:
...
@@ -1147,6 +1238,13 @@ def _get_and_verify_max_len(
...
@@ -1147,6 +1238,13 @@ def _get_and_verify_max_len(
# with model_max_length and allow this override when it's smaller.
# with model_max_length and allow this override when it's smaller.
model_max_length
=
getattr
(
hf_config
,
"model_max_length"
,
None
)
model_max_length
=
getattr
(
hf_config
,
"model_max_length"
,
None
)
if
model_max_length
is
not
None
and
max_model_len
<=
model_max_length
:
if
model_max_length
is
not
None
and
max_model_len
<=
model_max_length
:
if
disable_sliding_window
:
# TODO(robertgshaw): Find a model that has model_max_length
# with sliding window to see if this case should be allowed.
raise
NotImplementedError
(
"Disabling sliding window is not supported for models "
"model_max_length in the config. Please raise an issue "
"so we can investigate."
)
pass
pass
else
:
else
:
raise
ValueError
(
raise
ValueError
(
...
...
vllm/core/block/block_table.py
View file @
b9e12416
...
@@ -20,6 +20,10 @@ class BlockTable:
...
@@ -20,6 +20,10 @@ class BlockTable:
_blocks (Optional[List[Block]], optional): An optional list of existing
_blocks (Optional[List[Block]], optional): An optional list of existing
blocks to initialize the BlockTable with. If not provided, an empty
blocks to initialize the BlockTable with. If not provided, an empty
BlockTable is created.
BlockTable is created.
max_block_sliding_window (Optional[int], optional): The number of
blocks to keep around for each sequance. If None, all blocks
are kept (eg., when sliding window is not used).
It should at least fit the sliding window size of the model.
Attributes:
Attributes:
_block_size (int): The maximum number of tokens that can be stored in a
_block_size (int): The maximum number of tokens that can be stored in a
...
@@ -37,6 +41,7 @@ class BlockTable:
...
@@ -37,6 +41,7 @@ class BlockTable:
block_size
:
int
,
block_size
:
int
,
block_allocator
:
DeviceAwareBlockAllocator
,
block_allocator
:
DeviceAwareBlockAllocator
,
_blocks
:
Optional
[
List
[
Block
]]
=
None
,
_blocks
:
Optional
[
List
[
Block
]]
=
None
,
max_block_sliding_window
:
Optional
[
int
]
=
None
,
):
):
self
.
_block_size
=
block_size
self
.
_block_size
=
block_size
self
.
_allocator
=
block_allocator
self
.
_allocator
=
block_allocator
...
@@ -44,6 +49,7 @@ class BlockTable:
...
@@ -44,6 +49,7 @@ class BlockTable:
_blocks
=
[]
_blocks
=
[]
self
.
_blocks
:
List
[
Block
]
=
_blocks
self
.
_blocks
:
List
[
Block
]
=
_blocks
self
.
_max_block_sliding_window
=
max_block_sliding_window
# Use helper method instead of directly calculating, as blocks
# Use helper method instead of directly calculating, as blocks
# may not be allocated.
# may not be allocated.
self
.
_num_full_slots
=
len
(
self
.
_get_all_token_ids
())
self
.
_num_full_slots
=
len
(
self
.
_get_all_token_ids
())
...
@@ -89,7 +95,8 @@ class BlockTable:
...
@@ -89,7 +95,8 @@ class BlockTable:
def
append_token_ids
(
self
,
def
append_token_ids
(
self
,
token_ids
:
List
[
int
],
token_ids
:
List
[
int
],
num_lookahead_slots
:
int
=
0
)
->
None
:
num_lookahead_slots
:
int
=
0
,
num_computed_slots
:
Optional
[
int
]
=
None
)
->
None
:
"""Appends a sequence of token IDs to the existing blocks in the
"""Appends a sequence of token IDs to the existing blocks in the
BlockTable.
BlockTable.
...
@@ -104,13 +111,35 @@ class BlockTable:
...
@@ -104,13 +111,35 @@ class BlockTable:
Args:
Args:
token_ids (List[int]): The sequence of token IDs to be appended.
token_ids (List[int]): The sequence of token IDs to be appended.
num_computed_slots (Optional[int]): The number of KV cache slots
that are already filled (computed).
When sliding window is enabled, this is used to compute how many
blocks to drop at the front of the sequence.
Without sliding window, None can be passed.
Without chunked prefill, it should be the same as
_num_full_slots.
"""
"""
assert
self
.
_is_allocated
assert
self
.
_is_allocated
,
"no blocks have been allocated"
assert
len
(
self
.
_blocks
)
>
0
assert
len
(
self
.
_blocks
)
>
0
# Drop blocks that are no longer needed due to sliding window
if
self
.
_max_block_sliding_window
is
not
None
:
null_block
=
self
.
_allocator
.
allocate_or_get_null_block
()
assert
num_computed_slots
is
not
None
end_block_idx
=
(
num_computed_slots
//
self
.
_block_size
)
-
self
.
_max_block_sliding_window
for
idx
in
range
(
0
,
end_block_idx
):
b
=
self
.
_blocks
[
idx
]
if
b
is
not
null_block
:
self
.
_allocator
.
free
(
b
)
self
.
_blocks
[
idx
]
=
null_block
# Ensure there are enough empty slots for the new tokens plus
# lookahead slots
self
.
ensure_num_empty_slots
(
num_empty_slots
=
len
(
token_ids
)
+
self
.
ensure_num_empty_slots
(
num_empty_slots
=
len
(
token_ids
)
+
num_lookahead_slots
)
num_lookahead_slots
)
# Update the blocks with the new tokens
blocks
=
self
.
_blocks
[
self
.
_num_full_slots
//
self
.
_block_size
:]
blocks
=
self
.
_blocks
[
self
.
_num_full_slots
//
self
.
_block_size
:]
token_blocks
=
self
.
_chunk_token_blocks_for_append
(
token_ids
)
token_blocks
=
self
.
_chunk_token_blocks_for_append
(
token_ids
)
...
@@ -168,6 +197,7 @@ class BlockTable:
...
@@ -168,6 +197,7 @@ class BlockTable:
block_size
=
self
.
_block_size
,
block_size
=
self
.
_block_size
,
block_allocator
=
self
.
_allocator
,
block_allocator
=
self
.
_allocator
,
_blocks
=
forked_blocks
,
_blocks
=
forked_blocks
,
max_block_sliding_window
=
self
.
_max_block_sliding_window
,
)
)
def
free
(
self
)
->
None
:
def
free
(
self
)
->
None
:
...
...
vllm/core/block/common.py
View file @
b9e12416
from
collections
import
defaultdict
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Protocol
,
Tuple
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Protocol
from
vllm.core.block.interfaces
import
Block
,
BlockAllocator
from
vllm.core.block.interfaces
import
Block
,
BlockAllocator
...
@@ -111,7 +110,7 @@ class CopyOnWriteTracker:
...
@@ -111,7 +110,7 @@ class CopyOnWriteTracker:
refcounter
:
RefCounterProtocol
,
refcounter
:
RefCounterProtocol
,
allocator
:
BlockAllocator
,
allocator
:
BlockAllocator
,
):
):
self
.
_copy_on_writes
:
Dict
[
BlockId
,
List
[
BlockId
]]
=
defaultdict
(
list
)
self
.
_copy_on_writes
:
List
[
Tuple
[
BlockId
,
BlockId
]]
=
[]
self
.
_refcounter
=
refcounter
self
.
_refcounter
=
refcounter
self
.
_allocator
=
allocator
self
.
_allocator
=
allocator
...
@@ -152,25 +151,25 @@ class CopyOnWriteTracker:
...
@@ -152,25 +151,25 @@ class CopyOnWriteTracker:
# Track src/dst copy.
# Track src/dst copy.
assert
src_block_id
is
not
None
assert
src_block_id
is
not
None
assert
block_id
is
not
None
assert
block_id
is
not
None
self
.
_copy_on_writes
[
src_block_id
].
append
(
block_id
)
self
.
_copy_on_writes
.
append
((
src_block_id
,
block_id
)
)
return
block_id
return
block_id
def
clear_cows
(
self
)
->
Dict
[
BlockId
,
List
[
BlockId
]]:
def
clear_cows
(
self
)
->
List
[
Tuple
[
BlockId
,
BlockId
]]:
"""Clears the copy-on-write tracking information and returns the current
"""Clears the copy-on-write tracking information and returns the current
state.
state.
This method returns a
dictionary
mapping source block indices to
lists
This method returns a
list
mapping source block indices to
of
destination block indices for the current copy-on-write operations.
destination block indices for the current copy-on-write operations.
It then clears the internal tracking information.
It then clears the internal tracking information.
Returns:
Returns:
Dict
[BlockId,
List[
BlockId]]: A
dictionary
mapping source
List[Tuple
[BlockId, BlockId]]: A
list
mapping source
block indices to
lists of
destination block indices for the
block indices to destination block indices for the
current copy-on-write operations.
current copy-on-write operations.
"""
"""
cows
=
dict
(
self
.
_copy_on_writes
)
cows
=
self
.
_copy_on_writes
self
.
_copy_on_writes
.
clear
()
self
.
_copy_on_writes
=
[]
return
cows
return
cows
...
...
vllm/core/block/cpu_gpu_block_allocator.py
View file @
b9e12416
from
typing
import
Dict
,
FrozenSet
,
List
,
Optional
from
typing
import
Dict
,
FrozenSet
,
List
,
Optional
,
Tuple
from
vllm.core.block.interfaces
import
(
Block
,
BlockAllocator
,
BlockId
,
from
vllm.core.block.interfaces
import
(
Block
,
BlockAllocator
,
BlockId
,
DeviceAwareBlockAllocator
)
DeviceAwareBlockAllocator
)
...
@@ -105,11 +105,19 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
...
@@ -105,11 +105,19 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Device
.
GPU
:
gpu_block_allocator
,
Device
.
GPU
:
gpu_block_allocator
,
}
}
self
.
_null_block
:
Optional
[
Block
]
=
None
self
.
_block_ids_to_allocator
:
Dict
[
int
,
BlockAllocator
]
=
{}
self
.
_block_ids_to_allocator
:
Dict
[
int
,
BlockAllocator
]
=
{}
for
_
,
allocator
in
self
.
_allocators
.
items
():
for
_
,
allocator
in
self
.
_allocators
.
items
():
for
block_id
in
allocator
.
all_block_ids
:
for
block_id
in
allocator
.
all_block_ids
:
self
.
_block_ids_to_allocator
[
block_id
]
=
allocator
self
.
_block_ids_to_allocator
[
block_id
]
=
allocator
def
allocate_or_get_null_block
(
self
)
->
Block
:
if
self
.
_null_block
is
None
:
self
.
_null_block
=
NullBlock
(
self
.
allocate_mutable
(
None
,
Device
.
GPU
))
return
self
.
_null_block
def
allocate_mutable
(
self
,
prev_block
:
Optional
[
Block
],
def
allocate_mutable
(
self
,
prev_block
:
Optional
[
Block
],
device
:
Device
)
->
Block
:
device
:
Device
)
->
Block
:
"""Allocates a new mutable block on the specified device.
"""Allocates a new mutable block on the specified device.
...
@@ -149,6 +157,9 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
...
@@ -149,6 +157,9 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Args:
Args:
block (Block): The block to be freed.
block (Block): The block to be freed.
"""
"""
# Null block should never be freed
if
isinstance
(
block
,
NullBlock
):
return
block_id
=
block
.
block_id
block_id
=
block
.
block_id
assert
block_id
is
not
None
assert
block_id
is
not
None
allocator
=
self
.
_block_ids_to_allocator
[
block_id
]
allocator
=
self
.
_block_ids_to_allocator
[
block_id
]
...
@@ -165,6 +176,8 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
...
@@ -165,6 +176,8 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
List[Block]: A new list of blocks that shares the same memory as the
List[Block]: A new list of blocks that shares the same memory as the
original sequence.
original sequence.
"""
"""
# do not attempt to fork the null block
assert
not
isinstance
(
last_block
,
NullBlock
)
block_id
=
last_block
.
block_id
block_id
=
last_block
.
block_id
assert
block_id
is
not
None
assert
block_id
is
not
None
allocator
=
self
.
_block_ids_to_allocator
[
block_id
]
allocator
=
self
.
_block_ids_to_allocator
[
block_id
]
...
@@ -185,13 +198,13 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
...
@@ -185,13 +198,13 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
def
get_num_total_blocks
(
self
,
device
:
Device
)
->
int
:
def
get_num_total_blocks
(
self
,
device
:
Device
)
->
int
:
return
self
.
_allocators
[
device
].
get_num_total_blocks
()
return
self
.
_allocators
[
device
].
get_num_total_blocks
()
def
clear_copy_on_writes
(
self
)
->
Dict
[
int
,
List
[
int
]]:
def
clear_copy_on_writes
(
self
)
->
List
[
Tuple
[
int
,
int
]]:
"""Clears the copy-on-write (CoW) state and returns the mapping of
"""Clears the copy-on-write (CoW) state and returns the mapping of
source to destination block IDs.
source to destination block IDs.
Returns:
Returns:
Dict[int, List[int]]: A dictionary
mapping source block IDs to
lists
List[Tuple[int, int]]: A list
mapping source block IDs to
of
destination block IDs.
destination block IDs.
"""
"""
# CoW only supported on GPU
# CoW only supported on GPU
device
=
Device
.
GPU
device
=
Device
.
GPU
...
@@ -226,3 +239,64 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
...
@@ -226,3 +239,64 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
def
cow_block_if_not_appendable
(
self
,
block
:
Block
)
->
Optional
[
BlockId
]:
def
cow_block_if_not_appendable
(
self
,
block
:
Block
)
->
Optional
[
BlockId
]:
raise
NotImplementedError
raise
NotImplementedError
class
NullBlock
(
Block
):
"""
Null blocks are used as a placeholders for KV cache blocks that have
been dropped due to sliding window.
This implementation just wraps an ordinary block and prevents it from
being modified. It also allows for testing if a block is NullBlock
via isinstance().
"""
def
__init__
(
self
,
proxy
:
Block
):
super
().
__init__
()
self
.
_proxy
=
proxy
def
append_token_ids
(
self
,
token_ids
:
List
[
BlockId
]):
raise
ValueError
(
"null block should not be modified"
)
@
property
def
block_id
(
self
):
return
self
.
_proxy
.
block_id
@
block_id
.
setter
def
block_id
(
self
,
value
:
Optional
[
BlockId
]):
raise
ValueError
(
"null block should not be modified"
)
@
property
def
token_ids
(
self
)
->
List
[
BlockId
]:
return
self
.
_proxy
.
token_ids
@
property
def
num_empty_slots
(
self
)
->
BlockId
:
return
self
.
_proxy
.
num_empty_slots
@
property
def
is_full
(
self
):
return
self
.
_proxy
.
is_full
@
property
def
prev_block
(
self
):
return
self
.
_proxy
.
prev_block
@
property
def
computed
(
self
):
return
self
.
_proxy
.
computed
@
computed
.
setter
def
computed
(
self
,
value
):
self
.
_proxy
.
computed
=
value
@
property
def
last_accessed
(
self
)
->
float
:
return
self
.
_proxy
.
last_accessed
@
last_accessed
.
setter
def
last_accessed
(
self
,
last_accessed_ts
:
float
):
self
.
_proxy
.
last_accessed
=
last_accessed_ts
@
property
def
content_hash
(
self
):
return
self
.
_proxy
.
content_hash
vllm/core/block/interfaces.py
View file @
b9e12416
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Dict
,
FrozenSet
,
List
,
Optional
,
Protocol
from
typing
import
FrozenSet
,
List
,
Optional
,
Protocol
,
Tuple
from
vllm.utils
import
Device
from
vllm.utils
import
Device
...
@@ -122,7 +122,7 @@ class BlockAllocator(ABC):
...
@@ -122,7 +122,7 @@ class BlockAllocator(ABC):
pass
pass
@
abstractmethod
@
abstractmethod
def
clear_copy_on_writes
(
self
)
->
Dict
[
int
,
List
[
int
]]:
def
clear_copy_on_writes
(
self
)
->
List
[
Tuple
[
int
,
int
]]:
pass
pass
@
abstractmethod
@
abstractmethod
...
@@ -187,7 +187,7 @@ class DeviceAwareBlockAllocator(ABC):
...
@@ -187,7 +187,7 @@ class DeviceAwareBlockAllocator(ABC):
pass
pass
@
abstractmethod
@
abstractmethod
def
clear_copy_on_writes
(
self
)
->
Dict
[
int
,
List
[
int
]]:
def
clear_copy_on_writes
(
self
)
->
List
[
Tuple
[
int
,
int
]]:
pass
pass
@
abstractmethod
@
abstractmethod
...
@@ -203,3 +203,12 @@ class DeviceAwareBlockAllocator(ABC):
...
@@ -203,3 +203,12 @@ class DeviceAwareBlockAllocator(ABC):
def
get_common_computed_block_ids
(
def
get_common_computed_block_ids
(
self
,
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
self
,
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
pass
pass
@
abstractmethod
def
allocate_or_get_null_block
(
self
)
->
Block
:
"""
Null blocks are used as a placeholders for KV cache blocks that have
been dropped due to sliding window.
There is at most one null block per allocator.
"""
pass
vllm/core/block/naive_block.py
View file @
b9e12416
from
typing
import
Dict
,
FrozenSet
,
Iterable
,
List
,
Optional
,
Set
from
typing
import
FrozenSet
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
from
vllm.core.block.common
import
(
CopyOnWriteTracker
,
RefCounter
,
from
vllm.core.block.common
import
(
CopyOnWriteTracker
,
RefCounter
,
get_all_blocks_recursively
)
get_all_blocks_recursively
)
...
@@ -175,12 +175,12 @@ class NaiveBlockAllocator(BlockAllocator):
...
@@ -175,12 +175,12 @@ class NaiveBlockAllocator(BlockAllocator):
"""
"""
return
self
.
_cow_tracker
.
cow_block_if_not_appendable
(
block
)
return
self
.
_cow_tracker
.
cow_block_if_not_appendable
(
block
)
def
clear_copy_on_writes
(
self
)
->
Dict
[
BlockId
,
List
[
BlockId
]]:
def
clear_copy_on_writes
(
self
)
->
List
[
Tuple
[
BlockId
,
BlockId
]]:
"""Returns the copy-on-write source->destination mapping and clears it.
"""Returns the copy-on-write source->destination mapping and clears it.
Returns:
Returns:
Dict
[BlockId,
List[
BlockId]]: A
dictionary
mapping source
List[Tuple
[BlockId, BlockId]]: A
list
mapping source
block indices to
lists of
destination block indices.
block indices to destination block indices.
"""
"""
return
self
.
_cow_tracker
.
clear_cows
()
return
self
.
_cow_tracker
.
clear_cows
()
...
...
vllm/core/block/prefix_caching_block.py
View file @
b9e12416
"""Token blocks."""
"""Token blocks."""
from
itertools
import
takewhile
from
itertools
import
takewhile
from
os.path
import
commonprefix
from
os.path
import
commonprefix
from
typing
import
Dict
,
FrozenSet
,
Iterable
,
List
,
Optional
from
typing
import
Dict
,
FrozenSet
,
Iterable
,
List
,
Optional
,
Tuple
from
vllm.core.block.common
import
(
CopyOnWriteTracker
,
from
vllm.core.block.common
import
(
CopyOnWriteTracker
,
get_all_blocks_recursively
)
get_all_blocks_recursively
)
...
@@ -160,21 +160,17 @@ class PrefixCachingBlockAllocator(BlockAllocator):
...
@@ -160,21 +160,17 @@ class PrefixCachingBlockAllocator(BlockAllocator):
# If the evictor has blocks available for eviction, evict a block
# If the evictor has blocks available for eviction, evict a block
# and return it.
# and return it.
if
self
.
evictor
.
num_blocks
>
0
:
if
self
.
evictor
.
num_blocks
>
0
:
# here we get an evicted block, which is only added
# into evictor if its ref counter is 0
# and since its content would be changed, we need
# to remove it from _cached_blocks's tracking list
block_id
,
content_hash_to_evict
=
self
.
evictor
.
evict
()
block_id
,
content_hash_to_evict
=
self
.
evictor
.
evict
()
# Here we may have scenario that several blocks have
# the same content hash, but due to the latter coming block
# is coming from mutable to immutable path, their physical
# block is added into evictor.
# However in this case, we shall not pop the _cached_blocks,
# as the same content is still used by others, which means
# we need to check ref before decide to pop the list.
_block_id
=
self
.
_cached_blocks
[
content_hash_to_evict
]
_block_id
=
self
.
_cached_blocks
[
content_hash_to_evict
]
refcount
=
self
.
_refcounter
.
get
(
_block_id
)
assert
self
.
_refcounter
.
get
(
_block_id
)
==
0
if
refcount
==
1
:
assert
_block_id
==
block_id
self
.
_cached_blocks
.
pop
(
content_hash_to_evict
)
assert
_block_id
==
block_id
self
.
_cached_blocks
.
pop
(
content_hash_to_evict
)
self
.
_refcounter
.
incr
(
block_id
)
self
.
_refcounter
.
incr
(
block_id
)
...
@@ -199,7 +195,11 @@ class PrefixCachingBlockAllocator(BlockAllocator):
...
@@ -199,7 +195,11 @@ class PrefixCachingBlockAllocator(BlockAllocator):
def
_incr_refcount_cached_block
(
self
,
block
:
Block
,
def
_incr_refcount_cached_block
(
self
,
block
:
Block
,
block_id
:
BlockId
)
->
None
:
block_id
:
BlockId
)
->
None
:
# since block is already computed, mark it
# now _incr_refcount_cached_block comes from two place
# allocate_immutable/promote_to_immutable_block where hit
# _cached_blocks hash key.
# In both cases, it means that already exists a already
# computed block which shared with block now
block
.
computed
=
True
block
.
computed
=
True
refcount
=
self
.
_refcounter
.
incr
(
block_id
)
refcount
=
self
.
_refcounter
.
incr
(
block_id
)
...
@@ -228,13 +228,19 @@ class PrefixCachingBlockAllocator(BlockAllocator):
...
@@ -228,13 +228,19 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block
:
Block
)
->
None
:
block
:
Block
)
->
None
:
assert
isinstance
(
block
,
PrefixCachingBlock
)
assert
isinstance
(
block
,
PrefixCachingBlock
)
if
block
.
content_hash
is
None
:
# if we comes from promote_to_immutable_block, it means that
# block.content_hash is never None.
# However we need to release the same content block, so that
# physical block could get reused.
if
block
.
block_id
!=
block_id
or
block
.
content_hash
is
None
:
refcount
=
self
.
_refcounter
.
get
(
block_id
)
refcount
=
self
.
_refcounter
.
get
(
block_id
)
# We have fork case where block would get more than one ref,
# We have fork case where block would get more than one ref,
# so we cannot free it from tracking if ref cnt large than 1
# so we cannot free it from tracking if ref cnt large than 1
if
refcount
<=
1
:
assert
block
.
block_id
is
not
None
assert
block
.
block_id
is
not
None
refcount
=
self
.
_refcounter
.
get
(
block
.
block_id
)
if
refcount
==
1
:
del
self
.
_blocks
[
block
.
block_id
]
del
self
.
_blocks
[
block
.
block_id
]
return
self
.
_hashless_allocator
.
free
(
block
)
return
self
.
_hashless_allocator
.
free
(
block
)
refcount
=
self
.
_refcounter
.
decr
(
block_id
)
refcount
=
self
.
_refcounter
.
decr
(
block_id
)
...
@@ -317,7 +323,8 @@ class PrefixCachingBlockAllocator(BlockAllocator):
...
@@ -317,7 +323,8 @@ class PrefixCachingBlockAllocator(BlockAllocator):
if
block
.
content_hash
not
in
self
.
_cached_blocks
:
if
block
.
content_hash
not
in
self
.
_cached_blocks
:
self
.
_cached_blocks
[
block
.
content_hash
]
=
block
.
block_id
self
.
_cached_blocks
[
block
.
content_hash
]
=
block
.
block_id
else
:
else
:
self
.
_free_block_id_for_block
(
block
.
block_id
,
block
)
self
.
_free_block_id_for_block
(
self
.
_cached_blocks
[
block
.
content_hash
],
block
)
self
.
_incr_refcount_cached_block
(
self
.
_incr_refcount_cached_block
(
block
,
self
.
_cached_blocks
[
block
.
content_hash
])
block
,
self
.
_cached_blocks
[
block
.
content_hash
])
...
@@ -337,12 +344,12 @@ class PrefixCachingBlockAllocator(BlockAllocator):
...
@@ -337,12 +344,12 @@ class PrefixCachingBlockAllocator(BlockAllocator):
"""
"""
return
self
.
_cow_tracker
.
cow_block_if_not_appendable
(
block
)
return
self
.
_cow_tracker
.
cow_block_if_not_appendable
(
block
)
def
clear_copy_on_writes
(
self
)
->
Dict
[
BlockId
,
List
[
BlockId
]]:
def
clear_copy_on_writes
(
self
)
->
List
[
Tuple
[
BlockId
,
BlockId
]]:
"""Returns the copy-on-write source->destination mapping and clears it.
"""Returns the copy-on-write source->destination mapping and clears it.
Returns:
Returns:
Dict
[BlockId,
List[
BlockId]]: A
dictionary
mapping source
List[Tuple
[BlockId, BlockId]]: A
list
mapping source
block indices to
lists of
destination block indices.
block indices to destination block indices.
"""
"""
return
self
.
_cow_tracker
.
clear_cows
()
return
self
.
_cow_tracker
.
clear_cows
()
...
...
vllm/core/block/utils.py
0 → 100644
View file @
b9e12416
"""Block manager utils."""
from
vllm.sequence
import
SequenceGroup
# Exception strings for non-implemented block manager enc/dec scenarios
STR_NOT_IMPL_ENC_DEC_SWA
=
\
"Sliding window attention for encoder/decoder models "
+
\
"is not currently supported."
STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE
=
\
"Prefix caching for encoder/decoder models "
+
\
"is not currently supported."
def
_get_block_mgr_sliding_window_attr
(
block_mgr
):
'''
BlockManagerV1 and BlockManagerV2 have slightly different
members related to sliding window attention (SWA). This
function extracts the appropriate member to use for determining
whether SWA is enabled.
Arguments:
* block_mgr: BlockManagerV1 or BlockManagerV2 instance
'''
if
hasattr
(
block_mgr
,
'block_sliding_window'
):
return
block_mgr
.
block_sliding_window
if
hasattr
(
block_mgr
,
'max_block_sliding_window'
):
return
block_mgr
.
max_block_sliding_window
raise
AttributeError
(
"Block manager instance has neither "
+
\
"block_sliding_window nor "
+
\
"max_block_sliding_window attributes."
)
def
check_no_caching_or_swa_for_blockmgr_encdec
(
block_mgr
,
seq_group
:
SequenceGroup
)
->
None
:
'''
Enforce that prefix caching & sliding-window attention (SWA)
are currently unsupported *specifically* for encoder/decoder models.
Raises NotImplementedError if unsupported scenario is detected.
Arguments:
* block_mgr: BlockSpaceManager instance
* seq_group: SequenceGroup passed to block_mgr
'''
if
seq_group
.
is_encoder_decoder
():
if
_get_block_mgr_sliding_window_attr
(
block_mgr
)
is
not
None
:
raise
NotImplementedError
(
STR_NOT_IMPL_ENC_DEC_SWA
)
if
block_mgr
.
enable_caching
:
raise
NotImplementedError
(
STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE
)
vllm/core/block_manager_v1.py
View file @
b9e12416
...
@@ -5,9 +5,10 @@ from itertools import count, takewhile
...
@@ -5,9 +5,10 @@ from itertools import count, takewhile
from
os.path
import
commonprefix
from
os.path
import
commonprefix
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
from
typing
import
Set
,
Tuple
from
vllm.block
import
BlockTable
,
PhysicalTokenBlock
from
vllm.block
import
BlockTable
,
PhysicalTokenBlock
from
vllm.core.block.utils
import
check_no_caching_or_swa_for_blockmgr_encdec
from
vllm.core.evictor_v1
import
EvictionPolicy
,
Evictor
,
make_evictor
from
vllm.core.evictor_v1
import
EvictionPolicy
,
Evictor
,
make_evictor
from
vllm.core.interfaces
import
AllocStatus
,
BlockSpaceManager
from
vllm.core.interfaces
import
AllocStatus
,
BlockSpaceManager
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -255,14 +256,30 @@ class BlockSpaceManagerV1(BlockSpaceManager):
...
@@ -255,14 +256,30 @@ class BlockSpaceManagerV1(BlockSpaceManager):
Device
.
CPU
,
block_size
,
num_cpu_blocks
)
Device
.
CPU
,
block_size
,
num_cpu_blocks
)
# Mapping: seq_id -> BlockTable.
# Mapping: seq_id -> BlockTable.
self
.
block_tables
:
Dict
[
int
,
BlockTable
]
=
{}
self
.
block_tables
:
Dict
[
int
,
BlockTable
]
=
{}
# Mapping: req_id -> BlockTable
# Note that each SequenceGroup has a unique
# request ID
self
.
cross_block_tables
:
Dict
[
str
,
BlockTable
]
=
{}
def
_get_seq_num_required_blocks
(
self
,
seq
:
Sequence
)
->
int
:
return
0
if
seq
is
None
\
else
len
(
seq
.
logical_token_blocks
)
def
can_allocate
(
self
,
seq_group
:
SequenceGroup
)
->
AllocStatus
:
def
can_allocate
(
self
,
seq_group
:
SequenceGroup
)
->
AllocStatus
:
# FIXME(woosuk): Here we assume that all sequences in the group share
# FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences.
# the same prompt. This may not be true for preempted sequences.
seq
=
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
WAITING
)[
0
]
num_required_blocks
=
len
(
seq
.
logical_token_blocks
)
check_no_caching_or_swa_for_blockmgr_encdec
(
self
,
seq_group
)
self_num_required_blocks
=
self
.
_get_seq_num_required_blocks
(
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
WAITING
)[
0
])
cross_num_required_blocks
=
self
.
_get_seq_num_required_blocks
(
seq_group
.
get_encoder_seq
())
num_required_blocks
=
self_num_required_blocks
+
\
cross_num_required_blocks
if
self
.
block_sliding_window
is
not
None
:
if
self
.
block_sliding_window
is
not
None
:
num_required_blocks
=
min
(
num_required_blocks
,
num_required_blocks
=
min
(
num_required_blocks
,
self
.
block_sliding_window
)
self
.
block_sliding_window
)
num_free_gpu_blocks
=
self
.
gpu_allocator
.
get_num_free_blocks
()
num_free_gpu_blocks
=
self
.
gpu_allocator
.
get_num_free_blocks
()
...
@@ -276,11 +293,10 @@ class BlockSpaceManagerV1(BlockSpaceManager):
...
@@ -276,11 +293,10 @@ class BlockSpaceManagerV1(BlockSpaceManager):
else
:
else
:
return
AllocStatus
.
LATER
return
AllocStatus
.
LATER
def
allocate
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
def
_allocate_sequence
(
self
,
\
# NOTE: Here we assume that all sequences in the group have the same
seq
:
Sequence
,
\
# prompt.
ref_count
:
int
,
\
seq
=
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
WAITING
)[
0
]
is_encoder_decoder
:
bool
=
True
)
->
BlockTable
:
# Allocate new physical token blocks that will store the prompt tokens.
# Allocate new physical token blocks that will store the prompt tokens.
num_prompt_blocks
=
len
(
seq
.
logical_token_blocks
)
num_prompt_blocks
=
len
(
seq
.
logical_token_blocks
)
...
@@ -290,21 +306,46 @@ class BlockSpaceManagerV1(BlockSpaceManager):
...
@@ -290,21 +306,46 @@ class BlockSpaceManagerV1(BlockSpaceManager):
and
logical_idx
>=
self
.
block_sliding_window
):
and
logical_idx
>=
self
.
block_sliding_window
):
block
=
block_table
[
logical_idx
%
self
.
block_sliding_window
]
block
=
block_table
[
logical_idx
%
self
.
block_sliding_window
]
# Set the reference counts of the token blocks.
# Set the reference counts of the token blocks.
block
.
ref_count
=
seq_group
.
num_seqs
()
block
.
ref_count
=
ref_count
elif
self
.
enable_caching
:
elif
not
is_encoder_decoder
and
self
.
enable_caching
:
block
=
self
.
gpu_allocator
.
allocate
(
block
=
self
.
gpu_allocator
.
allocate
(
seq
.
hash_of_block
(
logical_idx
),
seq
.
hash_of_block
(
logical_idx
),
seq
.
num_hashed_tokens_of_block
(
logical_idx
))
seq
.
num_hashed_tokens_of_block
(
logical_idx
))
else
:
else
:
block
=
self
.
gpu_allocator
.
allocate
()
block
=
self
.
gpu_allocator
.
allocate
()
# Set the reference counts of the token blocks.
# Set the reference counts of the token blocks.
block
.
ref_count
=
seq_group
.
num_seqs
()
block
.
ref_count
=
ref_count
block_table
.
append
(
block
)
block_table
.
append
(
block
)
# Assign the block table for each sequence.
return
block_table
def
allocate
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
is_encoder_decoder
=
seq_group
.
is_encoder_decoder
()
check_no_caching_or_swa_for_blockmgr_encdec
(
self
,
seq_group
)
# Allocate decoder sequences
#
# NOTE: Here we assume that all sequences in the group have the same
# decoder prompt.
seq
=
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
WAITING
)[
0
]
block_table
:
BlockTable
=
\
self
.
_allocate_sequence
(
seq
,
seq_group
.
num_seqs
(),
is_encoder_decoder
)
# Assign the self-attention block tables for each sequence.
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
WAITING
):
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
WAITING
):
self
.
block_tables
[
seq
.
seq_id
]
=
block_table
.
copy
()
self
.
block_tables
[
seq
.
seq_id
]
=
block_table
.
copy
()
# Allocate encoder sequence
if
is_encoder_decoder
:
# A SequenceGroup has only a single encoder sequence (at most),
# thus allocate with a ref count of 1
block_table
=
self
.
_allocate_sequence
(
seq_group
.
get_encoder_seq
(),
1
,
is_encoder_decoder
)
# Assign the cross-attention block table for the SequenceGroup.
self
.
cross_block_tables
[
seq_group
.
request_id
]
=
block_table
def
can_append_slots
(
self
,
def
can_append_slots
(
self
,
seq_group
:
SequenceGroup
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
=
0
)
->
bool
:
num_lookahead_slots
:
int
=
0
)
->
bool
:
...
@@ -386,7 +427,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
...
@@ -386,7 +427,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
self
,
self
,
seq
:
Sequence
,
seq
:
Sequence
,
num_lookahead_slots
:
int
=
0
,
num_lookahead_slots
:
int
=
0
,
)
->
Dict
[
int
,
List
[
int
]]:
)
->
List
[
Tuple
[
int
,
int
]]:
"""Allocate a physical slot for a new token."""
"""Allocate a physical slot for a new token."""
logical_blocks
=
seq
.
logical_token_blocks
logical_blocks
=
seq
.
logical_token_blocks
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
...
@@ -405,7 +446,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
...
@@ -405,7 +446,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
# Allocate a new physical block.
# Allocate a new physical block.
new_block
=
self
.
_allocate_last_physical_block
(
seq
)
new_block
=
self
.
_allocate_last_physical_block
(
seq
)
block_table
.
append
(
new_block
)
block_table
.
append
(
new_block
)
return
{}
return
[]
# We want to append the token to the last physical block.
# We want to append the token to the last physical block.
last_block
=
block_table
[
-
1
]
last_block
=
block_table
[
-
1
]
...
@@ -418,7 +459,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
...
@@ -418,7 +459,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
maybe_new_block
=
self
.
_maybe_promote_last_block
(
maybe_new_block
=
self
.
_maybe_promote_last_block
(
seq
,
last_block
)
seq
,
last_block
)
block_table
[
-
1
]
=
maybe_new_block
block_table
[
-
1
]
=
maybe_new_block
return
{}
return
[]
else
:
else
:
# The last block is shared with other sequences.
# The last block is shared with other sequences.
# Copy on Write: Allocate a new block and copy the tokens.
# Copy on Write: Allocate a new block and copy the tokens.
...
@@ -426,7 +467,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
...
@@ -426,7 +467,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
block_table
[
-
1
]
=
new_block
block_table
[
-
1
]
=
new_block
self
.
gpu_allocator
.
free
(
last_block
)
self
.
gpu_allocator
.
free
(
last_block
)
return
{
last_block
.
block_number
:
[
new_block
.
block_number
]
}
return
[(
last_block
.
block_number
,
new_block
.
block_number
)
]
def
fork
(
self
,
parent_seq
:
Sequence
,
child_seq
:
Sequence
)
->
None
:
def
fork
(
self
,
parent_seq
:
Sequence
,
child_seq
:
Sequence
)
->
None
:
# NOTE: fork does not allocate a new physical block.
# NOTE: fork does not allocate a new physical block.
...
@@ -443,13 +484,18 @@ class BlockSpaceManagerV1(BlockSpaceManager):
...
@@ -443,13 +484,18 @@ class BlockSpaceManagerV1(BlockSpaceManager):
def
_get_physical_blocks
(
def
_get_physical_blocks
(
self
,
seq_group
:
SequenceGroup
)
->
List
[
PhysicalTokenBlock
]:
self
,
seq_group
:
SequenceGroup
)
->
List
[
PhysicalTokenBlock
]:
# NOTE: Here, we assume that the physical blocks are only shared by
# NOTE: Here, we assume that the physical blocks are only shared by
# the sequences in the same group.
# the sequences in the same group.
request_id
=
seq_group
.
request_id
blocks
:
Set
[
PhysicalTokenBlock
]
=
set
()
blocks
:
Set
[
PhysicalTokenBlock
]
=
set
()
for
seq
in
seq_group
.
get_seqs
():
for
seq
in
seq_group
.
get_seqs
():
if
seq
.
is_finished
():
if
seq
.
is_finished
():
continue
continue
blocks
.
update
(
self
.
block_tables
[
seq
.
seq_id
])
blocks
.
update
(
self
.
block_tables
[
seq
.
seq_id
])
# Cross-attention blocks
if
seq_group
.
is_encoder_decoder
():
blocks
.
update
(
self
.
cross_block_tables
[
request_id
])
return
list
(
blocks
)
return
list
(
blocks
)
def
can_swap_in
(
self
,
def
can_swap_in
(
self
,
...
@@ -457,8 +503,11 @@ class BlockSpaceManagerV1(BlockSpaceManager):
...
@@ -457,8 +503,11 @@ class BlockSpaceManagerV1(BlockSpaceManager):
num_lookahead_slots
:
int
=
0
)
->
AllocStatus
:
num_lookahead_slots
:
int
=
0
)
->
AllocStatus
:
assert
(
num_lookahead_slots
==
0
assert
(
num_lookahead_slots
==
0
),
"BlockSpaceManagerV1 does not support lookahead allocation"
),
"BlockSpaceManagerV1 does not support lookahead allocation"
blocks
=
self
.
_get_physical_blocks
(
seq_group
)
blocks
=
self
.
_get_physical_blocks
(
seq_group
)
num_swapped_seqs
=
seq_group
.
num_seqs
(
status
=
SequenceStatus
.
SWAPPED
)
num_swapped_seqs
=
seq_group
.
num_seqs
(
status
=
SequenceStatus
.
SWAPPED
)
if
seq_group
.
is_encoder_decoder
():
num_swapped_seqs
+=
1
num_free_blocks
=
self
.
gpu_allocator
.
get_num_free_blocks
()
num_free_blocks
=
self
.
gpu_allocator
.
get_num_free_blocks
()
# NOTE: Conservatively, we assume that every sequence will allocate
# NOTE: Conservatively, we assume that every sequence will allocate
# at least one free block right after the swap-in.
# at least one free block right after the swap-in.
...
@@ -471,66 +520,81 @@ class BlockSpaceManagerV1(BlockSpaceManager):
...
@@ -471,66 +520,81 @@ class BlockSpaceManagerV1(BlockSpaceManager):
else
:
else
:
return
AllocStatus
.
LATER
return
AllocStatus
.
LATER
def
_swap_block_table
(
self
,
block_table
:
BlockTable
,
src_allocator
:
BlockAllocatorBase
,
dest_allocator
:
BlockAllocatorBase
,
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
])
->
BlockTable
:
new_block_table
=
[]
for
from_block
in
block_table
:
if
from_block
in
mapping
:
to_block
=
mapping
[
from_block
]
to_block
.
ref_count
+=
1
else
:
to_block
=
dest_allocator
.
allocate
(
from_block
.
block_hash
,
from_block
.
num_hashed_tokens
)
mapping
[
from_block
]
=
to_block
new_block_table
.
append
(
to_block
)
# Free the source block swapped in to destination.
src_allocator
.
free
(
from_block
)
return
new_block_table
def
swap_in
(
self
,
def
swap_in
(
self
,
seq_group
:
SequenceGroup
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
=
0
)
->
Dict
[
int
,
int
]:
num_lookahead_slots
:
int
=
0
)
->
List
[
Tuple
[
int
,
int
]
]
:
assert
(
num_lookahead_slots
==
0
assert
(
num_lookahead_slots
==
0
),
"BlockSpaceManagerV1 does not support lookahead allocation"
),
"BlockSpaceManagerV1 does not support lookahead allocation"
request_id
=
seq_group
.
request_id
# CPU block -> GPU block.
# CPU block -> GPU block.
# dict is efficient in lookup `if cpu_block in mapping`
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
SWAPPED
):
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
SWAPPED
):
new_block_table
:
BlockTable
=
[]
self
.
block_tables
[
seq
.
seq_id
]
=
\
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
self
.
_swap_block_table
(
self
.
block_tables
[
seq
.
seq_id
],
self
.
cpu_allocator
,
for
cpu_block
in
block_table
:
self
.
gpu_allocator
,
if
cpu_block
in
mapping
:
mapping
)
gpu_block
=
mapping
[
cpu_block
]
gpu_block
.
ref_count
+=
1
if
seq_group
.
is_encoder_decoder
():
else
:
self
.
cross_block_tables
[
request_id
]
=
\
gpu_block
=
self
.
gpu_allocator
.
allocate
(
self
.
_swap_block_table
(
self
.
cross_block_tables
[
request_id
],
cpu_block
.
block_hash
,
cpu_block
.
num_hashed_tokens
)
self
.
cpu_allocator
,
mapping
[
cpu_block
]
=
gpu_block
self
.
gpu_allocator
,
new_block_table
.
append
(
gpu_block
)
mapping
)
# Free the CPU block swapped in to GPU.
self
.
cpu_allocator
.
free
(
cpu_block
)
return
[(
cpu_block
.
block_number
,
gpu_block
.
block_number
)
self
.
block_tables
[
seq
.
seq_id
]
=
new_block_table
for
cpu_block
,
gpu_block
in
mapping
.
items
()]
block_number_mapping
=
{
cpu_block
.
block_number
:
gpu_block
.
block_number
for
cpu_block
,
gpu_block
in
mapping
.
items
()
}
return
block_number_mapping
def
can_swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
def
can_swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
blocks
=
self
.
_get_physical_blocks
(
seq_group
)
blocks
=
self
.
_get_physical_blocks
(
seq_group
)
return
len
(
blocks
)
<=
self
.
cpu_allocator
.
get_num_free_blocks
()
return
len
(
blocks
)
<=
self
.
cpu_allocator
.
get_num_free_blocks
()
def
swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
Dict
[
int
,
int
]:
def
swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
List
[
Tuple
[
int
,
int
]]:
request_id
=
seq_group
.
request_id
# GPU block -> CPU block.
# GPU block -> CPU block.
# dict is efficient in lookup `if gpu_block in mapping`
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
new_block_table
:
BlockTable
=
[]
self
.
block_tables
[
seq
.
seq_id
]
=
\
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
self
.
_swap_block_table
(
self
.
block_tables
[
seq
.
seq_id
],
self
.
gpu_allocator
,
for
gpu_block
in
block_table
:
self
.
cpu_allocator
,
if
gpu_block
in
mapping
:
mapping
)
cpu_block
=
mapping
[
gpu_block
]
cpu_block
.
ref_count
+=
1
if
seq_group
.
is_encoder_decoder
():
else
:
self
.
cross_block_tables
[
request_id
]
=
\
cpu_block
=
self
.
cpu_allocator
.
allocate
(
self
.
_swap_block_table
(
self
.
cross_block_tables
[
request_id
],
gpu_block
.
block_hash
,
gpu_block
.
num_hashed_tokens
)
self
.
gpu_allocator
,
mapping
[
gpu_block
]
=
cpu_block
self
.
cpu_allocator
,
new_block_table
.
append
(
cpu_block
)
mapping
)
# Free the GPU block swapped out to CPU.
self
.
gpu_allocator
.
free
(
gpu_block
)
return
[(
cpu_block
.
block_number
,
gpu_block
.
block_number
)
self
.
block_tables
[
seq
.
seq_id
]
=
new_block_table
for
cpu_block
,
gpu_block
in
mapping
.
items
()]
block_number_mapping
=
{
gpu_block
.
block_number
:
cpu_block
.
block_number
for
gpu_block
,
cpu_block
in
mapping
.
items
()
}
return
block_number_mapping
def
_free_block_table
(
self
,
block_table
:
BlockTable
)
->
None
:
def
_free_block_table
(
self
,
block_table
:
BlockTable
)
->
None
:
# when using a sliding window, each seq will only use up
# when using a sliding window, each seq will only use up
...
@@ -555,15 +619,32 @@ class BlockSpaceManagerV1(BlockSpaceManager):
...
@@ -555,15 +619,32 @@ class BlockSpaceManagerV1(BlockSpaceManager):
self
.
_free_block_table
(
block_table
)
self
.
_free_block_table
(
block_table
)
del
self
.
block_tables
[
seq
.
seq_id
]
del
self
.
block_tables
[
seq
.
seq_id
]
def
free_cross
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
if
seq_group
.
request_id
not
in
self
.
cross_block_tables
:
# Already freed or hasn't ben scheduled yet.
return
block_table
=
self
.
cross_block_tables
[
seq_group
.
request_id
]
self
.
_free_block_table
(
block_table
)
del
self
.
cross_block_tables
[
seq_group
.
request_id
]
def
reset
(
self
)
->
None
:
def
reset
(
self
)
->
None
:
# Free decoder block tables
for
block_table
in
self
.
block_tables
.
values
():
for
block_table
in
self
.
block_tables
.
values
():
self
.
_free_block_table
(
block_table
)
self
.
_free_block_table
(
block_table
)
self
.
block_tables
.
clear
()
self
.
block_tables
.
clear
()
# Free cross-attention block tables
for
block_table
in
self
.
cross_block_tables
.
values
():
self
.
_free_block_table
(
block_table
)
self
.
cross_block_tables
.
clear
()
def
get_block_table
(
self
,
seq
:
Sequence
)
->
List
[
int
]:
def
get_block_table
(
self
,
seq
:
Sequence
)
->
List
[
int
]:
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
return
[
block
.
block_number
for
block
in
block_table
]
return
[
block
.
block_number
for
block
in
block_table
]
def
get_cross_block_table
(
self
,
seq_group
:
SequenceGroup
)
->
List
[
int
]:
block_table
=
self
.
cross_block_tables
[
seq_group
.
request_id
]
return
[
block
.
block_number
for
block
in
block_table
]
def
get_num_free_gpu_blocks
(
self
)
->
int
:
def
get_num_free_gpu_blocks
(
self
)
->
int
:
return
self
.
gpu_allocator
.
get_num_free_blocks
()
return
self
.
gpu_allocator
.
get_num_free_blocks
()
...
...
vllm/core/block_manager_v2.py
View file @
b9e12416
"""A block manager that manages token blocks."""
"""A block manager that manages token blocks."""
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Tuple
from
vllm.core.block.block_table
import
BlockTable
from
vllm.core.block.block_table
import
BlockTable
from
vllm.core.block.cpu_gpu_block_allocator
import
CpuGpuBlockAllocator
from
vllm.core.block.cpu_gpu_block_allocator
import
CpuGpuBlockAllocator
from
vllm.core.block.utils
import
check_no_caching_or_swa_for_blockmgr_encdec
from
vllm.core.interfaces
import
AllocStatus
,
BlockSpaceManager
from
vllm.core.interfaces
import
AllocStatus
,
BlockSpaceManager
from
vllm.sequence
import
Sequence
,
SequenceGroup
,
SequenceStatus
from
vllm.sequence
import
Sequence
,
SequenceGroup
,
SequenceStatus
from
vllm.utils
import
Device
from
vllm.utils
import
Device
SeqId
=
int
SeqId
=
int
EncoderSeqId
=
str
class
BlockSpaceManagerV2
(
BlockSpaceManager
):
class
BlockSpaceManagerV2
(
BlockSpaceManager
):
...
@@ -65,9 +68,18 @@ class BlockSpaceManagerV2(BlockSpaceManager):
...
@@ -65,9 +68,18 @@ class BlockSpaceManagerV2(BlockSpaceManager):
self
.
num_total_gpu_blocks
=
num_gpu_blocks
self
.
num_total_gpu_blocks
=
num_gpu_blocks
self
.
num_total_cpu_blocks
=
num_cpu_blocks
self
.
num_total_cpu_blocks
=
num_cpu_blocks
assert
sliding_window
is
None
,
"Sliding window not yet supported"
self
.
sliding_window
=
sliding_window
# max_block_sliding_window is the max number of blocks that need to be
self
.
block_sliding_window
=
None
# allocated
self
.
max_block_sliding_window
=
None
if
sliding_window
is
not
None
:
# +1 here because // rounds down
num_blocks
=
sliding_window
//
block_size
+
1
# +1 here because the last block may not be full,
# and so the sequence stretches one more block at the beginning
# For example, if sliding_window is 3 and block_size is 4,
# we may need 2 blocks when the second block only holds 1 token.
self
.
max_block_sliding_window
=
num_blocks
+
1
self
.
watermark
=
watermark
self
.
watermark
=
watermark
assert
watermark
>=
0.0
assert
watermark
>=
0.0
...
@@ -84,21 +96,29 @@ class BlockSpaceManagerV2(BlockSpaceManager):
...
@@ -84,21 +96,29 @@ class BlockSpaceManagerV2(BlockSpaceManager):
)
)
self
.
block_tables
:
Dict
[
SeqId
,
BlockTable
]
=
{}
self
.
block_tables
:
Dict
[
SeqId
,
BlockTable
]
=
{}
self
.
cross_block_tables
:
Dict
[
EncoderSeqId
,
BlockTable
]
=
{}
def
can_allocate
(
self
,
seq_group
:
SequenceGroup
)
->
AllocStatus
:
def
can_allocate
(
self
,
seq_group
:
SequenceGroup
)
->
AllocStatus
:
# FIXME(woosuk): Here we assume that all sequences in the group share
# FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences.
# the same prompt. This may not be true for preempted sequences.
seq
=
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
WAITING
)[
0
]
check_no_caching_or_swa_for_blockmgr_encdec
(
self
,
seq_group
)
seq
=
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
WAITING
)[
0
]
num_required_blocks
=
BlockTable
.
get_num_required_blocks
(
num_required_blocks
=
BlockTable
.
get_num_required_blocks
(
seq
.
get_token_ids
(),
seq
.
get_token_ids
(),
block_size
=
self
.
block_size
,
block_size
=
self
.
block_size
,
)
)
assert
self
.
block_sliding_window
is
None
if
seq_group
.
is_encoder_decoder
():
if
self
.
block_sliding_window
is
not
None
:
num_required_blocks
+=
BlockTable
.
get_num_required_blocks
(
seq_group
.
get_encoder_seq
().
get_token_ids
(),
block_size
=
self
.
block_size
,
)
if
self
.
max_block_sliding_window
is
not
None
:
num_required_blocks
=
min
(
num_required_blocks
,
num_required_blocks
=
min
(
num_required_blocks
,
self
.
block_sliding_window
)
self
.
max_
block_sliding_window
)
num_free_gpu_blocks
=
self
.
block_allocator
.
get_num_free_blocks
(
num_free_gpu_blocks
=
self
.
block_allocator
.
get_num_free_blocks
(
device
=
Device
.
GPU
)
device
=
Device
.
GPU
)
...
@@ -112,7 +132,19 @@ class BlockSpaceManagerV2(BlockSpaceManager):
...
@@ -112,7 +132,19 @@ class BlockSpaceManagerV2(BlockSpaceManager):
else
:
else
:
return
AllocStatus
.
LATER
return
AllocStatus
.
LATER
def
_allocate_sequence
(
self
,
seq
:
Sequence
)
->
BlockTable
:
block_table
=
BlockTable
(
block_size
=
self
.
block_size
,
block_allocator
=
self
.
block_allocator
,
max_block_sliding_window
=
self
.
max_block_sliding_window
,
)
block_table
.
allocate
(
seq
.
get_token_ids
())
return
block_table
def
allocate
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
def
allocate
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
# Allocate self-attention block tables for decoder sequences
waiting_seqs
=
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
WAITING
)
waiting_seqs
=
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
WAITING
)
assert
not
(
set
(
seq
.
seq_id
for
seq
in
waiting_seqs
)
assert
not
(
set
(
seq
.
seq_id
for
seq
in
waiting_seqs
)
&
self
.
block_tables
.
keys
()),
"block table already exists"
&
self
.
block_tables
.
keys
()),
"block table already exists"
...
@@ -120,19 +152,29 @@ class BlockSpaceManagerV2(BlockSpaceManager):
...
@@ -120,19 +152,29 @@ class BlockSpaceManagerV2(BlockSpaceManager):
# NOTE: Here we assume that all sequences in the group have the same
# NOTE: Here we assume that all sequences in the group have the same
# prompt.
# prompt.
seq
=
waiting_seqs
[
0
]
seq
=
waiting_seqs
[
0
]
block_table
:
BlockTable
=
self
.
_allocate_sequence
(
seq
)
block_table
=
BlockTable
(
block_size
=
self
.
block_size
,
block_allocator
=
self
.
block_allocator
,
)
assert
self
.
block_sliding_window
is
None
block_table
.
allocate
(
seq
.
get_token_ids
())
self
.
block_tables
[
seq
.
seq_id
]
=
block_table
self
.
block_tables
[
seq
.
seq_id
]
=
block_table
# Assign the block table for each sequence.
# Assign the block table for each sequence.
for
seq
in
waiting_seqs
[
1
:]:
for
seq
in
waiting_seqs
[
1
:]:
self
.
block_tables
[
seq
.
seq_id
]
=
block_table
.
fork
()
self
.
block_tables
[
seq
.
seq_id
]
=
block_table
.
fork
()
# Allocate cross-attention block table for encoder sequence
#
# NOTE: Here we assume that all sequences in the group have the same
# encoder prompt.
request_id
=
seq_group
.
request_id
assert
(
request_id
not
in
self
.
cross_block_tables
),
\
"block table already exists"
check_no_caching_or_swa_for_blockmgr_encdec
(
self
,
seq_group
)
if
seq_group
.
is_encoder_decoder
():
block_table
=
self
.
_allocate_sequence
(
seq_group
.
get_encoder_seq
())
self
.
cross_block_tables
[
request_id
]
=
block_table
def
can_append_slots
(
self
,
seq_group
:
SequenceGroup
,
def
can_append_slots
(
self
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
)
->
bool
:
num_lookahead_slots
:
int
)
->
bool
:
"""Determine if there is enough space in the GPU KV cache to continue
"""Determine if there is enough space in the GPU KV cache to continue
...
@@ -166,13 +208,14 @@ class BlockSpaceManagerV2(BlockSpaceManager):
...
@@ -166,13 +208,14 @@ class BlockSpaceManagerV2(BlockSpaceManager):
self
,
self
,
seq
:
Sequence
,
seq
:
Sequence
,
num_lookahead_slots
:
int
,
num_lookahead_slots
:
int
,
)
->
Dict
[
int
,
List
[
int
]]:
)
->
List
[
Tuple
[
int
,
int
]]:
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
block_table
.
append_token_ids
(
block_table
.
append_token_ids
(
token_ids
=
block_table
.
get_unseen_token_ids
(
seq
.
get_token_ids
()),
token_ids
=
block_table
.
get_unseen_token_ids
(
seq
.
get_token_ids
()),
num_lookahead_slots
=
num_lookahead_slots
,
num_lookahead_slots
=
num_lookahead_slots
,
num_computed_slots
=
seq
.
data
.
get_num_computed_tokens
(),
)
)
# Return any new copy-on-writes.
# Return any new copy-on-writes.
...
@@ -186,12 +229,27 @@ class BlockSpaceManagerV2(BlockSpaceManager):
...
@@ -186,12 +229,27 @@ class BlockSpaceManagerV2(BlockSpaceManager):
self
.
block_tables
[
seq
.
seq_id
].
free
()
self
.
block_tables
[
seq
.
seq_id
].
free
()
del
self
.
block_tables
[
seq
.
seq_id
]
del
self
.
block_tables
[
seq
.
seq_id
]
def
free_cross
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
request_id
=
seq_group
.
request_id
if
request_id
not
in
self
.
cross_block_tables
:
# Already freed or hasn't been scheduled yet.
return
self
.
cross_block_tables
[
request_id
].
free
()
del
self
.
cross_block_tables
[
request_id
]
def
get_block_table
(
self
,
seq
:
Sequence
)
->
List
[
int
]:
def
get_block_table
(
self
,
seq
:
Sequence
)
->
List
[
int
]:
assert
seq
.
seq_id
in
self
.
block_tables
assert
seq
.
seq_id
in
self
.
block_tables
block_ids
=
self
.
block_tables
[
seq
.
seq_id
].
physical_block_ids
block_ids
=
self
.
block_tables
[
seq
.
seq_id
].
physical_block_ids
assert
all
(
b
is
not
None
for
b
in
block_ids
)
assert
all
(
b
is
not
None
for
b
in
block_ids
)
return
block_ids
# type: ignore
return
block_ids
# type: ignore
def
get_cross_block_table
(
self
,
seq_group
:
SequenceGroup
)
->
List
[
int
]:
request_id
=
seq_group
.
request_id
assert
request_id
in
self
.
cross_block_tables
block_ids
=
self
.
cross_block_tables
[
request_id
].
physical_block_ids
assert
all
(
b
is
not
None
for
b
in
block_ids
)
return
block_ids
# type: ignore
def
access_all_blocks_in_seq
(
self
,
seq
:
Sequence
,
now
:
float
):
def
access_all_blocks_in_seq
(
self
,
seq
:
Sequence
,
now
:
float
):
# Update the last accessed time of all the blocks accessed
# Update the last accessed time of all the blocks accessed
# in this step.
# in this step.
...
@@ -242,13 +300,13 @@ class BlockSpaceManagerV2(BlockSpaceManager):
...
@@ -242,13 +300,13 @@ class BlockSpaceManagerV2(BlockSpaceManager):
return
AllocStatus
.
LATER
return
AllocStatus
.
LATER
def
swap_in
(
self
,
seq_group
:
SequenceGroup
,
def
swap_in
(
self
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
)
->
Dict
[
int
,
int
]:
num_lookahead_slots
:
int
)
->
List
[
Tuple
[
int
,
int
]
]
:
raise
NotImplementedError
raise
NotImplementedError
def
can_swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
def
can_swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
return
False
return
False
def
swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
Dict
[
int
,
int
]:
def
swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
List
[
Tuple
[
int
,
int
]
]
:
raise
NotImplementedError
raise
NotImplementedError
def
get_num_free_gpu_blocks
(
self
)
->
int
:
def
get_num_free_gpu_blocks
(
self
)
->
int
:
...
...
vllm/core/embedding_model_block_manager.py
0 → 100644
View file @
b9e12416
from
typing
import
List
,
Tuple
from
vllm.core.interfaces
import
AllocStatus
,
BlockSpaceManager
from
vllm.sequence
import
Sequence
,
SequenceGroup
class
EmbeddingModelBlockSpaceManager
(
BlockSpaceManager
):
"""An embedding version of BlockSpaceManager for use in environments
with embedding models where block management is not required.
This class provides the same interface as BlockSpaceManager, but its
methods perform no actions or return simple values like True in specific
actions. It's designed to be used in scenarios where the overhead of
block management is unnecessary, such as in an embedding environment.
"""
def
__init__
(
self
,
**
kwargs
,
)
->
None
:
pass
def
can_allocate
(
self
,
seq_group
:
SequenceGroup
)
->
AllocStatus
:
# Always return OK for dummy purposes
return
AllocStatus
.
OK
def
allocate
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
# No actual allocation logic needed
pass
def
can_append_slots
(
self
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
)
->
bool
:
return
True
def
append_slots
(
self
,
seq
:
Sequence
,
num_lookahead_slots
:
int
,
)
->
List
[
Tuple
[
int
,
int
]]:
return
None
# type: ignore
def
fork
(
self
,
parent_seq
:
Sequence
,
child_seq
:
Sequence
)
->
None
:
pass
def
can_swap_in
(
self
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
)
->
AllocStatus
:
return
AllocStatus
.
OK
def
swap_in
(
self
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
)
->
List
[
Tuple
[
int
,
int
]]:
return
None
# type: ignore
def
can_swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
return
True
def
swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
List
[
Tuple
[
int
,
int
]]:
return
None
# type: ignore
def
free
(
self
,
seq
:
Sequence
)
->
None
:
# No operation on free
return
def
get_block_table
(
self
,
seq
:
Sequence
)
->
List
[
int
]:
return
None
# type: ignore
def
get_num_free_gpu_blocks
(
self
)
->
int
:
return
1
def
get_num_free_cpu_blocks
(
self
)
->
int
:
return
1
def
access_all_blocks_in_seq
(
self
,
seq
:
Sequence
,
access_time
:
float
,
)
->
None
:
pass
def
get_common_computed_block_ids
(
self
,
seq_group
:
SequenceGroup
)
->
List
[
int
]:
return
None
# type: ignore
def
mark_blocks_as_computed
(
self
,
seq_group
:
SequenceGroup
):
pass
vllm/core/interfaces.py
View file @
b9e12416
import
enum
import
enum
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Dict
,
List
from
typing
import
List
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Tuple
from
vllm.sequence
import
Sequence
,
SequenceGroup
from
vllm.sequence
import
Sequence
,
SequenceGroup
...
@@ -34,6 +35,11 @@ class BlockSpaceManager(ABC):
...
@@ -34,6 +35,11 @@ class BlockSpaceManager(ABC):
from
vllm.core.block_manager_v2
import
BlockSpaceManagerV2
from
vllm.core.block_manager_v2
import
BlockSpaceManagerV2
return
BlockSpaceManagerV2
return
BlockSpaceManagerV2
if
version
==
"embedding"
:
from
vllm.core.embedding_model_block_manager
import
(
EmbeddingModelBlockSpaceManager
)
return
EmbeddingModelBlockSpaceManager
raise
ValueError
(
f
"Unknown version
{
version
=
}
"
)
raise
ValueError
(
f
"Unknown version
{
version
=
}
"
)
@
abstractmethod
@
abstractmethod
...
@@ -54,7 +60,7 @@ class BlockSpaceManager(ABC):
...
@@ -54,7 +60,7 @@ class BlockSpaceManager(ABC):
self
,
self
,
seq
:
Sequence
,
seq
:
Sequence
,
num_lookahead_slots
:
int
,
num_lookahead_slots
:
int
,
)
->
Dict
[
int
,
List
[
int
]]:
)
->
List
[
Tuple
[
int
,
int
]]:
pass
pass
@
abstractmethod
@
abstractmethod
...
@@ -68,7 +74,7 @@ class BlockSpaceManager(ABC):
...
@@ -68,7 +74,7 @@ class BlockSpaceManager(ABC):
@
abstractmethod
@
abstractmethod
def
swap_in
(
self
,
seq_group
:
SequenceGroup
,
def
swap_in
(
self
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
)
->
Dict
[
int
,
int
]:
num_lookahead_slots
:
int
)
->
List
[
Tuple
[
int
,
int
]
]
:
pass
pass
@
abstractmethod
@
abstractmethod
...
@@ -76,7 +82,7 @@ class BlockSpaceManager(ABC):
...
@@ -76,7 +82,7 @@ class BlockSpaceManager(ABC):
pass
pass
@
abstractmethod
@
abstractmethod
def
swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
Dict
[
int
,
int
]:
def
swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
List
[
Tuple
[
int
,
int
]
]
:
pass
pass
@
abstractmethod
@
abstractmethod
...
...
vllm/core/scheduler.py
View file @
b9e12416
...
@@ -13,7 +13,6 @@ from vllm.logger import init_logger
...
@@ -13,7 +13,6 @@ from vllm.logger import init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
(
Sequence
,
SequenceData
,
SequenceGroup
,
from
vllm.sequence
import
(
Sequence
,
SequenceData
,
SequenceGroup
,
SequenceGroupMetadata
,
SequenceStatus
)
SequenceGroupMetadata
,
SequenceStatus
)
from
vllm.utils
import
merge_dicts
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -118,18 +117,19 @@ class SchedulerOutputs:
...
@@ -118,18 +117,19 @@ class SchedulerOutputs:
num_prefill_groups
:
int
num_prefill_groups
:
int
# Total number of batched tokens.
# Total number of batched tokens.
num_batched_tokens
:
int
num_batched_tokens
:
int
# Blocks to swap in.
Dic
t of CPU -> GPU block number.
# Blocks to swap in.
Lis
t of CPU -> GPU block number.
blocks_to_swap_in
:
Dict
[
int
,
int
]
blocks_to_swap_in
:
List
[
Tuple
[
int
,
int
]
]
# Blocks to swap out.
Dic
t of GPU -> CPU block number.
# Blocks to swap out.
Lis
t of GPU -> CPU block number.
blocks_to_swap_out
:
Dict
[
int
,
int
]
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]
]
# Blocks to copy. Source to
a list of
dest block
s
.
# Blocks to copy. Source to dest block.
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]]
# Sequence groups that are going to be ignored.
# Sequence groups that are going to be ignored.
ignored_seq_groups
:
List
[
SequenceGroup
]
ignored_seq_groups
:
List
[
SequenceGroup
]
# The number of slots for lookahead decoding.
# The number of slots for lookahead decoding.
num_lookahead_slots
:
int
num_lookahead_slots
:
int
# The number of requests in the running queue
# The number of requests in the running queue
running_queue_size
:
int
running_queue_size
:
int
preempted
:
int
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# Swap in and swap out should never happen at the same time.
# Swap in and swap out should never happen at the same time.
...
@@ -175,9 +175,9 @@ class SchedulerRunningOutputs:
...
@@ -175,9 +175,9 @@ class SchedulerRunningOutputs:
# Sequences that are swapped out.
# Sequences that are swapped out.
swapped_out
:
List
[
SequenceGroup
]
swapped_out
:
List
[
SequenceGroup
]
# The blocks to swap out.
# The blocks to swap out.
blocks_to_swap_out
:
Dict
[
int
,
int
]
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]
]
# The blocks to copy.
# The blocks to copy.
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]]
# The number of slots for lookahead decoding.
# The number of slots for lookahead decoding.
num_lookahead_slots
:
int
num_lookahead_slots
:
int
...
@@ -188,8 +188,8 @@ class SchedulerRunningOutputs:
...
@@ -188,8 +188,8 @@ class SchedulerRunningOutputs:
prefill_seq_groups
=
[],
prefill_seq_groups
=
[],
preempted
=
[],
preempted
=
[],
swapped_out
=
[],
swapped_out
=
[],
blocks_to_swap_out
=
{}
,
blocks_to_swap_out
=
[]
,
blocks_to_copy
=
{}
,
blocks_to_copy
=
[]
,
num_lookahead_slots
=
0
,
num_lookahead_slots
=
0
,
)
)
...
@@ -207,9 +207,9 @@ class SchedulerSwappedInOutputs:
...
@@ -207,9 +207,9 @@ class SchedulerSwappedInOutputs:
# phase. I.e., it means the prefill has been chunked.
# phase. I.e., it means the prefill has been chunked.
prefill_seq_groups
:
List
[
SequenceGroup
]
prefill_seq_groups
:
List
[
SequenceGroup
]
# The blocks to swap in.
# The blocks to swap in.
blocks_to_swap_in
:
Dict
[
int
,
int
]
blocks_to_swap_in
:
List
[
Tuple
[
int
,
int
]
]
# The blocks to copy.
# The blocks to copy.
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]]
# The number of slots for lookahead decoding.
# The number of slots for lookahead decoding.
num_lookahead_slots
:
int
num_lookahead_slots
:
int
# Infeasible sequence groups.
# Infeasible sequence groups.
...
@@ -220,8 +220,8 @@ class SchedulerSwappedInOutputs:
...
@@ -220,8 +220,8 @@ class SchedulerSwappedInOutputs:
return
SchedulerSwappedInOutputs
(
return
SchedulerSwappedInOutputs
(
decode_seq_groups
=
[],
decode_seq_groups
=
[],
prefill_seq_groups
=
[],
prefill_seq_groups
=
[],
blocks_to_swap_in
=
{}
,
blocks_to_swap_in
=
[]
,
blocks_to_copy
=
{}
,
blocks_to_copy
=
[]
,
num_lookahead_slots
=
0
,
num_lookahead_slots
=
0
,
infeasible_seq_groups
=
[],
infeasible_seq_groups
=
[],
)
)
...
@@ -264,16 +264,14 @@ class Scheduler:
...
@@ -264,16 +264,14 @@ class Scheduler:
# LoRAs. This should be improved in the future.
# LoRAs. This should be improved in the future.
self
.
lora_config
=
lora_config
self
.
lora_config
=
lora_config
if
self
.
scheduler_config
.
chunked_prefill_enabled
:
version
=
"v1"
self
.
prompt_limit
=
self
.
scheduler_config
.
max_model_len
if
self
.
scheduler_config
.
use_v2_block_manager
:
else
:
version
=
"v2"
self
.
prompt_limit
=
min
(
if
self
.
scheduler_config
.
embedding_mode
:
self
.
scheduler_config
.
max_model_len
,
version
=
"embedding"
self
.
scheduler_config
.
max_num_batched_tokens
)
BlockSpaceManagerImpl
=
BlockSpaceManager
.
get_block_space_manager_class
(
BlockSpaceManagerImpl
=
BlockSpaceManager
.
get_block_space_manager_class
(
version
=
"v2"
if
self
.
scheduler_config
.
version
)
use_v2_block_manager
else
"v1"
)
# Create the block space manager.
# Create the block space manager.
self
.
block_manager
=
BlockSpaceManagerImpl
(
self
.
block_manager
=
BlockSpaceManagerImpl
(
...
@@ -306,6 +304,7 @@ class Scheduler:
...
@@ -306,6 +304,7 @@ class Scheduler:
self
.
artificial_preempt_cnt
=
(
ARTIFICIAL_PREEMPTION_MAX_CNT
self
.
artificial_preempt_cnt
=
(
ARTIFICIAL_PREEMPTION_MAX_CNT
if
self
.
enable_artificial_preemption
if
self
.
enable_artificial_preemption
else
0
)
else
0
)
self
.
num_cumulative_preemption
:
int
=
0
@
property
@
property
def
lora_enabled
(
self
)
->
bool
:
def
lora_enabled
(
self
)
->
bool
:
...
@@ -393,8 +392,8 @@ class Scheduler:
...
@@ -393,8 +392,8 @@ class Scheduler:
scheduling and SchedulerRunningOutputs.
scheduling and SchedulerRunningOutputs.
"""
"""
# Blocks that need to be swapped or copied before model execution.
# Blocks that need to be swapped or copied before model execution.
blocks_to_swap_out
:
Dict
[
int
,
int
]
=
{}
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]
]
=
[]
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
=
{}
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]]
=
[]
decode_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
decode_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
prefill_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
prefill_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
...
@@ -510,8 +509,8 @@ class Scheduler:
...
@@ -510,8 +509,8 @@ class Scheduler:
SchedulerSwappedInOutputs.
SchedulerSwappedInOutputs.
"""
"""
# Blocks that need to be swapped or copied before model execution.
# Blocks that need to be swapped or copied before model execution.
blocks_to_swap_in
:
Dict
[
int
,
int
]
=
{}
blocks_to_swap_in
:
List
[
Tuple
[
int
,
int
]
]
=
[]
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
=
{}
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]]
=
[]
decode_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
decode_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
prefill_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
prefill_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
now
=
time
.
time
()
now
=
time
.
time
()
...
@@ -590,6 +589,21 @@ class Scheduler:
...
@@ -590,6 +589,21 @@ class Scheduler:
infeasible_seq_groups
=
infeasible_seq_groups
,
infeasible_seq_groups
=
infeasible_seq_groups
,
)
)
def
_get_prompt_limit
(
self
,
seq_group
:
SequenceGroup
)
->
int
:
if
self
.
scheduler_config
.
chunked_prefill_enabled
:
prompt_limit
=
self
.
scheduler_config
.
max_model_len
else
:
prompt_limit
=
min
(
self
.
scheduler_config
.
max_model_len
,
self
.
scheduler_config
.
max_num_batched_tokens
)
# Model is fine tuned with long context. Return the fine tuned max_len.
if
(
seq_group
.
lora_request
and
seq_group
.
lora_request
.
long_lora_max_len
):
assert
prompt_limit
<=
seq_group
.
lora_request
.
long_lora_max_len
return
seq_group
.
lora_request
.
long_lora_max_len
else
:
return
prompt_limit
def
_schedule_prefills
(
def
_schedule_prefills
(
self
,
self
,
waiting_queue
:
deque
,
waiting_queue
:
deque
,
...
@@ -644,11 +658,11 @@ class Scheduler:
...
@@ -644,11 +658,11 @@ class Scheduler:
num_prompt_tokens
=
waiting_seqs
[
0
].
get_len
()
num_prompt_tokens
=
waiting_seqs
[
0
].
get_len
()
assert
num_new_tokens
==
num_prompt_tokens
assert
num_new_tokens
==
num_prompt_tokens
if
num_new_tokens
>
self
.
prompt_limit
:
prompt_limit
=
self
.
_get_prompt_limit
(
seq_group
)
if
num_new_tokens
>
prompt_limit
:
logger
.
warning
(
logger
.
warning
(
"Input prompt (%d tokens) is too long"
"Input prompt (%d tokens) is too long"
" and exceeds limit of %d"
,
num_new_tokens
,
" and exceeds limit of %d"
,
num_new_tokens
,
prompt_limit
)
self
.
prompt_limit
)
for
seq
in
waiting_seqs
:
for
seq
in
waiting_seqs
:
seq
.
status
=
SequenceStatus
.
FINISHED_IGNORED
seq
.
status
=
SequenceStatus
.
FINISHED_IGNORED
ignored_seq_groups
.
append
(
seq_group
)
ignored_seq_groups
.
append
(
seq_group
)
...
@@ -730,8 +744,8 @@ class Scheduler:
...
@@ -730,8 +744,8 @@ class Scheduler:
budget
.
add_num_seqs
(
seq_group
.
request_id
,
budget
.
add_num_seqs
(
seq_group
.
request_id
,
seq_group
.
get_max_num_running_seqs
())
seq_group
.
get_max_num_running_seqs
())
curr_loras
=
set
(
curr_loras
=
set
(
seq_group
.
lora_int_id
seq_group
.
lora_int_id
for
seq_group
in
self
.
running
f
or
seq_group
in
self
.
running
)
if
self
.
lora_enabled
else
None
i
f
seq_group
.
lora_int_id
>
0
)
if
self
.
lora_enabled
else
None
remaining_waiting
,
prefills
=
(
self
.
waiting
,
remaining_waiting
,
prefills
=
(
self
.
waiting
,
SchedulerPrefillOutputs
.
create_empty
())
SchedulerPrefillOutputs
.
create_empty
())
...
@@ -781,6 +795,8 @@ class Scheduler:
...
@@ -781,6 +795,8 @@ class Scheduler:
# Update swapped requests.
# Update swapped requests.
self
.
swapped
=
remaining_swapped
self
.
swapped
=
remaining_swapped
self
.
swapped
.
extend
(
running_scheduled
.
swapped_out
)
self
.
swapped
.
extend
(
running_scheduled
.
swapped_out
)
preempted
=
(
len
(
running_scheduled
.
preempted
)
+
len
(
running_scheduled
.
swapped_out
))
# There should be no prefill from running queue because this policy
# There should be no prefill from running queue because this policy
# doesn't allow chunked prefills.
# doesn't allow chunked prefills.
...
@@ -794,12 +810,13 @@ class Scheduler:
...
@@ -794,12 +810,13 @@ class Scheduler:
num_batched_tokens
=
budget
.
num_batched_tokens
,
num_batched_tokens
=
budget
.
num_batched_tokens
,
blocks_to_swap_in
=
swapped_in
.
blocks_to_swap_in
,
blocks_to_swap_in
=
swapped_in
.
blocks_to_swap_in
,
blocks_to_swap_out
=
running_scheduled
.
blocks_to_swap_out
,
blocks_to_swap_out
=
running_scheduled
.
blocks_to_swap_out
,
blocks_to_copy
=
merge_dicts
(
running_scheduled
.
blocks_to_copy
,
blocks_to_copy
=
running_scheduled
.
blocks_to_copy
+
swapped_in
.
blocks_to_copy
)
,
swapped_in
.
blocks_to_copy
,
ignored_seq_groups
=
prefills
.
ignored_seq_groups
+
ignored_seq_groups
=
prefills
.
ignored_seq_groups
+
swapped_in
.
infeasible_seq_groups
,
swapped_in
.
infeasible_seq_groups
,
num_lookahead_slots
=
running_scheduled
.
num_lookahead_slots
,
num_lookahead_slots
=
running_scheduled
.
num_lookahead_slots
,
running_queue_size
=
len
(
self
.
running
),
running_queue_size
=
len
(
self
.
running
),
preempted
=
preempted
,
)
)
def
_schedule_chunked_prefill
(
self
):
def
_schedule_chunked_prefill
(
self
):
...
@@ -882,11 +899,13 @@ class Scheduler:
...
@@ -882,11 +899,13 @@ class Scheduler:
num_batched_tokens
=
budget
.
num_batched_tokens
,
num_batched_tokens
=
budget
.
num_batched_tokens
,
blocks_to_swap_in
=
swapped_in
.
blocks_to_swap_in
,
blocks_to_swap_in
=
swapped_in
.
blocks_to_swap_in
,
blocks_to_swap_out
=
running_scheduled
.
blocks_to_swap_out
,
blocks_to_swap_out
=
running_scheduled
.
blocks_to_swap_out
,
blocks_to_copy
=
merge_dicts
(
running_scheduled
.
blocks_to_copy
,
blocks_to_copy
=
running_scheduled
.
blocks_to_copy
+
swapped_in
.
blocks_to_copy
)
,
swapped_in
.
blocks_to_copy
,
ignored_seq_groups
=
prefills
.
ignored_seq_groups
,
ignored_seq_groups
=
prefills
.
ignored_seq_groups
,
num_lookahead_slots
=
running_scheduled
.
num_lookahead_slots
,
num_lookahead_slots
=
running_scheduled
.
num_lookahead_slots
,
running_queue_size
=
len
(
self
.
running
),
running_queue_size
=
len
(
self
.
running
),
preempted
=
(
len
(
running_scheduled
.
preempted
)
+
len
(
running_scheduled
.
swapped_out
)),
)
)
def
_schedule
(
self
)
->
SchedulerOutputs
:
def
_schedule
(
self
)
->
SchedulerOutputs
:
...
@@ -969,6 +988,7 @@ class Scheduler:
...
@@ -969,6 +988,7 @@ class Scheduler:
sampling_params
=
seq_group
.
sampling_params
,
sampling_params
=
seq_group
.
sampling_params
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
do_sample
=
do_sample
,
do_sample
=
do_sample
,
pooling_params
=
seq_group
.
pooling_params
,
token_chunk_size
=
token_chunk_size
,
token_chunk_size
=
token_chunk_size
,
lora_request
=
seq_group
.
lora_request
,
lora_request
=
seq_group
.
lora_request
,
computed_block_nums
=
common_computed_block_nums
,
computed_block_nums
=
common_computed_block_nums
,
...
@@ -1011,32 +1031,29 @@ class Scheduler:
...
@@ -1011,32 +1031,29 @@ class Scheduler:
def
_append_slots
(
def
_append_slots
(
self
,
self
,
seq_group
:
SequenceGroup
,
seq_group
:
SequenceGroup
,
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]],
)
->
None
:
)
->
None
:
"""Appends new slots to the sequences in the given sequence group.
"""Appends new slots to the sequences in the given sequence group.
Args:
Args:
seq_group (SequenceGroup): The sequence group containing the
seq_group (SequenceGroup): The sequence group containing the
sequences to append slots to.
sequences to append slots to.
blocks_to_copy (Dict[int, List[int]]): A dictionary mapping source
blocks_to_copy (List[Tuple[int, int]]): A list of tuple of two
block indices to lists of destination block indices. This
ints, the first int is the source block index, and the second
dictionary is updated with the new source and destination block
int is the destination block index. This list is updated with
indices for the appended slots.
the new source and destination block indices for the appended
slots.
"""
"""
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
=
False
)
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
=
False
)
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
cows
=
self
.
block_manager
.
append_slots
(
seq
,
num_lookahead_slots
)
cows
=
self
.
block_manager
.
append_slots
(
seq
,
num_lookahead_slots
)
blocks_to_copy
.
extend
(
cows
)
for
src
,
dests
in
cows
.
items
():
if
src
not
in
blocks_to_copy
:
blocks_to_copy
[
src
]
=
[]
blocks_to_copy
[
src
].
extend
(
dests
)
def
_preempt
(
def
_preempt
(
self
,
self
,
seq_group
:
SequenceGroup
,
seq_group
:
SequenceGroup
,
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]
]
,
preemption_mode
:
Optional
[
PreemptionMode
]
=
None
,
preemption_mode
:
Optional
[
PreemptionMode
]
=
None
,
)
->
PreemptionMode
:
)
->
PreemptionMode
:
# If preemption mode is not specified, we determine the mode as follows:
# If preemption mode is not specified, we determine the mode as follows:
...
@@ -1055,6 +1072,17 @@ class Scheduler:
...
@@ -1055,6 +1072,17 @@ class Scheduler:
preemption_mode
=
PreemptionMode
.
RECOMPUTE
preemption_mode
=
PreemptionMode
.
RECOMPUTE
else
:
else
:
preemption_mode
=
PreemptionMode
.
SWAP
preemption_mode
=
PreemptionMode
.
SWAP
if
self
.
num_cumulative_preemption
%
50
==
0
:
logger
.
warning
(
"Sequence group %s is preempted by %s mode because there is "
"not enough KV cache space. This can affect the end-to-end "
"performance. Increase gpu_memory_utilization or "
"tensor_parallel_size to provide more KV cache memory. "
"total_num_cumulative_preemption=%d"
,
seq_group
.
request_id
,
preemption_mode
,
self
.
num_cumulative_preemption
+
1
)
self
.
num_cumulative_preemption
+=
1
if
preemption_mode
==
PreemptionMode
.
RECOMPUTE
:
if
preemption_mode
==
PreemptionMode
.
RECOMPUTE
:
self
.
_preempt_by_recompute
(
seq_group
)
self
.
_preempt_by_recompute
(
seq_group
)
elif
preemption_mode
==
PreemptionMode
.
SWAP
:
elif
preemption_mode
==
PreemptionMode
.
SWAP
:
...
@@ -1077,24 +1105,24 @@ class Scheduler:
...
@@ -1077,24 +1105,24 @@ class Scheduler:
def
_preempt_by_swap
(
def
_preempt_by_swap
(
self
,
self
,
seq_group
:
SequenceGroup
,
seq_group
:
SequenceGroup
,
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]
]
,
)
->
None
:
)
->
None
:
self
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
self
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
def
_swap_in
(
def
_swap_in
(
self
,
self
,
seq_group
:
SequenceGroup
,
seq_group
:
SequenceGroup
,
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_in
:
List
[
Tuple
[
int
,
int
]
]
,
)
->
None
:
)
->
None
:
mapping
=
self
.
block_manager
.
swap_in
(
seq_group
)
mapping
=
self
.
block_manager
.
swap_in
(
seq_group
)
blocks_to_swap_in
.
update
(
mapping
)
blocks_to_swap_in
.
extend
(
mapping
)
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
SWAPPED
):
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
SWAPPED
):
seq
.
status
=
SequenceStatus
.
RUNNING
seq
.
status
=
SequenceStatus
.
RUNNING
def
_swap_out
(
def
_swap_out
(
self
,
self
,
seq_group
:
SequenceGroup
,
seq_group
:
SequenceGroup
,
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]
]
,
)
->
None
:
)
->
None
:
if
not
self
.
block_manager
.
can_swap_out
(
seq_group
):
if
not
self
.
block_manager
.
can_swap_out
(
seq_group
):
# FIXME(woosuk): Abort the sequence group instead of aborting the
# FIXME(woosuk): Abort the sequence group instead of aborting the
...
@@ -1103,7 +1131,7 @@ class Scheduler:
...
@@ -1103,7 +1131,7 @@ class Scheduler:
"Aborted due to the lack of CPU swap space. Please increase "
"Aborted due to the lack of CPU swap space. Please increase "
"the swap space to avoid this error."
)
"the swap space to avoid this error."
)
mapping
=
self
.
block_manager
.
swap_out
(
seq_group
)
mapping
=
self
.
block_manager
.
swap_out
(
seq_group
)
blocks_to_swap_out
.
update
(
mapping
)
blocks_to_swap_out
.
extend
(
mapping
)
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
seq
.
status
=
SequenceStatus
.
SWAPPED
seq
.
status
=
SequenceStatus
.
SWAPPED
...
...
vllm/distributed/communication_op.py
View file @
b9e12416
from
collections
import
namedtuple
from
collections
import
namedtuple
from
contextlib
import
contextmanager
,
nullcontext
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
.parallel_state
import
(
get_cpu_world_group
,
from
.parallel_state
import
(
get_cpu_world_group
,
get_pp_pynccl_communicator
,
get_tensor_model_parallel_group
,
get_tensor_model_parallel_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
is_pynccl_enabled_for_all_reduce
)
get_tp_ca_communicator
,
get_tp_pynccl_communicator
)
@
dataclass
class
GraphCaptureContext
:
stream
:
torch
.
cuda
.
Stream
@
contextmanager
def
graph_capture
():
"""
`graph_capture` is a context manager which should surround the code that
is capturing the CUDA graph. Its main purpose is to ensure that the
some operations will be run after the graph is captured, before the graph
is replayed. It returns a `GraphCaptureContext` object which contains the
necessary data for the graph capture. Currently, it only contains the
stream that the graph capture is running on. This stream is set to the
current CUDA stream when the context manager is entered and reset to the
default stream when the context manager is exited. This is to ensure that
the graph capture is running on a separate stream from the default stream,
in order to explicitly distinguish the kernels to capture
from other kernels possibly launched on background in the default stream.
"""
stream
=
torch
.
cuda
.
Stream
()
graph_capture_context
=
GraphCaptureContext
(
stream
)
ca_comm
=
get_tp_ca_communicator
()
maybe_ca_context
=
nullcontext
()
if
ca_comm
is
None
else
ca_comm
.
capture
()
with
torch
.
cuda
.
stream
(
stream
),
maybe_ca_context
:
# In graph mode, we have to be very careful about the collective
# operations. The current status is:
# allreduce \ Mode | Eager | Graph |
# --------------------------------------------
# custom allreduce | enabled | enabled |
# PyNccl | disabled| enabled |
# torch.distributed | enabled | disabled|
#
# Note that custom allreduce will have a runtime check, if the tensor
# size is too large, it will fallback to the next available option.
# In summary: When using CUDA graph, we use
# either custom all-reduce kernel or pynccl. When not using CUDA
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
# We always prioritize using custom all-reduce kernel but fall back
# to PyTorch or pynccl if it is disabled or not supported.
tp_pynccl_comm
=
get_tp_pynccl_communicator
()
pp_pynccl_comm
=
get_pp_pynccl_communicator
()
if
not
tp_pynccl_comm
:
maybe_tp_pynccl_context
=
nullcontext
()
else
:
maybe_tp_pynccl_context
=
tp_pynccl_comm
.
change_state
(
enable
=
True
,
stream
=
torch
.
cuda
.
current_stream
())
if
not
pp_pynccl_comm
:
maybe_pp_pynccl_context
=
nullcontext
()
else
:
maybe_pp_pynccl_context
=
pp_pynccl_comm
.
change_state
(
enable
=
True
,
stream
=
torch
.
cuda
.
current_stream
())
with
maybe_tp_pynccl_context
,
maybe_pp_pynccl_context
:
yield
graph_capture_context
def
tensor_model_parallel_all_reduce
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
tensor_model_parallel_all_reduce
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -23,18 +82,18 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
...
@@ -23,18 +82,18 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
TLDR: always assume this function modifies its input, but use the return
TLDR: always assume this function modifies its input, but use the return
value as the output.
value as the output.
"""
"""
from
vllm.distributed.device_communicators
import
pynccl_utils
ca_comm
=
get_tp_ca_communicator
()
from
vllm.distributed.device_communicators.custom_all_reduce
import
(
custom_all_reduce
)
# Bypass the function if we are using only 1 GPU.
# Bypass the function if we are using only 1 GPU.
if
get_tensor_model_parallel_world_size
()
==
1
:
if
get_tensor_model_parallel_world_size
()
==
1
:
return
input_
return
input_
out
=
custom_all_reduce
(
input_
)
if
ca_comm
is
not
None
:
if
out
is
not
None
:
out
=
ca_comm
.
custom_all_reduce
(
input_
)
return
out
if
out
is
not
None
:
if
is_pynccl_enabled_for_all_reduce
():
return
out
pynccl_utils
.
all_reduce
(
input_
)
pynccl_comm
=
get_tp_pynccl_communicator
()
if
(
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
):
pynccl_comm
.
all_reduce
(
input_
)
else
:
else
:
torch
.
distributed
.
all_reduce
(
input_
,
torch
.
distributed
.
all_reduce
(
input_
,
group
=
get_tensor_model_parallel_group
())
group
=
get_tensor_model_parallel_group
())
...
@@ -137,7 +196,7 @@ def broadcast_object_list(obj_list: List[Any],
...
@@ -137,7 +196,7 @@ def broadcast_object_list(obj_list: List[Any],
return
obj_list
return
obj_list
TensorMetadata
=
namedtuple
(
"TensorMetadata"
,
[
"dtype"
,
"size"
])
TensorMetadata
=
namedtuple
(
"TensorMetadata"
,
[
"device"
,
"dtype"
,
"size"
])
def
_split_tensor_dict
(
def
_split_tensor_dict
(
...
@@ -152,15 +211,13 @@ def _split_tensor_dict(
...
@@ -152,15 +211,13 @@ def _split_tensor_dict(
tensor_list
=
[]
tensor_list
=
[]
for
key
,
value
in
tensor_dict
.
items
():
for
key
,
value
in
tensor_dict
.
items
():
if
isinstance
(
value
,
torch
.
Tensor
):
if
isinstance
(
value
,
torch
.
Tensor
):
# Note(youkaichao): currently this only supports broadcasting
# Note: we cannot use `value.device` here,
# tensors on cuda. In the future, we can add device as a field in
# because it contains not only the device type but also the device
# TensorMetadata to support broadcasting tensors on different
# index (e.g. "cuda:0"). We only need the device type.
# devices.
# receiving side will set the device index.
assert
value
.
is_cuda
,
(
device
=
"cpu"
if
value
.
is_cpu
else
"cuda"
f
"Tensor
{
key
}
:
{
value
}
is not on cuda. Currently we only "
metadata_list
.
append
(
f
"support broadcasting tensors on cuda."
)
(
key
,
TensorMetadata
(
device
,
value
.
dtype
,
value
.
size
())))
metadata_list
.
append
((
key
,
TensorMetadata
(
value
.
dtype
,
value
.
size
())))
tensor_list
.
append
(
value
)
tensor_list
.
append
(
value
)
else
:
else
:
metadata_list
.
append
((
key
,
value
))
metadata_list
.
append
((
key
,
value
))
...
@@ -178,16 +235,16 @@ def broadcast_tensor_dict(
...
@@ -178,16 +235,16 @@ def broadcast_tensor_dict(
to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
dtypes).
dtypes).
"""
"""
# Bypass the function if we are using only 1 GPU.
if
(
not
torch
.
distributed
.
is_initialized
()
or
torch
.
distributed
.
get_world_size
(
group
=
group
)
==
1
):
return
tensor_dict
group
=
group
or
torch
.
distributed
.
group
.
WORLD
group
=
group
or
torch
.
distributed
.
group
.
WORLD
metadata_group
=
metadata_group
or
get_cpu_world_group
()
metadata_group
=
metadata_group
or
get_cpu_world_group
()
ranks
=
torch
.
distributed
.
get_process_group_ranks
(
group
)
ranks
=
torch
.
distributed
.
get_process_group_ranks
(
group
)
assert
src
in
ranks
,
f
"Invalid src rank (
{
src
}
)"
assert
src
in
ranks
,
f
"Invalid src rank (
{
src
}
)"
# Bypass the function if we are using only 1 GPU.
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
group
)
if
world_size
==
1
:
return
tensor_dict
rank
=
torch
.
distributed
.
get_rank
()
rank
=
torch
.
distributed
.
get_rank
()
if
rank
==
src
:
if
rank
==
src
:
metadata_list
:
List
[
Tuple
[
Any
,
Any
]]
=
[]
metadata_list
:
List
[
Tuple
[
Any
,
Any
]]
=
[]
...
@@ -203,11 +260,22 @@ def broadcast_tensor_dict(
...
@@ -203,11 +260,22 @@ def broadcast_tensor_dict(
group
=
metadata_group
)
group
=
metadata_group
)
async_handles
=
[]
async_handles
=
[]
for
tensor
in
tensor_list
:
for
tensor
in
tensor_list
:
async_handles
.
append
(
if
tensor
.
numel
()
==
0
:
torch
.
distributed
.
broadcast
(
tensor
,
# Skip broadcasting empty tensors.
src
=
src
,
continue
group
=
group
,
if
tensor
.
is_cpu
:
async_op
=
True
))
# use metadata_group for CPU tensors
handle
=
torch
.
distributed
.
broadcast
(
tensor
,
src
=
src
,
group
=
metadata_group
,
async_op
=
True
)
else
:
# use group for GPU tensors
handle
=
torch
.
distributed
.
broadcast
(
tensor
,
src
=
src
,
group
=
group
,
async_op
=
True
)
async_handles
.
append
(
handle
)
for
async_handle
in
async_handles
:
for
async_handle
in
async_handles
:
async_handle
.
wait
()
async_handle
.
wait
()
...
@@ -223,12 +291,24 @@ def broadcast_tensor_dict(
...
@@ -223,12 +291,24 @@ def broadcast_tensor_dict(
if
isinstance
(
value
,
TensorMetadata
):
if
isinstance
(
value
,
TensorMetadata
):
tensor
=
torch
.
empty
(
value
.
size
,
tensor
=
torch
.
empty
(
value
.
size
,
dtype
=
value
.
dtype
,
dtype
=
value
.
dtype
,
device
=
"cuda"
)
device
=
value
.
device
)
async_handle
=
torch
.
distributed
.
broadcast
(
tensor
,
if
tensor
.
numel
()
==
0
:
src
=
src
,
# Skip broadcasting empty tensors.
async_op
=
True
,
tensor_dict
[
key
]
=
tensor
group
=
group
)
continue
async_handles
.
append
(
async_handle
)
if
tensor
.
is_cpu
:
# use metadata_group for CPU tensors
handle
=
torch
.
distributed
.
broadcast
(
tensor
,
src
=
src
,
group
=
metadata_group
,
async_op
=
True
)
else
:
# use group for GPU tensors
handle
=
torch
.
distributed
.
broadcast
(
tensor
,
src
=
src
,
group
=
group
,
async_op
=
True
)
async_handles
.
append
(
handle
)
tensor_dict
[
key
]
=
tensor
tensor_dict
[
key
]
=
tensor
else
:
else
:
tensor_dict
[
key
]
=
value
tensor_dict
[
key
]
=
value
...
...
vllm/distributed/device_communicators/custom_all_reduce.py
View file @
b9e12416
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
Any
,
List
,
Optional
from
typing
import
Any
,
List
,
Optional
,
Union
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.distributed.device_communicators.custom_all_reduce_utils
import
(
gpu_p2p_access_check
)
from
vllm.distributed.parallel_state
import
(
get_local_rank
,
get_tensor_model_parallel_cpu_group
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
try
:
try
:
import
pynvml
import
pynvml
from
vllm._C
import
custom_ar
from
vllm._C
import
custom_ar
@
contextmanager
def
_nvml
():
try
:
pynvml
.
nvmlInit
()
yield
finally
:
pynvml
.
nvmlShutdown
()
except
ImportError
:
except
ImportError
:
# For AMD GPUs
# For AMD GPUs
custom_ar
=
None
custom_ar
=
None
pynvml
=
None
pynvml
=
None
logger
=
init_logger
(
__name__
)
@
contextmanager
def
_nvml
():
try
:
yield
finally
:
pass
_CA_HANDLE
:
Optional
[
"CustomAllreduce"
]
=
None
_IS_CAPTURING
=
False
logger
=
init_logger
(
__name__
)
_SUPPORTED_WORLD_SIZES
=
[
2
,
4
,
6
,
8
]
def
init_custom_ar
()
->
None
:
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
global
_CA_HANDLE
if
_CA_HANDLE
is
not
None
:
return
rank
=
get_tensor_model_parallel_rank
()
world_size
=
get_tensor_model_parallel_world_size
()
if
world_size
==
1
:
# No need to initialize custom allreduce for single GPU case.
return
if
world_size
not
in
_SUPPORTED_WORLD_SIZES
:
logger
.
warning
(
"Custom allreduce is disabled due to an unsupported world size: "
"%d. Supported world sizes: %s. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly."
,
world_size
,
str
(
_SUPPORTED_WORLD_SIZES
))
return
num_dev
=
torch
.
cuda
.
device_count
()
# note: num dev can be larger than world_size if we're only using
# first few GPUs
if
num_dev
<
world_size
:
logger
.
warning
(
"Cannot test GPU P2P because not all GPUs are visible to the "
"current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
" is set."
)
return
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
cuda_visible_devices
=
envs
.
CUDA_VISIBLE_DEVICES
if
cuda_visible_devices
:
device_ids
=
list
(
map
(
int
,
cuda_visible_devices
.
split
(
","
)))
else
:
device_ids
=
list
(
range
(
num_dev
))
# this checks hardware and driver support for NVLink
full_nvlink
=
_is_full_nvlink
(
device_ids
)
if
world_size
>
2
and
not
full_nvlink
:
logger
.
warning
(
"Custom allreduce is disabled because it's not supported on more"
" than two PCIe-only GPUs. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly."
)
return
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# then we cache the result
if
not
_can_p2p
(
rank
,
world_size
):
logger
.
warning
(
"Custom allreduce is disabled because your platform lacks GPU P2P"
" capability or P2P test failed. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly."
)
return
_CA_HANDLE
=
CustomAllreduce
(
rank
,
world_size
,
full_nvlink
)
def
begin_capture
()
->
None
:
global
_IS_CAPTURING
_IS_CAPTURING
=
True
def
end_capture
()
->
None
:
global
_IS_CAPTURING
_IS_CAPTURING
=
False
def
is_capturing
()
->
bool
:
return
_IS_CAPTURING
and
_CA_HANDLE
is
not
None
def
get_handle
()
->
Optional
[
"CustomAllreduce"
]:
return
_CA_HANDLE
def
is_initialized
()
->
bool
:
return
_CA_HANDLE
is
not
None
@
contextmanager
def
capture
():
try
:
begin_capture
()
yield
finally
:
end_capture
()
handle
=
get_handle
()
if
handle
is
not
None
:
handle
.
register_graph_buffers
()
def
custom_all_reduce
(
input
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
ca_handle
=
get_handle
()
# when custom allreduce is disabled, this will be None
if
ca_handle
is
None
:
return
None
if
is_capturing
():
if
torch
.
cuda
.
is_current_stream_capturing
():
if
ca_handle
.
should_custom_ar
(
input
):
return
ca_handle
.
all_reduce_reg
(
input
)
else
:
if
ca_handle
.
should_custom_ar
(
input
):
# if warm up, mimic the allocation pattern
# since custom allreduce is out-of-place
return
torch
.
empty_like
(
input
)
else
:
# note: outside of cuda graph context,
# custom allreduce incurs a cost of cudaMemcpy, which should
# be small(<=1% of overall latency) compared to the performance
# gains of using custom kernels
if
ca_handle
.
should_custom_ar
(
input
):
return
ca_handle
.
all_reduce_unreg
(
input
)
return
None
@
contextmanager
def
_nvml
():
try
:
pynvml
.
nvmlInit
()
yield
finally
:
pynvml
.
nvmlShutdown
()
@
_nvml
()
@
_nvml
()
...
@@ -173,7 +67,6 @@ def _is_full_nvlink(device_ids: List[int]) -> bool:
...
@@ -173,7 +67,6 @@ def _is_full_nvlink(device_ids: List[int]) -> bool:
def
_can_p2p
(
rank
:
int
,
world_size
:
int
)
->
bool
:
def
_can_p2p
(
rank
:
int
,
world_size
:
int
)
->
bool
:
from
vllm.distributed.utils
import
gpu_p2p_access_check
for
i
in
range
(
world_size
):
for
i
in
range
(
world_size
):
if
i
==
rank
:
if
i
==
rank
:
continue
continue
...
@@ -184,22 +77,112 @@ def _can_p2p(rank: int, world_size: int) -> bool:
...
@@ -184,22 +77,112 @@ def _can_p2p(rank: int, world_size: int) -> bool:
class
CustomAllreduce
:
class
CustomAllreduce
:
_SUPPORTED_WORLD_SIZES
=
[
2
,
4
,
6
,
8
]
# max_size: max supported allreduce size
# max_size: max supported allreduce size
def
__init__
(
self
,
def
__init__
(
self
,
rank
,
group
:
Optional
[
ProcessGroup
]
=
None
,
world_size
,
device
:
Optional
[
Union
[
int
,
str
,
torch
.
device
]]
=
None
,
full_nvlink
,
max_size
=
8192
*
1024
)
->
None
:
max_size
=
8192
*
1024
)
->
None
:
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the CustomAllreduce to. If None,
it will be bind to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device, and all communicators in this group
are in the same node.
"""
self
.
_IS_CAPTURING
=
False
self
.
disabled
=
True
if
custom_ar
is
None
:
# disable because of missing custom allreduce library
# e.g. in a non-cuda environment
return
group
=
group
or
get_tensor_model_parallel_cpu_group
()
self
.
group
=
group
assert
dist
.
get_backend
(
group
)
!=
dist
.
Backend
.
NCCL
,
(
"CustomAllreduce should be attached to a non-NCCL group."
)
rank
=
dist
.
get_rank
(
group
=
self
.
group
)
world_size
=
dist
.
get_world_size
(
group
=
self
.
group
)
if
world_size
==
1
:
# No need to initialize custom allreduce for single GPU case.
return
if
world_size
not
in
CustomAllreduce
.
_SUPPORTED_WORLD_SIZES
:
logger
.
warning
(
"Custom allreduce is disabled due to an unsupported world"
" size: %d. Supported world sizes: %s. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly."
,
world_size
,
str
(
CustomAllreduce
.
_SUPPORTED_WORLD_SIZES
))
return
if
device
is
None
:
local_rank
=
get_local_rank
()
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
elif
isinstance
(
device
,
int
):
device
=
torch
.
device
(
f
"cuda:
{
device
}
"
)
elif
isinstance
(
device
,
str
):
device
=
torch
.
device
(
device
)
# now `device` is a `torch.device` object
assert
isinstance
(
device
,
torch
.
device
)
self
.
device
=
device
cuda_visible_devices
=
envs
.
CUDA_VISIBLE_DEVICES
if
cuda_visible_devices
:
device_ids
=
list
(
map
(
int
,
cuda_visible_devices
.
split
(
","
)))
else
:
device_ids
=
list
(
range
(
torch
.
cuda
.
device_count
()))
physical_device_id
=
device_ids
[
device
.
index
]
tensor
=
torch
.
tensor
([
physical_device_id
],
dtype
=
torch
.
int
,
device
=
"cpu"
)
gather_list
=
[
torch
.
tensor
([
0
],
dtype
=
torch
.
int
,
device
=
"cpu"
)
for
_
in
range
(
world_size
)
]
dist
.
all_gather
(
gather_list
,
tensor
,
group
=
self
.
group
)
physical_device_ids
=
[
t
.
item
()
for
t
in
gather_list
]
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
full_nvlink
=
_is_full_nvlink
(
physical_device_ids
)
if
world_size
>
2
and
not
full_nvlink
:
logger
.
warning
(
"Custom allreduce is disabled because it's not supported on"
" more than two PCIe-only GPUs. To silence this warning, "
"specify disable_custom_all_reduce=True explicitly."
)
return
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# then we cache the result
if
not
_can_p2p
(
rank
,
world_size
):
logger
.
warning
(
"Custom allreduce is disabled because your platform lacks "
"GPU P2P capability or P2P test failed. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly."
)
return
self
.
disabled
=
False
# buffers memory are owned by this Python class and passed to C++
# buffers memory are owned by this Python class and passed to C++
# meta data composes of two parts: meta data for synchronization
# meta data composes of two parts: meta data for synchronization
# (256 bytes) and a temporary buffer for storing intermediate
# (256 bytes) and a temporary buffer for storing intermediate
# allreduce results.
# allreduce results.
self
.
meta
=
torch
.
zeros
(
custom_ar
.
meta_size
()
+
max_size
,
self
.
meta
=
torch
.
zeros
(
custom_ar
.
meta_size
()
+
max_size
,
dtype
=
torch
.
uint8
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
device
=
self
.
device
)
# This is a pre-registered IPC buffer. In eager mode, input tensors
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
# are first copied into this buffer before allreduce is performed
self
.
buffer
=
torch
.
empty
(
max_size
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
self
.
buffer
=
torch
.
empty
(
max_size
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
# This is a buffer for storing the tuples of pointers pointing to
# This is a buffer for storing the tuples of pointers pointing to
# IPC buffers from all ranks. Each registered tuple has size of
# IPC buffers from all ranks. Each registered tuple has size of
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
...
@@ -207,8 +190,9 @@ class CustomAllreduce:
...
@@ -207,8 +190,9 @@ class CustomAllreduce:
# needs less than 10000 of registered tuples.
# needs less than 10000 of registered tuples.
self
.
rank_data
=
torch
.
empty
(
8
*
1024
*
1024
,
self
.
rank_data
=
torch
.
empty
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
device
=
self
.
device
)
self
.
max_size
=
max_size
self
.
max_size
=
max_size
self
.
rank
=
rank
self
.
world_size
=
world_size
self
.
world_size
=
world_size
handles
,
offsets
=
self
.
_get_ipc_meta
(
self
.
meta
)
handles
,
offsets
=
self
.
_get_ipc_meta
(
self
.
meta
)
self
.
full_nvlink
=
full_nvlink
self
.
full_nvlink
=
full_nvlink
...
@@ -217,6 +201,21 @@ class CustomAllreduce:
...
@@ -217,6 +201,21 @@ class CustomAllreduce:
self
.
full_nvlink
)
self
.
full_nvlink
)
self
.
register_buffer
(
self
.
buffer
)
self
.
register_buffer
(
self
.
buffer
)
@
contextmanager
def
capture
(
self
):
"""
The main responsibility of this context manager is the
`register_graph_buffers` call at the end of the context.
It records all the buffer addresses used in the CUDA graph.
"""
try
:
self
.
_IS_CAPTURING
=
True
yield
finally
:
self
.
_IS_CAPTURING
=
False
if
not
self
.
disabled
:
self
.
register_graph_buffers
()
def
_get_ipc_meta
(
self
,
inp
:
torch
.
Tensor
):
def
_get_ipc_meta
(
self
,
inp
:
torch
.
Tensor
):
data
=
inp
.
untyped_storage
().
_share_cuda_
()
data
=
inp
.
untyped_storage
().
_share_cuda_
()
shard_data
=
(
shard_data
=
(
...
@@ -226,14 +225,29 @@ class CustomAllreduce:
...
@@ -226,14 +225,29 @@ class CustomAllreduce:
return
self
.
_gather_ipc_meta
(
shard_data
)
return
self
.
_gather_ipc_meta
(
shard_data
)
def
_gather_ipc_meta
(
self
,
shard_data
):
def
_gather_ipc_meta
(
self
,
shard_data
):
all_data
:
List
[
Optional
[
Any
]]
=
[
None
]
*
self
.
world_size
# Note: don't use `[[None]] * self.world_size` here
dist
.
all_gather_object
(
all_data
,
shard_data
)
# because it will create a list of the same reference
all_data
:
List
[
Optional
[
Any
]]
=
[[
None
]
for
i
in
range
(
self
.
world_size
)]
all_data
[
self
.
rank
][
0
]
=
shard_data
ranks
=
dist
.
get_process_group_ranks
(
group
=
self
.
group
)
ranks
.
sort
()
for
i
,
rank
in
enumerate
(
ranks
):
dist
.
broadcast_object_list
(
all_data
[
i
],
src
=
rank
,
group
=
self
.
group
,
device
=
"cpu"
)
# we cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
handles
=
[]
handles
=
[]
offsets
=
[]
offsets
=
[]
for
i
in
range
(
len
(
all_data
)):
for
i
in
range
(
len
(
all_data
)):
handles
.
append
(
all_data
[
i
][
0
])
# type: ignore
handles
.
append
(
all_data
[
i
][
0
]
[
0
]
)
# type: ignore
offsets
.
append
(
all_data
[
i
][
1
])
# type: ignore
offsets
.
append
(
all_data
[
i
][
0
][
1
])
# type: ignore
return
handles
,
offsets
return
handles
,
offsets
def
register_buffer
(
self
,
inp
:
torch
.
Tensor
):
def
register_buffer
(
self
,
inp
:
torch
.
Tensor
):
...
@@ -265,8 +279,31 @@ class CustomAllreduce:
...
@@ -265,8 +279,31 @@ class CustomAllreduce:
custom_ar
.
all_reduce_unreg
(
self
.
_ptr
,
inp
,
self
.
buffer
,
out
)
custom_ar
.
all_reduce_unreg
(
self
.
_ptr
,
inp
,
self
.
buffer
,
out
)
return
out
return
out
def
custom_all_reduce
(
self
,
input
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
# when custom allreduce is disabled, this will be None
if
self
.
disabled
:
return
None
if
self
.
_IS_CAPTURING
:
if
torch
.
cuda
.
is_current_stream_capturing
():
if
self
.
should_custom_ar
(
input
):
return
self
.
all_reduce_reg
(
input
)
else
:
if
self
.
should_custom_ar
(
input
):
# if warm up, mimic the allocation pattern
# since custom allreduce is out-of-place
return
torch
.
empty_like
(
input
)
else
:
# note: outside of cuda graph context,
# custom allreduce incurs a cost of cudaMemcpy, which should
# be small(<=1% of overall latency) compared to the performance
# gains of using custom kernels
if
self
.
should_custom_ar
(
input
):
return
self
.
all_reduce_unreg
(
input
)
return
None
def
close
(
self
):
def
close
(
self
):
if
self
.
_ptr
:
if
not
self
.
disabled
and
self
.
_ptr
:
custom_ar
.
dispose
(
self
.
_ptr
)
custom_ar
.
dispose
(
self
.
_ptr
)
self
.
_ptr
=
0
self
.
_ptr
=
0
...
...
vllm/distributed/device_communicators/custom_all_reduce_utils.py
0 → 100644
View file @
b9e12416
import
json
import
os
import
sys
import
tempfile
import
time
from
contextlib
import
contextmanager
from
typing
import
Callable
,
Dict
,
List
,
Optional
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
import
vllm.envs
as
envs
from
vllm.distributed.parallel_state
import
get_cpu_world_group
,
get_local_rank
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
@
contextmanager
def
mute_output
():
with
open
(
os
.
devnull
,
"w"
)
as
f
:
sys
.
stderr
=
f
sys
.
stdout
=
f
yield
def
producer
(
i
:
int
,
init_method
:
str
,
cuda_visible_devices
:
Optional
[
str
]
=
None
):
if
cuda_visible_devices
is
not
None
:
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
cuda_visible_devices
with
mute_output
():
dist
.
init_process_group
(
backend
=
"gloo"
,
init_method
=
init_method
,
world_size
=
2
,
rank
=
0
,
)
# produce a tensor in GPU i
data
=
torch
.
zeros
((
128
,
),
device
=
f
"cuda:
{
i
}
"
)
# get the information to reconstruct the shared tensor
func
,
args
=
torch
.
multiprocessing
.
reductions
.
reduce_tensor
(
data
)
args
=
list
(
args
)
dist
.
broadcast_object_list
([(
func
,
args
)],
src
=
0
)
dist
.
barrier
()
torch
.
cuda
.
synchronize
()
assert
torch
.
all
(
data
==
1
).
item
()
def
consumer
(
j
:
int
,
init_method
:
str
,
cuda_visible_devices
:
Optional
[
str
]
=
None
):
if
cuda_visible_devices
is
not
None
:
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
cuda_visible_devices
with
mute_output
():
dist
.
init_process_group
(
backend
=
"gloo"
,
init_method
=
init_method
,
world_size
=
2
,
rank
=
1
,
)
torch
.
cuda
.
set_device
(
j
)
recv
=
[
None
]
dist
.
broadcast_object_list
(
recv
,
src
=
0
)
func
:
Callable
args
:
List
func
,
args
=
recv
[
0
]
# type: ignore
# `args[6]` is the device id
# by default pytorch will use `i` from the producer
# here we need to set it to `j` to test P2P access
args
[
6
]
=
j
data
=
func
(
*
args
)
data
+=
1
dist
.
barrier
()
torch
.
cuda
.
synchronize
()
assert
torch
.
all
(
data
==
1
).
item
()
def
can_actually_p2p
(
i
,
j
):
"""
Usually, checking if P2P access is enabled can be done by
`torch.cuda.can_device_access_peer(i, j)`. However, sometimes
the driver might be broken, and `torch.cuda.can_device_access_peer(i, j)`
returns `True` even if P2P access is not actually possible.
See https://github.com/vllm-project/vllm/issues/2728 and
https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10
Therefore, we have to perform a real P2P access to check if it is actually
possible.
Note on p2p and cuda IPC:
Usually, one process uses one GPU:
GPU i --> cuda context i --> tensor i --> process i
We need to combine p2p and cuda IPC, so that:
GPU i --> cuda context i --> tensor i --> process i
|shared|
GPU j --> cuda context j --> tensor j --> process j
That is to say, process i creates a tensor in GPU i, passes IPC handle to
process j, and process j accesses the tensor in GPU j. Any operation on the
tensor in process j will be reflected in the tensor in process i, because
they are the same memory segment.
It is important to note that process j accesses the tensor in GPU j, not
GPU i. That's why we need p2p access. # noqa
"""
cuda_visible_devices
=
os
.
getenv
(
'CUDA_VISIBLE_DEVICES'
,
None
)
# pass the CUDA_VISIBLE_DEVICES to the child process
# to make sure they see the same set of GPUs
# make sure the temp file is not the same across different calls
temp_path
=
tempfile
.
mktemp
()
+
str
(
time
.
time
())
# create an empty file
with
open
(
temp_path
,
"w"
):
pass
init_method
=
f
"file://
{
temp_path
}
"
# make sure the processes are spawned
smp
=
mp
.
get_context
(
"spawn"
)
pi
=
smp
.
Process
(
target
=
producer
,
args
=
(
i
,
init_method
,
cuda_visible_devices
))
pj
=
smp
.
Process
(
target
=
consumer
,
args
=
(
j
,
init_method
,
cuda_visible_devices
))
pi
.
start
()
pj
.
start
()
pi
.
join
()
pj
.
join
()
return
pi
.
exitcode
==
0
and
pj
.
exitcode
==
0
# why do we need this cache?
# we are testing peer-to-peer (p2p) access between GPUs,across processes.
# if we test it every time, it will be very slow, because we need to create
# N * N * 2 processes, where N is the world size. This is very slow.
# to reduce the time, we use a cache file to store the p2p access status.
# the cache file is generated by the master process if it does not exist.
# then all the processes can read the cache file to check the p2p access status.
# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
# can have different cache files for different CUDA_VISIBLE_DEVICES settings,
# e.g. used by different vllm engines. The device id in the cache file is a
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
# of visible devices in the vllm engine.
_gpu_p2p_access_cache
:
Optional
[
Dict
[
str
,
bool
]]
=
None
def
gpu_p2p_access_check
(
i
:
int
,
j
:
int
)
->
bool
:
"""Check if GPU i can access GPU j."""
# if the cache variable is already calculated,
# read from the cache instead of checking it again
global
_gpu_p2p_access_cache
if
_gpu_p2p_access_cache
is
not
None
:
return
_gpu_p2p_access_cache
[
f
"
{
i
}
->
{
j
}
"
]
is_distributed
=
dist
.
is_initialized
()
num_dev
=
torch
.
cuda
.
device_count
()
cuda_visible_devices
=
envs
.
CUDA_VISIBLE_DEVICES
if
cuda_visible_devices
is
None
:
cuda_visible_devices
=
","
.
join
(
str
(
i
)
for
i
in
range
(
num_dev
))
VLLM_CONFIG_ROOT
=
envs
.
VLLM_CONFIG_ROOT
path
=
os
.
path
.
expanduser
(
f
"
{
VLLM_CONFIG_ROOT
}
/vllm/gpu_p2p_access_cache_for_
{
cuda_visible_devices
}
.json"
)
os
.
makedirs
(
os
.
path
.
dirname
(
path
),
exist_ok
=
True
)
if
((
not
is_distributed
or
get_local_rank
()
==
0
)
and
(
not
os
.
path
.
exists
(
path
))):
# only the local master process (with local_rank == 0) can
# enter this block to calculate the cache
logger
.
info
(
"generating GPU P2P access cache for in %s"
,
path
)
cache
=
{}
for
_i
in
range
(
num_dev
):
for
_j
in
range
(
num_dev
):
cache
[
f
"
{
_i
}
->
{
_j
}
"
]
=
can_actually_p2p
(
_i
,
_j
)
with
open
(
path
,
"w"
)
as
f
:
json
.
dump
(
cache
,
f
,
indent
=
4
)
if
is_distributed
:
cpu_world_group
=
get_cpu_world_group
()
dist
.
barrier
(
cpu_world_group
)
logger
.
info
(
"reading GPU P2P access cache from %s"
,
path
)
with
open
(
path
,
"r"
)
as
f
:
cache
=
json
.
load
(
f
)
_gpu_p2p_access_cache
=
cache
return
_gpu_p2p_access_cache
[
f
"
{
i
}
->
{
j
}
"
]
__all__
=
[
"gpu_p2p_access_check"
]
vllm/distributed/device_communicators/pynccl.py
View file @
b9e12416
# This file is a pure Python wrapper for the NCCL library.
from
contextlib
import
contextmanager
# The main purpose is to use NCCL combined with CUDA graph.
# Before writing this script, we tried the following approach:
# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
# often gets stuck when initializing the NCCL communicator.
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
# contains many other potential cuda APIs, that are not allowed during
# capturing the CUDA graph. For further details, please check
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
#
# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
# doable, but we often encounter issues related with nccl versions, and need
# to switch between different versions of NCCL. See
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
# A C/C++ binding is not flexible enough to handle this. It requires
# recompilation of the code every time we want to switch between different
# versions. This current implementation, with a **pure** Python wrapper, is
# more flexible. We can easily switch between different versions of NCCL by
# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file`
# variable in the code.
import
ctypes
import
platform
from
typing
import
Optional
,
Union
from
typing
import
Optional
,
Union
# ===================== import region =====================
# ===================== import region =====================
...
@@ -28,217 +6,70 @@ import torch
...
@@ -28,217 +6,70 @@ import torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
,
ReduceOp
from
torch.distributed
import
ProcessGroup
,
ReduceOp
from
vllm.distributed.device_communicators.pynccl_wrapper
import
(
NCCLLibrary
,
buffer_type
,
cudaStream_t
,
ncclComm_t
,
ncclDataTypeEnum
,
ncclRedOpTypeEnum
,
ncclUniqueId
)
from
vllm.distributed.parallel_state
import
get_cpu_world_group
,
get_local_rank
from
vllm.distributed.parallel_state
import
get_cpu_world_group
,
get_local_rank
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
find_nccl_library
,
nccl_integrity_check
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
so_file
=
find_nccl_library
()
try
:
# load the library in another process.
# if it core dumps, it will not crash the current process
nccl_integrity_check
(
so_file
)
nccl
=
ctypes
.
CDLL
(
so_file
)
except
Exception
as
e
:
logger
.
error
(
"Failed to load NCCL library from %s ."
"It is expected if you are not running on NVIDIA/AMD GPUs."
"Otherwise, the nccl library might not exist, be corrupted "
"or it does not support the current platform %s."
"One solution is to download libnccl2 version 2.18 from "
"https://developer.download.nvidia.com/compute/cuda/repos/ "
"and extract the libnccl.so.2 file. If you already have the "
"library, please set the environment variable VLLM_NCCL_SO_PATH"
" to point to the correct nccl library path."
,
so_file
,
platform
.
platform
())
raise
e
# === export types and functions from nccl to Python ===
# for the original nccl definition, please check
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
ncclResult_t
=
ctypes
.
c_int
_c_ncclGetErrorString
=
nccl
.
ncclGetErrorString
_c_ncclGetErrorString
.
restype
=
ctypes
.
c_char_p
_c_ncclGetErrorString
.
argtypes
=
[
ncclResult_t
]
def
NCCL_CHECK
(
result
:
ncclResult_t
)
->
None
:
if
result
!=
0
:
error_str
=
_c_ncclGetErrorString
(
result
)
error_str
=
error_str
.
decode
(
"utf-8"
)
raise
RuntimeError
(
f
"NCCL error:
{
error_str
}
"
)
# equivalent to c declaration:
# ncclResult_t ncclGetVersion(int *version);
_c_ncclGetVersion
=
nccl
.
ncclGetVersion
_c_ncclGetVersion
.
restype
=
ctypes
.
c_int
_c_ncclGetVersion
.
argtypes
=
[
ctypes
.
POINTER
(
ctypes
.
c_int
)]
def
ncclGetVersion
()
->
str
:
version
=
ctypes
.
c_int
()
NCCL_CHECK
(
_c_ncclGetVersion
(
ctypes
.
byref
(
version
)))
# something like 21903 --> "2.19.3"
version_str
=
str
(
version
.
value
)
major
=
version_str
[
0
].
lstrip
(
"0"
)
minor
=
version_str
[
1
:
3
].
lstrip
(
"0"
)
patch
=
version_str
[
3
:].
lstrip
(
"0"
)
return
f
"
{
major
}
.
{
minor
}
.
{
patch
}
"
class
NcclUniqueId
(
ctypes
.
Structure
):
_fields_
=
[(
"internal"
,
ctypes
.
c_byte
*
128
)]
# equivalent to c declaration:
# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
_c_ncclGetUniqueId
=
nccl
.
ncclGetUniqueId
_c_ncclGetUniqueId
.
restype
=
ctypes
.
c_int
_c_ncclGetUniqueId
.
argtypes
=
[
ctypes
.
POINTER
(
NcclUniqueId
)]
def
ncclGetUniqueId
()
->
NcclUniqueId
:
unique_id
=
NcclUniqueId
()
NCCL_CHECK
(
_c_ncclGetUniqueId
(
ctypes
.
byref
(
unique_id
)))
return
unique_id
# equivalent to c declaration:
# ncclResult_t ncclCommInitRank(
# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
# note that ncclComm_t is a pointer type, so the first argument
# is a pointer to a pointer
_c_ncclCommInitRank
=
nccl
.
ncclCommInitRank
_c_ncclCommInitRank
.
restype
=
ctypes
.
c_int
_c_ncclCommInitRank
.
argtypes
=
[
ctypes
.
POINTER
(
ctypes
.
c_void_p
),
ctypes
.
c_int
,
NcclUniqueId
,
ctypes
.
c_int
]
ncclDataType_t
=
ctypes
.
c_int
class
ncclDataTypeEnum
:
ncclInt8
=
0
ncclChar
=
0
ncclUint8
=
1
ncclInt32
=
2
ncclInt
=
2
ncclUint32
=
3
ncclInt64
=
4
ncclUint64
=
5
ncclFloat16
=
6
ncclHalf
=
6
ncclFloat32
=
7
ncclFloat
=
7
ncclFloat64
=
8
ncclDouble
=
8
ncclBfloat16
=
9
ncclNumTypes
=
10
@
classmethod
def
from_torch
(
cls
,
dtype
:
torch
.
dtype
)
->
int
:
if
dtype
==
torch
.
int8
:
return
cls
.
ncclInt8
if
dtype
==
torch
.
uint8
:
return
cls
.
ncclUint8
if
dtype
==
torch
.
int32
:
return
cls
.
ncclInt32
if
dtype
==
torch
.
int64
:
return
cls
.
ncclInt64
if
dtype
==
torch
.
float16
:
return
cls
.
ncclFloat16
if
dtype
==
torch
.
float32
:
return
cls
.
ncclFloat32
if
dtype
==
torch
.
float64
:
return
cls
.
ncclFloat64
if
dtype
==
torch
.
bfloat16
:
return
cls
.
ncclBfloat16
raise
ValueError
(
f
"Unsupported dtype:
{
dtype
}
"
)
ncclRedOp_t
=
ctypes
.
c_int
class
ncclRedOpTypeEnum
:
ncclSum
=
0
ncclProd
=
1
ncclMax
=
2
ncclMin
=
3
ncclAvg
=
4
ncclNumOps
=
5
@
classmethod
class
PyNcclCommunicator
:
def
from_torch
(
cls
,
op
:
ReduceOp
)
->
int
:
if
op
==
ReduceOp
.
SUM
:
return
cls
.
ncclSum
if
op
==
ReduceOp
.
PRODUCT
:
return
cls
.
ncclProd
if
op
==
ReduceOp
.
MAX
:
return
cls
.
ncclMax
if
op
==
ReduceOp
.
MIN
:
return
cls
.
ncclMin
if
op
==
ReduceOp
.
AVG
:
return
cls
.
ncclAvg
raise
ValueError
(
f
"Unsupported op:
{
op
}
"
)
# equivalent to c declaration:
# ncclResult_t ncclAllReduce(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
# udaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument is a pointer
_c_ncclAllReduce
=
nccl
.
ncclAllReduce
_c_ncclAllReduce
.
restype
=
ctypes
.
c_int
_c_ncclAllReduce
.
argtypes
=
[
ctypes
.
c_void_p
,
ctypes
.
c_void_p
,
ctypes
.
c_size_t
,
ncclRedOp_t
,
ncclDataType_t
,
ctypes
.
c_void_p
,
ctypes
.
c_void_p
]
# be cautious! this is a collective call, it will block until all
# processes in the communicator have called this function.
# because Python object destruction can happen in random order,
# it is better not to call it at all.
# equivalent to c declaration:
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
_c_ncclCommDestroy
=
nccl
.
ncclCommDestroy
_c_ncclCommDestroy
.
restype
=
ctypes
.
c_int
_c_ncclCommDestroy
.
argtypes
=
[
ctypes
.
c_void_p
]
class
NCCLCommunicator
:
def
__init__
(
def
__init__
(
self
,
self
,
group
:
Optional
[
ProcessGroup
]
=
None
,
group
:
Optional
[
ProcessGroup
]
=
None
,
device
:
Optional
[
Union
[
int
,
str
,
torch
.
device
]]
=
None
,
device
:
Optional
[
Union
[
int
,
str
,
torch
.
device
]]
=
None
,
library_path
:
Optional
[
str
]
=
None
,
):
):
"""
"""
Args:
Args:
group: the process group to work on. If None, it will use the
group: the process group to work on. If None, it will use the
default process group.
default process group.
device: the device to bind the
NCCL
Communicator to. If None,
device: the device to bind the
PyNccl
Communicator to. If None,
it will be bind to f"cuda:{local_rank}".
it will be bind to f"cuda:{local_rank}".
library_path: the path to the NCCL library. If None, it will
use the default library path.
It is the caller's responsibility to make sure each communicator
It is the caller's responsibility to make sure each communicator
is bind to a unique device.
is bind to a unique device.
"""
"""
assert
dist
.
is_initialized
()
assert
dist
.
is_initialized
()
group
=
get_cpu_world_group
()
if
group
is
None
else
group
group
=
get_cpu_world_group
()
if
group
is
None
else
group
assert
dist
.
get_backend
(
group
)
!=
dist
.
Backend
.
NCCL
,
(
assert
dist
.
get_backend
(
group
)
!=
dist
.
Backend
.
NCCL
,
(
"
NCCL
Communicator should be attached to a non-NCCL group."
)
"
PyNccl
Communicator should be attached to a non-NCCL group."
)
self
.
group
=
group
self
.
group
=
group
# note: this rank is the rank in the group
# note: this rank is the rank in the group
self
.
rank
=
dist
.
get_rank
(
group
)
self
.
rank
=
dist
.
get_rank
(
group
)
self
.
world_size
=
dist
.
get_world_size
(
group
)
self
.
world_size
=
dist
.
get_world_size
(
group
)
# if world_size == 1, no need to create communicator
if
self
.
world_size
==
1
:
self
.
available
=
False
self
.
disabled
=
True
self
.
stream
=
None
return
try
:
self
.
nccl
=
NCCLLibrary
(
library_path
)
except
Exception
:
# disable because of missing NCCL library
# e.g. in a non-GPU environment
self
.
available
=
False
self
.
disabled
=
True
self
.
stream
=
None
return
self
.
available
=
True
self
.
disabled
=
False
logger
.
info
(
"vLLM is using nccl==%s"
,
self
.
nccl
.
ncclGetVersion
())
if
self
.
rank
==
0
:
if
self
.
rank
==
0
:
self
.
unique_id
=
ncclGetUniqueId
()
# get the unique id from NCCL
self
.
unique_id
=
self
.
nccl
.
ncclGetUniqueId
()
else
:
else
:
self
.
unique_id
=
NcclUniqueId
()
# construct an empty unique id
self
.
unique_id
=
ncclUniqueId
()
tensor
=
torch
.
ByteTensor
(
list
(
self
.
unique_id
.
internal
))
tensor
=
torch
.
ByteTensor
(
list
(
self
.
unique_id
.
internal
))
ranks
=
dist
.
get_process_group_ranks
(
group
)
ranks
=
dist
.
get_process_group_ranks
(
group
)
# arg `src` in `broadcast` is the global rank
# arg `src` in `broadcast` is the global rank
...
@@ -246,7 +77,6 @@ class NCCLCommunicator:
...
@@ -246,7 +77,6 @@ class NCCLCommunicator:
byte_list
=
tensor
.
tolist
()
byte_list
=
tensor
.
tolist
()
for
i
,
byte
in
enumerate
(
byte_list
):
for
i
,
byte
in
enumerate
(
byte_list
):
self
.
unique_id
.
internal
[
i
]
=
byte
self
.
unique_id
.
internal
[
i
]
=
byte
self
.
comm
=
ctypes
.
c_void_p
()
if
device
is
None
:
if
device
is
None
:
local_rank
=
get_local_rank
()
local_rank
=
get_local_rank
()
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
...
@@ -261,15 +91,27 @@ class NCCLCommunicator:
...
@@ -261,15 +91,27 @@ class NCCLCommunicator:
# `torch.cuda.device` is a context manager that changes the
# `torch.cuda.device` is a context manager that changes the
# current cuda device to the specified one
# current cuda device to the specified one
with
torch
.
cuda
.
device
(
device
):
with
torch
.
cuda
.
device
(
device
):
NCCL_CHECK
(
self
.
comm
:
ncclComm_t
=
self
.
nccl
.
ncclCommInitRank
(
_c_ncclCommInitRank
(
ctypes
.
byref
(
self
.
comm
),
self
.
world_size
,
self
.
world_size
,
self
.
unique_id
,
self
.
rank
)
self
.
unique_id
,
self
.
rank
))
self
.
stream
=
torch
.
cuda
.
Stream
()
self
.
stream
=
torch
.
cuda
.
Stream
()
# A small all_reduce for warmup.
data
=
torch
.
zeros
(
1
,
device
=
device
)
self
.
all_reduce
(
data
)
self
.
stream
.
synchronize
()
del
data
# by default it is disabled, e.g. in profiling models and prefill phase.
# to use it, use under `with obj.change_state(enable=True)`, usually
# when we are using CUDA graph.
self
.
disabled
=
True
def
all_reduce
(
self
,
def
all_reduce
(
self
,
tensor
:
torch
.
Tensor
,
tensor
:
torch
.
Tensor
,
op
:
ReduceOp
=
ReduceOp
.
SUM
,
op
:
ReduceOp
=
ReduceOp
.
SUM
,
stream
=
None
):
stream
=
None
):
if
self
.
disabled
:
return
# nccl communicator created on a specific device
# nccl communicator created on a specific device
# will only work on tensors on the same device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
# otherwise it will cause "illegal memory access"
...
@@ -278,10 +120,66 @@ class NCCLCommunicator:
...
@@ -278,10 +120,66 @@ class NCCLCommunicator:
f
"but the input tensor is on
{
tensor
.
device
}
"
)
f
"but the input tensor is on
{
tensor
.
device
}
"
)
if
stream
is
None
:
if
stream
is
None
:
stream
=
self
.
stream
stream
=
self
.
stream
NCCL_CHECK
(
self
.
nccl
.
ncclAllReduce
(
buffer_type
(
tensor
.
data_ptr
()),
_c_ncclAllReduce
(
ctypes
.
c_void_p
(
tensor
.
data_ptr
()),
buffer_type
(
tensor
.
data_ptr
()),
tensor
.
numel
(),
ctypes
.
c_void_p
(
tensor
.
data_ptr
()),
ncclDataTypeEnum
.
from_torch
(
tensor
.
dtype
),
tensor
.
numel
(),
ncclRedOpTypeEnum
.
from_torch
(
op
),
self
.
comm
,
ncclDataTypeEnum
.
from_torch
(
tensor
.
dtype
),
cudaStream_t
(
stream
.
cuda_stream
))
ncclRedOpTypeEnum
.
from_torch
(
op
),
self
.
comm
,
ctypes
.
c_void_p
(
stream
.
cuda_stream
)))
def
send
(
self
,
tensor
:
torch
.
Tensor
,
dst
:
Optional
[
int
]
=
None
,
stream
=
None
):
if
self
.
disabled
:
return
assert
tensor
.
device
==
self
.
device
,
(
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
tensor
.
device
}
"
)
if
stream
is
None
:
stream
=
self
.
stream
if
dst
is
None
:
dst
=
(
self
.
rank
+
1
)
%
self
.
world_size
self
.
nccl
.
ncclSend
(
buffer_type
(
tensor
.
data_ptr
()),
tensor
.
numel
(),
ncclDataTypeEnum
.
from_torch
(
tensor
.
dtype
),
dst
,
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
))
def
recv
(
self
,
tensor
:
torch
.
Tensor
,
src
:
Optional
[
int
]
=
None
,
stream
=
None
):
if
self
.
disabled
:
return
assert
tensor
.
device
==
self
.
device
,
(
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
tensor
.
device
}
"
)
if
stream
is
None
:
stream
=
self
.
stream
if
src
is
None
:
src
=
(
self
.
rank
-
1
)
%
self
.
world_size
self
.
nccl
.
ncclRecv
(
buffer_type
(
tensor
.
data_ptr
()),
tensor
.
numel
(),
ncclDataTypeEnum
.
from_torch
(
tensor
.
dtype
),
src
,
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
))
@
contextmanager
def
change_state
(
self
,
enable
:
Optional
[
bool
]
=
None
,
stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
):
"""
A context manager to change the state of the communicator.
"""
if
enable
is
None
:
# guess a default value when not specified
enable
=
self
.
available
if
stream
is
None
:
stream
=
self
.
stream
old_disable
=
self
.
disabled
old_stream
=
self
.
stream
self
.
stream
=
stream
self
.
disabled
=
not
enable
yield
self
.
disabled
=
old_disable
self
.
stream
=
old_stream
vllm/distributed/device_communicators/pynccl_utils.py
deleted
100644 → 0
View file @
e5d707db
import
contextlib
from
typing
import
Optional
import
torch
from
torch.distributed
import
ProcessGroup
,
ReduceOp
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
try
:
from
vllm.distributed.device_communicators.pynccl
import
(
NCCLCommunicator
,
ncclGetVersion
)
except
Exception
as
e
:
# in non-NVIDIA environments, we can't import the nccl module
# e.g. when running on machines with AMD GPUs
logger
.
info
(
"Failed to import NCCL library: %s"
,
e
)
logger
.
info
(
"It is expected if you are not running on NVIDIA GPUs."
)
pass
comm
:
Optional
[
"NCCLCommunicator"
]
=
None
def
is_initialized
()
->
bool
:
"""Returns whether the NCCL backend is initialized."""
return
comm
is
not
None
@
contextlib
.
contextmanager
def
set_pynccl_stream
(
stream
:
torch
.
cuda
.
Stream
):
"""Set the cuda stream for communication"""
try
:
assert
comm
is
not
None
comm
.
stream
=
stream
yield
finally
:
pass
def
init_process_group
(
group
:
Optional
[
ProcessGroup
]
=
None
)
->
None
:
assert
not
is_initialized
()
global
comm
logger
.
info
(
"vLLM is using nccl==%s"
,
ncclGetVersion
())
comm
=
NCCLCommunicator
(
group
=
group
)
def
all_reduce
(
input_
:
torch
.
Tensor
,
op
=
ReduceOp
.
SUM
)
->
None
:
"""All-reduces the input tensor across the process group."""
assert
input_
.
is_cuda
,
f
"
{
input_
}
should be a cuda tensor"
assert
comm
is
not
None
comm
.
all_reduce
(
input_
,
op
)
def
destroy_process_group
()
->
None
:
global
comm
comm
=
None
def
get_world_size
()
->
int
:
"""Returns the world size."""
assert
comm
is
not
None
return
comm
.
world_size
def
get_nccl_backend
()
->
Optional
[
"NCCLCommunicator"
]:
return
comm
vllm/distributed/device_communicators/pynccl_wrapper.py
0 → 100644
View file @
b9e12416
# This file is a pure Python wrapper for the NCCL library.
# The main purpose is to use NCCL combined with CUDA graph.
# Before writing this script, we tried the following approach:
# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
# often gets stuck when initializing the NCCL communicator.
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
# contains many other potential cuda APIs, that are not allowed during
# capturing the CUDA graph. For further details, please check
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
#
# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
# doable, but we often encounter issues related with nccl versions, and need
# to switch between different versions of NCCL. See
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
# A C/C++ binding is not flexible enough to handle this. It requires
# recompilation of the code every time we want to switch between different
# versions. This current implementation, with a **pure** Python wrapper, is
# more flexible. We can easily switch between different versions of NCCL by
# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file`
# variable in the code.
import
ctypes
import
platform
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.distributed
import
ReduceOp
from
vllm.logger
import
init_logger
from
vllm.utils
import
find_nccl_library
logger
=
init_logger
(
__name__
)
# === export types and functions from nccl to Python ===
# for the original nccl definition, please check
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
ncclResult_t
=
ctypes
.
c_int
ncclComm_t
=
ctypes
.
c_void_p
class
ncclUniqueId
(
ctypes
.
Structure
):
_fields_
=
[(
"internal"
,
ctypes
.
c_byte
*
128
)]
cudaStream_t
=
ctypes
.
c_void_p
buffer_type
=
ctypes
.
c_void_p
ncclDataType_t
=
ctypes
.
c_int
class
ncclDataTypeEnum
:
ncclInt8
=
0
ncclChar
=
0
ncclUint8
=
1
ncclInt32
=
2
ncclInt
=
2
ncclUint32
=
3
ncclInt64
=
4
ncclUint64
=
5
ncclFloat16
=
6
ncclHalf
=
6
ncclFloat32
=
7
ncclFloat
=
7
ncclFloat64
=
8
ncclDouble
=
8
ncclBfloat16
=
9
ncclNumTypes
=
10
@
classmethod
def
from_torch
(
cls
,
dtype
:
torch
.
dtype
)
->
int
:
if
dtype
==
torch
.
int8
:
return
cls
.
ncclInt8
if
dtype
==
torch
.
uint8
:
return
cls
.
ncclUint8
if
dtype
==
torch
.
int32
:
return
cls
.
ncclInt32
if
dtype
==
torch
.
int64
:
return
cls
.
ncclInt64
if
dtype
==
torch
.
float16
:
return
cls
.
ncclFloat16
if
dtype
==
torch
.
float32
:
return
cls
.
ncclFloat32
if
dtype
==
torch
.
float64
:
return
cls
.
ncclFloat64
if
dtype
==
torch
.
bfloat16
:
return
cls
.
ncclBfloat16
raise
ValueError
(
f
"Unsupported dtype:
{
dtype
}
"
)
ncclRedOp_t
=
ctypes
.
c_int
class
ncclRedOpTypeEnum
:
ncclSum
=
0
ncclProd
=
1
ncclMax
=
2
ncclMin
=
3
ncclAvg
=
4
ncclNumOps
=
5
@
classmethod
def
from_torch
(
cls
,
op
:
ReduceOp
)
->
int
:
if
op
==
ReduceOp
.
SUM
:
return
cls
.
ncclSum
if
op
==
ReduceOp
.
PRODUCT
:
return
cls
.
ncclProd
if
op
==
ReduceOp
.
MAX
:
return
cls
.
ncclMax
if
op
==
ReduceOp
.
MIN
:
return
cls
.
ncclMin
if
op
==
ReduceOp
.
AVG
:
return
cls
.
ncclAvg
raise
ValueError
(
f
"Unsupported op:
{
op
}
"
)
@
dataclass
class
Function
:
name
:
str
restype
:
Any
argtypes
:
List
[
Any
]
class
NCCLLibrary
:
exported_functions
=
[
# const char* ncclGetErrorString(ncclResult_t result)
Function
(
"ncclGetErrorString"
,
ctypes
.
c_char_p
,
[
ncclResult_t
]),
# ncclResult_t ncclGetVersion(int *version);
Function
(
"ncclGetVersion"
,
ncclResult_t
,
[
ctypes
.
POINTER
(
ctypes
.
c_int
)]),
# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
Function
(
"ncclGetUniqueId"
,
ncclResult_t
,
[
ctypes
.
POINTER
(
ncclUniqueId
)]),
# ncclResult_t ncclCommInitRank(
# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
# note that ncclComm_t is a pointer type, so the first argument
# is a pointer to a pointer
Function
(
"ncclCommInitRank"
,
ncclResult_t
,
[
ctypes
.
POINTER
(
ncclComm_t
),
ctypes
.
c_int
,
ncclUniqueId
,
ctypes
.
c_int
]),
# ncclResult_t ncclAllReduce(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
# cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function
(
"ncclAllReduce"
,
ncclResult_t
,
[
buffer_type
,
buffer_type
,
ctypes
.
c_size_t
,
ncclDataType_t
,
ncclRedOp_t
,
ncclComm_t
,
cudaStream_t
]),
# ncclResult_t ncclSend(
# const void* sendbuff, size_t count, ncclDataType_t datatype,
# int dest, ncclComm_t comm, cudaStream_t stream);
Function
(
"ncclSend"
,
ncclResult_t
,
[
buffer_type
,
ctypes
.
c_size_t
,
ncclDataType_t
,
ctypes
.
c_int
,
ncclComm_t
,
cudaStream_t
]),
# ncclResult_t ncclRecv(
# void* recvbuff, size_t count, ncclDataType_t datatype,
# int src, ncclComm_t comm, cudaStream_t stream);
Function
(
"ncclRecv"
,
ncclResult_t
,
[
buffer_type
,
ctypes
.
c_size_t
,
ncclDataType_t
,
ctypes
.
c_int
,
ncclComm_t
,
cudaStream_t
]),
# be cautious! this is a collective call, it will block until all
# processes in the communicator have called this function.
# because Python object destruction can happen in random order,
# it is better not to call it at all.
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
Function
(
"ncclCommDestroy"
,
ncclResult_t
,
[
ncclComm_t
]),
]
# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache
:
Dict
[
str
,
Any
]
=
{}
# class attribute to store the mapping from library path
# to the corresponding dictionary
path_to_dict_mapping
:
Dict
[
str
,
Dict
[
str
,
Any
]]
=
{}
def
__init__
(
self
,
so_file
:
Optional
[
str
]
=
None
):
so_file
=
so_file
or
find_nccl_library
()
try
:
if
so_file
not
in
NCCLLibrary
.
path_to_dict_mapping
:
lib
=
ctypes
.
CDLL
(
so_file
)
NCCLLibrary
.
path_to_library_cache
[
so_file
]
=
lib
self
.
lib
=
NCCLLibrary
.
path_to_library_cache
[
so_file
]
except
Exception
as
e
:
logger
.
error
(
"Failed to load NCCL library from %s ."
"It is expected if you are not running on NVIDIA/AMD GPUs."
"Otherwise, the nccl library might not exist, be corrupted "
"or it does not support the current platform %s."
"If you already have the library, please set the "
"environment variable VLLM_NCCL_SO_PATH"
" to point to the correct nccl library path."
,
so_file
,
platform
.
platform
())
raise
e
if
so_file
not
in
NCCLLibrary
.
path_to_dict_mapping
:
_funcs
=
{}
for
func
in
NCCLLibrary
.
exported_functions
:
f
=
getattr
(
self
.
lib
,
func
.
name
)
f
.
restype
=
func
.
restype
f
.
argtypes
=
func
.
argtypes
_funcs
[
func
.
name
]
=
f
NCCLLibrary
.
path_to_dict_mapping
[
so_file
]
=
_funcs
self
.
_funcs
=
NCCLLibrary
.
path_to_dict_mapping
[
so_file
]
def
ncclGetErrorString
(
self
,
result
:
ncclResult_t
)
->
str
:
return
self
.
_funcs
[
"ncclGetErrorString"
](
result
).
decode
(
"utf-8"
)
def
NCCL_CHECK
(
self
,
result
:
ncclResult_t
)
->
None
:
if
result
!=
0
:
error_str
=
self
.
ncclGetErrorString
(
result
)
raise
RuntimeError
(
f
"NCCL error:
{
error_str
}
"
)
def
ncclGetVersion
(
self
)
->
str
:
version
=
ctypes
.
c_int
()
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclGetVersion"
](
ctypes
.
byref
(
version
)))
version_str
=
str
(
version
.
value
)
# something like 21903 --> "2.19.3"
major
=
version_str
[
0
].
lstrip
(
"0"
)
minor
=
version_str
[
1
:
3
].
lstrip
(
"0"
)
patch
=
version_str
[
3
:].
lstrip
(
"0"
)
return
f
"
{
major
}
.
{
minor
}
.
{
patch
}
"
def
ncclGetUniqueId
(
self
)
->
ncclUniqueId
:
unique_id
=
ncclUniqueId
()
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclGetUniqueId"
](
ctypes
.
byref
(
unique_id
)))
return
unique_id
def
ncclCommInitRank
(
self
,
world_size
:
int
,
unique_id
:
ncclUniqueId
,
rank
:
int
)
->
ncclComm_t
:
comm
=
ncclComm_t
()
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclCommInitRank"
](
ctypes
.
byref
(
comm
),
world_size
,
unique_id
,
rank
))
return
comm
def
ncclAllReduce
(
self
,
sendbuff
:
buffer_type
,
recvbuff
:
buffer_type
,
count
:
int
,
datatype
:
int
,
op
:
int
,
comm
:
ncclComm_t
,
stream
:
cudaStream_t
)
->
None
:
# `datatype` actually should be `ncclDataType_t`
# and `op` should be `ncclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclAllReduce"
](
sendbuff
,
recvbuff
,
count
,
datatype
,
op
,
comm
,
stream
))
def
ncclSend
(
self
,
sendbuff
:
buffer_type
,
count
:
int
,
datatype
:
int
,
dest
:
int
,
comm
:
ncclComm_t
,
stream
:
cudaStream_t
)
->
None
:
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclSend"
](
sendbuff
,
count
,
datatype
,
dest
,
comm
,
stream
))
def
ncclRecv
(
self
,
recvbuff
:
buffer_type
,
count
:
int
,
datatype
:
int
,
src
:
int
,
comm
:
ncclComm_t
,
stream
:
cudaStream_t
)
->
None
:
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclRecv"
](
recvbuff
,
count
,
datatype
,
src
,
comm
,
stream
))
def
ncclCommDestroy
(
self
,
comm
:
ncclComm_t
)
->
None
:
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclCommDestroy"
](
comm
))
__all__
=
[
"NCCLLibrary"
,
"ncclDataTypeEnum"
,
"ncclRedOpTypeEnum"
,
"ncclUniqueId"
,
"ncclComm_t"
,
"cudaStream_t"
,
"buffer_type"
]
vllm/distributed/parallel_state.py
View file @
b9e12416
...
@@ -3,21 +3,27 @@
...
@@ -3,21 +3,27 @@
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Tensor and pipeline parallel groups."""
"""Tensor and pipeline parallel groups."""
import
contextlib
from
typing
import
List
,
Optional
from
typing
import
Optional
import
torch
import
torch
from
torch.distributed
import
ProcessGroup
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_ENABLE_CUSTOM_ALL_REDUCE
=
True
# Tensor model parallel group that the current rank belongs to.
# Tensor model parallel group that the current rank belongs to.
_TP_DEVICE_GROUP
=
None
_TP_DEVICE_GROUP
:
Optional
[
ProcessGroup
]
=
None
_TP_CPU_GROUP
=
None
_TP_CPU_GROUP
:
Optional
[
ProcessGroup
]
=
None
_TP_PYNCCL_COMMUNICATOR
=
None
_TP_CA_COMMUNICATOR
=
None
# Pipeline model parallel group that the current rank belongs to.
# Pipeline model parallel group that the current rank belongs to.
_PIPELINE_MODEL_PARALLEL_GROUP
=
None
_PP_DEVICE_GROUP
:
Optional
[
ProcessGroup
]
=
None
_PP_CPU_GROUP
:
Optional
[
ProcessGroup
]
=
None
_PP_PYNCCL_COMMUNICATOR
=
None
# when people blindly call `torch.distributed.all_reduce` etc,
# when people blindly call `torch.distributed.all_reduce` etc,
# it will use this group. It is initialized with the `backend`
# it will use this group. It is initialized with the `backend`
...
@@ -41,11 +47,31 @@ _CPU_WORLD_GROUP = None
...
@@ -41,11 +47,31 @@ _CPU_WORLD_GROUP = None
# A list of global ranks for each pipeline group to ease calculation of the
# A list of global ranks for each pipeline group to ease calculation of the
# source rank when broadcasting from the first or last pipeline stage.
# source rank when broadcasting from the first or last pipeline stage.
_P
IPELINE
_GLOBAL_RANKS
=
None
_P
P
_GLOBAL_RANKS
:
Optional
[
List
[
int
]]
=
None
_LOCAL_RANK
=
-
1
_LOCAL_RANK
=
-
1
def
set_custom_all_reduce
(
enable
:
bool
):
global
_ENABLE_CUSTOM_ALL_REDUCE
_ENABLE_CUSTOM_ALL_REDUCE
=
enable
def
get_pp_pynccl_communicator
():
global
_PP_PYNCCL_COMMUNICATOR
return
_PP_PYNCCL_COMMUNICATOR
def
get_tp_pynccl_communicator
():
global
_TP_PYNCCL_COMMUNICATOR
return
_TP_PYNCCL_COMMUNICATOR
def
get_tp_ca_communicator
():
global
_TP_CA_COMMUNICATOR
return
_TP_CA_COMMUNICATOR
def
get_local_rank
():
def
get_local_rank
():
global
_LOCAL_RANK
global
_LOCAL_RANK
return
_LOCAL_RANK
return
_LOCAL_RANK
...
@@ -80,10 +106,23 @@ def init_distributed_environment(
...
@@ -80,10 +106,23 @@ def init_distributed_environment(
# set the local rank
# set the local rank
# local_rank is not available in torch ProcessGroup,
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816
# see https://github.com/pytorch/pytorch/issues/122816
if
local_rank
==
-
1
and
distributed_init_method
==
"env://"
:
if
local_rank
==
-
1
:
local_rank
=
envs
.
LOCAL_RANK
# local rank not set, this usually happens in single-node
# setting, where we can use rank as local rank
if
distributed_init_method
==
"env://"
:
local_rank
=
envs
.
LOCAL_RANK
else
:
local_rank
=
rank
global
_LOCAL_RANK
global
_LOCAL_RANK
_LOCAL_RANK
=
local_rank
_LOCAL_RANK
=
local_rank
# A small all_reduce for warmup.
data
=
torch
.
zeros
(
1
)
if
torch
.
cuda
.
is_available
():
data
=
data
.
to
(
device
=
f
"cuda:
{
local_rank
}
"
)
torch
.
distributed
.
all_reduce
(
data
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
synchronize
()
del
data
def
initialize_model_parallel
(
def
initialize_model_parallel
(
...
@@ -134,28 +173,55 @@ def initialize_model_parallel(
...
@@ -134,28 +173,55 @@ def initialize_model_parallel(
# Build the tensor model-parallel groups.
# Build the tensor model-parallel groups.
global
_TP_DEVICE_GROUP
,
_TP_CPU_GROUP
global
_TP_DEVICE_GROUP
,
_TP_CPU_GROUP
global
_TP_PYNCCL_COMMUNICATOR
,
_TP_CA_COMMUNICATOR
assert
_TP_DEVICE_GROUP
is
None
,
(
assert
_TP_DEVICE_GROUP
is
None
,
(
"tensor model parallel group is already initialized"
)
"tensor model parallel group is already initialized"
)
for
i
in
range
(
num_tensor_model_parallel_groups
):
for
i
in
range
(
num_tensor_model_parallel_groups
):
ranks
=
range
(
i
*
tensor_model_parallel_size
,
ranks
=
list
(
(
i
+
1
)
*
tensor_model_parallel_size
)
range
(
i
*
tensor_model_parallel_size
,
(
i
+
1
)
*
tensor_model_parallel_size
))
group
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
backend
)
group
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
backend
)
cpu_group
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
"gloo"
)
cpu_group
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
"gloo"
)
if
rank
in
ranks
:
if
rank
in
ranks
:
_TP_DEVICE_GROUP
=
group
_TP_DEVICE_GROUP
=
group
_TP_CPU_GROUP
=
cpu_group
_TP_CPU_GROUP
=
cpu_group
from
vllm.distributed.device_communicators.pynccl
import
PyNcclCommunicator
if
tensor_model_parallel_size
>
1
:
_TP_PYNCCL_COMMUNICATOR
=
PyNcclCommunicator
(
group
=
_TP_CPU_GROUP
,
device
=
_LOCAL_RANK
,
)
# Initialize a custom fast all-reduce implementation.
if
_ENABLE_CUSTOM_ALL_REDUCE
:
from
vllm.distributed.device_communicators.custom_all_reduce
import
(
CustomAllreduce
)
_TP_CA_COMMUNICATOR
=
CustomAllreduce
(
group
=
_TP_CPU_GROUP
,
device
=
_LOCAL_RANK
,
)
# Build the pipeline model-parallel groups.
# Build the pipeline model-parallel groups.
global
_PIPELINE_MODEL_PARALLEL_GROUP
global
_PP_DEVICE_GROUP
,
_PP_CPU_GROUP
global
_PIPELINE_GLOBAL_RANKS
global
_PP_PYNCCL_COMMUNICATOR
assert
_PIPELINE_MODEL_PARALLEL_GROUP
is
None
,
(
global
_PP_GLOBAL_RANKS
assert
_PP_DEVICE_GROUP
is
None
,
(
"pipeline model parallel group is already initialized"
)
"pipeline model parallel group is already initialized"
)
for
i
in
range
(
num_pipeline_model_parallel_groups
):
for
i
in
range
(
num_pipeline_model_parallel_groups
):
ranks
=
range
(
i
,
world_size
,
num_pipeline_model_parallel_groups
)
ranks
=
list
(
range
(
i
,
world_size
,
num_pipeline_model_parallel_groups
)
)
group
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
backend
)
group
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
backend
)
cpu_group
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
"gloo"
)
if
rank
in
ranks
:
if
rank
in
ranks
:
_PIPELINE_MODEL_PARALLEL_GROUP
=
group
_PP_DEVICE_GROUP
=
group
_PIPELINE_GLOBAL_RANKS
=
ranks
_PP_CPU_GROUP
=
cpu_group
_PP_GLOBAL_RANKS
=
ranks
if
pipeline_model_parallel_size
>
1
:
_PP_PYNCCL_COMMUNICATOR
=
PyNcclCommunicator
(
group
=
_PP_CPU_GROUP
,
device
=
_LOCAL_RANK
,
)
def
ensure_model_parallel_initialized
(
def
ensure_model_parallel_initialized
(
...
@@ -188,8 +254,7 @@ def ensure_model_parallel_initialized(
...
@@ -188,8 +254,7 @@ def ensure_model_parallel_initialized(
def
model_parallel_is_initialized
():
def
model_parallel_is_initialized
():
"""Check if tensor and pipeline parallel groups are initialized."""
"""Check if tensor and pipeline parallel groups are initialized."""
return
(
_TP_DEVICE_GROUP
is
not
None
return
(
_TP_DEVICE_GROUP
is
not
None
and
_PP_DEVICE_GROUP
is
not
None
)
and
_PIPELINE_MODEL_PARALLEL_GROUP
is
not
None
)
def
get_cpu_world_group
():
def
get_cpu_world_group
():
...
@@ -214,9 +279,16 @@ def get_tensor_model_parallel_cpu_group():
...
@@ -214,9 +279,16 @@ def get_tensor_model_parallel_cpu_group():
def
get_pipeline_model_parallel_group
():
def
get_pipeline_model_parallel_group
():
"""Get the pipeline model parallel group the caller rank belongs to."""
"""Get the pipeline model parallel group the caller rank belongs to."""
assert
_P
IPELINE_MODEL_PARALLEL
_GROUP
is
not
None
,
(
assert
_P
P_DEVICE
_GROUP
is
not
None
,
(
"pipeline model parallel group is not initialized"
)
"pipeline model parallel group is not initialized"
)
return
_PIPELINE_MODEL_PARALLEL_GROUP
return
_PP_DEVICE_GROUP
def
get_pipeline_model_parallel_cpu_group
():
"""Get the pipeline model parallel cpu group the caller rank belongs to."""
assert
_PP_CPU_GROUP
is
not
None
,
(
"pipeline model parallel cpu group is not initialized"
)
return
_PP_CPU_GROUP
def
get_tensor_model_parallel_world_size
():
def
get_tensor_model_parallel_world_size
():
...
@@ -253,36 +325,36 @@ def get_tensor_model_parallel_src_rank():
...
@@ -253,36 +325,36 @@ def get_tensor_model_parallel_src_rank():
def
get_pipeline_model_parallel_first_rank
():
def
get_pipeline_model_parallel_first_rank
():
"""Return the global rank of the first process in the pipeline for the
"""Return the global rank of the first process in the pipeline for the
current tensor parallel group"""
current tensor parallel group"""
assert
_P
IPELINE
_GLOBAL_RANKS
is
not
None
,
(
assert
_P
P
_GLOBAL_RANKS
is
not
None
,
(
"Pipeline parallel group is not initialized"
)
"Pipeline parallel group is not initialized"
)
return
_P
IPELINE
_GLOBAL_RANKS
[
0
]
return
_P
P
_GLOBAL_RANKS
[
0
]
def
get_pipeline_model_parallel_last_rank
():
def
get_pipeline_model_parallel_last_rank
():
"""Return the global rank of the last process in the pipeline for the
"""Return the global rank of the last process in the pipeline for the
current tensor parallel group"""
current tensor parallel group"""
assert
_P
IPELINE
_GLOBAL_RANKS
is
not
None
,
(
assert
_P
P
_GLOBAL_RANKS
is
not
None
,
(
"Pipeline parallel group is not initialized"
)
"Pipeline parallel group is not initialized"
)
last_rank_local
=
get_pipeline_model_parallel_world_size
()
-
1
last_rank_local
=
get_pipeline_model_parallel_world_size
()
-
1
return
_P
IPELINE
_GLOBAL_RANKS
[
last_rank_local
]
return
_P
P
_GLOBAL_RANKS
[
last_rank_local
]
def
get_pipeline_model_parallel_next_rank
():
def
get_pipeline_model_parallel_next_rank
():
"""Return the global rank that follows the caller in the pipeline"""
"""Return the global rank that follows the caller in the pipeline"""
assert
_P
IPELINE
_GLOBAL_RANKS
is
not
None
,
(
assert
_P
P
_GLOBAL_RANKS
is
not
None
,
(
"Pipeline parallel group is not initialized"
)
"Pipeline parallel group is not initialized"
)
rank_in_pipeline
=
get_pipeline_model_parallel_rank
()
rank_in_pipeline
=
get_pipeline_model_parallel_rank
()
world_size
=
get_pipeline_model_parallel_world_size
()
world_size
=
get_pipeline_model_parallel_world_size
()
return
_P
IPELINE
_GLOBAL_RANKS
[(
rank_in_pipeline
+
1
)
%
world_size
]
return
_P
P
_GLOBAL_RANKS
[(
rank_in_pipeline
+
1
)
%
world_size
]
def
get_pipeline_model_parallel_prev_rank
():
def
get_pipeline_model_parallel_prev_rank
():
"""Return the global rank that precedes the caller in the pipeline"""
"""Return the global rank that precedes the caller in the pipeline"""
assert
_P
IPELINE
_GLOBAL_RANKS
is
not
None
,
(
assert
_P
P
_GLOBAL_RANKS
is
not
None
,
(
"Pipeline parallel group is not initialized"
)
"Pipeline parallel group is not initialized"
)
rank_in_pipeline
=
get_pipeline_model_parallel_rank
()
rank_in_pipeline
=
get_pipeline_model_parallel_rank
()
world_size
=
get_pipeline_model_parallel_world_size
()
world_size
=
get_pipeline_model_parallel_world_size
()
return
_P
IPELINE
_GLOBAL_RANKS
[(
rank_in_pipeline
-
1
)
%
world_size
]
return
_P
P
_GLOBAL_RANKS
[(
rank_in_pipeline
-
1
)
%
world_size
]
def
destroy_model_parallel
():
def
destroy_model_parallel
():
...
@@ -295,45 +367,12 @@ def destroy_model_parallel():
...
@@ -295,45 +367,12 @@ def destroy_model_parallel():
if
_TP_CPU_GROUP
:
if
_TP_CPU_GROUP
:
torch
.
distributed
.
destroy_process_group
(
_TP_CPU_GROUP
)
torch
.
distributed
.
destroy_process_group
(
_TP_CPU_GROUP
)
_TP_CPU_GROUP
=
None
_TP_CPU_GROUP
=
None
global
_PIPELINE_MODEL_PARALLEL_GROUP
global
_TP_PYNCCL_COMMUNICATOR
if
_PIPELINE_MODEL_PARALLEL_GROUP
:
_TP_PYNCCL_COMMUNICATOR
=
None
torch
.
distributed
.
destroy_process_group
(
_PIPELINE_MODEL_PARALLEL_GROUP
)
_PIPELINE_MODEL_PARALLEL_GROUP
=
None
global
_PP_DEVICE_GROUP
global
_PIPELINE_GLOBAL_RANKS
if
_PP_DEVICE_GROUP
:
_PIPELINE_GLOBAL_RANKS
=
None
torch
.
distributed
.
destroy_process_group
(
_PP_DEVICE_GROUP
)
from
vllm.distributed.device_communicators
import
pynccl_utils
_PP_DEVICE_GROUP
=
None
global
_PP_GLOBAL_RANKS
# Destroy the pynccl states if any.
_PP_GLOBAL_RANKS
=
None
pynccl_utils
.
destroy_process_group
()
# Whether to use pynccl for nccl all reduce.
# We use pynccl for all reduce when using CUDA graph, because torch.distributed
# is not well supported by CUDA graph.
_ENABLE_PYNCCL_FOR_ALL_REDUCE
=
False
@
contextlib
.
contextmanager
def
with_pynccl_for_all_reduce
():
from
vllm.distributed.device_communicators
import
pynccl_utils
"""use pynccl instead of torch.distributed for all reduce"""
tp_size
=
get_tensor_model_parallel_world_size
()
if
tp_size
==
1
:
# No-op.
# NOTE(woosuk): We don't initialize pynccl when tp_size is 1.
yield
else
:
global
_ENABLE_PYNCCL_FOR_ALL_REDUCE
old
=
_ENABLE_PYNCCL_FOR_ALL_REDUCE
_ENABLE_PYNCCL_FOR_ALL_REDUCE
=
True
stream
=
torch
.
cuda
.
current_stream
()
with
pynccl_utils
.
set_pynccl_stream
(
stream
):
yield
_ENABLE_PYNCCL_FOR_ALL_REDUCE
=
old
def
is_pynccl_enabled_for_all_reduce
():
"""check if pynccl is enabled for all reduce"""
global
_ENABLE_PYNCCL_FOR_ALL_REDUCE
return
_ENABLE_PYNCCL_FOR_ALL_REDUCE
Prev
1
…
10
11
12
13
14
15
16
17
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment