Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e661d594
Commit
e661d594
authored
Aug 12, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.5.4' into v0.5.4-dtk24.04.1
parents
6b16ea2e
4db5176d
Changes
374
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
412 additions
and
307 deletions
+412
-307
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+8
-2
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+6
-2
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+10
-10
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+7
-2
vllm/attention/layer.py
vllm/attention/layer.py
+2
-1
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+5
-2
vllm/config.py
vllm/config.py
+52
-17
vllm/connections.py
vllm/connections.py
+2
-2
vllm/core/block_manager_v1.py
vllm/core/block_manager_v1.py
+1
-1
vllm/core/policy.py
vllm/core/policy.py
+0
-45
vllm/core/scheduler.py
vllm/core/scheduler.py
+55
-81
vllm/distributed/device_communicators/cuda_wrapper.py
vllm/distributed/device_communicators/cuda_wrapper.py
+22
-22
vllm/distributed/device_communicators/custom_all_reduce_utils.py
...stributed/device_communicators/custom_all_reduce_utils.py
+3
-1
vllm/distributed/device_communicators/shm_broadcast.py
vllm/distributed/device_communicators/shm_broadcast.py
+21
-39
vllm/distributed/device_communicators/tpu_communicator.py
vllm/distributed/device_communicators/tpu_communicator.py
+31
-0
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+81
-39
vllm/distributed/utils.py
vllm/distributed/utils.py
+27
-5
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+11
-6
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+34
-5
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+34
-25
No files found.
vllm/attention/backends/rocm_flash_attn.py
View file @
e661d594
...
@@ -244,9 +244,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -244,9 +244,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
sliding_window
:
Optional
[
int
],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
None
:
)
->
None
:
assert
blocksparse_params
is
None
,
ValueError
(
if
blocksparse_params
is
not
None
:
"ROCFlashAttention does not support blocksparse attention."
)
raise
ValueError
(
"ROCmFlashAttention does not support blocksparse attention."
)
if
logits_soft_cap
is
not
None
:
raise
ValueError
(
"ROCmFlashAttention does not support attention logits soft "
"capping."
)
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
scale
=
float
(
scale
)
...
...
vllm/attention/backends/torch_sdpa.py
View file @
e661d594
...
@@ -109,9 +109,13 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
...
@@ -109,9 +109,13 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
sliding_window
:
Optional
[
int
],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
None
:
)
->
None
:
assert
blocksparse_params
is
None
,
ValueError
(
if
blocksparse_params
is
not
None
:
"Torch SPDA does not support block-sparse attention."
)
raise
ValueError
(
"Torch SPDA does not support block-sparse attention."
)
if
logits_soft_cap
is
not
None
:
raise
ValueError
(
"Torch SPDA does not support logits soft cap."
)
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
scale
=
float
(
scale
)
...
...
vllm/attention/backends/utils.py
View file @
e661d594
...
@@ -149,6 +149,15 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
...
@@ -149,6 +149,15 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
for
inter_data
in
self
.
input_builder
.
inter_data_list
:
for
inter_data
in
self
.
input_builder
.
inter_data_list
:
self
.
_add_seq_group
(
inter_data
,
self
.
_add_seq_group
(
inter_data
,
self
.
input_builder
.
chunked_prefill_enabled
)
self
.
input_builder
.
chunked_prefill_enabled
)
...
@@ -156,15 +165,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
...
@@ -156,15 +165,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
device
=
self
.
runner
.
device
device
=
self
.
runner
.
device
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
logits_soft_cap
=
getattr
(
self
.
runner
.
model_config
.
hf_config
,
"attn_logit_softcapping"
,
None
)
if
logits_soft_cap
is
not
None
:
raise
ValueError
(
"Please use Flashinfer backend for models with logits_soft_cap "
"(i.e., Gemma-2). Otherwise, the output might be wrong. "
"Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER."
)
max_query_len
=
max
(
query_lens
)
max_query_len
=
max
(
query_lens
)
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
max_decode_seq_len
=
max
(
self
.
curr_seq_lens
,
default
=
0
)
max_decode_seq_len
=
max
(
self
.
curr_seq_lens
,
default
=
0
)
...
@@ -173,7 +173,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
...
@@ -173,7 +173,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
if
use_captured_graph
:
if
use_captured_graph
:
self
.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_size
)
self
.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_size
)
self
.
block_tables
.
extend
([]
*
cuda_graph_pad_size
)
self
.
block_tables
.
extend
([]
*
cuda_graph_pad_size
)
num_decode_tokens
=
batch_size
+
cuda_graph_pad_size
num_decode_tokens
=
batch_size
# The shape of graph_block_tables is
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
# [max batch size, max context len // block size].
...
...
vllm/attention/backends/xformers.py
View file @
e661d594
...
@@ -408,9 +408,14 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
...
@@ -408,9 +408,14 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
sliding_window
:
Optional
[
int
],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
None
:
)
->
None
:
assert
blocksparse_params
is
None
,
ValueError
(
if
blocksparse_params
is
not
None
:
"XFormer does not support block-sparse attention."
)
raise
ValueError
(
"XFormers does not support block-sparse attention."
)
if
logits_soft_cap
is
not
None
:
raise
ValueError
(
"XFormers does not support attention logits soft capping."
)
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
scale
=
float
(
scale
)
...
...
vllm/attention/layer.py
View file @
e661d594
...
@@ -34,6 +34,7 @@ class Attention(nn.Module):
...
@@ -34,6 +34,7 @@ class Attention(nn.Module):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -82,7 +83,7 @@ class Attention(nn.Module):
...
@@ -82,7 +83,7 @@ class Attention(nn.Module):
impl_cls
=
attn_backend
.
get_impl_cls
()
impl_cls
=
attn_backend
.
get_impl_cls
()
self
.
impl
=
impl_cls
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
self
.
impl
=
impl_cls
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
blocksparse_params
)
blocksparse_params
,
logits_soft_cap
)
def
forward
(
def
forward
(
self
,
self
,
...
...
vllm/attention/ops/paged_attn.py
View file @
e661d594
...
@@ -4,7 +4,10 @@ from typing import List, Optional, Tuple
...
@@ -4,7 +4,10 @@ from typing import List, Optional, Tuple
import
torch
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
from
vllm.triton_utils
import
HAS_TRITON
if
HAS_TRITON
:
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE
=
512
_PARTITION_SIZE
=
512
...
@@ -31,7 +34,7 @@ class PagedAttention:
...
@@ -31,7 +34,7 @@ class PagedAttention:
@
staticmethod
@
staticmethod
def
get_supported_head_sizes
()
->
List
[
int
]:
def
get_supported_head_sizes
()
->
List
[
int
]:
return
[
64
,
80
,
96
,
112
,
128
,
192
,
256
]
return
[
64
,
80
,
96
,
112
,
120
,
128
,
192
,
256
]
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
...
...
vllm/config.py
View file @
e661d594
...
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Type, Union
...
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Type, Union
import
torch
import
torch
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
...
@@ -31,6 +32,7 @@ _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
...
@@ -31,6 +32,7 @@ _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_PP_SUPPORTED_MODELS
=
[
_PP_SUPPORTED_MODELS
=
[
"AquilaModel"
,
"AquilaModel"
,
"AquilaForCausalLM"
,
"AquilaForCausalLM"
,
"DeepseekV2ForCausalLM"
,
"InternLMForCausalLM"
,
"InternLMForCausalLM"
,
"LlamaForCausalLM"
,
"LlamaForCausalLM"
,
"LLaMAForCausalLM"
,
"LLaMAForCausalLM"
,
...
@@ -38,6 +40,10 @@ _PP_SUPPORTED_MODELS = [
...
@@ -38,6 +40,10 @@ _PP_SUPPORTED_MODELS = [
"Phi3ForCausalLM"
,
"Phi3ForCausalLM"
,
"GPT2LMHeadModel"
,
"GPT2LMHeadModel"
,
"MixtralForCausalLM"
,
"MixtralForCausalLM"
,
"NemotronForCausalLM"
,
"Qwen2ForCausalLM"
,
"Qwen2MoeForCausalLM"
,
"QWenLMHeadModel"
,
]
]
...
@@ -195,13 +201,17 @@ class ModelConfig:
...
@@ -195,13 +201,17 @@ class ModelConfig:
def
_parse_quant_hf_config
(
self
):
def
_parse_quant_hf_config
(
self
):
quant_cfg
=
getattr
(
self
.
hf_config
,
"quantization_config"
,
None
)
quant_cfg
=
getattr
(
self
.
hf_config
,
"quantization_config"
,
None
)
if
quant_cfg
is
None
:
if
quant_cfg
is
None
:
# compress-tensors uses a "compression_config" key
# compress
ed
-tensors uses a "compression_config" key
quant_cfg
=
getattr
(
self
.
hf_config
,
"compression_config"
,
None
)
quant_cfg
=
getattr
(
self
.
hf_config
,
"compression_config"
,
None
)
return
quant_cfg
return
quant_cfg
def
_verify_quantization
(
self
)
->
None
:
def
_verify_quantization
(
self
)
->
None
:
supported_quantization
=
[
*
QUANTIZATION_METHODS
]
supported_quantization
=
[
*
QUANTIZATION_METHODS
]
rocm_supported_quantization
=
[
"gptq"
,
"squeezellm"
,
"awq"
]
rocm_supported_quantization
=
[
"gptq"
,
"squeezellm"
,
"awq"
]
optimized_quantization_methods
=
[
"fp8"
,
"marlin"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
"awq_marlin"
,
"fbgemm_fp8"
,
"compressed_tensors"
,
"compressed-tensors"
]
if
self
.
quantization
is
not
None
:
if
self
.
quantization
is
not
None
:
self
.
quantization
=
self
.
quantization
.
lower
()
self
.
quantization
=
self
.
quantization
.
lower
()
...
@@ -240,9 +250,7 @@ class ModelConfig:
...
@@ -240,9 +250,7 @@ class ModelConfig:
raise
ValueError
(
raise
ValueError
(
f
"
{
self
.
quantization
}
quantization is currently not "
f
"
{
self
.
quantization
}
quantization is currently not "
f
"supported in ROCm."
)
f
"supported in ROCm."
)
if
(
self
.
quantization
if
self
.
quantization
not
in
optimized_quantization_methods
:
not
in
(
"fp8"
,
"marlin"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
"awq_marlin"
,
"fbgemm_fp8"
,
"compressed_tensors"
)):
logger
.
warning
(
logger
.
warning
(
"%s quantization is not fully "
"%s quantization is not fully "
"optimized yet. The speed can be slower than "
"optimized yet. The speed can be slower than "
...
@@ -281,6 +289,10 @@ class ModelConfig:
...
@@ -281,6 +289,10 @@ class ModelConfig:
raise
ValueError
(
raise
ValueError
(
"BitAndBytes quantization with TP or PP is not supported yet."
)
"BitAndBytes quantization with TP or PP is not supported yet."
)
if
self
.
quantization
==
"bitsandbytes"
and
self
.
enforce_eager
is
False
:
raise
ValueError
(
"BitAndBytes with enforce_eager = False is not supported yet."
)
def
get_hf_config_sliding_window
(
self
)
->
Optional
[
int
]:
def
get_hf_config_sliding_window
(
self
)
->
Optional
[
int
]:
"""Get the sliding window size, or None if disabled."""
"""Get the sliding window size, or None if disabled."""
...
@@ -590,9 +602,11 @@ class LoadConfig:
...
@@ -590,9 +602,11 @@ class LoadConfig:
mainly for profiling.
mainly for profiling.
"tensorizer" will use CoreWeave's tensorizer library for
"tensorizer" will use CoreWeave's tensorizer library for
fast weight loading.
fast weight loading.
"bitsandbytes" will load nf4 type weights.
ignore_patterns: The list of patterns to ignore when loading the model.
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
checkpoints.
"""
"""
load_format
:
Union
[
str
,
LoadFormat
,
"BaseModelLoader"
]
=
LoadFormat
.
AUTO
load_format
:
Union
[
str
,
LoadFormat
,
"BaseModelLoader"
]
=
LoadFormat
.
AUTO
...
@@ -716,7 +730,7 @@ class ParallelConfig:
...
@@ -716,7 +730,7 @@ class ParallelConfig:
backend
)
backend
)
self
.
_verify_args
()
self
.
_verify_args
()
self
.
rank
=
0
self
.
rank
:
int
=
0
@
property
@
property
def
use_ray
(
self
)
->
bool
:
def
use_ray
(
self
)
->
bool
:
...
@@ -842,6 +856,7 @@ class SchedulerConfig:
...
@@ -842,6 +856,7 @@ class SchedulerConfig:
class
DeviceConfig
:
class
DeviceConfig
:
device
:
Optional
[
torch
.
device
]
def
__init__
(
self
,
device
:
str
=
"auto"
)
->
None
:
def
__init__
(
self
,
device
:
str
=
"auto"
)
->
None
:
if
device
==
"auto"
:
if
device
==
"auto"
:
...
@@ -892,6 +907,7 @@ class SpeculativeConfig:
...
@@ -892,6 +907,7 @@ class SpeculativeConfig:
speculative_max_model_len
:
Optional
[
int
],
speculative_max_model_len
:
Optional
[
int
],
enable_chunked_prefill
:
bool
,
enable_chunked_prefill
:
bool
,
use_v2_block_manager
:
bool
,
use_v2_block_manager
:
bool
,
disable_log_stats
:
bool
,
speculative_disable_by_batch_size
:
Optional
[
int
],
speculative_disable_by_batch_size
:
Optional
[
int
],
ngram_prompt_lookup_max
:
Optional
[
int
],
ngram_prompt_lookup_max
:
Optional
[
int
],
ngram_prompt_lookup_min
:
Optional
[
int
],
ngram_prompt_lookup_min
:
Optional
[
int
],
...
@@ -1053,7 +1069,7 @@ class SpeculativeConfig:
...
@@ -1053,7 +1069,7 @@ class SpeculativeConfig:
draft_parallel_config
=
(
draft_parallel_config
=
(
SpeculativeConfig
.
create_draft_parallel_config
(
SpeculativeConfig
.
create_draft_parallel_config
(
target_parallel_config
,
target_parallel_config
,
speculative_draft_tensor_parallel_size
))
speculative_draft_tensor_parallel_size
,
draft_hf_config
))
if
num_speculative_tokens
is
None
:
if
num_speculative_tokens
is
None
:
raise
ValueError
(
raise
ValueError
(
...
@@ -1080,7 +1096,8 @@ class SpeculativeConfig:
...
@@ -1080,7 +1096,8 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_threshold
,
typical_acceptance_sampler_posterior_threshold
,
typical_acceptance_sampler_posterior_alpha
=
\
typical_acceptance_sampler_posterior_alpha
=
\
typical_acceptance_sampler_posterior_alpha
,
typical_acceptance_sampler_posterior_alpha
,
disable_logprobs
=
disable_logprobs
disable_logprobs
=
disable_logprobs
,
disable_log_stats
=
disable_log_stats
,
)
)
@
staticmethod
@
staticmethod
...
@@ -1121,15 +1138,23 @@ class SpeculativeConfig:
...
@@ -1121,15 +1138,23 @@ class SpeculativeConfig:
@
staticmethod
@
staticmethod
def
create_draft_parallel_config
(
def
create_draft_parallel_config
(
target_parallel_config
:
ParallelConfig
,
target_parallel_config
:
ParallelConfig
,
speculative_draft_tensor_parallel_size
:
Optional
[
int
]
speculative_draft_tensor_parallel_size
:
Optional
[
int
],
draft_hf_config
:
PretrainedConfig
,
)
->
ParallelConfig
:
)
->
ParallelConfig
:
"""Create a parallel config for use by the draft worker.
"""Create a parallel config for use by the draft worker.
This is mostly a copy of the target parallel config, except the tp_size.
This is mostly a copy of the target parallel config, except the tp_size.
"""
"""
if
speculative_draft_tensor_parallel_size
is
None
:
if
speculative_draft_tensor_parallel_size
is
None
:
speculative_draft_tensor_parallel_size
=
\
if
draft_hf_config
.
model_type
==
"mlp_speculator"
:
target_parallel_config
.
tensor_parallel_size
speculative_draft_tensor_parallel_size
=
1
if
target_parallel_config
.
tensor_parallel_size
>
1
:
logger
.
warning
(
"MLPSpeculator cannot currently be run with tp>1; "
"setting speculative_draft_tensor_parallel_size=1"
)
else
:
speculative_draft_tensor_parallel_size
=
\
target_parallel_config
.
tensor_parallel_size
elif
speculative_draft_tensor_parallel_size
!=
1
:
elif
speculative_draft_tensor_parallel_size
!=
1
:
# TODO(wooyeon): allow tp values larger than 1
# TODO(wooyeon): allow tp values larger than 1
raise
ValueError
(
raise
ValueError
(
...
@@ -1166,6 +1191,7 @@ class SpeculativeConfig:
...
@@ -1166,6 +1191,7 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_threshold
:
float
,
typical_acceptance_sampler_posterior_threshold
:
float
,
typical_acceptance_sampler_posterior_alpha
:
float
,
typical_acceptance_sampler_posterior_alpha
:
float
,
disable_logprobs
:
bool
,
disable_logprobs
:
bool
,
disable_log_stats
:
bool
,
):
):
"""Create a SpeculativeConfig object.
"""Create a SpeculativeConfig object.
...
@@ -1198,6 +1224,8 @@ class SpeculativeConfig:
...
@@ -1198,6 +1224,8 @@ class SpeculativeConfig:
sampling, target sampling, and after accepted tokens are
sampling, target sampling, and after accepted tokens are
determined. If set to False, log probabilities will be
determined. If set to False, log probabilities will be
returned.
returned.
disable_log_stats: Whether to disable periodic printing of stage
times in speculative decoding.
"""
"""
self
.
draft_model_config
=
draft_model_config
self
.
draft_model_config
=
draft_model_config
self
.
draft_parallel_config
=
draft_parallel_config
self
.
draft_parallel_config
=
draft_parallel_config
...
@@ -1212,6 +1240,7 @@ class SpeculativeConfig:
...
@@ -1212,6 +1240,7 @@ class SpeculativeConfig:
self
.
typical_acceptance_sampler_posterior_alpha
=
\
self
.
typical_acceptance_sampler_posterior_alpha
=
\
typical_acceptance_sampler_posterior_alpha
typical_acceptance_sampler_posterior_alpha
self
.
disable_logprobs
=
disable_logprobs
self
.
disable_logprobs
=
disable_logprobs
self
.
disable_log_stats
=
disable_log_stats
self
.
_verify_args
()
self
.
_verify_args
()
...
@@ -1281,7 +1310,7 @@ class LoRAConfig:
...
@@ -1281,7 +1310,7 @@ class LoRAConfig:
long_lora_scaling_factors
:
Optional
[
Tuple
[
float
]]
=
None
long_lora_scaling_factors
:
Optional
[
Tuple
[
float
]]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
#
Keep this in sync with csrc/punica/bgmv/bgmv_config.h
#
TODO: Increase the range of rank
possible_max_ranks
=
(
8
,
16
,
32
,
64
)
possible_max_ranks
=
(
8
,
16
,
32
,
64
)
possible_lora_extra_vocab_size
=
(
0
,
256
,
512
)
possible_lora_extra_vocab_size
=
(
0
,
256
,
512
)
if
self
.
max_lora_rank
not
in
possible_max_ranks
:
if
self
.
max_lora_rank
not
in
possible_max_ranks
:
...
@@ -1527,15 +1556,21 @@ def _get_and_verify_max_len(
...
@@ -1527,15 +1556,21 @@ def _get_and_verify_max_len(
"Disabling sliding window is not supported for models "
"Disabling sliding window is not supported for models "
"model_max_length in the config. Please raise an issue "
"model_max_length in the config. Please raise an issue "
"so we can investigate."
)
"so we can investigate."
)
pass
else
:
else
:
raise
ValueError
(
msg
=
(
f
"User-specified max_model_len (
{
max_model_len
}
) is greater "
f
"User-specified max_model_len (
{
max_model_len
}
) is greater "
"than the derived max_model_len "
f
"than the derived max_model_len
(
{
max_len_key
}
=
"
f
"
(
{
max_len_key
}
=
{
derived_max_model_len
}
or model_max_length="
f
"
{
derived_max_model_len
}
or model_max_length="
f
"
{
model_max_length
}
in model's config.json). This may lead "
f
"
{
model_max_length
}
in model's config.json). This may lead "
"to incorrect model outputs or CUDA errors. Make sure the "
"to incorrect model outputs or CUDA errors."
)
"value is correct and within the model context size."
)
if
envs
.
VLLM_ALLOW_LONG_MAX_MODEL_LEN
:
logger
.
warning
(
"%s Make sure the value is correct and within the "
"model context size."
,
msg
)
else
:
raise
ValueError
(
f
"
{
msg
}
To allow overriding this maximum, set "
"the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN=1"
)
return
int
(
max_model_len
)
return
int
(
max_model_len
)
...
...
vllm/connections.py
View file @
e661d594
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Mapping
,
Optional
from
typing
import
Mapping
,
MutableMapping
,
Optional
from
urllib.parse
import
urlparse
from
urllib.parse
import
urlparse
import
aiohttp
import
aiohttp
...
@@ -40,7 +40,7 @@ class HTTPConnection:
...
@@ -40,7 +40,7 @@ class HTTPConnection:
raise
ValueError
(
"Invalid HTTP URL: A valid HTTP URL "
raise
ValueError
(
"Invalid HTTP URL: A valid HTTP URL "
"must have scheme 'http' or 'https'."
)
"must have scheme 'http' or 'https'."
)
def
_headers
(
self
,
**
extras
:
str
)
->
Mapping
[
str
,
str
]:
def
_headers
(
self
,
**
extras
:
str
)
->
Mutable
Mapping
[
str
,
str
]:
return
{
"User-Agent"
:
f
"vLLM/
{
VLLM_VERSION
}
"
,
**
extras
}
return
{
"User-Agent"
:
f
"vLLM/
{
VLLM_VERSION
}
"
,
**
extras
}
def
get_response
(
def
get_response
(
...
...
vllm/core/block_manager_v1.py
View file @
e661d594
...
@@ -700,5 +700,5 @@ class BlockSpaceManagerV1(BlockSpaceManager):
...
@@ -700,5 +700,5 @@ class BlockSpaceManagerV1(BlockSpaceManager):
def
mark_blocks_as_computed
(
self
,
seq_group
:
SequenceGroup
):
def
mark_blocks_as_computed
(
self
,
seq_group
:
SequenceGroup
):
if
self
.
enable_caching
:
if
self
.
enable_caching
:
for
seq
in
seq_group
.
seqs_dict
.
value
s
():
for
seq
in
seq_group
.
get_seq
s
():
self
.
compute_full_blocks_in_seq
(
seq
)
self
.
compute_full_blocks_in_seq
(
seq
)
vllm/core/policy.py
deleted
100644 → 0
View file @
6b16ea2e
from
collections
import
deque
from
typing
import
Deque
from
vllm.sequence
import
SequenceGroup
class
Policy
:
def
get_priority
(
self
,
now
:
float
,
seq_group
:
SequenceGroup
,
)
->
float
:
raise
NotImplementedError
def
sort_by_priority
(
self
,
now
:
float
,
seq_groups
:
Deque
[
SequenceGroup
],
)
->
Deque
[
SequenceGroup
]:
return
deque
(
sorted
(
seq_groups
,
key
=
lambda
seq_group
:
self
.
get_priority
(
now
,
seq_group
),
reverse
=
True
,
))
class
FCFS
(
Policy
):
def
get_priority
(
self
,
now
:
float
,
seq_group
:
SequenceGroup
,
)
->
float
:
return
now
-
seq_group
.
metrics
.
arrival_time
class
PolicyFactory
:
_POLICY_REGISTRY
=
{
'fcfs'
:
FCFS
}
@
classmethod
def
get_policy
(
cls
,
policy_name
:
str
,
**
kwargs
)
->
Policy
:
return
cls
.
_POLICY_REGISTRY
[
policy_name
](
**
kwargs
)
vllm/core/scheduler.py
View file @
e661d594
...
@@ -8,7 +8,6 @@ from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
...
@@ -8,7 +8,6 @@ from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.core.interfaces
import
AllocStatus
,
BlockSpaceManager
from
vllm.core.interfaces
import
AllocStatus
,
BlockSpaceManager
from
vllm.core.policy
import
Policy
,
PolicyFactory
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
...
@@ -313,6 +312,7 @@ class Scheduler:
...
@@ -313,6 +312,7 @@ class Scheduler:
# Sequence groups finished requests ids since last step iteration.
# Sequence groups finished requests ids since last step iteration.
# It lets the model know that any state associated with these requests
# It lets the model know that any state associated with these requests
# can and must be released after the current step.
# can and must be released after the current step.
# This is used to evict the finished requests from the Mamba cache.
self
.
_finished_requests_ids
:
List
[
str
]
=
list
()
self
.
_finished_requests_ids
:
List
[
str
]
=
list
()
# Time at previous scheduling step
# Time at previous scheduling step
self
.
prev_time
=
0.0
self
.
prev_time
=
0.0
...
@@ -344,6 +344,16 @@ class Scheduler:
...
@@ -344,6 +344,16 @@ class Scheduler:
# Add sequence groups to the waiting queue.
# Add sequence groups to the waiting queue.
self
.
waiting
.
append
(
seq_group
)
self
.
waiting
.
append
(
seq_group
)
def
_add_seq_group_to_running
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
# Add sequence groups to the running queue.
# Only for testing purposes.
self
.
running
.
append
(
seq_group
)
def
_add_seq_group_to_swapped
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
# Add sequence groups to the swapped queue.
# Only for testing purposes.
self
.
swapped
.
append
(
seq_group
)
def
abort_seq_group
(
self
,
request_id
:
Union
[
str
,
Iterable
[
str
]])
->
None
:
def
abort_seq_group
(
self
,
request_id
:
Union
[
str
,
Iterable
[
str
]])
->
None
:
"""Aborts a sequence group with the given ID.
"""Aborts a sequence group with the given ID.
...
@@ -374,6 +384,7 @@ class Scheduler:
...
@@ -374,6 +384,7 @@ class Scheduler:
for
aborted_group
in
aborted_groups
:
for
aborted_group
in
aborted_groups
:
# Remove the sequence group from the state queue.
# Remove the sequence group from the state queue.
state_queue
.
remove
(
aborted_group
)
state_queue
.
remove
(
aborted_group
)
# Remove the aborted request from the Mamba cache.
self
.
_finished_requests_ids
.
append
(
aborted_group
.
request_id
)
self
.
_finished_requests_ids
.
append
(
aborted_group
.
request_id
)
for
seq
in
aborted_group
.
get_seqs
():
for
seq
in
aborted_group
.
get_seqs
():
if
seq
.
is_finished
():
if
seq
.
is_finished
():
...
@@ -396,32 +407,26 @@ class Scheduler:
...
@@ -396,32 +407,26 @@ class Scheduler:
def
_schedule_running
(
def
_schedule_running
(
self
,
self
,
running_queue
:
deque
,
budget
:
SchedulingBudget
,
budget
:
SchedulingBudget
,
curr_loras
:
Optional
[
Set
[
int
]],
curr_loras
:
Optional
[
Set
[
int
]],
policy
:
Policy
,
enable_chunking
:
bool
=
False
,
enable_chunking
:
bool
=
False
,
)
->
Tuple
[
deque
,
SchedulerRunningOutputs
]
:
)
->
SchedulerRunningOutputs
:
"""Schedule sequence groups that are running.
"""Schedule sequence groups that are running.
Running queue should include decode and chunked prefill requests.
Running queue should include decode and chunked prefill requests.
Args:
Args:
running_queue: The queue that contains running requests (i.e.,
decodes). The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated
budget: The scheduling budget. The argument is in-place updated
when any decodes are preempted.
when any decodes are preempted.
curr_loras: Currently batched lora request ids. The argument is
curr_loras: Currently batched lora request ids. The argument is
in-place updated when any decodes are preempted.
in-place updated when any decodes are preempted.
policy: The sorting policy to sort running_queue.
enable_chunking: If True, seq group can be chunked and only a
enable_chunking: If True, seq group can be chunked and only a
chunked number of tokens are scheduled if
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
all tokens.
Returns:
Returns:
A tuple of remaining running queue (should be always 0) after
SchedulerRunningOutputs.
scheduling and SchedulerRunningOutputs.
"""
"""
# Blocks that need to be swapped or copied before model execution.
# Blocks that need to be swapped or copied before model execution.
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
...
@@ -434,10 +439,9 @@ class Scheduler:
...
@@ -434,10 +439,9 @@ class Scheduler:
# NOTE(woosuk): Preemption happens only when there is no available slot
# NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state.
# to keep all the sequence groups in the RUNNING state.
# In this case, the policy is responsible for deciding which sequence
# groups to preempt.
running_queue
=
self
.
running
now
=
time
.
time
()
running_queue
=
policy
.
sort_by_priority
(
now
,
running_queue
)
while
running_queue
:
while
running_queue
:
seq_group
=
running_queue
[
0
]
seq_group
=
running_queue
[
0
]
num_running_tokens
=
self
.
_get_num_new_tokens
(
num_running_tokens
=
self
.
_get_num_new_tokens
(
...
@@ -501,7 +505,7 @@ class Scheduler:
...
@@ -501,7 +505,7 @@ class Scheduler:
if
curr_loras
is
not
None
and
seq_group
.
lora_int_id
>
0
:
if
curr_loras
is
not
None
and
seq_group
.
lora_int_id
>
0
:
curr_loras
.
add
(
seq_group
.
lora_int_id
)
curr_loras
.
add
(
seq_group
.
lora_int_id
)
return
running_queue
,
SchedulerRunningOutputs
(
return
SchedulerRunningOutputs
(
decode_seq_groups
=
decode_seq_groups
,
decode_seq_groups
=
decode_seq_groups
,
prefill_seq_groups
=
prefill_seq_groups
,
prefill_seq_groups
=
prefill_seq_groups
,
preempted
=
preempted
,
preempted
=
preempted
,
...
@@ -513,12 +517,10 @@ class Scheduler:
...
@@ -513,12 +517,10 @@ class Scheduler:
def
_schedule_swapped
(
def
_schedule_swapped
(
self
,
self
,
swapped_queue
:
deque
,
budget
:
SchedulingBudget
,
budget
:
SchedulingBudget
,
curr_loras
:
Optional
[
Set
[
int
]],
curr_loras
:
Optional
[
Set
[
int
]],
policy
:
Policy
,
enable_chunking
:
bool
=
False
,
enable_chunking
:
bool
=
False
,
)
->
Tuple
[
deque
,
SchedulerSwappedInOutputs
]
:
)
->
SchedulerSwappedInOutputs
:
"""Schedule sequence groups that are swapped out.
"""Schedule sequence groups that are swapped out.
It schedules swapped requests as long as it fits `budget` and
It schedules swapped requests as long as it fits `budget` and
...
@@ -526,20 +528,16 @@ class Scheduler:
...
@@ -526,20 +528,16 @@ class Scheduler:
`budget` and `curr_loras` are updated based on scheduled seq_groups.
`budget` and `curr_loras` are updated based on scheduled seq_groups.
Args:
Args:
swapped_queue: The queue that contains swapped out requests.
The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated
budget: The scheduling budget. The argument is in-place updated
when any requests are swapped in.
when any requests are swapped in.
curr_loras: Currently batched lora request ids. The argument is
curr_loras: Currently batched lora request ids. The argument is
in-place updated when any requests are swapped in.
in-place updated when any requests are swapped in.
policy: The sorting policy to sort swapped_queue.
enable_chunking: If True, seq group can be chunked and only a
enable_chunking: If True, seq group can be chunked and only a
chunked number of tokens are scheduled if
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
all tokens.
Returns:
Returns:
A tuple of remaining swapped_queue after scheduling and
SchedulerSwappedInOutputs.
SchedulerSwappedInOutputs.
"""
"""
# Blocks that need to be swapped or copied before model execution.
# Blocks that need to be swapped or copied before model execution.
...
@@ -547,10 +545,10 @@ class Scheduler:
...
@@ -547,10 +545,10 @@ class Scheduler:
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]]
=
[]
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]]
=
[]
decode_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
decode_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
prefill_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
prefill_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
now
=
time
.
time
()
swapped_queue
=
policy
.
sort_by_priority
(
now
,
swapped_queue
)
infeasible_seq_groups
:
List
[
SequenceGroup
]
=
[]
infeasible_seq_groups
:
List
[
SequenceGroup
]
=
[]
swapped_queue
=
self
.
swapped
leftover_swapped
:
Deque
[
SequenceGroup
]
=
deque
()
leftover_swapped
:
Deque
[
SequenceGroup
]
=
deque
()
while
swapped_queue
:
while
swapped_queue
:
seq_group
=
swapped_queue
[
0
]
seq_group
=
swapped_queue
[
0
]
...
@@ -615,7 +613,7 @@ class Scheduler:
...
@@ -615,7 +613,7 @@ class Scheduler:
swapped_queue
.
extendleft
(
leftover_swapped
)
swapped_queue
.
extendleft
(
leftover_swapped
)
return
swapped_queue
,
SchedulerSwappedInOutputs
(
return
SchedulerSwappedInOutputs
(
decode_seq_groups
=
decode_seq_groups
,
decode_seq_groups
=
decode_seq_groups
,
prefill_seq_groups
=
prefill_seq_groups
,
prefill_seq_groups
=
prefill_seq_groups
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_in
=
blocks_to_swap_in
,
...
@@ -642,11 +640,10 @@ class Scheduler:
...
@@ -642,11 +640,10 @@ class Scheduler:
def
_schedule_prefills
(
def
_schedule_prefills
(
self
,
self
,
waiting_queue
:
deque
,
budget
:
SchedulingBudget
,
budget
:
SchedulingBudget
,
curr_loras
:
Optional
[
Set
[
int
]],
curr_loras
:
Optional
[
Set
[
int
]],
enable_chunking
:
bool
=
False
,
enable_chunking
:
bool
=
False
,
)
->
Tuple
[
deque
,
SchedulerPrefillOutputs
]
:
)
->
SchedulerPrefillOutputs
:
"""Schedule sequence groups that are in prefill stage.
"""Schedule sequence groups that are in prefill stage.
Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE
Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE
...
@@ -658,8 +655,6 @@ class Scheduler:
...
@@ -658,8 +655,6 @@ class Scheduler:
`budget` and `curr_loras` are updated based on scheduled seq_groups.
`budget` and `curr_loras` are updated based on scheduled seq_groups.
Args:
Args:
waiting_queue: The queue that contains prefill requests.
The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated
budget: The scheduling budget. The argument is in-place updated
when any requests are scheduled.
when any requests are scheduled.
curr_loras: Currently batched lora request ids. The argument is
curr_loras: Currently batched lora request ids. The argument is
...
@@ -670,14 +665,12 @@ class Scheduler:
...
@@ -670,14 +665,12 @@ class Scheduler:
all tokens.
all tokens.
Returns:
Returns:
A tuple of remaining waiting_queue after scheduling and
SchedulerSwappedInOutputs.
SchedulerSwappedInOutputs.
"""
"""
ignored_seq_groups
:
List
[
SequenceGroup
]
=
[]
ignored_seq_groups
:
List
[
SequenceGroup
]
=
[]
seq_groups
:
List
[
SequenceGroup
]
=
[]
seq_groups
:
List
[
SequenceGroup
]
=
[]
# We don't sort waiting queue because we assume it is sorted.
# Copy the queue so that the input queue is not modified.
waiting_queue
=
self
.
waiting
waiting_queue
=
deque
([
s
for
s
in
waiting_queue
])
leftover_waiting_sequences
:
Deque
[
SequenceGroup
]
=
deque
()
leftover_waiting_sequences
:
Deque
[
SequenceGroup
]
=
deque
()
while
self
.
_passed_delay
(
time
.
time
())
and
waiting_queue
:
while
self
.
_passed_delay
(
time
.
time
())
and
waiting_queue
:
...
@@ -756,7 +749,7 @@ class Scheduler:
...
@@ -756,7 +749,7 @@ class Scheduler:
if
len
(
seq_groups
)
>
0
:
if
len
(
seq_groups
)
>
0
:
self
.
prev_prompt
=
True
self
.
prev_prompt
=
True
return
waiting_queue
,
SchedulerPrefillOutputs
(
return
SchedulerPrefillOutputs
(
seq_groups
=
seq_groups
,
seq_groups
=
seq_groups
,
ignored_seq_groups
=
ignored_seq_groups
,
ignored_seq_groups
=
ignored_seq_groups
,
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
=
True
))
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
=
True
))
...
@@ -783,53 +776,43 @@ class Scheduler:
...
@@ -783,53 +776,43 @@ class Scheduler:
seq_group
.
lora_int_id
for
seq_group
in
self
.
running
seq_group
.
lora_int_id
for
seq_group
in
self
.
running
if
seq_group
.
lora_int_id
>
0
)
if
self
.
lora_enabled
else
None
if
seq_group
.
lora_int_id
>
0
)
if
self
.
lora_enabled
else
None
remaining_waiting
,
prefills
=
(
self
.
waiting
,
prefills
=
SchedulerPrefillOutputs
.
create_empty
()
SchedulerPrefillOutputs
.
create_empty
())
running_scheduled
=
SchedulerRunningOutputs
.
create_empty
()
remaining_running
,
running_scheduled
=
(
swapped_in
=
SchedulerSwappedInOutputs
.
create_empty
()
self
.
running
,
SchedulerRunningOutputs
.
create_empty
())
remaining_swapped
,
swapped_in
=
(
self
.
swapped
,
SchedulerSwappedInOutputs
.
create_empty
())
# If any requests are swapped, prioritized swapped requests.
# If any requests are swapped, prioritized swapped requests.
if
not
self
.
swapped
:
if
not
self
.
swapped
:
remaining_waiting
,
prefills
=
self
.
_schedule_prefills
(
prefills
=
self
.
_schedule_prefills
(
budget
,
self
.
waiting
,
budget
,
curr_loras
,
enable_chunking
=
False
)
curr_loras
,
enable_chunking
=
False
)
fcfs_policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
# Don't schedule decodes if prefills are scheduled.
# Don't schedule decodes if prefills are scheduled.
# NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
# NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
# only contains decode requests, not chunked prefills.
# only contains decode requests, not chunked prefills.
if
len
(
prefills
.
seq_groups
)
==
0
:
if
len
(
prefills
.
seq_groups
)
==
0
:
remaining_running
,
running_scheduled
=
self
.
_schedule_running
(
running_scheduled
=
self
.
_schedule_running
(
budget
,
self
.
running
,
curr_loras
,
budget
,
enable_chunking
=
False
)
curr_loras
,
fcfs_policy
,
enable_chunking
=
False
)
# If any sequence group is preempted, do not swap in any sequence
# If any sequence group is preempted, do not swap in any sequence
# group. because it means there's no slot for new running requests.
# group. because it means there's no slot for new running requests.
if
len
(
running_scheduled
.
preempted
)
+
len
(
if
len
(
running_scheduled
.
preempted
)
+
len
(
running_scheduled
.
swapped_out
)
==
0
:
running_scheduled
.
swapped_out
)
==
0
:
remaining_swapped
,
swapped_in
=
self
.
_schedule_swapped
(
swapped_in
=
self
.
_schedule_swapped
(
budget
,
curr_loras
)
self
.
swapped
,
budget
,
curr_loras
,
fcfs_policy
)
assert
(
budget
.
num_batched_tokens
<=
assert
(
budget
.
num_batched_tokens
<=
self
.
scheduler_config
.
max_num_batched_tokens
)
self
.
scheduler_config
.
max_num_batched_tokens
)
assert
budget
.
num_curr_seqs
<=
self
.
scheduler_config
.
max_num_seqs
assert
budget
.
num_curr_seqs
<=
self
.
scheduler_config
.
max_num_seqs
# Update waiting requests.
# Update waiting requests.
self
.
waiting
=
remaining_waiting
self
.
waiting
.
extendleft
(
running_scheduled
.
preempted
)
self
.
waiting
.
extendleft
(
running_scheduled
.
preempted
)
# Update new running requests.
# Update new running requests.
self
.
running
=
remaining_running
self
.
running
.
extend
([
s
.
seq_group
for
s
in
prefills
.
seq_groups
])
self
.
running
.
extend
([
s
.
seq_group
for
s
in
prefills
.
seq_groups
])
self
.
running
.
extend
(
self
.
running
.
extend
(
[
s
.
seq_group
for
s
in
running_scheduled
.
decode_seq_groups
])
[
s
.
seq_group
for
s
in
running_scheduled
.
decode_seq_groups
])
self
.
running
.
extend
(
self
.
running
.
extend
(
[
s
.
seq_group
for
s
in
swapped_in
.
decode_seq_groups
])
[
s
.
seq_group
for
s
in
swapped_in
.
decode_seq_groups
])
# Update swapped requests.
# Update swapped requests.
self
.
swapped
=
remaining_swapped
self
.
swapped
.
extend
(
running_scheduled
.
swapped_out
)
self
.
swapped
.
extend
(
running_scheduled
.
swapped_out
)
preempted
=
(
len
(
running_scheduled
.
preempted
)
+
preempted
=
(
len
(
running_scheduled
.
preempted
)
+
len
(
running_scheduled
.
swapped_out
))
len
(
running_scheduled
.
swapped_out
))
...
@@ -875,42 +858,32 @@ class Scheduler:
...
@@ -875,42 +858,32 @@ class Scheduler:
)
)
curr_loras
:
Set
[
int
]
=
set
()
curr_loras
:
Set
[
int
]
=
set
()
remaining_waiting
,
prefills
=
(
self
.
waiting
,
prefills
=
SchedulerPrefillOutputs
.
create_empty
()
SchedulerPrefillOutputs
.
create_empty
())
swapped_in
=
SchedulerSwappedInOutputs
.
create_empty
()
remaining_running
,
running_scheduled
=
(
self
.
running
,
SchedulerRunningOutputs
.
create_empty
())
remaining_swapped
,
swapped_in
=
(
self
.
swapped
,
SchedulerSwappedInOutputs
.
create_empty
())
# Decoding should be always scheduled first by fcfs.
# Decoding should be always scheduled first by fcfs.
fcfs_policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
running_scheduled
=
self
.
_schedule_running
(
budget
,
remaining_running
,
running_scheduled
=
self
.
_schedule_running
(
curr_loras
,
self
.
running
,
enable_chunking
=
True
)
budget
,
curr_loras
,
fcfs_policy
,
enable_chunking
=
True
)
# Schedule swapped out requests.
# Schedule swapped out requests.
# If preemption happens, it means we don't have space for swap-in.
# If preemption happens, it means we don't have space for swap-in.
if
len
(
running_scheduled
.
preempted
)
+
len
(
if
len
(
running_scheduled
.
preempted
)
+
len
(
running_scheduled
.
swapped_out
)
==
0
:
running_scheduled
.
swapped_out
)
==
0
:
remaining_swapped
,
swapped_in
=
self
.
_schedule_swapped
(
swapped_in
=
self
.
_schedule_swapped
(
budget
,
curr_loras
)
self
.
swapped
,
budget
,
curr_loras
,
fcfs_policy
)
# Schedule new prefills.
# Schedule new prefills.
remaining_waiting
,
prefills
=
self
.
_schedule_prefills
(
prefills
=
self
.
_schedule_prefills
(
budget
,
self
.
waiting
,
budget
,
curr_loras
,
enable_chunking
=
True
)
curr_loras
,
enable_chunking
=
True
)
assert
(
budget
.
num_batched_tokens
<=
assert
(
budget
.
num_batched_tokens
<=
self
.
scheduler_config
.
max_num_batched_tokens
)
self
.
scheduler_config
.
max_num_batched_tokens
)
assert
budget
.
num_curr_seqs
<=
self
.
scheduler_config
.
max_num_seqs
assert
budget
.
num_curr_seqs
<=
self
.
scheduler_config
.
max_num_seqs
# Update waiting requests.
# Update waiting requests.
self
.
waiting
=
remaining_waiting
self
.
waiting
.
extendleft
(
running_scheduled
.
preempted
)
self
.
waiting
.
extendleft
(
running_scheduled
.
preempted
)
# Update new running requests.
# Update new running requests.
self
.
running
=
remaining_running
self
.
running
.
extend
([
s
.
seq_group
for
s
in
prefills
.
seq_groups
])
self
.
running
.
extend
([
s
.
seq_group
for
s
in
prefills
.
seq_groups
])
self
.
running
.
extend
(
self
.
running
.
extend
(
[
s
.
seq_group
for
s
in
running_scheduled
.
decode_seq_groups
])
[
s
.
seq_group
for
s
in
running_scheduled
.
decode_seq_groups
])
...
@@ -921,7 +894,6 @@ class Scheduler:
...
@@ -921,7 +894,6 @@ class Scheduler:
self
.
running
.
extend
(
self
.
running
.
extend
(
[
s
.
seq_group
for
s
in
swapped_in
.
prefill_seq_groups
])
[
s
.
seq_group
for
s
in
swapped_in
.
prefill_seq_groups
])
# Update swapped requests.
# Update swapped requests.
self
.
swapped
=
remaining_swapped
self
.
swapped
.
extend
(
running_scheduled
.
swapped_out
)
self
.
swapped
.
extend
(
running_scheduled
.
swapped_out
)
return
SchedulerOutputs
(
return
SchedulerOutputs
(
scheduled_seq_groups
=
(
prefills
.
seq_groups
+
scheduled_seq_groups
=
(
prefills
.
seq_groups
+
...
@@ -1029,7 +1001,6 @@ class Scheduler:
...
@@ -1029,7 +1001,6 @@ class Scheduler:
token_chunk_size
=
token_chunk_size
,
token_chunk_size
=
token_chunk_size
,
lora_request
=
seq_group
.
lora_request
,
lora_request
=
seq_group
.
lora_request
,
computed_block_nums
=
common_computed_block_nums
,
computed_block_nums
=
common_computed_block_nums
,
state
=
seq_group
.
state
,
# `multi_modal_data` will only be present for the 1st comm
# `multi_modal_data` will only be present for the 1st comm
# between engine and worker.
# between engine and worker.
# the subsequent comms can still use delta, but
# the subsequent comms can still use delta, but
...
@@ -1058,13 +1029,16 @@ class Scheduler:
...
@@ -1058,13 +1029,16 @@ class Scheduler:
self
.
block_manager
.
free
(
seq
)
self
.
block_manager
.
free
(
seq
)
def
free_finished_seq_groups
(
self
)
->
None
:
def
free_finished_seq_groups
(
self
)
->
None
:
for
queue
in
[
self
.
running
,
self
.
swapped
,
self
.
waiting
]:
remaining
:
Deque
[
SequenceGroup
]
=
deque
()
self
.
_finished_requests_ids
+=
[
for
seq_group
in
self
.
running
:
seq_group
.
request_id
for
seq_group
in
queue
if
seq_group
.
is_finished
():
if
seq_group
.
is_finished
()
# Add the finished requests to the finished requests list.
]
# This list will be used to update the Mamba cache in the
self
.
running
=
deque
(
seq_group
for
seq_group
in
self
.
running
# next step.
if
not
seq_group
.
is_finished
())
self
.
_finished_requests_ids
.
append
(
seq_group
.
request_id
)
else
:
remaining
.
append
(
seq_group
)
self
.
running
=
remaining
def
_allocate_and_set_running
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
def
_allocate_and_set_running
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
self
.
block_manager
.
allocate
(
seq_group
)
self
.
block_manager
.
allocate
(
seq_group
)
...
...
vllm/distributed/device_communicators/cuda_wrapper.py
View file @
e661d594
...
@@ -4,9 +4,6 @@ convenient for use when we just need to call a few functions.
...
@@ -4,9 +4,6 @@ convenient for use when we just need to call a few functions.
"""
"""
import
ctypes
import
ctypes
import
glob
import
os
import
sys
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Dict
,
List
,
Optional
...
@@ -36,24 +33,25 @@ class Function:
...
@@ -36,24 +33,25 @@ class Function:
argtypes
:
List
[
Any
]
argtypes
:
List
[
Any
]
def
get_pytorch_default_cudart_library_path
()
->
str
:
def
find_loaded_library
(
lib_name
)
->
Optional
[
str
]:
# code borrowed from https://github.com/pytorch/pytorch/blob/1cae60a87e5bdda8bcf55724a862eeed98a9747e/torch/__init__.py#L284 # noqa
"""
lib_folder
=
"cuda_runtime"
According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
lib_name
=
"libcudart.so.*[0-9]"
the file `/proc/self/maps` contains the memory maps of the process, which includes the
lib_path
=
None
shared libraries loaded by the process. We can use this file to find the path of the
for
path
in
sys
.
path
:
a loaded library.
nvidia_path
=
os
.
path
.
join
(
path
,
"nvidia"
)
"""
# noqa
if
not
os
.
path
.
exists
(
nvidia_path
):
found
=
False
continue
with
open
(
"/proc/self/maps"
)
as
f
:
candidate_lib_paths
=
glob
.
glob
(
for
line
in
f
:
os
.
path
.
join
(
nvidia_path
,
lib_folder
,
"lib"
,
lib_name
))
if
lib_name
in
line
:
if
candidate_lib_paths
and
not
lib_path
:
found
=
True
lib_path
=
candidate_lib_paths
[
0
]
break
if
lib_path
:
if
not
found
:
break
# the library is not loaded in the current process
if
not
lib_path
:
return
None
raise
ValueError
(
f
"
{
lib_name
}
not found in the system path
{
sys
.
path
}
"
)
start
=
line
.
index
(
"/"
)
return
lib_path
path
=
line
[
start
:].
strip
()
return
path
class
CudaRTLibrary
:
class
CudaRTLibrary
:
...
@@ -100,7 +98,9 @@ class CudaRTLibrary:
...
@@ -100,7 +98,9 @@ class CudaRTLibrary:
def
__init__
(
self
,
so_file
:
Optional
[
str
]
=
None
):
def
__init__
(
self
,
so_file
:
Optional
[
str
]
=
None
):
if
so_file
is
None
:
if
so_file
is
None
:
so_file
=
get_pytorch_default_cudart_library_path
()
so_file
=
find_loaded_library
(
"libcudart.so"
)
assert
so_file
is
not
None
,
\
"libcudart.so is not loaded in the current process"
if
so_file
not
in
CudaRTLibrary
.
path_to_library_cache
:
if
so_file
not
in
CudaRTLibrary
.
path_to_library_cache
:
lib
=
ctypes
.
CDLL
(
so_file
)
lib
=
ctypes
.
CDLL
(
so_file
)
CudaRTLibrary
.
path_to_library_cache
[
so_file
]
=
lib
CudaRTLibrary
.
path_to_library_cache
[
so_file
]
=
lib
...
...
vllm/distributed/device_communicators/custom_all_reduce_utils.py
View file @
e661d594
...
@@ -145,6 +145,7 @@ def can_actually_p2p(
...
@@ -145,6 +145,7 @@ def can_actually_p2p(
p_tgt
.
start
()
p_tgt
.
start
()
p_src
.
join
()
p_src
.
join
()
p_tgt
.
join
()
p_tgt
.
join
()
assert
p_src
.
exitcode
==
0
and
p_tgt
.
exitcode
==
0
result
:
List
[
bool
]
=
[]
result
:
List
[
bool
]
=
[]
for
src
,
tgt
in
zip
(
batch_src
,
batch_tgt
):
for
src
,
tgt
in
zip
(
batch_src
,
batch_tgt
):
a
=
result_queue
.
get
()
a
=
result_queue
.
get
()
...
@@ -221,7 +222,8 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
...
@@ -221,7 +222,8 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
# wrap raised exception to provide more information
# wrap raised exception to provide more information
raise
RuntimeError
(
raise
RuntimeError
(
f
"Error happened when batch testing "
f
"Error happened when batch testing "
f
"peer-to-peer access from
{
batch_src
}
to
{
batch_tgt
}
"
)
from
e
f
"peer-to-peer access from
{
batch_src
}
to
{
batch_tgt
}
:
\n
"
f
"
{
returned
.
stderr
.
decode
()
}
"
)
from
e
result
=
pickle
.
loads
(
returned
.
stdout
)
result
=
pickle
.
loads
(
returned
.
stdout
)
for
_i
,
_j
,
r
in
zip
(
batch_src
,
batch_tgt
,
result
):
for
_i
,
_j
,
r
in
zip
(
batch_src
,
batch_tgt
,
result
):
cache
[
f
"
{
_i
}
->
{
_j
}
"
]
=
r
cache
[
f
"
{
_i
}
->
{
_j
}
"
]
=
r
...
...
vllm/distributed/device_communicators/shm_broadcast.py
View file @
e661d594
...
@@ -9,7 +9,7 @@ from unittest.mock import patch
...
@@ -9,7 +9,7 @@ from unittest.mock import patch
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
zmq
import
P
UB
,
REP
,
REQ
,
SUB
,
SUBSCRIB
E
,
Context
# type: ignore
from
zmq
import
S
UB
,
SUBSCRIBE
,
XPUB
,
XPUB_VERBOS
E
,
Context
# type: ignore
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -153,9 +153,7 @@ class Handle:
...
@@ -153,9 +153,7 @@ class Handle:
buffer
:
Optional
[
ShmRingBuffer
]
=
None
buffer
:
Optional
[
ShmRingBuffer
]
=
None
local_subscribe_port
:
Optional
[
int
]
=
None
local_subscribe_port
:
Optional
[
int
]
=
None
local_sync_port
:
Optional
[
int
]
=
None
remote_subscribe_port
:
Optional
[
int
]
=
None
remote_subscribe_port
:
Optional
[
int
]
=
None
remote_sync_port
:
Optional
[
int
]
=
None
class
MessageQueue
:
class
MessageQueue
:
...
@@ -189,38 +187,36 @@ class MessageQueue:
...
@@ -189,38 +187,36 @@ class MessageQueue:
self
.
buffer
=
ShmRingBuffer
(
n_local_reader
,
max_chunk_bytes
,
self
.
buffer
=
ShmRingBuffer
(
n_local_reader
,
max_chunk_bytes
,
max_chunks
)
max_chunks
)
self
.
local_socket
=
context
.
socket
(
PUB
)
# XPUB is very similar to PUB,
# except that it can receive subscription messages
# to confirm the number of subscribers
self
.
local_socket
=
context
.
socket
(
XPUB
)
# set the verbose option so that we can receive every subscription
# message. otherwise, we will only receive the first subscription
# see http://api.zeromq.org/3-3:zmq-setsockopt for more details
self
.
local_socket
.
setsockopt
(
XPUB_VERBOSE
,
True
)
local_subscribe_port
=
get_open_port
()
local_subscribe_port
=
get_open_port
()
self
.
local_socket
.
bind
(
f
"tcp://*:
{
local_subscribe_port
}
"
)
self
.
local_socket
.
bind
(
f
"tcp://*:
{
local_subscribe_port
}
"
)
self
.
local_sync_socket
=
context
.
socket
(
REP
)
local_sync_port
=
get_open_port
()
self
.
local_sync_socket
.
bind
(
f
"tcp://*:
{
local_sync_port
}
"
)
self
.
current_idx
=
0
self
.
current_idx
=
0
else
:
else
:
self
.
buffer
=
None
# type: ignore
self
.
buffer
=
None
# type: ignore
local_subscribe_port
=
None
local_subscribe_port
=
None
local_sync_port
=
None
self
.
local_socket
=
None
self
.
local_socket
=
None
self
.
local_sync_socket
=
None
self
.
current_idx
=
-
1
self
.
current_idx
=
-
1
if
n_remote_reader
>
0
:
if
n_remote_reader
>
0
:
# for remote readers, we will:
# for remote readers, we will:
# create a publish-subscribe socket to communicate large data
# create a publish-subscribe socket to communicate large data
self
.
remote_socket
=
context
.
socket
(
PUB
)
self
.
remote_socket
=
context
.
socket
(
XPUB
)
self
.
remote_socket
.
setsockopt
(
XPUB_VERBOSE
,
True
)
remote_subscribe_port
=
get_open_port
()
remote_subscribe_port
=
get_open_port
()
self
.
remote_socket
.
bind
(
f
"tcp://*:
{
remote_subscribe_port
}
"
)
self
.
remote_socket
.
bind
(
f
"tcp://*:
{
remote_subscribe_port
}
"
)
self
.
remote_sync_socket
=
context
.
socket
(
REP
)
remote_sync_port
=
get_open_port
()
self
.
remote_sync_socket
.
bind
(
f
"tcp://*:
{
remote_sync_port
}
"
)
else
:
else
:
remote_subscribe_port
=
None
remote_subscribe_port
=
None
remote_sync_port
=
None
self
.
remote_socket
=
None
self
.
remote_socket
=
None
self
.
remote_sync_socket
=
None
self
.
_is_writer
=
True
self
.
_is_writer
=
True
self
.
_is_local_reader
=
False
self
.
_is_local_reader
=
False
...
@@ -233,9 +229,7 @@ class MessageQueue:
...
@@ -233,9 +229,7 @@ class MessageQueue:
local_reader_ranks
=
local_reader_ranks
,
local_reader_ranks
=
local_reader_ranks
,
buffer
=
self
.
buffer
,
buffer
=
self
.
buffer
,
local_subscribe_port
=
local_subscribe_port
,
local_subscribe_port
=
local_subscribe_port
,
local_sync_port
=
local_sync_port
,
remote_subscribe_port
=
remote_subscribe_port
,
remote_subscribe_port
=
remote_subscribe_port
,
remote_sync_port
=
remote_sync_port
,
)
)
logger
.
info
(
"vLLM message queue communication handle: %s"
,
self
.
handle
)
logger
.
info
(
"vLLM message queue communication handle: %s"
,
self
.
handle
)
...
@@ -264,12 +258,7 @@ class MessageQueue:
...
@@ -264,12 +258,7 @@ class MessageQueue:
self
.
local_socket
.
connect
(
self
.
local_socket
.
connect
(
f
"tcp://
{
handle
.
connect_ip
}
:
{
handle
.
local_subscribe_port
}
"
)
f
"tcp://
{
handle
.
connect_ip
}
:
{
handle
.
local_subscribe_port
}
"
)
self
.
local_sync_socket
=
context
.
socket
(
REQ
)
self
.
local_sync_socket
.
connect
(
f
"tcp://
{
handle
.
connect_ip
}
:
{
handle
.
local_sync_port
}
"
)
self
.
remote_socket
=
None
self
.
remote_socket
=
None
self
.
remote_sync_socket
=
None
else
:
else
:
self
.
buffer
=
None
# type: ignore
self
.
buffer
=
None
# type: ignore
self
.
current_idx
=
-
1
self
.
current_idx
=
-
1
...
@@ -278,17 +267,12 @@ class MessageQueue:
...
@@ -278,17 +267,12 @@ class MessageQueue:
self
.
_is_remote_reader
=
True
self
.
_is_remote_reader
=
True
self
.
local_socket
=
None
self
.
local_socket
=
None
self
.
local_sync_socket
=
None
self
.
remote_socket
=
context
.
socket
(
SUB
)
self
.
remote_socket
=
context
.
socket
(
SUB
)
self
.
remote_socket
.
setsockopt_string
(
SUBSCRIBE
,
""
)
self
.
remote_socket
.
setsockopt_string
(
SUBSCRIBE
,
""
)
self
.
remote_socket
.
connect
(
self
.
remote_socket
.
connect
(
f
"tcp://
{
handle
.
connect_ip
}
:
{
handle
.
remote_subscribe_port
}
"
)
f
"tcp://
{
handle
.
connect_ip
}
:
{
handle
.
remote_subscribe_port
}
"
)
self
.
remote_sync_socket
=
context
.
socket
(
REQ
)
self
.
remote_sync_socket
.
connect
(
f
"tcp://
{
handle
.
connect_ip
}
:
{
handle
.
remote_sync_port
}
"
)
return
self
return
self
def
wait_until_ready
(
self
):
def
wait_until_ready
(
self
):
...
@@ -300,29 +284,27 @@ class MessageQueue:
...
@@ -300,29 +284,27 @@ class MessageQueue:
# local readers
# local readers
for
i
in
range
(
self
.
n_local_reader
):
for
i
in
range
(
self
.
n_local_reader
):
recv
=
self
.
local_sync_socket
.
recv
()
# wait for subscription messages from all local readers
assert
recv
==
b
"READY"
self
.
local_socket
.
recv
()
self
.
local_sync_socket
.
send
(
b
"READY"
)
if
self
.
n_local_reader
>
0
:
if
self
.
n_local_reader
>
0
:
# send a message to all local readers
# to make sure the publish channel is working
self
.
local_socket
.
send
(
b
"READY"
)
self
.
local_socket
.
send
(
b
"READY"
)
# remote readers
# remote readers
for
i
in
range
(
self
.
n_remote_reader
):
for
i
in
range
(
self
.
n_remote_reader
):
recv
=
self
.
remote_sync_socket
.
recv
()
# wait for subscription messages from all remote readers
assert
recv
==
b
"READY"
self
.
remote_socket
.
recv
()
self
.
remote_sync_socket
.
send
(
b
"READY"
)
if
self
.
n_remote_reader
>
0
:
if
self
.
n_remote_reader
>
0
:
# send a message to all remote readers
# to make sure the publish channel is working
self
.
remote_socket
.
send
(
b
"READY"
)
self
.
remote_socket
.
send
(
b
"READY"
)
elif
self
.
_is_local_reader
:
elif
self
.
_is_local_reader
:
self
.
local_sync_socket
.
send
(
b
"READY"
)
# wait for the writer to send a message
recv
=
self
.
local_sync_socket
.
recv
()
assert
recv
==
b
"READY"
recv
=
self
.
local_socket
.
recv
()
recv
=
self
.
local_socket
.
recv
()
assert
recv
==
b
"READY"
assert
recv
==
b
"READY"
elif
self
.
_is_remote_reader
:
elif
self
.
_is_remote_reader
:
self
.
remote_sync_socket
.
send
(
b
"READY"
)
# wait for the writer to send a message
recv
=
self
.
remote_sync_socket
.
recv
()
assert
recv
==
b
"READY"
recv
=
self
.
remote_socket
.
recv
()
recv
=
self
.
remote_socket
.
recv
()
assert
recv
==
b
"READY"
assert
recv
==
b
"READY"
...
...
vllm/distributed/device_communicators/tpu_communicator.py
0 → 100644
View file @
e661d594
import
torch
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
from
vllm.platforms
import
current_platform
if
current_platform
.
is_tpu
():
import
torch_xla.core.xla_model
as
xm
import
torch_xla.runtime
as
xr
from
torch_xla._internal
import
pjrt
class
TpuCommunicator
:
def
__init__
(
self
,
group
:
ProcessGroup
):
if
not
current_platform
.
is_tpu
():
self
.
disabled
=
True
return
self
.
disabled
=
False
local_rank
=
dist
.
get_rank
(
group
)
world_size
=
dist
.
get_world_size
(
group
)
pjrt
.
initialize_multiprocess
(
local_rank
,
world_size
)
xr
.
_init_world_size_ordinal
()
def
all_reduce
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
xm
.
all_reduce
(
xm
.
REDUCE_SUM
,
x
)
def
all_gather
(
self
,
x
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
assert
dim
==
-
1
,
"TPUs only support dim=-1 for all-gather."
return
xm
.
all_gather
(
x
,
dim
=
dim
)
vllm/distributed/parallel_state.py
View file @
e661d594
...
@@ -45,22 +45,16 @@ TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
...
@@ -45,22 +45,16 @@ TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
def
_split_tensor_dict
(
def
_split_tensor_dict
(
tensor_dict
:
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]]
,
tensor_dict
:
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]]
prefix
:
str
=
""
)
->
Tuple
[
List
[
Tuple
[
str
,
Any
]],
List
[
torch
.
Tensor
]]:
)
->
Tuple
[
List
[
Tuple
[
str
,
Any
]],
List
[
torch
.
Tensor
]]:
"""Split the tensor dictionary into two parts:
"""Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
by its metadata.
by its metadata.
2. A list of tensors.
2. A list of tensors.
If the Tensor is nested under `tensor_dict["key1"]["key2"]`, the key of its
metadata will be "key1%key2".
"""
"""
metadata_list
:
List
[
Tuple
[
str
,
Any
]]
=
[]
metadata_list
:
List
[
Tuple
[
str
,
Any
]]
=
[]
tensor_list
=
[]
tensor_list
:
List
[
torch
.
Tensor
]
=
[]
for
key
,
value
in
tensor_dict
.
items
():
for
key
,
value
in
tensor_dict
.
items
():
assert
"%"
not
in
key
,
(
"Avoid having '%' in key "
"as it is used as a separator for nested entries."
)
if
isinstance
(
value
,
torch
.
Tensor
):
if
isinstance
(
value
,
torch
.
Tensor
):
# Note: we cannot use `value.device` here,
# Note: we cannot use `value.device` here,
# because it contains not only the device type but also the device
# because it contains not only the device type but also the device
...
@@ -68,31 +62,13 @@ def _split_tensor_dict(
...
@@ -68,31 +62,13 @@ def _split_tensor_dict(
# receiving side will set the device index.
# receiving side will set the device index.
device
=
value
.
device
.
type
device
=
value
.
device
.
type
metadata_list
.
append
(
metadata_list
.
append
(
(
prefix
+
key
,
TensorMetadata
(
device
,
value
.
dtype
,
(
key
,
TensorMetadata
(
device
,
value
.
dtype
,
value
.
size
())))
value
.
size
())))
tensor_list
.
append
(
value
)
tensor_list
.
append
(
value
)
elif
isinstance
(
value
,
dict
):
if
len
(
value
)
==
0
:
metadata_list
.
append
((
prefix
+
key
,
value
))
inner_metadata_list
,
inner_tensor_list
=
_split_tensor_dict
(
value
,
prefix
+
key
+
"%"
)
metadata_list
.
extend
(
inner_metadata_list
)
tensor_list
.
extend
(
inner_tensor_list
)
else
:
else
:
metadata_list
.
append
((
prefix
+
key
,
value
))
metadata_list
.
append
((
key
,
value
))
return
metadata_list
,
tensor_list
return
metadata_list
,
tensor_list
def
_update_nested_dict
(
nested_dict
,
flattened_key
,
value
):
key_splits
=
flattened_key
.
split
(
"%"
)
cur_dict
=
nested_dict
for
k
in
key_splits
[:
-
1
]:
if
k
not
in
cur_dict
:
cur_dict
[
k
]
=
{}
cur_dict
=
cur_dict
[
k
]
cur_dict
[
key_splits
[
-
1
]]
=
value
class
GroupCoordinator
:
class
GroupCoordinator
:
"""
"""
PyTorch ProcessGroup wrapper for a group of processes.
PyTorch ProcessGroup wrapper for a group of processes.
...
@@ -133,6 +109,7 @@ class GroupCoordinator:
...
@@ -133,6 +109,7 @@ class GroupCoordinator:
torch_distributed_backend
:
Union
[
str
,
Backend
],
torch_distributed_backend
:
Union
[
str
,
Backend
],
use_pynccl
:
bool
,
use_pynccl
:
bool
,
use_custom_allreduce
:
bool
,
use_custom_allreduce
:
bool
,
use_tpu_communicator
:
bool
,
use_message_queue_broadcaster
:
bool
=
False
,
use_message_queue_broadcaster
:
bool
=
False
,
):
):
...
@@ -164,6 +141,7 @@ class GroupCoordinator:
...
@@ -164,6 +141,7 @@ class GroupCoordinator:
self
.
use_pynccl
=
use_pynccl
self
.
use_pynccl
=
use_pynccl
self
.
use_custom_allreduce
=
use_custom_allreduce
self
.
use_custom_allreduce
=
use_custom_allreduce
self
.
use_tpu_communicator
=
use_tpu_communicator
# lazy import to avoid documentation build error
# lazy import to avoid documentation build error
from
vllm.distributed.device_communicators.custom_all_reduce
import
(
from
vllm.distributed.device_communicators.custom_all_reduce
import
(
...
@@ -190,6 +168,12 @@ class GroupCoordinator:
...
@@ -190,6 +168,12 @@ class GroupCoordinator:
else
:
else
:
self
.
ca_comm
=
None
self
.
ca_comm
=
None
from
vllm.distributed.device_communicators.tpu_communicator
import
(
TpuCommunicator
)
self
.
tpu_communicator
:
Optional
[
TpuCommunicator
]
if
use_tpu_communicator
and
self
.
world_size
>
1
:
self
.
tpu_communicator
=
TpuCommunicator
(
group
=
self
.
cpu_group
)
from
vllm.distributed.device_communicators.shm_broadcast
import
(
from
vllm.distributed.device_communicators.shm_broadcast
import
(
MessageQueue
)
MessageQueue
)
self
.
mq_broadcaster
:
Optional
[
MessageQueue
]
=
None
self
.
mq_broadcaster
:
Optional
[
MessageQueue
]
=
None
...
@@ -243,6 +227,13 @@ class GroupCoordinator:
...
@@ -243,6 +227,13 @@ class GroupCoordinator:
ca_comm
=
self
.
ca_comm
ca_comm
=
self
.
ca_comm
maybe_ca_context
=
nullcontext
(
maybe_ca_context
=
nullcontext
(
)
if
ca_comm
is
None
else
ca_comm
.
capture
()
)
if
ca_comm
is
None
else
ca_comm
.
capture
()
# ensure all initialization operations complete before attempting to
# capture the graph on another stream
curr_stream
=
torch
.
cuda
.
current_stream
()
if
curr_stream
!=
stream
:
stream
.
wait_stream
(
curr_stream
)
with
torch
.
cuda
.
stream
(
stream
),
maybe_ca_context
:
with
torch
.
cuda
.
stream
(
stream
),
maybe_ca_context
:
# In graph mode, we have to be very careful about the collective
# In graph mode, we have to be very careful about the collective
# operations. The current status is:
# operations. The current status is:
...
@@ -282,6 +273,12 @@ class GroupCoordinator:
...
@@ -282,6 +273,12 @@ class GroupCoordinator:
# Bypass the function if we are using only 1 GPU.
# Bypass the function if we are using only 1 GPU.
if
self
.
world_size
==
1
:
if
self
.
world_size
==
1
:
return
input_
return
input_
# For TPUs, use TPU communicator.
tpu_comm
=
self
.
tpu_communicator
if
tpu_comm
is
not
None
and
not
tpu_comm
.
disabled
:
return
tpu_comm
.
all_reduce
(
input_
)
if
ca_comm
is
not
None
:
if
ca_comm
is
not
None
:
out
=
ca_comm
.
custom_all_reduce
(
input_
)
out
=
ca_comm
.
custom_all_reduce
(
input_
)
if
out
is
not
None
:
if
out
is
not
None
:
...
@@ -289,6 +286,9 @@ class GroupCoordinator:
...
@@ -289,6 +286,9 @@ class GroupCoordinator:
pynccl_comm
=
self
.
pynccl_comm
pynccl_comm
=
self
.
pynccl_comm
if
(
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
):
if
(
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
):
pynccl_comm
.
all_reduce
(
input_
)
pynccl_comm
.
all_reduce
(
input_
)
elif
input_
.
is_cpu
:
import
intel_extension_for_pytorch
as
ipex
ipex
.
distributed
.
all_reduce
(
input_
,
group
=
self
.
device_group
)
else
:
else
:
torch
.
distributed
.
all_reduce
(
input_
,
group
=
self
.
device_group
)
torch
.
distributed
.
all_reduce
(
input_
,
group
=
self
.
device_group
)
return
input_
return
input_
...
@@ -300,6 +300,12 @@ class GroupCoordinator:
...
@@ -300,6 +300,12 @@ class GroupCoordinator:
return
input_
return
input_
assert
-
input_
.
dim
()
<=
dim
<
input_
.
dim
(),
(
assert
-
input_
.
dim
()
<=
dim
<
input_
.
dim
(),
(
f
"Invalid dim (
{
dim
}
) for input tensor with shape
{
input_
.
size
()
}
"
)
f
"Invalid dim (
{
dim
}
) for input tensor with shape
{
input_
.
size
()
}
"
)
# For TPUs, use TPU communicator.
tpu_comm
=
self
.
tpu_communicator
if
tpu_comm
is
not
None
and
not
tpu_comm
.
disabled
:
return
tpu_comm
.
all_gather
(
input_
,
dim
)
if
dim
<
0
:
if
dim
<
0
:
# Convert negative dim to positive.
# Convert negative dim to positive.
dim
+=
input_
.
dim
()
dim
+=
input_
.
dim
()
...
@@ -536,7 +542,7 @@ class GroupCoordinator:
...
@@ -536,7 +542,7 @@ class GroupCoordinator:
device
=
value
.
device
)
device
=
value
.
device
)
if
tensor
.
numel
()
==
0
:
if
tensor
.
numel
()
==
0
:
# Skip broadcasting empty tensors.
# Skip broadcasting empty tensors.
_update_nested_dict
(
tensor_dict
,
key
,
tensor
)
tensor_dict
[
key
]
=
tensor
continue
continue
if
tensor
.
is_cpu
:
if
tensor
.
is_cpu
:
# use metadata_group for CPU tensors
# use metadata_group for CPU tensors
...
@@ -553,9 +559,9 @@ class GroupCoordinator:
...
@@ -553,9 +559,9 @@ class GroupCoordinator:
group
=
group
,
group
=
group
,
async_op
=
True
)
async_op
=
True
)
async_handles
.
append
(
handle
)
async_handles
.
append
(
handle
)
_update_nested_dict
(
tensor_dict
,
key
,
tensor
)
tensor_dict
[
key
]
=
tensor
else
:
else
:
_update_nested_dict
(
tensor_dict
,
key
,
value
)
tensor_dict
[
key
]
=
value
for
async_handle
in
async_handles
:
for
async_handle
in
async_handles
:
async_handle
.
wait
()
async_handle
.
wait
()
return
tensor_dict
return
tensor_dict
...
@@ -563,7 +569,8 @@ class GroupCoordinator:
...
@@ -563,7 +569,8 @@ class GroupCoordinator:
def
send_tensor_dict
(
def
send_tensor_dict
(
self
,
self
,
tensor_dict
:
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]],
tensor_dict
:
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]],
dst
:
Optional
[
int
]
=
None
dst
:
Optional
[
int
]
=
None
,
all_gather_group
:
Optional
[
"GroupCoordinator"
]
=
None
,
)
->
Optional
[
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]]]:
)
->
Optional
[
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]]]:
"""Send the input tensor dictionary.
"""Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank.
NOTE: `dst` is the local rank of the source rank.
...
@@ -572,6 +579,11 @@ class GroupCoordinator:
...
@@ -572,6 +579,11 @@ class GroupCoordinator:
if
not
torch
.
distributed
.
is_initialized
()
or
self
.
world_size
==
1
:
if
not
torch
.
distributed
.
is_initialized
()
or
self
.
world_size
==
1
:
return
tensor_dict
return
tensor_dict
all_gather_size
=
(
1
if
all_gather_group
is
None
else
all_gather_group
.
world_size
)
all_gather_rank
=
(
0
if
all_gather_group
is
None
else
all_gather_group
.
rank_in_group
)
group
=
self
.
device_group
group
=
self
.
device_group
metadata_group
=
self
.
cpu_group
metadata_group
=
self
.
cpu_group
...
@@ -592,6 +604,12 @@ class GroupCoordinator:
...
@@ -592,6 +604,12 @@ class GroupCoordinator:
if
tensor
.
numel
()
==
0
:
if
tensor
.
numel
()
==
0
:
# Skip sending empty tensors.
# Skip sending empty tensors.
continue
continue
# send-allgather: send only a slice, then do allgather.
if
(
all_gather_group
is
not
None
and
tensor
.
numel
()
%
all_gather_size
==
0
):
tensor
=
tensor
.
reshape
(
all_gather_size
,
-
1
)[
all_gather_rank
]
if
tensor
.
is_cpu
:
if
tensor
.
is_cpu
:
# use metadata_group for CPU tensors
# use metadata_group for CPU tensors
torch
.
distributed
.
send
(
tensor
,
torch
.
distributed
.
send
(
tensor
,
...
@@ -606,7 +624,8 @@ class GroupCoordinator:
...
@@ -606,7 +624,8 @@ class GroupCoordinator:
def
recv_tensor_dict
(
def
recv_tensor_dict
(
self
,
self
,
src
:
Optional
[
int
]
=
None
src
:
Optional
[
int
]
=
None
,
all_gather_group
:
Optional
[
"GroupCoordinator"
]
=
None
,
)
->
Optional
[
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]]]:
)
->
Optional
[
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]]]:
"""Recv the input tensor dictionary.
"""Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
NOTE: `src` is the local rank of the source rank.
...
@@ -615,6 +634,11 @@ class GroupCoordinator:
...
@@ -615,6 +634,11 @@ class GroupCoordinator:
if
not
torch
.
distributed
.
is_initialized
()
or
self
.
world_size
==
1
:
if
not
torch
.
distributed
.
is_initialized
()
or
self
.
world_size
==
1
:
return
None
return
None
all_gather_size
=
(
1
if
all_gather_group
is
None
else
all_gather_group
.
world_size
)
all_gather_rank
=
(
0
if
all_gather_group
is
None
else
all_gather_group
.
rank_in_group
)
group
=
self
.
device_group
group
=
self
.
device_group
metadata_group
=
self
.
cpu_group
metadata_group
=
self
.
cpu_group
...
@@ -631,8 +655,18 @@ class GroupCoordinator:
...
@@ -631,8 +655,18 @@ class GroupCoordinator:
device
=
value
.
device
)
device
=
value
.
device
)
if
tensor
.
numel
()
==
0
:
if
tensor
.
numel
()
==
0
:
# Skip broadcasting empty tensors.
# Skip broadcasting empty tensors.
_update_nested_dict
(
tensor_dict
,
key
,
tensor
)
tensor_dict
[
key
]
=
tensor
continue
continue
# send-allgather: send only a slice, then do allgather.
use_all_gather
=
(
all_gather_group
is
not
None
and
tensor
.
numel
()
%
all_gather_size
==
0
)
if
use_all_gather
:
orig_shape
=
tensor
.
shape
tensor
=
tensor
.
reshape
(
all_gather_size
,
-
1
)[
all_gather_rank
]
if
tensor
.
is_cpu
:
if
tensor
.
is_cpu
:
# use metadata_group for CPU tensors
# use metadata_group for CPU tensors
torch
.
distributed
.
recv
(
tensor
,
torch
.
distributed
.
recv
(
tensor
,
...
@@ -643,9 +677,15 @@ class GroupCoordinator:
...
@@ -643,9 +677,15 @@ class GroupCoordinator:
torch
.
distributed
.
recv
(
tensor
,
torch
.
distributed
.
recv
(
tensor
,
src
=
self
.
ranks
[
src
],
src
=
self
.
ranks
[
src
],
group
=
group
)
group
=
group
)
_update_nested_dict
(
tensor_dict
,
key
,
tensor
)
if
use_all_gather
:
# do the allgather
tensor
=
all_gather_group
.
all_gather
(
# type: ignore
tensor
,
dim
=
0
)
tensor
=
tensor
.
reshape
(
orig_shape
)
tensor_dict
[
key
]
=
tensor
else
:
else
:
_update_nested_dict
(
tensor_dict
,
key
,
value
)
tensor_dict
[
key
]
=
value
return
tensor_dict
return
tensor_dict
def
barrier
(
self
):
def
barrier
(
self
):
...
@@ -673,8 +713,8 @@ class GroupCoordinator:
...
@@ -673,8 +713,8 @@ class GroupCoordinator:
size
:
torch
.
Size
,
size
:
torch
.
Size
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
src
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
src
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
"""Receives a tensor from the src rank."""
"""Receives a tensor from the s
ou
rc
e
rank."""
"""NOTE: `src` is the local rank of the
destination
rank."""
"""NOTE: `src` is the local rank of the
source
rank."""
if
src
is
None
:
if
src
is
None
:
src
=
(
self
.
rank_in_group
-
1
)
%
self
.
world_size
src
=
(
self
.
rank_in_group
-
1
)
%
self
.
world_size
...
@@ -717,6 +757,7 @@ def init_world_group(ranks: List[int], local_rank: int,
...
@@ -717,6 +757,7 @@ def init_world_group(ranks: List[int], local_rank: int,
torch_distributed_backend
=
backend
,
torch_distributed_backend
=
backend
,
use_pynccl
=
False
,
use_pynccl
=
False
,
use_custom_allreduce
=
False
,
use_custom_allreduce
=
False
,
use_tpu_communicator
=
False
,
)
)
...
@@ -735,6 +776,7 @@ def init_model_parallel_group(
...
@@ -735,6 +776,7 @@ def init_model_parallel_group(
torch_distributed_backend
=
backend
,
torch_distributed_backend
=
backend
,
use_pynccl
=
True
,
use_pynccl
=
True
,
use_custom_allreduce
=
use_custom_allreduce
,
use_custom_allreduce
=
use_custom_allreduce
,
use_tpu_communicator
=
True
,
use_message_queue_broadcaster
=
use_message_queue_broadcaster
,
use_message_queue_broadcaster
=
use_message_queue_broadcaster
,
)
)
...
...
vllm/distributed/utils.py
View file @
e661d594
...
@@ -6,6 +6,11 @@ from typing import Sequence, Tuple
...
@@ -6,6 +6,11 @@ from typing import Sequence, Tuple
import
torch
import
torch
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
def
ensure_divisibility
(
numerator
,
denominator
):
def
ensure_divisibility
(
numerator
,
denominator
):
"""Ensure that numerator is divisible by the denominator."""
"""Ensure that numerator is divisible by the denominator."""
...
@@ -54,11 +59,28 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int,
...
@@ -54,11 +59,28 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int,
If the number of layers is not divisible by the number of partitions,
If the number of layers is not divisible by the number of partitions,
the last partition will have the remaining layers.
the last partition will have the remaining layers.
"""
"""
layers_per_partition
=
num_hidden_layers
//
pp_size
partition_list_str
=
envs
.
VLLM_PP_LAYER_PARTITION
start_layer
=
pp_rank
*
layers_per_partition
if
partition_list_str
is
not
None
:
end_layer
=
start_layer
+
layers_per_partition
try
:
partitions
=
[
int
(
layer
)
for
layer
in
partition_list_str
.
split
(
","
)
]
except
ValueError
as
err
:
raise
ValueError
(
"Invalid partition string: {}"
.
format
(
partition_list_str
))
from
err
if
len
(
partitions
)
!=
pp_size
:
raise
ValueError
(
f
"
{
len
(
partitions
)
=
}
does not match
{
pp_size
=
}
."
)
if
sum
(
partitions
)
!=
num_hidden_layers
:
raise
ValueError
(
f
"
{
sum
(
partitions
)
=
}
does not match
{
num_hidden_layers
=
}
."
)
start_layer
=
sum
(
partitions
[:
pp_rank
])
end_layer
=
start_layer
+
partitions
[
pp_rank
]
else
:
layers_per_partition
=
num_hidden_layers
//
pp_size
start_layer
=
pp_rank
*
layers_per_partition
end_layer
=
start_layer
+
layers_per_partition
if
pp_rank
==
pp_size
-
1
:
if
pp_rank
==
pp_size
-
1
:
end_layer
=
num_hidden_layers
end_layer
=
num_hidden_layers
return
(
start_layer
,
end_layer
)
return
(
start_layer
,
end_layer
)
vllm/engine/arg_utils.py
View file @
e661d594
...
@@ -632,9 +632,9 @@ class EngineArgs:
...
@@ -632,9 +632,9 @@ class EngineArgs:
'--preemption-mode'
,
'--preemption-mode'
,
type
=
str
,
type
=
str
,
default
=
None
,
default
=
None
,
help
=
'If
\'
recompute
\'
, the engine performs preemption by
block
'
help
=
'If
\'
recompute
\'
, the engine performs preemption by '
'
swapp
ing; If
\'
swap
\'
, the engine performs preemption by
block
'
'
recomput
ing; If
\'
swap
\'
, the engine performs preemption by '
'swapping.'
)
'
block
swapping.'
)
parser
.
add_argument
(
parser
.
add_argument
(
"--served-model-name"
,
"--served-model-name"
,
...
@@ -676,8 +676,8 @@ class EngineArgs:
...
@@ -676,8 +676,8 @@ class EngineArgs:
# bitsandbytes quantization needs a specific model loader
# bitsandbytes quantization needs a specific model loader
# so we make sure the quant method and the load format are consistent
# so we make sure the quant method and the load format are consistent
if
(
self
.
quantization
==
"bitsandbytes"
or
if
(
self
.
quantization
==
"bitsandbytes"
or
self
.
qlora_adapter_name_or_path
is
not
None
)
and
\
self
.
qlora_adapter_name_or_path
is
not
None
)
and
\
self
.
load_format
!=
"bitsandbytes"
:
self
.
load_format
!=
"bitsandbytes"
:
raise
ValueError
(
raise
ValueError
(
"BitsAndBytes quantization and QLoRA adapter only support "
"BitsAndBytes quantization and QLoRA adapter only support "
f
"'bitsandbytes' load format, but got
{
self
.
load_format
}
"
)
f
"'bitsandbytes' load format, but got
{
self
.
load_format
}
"
)
...
@@ -754,10 +754,14 @@ class EngineArgs:
...
@@ -754,10 +754,14 @@ class EngineArgs:
use_sliding_window
=
(
model_config
.
get_sliding_window
()
use_sliding_window
=
(
model_config
.
get_sliding_window
()
is
not
None
)
is
not
None
)
use_spec_decode
=
self
.
speculative_model
is
not
None
use_spec_decode
=
self
.
speculative_model
is
not
None
has_seqlen_agnostic_layers
=
(
model_config
.
contains_seqlen_agnostic_layers
(
parallel_config
))
if
(
is_gpu
and
not
use_sliding_window
and
not
use_spec_decode
if
(
is_gpu
and
not
use_sliding_window
and
not
use_spec_decode
and
not
self
.
enable_lora
and
not
self
.
enable_lora
and
not
self
.
enable_prompt_adapter
and
not
self
.
enable_prompt_adapter
and
not
self
.
enable_prefix_caching
):
and
not
self
.
enable_prefix_caching
and
not
has_seqlen_agnostic_layers
):
self
.
enable_chunked_prefill
=
True
self
.
enable_chunked_prefill
=
True
logger
.
warning
(
logger
.
warning
(
"Chunked prefill is enabled by default for models with "
"Chunked prefill is enabled by default for models with "
...
@@ -788,6 +792,7 @@ class EngineArgs:
...
@@ -788,6 +792,7 @@ class EngineArgs:
speculative_max_model_len
=
self
.
speculative_max_model_len
,
speculative_max_model_len
=
self
.
speculative_max_model_len
,
enable_chunked_prefill
=
self
.
enable_chunked_prefill
,
enable_chunked_prefill
=
self
.
enable_chunked_prefill
,
use_v2_block_manager
=
self
.
use_v2_block_manager
,
use_v2_block_manager
=
self
.
use_v2_block_manager
,
disable_log_stats
=
self
.
disable_log_stats
,
ngram_prompt_lookup_max
=
self
.
ngram_prompt_lookup_max
,
ngram_prompt_lookup_max
=
self
.
ngram_prompt_lookup_max
,
ngram_prompt_lookup_min
=
self
.
ngram_prompt_lookup_min
,
ngram_prompt_lookup_min
=
self
.
ngram_prompt_lookup_min
,
draft_token_acceptance_method
=
\
draft_token_acceptance_method
=
\
...
...
vllm/engine/async_llm_engine.py
View file @
e661d594
...
@@ -7,7 +7,8 @@ from typing import (AsyncIterator, Callable, Dict, Iterable, List, Mapping,
...
@@ -7,7 +7,8 @@ from typing import (AsyncIterator, Callable, Dict, Iterable, List, Mapping,
from
transformers
import
PreTrainedTokenizer
from
transformers
import
PreTrainedTokenizer
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.config
import
DecodingConfig
,
EngineConfig
,
ModelConfig
from
vllm.config
import
(
DecodingConfig
,
EngineConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_timeout
import
asyncio_timeout
from
vllm.engine.async_timeout
import
asyncio_timeout
...
@@ -407,11 +408,15 @@ class AsyncLLMEngine:
...
@@ -407,11 +408,15 @@ class AsyncLLMEngine:
from
vllm.executor.neuron_executor
import
NeuronExecutorAsync
from
vllm.executor.neuron_executor
import
NeuronExecutorAsync
executor_class
=
NeuronExecutorAsync
executor_class
=
NeuronExecutorAsync
elif
engine_config
.
device_config
.
device_type
==
"tpu"
:
elif
engine_config
.
device_config
.
device_type
==
"tpu"
:
from
vllm.executor.tpu_executor
import
TPUExecutorAsync
if
distributed_executor_backend
==
"ray"
:
executor_class
=
TPUExecutorAsync
initialize_ray_cluster
(
engine_config
.
parallel_config
)
from
vllm.executor.ray_tpu_executor
import
RayTPUExecutorAsync
executor_class
=
RayTPUExecutorAsync
else
:
assert
distributed_executor_backend
is
None
from
vllm.executor.tpu_executor
import
TPUExecutorAsync
executor_class
=
TPUExecutorAsync
elif
engine_config
.
device_config
.
device_type
==
"cpu"
:
elif
engine_config
.
device_config
.
device_type
==
"cpu"
:
assert
distributed_executor_backend
is
None
,
(
"Distributed execution is not supported with the CPU backend."
)
from
vllm.executor.cpu_executor
import
CPUExecutorAsync
from
vllm.executor.cpu_executor
import
CPUExecutorAsync
executor_class
=
CPUExecutorAsync
executor_class
=
CPUExecutorAsync
elif
engine_config
.
device_config
.
device_type
==
"openvino"
:
elif
engine_config
.
device_config
.
device_type
==
"openvino"
:
...
@@ -924,6 +929,14 @@ class AsyncLLMEngine:
...
@@ -924,6 +929,14 @@ class AsyncLLMEngine:
else
:
else
:
return
self
.
engine
.
get_model_config
()
return
self
.
engine
.
get_model_config
()
async
def
get_parallel_config
(
self
)
->
ParallelConfig
:
"""Get the parallel configuration of the vLLM engine."""
if
self
.
engine_use_ray
:
return
await
self
.
engine
.
get_parallel_config
.
remote
(
# type: ignore
)
else
:
return
self
.
engine
.
get_parallel_config
()
async
def
get_decoding_config
(
self
)
->
DecodingConfig
:
async
def
get_decoding_config
(
self
)
->
DecodingConfig
:
"""Get the decoding configuration of the vLLM engine."""
"""Get the decoding configuration of the vLLM engine."""
if
self
.
engine_use_ray
:
if
self
.
engine_use_ray
:
...
@@ -932,6 +945,22 @@ class AsyncLLMEngine:
...
@@ -932,6 +945,22 @@ class AsyncLLMEngine:
else
:
else
:
return
self
.
engine
.
get_decoding_config
()
return
self
.
engine
.
get_decoding_config
()
async
def
get_scheduler_config
(
self
)
->
SchedulerConfig
:
"""Get the scheduling configuration of the vLLM engine."""
if
self
.
engine_use_ray
:
return
await
self
.
engine
.
get_scheduler_config
.
remote
(
# type: ignore
)
else
:
return
self
.
engine
.
get_scheduler_config
()
async
def
get_lora_config
(
self
)
->
LoRAConfig
:
"""Get the lora configuration of the vLLM engine."""
if
self
.
engine_use_ray
:
return
await
self
.
engine
.
get_lora_config
.
remote
(
# type: ignore
)
else
:
return
self
.
engine
.
get_lora_config
()
async
def
do_log_stats
(
async
def
do_log_stats
(
self
,
self
,
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
,
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
,
...
...
vllm/engine/llm_engine.py
View file @
e661d594
...
@@ -5,8 +5,6 @@ from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
...
@@ -5,8 +5,6 @@ from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
,
Type
,
TypeVar
,
Union
from
typing
import
Set
,
Type
,
TypeVar
,
Union
from
transformers
import
PreTrainedTokenizer
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
DeviceConfig
,
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
DeviceConfig
,
EngineConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
EngineConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
...
@@ -40,8 +38,8 @@ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
...
@@ -40,8 +38,8 @@ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer
)
init_tracer
)
from
vllm.transformers_utils.config
import
try_get_generation_config
from
vllm.transformers_utils.config
import
try_get_generation_config
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.tokenizer_group
import
(
BaseTokenizerGroup
,
from
vllm.transformers_utils.tokenizer_group
import
(
ge
t_tokenizer_
g
ro
up
)
AnyTokenizer
,
BaseTokenizerGroup
,
ini
t_tokenizer_
f
ro
m_configs
)
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
usage_message
)
usage_message
)
from
vllm.utils
import
Counter
from
vllm.utils
import
Counter
...
@@ -408,8 +406,14 @@ class LLMEngine:
...
@@ -408,8 +406,14 @@ class LLMEngine:
from
vllm.executor.neuron_executor
import
NeuronExecutor
from
vllm.executor.neuron_executor
import
NeuronExecutor
executor_class
=
NeuronExecutor
executor_class
=
NeuronExecutor
elif
engine_config
.
device_config
.
device_type
==
"tpu"
:
elif
engine_config
.
device_config
.
device_type
==
"tpu"
:
from
vllm.executor.tpu_executor
import
TPUExecutor
if
distributed_executor_backend
==
"ray"
:
executor_class
=
TPUExecutor
initialize_ray_cluster
(
engine_config
.
parallel_config
)
from
vllm.executor.ray_tpu_executor
import
RayTPUExecutor
executor_class
=
RayTPUExecutor
else
:
assert
distributed_executor_backend
is
None
from
vllm.executor.tpu_executor
import
TPUExecutor
executor_class
=
TPUExecutor
elif
engine_config
.
device_config
.
device_type
==
"cpu"
:
elif
engine_config
.
device_config
.
device_type
==
"cpu"
:
from
vllm.executor.cpu_executor
import
CPUExecutor
from
vllm.executor.cpu_executor
import
CPUExecutor
executor_class
=
CPUExecutor
executor_class
=
CPUExecutor
...
@@ -485,29 +489,21 @@ class LLMEngine:
...
@@ -485,29 +489,21 @@ class LLMEngine:
return
self
.
tokenizer
return
self
.
tokenizer
def
get_tokenizer
(
def
get_tokenizer
(
self
,
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
"PreTrained
Tokenizer
"
:
)
->
Any
Tokenizer
:
return
self
.
get_tokenizer_group
().
get_lora_tokenizer
(
lora_request
)
return
self
.
get_tokenizer_group
().
get_lora_tokenizer
(
lora_request
)
# def get_tokenizer_for_seq(self,
# def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:
# sequence: Sequence) -> "PreTrainedTokenizer":
# return self.get_tokenizer_group().get_lora_tokenizer(
# return self.get_tokenizer_group().get_lora_tokenizer(
# sequence.lora_request)
# sequence.lora_request)
def
_init_tokenizer
(
self
,
**
tokenizer_init_kwargs
)
->
BaseTokenizerGroup
:
def
_init_tokenizer
(
self
)
->
BaseTokenizerGroup
:
init_kwargs
=
dict
(
return
init_tokenizer_from_configs
(
tokenizer_id
=
self
.
model_config
.
tokenizer
,
model_config
=
self
.
model_config
,
enable_lora
=
bool
(
self
.
lora_config
),
scheduler_config
=
self
.
scheduler_config
,
max_num_seqs
=
self
.
scheduler_config
.
max_num_seqs
,
parallel_config
=
self
.
parallel_config
,
max_input_length
=
None
,
enable_lora
=
bool
(
self
.
lora_config
))
tokenizer_mode
=
self
.
model_config
.
tokenizer_mode
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
revision
=
self
.
model_config
.
tokenizer_revision
)
init_kwargs
.
update
(
tokenizer_init_kwargs
)
return
get_tokenizer_group
(
self
.
parallel_config
.
tokenizer_pool_config
,
**
init_kwargs
)
def
_verify_args
(
self
)
->
None
:
def
_verify_args
(
self
)
->
None
:
self
.
model_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
self
.
model_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
...
@@ -769,10 +765,22 @@ class LLMEngine:
...
@@ -769,10 +765,22 @@ class LLMEngine:
"""Gets the model configuration."""
"""Gets the model configuration."""
return
self
.
model_config
return
self
.
model_config
def
get_parallel_config
(
self
)
->
ParallelConfig
:
"""Gets the parallel configuration."""
return
self
.
parallel_config
def
get_decoding_config
(
self
)
->
DecodingConfig
:
def
get_decoding_config
(
self
)
->
DecodingConfig
:
"""Gets the decoding configuration."""
"""Gets the decoding configuration."""
return
self
.
decoding_config
return
self
.
decoding_config
def
get_scheduler_config
(
self
)
->
SchedulerConfig
:
"""Gets the scheduler configuration."""
return
self
.
scheduler_config
def
get_lora_config
(
self
)
->
LoRAConfig
:
"""Gets the LoRA configuration."""
return
self
.
lora_config
def
get_num_unfinished_requests
(
self
)
->
int
:
def
get_num_unfinished_requests
(
self
)
->
int
:
"""Gets the number of unfinished requests."""
"""Gets the number of unfinished requests."""
return
sum
(
scheduler
.
get_num_unfinished_seq_groups
()
return
sum
(
scheduler
.
get_num_unfinished_seq_groups
()
...
@@ -963,8 +971,9 @@ class LLMEngine:
...
@@ -963,8 +971,9 @@ class LLMEngine:
model_output
:
Optional
[
List
[
SamplerOutput
]]
=
None
)
->
None
:
model_output
:
Optional
[
List
[
SamplerOutput
]]
=
None
)
->
None
:
"""Forced log when no requests active."""
"""Forced log when no requests active."""
if
self
.
log_stats
:
if
self
.
log_stats
:
stats
=
self
.
_get_stats
(
scheduler_outputs
,
model_output
)
for
logger
in
self
.
stat_loggers
.
values
():
for
logger
in
self
.
stat_loggers
.
values
():
logger
.
log
(
s
elf
.
_get_stats
(
scheduler_outputs
,
model_output
)
)
logger
.
log
(
s
tats
)
def
_get_stats
(
def
_get_stats
(
self
,
self
,
...
...
Prev
1
…
7
8
9
10
11
12
13
14
15
…
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment