Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fcfc474d
Commit
fcfc474d
authored
Apr 09, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.3' into v0.8.3-dev
parents
bb94d2e5
296c6572
Changes
503
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1342 additions
and
444 deletions
+1342
-444
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+56
-23
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+77
-37
vllm/v1/core/sched/output.py
vllm/v1/core/sched/output.py
+1
-2
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+154
-116
vllm/v1/core/specialized_manager.py
vllm/v1/core/specialized_manager.py
+161
-0
vllm/v1/engine/__init__.py
vllm/v1/engine/__init__.py
+8
-1
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+22
-12
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+219
-43
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+345
-75
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+24
-11
vllm/v1/engine/output_processor.py
vllm/v1/engine/output_processor.py
+1
-1
vllm/v1/engine/processor.py
vllm/v1/engine/processor.py
+49
-63
vllm/v1/executor/multiproc_executor.py
vllm/v1/executor/multiproc_executor.py
+21
-17
vllm/v1/kv_cache_interface.py
vllm/v1/kv_cache_interface.py
+46
-11
vllm/v1/metrics/loggers.py
vllm/v1/metrics/loggers.py
+43
-4
vllm/v1/metrics/stats.py
vllm/v1/metrics/stats.py
+4
-0
vllm/v1/request.py
vllm/v1/request.py
+7
-10
vllm/v1/sample/ops/topk_topp_sampler.py
vllm/v1/sample/ops/topk_topp_sampler.py
+86
-18
vllm/v1/sample/rejection_sampler.py
vllm/v1/sample/rejection_sampler.py
+12
-0
vllm/v1/sample/sampler.py
vllm/v1/sample/sampler.py
+6
-0
No files found.
vllm/v1/core/kv_cache_manager.py
View file @
fcfc474d
...
...
@@ -5,10 +5,12 @@ from collections.abc import Iterable
from
typing
import
Optional
from
vllm.logger
import
init_logger
from
vllm.utils
import
cdiv
from
vllm.utils
import
cdiv
,
sha256
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.kv_cache_utils
import
(
BlockHashType
,
KVCacheBlock
,
hash_request_tokens
)
from
vllm.v1.core.specialized_manager
import
get_specialized_manager
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.metrics.stats
import
PrefixCacheStats
from
vllm.v1.request
import
Request
,
RequestStatus
...
...
@@ -19,20 +21,24 @@ class KVCacheManager:
def
__init__
(
self
,
block_size
:
int
,
num_gpu_blocks
:
int
,
kv_cache_config
:
KVCacheConfig
,
max_model_len
:
int
,
sliding_window
:
Optional
[
int
]
=
None
,
enable_caching
:
bool
=
True
,
caching_hash_algo
:
str
=
"builtin"
,
num_preallocate_tokens
:
int
=
64
,
log_stats
:
bool
=
False
,
)
->
None
:
self
.
block_size
=
block_size
self
.
num_gpu_blocks
=
num_gpu_blocks
assert
len
(
kv_cache_config
.
kv_cache_groups
)
==
1
,
(
"KVCacheManager does not support hybrid models with more than 1 "
"kv cache group"
)
kv_cache_spec
=
kv_cache_config
.
kv_cache_groups
[
0
].
kv_cache_spec
self
.
block_size
=
kv_cache_spec
.
block_size
self
.
num_gpu_blocks
=
kv_cache_config
.
num_blocks
self
.
max_model_len
=
max_model_len
self
.
max_num_blocks_per_req
=
cdiv
(
max_model_len
,
block_size
)
self
.
sliding_window
=
sliding_window
self
.
max_num_blocks_per_req
=
cdiv
(
max_model_len
,
self
.
block_size
)
self
.
enable_caching
=
enable_caching
self
.
caching_hash_fn
=
sha256
if
caching_hash_algo
==
"sha256"
else
hash
# FIXME: make prefix cache stats conditional on log_stats
self
.
log_stats
=
log_stats
# NOTE(woosuk): To avoid frequent block allocation, we preallocate some
...
...
@@ -46,9 +52,15 @@ class KVCacheManager:
# further allocation. When it uses up all the N empty blocks, it gets
# N new empty blocks.
self
.
num_preallocate_tokens
=
num_preallocate_tokens
self
.
num_preallocate_blocks
=
cdiv
(
num_preallocate_tokens
,
block_size
)
self
.
num_preallocate_blocks
=
cdiv
(
num_preallocate_tokens
,
self
.
block_size
)
self
.
block_pool
=
BlockPool
(
self
.
num_gpu_blocks
,
enable_caching
)
self
.
block_pool
=
BlockPool
(
num_gpu_blocks
,
enable_caching
)
self
.
specialized_manager
=
get_specialized_manager
(
kv_cache_spec
=
kv_cache_spec
,
block_pool
=
self
.
block_pool
,
)
# Mapping from request ID to blocks to track the blocks allocated
# for each request, so that we can free the blocks when the request
...
...
@@ -109,22 +121,31 @@ class KVCacheManager:
# if the scheduler has tried to schedule the request before.
block_hashes
=
self
.
req_to_block_hashes
[
request
.
request_id
]
if
not
block_hashes
:
block_hashes
=
hash_request_tokens
(
self
.
block_size
,
request
)
block_hashes
=
hash_request_tokens
(
self
.
caching_hash_fn
,
self
.
block_size
,
request
)
self
.
req_to_block_hashes
[
request
.
request_id
]
=
block_hashes
self
.
prefix_cache_stats
.
requests
+=
1
if
request
.
sampling_params
.
prompt_logprobs
is
None
:
# Check for cache hits
computed_blocks
=
[]
for
block_hash
in
block_hashes
:
# block_hashes is a chain of block hashes. If a block hash
# is not in the cached_block_hash_to_id, the following
# block hashes are not computed yet for sure.
if
cached_block
:
=
self
.
block_pool
.
get_cached_block
(
block_hash
):
computed_blocks
.
append
(
cached_block
)
if
len
(
block_hashes
)
*
self
.
block_size
==
request
.
num_tokens
:
# When prompt length is divisible by the block size and all
# blocks are cached, we need to recompute the last token. This
# have to be achieved by re-computing an entire block because
# allocate_slots() assumes num_computed_tokens is always a
# multiple of the block size. To achieve this, remove the last
# block hash from the block_hashes for find_longest_cache_hit
# This limitation can potentially be removed in the future to
# slightly improve the performance.
last_block_hash
=
block_hashes
.
pop
()
else
:
break
last_block_hash
=
None
computed_blocks
=
(
self
.
specialized_manager
.
find_longest_cache_hit
(
block_hashes
))
if
last_block_hash
is
not
None
:
# Add back the last block hash if it was removed.
block_hashes
.
append
(
last_block_hash
)
self
.
prefix_cache_stats
.
queries
+=
len
(
block_hashes
)
self
.
prefix_cache_stats
.
hits
+=
len
(
computed_blocks
)
...
...
@@ -173,13 +194,24 @@ class KVCacheManager:
new_computed_blocks
=
new_computed_blocks
or
[]
req_blocks
=
self
.
req_to_blocks
[
request
.
request_id
]
# Free the blocks that are skipped during the attention computation
# (e.g., tokens outside the sliding window).
# We can do this even if we cannot schedule this request due to
# insufficient free blocks.
# Should call this function before allocating new blocks to reduce
# the number of evicted blocks.
removed_blocks
=
self
.
specialized_manager
.
remove_skipped_blocks
(
req_blocks
,
request
.
num_computed_tokens
)
self
.
block_pool
.
free_blocks
(
removed_blocks
)
# The number of computed tokens is the number of computed tokens plus
# the new prefix caching hits
num_computed_tokens
=
(
request
.
num_computed_tokens
+
len
(
new_computed_blocks
)
*
self
.
block_size
)
num_required_blocks
=
cdiv
(
num_computed_tokens
+
num_tokens
,
self
.
block_size
)
req_blocks
=
self
.
req_to_blocks
[
request
.
request_id
]
num_new_blocks
=
(
num_required_blocks
-
len
(
req_blocks
)
-
len
(
new_computed_blocks
))
...
...
@@ -247,6 +279,7 @@ class KVCacheManager:
num_cached_blocks
=
num_cached_blocks
,
num_full_blocks
=
num_full_blocks_after_append
,
block_size
=
self
.
block_size
,
hash_fn
=
self
.
caching_hash_fn
,
)
self
.
num_cached_block
[
...
...
vllm/v1/core/kv_cache_utils.py
View file @
fcfc474d
# SPDX-License-Identifier: Apache-2.0
"""KV-Cache Utilities."""
import
os
from
collections
import
deque
from
collections.abc
import
Sequence
from
dataclasses
import
dataclass
from
typing
import
Any
,
NamedTuple
,
Optional
from
typing
import
Any
,
Callable
,
NamedTuple
,
Optional
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.v1.kv_cache_interface
import
(
KVCacheConfig
,
KVCacheGroupSpec
,
KVCacheSpec
,
KVCacheTensor
)
from
vllm.utils
import
sha256
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
,
KVCacheSpec
,
KVCacheTensor
,
SlidingWindowSpec
)
from
vllm.v1.metrics.stats
import
PrefixCacheStats
from
vllm.v1.request
import
Request
...
...
@@ -18,9 +21,8 @@ logger = init_logger(__name__)
class
BlockHashType
(
NamedTuple
):
"""Hash value of a block (int), the token IDs in the block, and extra keys.
We keep a tuple of token IDs and extra keys to reduce the likelihood of
hash collisions when the hash value is the same. But please note that
hash collisions can still theoretically occur, albeit with an extremely
low probability.
hash collisions when the hash value is the same. By using SHA256 however,
hash collisions are practically impossible.
"""
# Hash value of the block in an integer.
hash_value
:
int
...
...
@@ -30,6 +32,20 @@ class BlockHashType(NamedTuple):
extra_keys
:
Optional
[
Any
]
=
None
# The hash seed for the first block of the prefix block sequence.
#
# Even if the hash function is the builtin hash(), we use sha256 to generate
# the initial hash to simplify the code. This is not performance critical
# as it is done one per process.
#
# We use a random value to avoid hash collisions or PYTHONHASHSEED environment
# variable if set such that processes can share the seed if needed.
# This aligns with the behavior of Python's hash() function, which also uses
# a random seed if PYTHONHASHSEED is not set.
NONE_HASH
=
int
.
from_bytes
(
os
.
urandom
(
32
),
byteorder
=
"big"
)
if
os
.
getenv
(
'PYTHONHASHSEED'
)
is
not
None
else
sha256
(
os
.
getenv
(
'PYTHONHASHSEED'
))
class
PrefixCachingMetrics
:
"""Metrics for prefix caching with a hit rate of the most recent N requests.
...
...
@@ -375,6 +391,7 @@ def generate_block_hash_extra_keys(
def
hash_block_tokens
(
hash_function
:
Callable
,
parent_block_hash
:
Optional
[
int
],
curr_block_token_ids
:
Sequence
[
int
],
extra_keys
:
Optional
[
tuple
[
Any
,
...]]
=
None
)
->
BlockHashType
:
...
...
@@ -395,21 +412,16 @@ def hash_block_tokens(
The entire tuple is used as the hash key of the block.
"""
if
not
parent_block_hash
:
# Note that we use 'None' as a string here instead of None because
# as of Python 3.12, hash(None) returns a constant predictable value.
# This could possibly make it easier to find and exploit hash
# collisions. 'None' as a string will be hashed differently per process,
# but consistently within the same process. This is the same as the
# behavior of None prior to Python 3.12.
parent_block_hash
=
hash
(
'None'
)
parent_block_hash
=
NONE_HASH
curr_block_token_ids_tuple
=
tuple
(
curr_block_token_ids
)
return
BlockHashType
(
hash
((
parent_block_hash
,
curr_block_token_ids_tuple
,
extra_keys
)),
hash_function
(
(
parent_block_hash
,
curr_block_token_ids_tuple
,
extra_keys
)),
curr_block_token_ids_tuple
,
extra_keys
)
def
hash_request_tokens
(
block_size
:
int
,
def
hash_request_tokens
(
hash_function
:
Any
,
block_size
:
int
,
request
:
Request
)
->
list
[
BlockHashType
]:
"""Computes hash values of a chain of blocks given a sequence of
token IDs. The hash value is used for prefix caching.
...
...
@@ -441,7 +453,7 @@ def hash_request_tokens(block_size: int,
req_extra_keys
,
curr_mm_idx
=
generate_block_hash_extra_keys
(
request
,
start
,
end
,
curr_mm_idx
)
block_hash
=
hash_block_tokens
(
parent_block_hash_value
,
block_hash
=
hash_block_tokens
(
hash_function
,
parent_block_hash_value
,
block_token_ids
,
req_extra_keys
)
ret
.
append
(
block_hash
)
parent_block_hash_value
=
block_hash
.
hash_value
...
...
@@ -472,14 +484,14 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
max_model_len
=
vllm_config
.
model_config
.
max_model_len
needed_memory
=
0
for
layer_spec
in
kv_cache_spec
.
values
():
needed_memory
+=
layer_spec
.
bytes_for_tokens
(
max_model_len
)
needed_memory
+=
layer_spec
.
max_memory_usage_bytes
(
vllm_config
)
if
needed_memory
>
available_memory
:
raise
ValueError
(
f
"To serve at least one request with the models's max seq len "
f
"(
{
max_model_len
}
), (
{
needed_memory
/
1024
/
1024
/
1024
:.
2
f
}
GB KV "
f
"(
{
max_model_len
}
), (
{
needed_memory
/
1024
/
1024
/
1024
:.
2
f
}
G
i
B KV "
f
"cache is needed, which is larger than the available KV cache "
f
"memory (
{
available_memory
/
1024
/
1024
/
1024
:.
2
f
}
GB). Try "
f
"memory (
{
available_memory
/
1024
/
1024
/
1024
:.
2
f
}
G
i
B). Try "
f
"increasing `gpu_memory_utilization` or decreasing "
f
"`max_model_len` when initializing the engine."
)
...
...
@@ -586,6 +598,33 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
return
kv_cache_config
def
unify_hybrid_kv_cache_specs
(
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
]):
"""
Only models with one type of KV cache are supported yet. This function tries
to convert the KV cache specs to one type if the model is a hybrid model
with multiple type of KV cache. It will convert all SlidingWindowSpec to
FullAttentionSpec if both types are present.
Args:
kv_cache_spec: The kv cache spec of each attention layer in the model
"""
has_full_attention
=
any
(
isinstance
(
spec
,
FullAttentionSpec
)
for
spec
in
kv_cache_spec
.
values
())
has_sliding_window
=
any
(
isinstance
(
spec
,
SlidingWindowSpec
)
for
spec
in
kv_cache_spec
.
values
())
if
has_full_attention
and
has_sliding_window
:
for
layer_name
,
spec
in
kv_cache_spec
.
items
():
if
isinstance
(
spec
,
SlidingWindowSpec
):
kv_cache_spec
[
layer_name
]
=
FullAttentionSpec
(
block_size
=
spec
.
block_size
,
num_kv_heads
=
spec
.
num_kv_heads
,
head_size
=
spec
.
head_size
,
dtype
=
spec
.
dtype
,
use_mla
=
spec
.
use_mla
,
)
def
get_kv_cache_config
(
vllm_config
:
VllmConfig
,
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
],
available_memory
:
int
)
->
KVCacheConfig
:
...
...
@@ -602,6 +641,7 @@ def get_kv_cache_config(vllm_config: VllmConfig,
The generated KVCacheConfigs
"""
check_enough_kv_cache_memory
(
vllm_config
,
kv_cache_spec
,
available_memory
)
unify_hybrid_kv_cache_specs
(
kv_cache_spec
)
if
is_kv_cache_type_uniform
(
kv_cache_spec
):
# KV cache of all layers are the same, which is true for
# most models. Allocate the same amount of memory for
...
...
vllm/v1/core/sched/output.py
View file @
fcfc474d
...
...
@@ -10,8 +10,7 @@ if TYPE_CHECKING:
import
numpy.typing
as
npt
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.multimodal.base
import
PlaceholderRange
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.request
import
Request
...
...
vllm/v1/core/sched/scheduler.py
View file @
fcfc474d
...
...
@@ -7,9 +7,9 @@ from collections import deque
from
collections.abc
import
Iterable
from
typing
import
Optional
,
Union
from
vllm.config
import
(
CacheConfig
,
LoRAConfig
,
ModelConfig
,
SchedulerConfig
,
SpeculativeConfig
)
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
ModelConfig
,
SchedulerConfig
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.v1.core.encoder_cache_manager
import
(
EncoderCacheManager
,
compute_encoder_budget
)
from
vllm.v1.core.kv_cache_manager
import
KVCacheManager
...
...
@@ -19,9 +19,11 @@ from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
from
vllm.v1.core.sched.utils
import
check_stop
from
vllm.v1.engine
import
(
EngineCoreEventType
,
EngineCoreOutput
,
EngineCoreOutputs
)
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.metrics.stats
import
SchedulerStats
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.spec_decode.metrics
import
SpecDecodingStats
from
vllm.v1.structured_output
import
StructuredOutputManager
logger
=
init_logger
(
__name__
)
...
...
@@ -35,32 +37,37 @@ class Scheduler(SchedulerInterface):
model_config
:
ModelConfig
,
cache_config
:
CacheConfig
,
lora_config
:
Optional
[
LoRAConfig
],
speculative_config
:
Optional
[
SpeculativeConfig
],
log_stats
:
bool
,
kv_cache_config
:
KVCacheConfig
,
structured_output_manager
:
StructuredOutputManager
,
mm_registry
:
MultiModalRegistry
=
MULTIMODAL_REGISTRY
,
include_finished_set
:
bool
=
False
,
log_stats
:
bool
=
False
,
)
->
None
:
self
.
scheduler_config
=
scheduler_config
self
.
cache_config
=
cache_config
self
.
lora_config
=
lora_config
self
.
speculativ
e_config
=
speculativ
e_config
self
.
kv_cach
e_config
=
kv_cach
e_config
self
.
log_stats
=
log_stats
self
.
structured_output_manager
=
structured_output_manager
# include_finished_set controls whether a separate set of finished
# request ids should be included in the EngineCoreOutputs returned
# by update_from_outputs(). This is currently used in the multi-engine
# case to track request lifetimes efficiently.
self
.
include_finished_set
=
include_finished_set
# Scheduling constraints.
self
.
max_num_running_reqs
=
self
.
scheduler_config
.
max_num_seqs
self
.
max_num_scheduled_tokens
=
\
self
.
scheduler_config
.
max_num_batched_tokens
self
.
max_model_len
=
self
.
scheduler_config
.
max_model_len
num_gpu_blocks
=
cache_config
.
num_gpu_blocks
assert
isinstance
(
num_gpu_blocks
,
int
)
and
num_gpu_blocks
>
0
# Create the KV cache manager.
self
.
kv_cache_manager
=
KVCacheManager
(
block_size
=
self
.
cache_config
.
block_size
,
num_gpu_blocks
=
num_gpu_blocks
,
kv_cache_config
=
kv_cache_config
,
max_model_len
=
self
.
max_model_len
,
sliding_window
=
self
.
cache_config
.
sliding_window
,
enable_
caching
=
self
.
cache_config
.
enable_
prefix_caching
,
enable_caching
=
cache_config
.
enable_prefix_caching
,
caching
_hash_algo
=
self
.
cache_config
.
prefix_caching
_hash_algo
,
log_stats
=
self
.
log_stats
)
self
.
block_size
=
self
.
cache_config
.
block_size
...
...
@@ -92,6 +99,7 @@ class Scheduler(SchedulerInterface):
encoder_compute_budget
,
encoder_cache_size
=
compute_encoder_budget
(
model_config
=
model_config
,
scheduler_config
=
scheduler_config
,
mm_registry
=
mm_registry
,
)
# NOTE(woosuk): Here, "encoder" includes the vision encoder (and
...
...
@@ -152,23 +160,31 @@ class Scheduler(SchedulerInterface):
num_new_tokens
=
(
request
.
num_tokens_with_spec
-
request
.
num_computed_tokens
)
if
(
0
<
self
.
scheduler_config
.
long_prefill_token_threshold
<
num_new_tokens
):
num_new_tokens
=
(
self
.
scheduler_config
.
long_prefill_token_threshold
)
num_new_tokens
=
min
(
num_new_tokens
,
token_budget
)
assert
num_new_tokens
>
0
# Schedule encoder inputs.
encoder_inputs_to_schedule
,
num_new_tokens
,
new_encoder_budget
=
(
self
.
_try_schedule_encoder_inputs
(
request
,
request
.
num_computed_tokens
,
num_new_tokens
,
encoder_budget
)
)
if
request
.
has_encoder_inputs
:
(
encoder_inputs_to_schedule
,
num_new_tokens
,
new_encoder_budget
)
=
self
.
_try_schedule_encoder_inputs
(
request
,
request
.
num_computed_tokens
,
num_new_tokens
,
encoder_budget
)
if
num_new_tokens
==
0
:
# The request cannot be scheduled because the encoder budget
# or the encoder cache is exhausted.
# NOTE(woosuk): Here, by doing `continue` instead of `break`,
# we do not strictly follow the FCFS scheduling policy and
# allow the lower-priority requests to be scheduled.
# NOTE(woosuk): By using `continue` instead of `break` here,
# we intentionally relax the strict FCFS scheduling policy
# to allow lower-priority requests to be scheduled when a
# higher-priority request is blocked by encoder constraints.
req_index
+=
1
continue
else
:
encoder_inputs_to_schedule
=
None
new_encoder_budget
=
encoder_budget
while
True
:
new_blocks
=
self
.
kv_cache_manager
.
allocate_slots
(
...
...
@@ -235,16 +251,16 @@ class Scheduler(SchedulerInterface):
encoder_budget
=
new_encoder_budget
# Record the LoRAs in scheduled_running_reqs
request
ed_loras
:
set
[
int
]
=
set
()
schedul
ed_loras
:
set
[
int
]
=
set
()
if
self
.
lora_config
:
request
ed_loras
=
set
(
schedul
ed_loras
=
set
(
req
.
lora_request
.
lora_int_id
for
req
in
scheduled_running_reqs
if
req
.
lora_request
and
req
.
lora_request
.
lora_int_id
>
0
)
assert
len
(
request
ed_loras
)
<=
self
.
lora_config
.
max_loras
assert
len
(
schedul
ed_loras
)
<=
self
.
lora_config
.
max_loras
# Use a temporary deque to collect requests that need to be skipped
# and put back at the head of the waiting queue later
waiting_for_fsm
:
deque
[
Request
]
=
deque
()
skipped_waiting_requests
:
deque
[
Request
]
=
deque
()
# Next, schedule the WAITING requests.
if
not
preempted_reqs
:
...
...
@@ -254,31 +270,27 @@ class Scheduler(SchedulerInterface):
request
=
self
.
waiting
[
0
]
# Skip request if the structured output request is still waiting
# for FSM compilation.
if
request
.
status
==
RequestStatus
.
WAITING_FOR_FSM
:
structured_output_req
=
request
.
structured_output_request
if
structured_output_req
and
structured_output_req
.
grammar
:
request
.
status
=
RequestStatus
.
WAITING
else
:
waiting_structured_output_req
=
self
.
waiting
.
popleft
()
waiting_for_fsm
.
appendleft
(
waiting_structured_output_req
)
self
.
waiting
.
popleft
()
skipped_waiting_requests
.
appendleft
(
request
)
continue
# Check that adding the request still respects the max_loras
# constraint.
if
self
.
lora_config
and
request
.
lora_request
:
req_lora_id
=
request
.
lora_request
.
lora_int_id
if
len
(
requested_loras
)
==
self
.
lora_config
.
max_loras
and
(
req_lora_id
not
in
requested_loras
):
# Cannot schedule.
# TODO (varun): This means all the other requests in
# the WAITING queue will be blocked by this request,
# even if,
# 1. these other requests do not use LoRA, or,
# 2. these other requests use the already requested
# LoRAs.
# This is too conservative and could be optimized.
break
if
self
.
lora_config
and
request
.
lora_request
and
(
len
(
scheduled_loras
)
==
self
.
lora_config
.
max_loras
and
request
.
lora_request
.
lora_int_id
not
in
scheduled_loras
):
# Scheduling would exceed max_loras, skip.
self
.
waiting
.
popleft
()
skipped_waiting_requests
.
appendleft
(
request
)
continue
# Get already-cached tokens.
computed_blocks
,
num_computed_tokens
=
\
...
...
@@ -288,21 +300,15 @@ class Scheduler(SchedulerInterface):
# `request.num_prompt_tokens` to consider the resumed requests,
# which have output tokens.
num_new_tokens
=
request
.
num_tokens
-
num_computed_tokens
if
num_new_tokens
==
0
:
# This happens when prompt length is divisible by the block
# size and all blocks are cached. Now we force to recompute
# the last block. Note that we have to re-compute an entire
# block because allocate_slots() assumes num_computed_tokens
# is always a multiple of the block size. This limitation
# can potentially be removed in the future to slightly
# improve the performance.
num_computed_tokens
-=
self
.
block_size
num_new_tokens
=
self
.
block_size
computed_blocks
.
pop
()
if
(
0
<
self
.
scheduler_config
.
long_prefill_token_threshold
<
num_new_tokens
):
num_new_tokens
=
(
self
.
scheduler_config
.
long_prefill_token_threshold
)
num_new_tokens
=
min
(
num_new_tokens
,
token_budget
)
assert
num_new_tokens
>
0
# Schedule encoder inputs.
if
request
.
has_encoder_inputs
:
(
encoder_inputs_to_schedule
,
num_new_tokens
,
new_encoder_budget
)
=
self
.
_try_schedule_encoder_inputs
(
request
,
num_computed_tokens
,
num_new_tokens
,
...
...
@@ -310,6 +316,9 @@ class Scheduler(SchedulerInterface):
if
num_new_tokens
==
0
:
# The request cannot be scheduled.
break
else
:
encoder_inputs_to_schedule
=
None
new_encoder_budget
=
encoder_budget
new_blocks
=
self
.
kv_cache_manager
.
allocate_slots
(
request
,
num_new_tokens
,
computed_blocks
)
...
...
@@ -336,7 +345,7 @@ class Scheduler(SchedulerInterface):
f
"Invalid request status:
{
request
.
status
}
"
)
if
self
.
lora_config
and
request
.
lora_request
:
request
ed_loras
.
add
(
request
.
lora_request
.
lora_int_id
)
schedul
ed_loras
.
add
(
request
.
lora_request
.
lora_int_id
)
req_to_new_block_ids
[
request
.
request_id
]
=
[
b
.
block_id
for
b
in
computed_blocks
+
new_blocks
]
...
...
@@ -355,8 +364,8 @@ class Scheduler(SchedulerInterface):
encoder_budget
=
new_encoder_budget
# Put back any skipped requests at the head of the waiting queue
if
waiting_for_fsm
:
self
.
waiting
.
extendleft
(
waiting_for_fsm
)
if
skipped_waiting_requests
:
self
.
waiting
.
extendleft
(
skipped_waiting_requests
)
# Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens
=
sum
(
num_scheduled_tokens
.
values
())
...
...
@@ -425,6 +434,18 @@ class Scheduler(SchedulerInterface):
grammar_bitmask
=
grammar_bitmask
,
)
# Advance the number of computed tokens for the request AFTER
# the request is scheduled.
# 1. The scheduler_output of the current step has to include the
# original number of scheduled tokens to determine input IDs.
# 2. Advance the number of computed tokens here allowing us to
# schedule the prefill request again immediately in the next
# scheduling step.
# 3. If some tokens (e.g. spec tokens) are rejected later, the number of
# computed tokens will be adjusted in update_from_output.
for
req_id
,
num_scheduled_token
in
num_scheduled_tokens
.
items
():
self
.
requests
[
req_id
].
num_computed_tokens
+=
num_scheduled_token
self
.
finished_req_ids
=
set
()
return
scheduler_output
...
...
@@ -479,9 +500,6 @@ class Scheduler(SchedulerInterface):
limitations, the method adjusts `num_new_tokens` to schedule only the
decoder tokens up to just before the unschedulable encoder input.
"""
if
not
request
.
has_encoder_inputs
():
return
[],
num_new_tokens
,
encoder_budget
encoder_inputs_to_schedule
:
list
[
int
]
=
[]
mm_positions
=
request
.
mm_positions
assert
mm_positions
is
not
None
...
...
@@ -539,6 +557,7 @@ class Scheduler(SchedulerInterface):
new_running
:
list
[
Request
]
=
[]
outputs
:
list
[
EngineCoreOutput
]
=
[]
spec_decoding_stats
:
Optional
[
SpecDecodingStats
]
=
None
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
# loop can be a performance bottleneck. We should do our best to avoid
...
...
@@ -553,36 +572,32 @@ class Scheduler(SchedulerInterface):
req_index
=
model_runner_output
.
req_id_to_index
[
req_id
]
generated_token_ids
=
sampled_token_ids
[
req_index
]
if
req_id
not
in
scheduler_output
.
scheduled_spec_decode_tokens
:
# When the request's num_computed_tokens catches up
# its num_tokens, the request generates output tokens.
# Otherwise, we ignore the sampler output for the request.
request
.
num_computed_tokens
+=
num_tokens_scheduled
assert
request
.
num_computed_tokens
<=
request
.
num_tokens
else
:
# num_computed_tokens_step represents the number of tokens
scheduled_spec_token_ids
=
(
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
))
if
scheduled_spec_token_ids
:
# num_computed_tokens represents the number of tokens
# processed in the current step, considering scheduled
# tokens and rejections.
# It is calculated as:
# num_computed_tokens_step = num_scheduled_tokens -
# num_tokens_rejected,
# where num_tokens_rejected is given by:
# tokens and rejections. If some tokens are rejected,
# num_computed_tokens is decreased by the number of rejected
# tokens, where is given by:
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
scheduled_spec_token_ids
=
(
scheduler_output
.
scheduled_spec_decode_tokens
[
req_id
])
num_computed_tokens_step
=
num_scheduled_tokens
[
req_id
]
-
(
len
(
scheduled_spec_token_ids
)
+
1
-
num_tokens_rejected
=
(
len
(
scheduled_spec_token_ids
)
+
1
-
len
(
generated_token_ids
))
request
.
num_computed_tokens
+=
num_computed_tokens_step
request
.
num_computed_tokens
-=
num_tokens_rejected
spec_decoding_stats
=
self
.
make_spec_decoding_stats
(
spec_decoding_stats
,
num_draft_tokens
=
len
(
scheduled_spec_token_ids
),
num_accepted_tokens
=
len
(
generated_token_ids
)
-
1
)
cached_encoder_input_ids
=
(
self
.
encoder_cache_manager
.
get_cached_input_ids
(
request
))
# OPTIMIZATION: Avoid list(set) if the set is empty.
if
cached_encoder_input_ids
:
for
input_id
in
list
(
cached_encoder_input_ids
):
start_pos
=
request
.
mm_positions
[
input_id
][
"offset"
]
num_tokens
=
request
.
mm_positions
[
input_id
][
"length"
]
mm_positions
=
request
.
mm_positions
[
input_id
]
start_pos
=
mm_positions
[
"offset"
]
num_tokens
=
mm_positions
[
"length"
]
if
start_pos
+
num_tokens
<=
request
.
num_computed_tokens
:
# The encoder output is already processed and stored
# in the decoder's KV cache.
...
...
@@ -595,23 +610,24 @@ class Scheduler(SchedulerInterface):
stopped
=
False
new_logprobs
=
None
new_token_ids
:
list
[
int
]
=
[]
new_token_ids
=
generated_token_ids
if
request
.
num_computed_tokens
>=
request
.
num_tokens
:
for
output_token_id
in
generated_token_ids
:
# Append generated tokens and check for stop. Note that if
# a request is still being prefilled, we expect the model runner
# to return empty token ids for the request.
for
num_new
,
output_token_id
in
enumerate
(
new_token_ids
,
1
):
request
.
append_output_token_ids
(
output_token_id
)
new_token_ids
.
append
(
output_token_id
)
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
stopped
=
check_stop
(
request
,
self
.
max_model_len
)
if
stopped
:
self
.
_free_request
(
request
)
del
new_token_ids
[
num_new
:]
# Trim new tokens if needed.
break
# Extract sample logprobs if needed.
if
request
.
sampling_params
.
logprobs
is
not
None
:
assert
logprobs
is
not
None
if
request
.
sampling_params
.
logprobs
is
not
None
and
logprobs
:
# NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
new_logprobs
=
logprobs
.
slice
(
req_index
,
req_index
+
1
)
...
...
@@ -621,9 +637,7 @@ class Scheduler(SchedulerInterface):
# should not be None if use_structured_output, we have
# check above, so safe to ignore type warning
request
.
structured_output_request
.
grammar
.
accept_tokens
(
# type: ignore[union-attr]
request
.
request_id
,
new_token_ids
,
)
req_id
,
new_token_ids
)
# Get prompt logprobs for this request.
prompt_logprobs_tensors
=
prompt_logprobs_dict
.
get
(
req_id
)
...
...
@@ -642,15 +656,21 @@ class Scheduler(SchedulerInterface):
# Invariant: EngineCore returns no partial prefill outputs.
assert
not
prompt_logprobs_tensors
self
.
scheduled_req_ids
.
remove
(
req
uest
.
request
_id
)
self
.
scheduled_req_ids
.
remove
(
req_id
)
if
not
stopped
:
new_running
.
append
(
request
)
self
.
running
=
new_running
return
EngineCoreOutputs
(
engine_core_outputs
=
EngineCoreOutputs
(
outputs
=
outputs
,
scheduler_stats
=
self
.
make_stats
(),
scheduler_stats
=
self
.
make_stats
(
spec_decoding_stats
),
)
if
self
.
include_finished_set
:
#TODO currently sending duplicates here, improve this
engine_core_outputs
.
finished_requests
=
(
scheduler_output
.
finished_req_ids
|
self
.
finished_req_ids
)
return
engine_core_outputs
def
add_request
(
self
,
request
:
Request
)
->
None
:
self
.
waiting
.
append
(
request
)
...
...
@@ -710,7 +730,10 @@ class Scheduler(SchedulerInterface):
def
reset_prefix_cache
(
self
)
->
bool
:
return
self
.
kv_cache_manager
.
reset_prefix_cache
()
def
make_stats
(
self
)
->
Optional
[
SchedulerStats
]:
def
make_stats
(
self
,
spec_decoding_stats
:
Optional
[
SpecDecodingStats
]
=
None
,
)
->
Optional
[
SchedulerStats
]:
if
not
self
.
log_stats
:
return
None
return
SchedulerStats
(
...
...
@@ -718,4 +741,19 @@ class Scheduler(SchedulerInterface):
num_waiting_reqs
=
len
(
self
.
waiting
),
gpu_cache_usage
=
self
.
kv_cache_manager
.
usage
,
prefix_cache_stats
=
self
.
kv_cache_manager
.
make_prefix_cache_stats
(),
spec_decoding_stats
=
spec_decoding_stats
,
)
def
make_spec_decoding_stats
(
self
,
spec_decoding_stats
:
Optional
[
SpecDecodingStats
],
num_draft_tokens
:
int
,
num_accepted_tokens
:
int
,
)
->
Optional
[
SpecDecodingStats
]:
if
not
self
.
log_stats
:
return
None
if
spec_decoding_stats
is
None
:
spec_decoding_stats
=
SpecDecodingStats
()
spec_decoding_stats
.
observe
(
num_draft_tokens
=
num_draft_tokens
,
num_accepted_tokens
=
num_accepted_tokens
)
return
spec_decoding_stats
vllm/v1/core/specialized_manager.py
0 → 100644
View file @
fcfc474d
# SPDX-License-Identifier: Apache-2.0
from
abc
import
ABC
,
abstractmethod
from
vllm.utils
import
cdiv
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.kv_cache_utils
import
BlockHashType
,
KVCacheBlock
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheSpec
,
SlidingWindowSpec
)
class
SpecializedManager
(
ABC
):
"""
An abstract base class for specialized managers that handle the kv
cache management logic of different attention layers.
"""
def
__init__
(
self
,
kv_cache_spec
:
KVCacheSpec
,
block_pool
:
BlockPool
,
)
->
None
:
"""
Initializes the SpecializedManager.
Args:
kv_cache_spec: The kv_cache_spec for this manager.
block_pool: The block pool.
"""
self
.
block_size
=
kv_cache_spec
.
block_size
self
.
kv_cache_spec
=
kv_cache_spec
self
.
block_pool
=
block_pool
@
abstractmethod
def
find_longest_cache_hit
(
self
,
block_hashes
:
list
[
BlockHashType
])
->
list
[
KVCacheBlock
]:
"""
Get the longest cache hit prefix of the blocks. If no cache hit is
found, return an empty list.
Args:
block_hashes: The block hashes of the request.
Returns:
A list of cached blocks with skipped blocks replaced by null block.
For example, sliding window manager should return a list like
[NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)] for block size 4 and
sliding window 8.
"""
raise
NotImplementedError
@
abstractmethod
def
remove_skipped_blocks
(
self
,
blocks
:
list
[
KVCacheBlock
],
num_computed_tokens
:
int
)
->
list
[
KVCacheBlock
]:
"""
Remove the blocks that are no longer needed from `blocks`. The removed
blocks should be replaced by null_block. Return the removed blocks in
eviction order, where the first returned block should be evicted first.
Don't free the removed blocks in this function.
Args:
blocks: The list of blocks to be updated.
num_computed_tokens: The number of tokens that have been computed.
Returns:
The removed blocks in eviction order.
"""
raise
NotImplementedError
class
FullAttentionManager
(
SpecializedManager
):
def
find_longest_cache_hit
(
self
,
block_hashes
:
list
[
BlockHashType
])
->
list
[
KVCacheBlock
]:
computed_blocks
:
list
[
KVCacheBlock
]
=
[]
for
block_hash
in
block_hashes
:
# block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are
# not computed yet for sure.
if
cached_block
:
=
self
.
block_pool
.
get_cached_block
(
block_hash
):
computed_blocks
.
append
(
cached_block
)
else
:
break
return
computed_blocks
def
remove_skipped_blocks
(
self
,
blocks
:
list
[
KVCacheBlock
],
num_computed_tokens
:
int
)
->
list
[
KVCacheBlock
]:
# No need to remove blocks for full attention.
return
[]
class
SlidingWindowManager
(
SpecializedManager
):
def
__init__
(
self
,
kv_cache_spec
:
SlidingWindowSpec
,
block_pool
:
BlockPool
):
super
().
__init__
(
kv_cache_spec
,
block_pool
)
self
.
sliding_window
=
kv_cache_spec
.
sliding_window
# The number of contiguous blocks needed for prefix cache hit.
# -1 since the input token itself is also included in the window
self
.
sliding_window_contiguous_blocks
=
cdiv
(
(
kv_cache_spec
.
sliding_window
-
1
),
self
.
block_size
)
self
.
_null_block
=
block_pool
.
null_block
def
find_longest_cache_hit
(
self
,
block_hashes
:
list
[
BlockHashType
])
->
list
[
KVCacheBlock
]:
# TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
# optimize the time complexity from O(len(block_hashes)) to
# O(len(block_hashes) / sliding_window_contiguous_blocks +
# sliding_window_contiguous_blocks),
# which is good for low cache hit rate scenarios.
computed_blocks
=
[
self
.
_null_block
]
*
len
(
block_hashes
)
num_contiguous_blocks
=
0
# Search from right to left and early stop when a match is found.
for
i
in
range
(
len
(
block_hashes
)
-
1
,
-
1
,
-
1
):
if
cached_block
:
=
self
.
block_pool
.
get_cached_block
(
block_hashes
[
i
]):
computed_blocks
[
i
]
=
cached_block
num_contiguous_blocks
+=
1
if
(
num_contiguous_blocks
>=
self
.
sliding_window_contiguous_blocks
):
# Trim the trailing blocks.
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
# when sliding_window_contiguous_blocks=2.
del
computed_blocks
[
i
+
num_contiguous_blocks
:]
return
computed_blocks
else
:
num_contiguous_blocks
=
0
# The first `num_contiguous_blocks` is a cache hit even if
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
del
computed_blocks
[
num_contiguous_blocks
:]
return
computed_blocks
def
remove_skipped_blocks
(
self
,
blocks
:
list
[
KVCacheBlock
],
num_computed_tokens
:
int
)
->
list
[
KVCacheBlock
]:
# Remove the blocks that are no longer be in the sliding window and
# skipped during the attention computation.
last_useful_token
=
num_computed_tokens
-
self
.
sliding_window
+
1
last_useful_block
=
last_useful_token
//
self
.
block_size
removed_blocks
:
list
[
KVCacheBlock
]
=
[]
for
i
in
range
(
last_useful_block
-
1
,
-
1
,
-
1
):
if
blocks
[
i
]
==
self
.
_null_block
:
# If the block is already a null block, the blocks before it
# should also have been set to null blocks by the previous calls
# to this function.
break
removed_blocks
.
append
(
blocks
[
i
])
blocks
[
i
]
=
self
.
_null_block
return
removed_blocks
spec_manager_map
:
dict
[
type
[
KVCacheSpec
],
type
[
SpecializedManager
]]
=
{
FullAttentionSpec
:
FullAttentionManager
,
SlidingWindowSpec
:
SlidingWindowManager
,
}
def
get_specialized_manager
(
kv_cache_spec
:
KVCacheSpec
,
block_pool
:
BlockPool
)
->
SpecializedManager
:
manager_class
=
spec_manager_map
[
type
(
kv_cache_spec
)]
manager
=
manager_class
(
kv_cache_spec
,
block_pool
)
return
manager
vllm/v1/engine/__init__.py
View file @
fcfc474d
...
...
@@ -128,12 +128,18 @@ class EngineCoreOutputs(
#NOTE(Nick): We could consider ways to make this more compact,
# e.g. columnwise layout
engine_index
:
int
=
0
# [num_reqs]
outputs
:
list
[
EngineCoreOutput
]
=
[]
scheduler_stats
:
Optional
[
SchedulerStats
]
=
None
timestamp
:
float
=
0.0
utility_output
:
Optional
[
UtilityOutput
]
=
None
finished_requests
:
Optional
[
set
[
str
]]
=
None
# In DP case, used to signal that the engine is paused.
engine_paused
:
bool
=
False
def
__post_init__
(
self
):
if
self
.
timestamp
==
0.0
:
...
...
@@ -147,4 +153,5 @@ class EngineCoreRequestType(enum.Enum):
"""
ADD
=
b
'
\x00
'
ABORT
=
b
'
\x01
'
UTILITY
=
b
'
\x02
'
START_DP
=
b
'
\x02
'
UTILITY
=
b
'
\x03
'
vllm/v1/engine/async_llm.py
View file @
fcfc474d
...
...
@@ -14,10 +14,11 @@ from vllm.config import ModelConfig, VllmConfig
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.protocol
import
EngineClient
from
vllm.envs
import
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
,
PromptType
from
vllm.inputs
import
PromptType
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.outputs
import
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
...
...
@@ -48,7 +49,7 @@ class AsyncLLM(EngineClient):
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
input
_registry
:
Input
Registry
=
INPUT
_REGISTRY
,
mm
_registry
:
MultiModal
Registry
=
MULTIMODAL
_REGISTRY
,
use_cached_outputs
:
bool
=
False
,
log_requests
:
bool
=
True
,
start_engine_loop
:
bool
=
True
,
...
...
@@ -66,11 +67,17 @@ class AsyncLLM(EngineClient):
self
.
log_requests
=
log_requests
self
.
log_stats
=
log_stats
self
.
stat_loggers
:
list
[
StatLoggerBase
]
=
[]
# Set up stat loggers; independent set for each DP rank.
self
.
stat_loggers
:
list
[
list
[
StatLoggerBase
]]
=
[]
if
self
.
log_stats
:
for
i
in
range
(
vllm_config
.
parallel_config
.
data_parallel_size
):
loggers
:
list
[
StatLoggerBase
]
=
[]
if
logger
.
isEnabledFor
(
logging
.
INFO
):
self
.
stat_loggers
.
append
(
LoggingStatLogger
())
self
.
stat_loggers
.
append
(
PrometheusStatLogger
(
vllm_config
))
loggers
.
append
(
LoggingStatLogger
(
engine_index
=
i
))
loggers
.
append
(
PrometheusStatLogger
(
vllm_config
,
engine_index
=
i
))
self
.
stat_loggers
.
append
(
loggers
)
# Tokenizer (+ ensure liveness if running in another process).
self
.
tokenizer
=
init_tokenizer_from_configs
(
...
...
@@ -84,7 +91,7 @@ class AsyncLLM(EngineClient):
self
.
processor
=
Processor
(
vllm_config
=
vllm_config
,
tokenizer
=
self
.
tokenizer
,
input
_registry
=
input
_registry
,
mm
_registry
=
mm
_registry
,
)
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
...
...
@@ -329,6 +336,7 @@ class AsyncLLM(EngineClient):
# TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial.
self
.
_record_stats
(
engine_index
=
outputs
.
engine_index
,
scheduler_stats
=
outputs
.
scheduler_stats
,
iteration_stats
=
iteration_stats
,
)
...
...
@@ -350,12 +358,13 @@ class AsyncLLM(EngineClient):
self
,
scheduler_stats
:
Optional
[
SchedulerStats
],
iteration_stats
:
Optional
[
IterationStats
],
engine_index
:
int
=
0
,
):
if
not
self
.
log_stats
:
return
assert
scheduler_stats
is
not
None
for
stat_logger
in
self
.
stat_loggers
:
for
stat_logger
in
self
.
stat_loggers
[
engine_index
]
:
stat_logger
.
record
(
scheduler_stats
=
scheduler_stats
,
iteration_stats
=
iteration_stats
)
...
...
@@ -393,7 +402,8 @@ class AsyncLLM(EngineClient):
scheduler_outputs
=
None
,
model_output
=
None
,
)
->
None
:
for
stat_logger
in
self
.
stat_loggers
:
for
loggers
in
self
.
stat_loggers
:
for
stat_logger
in
loggers
:
stat_logger
.
log
()
async
def
check_health
(
self
)
->
None
:
...
...
@@ -414,8 +424,8 @@ class AsyncLLM(EngineClient):
async
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
await
self
.
engine_core
.
sleep_async
(
level
)
async
def
wake_up
(
self
)
->
None
:
await
self
.
engine_core
.
wake_up_async
()
async
def
wake_up
(
self
,
tags
:
Optional
[
list
[
str
]]
=
None
)
->
None
:
await
self
.
engine_core
.
wake_up_async
(
tags
)
async
def
is_sleeping
(
self
)
->
bool
:
return
await
self
.
engine_core
.
is_sleeping_async
()
...
...
vllm/v1/engine/core.py
View file @
fcfc474d
# SPDX-License-Identifier: Apache-2.0
import
os
import
queue
import
signal
import
sys
import
threading
import
time
from
concurrent.futures
import
Future
from
inspect
import
isclass
,
signature
from
multiprocessing.connection
import
Connection
from
typing
import
Any
,
Optional
from
logging
import
DEBUG
from
typing
import
Any
,
Callable
,
Optional
,
TypeVar
,
Union
import
msgspec
import
psutil
import
zmq
import
zmq.asyncio
from
vllm.config
import
VllmConfig
from
vllm.config
import
ParallelConfig
,
VllmConfig
from
vllm.distributed
import
stateless_destroy_torch_distributed_process_group
from
vllm.executor.multiproc_worker_utils
import
_add_prefix
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.transformers_utils.config
import
(
...
...
@@ -23,12 +26,14 @@ from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname,
zmq_socket_ctx
)
from
vllm.v1.core.kv_cache_utils
import
(
get_kv_cache_config
,
unify_kv_cache_configs
)
from
vllm.v1.core.sched.interface
import
SchedulerInterface
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.scheduler
import
Scheduler
as
V1Scheduler
from
vllm.v1.engine
import
(
EngineCoreOutputs
,
EngineCoreRequest
,
EngineCoreRequestType
,
UtilityOutput
)
from
vllm.v1.engine.mm_input_cache
import
MMInputCacheServer
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
...
...
@@ -39,6 +44,8 @@ logger = init_logger(__name__)
POLLING_TIMEOUT_S
=
2.5
_R
=
TypeVar
(
'_R'
)
# Return type for collective_rpc
class
EngineCore
:
"""Inner loop of vLLM's Engine."""
...
...
@@ -60,8 +67,9 @@ class EngineCore:
self
.
model_executor
=
executor_class
(
vllm_config
)
# Setup KV Caches and update CacheConfig after profiling.
num_gpu_blocks
,
num_cpu_blocks
=
self
.
_initialize_kv_caches
(
vllm_config
)
num_gpu_blocks
,
num_cpu_blocks
,
kv_cache_config
=
\
self
.
_initialize_kv_caches
(
vllm_config
)
vllm_config
.
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
vllm_config
.
cache_config
.
num_cpu_blocks
=
num_cpu_blocks
...
...
@@ -84,14 +92,16 @@ class EngineCore:
"compatibility may not be maintained."
,
vllm_config
.
scheduler_config
.
scheduler_cls
)
self
.
scheduler
=
Scheduler
(
self
.
scheduler
:
SchedulerInterface
=
Scheduler
(
scheduler_config
=
vllm_config
.
scheduler_config
,
model_config
=
vllm_config
.
model_config
,
cache_config
=
vllm_config
.
cache_config
,
lora_config
=
vllm_config
.
lora_config
,
speculative_config
=
vllm_config
.
speculative_config
,
log_stats
=
self
.
log_stats
,
kv_cache_config
=
kv_cache_config
,
structured_output_manager
=
self
.
structured_output_manager
,
include_finished_set
=
vllm_config
.
parallel_config
.
data_parallel_size
>
1
,
log_stats
=
self
.
log_stats
,
)
# Setup MM Input Mapper.
...
...
@@ -110,8 +120,8 @@ class EngineCore:
self
.
batch_queue_size
)
self
.
batch_queue
=
queue
.
Queue
(
self
.
batch_queue_size
)
def
_initialize_kv_caches
(
self
,
vllm_config
:
VllmConfig
)
->
tuple
[
int
,
int
]:
def
_initialize_kv_caches
(
self
,
vllm_config
:
VllmConfig
)
->
tuple
[
int
,
int
,
KVCacheConfig
]:
start
=
time
.
time
()
# Get all kv cache needed by the model
...
...
@@ -136,13 +146,14 @@ class EngineCore:
unify_kv_cache_configs
(
kv_cache_configs
)
# All workers have the same kv_cache_config except layer names, so use
# an arbitrary one to
get the number of blocks
.
# an arbitrary one to
initialize the scheduler
.
assert
all
([
cfg
.
num_blocks
==
kv_cache_configs
[
0
].
num_blocks
for
cfg
in
kv_cache_configs
])
num_gpu_blocks
=
kv_cache_configs
[
0
].
num_blocks
num_cpu_blocks
=
0
scheduler_kv_cache_config
=
kv_cache_configs
[
0
]
# Initialize kv cache and warmup the execution
self
.
model_executor
.
initialize_from_config
(
kv_cache_configs
)
...
...
@@ -150,7 +161,7 @@ class EngineCore:
elapsed
=
time
.
time
()
-
start
logger
.
info
((
"init engine (profile, create kv cache, "
"warmup model) took %.2f seconds"
),
elapsed
)
return
num_gpu_blocks
,
num_cpu_blocks
return
num_gpu_blocks
,
num_cpu_blocks
,
scheduler_kv_cache_config
def
add_request
(
self
,
request
:
EngineCoreRequest
):
"""Add request to the scheduler."""
...
...
@@ -253,8 +264,8 @@ class EngineCore:
def
sleep
(
self
,
level
:
int
=
1
):
self
.
model_executor
.
sleep
(
level
)
def
wake_up
(
self
):
self
.
model_executor
.
wake_up
()
def
wake_up
(
self
,
tags
:
Optional
[
list
[
str
]]
=
None
):
self
.
model_executor
.
wake_up
(
tags
)
def
is_sleeping
(
self
)
->
bool
:
return
self
.
model_executor
.
is_sleeping
...
...
@@ -274,6 +285,24 @@ class EngineCore:
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
model_executor
.
pin_lora
(
lora_id
)
def
save_sharded_state
(
self
,
path
:
str
,
pattern
:
Optional
[
str
]
=
None
,
max_size
:
Optional
[
int
]
=
None
,
)
->
None
:
self
.
model_executor
.
save_sharded_state
(
path
=
path
,
pattern
=
pattern
,
max_size
=
max_size
)
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
[...,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
tuple
=
(),
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
return
self
.
model_executor
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
class
EngineCoreProc
(
EngineCore
):
"""ZMQ-wrapper for running EngineCore in background process."""
...
...
@@ -282,10 +311,10 @@ class EngineCoreProc(EngineCore):
self
,
input_path
:
str
,
output_path
:
str
,
ready_pipe
:
Connection
,
vllm_config
:
VllmConfig
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
engine_index
:
int
=
0
,
):
super
().
__init__
(
vllm_config
,
executor_class
,
log_stats
)
...
...
@@ -301,14 +330,20 @@ class EngineCoreProc(EngineCore):
args
=
(
input_path
,
),
daemon
=
True
).
start
()
threading
.
Thread
(
target
=
self
.
process_output_socket
,
args
=
(
output_path
,
),
args
=
(
output_path
,
engine_index
),
daemon
=
True
).
start
()
# Send Readiness signal to EngineClient.
ready_pipe
.
send
({
"status"
:
"READY"
})
self
.
global_unfinished_reqs
=
False
self
.
step_fn
=
(
self
.
step
if
self
.
batch_queue
is
None
else
self
.
step_with_batch_queue
)
@
staticmethod
def
run_engine_core
(
*
args
,
**
kwargs
):
def
run_engine_core
(
*
args
,
dp_rank
:
int
=
0
,
local_dp_rank
:
int
=
0
,
ready_pipe
,
**
kwargs
):
"""Launch EngineCore busy loop in background process."""
# Signal handler used for graceful termination.
...
...
@@ -330,9 +365,21 @@ class EngineCoreProc(EngineCore):
signal
.
signal
(
signal
.
SIGINT
,
signal_handler
)
parent_process
=
psutil
.
Process
().
parent
()
engine_core
=
None
engine_core
:
Optional
[
EngineCoreProc
]
=
None
try
:
parallel_config
:
ParallelConfig
=
kwargs
[
"vllm_config"
].
parallel_config
if
parallel_config
.
data_parallel_size
>
1
:
# Set data parallel rank for this engine process.
parallel_config
.
data_parallel_rank
=
dp_rank
parallel_config
.
data_parallel_rank_local
=
local_dp_rank
engine_core
=
DPEngineCoreProc
(
*
args
,
**
kwargs
)
else
:
engine_core
=
EngineCoreProc
(
*
args
,
**
kwargs
)
# Send Readiness signal to EngineClient.
ready_pipe
.
send
({
"status"
:
"READY"
})
engine_core
.
run_busy_loop
()
except
SystemExit
:
...
...
@@ -350,26 +397,42 @@ class EngineCoreProc(EngineCore):
def
run_busy_loop
(
self
):
"""Core busy loop of the EngineCore."""
step_fn
=
(
self
.
step
if
self
.
batch_queue
is
None
else
self
.
step_with_batch_queue
)
# Loop until process is sent a SIGINT or SIGTERM
while
True
:
# 1) Poll the input queue until there is work to do.
while
not
self
.
scheduler
.
has_requests
():
logger
.
debug
(
"EngineCore busy loop waiting."
)
self
.
_process_input_queue
()
# 2) Step the engine core and return the outputs.
self
.
_process_engine_step
()
def
_process_input_queue
(
self
):
"""Exits when an engine step needs to be performed."""
waited
=
False
while
not
self
.
global_unfinished_reqs
and
not
(
self
.
scheduler
.
has_requests
()):
if
logger
.
isEnabledFor
(
DEBUG
)
and
self
.
input_queue
.
empty
():
logger
.
debug
(
"EngineCore waiting for work."
)
waited
=
True
req
=
self
.
input_queue
.
get
()
self
.
_handle_client_request
(
*
req
)
# 2) Handle any new client requests.
if
waited
:
logger
.
debug
(
"EngineCore loop active - local unfinished: %s, finished: %s."
,
self
.
scheduler
.
has_unfinished_requests
(),
self
.
scheduler
.
has_finished_requests
())
# Handle any more client requests.
while
not
self
.
input_queue
.
empty
():
req
=
self
.
input_queue
.
get_nowait
()
self
.
_handle_client_request
(
*
req
)
# 3) Step the engine core.
outputs
=
step_fn
()
def
_process_engine_step
(
self
):
"""Called only when there are unfinished local requests."""
# 4) Put EngineCoreOutputs into the output queue.
# Step the engine core.
outputs
=
self
.
step_fn
()
# Put EngineCoreOutputs into the output queue.
if
outputs
is
not
None
:
self
.
output_queue
.
put_nowait
(
outputs
)
...
...
@@ -381,6 +444,10 @@ class EngineCoreProc(EngineCore):
self
.
add_request
(
request
)
elif
request_type
==
EngineCoreRequestType
.
ABORT
:
self
.
abort_requests
(
request
)
elif
request_type
==
EngineCoreRequestType
.
START_DP
:
if
not
self
.
global_unfinished_reqs
:
logger
.
debug
(
"EngineCore starting idle loop."
)
self
.
global_unfinished_reqs
=
True
elif
request_type
==
EngineCoreRequestType
.
UTILITY
:
call_id
,
method_name
,
args
=
request
output
=
UtilityOutput
(
call_id
)
...
...
@@ -431,7 +498,7 @@ class EngineCoreProc(EngineCore):
# Push to input queue for core busy loop.
self
.
input_queue
.
put_nowait
((
request_type
,
request
))
def
process_output_socket
(
self
,
output_path
:
str
):
def
process_output_socket
(
self
,
output_path
:
str
,
engine_index
:
int
):
"""Output socket IO thread."""
# Msgpack serialization encoding.
...
...
@@ -442,5 +509,114 @@ class EngineCoreProc(EngineCore):
with
zmq_socket_ctx
(
output_path
,
zmq
.
constants
.
PUSH
)
as
socket
:
while
True
:
outputs
=
self
.
output_queue
.
get
()
outputs
.
engine_index
=
engine_index
encoder
.
encode_into
(
outputs
,
buffer
)
socket
.
send_multipart
((
buffer
,
),
copy
=
False
)
socket
.
send
(
buffer
,
copy
=
False
)
ENGINE_PAUSED_OUTPUTS
=
EngineCoreOutputs
(
engine_paused
=
True
)
class
DPEngineCoreProc
(
EngineCoreProc
):
"""ZMQ-wrapper for running EngineCore in background process
in a data parallel context."""
def
__init__
(
self
,
input_path
:
str
,
output_path
:
str
,
vllm_config
:
VllmConfig
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
):
# Add process-specific prefix to stdout and stderr before
# we initialize the engine.
from
multiprocessing
import
current_process
process_name
=
current_process
().
name
pid
=
os
.
getpid
()
_add_prefix
(
sys
.
stdout
,
process_name
,
pid
)
_add_prefix
(
sys
.
stderr
,
process_name
,
pid
)
dp_size
=
vllm_config
.
parallel_config
.
data_parallel_size
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
local_dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank_local
assert
dp_size
>
1
assert
0
<=
local_dp_rank
<=
dp_rank
<
dp_size
from
vllm.platforms
import
current_platform
if
current_platform
.
is_cuda_alike
():
from
vllm.platforms.cuda
import
device_id_to_physical_device_id
tp_size
=
vllm_config
.
parallel_config
.
tensor_parallel_size
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
","
.
join
(
str
(
device_id_to_physical_device_id
(
i
))
for
i
in
range
(
local_dp_rank
*
tp_size
,
(
local_dp_rank
+
1
)
*
tp_size
))
self
.
dp_group
=
vllm_config
.
parallel_config
.
stateless_init_dp_group
()
# Initialize the engine after setting up environment.
super
().
__init__
(
input_path
,
output_path
,
vllm_config
,
executor_class
,
log_stats
,
dp_rank
)
# Counts forward-passes of the model so that we can synchronize
# finished with DP peers every N steps.
self
.
counter
=
0
def
shutdown
(
self
):
super
().
shutdown
()
if
dp_group
:
=
getattr
(
self
,
"dp_group"
,
None
):
stateless_destroy_torch_distributed_process_group
(
dp_group
)
def
run_busy_loop
(
self
):
"""Core busy loop of the EngineCore for data parallel case."""
# Loop until process is sent a SIGINT or SIGTERM
while
True
:
# 1) Poll the input queue until there is work to do.
self
.
_process_input_queue
()
local_unfinished_reqs
=
self
.
scheduler
.
has_unfinished_requests
()
if
local_unfinished_reqs
:
# 2) Step the engine core.
self
.
_process_engine_step
()
# Check if we have now finished all requests.
local_unfinished_reqs
=
(
self
.
scheduler
.
has_unfinished_requests
())
else
:
if
self
.
scheduler
.
has_finished_requests
():
# There are no unfinished requests, but there are some
# finished requests remaining to be removed from the
# batch state. This engine step won't perform a forward
# pass but will flush the finished requests to ensure
# up-to-date state is returned in the engine outputs.
self
.
_process_engine_step
()
if
not
self
.
global_unfinished_reqs
:
# All engines are idle.
continue
# There must be unfinished requests in DP peers, run a
# dummy forward pass.
self
.
execute_dummy_batch
()
# 3) All-reduce operation to determine global unfinished reqs.
self
.
global_unfinished_reqs
=
self
.
_has_global_unfinished_reqs
(
local_unfinished_reqs
)
if
not
self
.
global_unfinished_reqs
:
# Notify client that we are pausing the loop.
self
.
output_queue
.
put_nowait
(
ENGINE_PAUSED_OUTPUTS
)
def
_has_global_unfinished_reqs
(
self
,
local_unfinished
:
bool
)
->
bool
:
# Optimization - only perform finish-sync all-reduce every 16 steps.
self
.
counter
+=
1
if
self
.
counter
!=
16
:
return
True
self
.
counter
=
0
return
ParallelConfig
.
has_unfinished_dp
(
self
.
dp_group
,
local_unfinished
)
vllm/v1/engine/core_client.py
View file @
fcfc474d
...
...
@@ -8,10 +8,11 @@ import threading
import
uuid
import
weakref
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Awaitable
,
Sequence
from
concurrent.futures
import
Future
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
,
field
from
threading
import
Thread
from
typing
import
Any
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
TypeVar
,
Union
import
zmq
import
zmq.asyncio
...
...
@@ -32,6 +33,8 @@ logger = init_logger(__name__)
AnyFuture
=
Union
[
asyncio
.
Future
[
Any
],
Future
[
Any
]]
_R
=
TypeVar
(
'_R'
)
# Return type for collective_rpc
class
EngineCoreClient
(
ABC
):
"""
...
...
@@ -60,6 +63,9 @@ class EngineCoreClient(ABC):
"is not currently supported."
)
if
multiprocess_mode
and
asyncio_mode
:
if
vllm_config
.
parallel_config
.
data_parallel_size
>
1
:
return
DPAsyncMPClient
(
vllm_config
,
executor_class
,
log_stats
)
return
AsyncMPClient
(
vllm_config
,
executor_class
,
log_stats
)
if
multiprocess_mode
and
not
asyncio_mode
:
...
...
@@ -86,7 +92,7 @@ class EngineCoreClient(ABC):
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
raise
NotImplementedError
def
wake_up
(
self
)
->
None
:
def
wake_up
(
self
,
tags
:
Optional
[
list
[
str
]]
=
None
)
->
None
:
raise
NotImplementedError
def
is_sleeping
(
self
)
->
bool
:
...
...
@@ -113,6 +119,19 @@ class EngineCoreClient(ABC):
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
raise
NotImplementedError
def
save_sharded_state
(
self
,
path
:
str
,
pattern
:
Optional
[
str
]
=
None
,
max_size
:
Optional
[
int
]
=
None
)
->
None
:
raise
NotImplementedError
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
[...,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
tuple
=
(),
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
raise
NotImplementedError
async
def
get_output_async
(
self
)
->
EngineCoreOutputs
:
raise
NotImplementedError
...
...
@@ -128,7 +147,7 @@ class EngineCoreClient(ABC):
async
def
sleep_async
(
self
,
level
:
int
=
1
)
->
None
:
raise
NotImplementedError
async
def
wake_up_async
(
self
)
->
None
:
async
def
wake_up_async
(
self
,
tags
:
Optional
[
list
[
str
]]
=
None
)
->
None
:
raise
NotImplementedError
async
def
is_sleeping_async
(
self
)
->
bool
:
...
...
@@ -149,6 +168,20 @@ class EngineCoreClient(ABC):
async
def
pin_lora_async
(
self
,
lora_id
:
int
)
->
bool
:
raise
NotImplementedError
async
def
save_sharded_state_async
(
self
,
path
:
str
,
pattern
:
Optional
[
str
]
=
None
,
max_size
:
Optional
[
int
]
=
None
)
->
None
:
raise
NotImplementedError
async
def
collective_rpc_async
(
self
,
method
:
Union
[
str
,
Callable
[...,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
tuple
=
(),
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
raise
NotImplementedError
class
InprocClient
(
EngineCoreClient
):
"""
...
...
@@ -185,8 +218,8 @@ class InprocClient(EngineCoreClient):
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
self
.
engine_core
.
sleep
(
level
)
def
wake_up
(
self
)
->
None
:
self
.
engine_core
.
wake_up
()
def
wake_up
(
self
,
tags
:
Optional
[
list
[
str
]]
=
None
)
->
None
:
self
.
engine_core
.
wake_up
(
tags
)
def
is_sleeping
(
self
)
->
bool
:
return
self
.
engine_core
.
is_sleeping
()
...
...
@@ -206,29 +239,88 @@ class InprocClient(EngineCoreClient):
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
engine_core
.
pin_lora
(
lora_id
)
def
save_sharded_state
(
self
,
path
:
str
,
pattern
:
Optional
[
str
]
=
None
,
max_size
:
Optional
[
int
]
=
None
)
->
None
:
self
.
engine_core
.
save_sharded_state
(
path
,
pattern
,
max_size
)
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
[...,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
tuple
=
(),
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
return
self
.
engine_core
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
class
CoreEngine
:
"""One per data parallel rank."""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
ctx
:
Union
[
zmq
.
Context
,
zmq
.
asyncio
.
Context
],
output_path
:
str
,
index
:
int
=
0
,
local_dp_rank
:
int
=
0
,
):
# Paths and sockets for IPC.
input_path
=
get_open_zmq_ipc_path
()
self
.
input_socket
=
make_zmq_socket
(
ctx
,
input_path
,
zmq
.
constants
.
PUSH
)
try
:
# Start EngineCore in background process.
self
.
proc_handle
=
BackgroundProcHandle
(
input_path
=
input_path
,
output_path
=
output_path
,
process_name
=
f
"EngineCore_
{
index
}
"
,
target_fn
=
EngineCoreProc
.
run_engine_core
,
process_kwargs
=
{
"vllm_config"
:
vllm_config
,
"dp_rank"
:
index
,
"local_dp_rank"
:
local_dp_rank
,
"executor_class"
:
executor_class
,
"log_stats"
:
log_stats
,
})
self
.
num_reqs_in_flight
=
0
finally
:
if
not
hasattr
(
self
,
"num_reqs_in_flight"
):
# Ensure socket is closed if process fails to start.
self
.
close
()
def
send_multipart
(
self
,
msg_parts
:
Sequence
):
return
self
.
input_socket
.
send_multipart
(
msg_parts
,
copy
=
False
)
def
close
(
self
):
if
proc_handle
:
=
getattr
(
self
,
"proc_handle"
,
None
):
proc_handle
.
shutdown
()
if
socket
:
=
getattr
(
self
,
"input_socket"
,
None
):
socket
.
close
(
linger
=
0
)
@
dataclass
class
BackgroundResources
:
"""Used as a finalizer for clean shutdown, avoiding
circular reference back to the client object."""
ctx
:
zmq
.
Context
ctx
:
Union
[
zmq
.
Context
]
core_engines
:
list
[
CoreEngine
]
=
field
(
default_factory
=
list
)
output_socket
:
Optional
[
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]]
=
None
input_socket
:
Optional
[
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]]
=
None
proc_handle
:
Optional
[
BackgroundProcHandle
]
=
None
shutdown_path
:
Optional
[
str
]
=
None
def
__call__
(
self
):
"""Clean up background resources."""
if
self
.
proc_handle
is
not
None
:
self
.
proc_handle
.
shutdown
()
for
core_engine
in
self
.
core_engines
:
core_engine
.
close
()
# ZMQ context termination can hang if the sockets
# aren't explicitly closed first.
if
self
.
output_socket
is
not
None
:
self
.
output_socket
.
close
(
linger
=
0
)
if
self
.
input_socket
is
not
None
:
self
.
input_socket
.
close
(
linger
=
0
)
if
self
.
shutdown_path
is
not
None
:
# We must ensure that the sync output socket is
# closed cleanly in its own thread.
...
...
@@ -284,7 +376,7 @@ class MPClient(EngineCoreClient):
self
.
decoder
=
MsgpackDecoder
(
EngineCoreOutputs
)
# ZMQ setup.
sync_ctx
=
zmq
.
Context
()
sync_ctx
=
zmq
.
Context
(
io_threads
=
2
)
self
.
ctx
=
zmq
.
asyncio
.
Context
(
sync_ctx
)
if
asyncio_mode
else
sync_ctx
# This will ensure resources created so far are closed
...
...
@@ -293,28 +385,38 @@ class MPClient(EngineCoreClient):
self
.
resources
=
BackgroundResources
(
ctx
=
sync_ctx
)
self
.
_finalizer
=
weakref
.
finalize
(
self
,
self
.
resources
)
# Paths for IPC.
# Paths
and sockets
for IPC.
self
.
output_path
=
get_open_zmq_ipc_path
()
input_path
=
get_open_zmq_ipc_path
()
# Start EngineCore in background process.
self
.
resources
.
proc_handle
=
BackgroundProcHandle
(
in
put_path
=
input_path
,
output_path
=
self
.
output_path
,
process_name
=
"EngineCore"
,
target_fn
=
EngineCoreProc
.
run
_engine
_core
,
process_kwargs
=
{
"vllm_config"
:
vllm_config
,
"executor_class"
:
executor_class
,
"log_stats"
:
log_stats
,
}
)
new_core_engine
=
lambda
index
,
local_dp_rank
=
None
:
CoreEngine
(
vllm_config
,
executor_class
,
log_stats
,
self
.
ctx
,
self
.
output_path
,
in
dex
,
local_dp_rank
)
# Start engine core process(es).
self
.
_init_core_engines
(
vllm_config
,
new_core
_engine
,
self
.
resources
.
core_engines
)
# Wait for engine core process(es) to start.
for
engine
in
self
.
resources
.
core_engines
:
engine
.
proc_handle
.
wait_for_startup
(
)
# Create input socket.
self
.
resources
.
input_socket
=
make_zmq_socket
(
self
.
ctx
,
input_path
,
zmq
.
constants
.
PUSH
)
self
.
input_socket
=
self
.
resources
.
input_socket
self
.
utility_results
:
dict
[
int
,
AnyFuture
]
=
{}
def
_init_core_engines
(
self
,
vllm_config
:
VllmConfig
,
new_core_engine
:
Callable
[[
int
,
Optional
[
int
]],
CoreEngine
],
core_engines
:
list
[
CoreEngine
],
)
->
None
:
# Default case - single core engine.
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
local_dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank_local
core_engine
=
new_core_engine
(
dp_rank
,
local_dp_rank
if
local_dp_rank
is
not
None
else
dp_rank
)
core_engines
.
append
(
core_engine
)
self
.
core_engine
=
core_engine
def
shutdown
(
self
):
self
.
_finalizer
()
...
...
@@ -356,9 +458,9 @@ class SyncMPClient(MPClient):
def
process_outputs_socket
():
shutdown_socket
=
ctx
.
socket
(
zmq
.
PAIR
)
shutdown_socket
.
bind
(
shutdown_path
)
out_socket
=
make_zmq_socket
(
ctx
,
output_path
,
zmq
.
constants
.
PULL
)
try
:
shutdown_socket
.
bind
(
shutdown_path
)
poller
=
zmq
.
Poller
()
poller
.
register
(
shutdown_socket
)
poller
.
register
(
out_socket
)
...
...
@@ -370,7 +472,7 @@ class SyncMPClient(MPClient):
# shutdown signal, exit thread.
break
(
frame
,
)
=
out_socket
.
recv
_multipart
(
copy
=
False
)
frame
=
out_socket
.
recv
(
copy
=
False
)
outputs
=
decoder
.
decode
(
frame
.
buffer
)
if
outputs
.
utility_output
:
_process_utility_output
(
outputs
.
utility_output
,
...
...
@@ -391,18 +493,15 @@ class SyncMPClient(MPClient):
def
get_output
(
self
)
->
EngineCoreOutputs
:
return
self
.
outputs_queue
.
get
()
def
_send_input
(
self
,
request_type
:
EngineCoreRequestType
,
request
:
Any
)
->
None
:
def
_send_input
(
self
,
request_type
:
EngineCoreRequestType
,
request
:
Any
):
# (RequestType, SerializedRequest)
msg
=
(
request_type
.
value
,
self
.
encoder
.
encode
(
request
))
self
.
input_socket
.
send_multipart
(
msg
,
copy
=
False
)
self
.
core_engine
.
send_multipart
(
msg
)
def
_
call_utility
(
self
,
method
:
str
,
*
args
)
->
Any
:
def
call_utility
(
self
,
method
:
str
,
*
args
)
->
Any
:
call_id
=
uuid
.
uuid1
().
int
>>
64
future
:
Future
[
Any
]
=
Future
()
self
.
utility_results
[
call_id
]
=
future
self
.
_send_input
(
EngineCoreRequestType
.
UTILITY
,
(
call_id
,
method
,
args
))
...
...
@@ -419,34 +518,48 @@ class SyncMPClient(MPClient):
self
.
_send_input
(
EngineCoreRequestType
.
ABORT
,
request_ids
)
def
profile
(
self
,
is_start
:
bool
=
True
)
->
None
:
self
.
_
call_utility
(
"profile"
,
is_start
)
self
.
call_utility
(
"profile"
,
is_start
)
def
reset_prefix_cache
(
self
)
->
None
:
self
.
_
call_utility
(
"reset_prefix_cache"
)
self
.
call_utility
(
"reset_prefix_cache"
)
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
return
self
.
_
call_utility
(
"add_lora"
,
lora_request
)
return
self
.
call_utility
(
"add_lora"
,
lora_request
)
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
_
call_utility
(
"remove_lora"
,
lora_id
)
return
self
.
call_utility
(
"remove_lora"
,
lora_id
)
def
list_loras
(
self
)
->
set
[
int
]:
return
self
.
_
call_utility
(
"list_loras"
)
return
self
.
call_utility
(
"list_loras"
)
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
_
call_utility
(
"pin_lora"
,
lora_id
)
return
self
.
call_utility
(
"pin_lora"
,
lora_id
)
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
self
.
_
call_utility
(
"sleep"
,
level
)
self
.
call_utility
(
"sleep"
,
level
)
def
wake_up
(
self
)
->
None
:
self
.
_
call_utility
(
"wake_up"
)
def
wake_up
(
self
,
tags
:
Optional
[
list
[
str
]]
=
None
)
->
None
:
self
.
call_utility
(
"wake_up"
,
tags
)
def
is_sleeping
(
self
)
->
bool
:
return
self
.
_
call_utility
(
"is_sleeping"
)
return
self
.
call_utility
(
"is_sleeping"
)
def
execute_dummy_batch
(
self
)
->
None
:
self
.
_call_utility
(
"execute_dummy_batch"
)
self
.
call_utility
(
"execute_dummy_batch"
)
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
[...,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
tuple
=
(),
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
return
self
.
call_utility
(
"collective_rpc"
,
method
,
timeout
,
args
,
kwargs
)
def
save_sharded_state
(
self
,
path
:
str
,
pattern
:
Optional
[
str
]
=
None
,
max_size
:
Optional
[
int
]
=
None
)
->
None
:
self
.
call_utility
(
"save_sharded_state"
,
path
,
pattern
,
max_size
)
class
AsyncMPClient
(
MPClient
):
...
...
@@ -464,13 +577,21 @@ class AsyncMPClient(MPClient):
self
.
outputs_queue
:
Optional
[
asyncio
.
Queue
[
EngineCoreOutputs
]]
=
None
self
.
queue_task
:
Optional
[
asyncio
.
Task
]
=
None
async
def
_start_output_queue_task
(
self
):
self
.
outputs_handler
:
Optional
[
Callable
[
[
AsyncMPClient
,
EngineCoreOutputs
],
Awaitable
[
None
]]]
=
None
def
_ensure_output_queue_task
(
self
):
if
self
.
outputs_queue
is
not
None
:
return
# Perform IO in separate task to parallelize as much as possible.
# Avoid task having direct reference back to the client.
self
.
outputs_queue
=
asyncio
.
Queue
()
decoder
=
self
.
decoder
utility_results
=
self
.
utility_results
outputs_queue
=
self
.
outputs_queue
output_handler
=
self
.
outputs_handler
_self_ref
=
weakref
.
ref
(
self
)
if
output_handler
else
None
output_path
=
self
.
output_path
output_socket
=
make_zmq_socket
(
self
.
ctx
,
output_path
,
zmq
.
constants
.
PULL
)
...
...
@@ -483,34 +604,52 @@ class AsyncMPClient(MPClient):
if
outputs
.
utility_output
:
_process_utility_output
(
outputs
.
utility_output
,
utility_results
)
else
:
continue
if
output_handler
is
not
None
:
assert
_self_ref
is
not
None
_self
=
_self_ref
()
if
not
_self
:
# Client has been garbage collected, abort.
return
await
output_handler
(
_self
,
outputs
)
if
outputs
.
outputs
or
outputs
.
scheduler_stats
:
outputs_queue
.
put_nowait
(
outputs
)
self
.
queue_task
=
asyncio
.
create_task
(
process_outputs_socket
(),
name
=
"EngineCoreOutputQueueTask"
)
async
def
get_output_async
(
self
)
->
EngineCoreOutputs
:
if
self
.
outputs_queue
is
None
:
await
self
.
_start_output_queue_task
()
self
.
_ensure_output_queue_task
()
assert
self
.
outputs_queue
is
not
None
return
await
self
.
outputs_queue
.
get
()
async
def
_send_input
(
self
,
request_type
:
EngineCoreRequestType
,
request
:
Any
)
->
None
:
await
self
.
core_engine
.
send_multipart
(
(
request_type
.
value
,
self
.
encoder
.
encode
(
request
)))
msg
=
(
request_type
.
value
,
self
.
encoder
.
encode
(
request
))
await
self
.
input_socket
.
send_multipart
(
msg
,
copy
=
False
)
self
.
_ensure_output_queue_task
()
if
self
.
outputs_queue
is
None
:
await
self
.
_start_output_queue_task
()
async
def
call_utility_async
(
self
,
method
:
str
,
*
args
)
->
Any
:
return
await
self
.
_call_utility_async
(
method
,
*
args
,
engine
=
self
.
core_engine
)
async
def
_call_utility_async
(
self
,
method
:
str
,
*
args
)
->
Any
:
async
def
_call_utility_async
(
self
,
method
:
str
,
*
args
,
engine
:
CoreEngine
,
)
->
Any
:
call_id
=
uuid
.
uuid1
().
int
>>
64
future
=
asyncio
.
get_running_loop
().
create_future
()
self
.
utility_results
[
call_id
]
=
future
await
self
.
_send_input
(
EngineCoreRequestType
.
UTILITY
,
(
call_id
,
method
,
args
))
message
=
(
EngineCoreRequestType
.
UTILITY
.
value
,
self
.
encoder
.
encode
((
call_id
,
method
,
args
)))
await
engine
.
send_multipart
(
message
)
self
.
_ensure_output_queue_task
()
return
await
future
async
def
add_request_async
(
self
,
request
:
EngineCoreRequest
)
->
None
:
...
...
@@ -524,31 +663,162 @@ class AsyncMPClient(MPClient):
await
self
.
_send_input
(
EngineCoreRequestType
.
ABORT
,
request_ids
)
async
def
profile_async
(
self
,
is_start
:
bool
=
True
)
->
None
:
await
self
.
_
call_utility_async
(
"profile"
,
is_start
)
await
self
.
call_utility_async
(
"profile"
,
is_start
)
async
def
reset_prefix_cache_async
(
self
)
->
None
:
await
self
.
_
call_utility_async
(
"reset_prefix_cache"
)
await
self
.
call_utility_async
(
"reset_prefix_cache"
)
async
def
sleep_async
(
self
,
level
:
int
=
1
)
->
None
:
await
self
.
_
call_utility_async
(
"sleep"
,
level
)
await
self
.
call_utility_async
(
"sleep"
,
level
)
async
def
wake_up_async
(
self
)
->
None
:
await
self
.
_
call_utility_async
(
"wake_up"
)
async
def
wake_up_async
(
self
,
tags
:
Optional
[
list
[
str
]]
=
None
)
->
None
:
await
self
.
call_utility_async
(
"wake_up"
,
tags
)
async
def
is_sleeping_async
(
self
)
->
bool
:
return
await
self
.
_
call_utility_async
(
"is_sleeping"
)
return
await
self
.
call_utility_async
(
"is_sleeping"
)
async
def
execute_dummy_batch_async
(
self
)
->
None
:
await
self
.
_
call_utility_async
(
"execute_dummy_batch"
)
await
self
.
call_utility_async
(
"execute_dummy_batch"
)
async
def
add_lora_async
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
return
await
self
.
_
call_utility_async
(
"add_lora"
,
lora_request
)
return
await
self
.
call_utility_async
(
"add_lora"
,
lora_request
)
async
def
remove_lora_async
(
self
,
lora_id
:
int
)
->
bool
:
return
await
self
.
_
call_utility_async
(
"remove_lora"
,
lora_id
)
return
await
self
.
call_utility_async
(
"remove_lora"
,
lora_id
)
async
def
list_loras_async
(
self
)
->
set
[
int
]:
return
await
self
.
_
call_utility_async
(
"list_loras"
)
return
await
self
.
call_utility_async
(
"list_loras"
)
async
def
pin_lora_async
(
self
,
lora_id
:
int
)
->
bool
:
return
await
self
.
_call_utility_async
(
"pin_lora"
,
lora_id
)
return
await
self
.
call_utility_async
(
"pin_lora"
,
lora_id
)
async
def
save_sharded_state_async
(
self
,
path
:
str
,
pattern
:
Optional
[
str
]
=
None
,
max_size
:
Optional
[
int
]
=
None
)
->
None
:
await
self
.
call_utility_async
(
"save_sharded_state"
,
path
,
pattern
,
max_size
)
async
def
collective_rpc_async
(
self
,
method
:
Union
[
str
,
Callable
[...,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
tuple
=
(),
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
return
await
self
.
call_utility_async
(
"collective_rpc"
,
method
,
timeout
,
args
,
kwargs
)
class
DPAsyncMPClient
(
AsyncMPClient
):
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)
EngineCore."""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
):
super
().
__init__
(
vllm_config
,
executor_class
,
log_stats
)
assert
len
(
self
.
core_engines
)
>
1
# Control message used for triggering dp idle mode loop.
self
.
start_dp_msg
=
(
EngineCoreRequestType
.
START_DP
.
value
,
self
.
encoder
.
encode
(
None
))
self
.
num_engines_running
=
0
self
.
reqs_in_flight
:
dict
[
str
,
CoreEngine
]
=
{}
self
.
outputs_handler
=
DPAsyncMPClient
.
process_engine_outputs
# type: ignore[assignment]
def
_init_core_engines
(
self
,
vllm_config
:
VllmConfig
,
new_core_engine
:
Callable
[[
int
,
Optional
[
int
]],
CoreEngine
],
core_engines
:
list
[
CoreEngine
],
)
->
None
:
# Launch a core engine for each data parallel rank.
dp_size
=
vllm_config
.
parallel_config
.
data_parallel_size
for
i
in
range
(
dp_size
):
# Multi-node not yet supported so local_dp_rank == dp_rank.
core_engines
.
append
(
new_core_engine
(
i
,
i
))
self
.
core_engines
=
core_engines
async
def
call_utility_async
(
self
,
method
:
str
,
*
args
)
->
Any
:
# Only the result from the first engine is returned.
return
(
await
asyncio
.
gather
(
*
[
self
.
_call_utility_async
(
method
,
*
args
,
engine
=
engine
)
for
engine
in
self
.
core_engines
]))[
0
]
async
def
add_request_async
(
self
,
request
:
EngineCoreRequest
)
->
None
:
# NOTE: text prompt is not needed in the core engine as it has been
# tokenized.
request
.
prompt
=
None
msg
=
(
EngineCoreRequestType
.
ADD
.
value
,
self
.
encoder
.
encode
(
request
))
chosen_engine
=
self
.
get_core_engine_for_request
()
self
.
reqs_in_flight
[
request
.
request_id
]
=
chosen_engine
chosen_engine
.
num_reqs_in_flight
+=
1
if
self
.
num_engines_running
>=
len
(
self
.
core_engines
):
await
chosen_engine
.
send_multipart
(
msg
)
else
:
# Send request to chosen engine and dp start loop
# control message to all other engines.
self
.
num_engines_running
+=
len
(
self
.
core_engines
)
await
asyncio
.
gather
(
*
[
engine
.
send_multipart
(
msg
if
engine
is
chosen_engine
else
self
.
start_dp_msg
)
for
engine
in
self
.
core_engines
])
self
.
_ensure_output_queue_task
()
def
get_core_engine_for_request
(
self
)
->
CoreEngine
:
return
min
(
self
.
core_engines
,
key
=
lambda
e
:
e
.
num_reqs_in_flight
)
@
staticmethod
async
def
process_engine_outputs
(
self
:
"DPAsyncMPClient"
,
outputs
:
EngineCoreOutputs
):
if
self
.
reqs_in_flight
:
for
req_id
in
outputs
.
finished_requests
or
():
if
engine
:
=
self
.
reqs_in_flight
.
pop
(
req_id
,
None
):
engine
.
num_reqs_in_flight
-=
1
if
outputs
.
engine_paused
:
assert
self
.
num_engines_running
>=
1
self
.
num_engines_running
-=
1
if
not
self
.
num_engines_running
and
self
.
reqs_in_flight
:
# If there are requests in flight here, they must have
# been sent after the engines paused. We must make
# sure to start the other engines:
self
.
num_engines_running
=
len
(
self
.
core_engines
)
coros
=
[
engine
.
send_multipart
(
self
.
start_dp_msg
)
for
engine
in
self
.
core_engines
if
not
engine
.
num_reqs_in_flight
]
if
coros
:
await
asyncio
.
gather
(
*
coros
)
async
def
abort_requests_async
(
self
,
request_ids
:
list
[
str
])
->
None
:
if
not
request_ids
:
return
if
len
(
request_ids
)
==
1
:
# Fast-path common case.
if
engine
:
=
self
.
reqs_in_flight
.
get
(
request_ids
[
0
]):
await
self
.
_abort_requests
(
request_ids
,
engine
)
return
by_engine
:
dict
[
CoreEngine
,
list
[
str
]]
=
{}
for
req_id
in
request_ids
:
if
engine
:
=
self
.
reqs_in_flight
.
get
(
req_id
):
by_engine
.
setdefault
(
engine
,
[]).
append
(
req_id
)
for
engine
,
req_ids
in
by_engine
.
items
():
await
self
.
_abort_requests
(
req_ids
,
engine
)
async
def
_abort_requests
(
self
,
request_ids
:
list
[
str
],
engine
:
CoreEngine
)
->
None
:
await
engine
.
send_multipart
((
EngineCoreRequestType
.
ABORT
.
value
,
self
.
encoder
.
encode
(
request_ids
)))
vllm/v1/engine/llm_engine.py
View file @
fcfc474d
...
...
@@ -2,15 +2,16 @@
from
collections.abc
import
Mapping
from
copy
import
copy
from
typing
import
Optional
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
Union
from
typing_extensions
import
TypeVar
import
vllm.envs
as
envs
from
vllm.config
import
ParallelConfig
,
VllmConfig
from
vllm.distributed
import
stateless_destroy_torch_distributed_process_group
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.metrics_types
import
StatLoggerBase
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
,
PromptType
from
vllm.inputs
import
PromptType
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
...
...
@@ -31,6 +32,7 @@ from vllm.v1.executor.abstract import Executor
logger
=
init_logger
(
__name__
)
_G
=
TypeVar
(
"_G"
,
bound
=
BaseTokenizerGroup
,
default
=
BaseTokenizerGroup
)
_R
=
TypeVar
(
"_R"
,
default
=
Any
)
class
LLMEngine
:
...
...
@@ -43,7 +45,6 @@ class LLMEngine:
log_stats
:
bool
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
stat_loggers
:
Optional
[
dict
[
str
,
StatLoggerBase
]]
=
None
,
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
mm_registry
:
MultiModalRegistry
=
MULTIMODAL_REGISTRY
,
use_cached_outputs
:
bool
=
False
,
multiprocess_mode
:
bool
=
False
,
...
...
@@ -60,11 +61,13 @@ class LLMEngine:
self
.
cache_config
=
vllm_config
.
cache_config
# important: init dp group before init the engine_core
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
dp_enabled
=
self
.
parallel_config
.
data_parallel_size
>
1
# noqa
# In the decoupled engine case this is handled in EngineCoreProc.
parallel_config
=
vllm_config
.
parallel_config
if
not
multiprocess_mode
and
parallel_config
.
data_parallel_size
>
1
:
self
.
dp_group
=
parallel_config
.
stateless_init_dp_group
()
else
:
self
.
dp_group
=
None
self
.
should_execute_dummy_batch
=
False
if
self
.
dp_enabled
:
self
.
dp_group
=
self
.
parallel_config
.
stateless_init_dp_group
()
# Tokenizer (+ ensure liveness if running in another process).
self
.
tokenizer
=
init_tokenizer_from_configs
(
...
...
@@ -77,7 +80,6 @@ class LLMEngine:
# Processor (convert Inputs --> EngineCoreRequests)
self
.
processor
=
Processor
(
vllm_config
=
vllm_config
,
tokenizer
=
self
.
tokenizer
,
input_registry
=
input_registry
,
mm_registry
=
mm_registry
)
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
...
...
@@ -148,7 +150,7 @@ class LLMEngine:
def
has_unfinished_requests
(
self
)
->
bool
:
has_unfinished
=
self
.
output_processor
.
has_unfinished_requests
()
if
not
self
.
dp_
enabled
:
if
self
.
dp_
group
is
None
:
return
has_unfinished
return
self
.
has_unfinished_requests_dp
(
has_unfinished
)
...
...
@@ -243,8 +245,8 @@ class LLMEngine:
def
sleep
(
self
,
level
:
int
=
1
):
self
.
engine_core
.
sleep
(
level
)
def
wake_up
(
self
):
self
.
engine_core
.
wake_up
()
def
wake_up
(
self
,
tags
:
Optional
[
list
[
str
]]
=
None
):
self
.
engine_core
.
wake_up
(
tags
)
def
is_sleeping
(
self
)
->
bool
:
return
self
.
engine_core
.
is_sleeping
()
...
...
@@ -280,3 +282,14 @@ class LLMEngine:
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
"""Prevent an adapter from being evicted."""
return
self
.
engine_core
.
pin_lora
(
lora_id
)
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
[...,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
tuple
=
(),
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
return
self
.
engine_core
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
def
__del__
(
self
):
if
dp_group
:
=
getattr
(
self
,
"dp_group"
,
None
):
stateless_destroy_torch_distributed_process_group
(
dp_group
)
vllm/v1/engine/output_processor.py
View file @
fcfc474d
...
...
@@ -328,7 +328,7 @@ class OutputProcessor:
# 2) Detokenize the token ids into text and perform stop checks.
stop_string
=
req_state
.
detokenizer
.
update
(
new_token_ids
,
finish_reason
==
FinishReason
.
STOP
)
if
stop_string
and
finish_reason
!=
FinishReason
.
STOP
:
if
stop_string
:
finish_reason
=
FinishReason
.
STOP
stop_reason
=
stop_string
...
...
vllm/v1/engine/processor.py
View file @
fcfc474d
...
...
@@ -5,9 +5,8 @@ from collections.abc import Mapping
from
typing
import
Optional
,
Union
from
vllm.config
import
VllmConfig
from
vllm.inputs
import
(
INPUT_REGISTRY
,
InputRegistry
,
ProcessorInputs
,
PromptType
,
SingletonInputsAdapter
)
from
vllm.inputs.parse
import
is_encoder_decoder_inputs
from
vllm.inputs
import
ProcessorInputs
,
PromptType
from
vllm.inputs.parse
import
split_enc_dec_inputs
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
MultiModalKwargs
,
...
...
@@ -31,7 +30,6 @@ class Processor:
self
,
vllm_config
:
VllmConfig
,
tokenizer
:
BaseTokenizerGroup
,
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
mm_registry
:
MultiModalRegistry
=
MULTIMODAL_REGISTRY
,
):
...
...
@@ -123,7 +121,8 @@ class Processor:
return
supported_backends
=
[
"xgrammar"
,
"xgrammar:disable-any-whitespace"
,
"guidance"
,
"auto"
"xgrammar"
,
"xgrammar:disable-any-whitespace"
,
"guidance"
,
"guidance:disable-any-whitespace"
,
"auto"
]
engine_level_backend
=
self
.
decoding_config
.
guided_decoding_backend
if
engine_level_backend
not
in
supported_backends
:
...
...
@@ -137,13 +136,15 @@ class Processor:
f
" !=
{
engine_level_backend
}
"
)
else
:
params
.
guided_decoding
.
backend
=
engine_level_backend
import
vllm.platforms
if
vllm
.
platforms
.
current_platform
.
is_tpu
():
raise
ValueError
(
"Structured output is not supported on TPU."
)
# Request content validation
if
engine_level_backend
==
"xgrammar"
:
if
engine_level_backend
.
startswith
(
"xgrammar"
):
# xgrammar with no fallback
validate_structured_output_request_xgrammar
(
params
)
params
.
guided_decoding
.
backend
=
"xgrammar"
params
.
guided_decoding
.
backend
=
engine_level_backend
elif
engine_level_backend
==
"auto"
:
# "auto" is an opt-in to opinionated behavior where we try to
# choose a backend based on request contents. This is not the
...
...
@@ -157,12 +158,13 @@ class Processor:
# are not supported in xgrammar. Fall back to guidance.
params
.
guided_decoding
.
backend
=
"guidance"
if
params
.
guided_decoding
.
backend
==
"guidance"
:
if
engine_level_backend
.
startswith
(
"guidance"
)
:
# TODO ideally we would have the LLTokenizer here as Lark syntax
# allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars.
validate_guidance_grammar
(
params
,
tokenizer
=
None
)
params
.
guided_decoding
.
backend
=
engine_level_backend
def
process_inputs
(
self
,
...
...
@@ -206,14 +208,7 @@ class Processor:
self
.
_validate_model_inputs
(
processed_inputs
,
lora_request
)
if
is_encoder_decoder_inputs
(
processed_inputs
):
decoder_inputs
=
SingletonInputsAdapter
(
processed_inputs
[
"decoder"
])
encoder_inputs
=
SingletonInputsAdapter
(
processed_inputs
[
"encoder"
])
else
:
decoder_inputs
=
SingletonInputsAdapter
(
processed_inputs
)
encoder_inputs
=
None
encoder_inputs
,
decoder_inputs
=
split_enc_dec_inputs
(
processed_inputs
)
# TODO: Impl encoder-decoder
if
encoder_inputs
is
not
None
:
...
...
@@ -224,8 +219,9 @@ class Processor:
sampling_params
=
params
.
clone
()
# If unset max tokens, then generate up to the max_model_len.
if
sampling_params
.
max_tokens
is
None
:
sampling_params
.
max_tokens
=
(
self
.
model_config
.
max_model_len
-
len
(
decoder_inputs
.
prompt_token_ids
))
sampling_params
.
max_tokens
=
(
self
.
model_config
.
max_model_len
-
len
(
decoder_inputs
[
"prompt_token_ids"
]))
sampling_params
.
update_from_generation_config
(
self
.
generation_config_fields
,
eos_token_id
)
sampling_params
.
update_from_tokenizer
(
...
...
@@ -235,57 +231,46 @@ class Processor:
sorted_mm_inputs
:
Optional
[
list
[
MultiModalKwargs
]]
=
None
sorted_mm_positions
:
Optional
[
list
[
PlaceholderRange
]]
=
None
sorted_mm_hashes
:
Optional
[
list
[
str
]]
=
None
if
(
decoder_mm_inputs
:
=
decoder_inputs
.
multi_modal_data
):
assert
isinstance
(
decoder_mm_inputs
,
MultiModalKwargs
)
# The output of merged multi-modal processor (`decoder_mm_inputs`)
# contains the kwargs for all items from all modalities.
# This code separates them so that there is one set of kwargs
# per item per modality.
individual_mm_inputs
=
[
MultiModalKwargs
.
from_items
([
item
])
for
modality
in
decoder_mm_inputs
.
modalities
for
item
in
decoder_mm_inputs
.
get_items
(
modality
)
]
if
decoder_inputs
[
"type"
]
==
"multimodal"
:
decoder_mm_inputs
=
decoder_inputs
[
"mm_kwargs"
]
# Merge and flatten multimodal placeholders, hashes and inputs
# from dictionaries to lists, and sort them by each item's position
# in the input sequence.
# NOTE: interleaved modalities are not supported.
(
sorted_modalities
,
sorted_
item_
modalities
,
sorted_mm_positions
,
sorted_mm_hashes
,
)
=
merge_and_sort_multimodal_metadata
(
decoder_inputs
.
multi_modal
_placeholders
,
decoder_inputs
.
multi_modal
_hashes
if
self
.
use_hash
else
None
,
decoder_inputs
[
"mm
_placeholders
"
]
,
decoder_inputs
[
"mm
_hashes
"
]
if
self
.
use_hash
else
None
,
)
# NOTE: Sort multimodal inputs/kwargs ONLY IF there are multiple
# modalities involved.
if
len
(
sorted_modalities
)
>
1
:
modality_order_dict
=
{
modality
:
order
for
order
,
modality
in
enumerate
(
sorted_modalities
)
}
# Sanity check to make sure each multimodal input has only one
# modality key.
for
mm_input
in
individual_mm_inputs
:
assert
len
(
mm_input
.
modalities
)
==
1
# Sort MultiModalKwargs to match sorted_mm_positions
sorted_mm_inputs
=
sorted
(
individual_mm_inputs
,
key
=
lambda
mm_input
:
modality_order_dict
[
list
(
mm_input
.
modalities
)[
0
]])
# The output of merged multi-modal processor (`decoder_mm_inputs`)
# is a single MultiModalKwargs for all items from all modalities.
# This code flattens kwargs for individual items in a list and
# sorts them by each item's position in the input sequence if there
# are multiple modalities.
unique_modalities
=
set
(
sorted_item_modalities
)
if
len
(
unique_modalities
)
>
1
:
sorted_mm_inputs
=
[]
used_indices
=
{
modality
:
0
for
modality
in
unique_modalities
}
for
modality
in
sorted_item_modalities
:
items
=
decoder_mm_inputs
.
get_items
(
modality
)
item
=
items
[
used_indices
[
modality
]]
sorted_mm_inputs
.
append
(
MultiModalKwargs
.
from_items
([
item
]))
used_indices
[
modality
]
+=
1
else
:
sorted_mm_inputs
=
individual_mm_inputs
sorted_mm_inputs
=
[
MultiModalKwargs
.
from_items
([
item
])
for
item
in
decoder_mm_inputs
.
get_items
(
sorted_item_modalities
[
0
])
]
return
EngineCoreRequest
(
request_id
=
request_id
,
prompt
=
decoder_inputs
.
prompt
,
prompt_token_ids
=
decoder_inputs
.
prompt_token_ids
,
prompt
=
decoder_inputs
.
get
(
"
prompt
"
)
,
prompt_token_ids
=
decoder_inputs
[
"
prompt_token_ids
"
]
,
mm_inputs
=
sorted_mm_inputs
,
mm_hashes
=
sorted_mm_hashes
,
mm_placeholders
=
sorted_mm_positions
,
...
...
@@ -298,15 +283,16 @@ class Processor:
def
_validate_model_inputs
(
self
,
inputs
:
ProcessorInputs
,
lora_request
:
Optional
[
LoRARequest
]
=
None
):
if
is_encoder_decoder_inputs
(
inputs
):
encoder_inputs
,
decoder_inputs
=
split_enc_dec_inputs
(
inputs
)
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
prompt_inputs
=
inputs
[
"decoder"
if
self
.
model_config
.
is_multimodal_model
else
"encoder"
]
if
self
.
model_config
.
is_multimodal_model
:
prompt_inputs
=
decoder_inputs
else
:
prompt_inputs
=
inputs
prompt_inputs
=
encoder_inputs
or
decoder_
inputs
prompt_ids
=
SingletonInputsAdapter
(
prompt_inputs
).
prompt_token_ids
prompt_ids
=
prompt_inputs
[
"
prompt_token_ids
"
]
if
prompt_ids
is
None
or
len
(
prompt_ids
)
==
0
:
raise
ValueError
(
"Prompt cannot be empty"
)
...
...
vllm/v1/executor/multiproc_executor.py
View file @
fcfc474d
...
...
@@ -235,7 +235,10 @@ class WorkerProc:
worker_response_mq_handle
=
self
.
worker_response_mq
.
export_handle
()
# Send Readiness signal to EngineCore process.
with
zmq_socket_ctx
(
ready_path
,
zmq
.
constants
.
PUSH
)
as
ready_socket
:
# Set linger here because we want to ensure the message has
# been sent before the context is closed.
with
zmq_socket_ctx
(
ready_path
,
zmq
.
constants
.
PUSH
,
linger
=
10000
)
as
ready_socket
:
payload
=
pickle
.
dumps
(
worker_response_mq_handle
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
ready_socket
.
send_string
(
WorkerProc
.
READY_STR
)
...
...
@@ -270,11 +273,13 @@ class WorkerProc:
proc
=
context
.
Process
(
target
=
WorkerProc
.
worker_main
,
kwargs
=
process_kwargs
,
daemon
=
True
)
with
zmq_socket_ctx
(
ready_path
,
zmq
.
constants
.
PULL
)
as
ready_socket
:
proc
.
start
()
# Wait for startup
worker_response_mq_handle
=
WorkerProc
.
wait_for_startup
(
proc
,
ready_
path
)
proc
,
ready_
socket
)
worker_response_mq
=
MessageQueue
.
create_from_handle
(
worker_response_mq_handle
,
0
)
...
...
@@ -337,21 +342,20 @@ class WorkerProc:
@
staticmethod
def
wait_for_startup
(
proc
:
BaseProcess
,
ready_
path
:
str
,
ready_
socket
:
zmq
.
Socket
,
)
->
Optional
[
Handle
]:
"""Wait until the Worker is ready."""
with
zmq_socket_ctx
(
ready_path
,
zmq
.
constants
.
PULL
)
as
socket
:
# Wait for Worker to send READY.
while
socket
.
poll
(
timeout
=
POLLING_TIMEOUT_MS
)
==
0
:
while
ready_
socket
.
poll
(
timeout
=
POLLING_TIMEOUT_MS
)
==
0
:
logger
.
debug
(
"Waiting for WorkerProc to startup."
)
if
not
proc
.
is_alive
():
raise
RuntimeError
(
"WorkerProc failed to start."
)
message
=
socket
.
recv_string
()
message
=
ready_
socket
.
recv_string
()
assert
message
==
WorkerProc
.
READY_STR
handle_frame
=
socket
.
recv
(
copy
=
False
)
handle_frame
=
ready_
socket
.
recv
(
copy
=
False
)
handle
=
pickle
.
loads
(
handle_frame
.
buffer
)
return
handle
...
...
vllm/v1/kv_cache_interface.py
View file @
fcfc474d
...
...
@@ -4,6 +4,7 @@ from dataclasses import dataclass
import
torch
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.utils
import
cdiv
,
get_dtype_size
...
...
@@ -43,28 +44,23 @@ class KVCacheSpec:
"""
raise
NotImplementedError
def
bytes_for_tokens
(
self
,
num_tokens
:
int
)
->
int
:
def
max_memory_usage_bytes
(
self
,
vllm_config
:
VllmConfig
)
->
int
:
"""
The KV cache size for `num_tokens` tokens in bytes. Returns the real
memory size after padding `num_tokens` to full blocks.
The maximum possible memory usage of this KV cache in bytes.
Returns:
The KV cache size
The KV cache size
in bytes
"""
raise
NotImplementedError
@
dataclass
class
Full
AttentionSpec
(
KVCacheSpec
):
class
AttentionSpec
(
KVCacheSpec
):
num_kv_heads
:
int
head_size
:
int
dtype
:
torch
.
dtype
use_mla
:
bool
@
property
def
type_id
(
self
)
->
str
:
return
f
"full_attention_
{
self
.
block_size
}
_
{
self
.
page_size_bytes
}
"
@
property
def
page_size_bytes
(
self
)
->
int
:
# For MLA we only store a single latent vector
...
...
@@ -72,8 +68,47 @@ class FullAttentionSpec(KVCacheSpec):
return
coef
*
self
.
block_size
*
self
.
num_kv_heads
*
self
.
head_size
\
*
get_dtype_size
(
self
.
dtype
)
def
bytes_for_tokens
(
self
,
num_tokens
:
int
)
->
int
:
return
cdiv
(
num_tokens
,
self
.
block_size
)
*
self
.
page_size_bytes
@
dataclass
class
FullAttentionSpec
(
AttentionSpec
):
@
property
def
type_id
(
self
)
->
str
:
return
f
"full_attention_
{
self
.
block_size
}
_
{
self
.
page_size_bytes
}
"
def
max_memory_usage_bytes
(
self
,
vllm_config
:
VllmConfig
)
->
int
:
max_model_len
=
vllm_config
.
model_config
.
max_model_len
return
cdiv
(
max_model_len
,
self
.
block_size
)
*
self
.
page_size_bytes
@
dataclass
class
SlidingWindowSpec
(
AttentionSpec
):
sliding_window
:
int
def
__post_init__
(
self
):
assert
not
self
.
use_mla
,
"MLA is not supported for sliding window"
@
property
def
type_id
(
self
)
->
str
:
return
f
"sliding_window_
{
self
.
sliding_window
}
_
{
self
.
block_size
}
_
{
self
.
page_size_bytes
}
"
# noqa
def
max_memory_usage_bytes
(
self
,
vllm_config
:
VllmConfig
)
->
int
:
max_model_len
=
vllm_config
.
model_config
.
max_model_len
max_num_batched_tokens
=
(
vllm_config
.
scheduler_config
.
max_num_batched_tokens
)
# During chunked prefill, we allocate KV cache for the last
# `self.sliding_window-1` computed tokens plus the newly scheduled
# tokens. And we won't allocate KV cache for more than `max_model_len`
# tokens.
num_tokens
=
min
(
self
.
sliding_window
-
1
+
max_num_batched_tokens
,
max_model_len
)
# +1 here because the sliding window may not start from the beginning
# of the block. For example, if the block size is 4 and num_token
# is 4, we need two blocks [XXCD] [EF] to store the sliding
# window [CDEF] of 6 tokens.
return
(
cdiv
(
num_tokens
,
self
.
block_size
)
+
1
)
*
self
.
page_size_bytes
@
dataclass
...
...
vllm/v1/metrics/loggers.py
View file @
fcfc474d
...
...
@@ -12,6 +12,7 @@ from vllm.logger import init_logger
from
vllm.v1.core.kv_cache_utils
import
PrefixCachingMetrics
from
vllm.v1.engine
import
FinishReason
from
vllm.v1.metrics.stats
import
IterationStats
,
SchedulerStats
from
vllm.v1.spec_decode.metrics
import
SpecDecodingMetrics
logger
=
init_logger
(
__name__
)
...
...
@@ -31,12 +32,14 @@ class StatLoggerBase(ABC):
class
LoggingStatLogger
(
StatLoggerBase
):
def
__init__
(
self
):
def
__init__
(
self
,
engine_index
:
int
=
0
):
self
.
engine_index
=
engine_index
self
.
_reset
(
time
.
monotonic
())
self
.
last_scheduler_stats
=
SchedulerStats
()
# Prefix cache metrics. This cannot be reset.
# TODO: Make the interval configurable.
self
.
prefix_caching_metrics
=
PrefixCachingMetrics
()
self
.
spec_decoding_metrics
=
SpecDecodingMetrics
()
def
_reset
(
self
,
now
):
self
.
last_log_time
=
now
...
...
@@ -64,6 +67,10 @@ class LoggingStatLogger(StatLoggerBase):
self
.
prefix_caching_metrics
.
observe
(
scheduler_stats
.
prefix_cache_stats
)
if
scheduler_stats
.
spec_decoding_stats
is
not
None
:
self
.
spec_decoding_metrics
.
observe
(
scheduler_stats
.
spec_decoding_stats
)
self
.
last_scheduler_stats
=
scheduler_stats
def
log
(
self
):
...
...
@@ -78,11 +85,13 @@ class LoggingStatLogger(StatLoggerBase):
# Format and print output.
logger
.
info
(
"Engine %03d: "
"Avg prompt throughput: %.1f tokens/s, "
"Avg generation throughput: %.1f tokens/s, "
"Running: %d reqs, Waiting: %d reqs, "
"GPU KV cache usage: %.1f%%, "
"Prefix cache hit rate: %.1f%%"
,
self
.
engine_index
,
prompt_throughput
,
generation_throughput
,
scheduler_stats
.
num_running_reqs
,
...
...
@@ -91,10 +100,13 @@ class LoggingStatLogger(StatLoggerBase):
self
.
prefix_caching_metrics
.
hit_rate
*
100
,
)
if
scheduler_stats
.
spec_decoding_stats
is
not
None
:
self
.
spec_decoding_metrics
.
log
()
class
PrometheusStatLogger
(
StatLoggerBase
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
engine_index
:
int
=
0
):
self
.
_unregister_vllm_metrics
()
# Use this flag to hide metrics that were deprecated in
...
...
@@ -102,8 +114,11 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
show_hidden_metrics
=
\
vllm_config
.
observability_config
.
show_hidden_metrics
labelnames
=
[
"model_name"
]
labelvalues
=
[
vllm_config
.
model_config
.
served_model_name
]
labelnames
=
[
"model_name"
,
"engine"
]
labelvalues
=
[
vllm_config
.
model_config
.
served_model_name
,
str
(
engine_index
)
]
max_model_len
=
vllm_config
.
model_config
.
max_model_len
...
...
@@ -296,6 +311,24 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
labelname_running_lora_adapters
,
])
#
# Speculative Decoding metrics
# The acceptance rate can be calculated using a PromQL query:
#
# rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
# rate(vllm:spec_decode_num_draft_tokens_total[$interval])
#
self
.
counter_spec_decode_num_draft_tokens
=
\
prometheus_client
.
Counter
(
name
=
"vllm:spec_decode_num_draft_tokens_total"
,
documentation
=
"Number of draft tokens."
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
self
.
counter_spec_decode_num_accepted_tokens
=
\
prometheus_client
.
Counter
(
name
=
"vllm:spec_decode_num_accepted_tokens_total"
,
documentation
=
"Number of accepted tokens."
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
#
# Cache config info metric
#
...
...
@@ -332,6 +365,12 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
counter_gpu_prefix_cache_hits
.
inc
(
scheduler_stats
.
prefix_cache_stats
.
hits
)
if
scheduler_stats
.
spec_decoding_stats
is
not
None
:
self
.
counter_spec_decode_num_draft_tokens
.
inc
(
scheduler_stats
.
spec_decoding_stats
.
num_draft_tokens
)
self
.
counter_spec_decode_num_accepted_tokens
.
inc
(
scheduler_stats
.
spec_decoding_stats
.
num_accepted_tokens
)
if
iteration_stats
is
None
:
return
...
...
vllm/v1/metrics/stats.py
View file @
fcfc474d
...
...
@@ -4,6 +4,8 @@ import time
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Optional
from
vllm.v1.spec_decode.metrics
import
SpecDecodingStats
if
TYPE_CHECKING
:
from
vllm.v1.engine
import
EngineCoreEvent
,
EngineCoreOutput
,
FinishReason
from
vllm.v1.engine.output_processor
import
RequestState
...
...
@@ -35,6 +37,8 @@ class SchedulerStats:
prefix_cache_stats
:
PrefixCacheStats
=
field
(
default_factory
=
PrefixCacheStats
)
spec_decoding_stats
:
Optional
[
SpecDecodingStats
]
=
None
@
dataclass
class
LoRAStats
:
...
...
vllm/v1/request.py
View file @
fcfc474d
...
...
@@ -59,6 +59,8 @@ class Request:
self
.
mm_positions
=
multi_modal_placeholders
or
[]
self
.
mm_inputs
=
multi_modal_inputs
or
[]
self
.
mm_hashes
:
list
[
str
]
=
multi_modal_hashes
or
[]
self
.
num_encoder_inputs
=
len
(
self
.
mm_inputs
)
self
.
has_encoder_inputs
=
self
.
num_encoder_inputs
>
0
# Sanity check
assert
len
(
self
.
mm_inputs
)
==
len
(
self
.
mm_positions
)
...
...
@@ -93,7 +95,9 @@ class Request:
token_ids
:
Union
[
int
,
list
[
int
]],
)
->
None
:
if
isinstance
(
token_ids
,
int
):
token_ids
=
[
token_ids
]
self
.
_output_token_ids
.
append
(
token_ids
)
self
.
_all_token_ids
.
append
(
token_ids
)
else
:
self
.
_output_token_ids
.
extend
(
token_ids
)
self
.
_all_token_ids
.
extend
(
token_ids
)
...
...
@@ -115,13 +119,6 @@ class Request:
def
get_finished_reason
(
self
)
->
Union
[
FinishReason
,
None
]:
return
RequestStatus
.
get_finished_reason
(
self
.
status
)
def
has_encoder_inputs
(
self
)
->
bool
:
return
len
(
self
.
mm_inputs
)
>
0
@
property
def
num_encoder_inputs
(
self
)
->
int
:
return
len
(
self
.
mm_positions
)
def
get_num_encoder_tokens
(
self
,
input_id
:
int
)
->
int
:
assert
input_id
<
len
(
self
.
mm_positions
)
num_tokens
=
self
.
mm_positions
[
input_id
][
"length"
]
...
...
vllm/v1/sample/ops/topk_topp_sampler.py
View file @
fcfc474d
...
...
@@ -19,6 +19,12 @@ except ImportError:
class
TopKTopPSampler
(
nn
.
Module
):
"""
Module that performs optional top-k and top-p filtering followed by
weighted random sampling of logits.
Implementations may update the logits tensor in-place.
"""
def
__init__
(
self
):
super
().
__init__
()
...
...
@@ -84,7 +90,11 @@ class TopKTopPSampler(nn.Module):
k
:
Optional
[
torch
.
Tensor
],
p
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
"""PyTorch-native implementation of top-k and top-p sampling."""
"""
PyTorch-native implementation of top-k and top-p sampling.
The logits tensor may be updated in-place.
"""
logits
=
apply_top_k_top_p
(
logits
,
k
,
p
)
probs
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
return
random_sample
(
probs
,
generators
)
...
...
@@ -112,23 +122,48 @@ class TopKTopPSampler(nn.Module):
k
:
Optional
[
torch
.
Tensor
],
p
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# If only top-k is specified, use pytorch's builtin topk op. This leads
# to significant speed up on TPU compared to using apply_top_k_top_p.
if
k
is
not
None
and
p
is
None
:
topk_values
,
topk_indices
=
torch
.
topk
(
logits
,
k
,
dim
=-
1
)
mask
=
torch
.
ones_like
(
logits
,
dtype
=
torch
.
bool
)
mask
.
scatter_
(
-
1
,
topk_indices
,
False
)
logits
.
masked_fill_
(
mask
,
float
(
'-inf'
))
else
:
# TODO Placeholder for TPU optimized topp kernel
# logits = apply_top_k_top_p(logits, k, p)
pass
logits
=
apply_top_k_top_p_tpu
(
logits
,
k
,
p
)
probs
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
return
random_sample
(
probs
,
generators
)
def
apply_top_k_top_p_tpu
(
logits
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""
Apply top-k and top-p optimized for TPU.
This algorithm avoids using torch.scatter which is extremely slow on TPU.
This is achieved by finding a "cut-off" element in the original logit, and
after thresholding the logit using this cut-off, the remaining elements
shall constitute the top-p set.
Note: in the case of tie (i.e. multipple cut-off elements present in the
logit), all tie elements are included in the top-p set. In other words,
this function does not break ties. Instead, these tie tokens have equal
chance of being chosen during final sampling, so we can consider the tie
being broken then.
"""
if
k
is
not
None
:
logits
=
apply_top_k_only
(
logits
,
k
)
if
p
is
not
None
:
probs
=
logits
.
softmax
(
dim
=-
1
)
probs_sort
,
_
=
probs
.
sort
(
dim
=-
1
,
descending
=
False
)
cumprob
=
torch
.
cumsum
(
probs_sort
,
dim
=-
1
)
top_p_mask
=
cumprob
<=
1
-
p
.
unsqueeze
(
dim
=
1
)
top_p_mask
[:,
-
1
]
=
False
# at least one
top_p_count
=
top_p_mask
.
sum
(
dim
=-
1
).
unsqueeze
(
1
)
top_p_cutoff
=
probs_sort
.
gather
(
-
1
,
top_p_count
)
elements_to_discard
=
probs
<
top_p_cutoff
logits
.
masked_fill_
(
elements_to_discard
,
-
float
(
"inf"
))
return
logits
def
apply_top_k_top_p
(
logits
:
torch
.
Tensor
,
k
:
Optional
[
torch
.
Tensor
],
...
...
@@ -136,10 +171,18 @@ def apply_top_k_top_p(
)
->
torch
.
Tensor
:
"""Apply top-k and top-p masks to the logits.
This function sorts the logits tensor, which can be slow for large batches.
If a top-p is used, this function will sort the logits tensor,
which can be slow for large batches.
The logits tensor may be updated in-place.
"""
if
k
is
None
and
p
is
None
:
if
p
is
None
:
if
k
is
None
:
return
logits
# Avoid sorting vocab for top-k only case.
return
apply_top_k_only
(
logits
,
k
)
logits_sort
,
logits_idx
=
logits
.
sort
(
dim
=-
1
,
descending
=
False
)
if
k
is
not
None
:
...
...
@@ -153,7 +196,7 @@ def apply_top_k_top_p(
if
p
is
not
None
:
# Apply top-p.
probs_sort
=
logits_sort
.
softmax
(
dim
=-
1
)
probs_sum
=
probs_sort
.
cumsum
(
dim
=-
1
)
probs_sum
=
torch
.
cumsum
(
probs_sort
,
dim
=-
1
,
out
=
probs_sort
)
top_p_mask
=
probs_sum
<=
1
-
p
.
unsqueeze
(
dim
=
1
)
# at least one
top_p_mask
[:,
-
1
]
=
False
...
...
@@ -164,6 +207,31 @@ def apply_top_k_top_p(
return
logits
def
apply_top_k_only
(
logits
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""
Apply top-k mask to the logits.
This implementation doesn't involve sorting the entire vocab.
The logits tensor may be updated in-place.
"""
no_top_k_mask
=
k
==
logits
.
shape
[
1
]
# Set non-top-k rows to 1 so that we can gather.
k
=
k
.
masked_fill
(
no_top_k_mask
,
1
)
max_top_k
=
k
.
max
()
# topk.values tensor has shape [batch_size, max_top_k].
# Convert top k to 0-based index in range [0, max_top_k).
k_index
=
k
.
sub_
(
1
).
unsqueeze
(
1
).
expand
(
logits
.
shape
[
0
],
1
)
top_k_mask
=
logits
.
topk
(
max_top_k
,
dim
=
1
).
values
.
gather
(
1
,
k_index
.
long
())
# Handle non-topk rows.
top_k_mask
.
masked_fill_
(
no_top_k_mask
.
unsqueeze
(
1
),
-
float
(
"inf"
))
logits
.
masked_fill_
(
logits
<
top_k_mask
,
-
float
(
"inf"
))
return
logits
def
random_sample
(
probs
:
torch
.
Tensor
,
generators
:
dict
[
int
,
torch
.
Generator
],
...
...
vllm/v1/sample/rejection_sampler.py
View file @
fcfc474d
...
...
@@ -109,6 +109,18 @@ class RejectionSampler(nn.Module):
output_token_ids
:
torch
.
Tensor
,
vocab_size
:
int
,
)
->
list
[
list
[
int
]]:
"""Parse the output of the rejection sampler.
Args:
output_token_ids: The sampled token IDs in shape
[batch_size, max_spec_len + 1]. The rejected tokens are
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
and will be filtered out in this function.
vocab_size: The size of the vocabulary.
Returns:
A list of lists of token IDs.
"""
output_token_ids_np
=
output_token_ids
.
cpu
().
numpy
()
# Create mask for valid tokens.
valid_mask
=
((
output_token_ids_np
!=
PLACEHOLDER_TOKEN_ID
)
&
...
...
vllm/v1/sample/sampler.py
View file @
fcfc474d
...
...
@@ -87,6 +87,12 @@ class Sampler(nn.Module):
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
torch
.
Tensor
:
"""Sample logits based on sampling metadata.
The various logits processing functions called in this method
may update the logits tensor in-place.
"""
assert
not
(
sampling_metadata
.
all_greedy
and
sampling_metadata
.
all_random
)
if
sampling_metadata
.
all_random
:
...
...
Prev
1
…
20
21
22
23
24
25
26
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment