Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e661d594
"vllm/vscode:/vscode.git/clone" did not exist on "b7cbc254169128a4203d111f3b87edaa17839a32"
Commit
e661d594
authored
Aug 12, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.5.4' into v0.5.4-dtk24.04.1
parents
6b16ea2e
4db5176d
Changes
374
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
412 additions
and
307 deletions
+412
-307
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+8
-2
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+6
-2
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+10
-10
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+7
-2
vllm/attention/layer.py
vllm/attention/layer.py
+2
-1
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+5
-2
vllm/config.py
vllm/config.py
+52
-17
vllm/connections.py
vllm/connections.py
+2
-2
vllm/core/block_manager_v1.py
vllm/core/block_manager_v1.py
+1
-1
vllm/core/policy.py
vllm/core/policy.py
+0
-45
vllm/core/scheduler.py
vllm/core/scheduler.py
+55
-81
vllm/distributed/device_communicators/cuda_wrapper.py
vllm/distributed/device_communicators/cuda_wrapper.py
+22
-22
vllm/distributed/device_communicators/custom_all_reduce_utils.py
...stributed/device_communicators/custom_all_reduce_utils.py
+3
-1
vllm/distributed/device_communicators/shm_broadcast.py
vllm/distributed/device_communicators/shm_broadcast.py
+21
-39
vllm/distributed/device_communicators/tpu_communicator.py
vllm/distributed/device_communicators/tpu_communicator.py
+31
-0
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+81
-39
vllm/distributed/utils.py
vllm/distributed/utils.py
+27
-5
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+11
-6
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+34
-5
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+34
-25
No files found.
vllm/attention/backends/rocm_flash_attn.py
View file @
e661d594
...
...
@@ -244,9 +244,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
None
:
assert
blocksparse_params
is
None
,
ValueError
(
"ROCFlashAttention does not support blocksparse attention."
)
if
blocksparse_params
is
not
None
:
raise
ValueError
(
"ROCmFlashAttention does not support blocksparse attention."
)
if
logits_soft_cap
is
not
None
:
raise
ValueError
(
"ROCmFlashAttention does not support attention logits soft "
"capping."
)
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
...
...
vllm/attention/backends/torch_sdpa.py
View file @
e661d594
...
...
@@ -109,9 +109,13 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
None
:
assert
blocksparse_params
is
None
,
ValueError
(
"Torch SPDA does not support block-sparse attention."
)
if
blocksparse_params
is
not
None
:
raise
ValueError
(
"Torch SPDA does not support block-sparse attention."
)
if
logits_soft_cap
is
not
None
:
raise
ValueError
(
"Torch SPDA does not support logits soft cap."
)
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
...
...
vllm/attention/backends/utils.py
View file @
e661d594
...
...
@@ -149,6 +149,15 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
for
inter_data
in
self
.
input_builder
.
inter_data_list
:
self
.
_add_seq_group
(
inter_data
,
self
.
input_builder
.
chunked_prefill_enabled
)
...
...
@@ -156,15 +165,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
device
=
self
.
runner
.
device
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
logits_soft_cap
=
getattr
(
self
.
runner
.
model_config
.
hf_config
,
"attn_logit_softcapping"
,
None
)
if
logits_soft_cap
is
not
None
:
raise
ValueError
(
"Please use Flashinfer backend for models with logits_soft_cap "
"(i.e., Gemma-2). Otherwise, the output might be wrong. "
"Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER."
)
max_query_len
=
max
(
query_lens
)
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
max_decode_seq_len
=
max
(
self
.
curr_seq_lens
,
default
=
0
)
...
...
@@ -173,7 +173,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
if
use_captured_graph
:
self
.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_size
)
self
.
block_tables
.
extend
([]
*
cuda_graph_pad_size
)
num_decode_tokens
=
batch_size
+
cuda_graph_pad_size
num_decode_tokens
=
batch_size
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
...
...
vllm/attention/backends/xformers.py
View file @
e661d594
...
...
@@ -408,9 +408,14 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
None
:
assert
blocksparse_params
is
None
,
ValueError
(
"XFormer does not support block-sparse attention."
)
if
blocksparse_params
is
not
None
:
raise
ValueError
(
"XFormers does not support block-sparse attention."
)
if
logits_soft_cap
is
not
None
:
raise
ValueError
(
"XFormers does not support attention logits soft capping."
)
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
...
...
vllm/attention/layer.py
View file @
e661d594
...
...
@@ -34,6 +34,7 @@ class Attention(nn.Module):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
...
...
@@ -82,7 +83,7 @@ class Attention(nn.Module):
impl_cls
=
attn_backend
.
get_impl_cls
()
self
.
impl
=
impl_cls
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
blocksparse_params
)
blocksparse_params
,
logits_soft_cap
)
def
forward
(
self
,
...
...
vllm/attention/ops/paged_attn.py
View file @
e661d594
...
...
@@ -4,7 +4,10 @@ from typing import List, Optional, Tuple
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
from
vllm.triton_utils
import
HAS_TRITON
if
HAS_TRITON
:
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE
=
512
...
...
@@ -31,7 +34,7 @@ class PagedAttention:
@
staticmethod
def
get_supported_head_sizes
()
->
List
[
int
]:
return
[
64
,
80
,
96
,
112
,
128
,
192
,
256
]
return
[
64
,
80
,
96
,
112
,
120
,
128
,
192
,
256
]
@
staticmethod
def
get_kv_cache_shape
(
...
...
vllm/config.py
View file @
e661d594
...
...
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Type, Union
import
torch
from
transformers
import
PretrainedConfig
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.model_executor.models
import
ModelRegistry
...
...
@@ -31,6 +32,7 @@ _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_PP_SUPPORTED_MODELS
=
[
"AquilaModel"
,
"AquilaForCausalLM"
,
"DeepseekV2ForCausalLM"
,
"InternLMForCausalLM"
,
"LlamaForCausalLM"
,
"LLaMAForCausalLM"
,
...
...
@@ -38,6 +40,10 @@ _PP_SUPPORTED_MODELS = [
"Phi3ForCausalLM"
,
"GPT2LMHeadModel"
,
"MixtralForCausalLM"
,
"NemotronForCausalLM"
,
"Qwen2ForCausalLM"
,
"Qwen2MoeForCausalLM"
,
"QWenLMHeadModel"
,
]
...
...
@@ -195,13 +201,17 @@ class ModelConfig:
def
_parse_quant_hf_config
(
self
):
quant_cfg
=
getattr
(
self
.
hf_config
,
"quantization_config"
,
None
)
if
quant_cfg
is
None
:
# compress-tensors uses a "compression_config" key
# compress
ed
-tensors uses a "compression_config" key
quant_cfg
=
getattr
(
self
.
hf_config
,
"compression_config"
,
None
)
return
quant_cfg
def
_verify_quantization
(
self
)
->
None
:
supported_quantization
=
[
*
QUANTIZATION_METHODS
]
rocm_supported_quantization
=
[
"gptq"
,
"squeezellm"
,
"awq"
]
optimized_quantization_methods
=
[
"fp8"
,
"marlin"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
"awq_marlin"
,
"fbgemm_fp8"
,
"compressed_tensors"
,
"compressed-tensors"
]
if
self
.
quantization
is
not
None
:
self
.
quantization
=
self
.
quantization
.
lower
()
...
...
@@ -240,9 +250,7 @@ class ModelConfig:
raise
ValueError
(
f
"
{
self
.
quantization
}
quantization is currently not "
f
"supported in ROCm."
)
if
(
self
.
quantization
not
in
(
"fp8"
,
"marlin"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
"awq_marlin"
,
"fbgemm_fp8"
,
"compressed_tensors"
)):
if
self
.
quantization
not
in
optimized_quantization_methods
:
logger
.
warning
(
"%s quantization is not fully "
"optimized yet. The speed can be slower than "
...
...
@@ -281,6 +289,10 @@ class ModelConfig:
raise
ValueError
(
"BitAndBytes quantization with TP or PP is not supported yet."
)
if
self
.
quantization
==
"bitsandbytes"
and
self
.
enforce_eager
is
False
:
raise
ValueError
(
"BitAndBytes with enforce_eager = False is not supported yet."
)
def
get_hf_config_sliding_window
(
self
)
->
Optional
[
int
]:
"""Get the sliding window size, or None if disabled."""
...
...
@@ -590,9 +602,11 @@ class LoadConfig:
mainly for profiling.
"tensorizer" will use CoreWeave's tensorizer library for
fast weight loading.
"bitsandbytes" will load nf4 type weights.
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
"""
load_format
:
Union
[
str
,
LoadFormat
,
"BaseModelLoader"
]
=
LoadFormat
.
AUTO
...
...
@@ -716,7 +730,7 @@ class ParallelConfig:
backend
)
self
.
_verify_args
()
self
.
rank
=
0
self
.
rank
:
int
=
0
@
property
def
use_ray
(
self
)
->
bool
:
...
...
@@ -842,6 +856,7 @@ class SchedulerConfig:
class
DeviceConfig
:
device
:
Optional
[
torch
.
device
]
def
__init__
(
self
,
device
:
str
=
"auto"
)
->
None
:
if
device
==
"auto"
:
...
...
@@ -892,6 +907,7 @@ class SpeculativeConfig:
speculative_max_model_len
:
Optional
[
int
],
enable_chunked_prefill
:
bool
,
use_v2_block_manager
:
bool
,
disable_log_stats
:
bool
,
speculative_disable_by_batch_size
:
Optional
[
int
],
ngram_prompt_lookup_max
:
Optional
[
int
],
ngram_prompt_lookup_min
:
Optional
[
int
],
...
...
@@ -1053,7 +1069,7 @@ class SpeculativeConfig:
draft_parallel_config
=
(
SpeculativeConfig
.
create_draft_parallel_config
(
target_parallel_config
,
speculative_draft_tensor_parallel_size
))
speculative_draft_tensor_parallel_size
,
draft_hf_config
))
if
num_speculative_tokens
is
None
:
raise
ValueError
(
...
...
@@ -1080,7 +1096,8 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_threshold
,
typical_acceptance_sampler_posterior_alpha
=
\
typical_acceptance_sampler_posterior_alpha
,
disable_logprobs
=
disable_logprobs
disable_logprobs
=
disable_logprobs
,
disable_log_stats
=
disable_log_stats
,
)
@
staticmethod
...
...
@@ -1121,15 +1138,23 @@ class SpeculativeConfig:
@
staticmethod
def
create_draft_parallel_config
(
target_parallel_config
:
ParallelConfig
,
speculative_draft_tensor_parallel_size
:
Optional
[
int
]
speculative_draft_tensor_parallel_size
:
Optional
[
int
],
draft_hf_config
:
PretrainedConfig
,
)
->
ParallelConfig
:
"""Create a parallel config for use by the draft worker.
This is mostly a copy of the target parallel config, except the tp_size.
"""
if
speculative_draft_tensor_parallel_size
is
None
:
speculative_draft_tensor_parallel_size
=
\
target_parallel_config
.
tensor_parallel_size
if
draft_hf_config
.
model_type
==
"mlp_speculator"
:
speculative_draft_tensor_parallel_size
=
1
if
target_parallel_config
.
tensor_parallel_size
>
1
:
logger
.
warning
(
"MLPSpeculator cannot currently be run with tp>1; "
"setting speculative_draft_tensor_parallel_size=1"
)
else
:
speculative_draft_tensor_parallel_size
=
\
target_parallel_config
.
tensor_parallel_size
elif
speculative_draft_tensor_parallel_size
!=
1
:
# TODO(wooyeon): allow tp values larger than 1
raise
ValueError
(
...
...
@@ -1166,6 +1191,7 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_threshold
:
float
,
typical_acceptance_sampler_posterior_alpha
:
float
,
disable_logprobs
:
bool
,
disable_log_stats
:
bool
,
):
"""Create a SpeculativeConfig object.
...
...
@@ -1198,6 +1224,8 @@ class SpeculativeConfig:
sampling, target sampling, and after accepted tokens are
determined. If set to False, log probabilities will be
returned.
disable_log_stats: Whether to disable periodic printing of stage
times in speculative decoding.
"""
self
.
draft_model_config
=
draft_model_config
self
.
draft_parallel_config
=
draft_parallel_config
...
...
@@ -1212,6 +1240,7 @@ class SpeculativeConfig:
self
.
typical_acceptance_sampler_posterior_alpha
=
\
typical_acceptance_sampler_posterior_alpha
self
.
disable_logprobs
=
disable_logprobs
self
.
disable_log_stats
=
disable_log_stats
self
.
_verify_args
()
...
...
@@ -1281,7 +1310,7 @@ class LoRAConfig:
long_lora_scaling_factors
:
Optional
[
Tuple
[
float
]]
=
None
def
__post_init__
(
self
):
#
Keep this in sync with csrc/punica/bgmv/bgmv_config.h
#
TODO: Increase the range of rank
possible_max_ranks
=
(
8
,
16
,
32
,
64
)
possible_lora_extra_vocab_size
=
(
0
,
256
,
512
)
if
self
.
max_lora_rank
not
in
possible_max_ranks
:
...
...
@@ -1527,15 +1556,21 @@ def _get_and_verify_max_len(
"Disabling sliding window is not supported for models "
"model_max_length in the config. Please raise an issue "
"so we can investigate."
)
pass
else
:
raise
ValueError
(
msg
=
(
f
"User-specified max_model_len (
{
max_model_len
}
) is greater "
"than the derived max_model_len "
f
"
(
{
max_len_key
}
=
{
derived_max_model_len
}
or model_max_length="
f
"than the derived max_model_len
(
{
max_len_key
}
=
"
f
"
{
derived_max_model_len
}
or model_max_length="
f
"
{
model_max_length
}
in model's config.json). This may lead "
"to incorrect model outputs or CUDA errors. Make sure the "
"value is correct and within the model context size."
)
"to incorrect model outputs or CUDA errors."
)
if
envs
.
VLLM_ALLOW_LONG_MAX_MODEL_LEN
:
logger
.
warning
(
"%s Make sure the value is correct and within the "
"model context size."
,
msg
)
else
:
raise
ValueError
(
f
"
{
msg
}
To allow overriding this maximum, set "
"the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN=1"
)
return
int
(
max_model_len
)
...
...
vllm/connections.py
View file @
e661d594
from
pathlib
import
Path
from
typing
import
Mapping
,
Optional
from
typing
import
Mapping
,
MutableMapping
,
Optional
from
urllib.parse
import
urlparse
import
aiohttp
...
...
@@ -40,7 +40,7 @@ class HTTPConnection:
raise
ValueError
(
"Invalid HTTP URL: A valid HTTP URL "
"must have scheme 'http' or 'https'."
)
def
_headers
(
self
,
**
extras
:
str
)
->
Mapping
[
str
,
str
]:
def
_headers
(
self
,
**
extras
:
str
)
->
Mutable
Mapping
[
str
,
str
]:
return
{
"User-Agent"
:
f
"vLLM/
{
VLLM_VERSION
}
"
,
**
extras
}
def
get_response
(
...
...
vllm/core/block_manager_v1.py
View file @
e661d594
...
...
@@ -700,5 +700,5 @@ class BlockSpaceManagerV1(BlockSpaceManager):
def
mark_blocks_as_computed
(
self
,
seq_group
:
SequenceGroup
):
if
self
.
enable_caching
:
for
seq
in
seq_group
.
seqs_dict
.
value
s
():
for
seq
in
seq_group
.
get_seq
s
():
self
.
compute_full_blocks_in_seq
(
seq
)
vllm/core/policy.py
deleted
100644 → 0
View file @
6b16ea2e
from
collections
import
deque
from
typing
import
Deque
from
vllm.sequence
import
SequenceGroup
class
Policy
:
def
get_priority
(
self
,
now
:
float
,
seq_group
:
SequenceGroup
,
)
->
float
:
raise
NotImplementedError
def
sort_by_priority
(
self
,
now
:
float
,
seq_groups
:
Deque
[
SequenceGroup
],
)
->
Deque
[
SequenceGroup
]:
return
deque
(
sorted
(
seq_groups
,
key
=
lambda
seq_group
:
self
.
get_priority
(
now
,
seq_group
),
reverse
=
True
,
))
class
FCFS
(
Policy
):
def
get_priority
(
self
,
now
:
float
,
seq_group
:
SequenceGroup
,
)
->
float
:
return
now
-
seq_group
.
metrics
.
arrival_time
class
PolicyFactory
:
_POLICY_REGISTRY
=
{
'fcfs'
:
FCFS
}
@
classmethod
def
get_policy
(
cls
,
policy_name
:
str
,
**
kwargs
)
->
Policy
:
return
cls
.
_POLICY_REGISTRY
[
policy_name
](
**
kwargs
)
vllm/core/scheduler.py
View file @
e661d594
...
...
@@ -8,7 +8,6 @@ from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.core.interfaces
import
AllocStatus
,
BlockSpaceManager
from
vllm.core.policy
import
Policy
,
PolicyFactory
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
...
...
@@ -313,6 +312,7 @@ class Scheduler:
# Sequence groups finished requests ids since last step iteration.
# It lets the model know that any state associated with these requests
# can and must be released after the current step.
# This is used to evict the finished requests from the Mamba cache.
self
.
_finished_requests_ids
:
List
[
str
]
=
list
()
# Time at previous scheduling step
self
.
prev_time
=
0.0
...
...
@@ -344,6 +344,16 @@ class Scheduler:
# Add sequence groups to the waiting queue.
self
.
waiting
.
append
(
seq_group
)
def
_add_seq_group_to_running
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
# Add sequence groups to the running queue.
# Only for testing purposes.
self
.
running
.
append
(
seq_group
)
def
_add_seq_group_to_swapped
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
# Add sequence groups to the swapped queue.
# Only for testing purposes.
self
.
swapped
.
append
(
seq_group
)
def
abort_seq_group
(
self
,
request_id
:
Union
[
str
,
Iterable
[
str
]])
->
None
:
"""Aborts a sequence group with the given ID.
...
...
@@ -374,6 +384,7 @@ class Scheduler:
for
aborted_group
in
aborted_groups
:
# Remove the sequence group from the state queue.
state_queue
.
remove
(
aborted_group
)
# Remove the aborted request from the Mamba cache.
self
.
_finished_requests_ids
.
append
(
aborted_group
.
request_id
)
for
seq
in
aborted_group
.
get_seqs
():
if
seq
.
is_finished
():
...
...
@@ -396,32 +407,26 @@ class Scheduler:
def
_schedule_running
(
self
,
running_queue
:
deque
,
budget
:
SchedulingBudget
,
curr_loras
:
Optional
[
Set
[
int
]],
policy
:
Policy
,
enable_chunking
:
bool
=
False
,
)
->
Tuple
[
deque
,
SchedulerRunningOutputs
]
:
)
->
SchedulerRunningOutputs
:
"""Schedule sequence groups that are running.
Running queue should include decode and chunked prefill requests.
Args:
running_queue: The queue that contains running requests (i.e.,
decodes). The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated
when any decodes are preempted.
curr_loras: Currently batched lora request ids. The argument is
in-place updated when any decodes are preempted.
policy: The sorting policy to sort running_queue.
enable_chunking: If True, seq group can be chunked and only a
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
Returns:
A tuple of remaining running queue (should be always 0) after
scheduling and SchedulerRunningOutputs.
SchedulerRunningOutputs.
"""
# Blocks that need to be swapped or copied before model execution.
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
...
...
@@ -434,10 +439,9 @@ class Scheduler:
# NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state.
# In this case, the policy is responsible for deciding which sequence
# groups to preempt.
now
=
time
.
time
()
running_queue
=
policy
.
sort_by_priority
(
now
,
running_queue
)
running_queue
=
self
.
running
while
running_queue
:
seq_group
=
running_queue
[
0
]
num_running_tokens
=
self
.
_get_num_new_tokens
(
...
...
@@ -501,7 +505,7 @@ class Scheduler:
if
curr_loras
is
not
None
and
seq_group
.
lora_int_id
>
0
:
curr_loras
.
add
(
seq_group
.
lora_int_id
)
return
running_queue
,
SchedulerRunningOutputs
(
return
SchedulerRunningOutputs
(
decode_seq_groups
=
decode_seq_groups
,
prefill_seq_groups
=
prefill_seq_groups
,
preempted
=
preempted
,
...
...
@@ -513,12 +517,10 @@ class Scheduler:
def
_schedule_swapped
(
self
,
swapped_queue
:
deque
,
budget
:
SchedulingBudget
,
curr_loras
:
Optional
[
Set
[
int
]],
policy
:
Policy
,
enable_chunking
:
bool
=
False
,
)
->
Tuple
[
deque
,
SchedulerSwappedInOutputs
]
:
)
->
SchedulerSwappedInOutputs
:
"""Schedule sequence groups that are swapped out.
It schedules swapped requests as long as it fits `budget` and
...
...
@@ -526,20 +528,16 @@ class Scheduler:
`budget` and `curr_loras` are updated based on scheduled seq_groups.
Args:
swapped_queue: The queue that contains swapped out requests.
The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated
when any requests are swapped in.
curr_loras: Currently batched lora request ids. The argument is
in-place updated when any requests are swapped in.
policy: The sorting policy to sort swapped_queue.
enable_chunking: If True, seq group can be chunked and only a
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
Returns:
A tuple of remaining swapped_queue after scheduling and
SchedulerSwappedInOutputs.
"""
# Blocks that need to be swapped or copied before model execution.
...
...
@@ -547,10 +545,10 @@ class Scheduler:
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]]
=
[]
decode_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
prefill_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
now
=
time
.
time
()
swapped_queue
=
policy
.
sort_by_priority
(
now
,
swapped_queue
)
infeasible_seq_groups
:
List
[
SequenceGroup
]
=
[]
swapped_queue
=
self
.
swapped
leftover_swapped
:
Deque
[
SequenceGroup
]
=
deque
()
while
swapped_queue
:
seq_group
=
swapped_queue
[
0
]
...
...
@@ -615,7 +613,7 @@ class Scheduler:
swapped_queue
.
extendleft
(
leftover_swapped
)
return
swapped_queue
,
SchedulerSwappedInOutputs
(
return
SchedulerSwappedInOutputs
(
decode_seq_groups
=
decode_seq_groups
,
prefill_seq_groups
=
prefill_seq_groups
,
blocks_to_swap_in
=
blocks_to_swap_in
,
...
...
@@ -642,11 +640,10 @@ class Scheduler:
def
_schedule_prefills
(
self
,
waiting_queue
:
deque
,
budget
:
SchedulingBudget
,
curr_loras
:
Optional
[
Set
[
int
]],
enable_chunking
:
bool
=
False
,
)
->
Tuple
[
deque
,
SchedulerPrefillOutputs
]
:
)
->
SchedulerPrefillOutputs
:
"""Schedule sequence groups that are in prefill stage.
Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE
...
...
@@ -658,8 +655,6 @@ class Scheduler:
`budget` and `curr_loras` are updated based on scheduled seq_groups.
Args:
waiting_queue: The queue that contains prefill requests.
The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated
when any requests are scheduled.
curr_loras: Currently batched lora request ids. The argument is
...
...
@@ -670,14 +665,12 @@ class Scheduler:
all tokens.
Returns:
A tuple of remaining waiting_queue after scheduling and
SchedulerSwappedInOutputs.
"""
ignored_seq_groups
:
List
[
SequenceGroup
]
=
[]
seq_groups
:
List
[
SequenceGroup
]
=
[]
# We don't sort waiting queue because we assume it is sorted.
# Copy the queue so that the input queue is not modified.
waiting_queue
=
deque
([
s
for
s
in
waiting_queue
])
waiting_queue
=
self
.
waiting
leftover_waiting_sequences
:
Deque
[
SequenceGroup
]
=
deque
()
while
self
.
_passed_delay
(
time
.
time
())
and
waiting_queue
:
...
...
@@ -756,7 +749,7 @@ class Scheduler:
if
len
(
seq_groups
)
>
0
:
self
.
prev_prompt
=
True
return
waiting_queue
,
SchedulerPrefillOutputs
(
return
SchedulerPrefillOutputs
(
seq_groups
=
seq_groups
,
ignored_seq_groups
=
ignored_seq_groups
,
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
=
True
))
...
...
@@ -783,53 +776,43 @@ class Scheduler:
seq_group
.
lora_int_id
for
seq_group
in
self
.
running
if
seq_group
.
lora_int_id
>
0
)
if
self
.
lora_enabled
else
None
remaining_waiting
,
prefills
=
(
self
.
waiting
,
SchedulerPrefillOutputs
.
create_empty
())
remaining_running
,
running_scheduled
=
(
self
.
running
,
SchedulerRunningOutputs
.
create_empty
())
remaining_swapped
,
swapped_in
=
(
self
.
swapped
,
SchedulerSwappedInOutputs
.
create_empty
())
prefills
=
SchedulerPrefillOutputs
.
create_empty
()
running_scheduled
=
SchedulerRunningOutputs
.
create_empty
()
swapped_in
=
SchedulerSwappedInOutputs
.
create_empty
()
# If any requests are swapped, prioritized swapped requests.
if
not
self
.
swapped
:
remaining_waiting
,
prefills
=
self
.
_schedule_prefills
(
self
.
waiting
,
budget
,
curr_loras
,
enable_chunking
=
False
)
prefills
=
self
.
_schedule_prefills
(
budget
,
curr_loras
,
enable_chunking
=
False
)
fcfs_policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
# Don't schedule decodes if prefills are scheduled.
# NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
# only contains decode requests, not chunked prefills.
if
len
(
prefills
.
seq_groups
)
==
0
:
remaining_running
,
running_scheduled
=
self
.
_schedule_running
(
self
.
running
,
budget
,
curr_loras
,
fcfs_policy
,
enable_chunking
=
False
)
running_scheduled
=
self
.
_schedule_running
(
budget
,
curr_loras
,
enable_chunking
=
False
)
# If any sequence group is preempted, do not swap in any sequence
# group. because it means there's no slot for new running requests.
if
len
(
running_scheduled
.
preempted
)
+
len
(
running_scheduled
.
swapped_out
)
==
0
:
remaining_swapped
,
swapped_in
=
self
.
_schedule_swapped
(
self
.
swapped
,
budget
,
curr_loras
,
fcfs_policy
)
swapped_in
=
self
.
_schedule_swapped
(
budget
,
curr_loras
)
assert
(
budget
.
num_batched_tokens
<=
self
.
scheduler_config
.
max_num_batched_tokens
)
assert
budget
.
num_curr_seqs
<=
self
.
scheduler_config
.
max_num_seqs
# Update waiting requests.
self
.
waiting
=
remaining_waiting
self
.
waiting
.
extendleft
(
running_scheduled
.
preempted
)
# Update new running requests.
self
.
running
=
remaining_running
self
.
running
.
extend
([
s
.
seq_group
for
s
in
prefills
.
seq_groups
])
self
.
running
.
extend
(
[
s
.
seq_group
for
s
in
running_scheduled
.
decode_seq_groups
])
self
.
running
.
extend
(
[
s
.
seq_group
for
s
in
swapped_in
.
decode_seq_groups
])
# Update swapped requests.
self
.
swapped
=
remaining_swapped
self
.
swapped
.
extend
(
running_scheduled
.
swapped_out
)
preempted
=
(
len
(
running_scheduled
.
preempted
)
+
len
(
running_scheduled
.
swapped_out
))
...
...
@@ -875,42 +858,32 @@ class Scheduler:
)
curr_loras
:
Set
[
int
]
=
set
()
remaining_waiting
,
prefills
=
(
self
.
waiting
,
SchedulerPrefillOutputs
.
create_empty
())
remaining_running
,
running_scheduled
=
(
self
.
running
,
SchedulerRunningOutputs
.
create_empty
())
remaining_swapped
,
swapped_in
=
(
self
.
swapped
,
SchedulerSwappedInOutputs
.
create_empty
())
prefills
=
SchedulerPrefillOutputs
.
create_empty
()
swapped_in
=
SchedulerSwappedInOutputs
.
create_empty
()
# Decoding should be always scheduled first by fcfs.
fcfs_policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
remaining_running
,
running_scheduled
=
self
.
_schedule_running
(
self
.
running
,
budget
,
curr_loras
,
fcfs_policy
,
enable_chunking
=
True
)
running_scheduled
=
self
.
_schedule_running
(
budget
,
curr_loras
,
enable_chunking
=
True
)
# Schedule swapped out requests.
# If preemption happens, it means we don't have space for swap-in.
if
len
(
running_scheduled
.
preempted
)
+
len
(
running_scheduled
.
swapped_out
)
==
0
:
remaining_swapped
,
swapped_in
=
self
.
_schedule_swapped
(
self
.
swapped
,
budget
,
curr_loras
,
fcfs_policy
)
swapped_in
=
self
.
_schedule_swapped
(
budget
,
curr_loras
)
# Schedule new prefills.
remaining_waiting
,
prefills
=
self
.
_schedule_prefills
(
self
.
waiting
,
budget
,
curr_loras
,
enable_chunking
=
True
)
prefills
=
self
.
_schedule_prefills
(
budget
,
curr_loras
,
enable_chunking
=
True
)
assert
(
budget
.
num_batched_tokens
<=
self
.
scheduler_config
.
max_num_batched_tokens
)
assert
budget
.
num_curr_seqs
<=
self
.
scheduler_config
.
max_num_seqs
# Update waiting requests.
self
.
waiting
=
remaining_waiting
self
.
waiting
.
extendleft
(
running_scheduled
.
preempted
)
# Update new running requests.
self
.
running
=
remaining_running
self
.
running
.
extend
([
s
.
seq_group
for
s
in
prefills
.
seq_groups
])
self
.
running
.
extend
(
[
s
.
seq_group
for
s
in
running_scheduled
.
decode_seq_groups
])
...
...
@@ -921,7 +894,6 @@ class Scheduler:
self
.
running
.
extend
(
[
s
.
seq_group
for
s
in
swapped_in
.
prefill_seq_groups
])
# Update swapped requests.
self
.
swapped
=
remaining_swapped
self
.
swapped
.
extend
(
running_scheduled
.
swapped_out
)
return
SchedulerOutputs
(
scheduled_seq_groups
=
(
prefills
.
seq_groups
+
...
...
@@ -1029,7 +1001,6 @@ class Scheduler:
token_chunk_size
=
token_chunk_size
,
lora_request
=
seq_group
.
lora_request
,
computed_block_nums
=
common_computed_block_nums
,
state
=
seq_group
.
state
,
# `multi_modal_data` will only be present for the 1st comm
# between engine and worker.
# the subsequent comms can still use delta, but
...
...
@@ -1058,13 +1029,16 @@ class Scheduler:
self
.
block_manager
.
free
(
seq
)
def
free_finished_seq_groups
(
self
)
->
None
:
for
queue
in
[
self
.
running
,
self
.
swapped
,
self
.
waiting
]:
self
.
_finished_requests_ids
+=
[
seq_group
.
request_id
for
seq_group
in
queue
if
seq_group
.
is_finished
()
]
self
.
running
=
deque
(
seq_group
for
seq_group
in
self
.
running
if
not
seq_group
.
is_finished
())
remaining
:
Deque
[
SequenceGroup
]
=
deque
()
for
seq_group
in
self
.
running
:
if
seq_group
.
is_finished
():
# Add the finished requests to the finished requests list.
# This list will be used to update the Mamba cache in the
# next step.
self
.
_finished_requests_ids
.
append
(
seq_group
.
request_id
)
else
:
remaining
.
append
(
seq_group
)
self
.
running
=
remaining
def
_allocate_and_set_running
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
self
.
block_manager
.
allocate
(
seq_group
)
...
...
vllm/distributed/device_communicators/cuda_wrapper.py
View file @
e661d594
...
...
@@ -4,9 +4,6 @@ convenient for use when we just need to call a few functions.
"""
import
ctypes
import
glob
import
os
import
sys
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
...
...
@@ -36,24 +33,25 @@ class Function:
argtypes
:
List
[
Any
]
def
get_pytorch_default_cudart_library_path
()
->
str
:
# code borrowed from https://github.com/pytorch/pytorch/blob/1cae60a87e5bdda8bcf55724a862eeed98a9747e/torch/__init__.py#L284 # noqa
lib_folder
=
"cuda_runtime"
lib_name
=
"libcudart.so.*[0-9]"
lib_path
=
None
for
path
in
sys
.
path
:
nvidia_path
=
os
.
path
.
join
(
path
,
"nvidia"
)
if
not
os
.
path
.
exists
(
nvidia_path
):
continue
candidate_lib_paths
=
glob
.
glob
(
os
.
path
.
join
(
nvidia_path
,
lib_folder
,
"lib"
,
lib_name
))
if
candidate_lib_paths
and
not
lib_path
:
lib_path
=
candidate_lib_paths
[
0
]
if
lib_path
:
break
if
not
lib_path
:
raise
ValueError
(
f
"
{
lib_name
}
not found in the system path
{
sys
.
path
}
"
)
return
lib_path
def
find_loaded_library
(
lib_name
)
->
Optional
[
str
]:
"""
According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
the file `/proc/self/maps` contains the memory maps of the process, which includes the
shared libraries loaded by the process. We can use this file to find the path of the
a loaded library.
"""
# noqa
found
=
False
with
open
(
"/proc/self/maps"
)
as
f
:
for
line
in
f
:
if
lib_name
in
line
:
found
=
True
break
if
not
found
:
# the library is not loaded in the current process
return
None
start
=
line
.
index
(
"/"
)
path
=
line
[
start
:].
strip
()
return
path
class
CudaRTLibrary
:
...
...
@@ -100,7 +98,9 @@ class CudaRTLibrary:
def
__init__
(
self
,
so_file
:
Optional
[
str
]
=
None
):
if
so_file
is
None
:
so_file
=
get_pytorch_default_cudart_library_path
()
so_file
=
find_loaded_library
(
"libcudart.so"
)
assert
so_file
is
not
None
,
\
"libcudart.so is not loaded in the current process"
if
so_file
not
in
CudaRTLibrary
.
path_to_library_cache
:
lib
=
ctypes
.
CDLL
(
so_file
)
CudaRTLibrary
.
path_to_library_cache
[
so_file
]
=
lib
...
...
vllm/distributed/device_communicators/custom_all_reduce_utils.py
View file @
e661d594
...
...
@@ -145,6 +145,7 @@ def can_actually_p2p(
p_tgt
.
start
()
p_src
.
join
()
p_tgt
.
join
()
assert
p_src
.
exitcode
==
0
and
p_tgt
.
exitcode
==
0
result
:
List
[
bool
]
=
[]
for
src
,
tgt
in
zip
(
batch_src
,
batch_tgt
):
a
=
result_queue
.
get
()
...
...
@@ -221,7 +222,8 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
# wrap raised exception to provide more information
raise
RuntimeError
(
f
"Error happened when batch testing "
f
"peer-to-peer access from
{
batch_src
}
to
{
batch_tgt
}
"
)
from
e
f
"peer-to-peer access from
{
batch_src
}
to
{
batch_tgt
}
:
\n
"
f
"
{
returned
.
stderr
.
decode
()
}
"
)
from
e
result
=
pickle
.
loads
(
returned
.
stdout
)
for
_i
,
_j
,
r
in
zip
(
batch_src
,
batch_tgt
,
result
):
cache
[
f
"
{
_i
}
->
{
_j
}
"
]
=
r
...
...
vllm/distributed/device_communicators/shm_broadcast.py
View file @
e661d594
...
...
@@ -9,7 +9,7 @@ from unittest.mock import patch
import
torch
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
from
zmq
import
P
UB
,
REP
,
REQ
,
SUB
,
SUBSCRIB
E
,
Context
# type: ignore
from
zmq
import
S
UB
,
SUBSCRIBE
,
XPUB
,
XPUB_VERBOS
E
,
Context
# type: ignore
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
...
...
@@ -153,9 +153,7 @@ class Handle:
buffer
:
Optional
[
ShmRingBuffer
]
=
None
local_subscribe_port
:
Optional
[
int
]
=
None
local_sync_port
:
Optional
[
int
]
=
None
remote_subscribe_port
:
Optional
[
int
]
=
None
remote_sync_port
:
Optional
[
int
]
=
None
class
MessageQueue
:
...
...
@@ -189,38 +187,36 @@ class MessageQueue:
self
.
buffer
=
ShmRingBuffer
(
n_local_reader
,
max_chunk_bytes
,
max_chunks
)
self
.
local_socket
=
context
.
socket
(
PUB
)
# XPUB is very similar to PUB,
# except that it can receive subscription messages
# to confirm the number of subscribers
self
.
local_socket
=
context
.
socket
(
XPUB
)
# set the verbose option so that we can receive every subscription
# message. otherwise, we will only receive the first subscription
# see http://api.zeromq.org/3-3:zmq-setsockopt for more details
self
.
local_socket
.
setsockopt
(
XPUB_VERBOSE
,
True
)
local_subscribe_port
=
get_open_port
()
self
.
local_socket
.
bind
(
f
"tcp://*:
{
local_subscribe_port
}
"
)
self
.
local_sync_socket
=
context
.
socket
(
REP
)
local_sync_port
=
get_open_port
()
self
.
local_sync_socket
.
bind
(
f
"tcp://*:
{
local_sync_port
}
"
)
self
.
current_idx
=
0
else
:
self
.
buffer
=
None
# type: ignore
local_subscribe_port
=
None
local_sync_port
=
None
self
.
local_socket
=
None
self
.
local_sync_socket
=
None
self
.
current_idx
=
-
1
if
n_remote_reader
>
0
:
# for remote readers, we will:
# create a publish-subscribe socket to communicate large data
self
.
remote_socket
=
context
.
socket
(
PUB
)
self
.
remote_socket
=
context
.
socket
(
XPUB
)
self
.
remote_socket
.
setsockopt
(
XPUB_VERBOSE
,
True
)
remote_subscribe_port
=
get_open_port
()
self
.
remote_socket
.
bind
(
f
"tcp://*:
{
remote_subscribe_port
}
"
)
self
.
remote_sync_socket
=
context
.
socket
(
REP
)
remote_sync_port
=
get_open_port
()
self
.
remote_sync_socket
.
bind
(
f
"tcp://*:
{
remote_sync_port
}
"
)
else
:
remote_subscribe_port
=
None
remote_sync_port
=
None
self
.
remote_socket
=
None
self
.
remote_sync_socket
=
None
self
.
_is_writer
=
True
self
.
_is_local_reader
=
False
...
...
@@ -233,9 +229,7 @@ class MessageQueue:
local_reader_ranks
=
local_reader_ranks
,
buffer
=
self
.
buffer
,
local_subscribe_port
=
local_subscribe_port
,
local_sync_port
=
local_sync_port
,
remote_subscribe_port
=
remote_subscribe_port
,
remote_sync_port
=
remote_sync_port
,
)
logger
.
info
(
"vLLM message queue communication handle: %s"
,
self
.
handle
)
...
...
@@ -264,12 +258,7 @@ class MessageQueue:
self
.
local_socket
.
connect
(
f
"tcp://
{
handle
.
connect_ip
}
:
{
handle
.
local_subscribe_port
}
"
)
self
.
local_sync_socket
=
context
.
socket
(
REQ
)
self
.
local_sync_socket
.
connect
(
f
"tcp://
{
handle
.
connect_ip
}
:
{
handle
.
local_sync_port
}
"
)
self
.
remote_socket
=
None
self
.
remote_sync_socket
=
None
else
:
self
.
buffer
=
None
# type: ignore
self
.
current_idx
=
-
1
...
...
@@ -278,17 +267,12 @@ class MessageQueue:
self
.
_is_remote_reader
=
True
self
.
local_socket
=
None
self
.
local_sync_socket
=
None
self
.
remote_socket
=
context
.
socket
(
SUB
)
self
.
remote_socket
.
setsockopt_string
(
SUBSCRIBE
,
""
)
self
.
remote_socket
.
connect
(
f
"tcp://
{
handle
.
connect_ip
}
:
{
handle
.
remote_subscribe_port
}
"
)
self
.
remote_sync_socket
=
context
.
socket
(
REQ
)
self
.
remote_sync_socket
.
connect
(
f
"tcp://
{
handle
.
connect_ip
}
:
{
handle
.
remote_sync_port
}
"
)
return
self
def
wait_until_ready
(
self
):
...
...
@@ -300,29 +284,27 @@ class MessageQueue:
# local readers
for
i
in
range
(
self
.
n_local_reader
):
recv
=
self
.
local_sync_socket
.
recv
()
assert
recv
==
b
"READY"
self
.
local_sync_socket
.
send
(
b
"READY"
)
# wait for subscription messages from all local readers
self
.
local_socket
.
recv
()
if
self
.
n_local_reader
>
0
:
# send a message to all local readers
# to make sure the publish channel is working
self
.
local_socket
.
send
(
b
"READY"
)
# remote readers
for
i
in
range
(
self
.
n_remote_reader
):
recv
=
self
.
remote_sync_socket
.
recv
()
assert
recv
==
b
"READY"
self
.
remote_sync_socket
.
send
(
b
"READY"
)
# wait for subscription messages from all remote readers
self
.
remote_socket
.
recv
()
if
self
.
n_remote_reader
>
0
:
# send a message to all remote readers
# to make sure the publish channel is working
self
.
remote_socket
.
send
(
b
"READY"
)
elif
self
.
_is_local_reader
:
self
.
local_sync_socket
.
send
(
b
"READY"
)
recv
=
self
.
local_sync_socket
.
recv
()
assert
recv
==
b
"READY"
# wait for the writer to send a message
recv
=
self
.
local_socket
.
recv
()
assert
recv
==
b
"READY"
elif
self
.
_is_remote_reader
:
self
.
remote_sync_socket
.
send
(
b
"READY"
)
recv
=
self
.
remote_sync_socket
.
recv
()
assert
recv
==
b
"READY"
# wait for the writer to send a message
recv
=
self
.
remote_socket
.
recv
()
assert
recv
==
b
"READY"
...
...
vllm/distributed/device_communicators/tpu_communicator.py
0 → 100644
View file @
e661d594
import
torch
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
from
vllm.platforms
import
current_platform
if
current_platform
.
is_tpu
():
import
torch_xla.core.xla_model
as
xm
import
torch_xla.runtime
as
xr
from
torch_xla._internal
import
pjrt
class
TpuCommunicator
:
def
__init__
(
self
,
group
:
ProcessGroup
):
if
not
current_platform
.
is_tpu
():
self
.
disabled
=
True
return
self
.
disabled
=
False
local_rank
=
dist
.
get_rank
(
group
)
world_size
=
dist
.
get_world_size
(
group
)
pjrt
.
initialize_multiprocess
(
local_rank
,
world_size
)
xr
.
_init_world_size_ordinal
()
def
all_reduce
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
xm
.
all_reduce
(
xm
.
REDUCE_SUM
,
x
)
def
all_gather
(
self
,
x
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
assert
dim
==
-
1
,
"TPUs only support dim=-1 for all-gather."
return
xm
.
all_gather
(
x
,
dim
=
dim
)
vllm/distributed/parallel_state.py
View file @
e661d594
...
...
@@ -45,22 +45,16 @@ TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
def
_split_tensor_dict
(
tensor_dict
:
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]]
,
prefix
:
str
=
""
)
->
Tuple
[
List
[
Tuple
[
str
,
Any
]],
List
[
torch
.
Tensor
]]:
tensor_dict
:
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]]
)
->
Tuple
[
List
[
Tuple
[
str
,
Any
]],
List
[
torch
.
Tensor
]]:
"""Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
by its metadata.
2. A list of tensors.
If the Tensor is nested under `tensor_dict["key1"]["key2"]`, the key of its
metadata will be "key1%key2".
"""
metadata_list
:
List
[
Tuple
[
str
,
Any
]]
=
[]
tensor_list
=
[]
tensor_list
:
List
[
torch
.
Tensor
]
=
[]
for
key
,
value
in
tensor_dict
.
items
():
assert
"%"
not
in
key
,
(
"Avoid having '%' in key "
"as it is used as a separator for nested entries."
)
if
isinstance
(
value
,
torch
.
Tensor
):
# Note: we cannot use `value.device` here,
# because it contains not only the device type but also the device
...
...
@@ -68,31 +62,13 @@ def _split_tensor_dict(
# receiving side will set the device index.
device
=
value
.
device
.
type
metadata_list
.
append
(
(
prefix
+
key
,
TensorMetadata
(
device
,
value
.
dtype
,
value
.
size
())))
(
key
,
TensorMetadata
(
device
,
value
.
dtype
,
value
.
size
())))
tensor_list
.
append
(
value
)
elif
isinstance
(
value
,
dict
):
if
len
(
value
)
==
0
:
metadata_list
.
append
((
prefix
+
key
,
value
))
inner_metadata_list
,
inner_tensor_list
=
_split_tensor_dict
(
value
,
prefix
+
key
+
"%"
)
metadata_list
.
extend
(
inner_metadata_list
)
tensor_list
.
extend
(
inner_tensor_list
)
else
:
metadata_list
.
append
((
prefix
+
key
,
value
))
metadata_list
.
append
((
key
,
value
))
return
metadata_list
,
tensor_list
def
_update_nested_dict
(
nested_dict
,
flattened_key
,
value
):
key_splits
=
flattened_key
.
split
(
"%"
)
cur_dict
=
nested_dict
for
k
in
key_splits
[:
-
1
]:
if
k
not
in
cur_dict
:
cur_dict
[
k
]
=
{}
cur_dict
=
cur_dict
[
k
]
cur_dict
[
key_splits
[
-
1
]]
=
value
class
GroupCoordinator
:
"""
PyTorch ProcessGroup wrapper for a group of processes.
...
...
@@ -133,6 +109,7 @@ class GroupCoordinator:
torch_distributed_backend
:
Union
[
str
,
Backend
],
use_pynccl
:
bool
,
use_custom_allreduce
:
bool
,
use_tpu_communicator
:
bool
,
use_message_queue_broadcaster
:
bool
=
False
,
):
...
...
@@ -164,6 +141,7 @@ class GroupCoordinator:
self
.
use_pynccl
=
use_pynccl
self
.
use_custom_allreduce
=
use_custom_allreduce
self
.
use_tpu_communicator
=
use_tpu_communicator
# lazy import to avoid documentation build error
from
vllm.distributed.device_communicators.custom_all_reduce
import
(
...
...
@@ -190,6 +168,12 @@ class GroupCoordinator:
else
:
self
.
ca_comm
=
None
from
vllm.distributed.device_communicators.tpu_communicator
import
(
TpuCommunicator
)
self
.
tpu_communicator
:
Optional
[
TpuCommunicator
]
if
use_tpu_communicator
and
self
.
world_size
>
1
:
self
.
tpu_communicator
=
TpuCommunicator
(
group
=
self
.
cpu_group
)
from
vllm.distributed.device_communicators.shm_broadcast
import
(
MessageQueue
)
self
.
mq_broadcaster
:
Optional
[
MessageQueue
]
=
None
...
...
@@ -243,6 +227,13 @@ class GroupCoordinator:
ca_comm
=
self
.
ca_comm
maybe_ca_context
=
nullcontext
(
)
if
ca_comm
is
None
else
ca_comm
.
capture
()
# ensure all initialization operations complete before attempting to
# capture the graph on another stream
curr_stream
=
torch
.
cuda
.
current_stream
()
if
curr_stream
!=
stream
:
stream
.
wait_stream
(
curr_stream
)
with
torch
.
cuda
.
stream
(
stream
),
maybe_ca_context
:
# In graph mode, we have to be very careful about the collective
# operations. The current status is:
...
...
@@ -282,6 +273,12 @@ class GroupCoordinator:
# Bypass the function if we are using only 1 GPU.
if
self
.
world_size
==
1
:
return
input_
# For TPUs, use TPU communicator.
tpu_comm
=
self
.
tpu_communicator
if
tpu_comm
is
not
None
and
not
tpu_comm
.
disabled
:
return
tpu_comm
.
all_reduce
(
input_
)
if
ca_comm
is
not
None
:
out
=
ca_comm
.
custom_all_reduce
(
input_
)
if
out
is
not
None
:
...
...
@@ -289,6 +286,9 @@ class GroupCoordinator:
pynccl_comm
=
self
.
pynccl_comm
if
(
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
):
pynccl_comm
.
all_reduce
(
input_
)
elif
input_
.
is_cpu
:
import
intel_extension_for_pytorch
as
ipex
ipex
.
distributed
.
all_reduce
(
input_
,
group
=
self
.
device_group
)
else
:
torch
.
distributed
.
all_reduce
(
input_
,
group
=
self
.
device_group
)
return
input_
...
...
@@ -300,6 +300,12 @@ class GroupCoordinator:
return
input_
assert
-
input_
.
dim
()
<=
dim
<
input_
.
dim
(),
(
f
"Invalid dim (
{
dim
}
) for input tensor with shape
{
input_
.
size
()
}
"
)
# For TPUs, use TPU communicator.
tpu_comm
=
self
.
tpu_communicator
if
tpu_comm
is
not
None
and
not
tpu_comm
.
disabled
:
return
tpu_comm
.
all_gather
(
input_
,
dim
)
if
dim
<
0
:
# Convert negative dim to positive.
dim
+=
input_
.
dim
()
...
...
@@ -536,7 +542,7 @@ class GroupCoordinator:
device
=
value
.
device
)
if
tensor
.
numel
()
==
0
:
# Skip broadcasting empty tensors.
_update_nested_dict
(
tensor_dict
,
key
,
tensor
)
tensor_dict
[
key
]
=
tensor
continue
if
tensor
.
is_cpu
:
# use metadata_group for CPU tensors
...
...
@@ -553,9 +559,9 @@ class GroupCoordinator:
group
=
group
,
async_op
=
True
)
async_handles
.
append
(
handle
)
_update_nested_dict
(
tensor_dict
,
key
,
tensor
)
tensor_dict
[
key
]
=
tensor
else
:
_update_nested_dict
(
tensor_dict
,
key
,
value
)
tensor_dict
[
key
]
=
value
for
async_handle
in
async_handles
:
async_handle
.
wait
()
return
tensor_dict
...
...
@@ -563,7 +569,8 @@ class GroupCoordinator:
def
send_tensor_dict
(
self
,
tensor_dict
:
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]],
dst
:
Optional
[
int
]
=
None
dst
:
Optional
[
int
]
=
None
,
all_gather_group
:
Optional
[
"GroupCoordinator"
]
=
None
,
)
->
Optional
[
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]]]:
"""Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank.
...
...
@@ -572,6 +579,11 @@ class GroupCoordinator:
if
not
torch
.
distributed
.
is_initialized
()
or
self
.
world_size
==
1
:
return
tensor_dict
all_gather_size
=
(
1
if
all_gather_group
is
None
else
all_gather_group
.
world_size
)
all_gather_rank
=
(
0
if
all_gather_group
is
None
else
all_gather_group
.
rank_in_group
)
group
=
self
.
device_group
metadata_group
=
self
.
cpu_group
...
...
@@ -592,6 +604,12 @@ class GroupCoordinator:
if
tensor
.
numel
()
==
0
:
# Skip sending empty tensors.
continue
# send-allgather: send only a slice, then do allgather.
if
(
all_gather_group
is
not
None
and
tensor
.
numel
()
%
all_gather_size
==
0
):
tensor
=
tensor
.
reshape
(
all_gather_size
,
-
1
)[
all_gather_rank
]
if
tensor
.
is_cpu
:
# use metadata_group for CPU tensors
torch
.
distributed
.
send
(
tensor
,
...
...
@@ -606,7 +624,8 @@ class GroupCoordinator:
def
recv_tensor_dict
(
self
,
src
:
Optional
[
int
]
=
None
src
:
Optional
[
int
]
=
None
,
all_gather_group
:
Optional
[
"GroupCoordinator"
]
=
None
,
)
->
Optional
[
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]]]:
"""Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
...
...
@@ -615,6 +634,11 @@ class GroupCoordinator:
if
not
torch
.
distributed
.
is_initialized
()
or
self
.
world_size
==
1
:
return
None
all_gather_size
=
(
1
if
all_gather_group
is
None
else
all_gather_group
.
world_size
)
all_gather_rank
=
(
0
if
all_gather_group
is
None
else
all_gather_group
.
rank_in_group
)
group
=
self
.
device_group
metadata_group
=
self
.
cpu_group
...
...
@@ -631,8 +655,18 @@ class GroupCoordinator:
device
=
value
.
device
)
if
tensor
.
numel
()
==
0
:
# Skip broadcasting empty tensors.
_update_nested_dict
(
tensor_dict
,
key
,
tensor
)
tensor_dict
[
key
]
=
tensor
continue
# send-allgather: send only a slice, then do allgather.
use_all_gather
=
(
all_gather_group
is
not
None
and
tensor
.
numel
()
%
all_gather_size
==
0
)
if
use_all_gather
:
orig_shape
=
tensor
.
shape
tensor
=
tensor
.
reshape
(
all_gather_size
,
-
1
)[
all_gather_rank
]
if
tensor
.
is_cpu
:
# use metadata_group for CPU tensors
torch
.
distributed
.
recv
(
tensor
,
...
...
@@ -643,9 +677,15 @@ class GroupCoordinator:
torch
.
distributed
.
recv
(
tensor
,
src
=
self
.
ranks
[
src
],
group
=
group
)
_update_nested_dict
(
tensor_dict
,
key
,
tensor
)
if
use_all_gather
:
# do the allgather
tensor
=
all_gather_group
.
all_gather
(
# type: ignore
tensor
,
dim
=
0
)
tensor
=
tensor
.
reshape
(
orig_shape
)
tensor_dict
[
key
]
=
tensor
else
:
_update_nested_dict
(
tensor_dict
,
key
,
value
)
tensor_dict
[
key
]
=
value
return
tensor_dict
def
barrier
(
self
):
...
...
@@ -673,8 +713,8 @@ class GroupCoordinator:
size
:
torch
.
Size
,
dtype
:
torch
.
dtype
,
src
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
"""Receives a tensor from the src rank."""
"""NOTE: `src` is the local rank of the
destination
rank."""
"""Receives a tensor from the s
ou
rc
e
rank."""
"""NOTE: `src` is the local rank of the
source
rank."""
if
src
is
None
:
src
=
(
self
.
rank_in_group
-
1
)
%
self
.
world_size
...
...
@@ -717,6 +757,7 @@ def init_world_group(ranks: List[int], local_rank: int,
torch_distributed_backend
=
backend
,
use_pynccl
=
False
,
use_custom_allreduce
=
False
,
use_tpu_communicator
=
False
,
)
...
...
@@ -735,6 +776,7 @@ def init_model_parallel_group(
torch_distributed_backend
=
backend
,
use_pynccl
=
True
,
use_custom_allreduce
=
use_custom_allreduce
,
use_tpu_communicator
=
True
,
use_message_queue_broadcaster
=
use_message_queue_broadcaster
,
)
...
...
vllm/distributed/utils.py
View file @
e661d594
...
...
@@ -6,6 +6,11 @@ from typing import Sequence, Tuple
import
torch
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
def
ensure_divisibility
(
numerator
,
denominator
):
"""Ensure that numerator is divisible by the denominator."""
...
...
@@ -54,11 +59,28 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int,
If the number of layers is not divisible by the number of partitions,
the last partition will have the remaining layers.
"""
layers_per_partition
=
num_hidden_layers
//
pp_size
start_layer
=
pp_rank
*
layers_per_partition
end_layer
=
start_layer
+
layers_per_partition
partition_list_str
=
envs
.
VLLM_PP_LAYER_PARTITION
if
partition_list_str
is
not
None
:
try
:
partitions
=
[
int
(
layer
)
for
layer
in
partition_list_str
.
split
(
","
)
]
except
ValueError
as
err
:
raise
ValueError
(
"Invalid partition string: {}"
.
format
(
partition_list_str
))
from
err
if
len
(
partitions
)
!=
pp_size
:
raise
ValueError
(
f
"
{
len
(
partitions
)
=
}
does not match
{
pp_size
=
}
."
)
if
sum
(
partitions
)
!=
num_hidden_layers
:
raise
ValueError
(
f
"
{
sum
(
partitions
)
=
}
does not match
{
num_hidden_layers
=
}
."
)
start_layer
=
sum
(
partitions
[:
pp_rank
])
end_layer
=
start_layer
+
partitions
[
pp_rank
]
else
:
layers_per_partition
=
num_hidden_layers
//
pp_size
start_layer
=
pp_rank
*
layers_per_partition
end_layer
=
start_layer
+
layers_per_partition
if
pp_rank
==
pp_size
-
1
:
end_layer
=
num_hidden_layers
if
pp_rank
==
pp_size
-
1
:
end_layer
=
num_hidden_layers
return
(
start_layer
,
end_layer
)
vllm/engine/arg_utils.py
View file @
e661d594
...
...
@@ -632,9 +632,9 @@ class EngineArgs:
'--preemption-mode'
,
type
=
str
,
default
=
None
,
help
=
'If
\'
recompute
\'
, the engine performs preemption by
block
'
'
swapp
ing; If
\'
swap
\'
, the engine performs preemption by
block
'
'swapping.'
)
help
=
'If
\'
recompute
\'
, the engine performs preemption by '
'
recomput
ing; If
\'
swap
\'
, the engine performs preemption by '
'
block
swapping.'
)
parser
.
add_argument
(
"--served-model-name"
,
...
...
@@ -676,8 +676,8 @@ class EngineArgs:
# bitsandbytes quantization needs a specific model loader
# so we make sure the quant method and the load format are consistent
if
(
self
.
quantization
==
"bitsandbytes"
or
self
.
qlora_adapter_name_or_path
is
not
None
)
and
\
self
.
load_format
!=
"bitsandbytes"
:
self
.
qlora_adapter_name_or_path
is
not
None
)
and
\
self
.
load_format
!=
"bitsandbytes"
:
raise
ValueError
(
"BitsAndBytes quantization and QLoRA adapter only support "
f
"'bitsandbytes' load format, but got
{
self
.
load_format
}
"
)
...
...
@@ -754,10 +754,14 @@ class EngineArgs:
use_sliding_window
=
(
model_config
.
get_sliding_window
()
is
not
None
)
use_spec_decode
=
self
.
speculative_model
is
not
None
has_seqlen_agnostic_layers
=
(
model_config
.
contains_seqlen_agnostic_layers
(
parallel_config
))
if
(
is_gpu
and
not
use_sliding_window
and
not
use_spec_decode
and
not
self
.
enable_lora
and
not
self
.
enable_prompt_adapter
and
not
self
.
enable_prefix_caching
):
and
not
self
.
enable_prefix_caching
and
not
has_seqlen_agnostic_layers
):
self
.
enable_chunked_prefill
=
True
logger
.
warning
(
"Chunked prefill is enabled by default for models with "
...
...
@@ -788,6 +792,7 @@ class EngineArgs:
speculative_max_model_len
=
self
.
speculative_max_model_len
,
enable_chunked_prefill
=
self
.
enable_chunked_prefill
,
use_v2_block_manager
=
self
.
use_v2_block_manager
,
disable_log_stats
=
self
.
disable_log_stats
,
ngram_prompt_lookup_max
=
self
.
ngram_prompt_lookup_max
,
ngram_prompt_lookup_min
=
self
.
ngram_prompt_lookup_min
,
draft_token_acceptance_method
=
\
...
...
vllm/engine/async_llm_engine.py
View file @
e661d594
...
...
@@ -7,7 +7,8 @@ from typing import (AsyncIterator, Callable, Dict, Iterable, List, Mapping,
from
transformers
import
PreTrainedTokenizer
import
vllm.envs
as
envs
from
vllm.config
import
DecodingConfig
,
EngineConfig
,
ModelConfig
from
vllm.config
import
(
DecodingConfig
,
EngineConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_timeout
import
asyncio_timeout
...
...
@@ -407,11 +408,15 @@ class AsyncLLMEngine:
from
vllm.executor.neuron_executor
import
NeuronExecutorAsync
executor_class
=
NeuronExecutorAsync
elif
engine_config
.
device_config
.
device_type
==
"tpu"
:
from
vllm.executor.tpu_executor
import
TPUExecutorAsync
executor_class
=
TPUExecutorAsync
if
distributed_executor_backend
==
"ray"
:
initialize_ray_cluster
(
engine_config
.
parallel_config
)
from
vllm.executor.ray_tpu_executor
import
RayTPUExecutorAsync
executor_class
=
RayTPUExecutorAsync
else
:
assert
distributed_executor_backend
is
None
from
vllm.executor.tpu_executor
import
TPUExecutorAsync
executor_class
=
TPUExecutorAsync
elif
engine_config
.
device_config
.
device_type
==
"cpu"
:
assert
distributed_executor_backend
is
None
,
(
"Distributed execution is not supported with the CPU backend."
)
from
vllm.executor.cpu_executor
import
CPUExecutorAsync
executor_class
=
CPUExecutorAsync
elif
engine_config
.
device_config
.
device_type
==
"openvino"
:
...
...
@@ -924,6 +929,14 @@ class AsyncLLMEngine:
else
:
return
self
.
engine
.
get_model_config
()
async
def
get_parallel_config
(
self
)
->
ParallelConfig
:
"""Get the parallel configuration of the vLLM engine."""
if
self
.
engine_use_ray
:
return
await
self
.
engine
.
get_parallel_config
.
remote
(
# type: ignore
)
else
:
return
self
.
engine
.
get_parallel_config
()
async
def
get_decoding_config
(
self
)
->
DecodingConfig
:
"""Get the decoding configuration of the vLLM engine."""
if
self
.
engine_use_ray
:
...
...
@@ -932,6 +945,22 @@ class AsyncLLMEngine:
else
:
return
self
.
engine
.
get_decoding_config
()
async
def
get_scheduler_config
(
self
)
->
SchedulerConfig
:
"""Get the scheduling configuration of the vLLM engine."""
if
self
.
engine_use_ray
:
return
await
self
.
engine
.
get_scheduler_config
.
remote
(
# type: ignore
)
else
:
return
self
.
engine
.
get_scheduler_config
()
async
def
get_lora_config
(
self
)
->
LoRAConfig
:
"""Get the lora configuration of the vLLM engine."""
if
self
.
engine_use_ray
:
return
await
self
.
engine
.
get_lora_config
.
remote
(
# type: ignore
)
else
:
return
self
.
engine
.
get_lora_config
()
async
def
do_log_stats
(
self
,
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
,
...
...
vllm/engine/llm_engine.py
View file @
e661d594
...
...
@@ -5,8 +5,6 @@ from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
,
Type
,
TypeVar
,
Union
from
transformers
import
PreTrainedTokenizer
import
vllm.envs
as
envs
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
DeviceConfig
,
EngineConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
...
...
@@ -40,8 +38,8 @@ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer
)
from
vllm.transformers_utils.config
import
try_get_generation_config
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.tokenizer_group
import
(
BaseTokenizerGroup
,
ge
t_tokenizer_
g
ro
up
)
from
vllm.transformers_utils.tokenizer_group
import
(
AnyTokenizer
,
BaseTokenizerGroup
,
ini
t_tokenizer_
f
ro
m_configs
)
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
usage_message
)
from
vllm.utils
import
Counter
...
...
@@ -408,8 +406,14 @@ class LLMEngine:
from
vllm.executor.neuron_executor
import
NeuronExecutor
executor_class
=
NeuronExecutor
elif
engine_config
.
device_config
.
device_type
==
"tpu"
:
from
vllm.executor.tpu_executor
import
TPUExecutor
executor_class
=
TPUExecutor
if
distributed_executor_backend
==
"ray"
:
initialize_ray_cluster
(
engine_config
.
parallel_config
)
from
vllm.executor.ray_tpu_executor
import
RayTPUExecutor
executor_class
=
RayTPUExecutor
else
:
assert
distributed_executor_backend
is
None
from
vllm.executor.tpu_executor
import
TPUExecutor
executor_class
=
TPUExecutor
elif
engine_config
.
device_config
.
device_type
==
"cpu"
:
from
vllm.executor.cpu_executor
import
CPUExecutor
executor_class
=
CPUExecutor
...
...
@@ -485,29 +489,21 @@ class LLMEngine:
return
self
.
tokenizer
def
get_tokenizer
(
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
)
->
"PreTrained
Tokenizer
"
:
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
Any
Tokenizer
:
return
self
.
get_tokenizer_group
().
get_lora_tokenizer
(
lora_request
)
# def get_tokenizer_for_seq(self,
# sequence: Sequence) -> "PreTrainedTokenizer":
# def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:
# return self.get_tokenizer_group().get_lora_tokenizer(
# sequence.lora_request)
def
_init_tokenizer
(
self
,
**
tokenizer_init_kwargs
)
->
BaseTokenizerGroup
:
init_kwargs
=
dict
(
tokenizer_id
=
self
.
model_config
.
tokenizer
,
enable_lora
=
bool
(
self
.
lora_config
),
max_num_seqs
=
self
.
scheduler_config
.
max_num_seqs
,
max_input_length
=
None
,
tokenizer_mode
=
self
.
model_config
.
tokenizer_mode
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
revision
=
self
.
model_config
.
tokenizer_revision
)
init_kwargs
.
update
(
tokenizer_init_kwargs
)
return
get_tokenizer_group
(
self
.
parallel_config
.
tokenizer_pool_config
,
**
init_kwargs
)
def
_init_tokenizer
(
self
)
->
BaseTokenizerGroup
:
return
init_tokenizer_from_configs
(
model_config
=
self
.
model_config
,
scheduler_config
=
self
.
scheduler_config
,
parallel_config
=
self
.
parallel_config
,
enable_lora
=
bool
(
self
.
lora_config
))
def
_verify_args
(
self
)
->
None
:
self
.
model_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
...
...
@@ -769,10 +765,22 @@ class LLMEngine:
"""Gets the model configuration."""
return
self
.
model_config
def
get_parallel_config
(
self
)
->
ParallelConfig
:
"""Gets the parallel configuration."""
return
self
.
parallel_config
def
get_decoding_config
(
self
)
->
DecodingConfig
:
"""Gets the decoding configuration."""
return
self
.
decoding_config
def
get_scheduler_config
(
self
)
->
SchedulerConfig
:
"""Gets the scheduler configuration."""
return
self
.
scheduler_config
def
get_lora_config
(
self
)
->
LoRAConfig
:
"""Gets the LoRA configuration."""
return
self
.
lora_config
def
get_num_unfinished_requests
(
self
)
->
int
:
"""Gets the number of unfinished requests."""
return
sum
(
scheduler
.
get_num_unfinished_seq_groups
()
...
...
@@ -963,8 +971,9 @@ class LLMEngine:
model_output
:
Optional
[
List
[
SamplerOutput
]]
=
None
)
->
None
:
"""Forced log when no requests active."""
if
self
.
log_stats
:
stats
=
self
.
_get_stats
(
scheduler_outputs
,
model_output
)
for
logger
in
self
.
stat_loggers
.
values
():
logger
.
log
(
s
elf
.
_get_stats
(
scheduler_outputs
,
model_output
)
)
logger
.
log
(
s
tats
)
def
_get_stats
(
self
,
...
...
Prev
1
…
7
8
9
10
11
12
13
14
15
…
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment