Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
675ba75f
Commit
675ba75f
authored
Apr 07, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.3' into v0.8.3-ori
parents
5cc98918
296c6572
Changes
501
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1282 additions
and
488 deletions
+1282
-488
vllm/v1/core/sched/output.py
vllm/v1/core/sched/output.py
+1
-2
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+154
-116
vllm/v1/core/specialized_manager.py
vllm/v1/core/specialized_manager.py
+161
-0
vllm/v1/engine/__init__.py
vllm/v1/engine/__init__.py
+8
-1
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+22
-12
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+219
-43
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+345
-75
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+24
-11
vllm/v1/engine/output_processor.py
vllm/v1/engine/output_processor.py
+1
-1
vllm/v1/engine/processor.py
vllm/v1/engine/processor.py
+49
-63
vllm/v1/executor/multiproc_executor.py
vllm/v1/executor/multiproc_executor.py
+21
-17
vllm/v1/kv_cache_interface.py
vllm/v1/kv_cache_interface.py
+46
-11
vllm/v1/metrics/loggers.py
vllm/v1/metrics/loggers.py
+43
-4
vllm/v1/metrics/stats.py
vllm/v1/metrics/stats.py
+4
-0
vllm/v1/request.py
vllm/v1/request.py
+7
-10
vllm/v1/sample/ops/topk_topp_sampler.py
vllm/v1/sample/ops/topk_topp_sampler.py
+86
-18
vllm/v1/sample/rejection_sampler.py
vllm/v1/sample/rejection_sampler.py
+12
-0
vllm/v1/sample/sampler.py
vllm/v1/sample/sampler.py
+6
-0
vllm/v1/sample/tpu/metadata.py
vllm/v1/sample/tpu/metadata.py
+65
-104
vllm/v1/serial_utils.py
vllm/v1/serial_utils.py
+8
-0
No files found.
vllm/v1/core/sched/output.py
View file @
675ba75f
...
@@ -10,8 +10,7 @@ if TYPE_CHECKING:
...
@@ -10,8 +10,7 @@ if TYPE_CHECKING:
import
numpy.typing
as
npt
import
numpy.typing
as
npt
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.multimodal.base
import
PlaceholderRange
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.request
import
Request
from
vllm.v1.request
import
Request
...
...
vllm/v1/core/sched/scheduler.py
View file @
675ba75f
...
@@ -7,9 +7,9 @@ from collections import deque
...
@@ -7,9 +7,9 @@ from collections import deque
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
from
typing
import
Optional
,
Union
from
typing
import
Optional
,
Union
from
vllm.config
import
(
CacheConfig
,
LoRAConfig
,
ModelConfig
,
SchedulerConfig
,
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
ModelConfig
,
SchedulerConfig
SpeculativeConfig
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.v1.core.encoder_cache_manager
import
(
EncoderCacheManager
,
from
vllm.v1.core.encoder_cache_manager
import
(
EncoderCacheManager
,
compute_encoder_budget
)
compute_encoder_budget
)
from
vllm.v1.core.kv_cache_manager
import
KVCacheManager
from
vllm.v1.core.kv_cache_manager
import
KVCacheManager
...
@@ -19,9 +19,11 @@ from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
...
@@ -19,9 +19,11 @@ from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
from
vllm.v1.core.sched.utils
import
check_stop
from
vllm.v1.core.sched.utils
import
check_stop
from
vllm.v1.engine
import
(
EngineCoreEventType
,
EngineCoreOutput
,
from
vllm.v1.engine
import
(
EngineCoreEventType
,
EngineCoreOutput
,
EngineCoreOutputs
)
EngineCoreOutputs
)
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.metrics.stats
import
SchedulerStats
from
vllm.v1.metrics.stats
import
SchedulerStats
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.spec_decode.metrics
import
SpecDecodingStats
from
vllm.v1.structured_output
import
StructuredOutputManager
from
vllm.v1.structured_output
import
StructuredOutputManager
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -35,32 +37,37 @@ class Scheduler(SchedulerInterface):
...
@@ -35,32 +37,37 @@ class Scheduler(SchedulerInterface):
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
cache_config
:
CacheConfig
,
cache_config
:
CacheConfig
,
lora_config
:
Optional
[
LoRAConfig
],
lora_config
:
Optional
[
LoRAConfig
],
speculative_config
:
Optional
[
SpeculativeConfig
],
kv_cache_config
:
KVCacheConfig
,
log_stats
:
bool
,
structured_output_manager
:
StructuredOutputManager
,
structured_output_manager
:
StructuredOutputManager
,
mm_registry
:
MultiModalRegistry
=
MULTIMODAL_REGISTRY
,
include_finished_set
:
bool
=
False
,
log_stats
:
bool
=
False
,
)
->
None
:
)
->
None
:
self
.
scheduler_config
=
scheduler_config
self
.
scheduler_config
=
scheduler_config
self
.
cache_config
=
cache_config
self
.
cache_config
=
cache_config
self
.
lora_config
=
lora_config
self
.
lora_config
=
lora_config
self
.
speculativ
e_config
=
speculativ
e_config
self
.
kv_cach
e_config
=
kv_cach
e_config
self
.
log_stats
=
log_stats
self
.
log_stats
=
log_stats
self
.
structured_output_manager
=
structured_output_manager
self
.
structured_output_manager
=
structured_output_manager
# include_finished_set controls whether a separate set of finished
# request ids should be included in the EngineCoreOutputs returned
# by update_from_outputs(). This is currently used in the multi-engine
# case to track request lifetimes efficiently.
self
.
include_finished_set
=
include_finished_set
# Scheduling constraints.
# Scheduling constraints.
self
.
max_num_running_reqs
=
self
.
scheduler_config
.
max_num_seqs
self
.
max_num_running_reqs
=
self
.
scheduler_config
.
max_num_seqs
self
.
max_num_scheduled_tokens
=
\
self
.
max_num_scheduled_tokens
=
\
self
.
scheduler_config
.
max_num_batched_tokens
self
.
scheduler_config
.
max_num_batched_tokens
self
.
max_model_len
=
self
.
scheduler_config
.
max_model_len
self
.
max_model_len
=
self
.
scheduler_config
.
max_model_len
num_gpu_blocks
=
cache_config
.
num_gpu_blocks
assert
isinstance
(
num_gpu_blocks
,
int
)
and
num_gpu_blocks
>
0
# Create the KV cache manager.
# Create the KV cache manager.
self
.
kv_cache_manager
=
KVCacheManager
(
self
.
kv_cache_manager
=
KVCacheManager
(
block_size
=
self
.
cache_config
.
block_size
,
kv_cache_config
=
kv_cache_config
,
num_gpu_blocks
=
num_gpu_blocks
,
max_model_len
=
self
.
max_model_len
,
max_model_len
=
self
.
max_model_len
,
sliding_window
=
self
.
cache_config
.
sliding_window
,
enable_caching
=
cache_config
.
enable_prefix_caching
,
enable_
caching
=
self
.
cache_config
.
enable_
prefix_caching
,
caching
_hash_algo
=
self
.
cache_config
.
prefix_caching
_hash_algo
,
log_stats
=
self
.
log_stats
)
log_stats
=
self
.
log_stats
)
self
.
block_size
=
self
.
cache_config
.
block_size
self
.
block_size
=
self
.
cache_config
.
block_size
...
@@ -92,6 +99,7 @@ class Scheduler(SchedulerInterface):
...
@@ -92,6 +99,7 @@ class Scheduler(SchedulerInterface):
encoder_compute_budget
,
encoder_cache_size
=
compute_encoder_budget
(
encoder_compute_budget
,
encoder_cache_size
=
compute_encoder_budget
(
model_config
=
model_config
,
model_config
=
model_config
,
scheduler_config
=
scheduler_config
,
scheduler_config
=
scheduler_config
,
mm_registry
=
mm_registry
,
)
)
# NOTE(woosuk): Here, "encoder" includes the vision encoder (and
# NOTE(woosuk): Here, "encoder" includes the vision encoder (and
...
@@ -152,23 +160,31 @@ class Scheduler(SchedulerInterface):
...
@@ -152,23 +160,31 @@ class Scheduler(SchedulerInterface):
num_new_tokens
=
(
request
.
num_tokens_with_spec
-
num_new_tokens
=
(
request
.
num_tokens_with_spec
-
request
.
num_computed_tokens
)
request
.
num_computed_tokens
)
if
(
0
<
self
.
scheduler_config
.
long_prefill_token_threshold
<
num_new_tokens
):
num_new_tokens
=
(
self
.
scheduler_config
.
long_prefill_token_threshold
)
num_new_tokens
=
min
(
num_new_tokens
,
token_budget
)
num_new_tokens
=
min
(
num_new_tokens
,
token_budget
)
assert
num_new_tokens
>
0
assert
num_new_tokens
>
0
# Schedule encoder inputs.
# Schedule encoder inputs.
encoder_inputs_to_schedule
,
num_new_tokens
,
new_encoder_budget
=
(
if
request
.
has_encoder_inputs
:
self
.
_try_schedule_encoder_inputs
(
request
,
(
encoder_inputs_to_schedule
,
num_new_tokens
,
request
.
num_computed_tokens
,
new_encoder_budget
)
=
self
.
_try_schedule_encoder_inputs
(
num_new_tokens
,
request
,
request
.
num_computed_tokens
,
num_new_tokens
,
encoder_budget
))
encoder_budget
)
if
num_new_tokens
==
0
:
if
num_new_tokens
==
0
:
# The request cannot be scheduled because the encoder budget
# The request cannot be scheduled because the encoder budget
# or the encoder cache is exhausted.
# or the encoder cache is exhausted.
# NOTE(woosuk): Here, by doing `continue` instead of `break`,
# NOTE(woosuk): By using `continue` instead of `break` here,
# we do not strictly follow the FCFS scheduling policy and
# we intentionally relax the strict FCFS scheduling policy
# allow the lower-priority requests to be scheduled.
# to allow lower-priority requests to be scheduled when a
req_index
+=
1
# higher-priority request is blocked by encoder constraints.
continue
req_index
+=
1
continue
else
:
encoder_inputs_to_schedule
=
None
new_encoder_budget
=
encoder_budget
while
True
:
while
True
:
new_blocks
=
self
.
kv_cache_manager
.
allocate_slots
(
new_blocks
=
self
.
kv_cache_manager
.
allocate_slots
(
...
@@ -235,16 +251,16 @@ class Scheduler(SchedulerInterface):
...
@@ -235,16 +251,16 @@ class Scheduler(SchedulerInterface):
encoder_budget
=
new_encoder_budget
encoder_budget
=
new_encoder_budget
# Record the LoRAs in scheduled_running_reqs
# Record the LoRAs in scheduled_running_reqs
request
ed_loras
:
set
[
int
]
=
set
()
schedul
ed_loras
:
set
[
int
]
=
set
()
if
self
.
lora_config
:
if
self
.
lora_config
:
request
ed_loras
=
set
(
schedul
ed_loras
=
set
(
req
.
lora_request
.
lora_int_id
for
req
in
scheduled_running_reqs
req
.
lora_request
.
lora_int_id
for
req
in
scheduled_running_reqs
if
req
.
lora_request
and
req
.
lora_request
.
lora_int_id
>
0
)
if
req
.
lora_request
and
req
.
lora_request
.
lora_int_id
>
0
)
assert
len
(
request
ed_loras
)
<=
self
.
lora_config
.
max_loras
assert
len
(
schedul
ed_loras
)
<=
self
.
lora_config
.
max_loras
# Use a temporary deque to collect requests that need to be skipped
# Use a temporary deque to collect requests that need to be skipped
# and put back at the head of the waiting queue later
# and put back at the head of the waiting queue later
waiting_for_fsm
:
deque
[
Request
]
=
deque
()
skipped_waiting_requests
:
deque
[
Request
]
=
deque
()
# Next, schedule the WAITING requests.
# Next, schedule the WAITING requests.
if
not
preempted_reqs
:
if
not
preempted_reqs
:
...
@@ -254,31 +270,27 @@ class Scheduler(SchedulerInterface):
...
@@ -254,31 +270,27 @@ class Scheduler(SchedulerInterface):
request
=
self
.
waiting
[
0
]
request
=
self
.
waiting
[
0
]
# Skip request if the structured output request is still waiting
# for FSM compilation.
if
request
.
status
==
RequestStatus
.
WAITING_FOR_FSM
:
if
request
.
status
==
RequestStatus
.
WAITING_FOR_FSM
:
structured_output_req
=
request
.
structured_output_request
structured_output_req
=
request
.
structured_output_request
if
structured_output_req
and
structured_output_req
.
grammar
:
if
structured_output_req
and
structured_output_req
.
grammar
:
request
.
status
=
RequestStatus
.
WAITING
request
.
status
=
RequestStatus
.
WAITING
else
:
else
:
waiting_structured_output_req
=
self
.
waiting
.
popleft
()
self
.
waiting
.
popleft
()
waiting_for_fsm
.
appendleft
(
skipped_waiting_requests
.
appendleft
(
request
)
waiting_structured_output_req
)
continue
continue
# Check that adding the request still respects the max_loras
# Check that adding the request still respects the max_loras
# constraint.
# constraint.
if
self
.
lora_config
and
request
.
lora_request
:
if
self
.
lora_config
and
request
.
lora_request
and
(
req_lora_id
=
request
.
lora_request
.
lora_int_id
len
(
scheduled_loras
)
==
self
.
lora_config
.
max_loras
if
len
(
requested_loras
)
==
self
.
lora_config
.
max_loras
and
(
and
request
.
lora_request
.
lora_int_id
req_lora_id
not
in
requested_loras
):
not
in
scheduled_loras
):
# Cannot schedule.
# Scheduling would exceed max_loras, skip.
# TODO (varun): This means all the other requests in
self
.
waiting
.
popleft
()
# the WAITING queue will be blocked by this request,
skipped_waiting_requests
.
appendleft
(
request
)
# even if,
continue
# 1. these other requests do not use LoRA, or,
# 2. these other requests use the already requested
# LoRAs.
# This is too conservative and could be optimized.
break
# Get already-cached tokens.
# Get already-cached tokens.
computed_blocks
,
num_computed_tokens
=
\
computed_blocks
,
num_computed_tokens
=
\
...
@@ -288,28 +300,25 @@ class Scheduler(SchedulerInterface):
...
@@ -288,28 +300,25 @@ class Scheduler(SchedulerInterface):
# `request.num_prompt_tokens` to consider the resumed requests,
# `request.num_prompt_tokens` to consider the resumed requests,
# which have output tokens.
# which have output tokens.
num_new_tokens
=
request
.
num_tokens
-
num_computed_tokens
num_new_tokens
=
request
.
num_tokens
-
num_computed_tokens
if
num_new_tokens
==
0
:
if
(
0
<
self
.
scheduler_config
.
long_prefill_token_threshold
<
# This happens when prompt length is divisible by the block
num_new_tokens
):
# size and all blocks are cached. Now we force to recompute
num_new_tokens
=
(
# the last block. Note that we have to re-compute an entire
self
.
scheduler_config
.
long_prefill_token_threshold
)
# block because allocate_slots() assumes num_computed_tokens
# is always a multiple of the block size. This limitation
# can potentially be removed in the future to slightly
# improve the performance.
num_computed_tokens
-=
self
.
block_size
num_new_tokens
=
self
.
block_size
computed_blocks
.
pop
()
num_new_tokens
=
min
(
num_new_tokens
,
token_budget
)
num_new_tokens
=
min
(
num_new_tokens
,
token_budget
)
assert
num_new_tokens
>
0
assert
num_new_tokens
>
0
# Schedule encoder inputs.
# Schedule encoder inputs.
(
encoder_inputs_to_schedule
,
num_new_tokens
,
if
request
.
has_encoder_inputs
:
new_encoder_budget
)
=
self
.
_try_schedule_encoder_inputs
(
(
encoder_inputs_to_schedule
,
num_new_tokens
,
request
,
num_computed_tokens
,
num_new_tokens
,
new_encoder_budget
)
=
self
.
_try_schedule_encoder_inputs
(
encoder_budget
)
request
,
num_computed_tokens
,
num_new_tokens
,
if
num_new_tokens
==
0
:
encoder_budget
)
# The request cannot be scheduled.
if
num_new_tokens
==
0
:
break
# The request cannot be scheduled.
break
else
:
encoder_inputs_to_schedule
=
None
new_encoder_budget
=
encoder_budget
new_blocks
=
self
.
kv_cache_manager
.
allocate_slots
(
new_blocks
=
self
.
kv_cache_manager
.
allocate_slots
(
request
,
num_new_tokens
,
computed_blocks
)
request
,
num_new_tokens
,
computed_blocks
)
...
@@ -336,7 +345,7 @@ class Scheduler(SchedulerInterface):
...
@@ -336,7 +345,7 @@ class Scheduler(SchedulerInterface):
f
"Invalid request status:
{
request
.
status
}
"
)
f
"Invalid request status:
{
request
.
status
}
"
)
if
self
.
lora_config
and
request
.
lora_request
:
if
self
.
lora_config
and
request
.
lora_request
:
request
ed_loras
.
add
(
request
.
lora_request
.
lora_int_id
)
schedul
ed_loras
.
add
(
request
.
lora_request
.
lora_int_id
)
req_to_new_block_ids
[
request
.
request_id
]
=
[
req_to_new_block_ids
[
request
.
request_id
]
=
[
b
.
block_id
for
b
in
computed_blocks
+
new_blocks
b
.
block_id
for
b
in
computed_blocks
+
new_blocks
]
]
...
@@ -355,8 +364,8 @@ class Scheduler(SchedulerInterface):
...
@@ -355,8 +364,8 @@ class Scheduler(SchedulerInterface):
encoder_budget
=
new_encoder_budget
encoder_budget
=
new_encoder_budget
# Put back any skipped requests at the head of the waiting queue
# Put back any skipped requests at the head of the waiting queue
if
waiting_for_fsm
:
if
skipped_waiting_requests
:
self
.
waiting
.
extendleft
(
waiting_for_fsm
)
self
.
waiting
.
extendleft
(
skipped_waiting_requests
)
# Check if the scheduling constraints are satisfied.
# Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens
=
sum
(
num_scheduled_tokens
.
values
())
total_num_scheduled_tokens
=
sum
(
num_scheduled_tokens
.
values
())
...
@@ -425,6 +434,18 @@ class Scheduler(SchedulerInterface):
...
@@ -425,6 +434,18 @@ class Scheduler(SchedulerInterface):
grammar_bitmask
=
grammar_bitmask
,
grammar_bitmask
=
grammar_bitmask
,
)
)
# Advance the number of computed tokens for the request AFTER
# the request is scheduled.
# 1. The scheduler_output of the current step has to include the
# original number of scheduled tokens to determine input IDs.
# 2. Advance the number of computed tokens here allowing us to
# schedule the prefill request again immediately in the next
# scheduling step.
# 3. If some tokens (e.g. spec tokens) are rejected later, the number of
# computed tokens will be adjusted in update_from_output.
for
req_id
,
num_scheduled_token
in
num_scheduled_tokens
.
items
():
self
.
requests
[
req_id
].
num_computed_tokens
+=
num_scheduled_token
self
.
finished_req_ids
=
set
()
self
.
finished_req_ids
=
set
()
return
scheduler_output
return
scheduler_output
...
@@ -479,9 +500,6 @@ class Scheduler(SchedulerInterface):
...
@@ -479,9 +500,6 @@ class Scheduler(SchedulerInterface):
limitations, the method adjusts `num_new_tokens` to schedule only the
limitations, the method adjusts `num_new_tokens` to schedule only the
decoder tokens up to just before the unschedulable encoder input.
decoder tokens up to just before the unschedulable encoder input.
"""
"""
if
not
request
.
has_encoder_inputs
():
return
[],
num_new_tokens
,
encoder_budget
encoder_inputs_to_schedule
:
list
[
int
]
=
[]
encoder_inputs_to_schedule
:
list
[
int
]
=
[]
mm_positions
=
request
.
mm_positions
mm_positions
=
request
.
mm_positions
assert
mm_positions
is
not
None
assert
mm_positions
is
not
None
...
@@ -539,6 +557,7 @@ class Scheduler(SchedulerInterface):
...
@@ -539,6 +557,7 @@ class Scheduler(SchedulerInterface):
new_running
:
list
[
Request
]
=
[]
new_running
:
list
[
Request
]
=
[]
outputs
:
list
[
EngineCoreOutput
]
=
[]
outputs
:
list
[
EngineCoreOutput
]
=
[]
spec_decoding_stats
:
Optional
[
SpecDecodingStats
]
=
None
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
# loop can be a performance bottleneck. We should do our best to avoid
# loop can be a performance bottleneck. We should do our best to avoid
...
@@ -553,36 +572,32 @@ class Scheduler(SchedulerInterface):
...
@@ -553,36 +572,32 @@ class Scheduler(SchedulerInterface):
req_index
=
model_runner_output
.
req_id_to_index
[
req_id
]
req_index
=
model_runner_output
.
req_id_to_index
[
req_id
]
generated_token_ids
=
sampled_token_ids
[
req_index
]
generated_token_ids
=
sampled_token_ids
[
req_index
]
if
req_id
not
in
scheduler_output
.
scheduled_spec_decode_tokens
:
# When the request's num_computed_tokens catches up
scheduled_spec_token_ids
=
(
# its num_tokens, the request generates output tokens.
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
))
# Otherwise, we ignore the sampler output for the request.
if
scheduled_spec_token_ids
:
request
.
num_computed_tokens
+=
num_tokens_scheduled
# num_computed_tokens represents the number of tokens
assert
request
.
num_computed_tokens
<=
request
.
num_tokens
else
:
# num_computed_tokens_step represents the number of tokens
# processed in the current step, considering scheduled
# processed in the current step, considering scheduled
# tokens and rejections.
# tokens and rejections. If some tokens are rejected,
# It is calculated as:
# num_computed_tokens is decreased by the number of rejected
# num_computed_tokens_step = num_scheduled_tokens -
# tokens, where is given by:
# num_tokens_rejected,
# where num_tokens_rejected is given by:
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
scheduled_spec_token_ids
=
(
num_tokens_rejected
=
(
len
(
scheduled_spec_token_ids
)
+
1
-
scheduler_output
.
scheduled_spec_decode
_token
s
[
req
_id
]
)
len
(
generated
_token_id
s
)
)
request
.
num_computed_tokens
-=
num_tokens_rejected
num_computed_tokens_step
=
num_scheduled_tokens
[
req_id
]
-
(
spec_decoding_stats
=
self
.
make_spec_decoding_stats
(
len
(
scheduled_spec_token_ids
)
+
1
-
spec_decoding_stats
,
len
(
generated
_token_ids
)
)
num_draft_tokens
=
len
(
scheduled_spec
_token_ids
)
,
request
.
num_computed_tokens
+=
num_computed_tokens_step
num_accepted_tokens
=
len
(
generated_token_ids
)
-
1
)
cached_encoder_input_ids
=
(
cached_encoder_input_ids
=
(
self
.
encoder_cache_manager
.
get_cached_input_ids
(
request
))
self
.
encoder_cache_manager
.
get_cached_input_ids
(
request
))
# OPTIMIZATION: Avoid list(set) if the set is empty.
# OPTIMIZATION: Avoid list(set) if the set is empty.
if
cached_encoder_input_ids
:
if
cached_encoder_input_ids
:
for
input_id
in
list
(
cached_encoder_input_ids
):
for
input_id
in
list
(
cached_encoder_input_ids
):
start_pos
=
request
.
mm_positions
[
input_id
][
"offset"
]
mm_positions
=
request
.
mm_positions
[
input_id
]
num_tokens
=
request
.
mm_positions
[
input_id
][
"length"
]
start_pos
=
mm_positions
[
"offset"
]
num_tokens
=
mm_positions
[
"length"
]
if
start_pos
+
num_tokens
<=
request
.
num_computed_tokens
:
if
start_pos
+
num_tokens
<=
request
.
num_computed_tokens
:
# The encoder output is already processed and stored
# The encoder output is already processed and stored
# in the decoder's KV cache.
# in the decoder's KV cache.
...
@@ -595,35 +610,34 @@ class Scheduler(SchedulerInterface):
...
@@ -595,35 +610,34 @@ class Scheduler(SchedulerInterface):
stopped
=
False
stopped
=
False
new_logprobs
=
None
new_logprobs
=
None
new_token_ids
:
list
[
int
]
=
[]
new_token_ids
=
generated_token_ids
if
request
.
num_computed_tokens
>=
request
.
num_tokens
:
# Append generated tokens and check for stop. Note that if
for
output_token_id
in
generated_token_ids
:
# a request is still being prefilled, we expect the model runner
request
.
append_output_token_ids
(
output_token_id
)
# to return empty token ids for the request.
new_token_ids
.
append
(
output_token_id
)
for
num_new
,
output_token_id
in
enumerate
(
new_token_ids
,
1
):
request
.
append_output_token_ids
(
output_token_id
)
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
# Check for stop and update request state.
stopped
=
check_stop
(
request
,
self
.
max_model_len
)
# This must be called before we make the EngineCoreOutput.
if
stopped
:
stopped
=
check_stop
(
request
,
self
.
max_model_len
)
self
.
_free_request
(
request
)
if
stopped
:
break
self
.
_free_request
(
request
)
del
new_token_ids
[
num_new
:]
# Trim new tokens if needed.
break
# Extract sample logprobs if needed.
# Extract sample logprobs if needed.
if
request
.
sampling_params
.
logprobs
is
not
None
:
if
request
.
sampling_params
.
logprobs
is
not
None
and
logprobs
:
assert
logprobs
is
not
None
# NOTE: once we support N tokens per step (spec decode),
# NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
# the outer lists can be of length > 1.
new_logprobs
=
logprobs
.
slice
(
req_index
,
req_index
+
1
)
new_logprobs
=
logprobs
.
slice
(
req_index
,
req_index
+
1
)
if
new_token_ids
and
request
.
use_structured_output
:
if
new_token_ids
and
request
.
use_structured_output
:
# NOTE: structured_output_request
# NOTE: structured_output_request
# should not be None if use_structured_output, we have
# should not be None if use_structured_output, we have
# check above, so safe to ignore type warning
# check above, so safe to ignore type warning
request
.
structured_output_request
.
grammar
.
accept_tokens
(
# type: ignore[union-attr]
request
.
structured_output_request
.
grammar
.
accept_tokens
(
# type: ignore[union-attr]
request
.
request_id
,
req_id
,
new_token_ids
)
new_token_ids
,
)
# Get prompt logprobs for this request.
# Get prompt logprobs for this request.
prompt_logprobs_tensors
=
prompt_logprobs_dict
.
get
(
req_id
)
prompt_logprobs_tensors
=
prompt_logprobs_dict
.
get
(
req_id
)
...
@@ -642,15 +656,21 @@ class Scheduler(SchedulerInterface):
...
@@ -642,15 +656,21 @@ class Scheduler(SchedulerInterface):
# Invariant: EngineCore returns no partial prefill outputs.
# Invariant: EngineCore returns no partial prefill outputs.
assert
not
prompt_logprobs_tensors
assert
not
prompt_logprobs_tensors
self
.
scheduled_req_ids
.
remove
(
req
uest
.
request
_id
)
self
.
scheduled_req_ids
.
remove
(
req_id
)
if
not
stopped
:
if
not
stopped
:
new_running
.
append
(
request
)
new_running
.
append
(
request
)
self
.
running
=
new_running
self
.
running
=
new_running
return
EngineCoreOutputs
(
engine_core_outputs
=
EngineCoreOutputs
(
outputs
=
outputs
,
outputs
=
outputs
,
scheduler_stats
=
self
.
make_stats
(),
scheduler_stats
=
self
.
make_stats
(
spec_decoding_stats
),
)
)
if
self
.
include_finished_set
:
#TODO currently sending duplicates here, improve this
engine_core_outputs
.
finished_requests
=
(
scheduler_output
.
finished_req_ids
|
self
.
finished_req_ids
)
return
engine_core_outputs
def
add_request
(
self
,
request
:
Request
)
->
None
:
def
add_request
(
self
,
request
:
Request
)
->
None
:
self
.
waiting
.
append
(
request
)
self
.
waiting
.
append
(
request
)
...
@@ -710,7 +730,10 @@ class Scheduler(SchedulerInterface):
...
@@ -710,7 +730,10 @@ class Scheduler(SchedulerInterface):
def
reset_prefix_cache
(
self
)
->
bool
:
def
reset_prefix_cache
(
self
)
->
bool
:
return
self
.
kv_cache_manager
.
reset_prefix_cache
()
return
self
.
kv_cache_manager
.
reset_prefix_cache
()
def
make_stats
(
self
)
->
Optional
[
SchedulerStats
]:
def
make_stats
(
self
,
spec_decoding_stats
:
Optional
[
SpecDecodingStats
]
=
None
,
)
->
Optional
[
SchedulerStats
]:
if
not
self
.
log_stats
:
if
not
self
.
log_stats
:
return
None
return
None
return
SchedulerStats
(
return
SchedulerStats
(
...
@@ -718,4 +741,19 @@ class Scheduler(SchedulerInterface):
...
@@ -718,4 +741,19 @@ class Scheduler(SchedulerInterface):
num_waiting_reqs
=
len
(
self
.
waiting
),
num_waiting_reqs
=
len
(
self
.
waiting
),
gpu_cache_usage
=
self
.
kv_cache_manager
.
usage
,
gpu_cache_usage
=
self
.
kv_cache_manager
.
usage
,
prefix_cache_stats
=
self
.
kv_cache_manager
.
make_prefix_cache_stats
(),
prefix_cache_stats
=
self
.
kv_cache_manager
.
make_prefix_cache_stats
(),
spec_decoding_stats
=
spec_decoding_stats
,
)
)
def
make_spec_decoding_stats
(
self
,
spec_decoding_stats
:
Optional
[
SpecDecodingStats
],
num_draft_tokens
:
int
,
num_accepted_tokens
:
int
,
)
->
Optional
[
SpecDecodingStats
]:
if
not
self
.
log_stats
:
return
None
if
spec_decoding_stats
is
None
:
spec_decoding_stats
=
SpecDecodingStats
()
spec_decoding_stats
.
observe
(
num_draft_tokens
=
num_draft_tokens
,
num_accepted_tokens
=
num_accepted_tokens
)
return
spec_decoding_stats
vllm/v1/core/specialized_manager.py
0 → 100644
View file @
675ba75f
# SPDX-License-Identifier: Apache-2.0
from
abc
import
ABC
,
abstractmethod
from
vllm.utils
import
cdiv
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.kv_cache_utils
import
BlockHashType
,
KVCacheBlock
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheSpec
,
SlidingWindowSpec
)
class
SpecializedManager
(
ABC
):
"""
An abstract base class for specialized managers that handle the kv
cache management logic of different attention layers.
"""
def
__init__
(
self
,
kv_cache_spec
:
KVCacheSpec
,
block_pool
:
BlockPool
,
)
->
None
:
"""
Initializes the SpecializedManager.
Args:
kv_cache_spec: The kv_cache_spec for this manager.
block_pool: The block pool.
"""
self
.
block_size
=
kv_cache_spec
.
block_size
self
.
kv_cache_spec
=
kv_cache_spec
self
.
block_pool
=
block_pool
@
abstractmethod
def
find_longest_cache_hit
(
self
,
block_hashes
:
list
[
BlockHashType
])
->
list
[
KVCacheBlock
]:
"""
Get the longest cache hit prefix of the blocks. If no cache hit is
found, return an empty list.
Args:
block_hashes: The block hashes of the request.
Returns:
A list of cached blocks with skipped blocks replaced by null block.
For example, sliding window manager should return a list like
[NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)] for block size 4 and
sliding window 8.
"""
raise
NotImplementedError
@
abstractmethod
def
remove_skipped_blocks
(
self
,
blocks
:
list
[
KVCacheBlock
],
num_computed_tokens
:
int
)
->
list
[
KVCacheBlock
]:
"""
Remove the blocks that are no longer needed from `blocks`. The removed
blocks should be replaced by null_block. Return the removed blocks in
eviction order, where the first returned block should be evicted first.
Don't free the removed blocks in this function.
Args:
blocks: The list of blocks to be updated.
num_computed_tokens: The number of tokens that have been computed.
Returns:
The removed blocks in eviction order.
"""
raise
NotImplementedError
class
FullAttentionManager
(
SpecializedManager
):
def
find_longest_cache_hit
(
self
,
block_hashes
:
list
[
BlockHashType
])
->
list
[
KVCacheBlock
]:
computed_blocks
:
list
[
KVCacheBlock
]
=
[]
for
block_hash
in
block_hashes
:
# block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are
# not computed yet for sure.
if
cached_block
:
=
self
.
block_pool
.
get_cached_block
(
block_hash
):
computed_blocks
.
append
(
cached_block
)
else
:
break
return
computed_blocks
def
remove_skipped_blocks
(
self
,
blocks
:
list
[
KVCacheBlock
],
num_computed_tokens
:
int
)
->
list
[
KVCacheBlock
]:
# No need to remove blocks for full attention.
return
[]
class
SlidingWindowManager
(
SpecializedManager
):
def
__init__
(
self
,
kv_cache_spec
:
SlidingWindowSpec
,
block_pool
:
BlockPool
):
super
().
__init__
(
kv_cache_spec
,
block_pool
)
self
.
sliding_window
=
kv_cache_spec
.
sliding_window
# The number of contiguous blocks needed for prefix cache hit.
# -1 since the input token itself is also included in the window
self
.
sliding_window_contiguous_blocks
=
cdiv
(
(
kv_cache_spec
.
sliding_window
-
1
),
self
.
block_size
)
self
.
_null_block
=
block_pool
.
null_block
def
find_longest_cache_hit
(
self
,
block_hashes
:
list
[
BlockHashType
])
->
list
[
KVCacheBlock
]:
# TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
# optimize the time complexity from O(len(block_hashes)) to
# O(len(block_hashes) / sliding_window_contiguous_blocks +
# sliding_window_contiguous_blocks),
# which is good for low cache hit rate scenarios.
computed_blocks
=
[
self
.
_null_block
]
*
len
(
block_hashes
)
num_contiguous_blocks
=
0
# Search from right to left and early stop when a match is found.
for
i
in
range
(
len
(
block_hashes
)
-
1
,
-
1
,
-
1
):
if
cached_block
:
=
self
.
block_pool
.
get_cached_block
(
block_hashes
[
i
]):
computed_blocks
[
i
]
=
cached_block
num_contiguous_blocks
+=
1
if
(
num_contiguous_blocks
>=
self
.
sliding_window_contiguous_blocks
):
# Trim the trailing blocks.
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
# when sliding_window_contiguous_blocks=2.
del
computed_blocks
[
i
+
num_contiguous_blocks
:]
return
computed_blocks
else
:
num_contiguous_blocks
=
0
# The first `num_contiguous_blocks` is a cache hit even if
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
del
computed_blocks
[
num_contiguous_blocks
:]
return
computed_blocks
def
remove_skipped_blocks
(
self
,
blocks
:
list
[
KVCacheBlock
],
num_computed_tokens
:
int
)
->
list
[
KVCacheBlock
]:
# Remove the blocks that are no longer be in the sliding window and
# skipped during the attention computation.
last_useful_token
=
num_computed_tokens
-
self
.
sliding_window
+
1
last_useful_block
=
last_useful_token
//
self
.
block_size
removed_blocks
:
list
[
KVCacheBlock
]
=
[]
for
i
in
range
(
last_useful_block
-
1
,
-
1
,
-
1
):
if
blocks
[
i
]
==
self
.
_null_block
:
# If the block is already a null block, the blocks before it
# should also have been set to null blocks by the previous calls
# to this function.
break
removed_blocks
.
append
(
blocks
[
i
])
blocks
[
i
]
=
self
.
_null_block
return
removed_blocks
spec_manager_map
:
dict
[
type
[
KVCacheSpec
],
type
[
SpecializedManager
]]
=
{
FullAttentionSpec
:
FullAttentionManager
,
SlidingWindowSpec
:
SlidingWindowManager
,
}
def
get_specialized_manager
(
kv_cache_spec
:
KVCacheSpec
,
block_pool
:
BlockPool
)
->
SpecializedManager
:
manager_class
=
spec_manager_map
[
type
(
kv_cache_spec
)]
manager
=
manager_class
(
kv_cache_spec
,
block_pool
)
return
manager
vllm/v1/engine/__init__.py
View file @
675ba75f
...
@@ -128,12 +128,18 @@ class EngineCoreOutputs(
...
@@ -128,12 +128,18 @@ class EngineCoreOutputs(
#NOTE(Nick): We could consider ways to make this more compact,
#NOTE(Nick): We could consider ways to make this more compact,
# e.g. columnwise layout
# e.g. columnwise layout
engine_index
:
int
=
0
# [num_reqs]
# [num_reqs]
outputs
:
list
[
EngineCoreOutput
]
=
[]
outputs
:
list
[
EngineCoreOutput
]
=
[]
scheduler_stats
:
Optional
[
SchedulerStats
]
=
None
scheduler_stats
:
Optional
[
SchedulerStats
]
=
None
timestamp
:
float
=
0.0
timestamp
:
float
=
0.0
utility_output
:
Optional
[
UtilityOutput
]
=
None
utility_output
:
Optional
[
UtilityOutput
]
=
None
finished_requests
:
Optional
[
set
[
str
]]
=
None
# In DP case, used to signal that the engine is paused.
engine_paused
:
bool
=
False
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
self
.
timestamp
==
0.0
:
if
self
.
timestamp
==
0.0
:
...
@@ -147,4 +153,5 @@ class EngineCoreRequestType(enum.Enum):
...
@@ -147,4 +153,5 @@ class EngineCoreRequestType(enum.Enum):
"""
"""
ADD
=
b
'
\x00
'
ADD
=
b
'
\x00
'
ABORT
=
b
'
\x01
'
ABORT
=
b
'
\x01
'
UTILITY
=
b
'
\x02
'
START_DP
=
b
'
\x02
'
UTILITY
=
b
'
\x03
'
vllm/v1/engine/async_llm.py
View file @
675ba75f
...
@@ -14,10 +14,11 @@ from vllm.config import ModelConfig, VllmConfig
...
@@ -14,10 +14,11 @@ from vllm.config import ModelConfig, VllmConfig
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.protocol
import
EngineClient
from
vllm.engine.protocol
import
EngineClient
from
vllm.envs
import
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
from
vllm.envs
import
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
,
PromptType
from
vllm.inputs
import
PromptType
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
...
@@ -48,7 +49,7 @@ class AsyncLLM(EngineClient):
...
@@ -48,7 +49,7 @@ class AsyncLLM(EngineClient):
executor_class
:
type
[
Executor
],
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
log_stats
:
bool
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
input
_registry
:
Input
Registry
=
INPUT
_REGISTRY
,
mm
_registry
:
MultiModal
Registry
=
MULTIMODAL
_REGISTRY
,
use_cached_outputs
:
bool
=
False
,
use_cached_outputs
:
bool
=
False
,
log_requests
:
bool
=
True
,
log_requests
:
bool
=
True
,
start_engine_loop
:
bool
=
True
,
start_engine_loop
:
bool
=
True
,
...
@@ -66,11 +67,17 @@ class AsyncLLM(EngineClient):
...
@@ -66,11 +67,17 @@ class AsyncLLM(EngineClient):
self
.
log_requests
=
log_requests
self
.
log_requests
=
log_requests
self
.
log_stats
=
log_stats
self
.
log_stats
=
log_stats
self
.
stat_loggers
:
list
[
StatLoggerBase
]
=
[]
# Set up stat loggers; independent set for each DP rank.
self
.
stat_loggers
:
list
[
list
[
StatLoggerBase
]]
=
[]
if
self
.
log_stats
:
if
self
.
log_stats
:
if
logger
.
isEnabledFor
(
logging
.
INFO
):
for
i
in
range
(
vllm_config
.
parallel_config
.
data_parallel_size
):
self
.
stat_loggers
.
append
(
LoggingStatLogger
())
loggers
:
list
[
StatLoggerBase
]
=
[]
self
.
stat_loggers
.
append
(
PrometheusStatLogger
(
vllm_config
))
if
logger
.
isEnabledFor
(
logging
.
INFO
):
loggers
.
append
(
LoggingStatLogger
(
engine_index
=
i
))
loggers
.
append
(
PrometheusStatLogger
(
vllm_config
,
engine_index
=
i
))
self
.
stat_loggers
.
append
(
loggers
)
# Tokenizer (+ ensure liveness if running in another process).
# Tokenizer (+ ensure liveness if running in another process).
self
.
tokenizer
=
init_tokenizer_from_configs
(
self
.
tokenizer
=
init_tokenizer_from_configs
(
...
@@ -84,7 +91,7 @@ class AsyncLLM(EngineClient):
...
@@ -84,7 +91,7 @@ class AsyncLLM(EngineClient):
self
.
processor
=
Processor
(
self
.
processor
=
Processor
(
vllm_config
=
vllm_config
,
vllm_config
=
vllm_config
,
tokenizer
=
self
.
tokenizer
,
tokenizer
=
self
.
tokenizer
,
input
_registry
=
input
_registry
,
mm
_registry
=
mm
_registry
,
)
)
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
...
@@ -329,6 +336,7 @@ class AsyncLLM(EngineClient):
...
@@ -329,6 +336,7 @@ class AsyncLLM(EngineClient):
# TODO(rob): make into a coroutine and launch it in
# TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial.
# background thread once Prometheus overhead is non-trivial.
self
.
_record_stats
(
self
.
_record_stats
(
engine_index
=
outputs
.
engine_index
,
scheduler_stats
=
outputs
.
scheduler_stats
,
scheduler_stats
=
outputs
.
scheduler_stats
,
iteration_stats
=
iteration_stats
,
iteration_stats
=
iteration_stats
,
)
)
...
@@ -350,12 +358,13 @@ class AsyncLLM(EngineClient):
...
@@ -350,12 +358,13 @@ class AsyncLLM(EngineClient):
self
,
self
,
scheduler_stats
:
Optional
[
SchedulerStats
],
scheduler_stats
:
Optional
[
SchedulerStats
],
iteration_stats
:
Optional
[
IterationStats
],
iteration_stats
:
Optional
[
IterationStats
],
engine_index
:
int
=
0
,
):
):
if
not
self
.
log_stats
:
if
not
self
.
log_stats
:
return
return
assert
scheduler_stats
is
not
None
assert
scheduler_stats
is
not
None
for
stat_logger
in
self
.
stat_loggers
:
for
stat_logger
in
self
.
stat_loggers
[
engine_index
]
:
stat_logger
.
record
(
scheduler_stats
=
scheduler_stats
,
stat_logger
.
record
(
scheduler_stats
=
scheduler_stats
,
iteration_stats
=
iteration_stats
)
iteration_stats
=
iteration_stats
)
...
@@ -393,8 +402,9 @@ class AsyncLLM(EngineClient):
...
@@ -393,8 +402,9 @@ class AsyncLLM(EngineClient):
scheduler_outputs
=
None
,
scheduler_outputs
=
None
,
model_output
=
None
,
model_output
=
None
,
)
->
None
:
)
->
None
:
for
stat_logger
in
self
.
stat_loggers
:
for
loggers
in
self
.
stat_loggers
:
stat_logger
.
log
()
for
stat_logger
in
loggers
:
stat_logger
.
log
()
async
def
check_health
(
self
)
->
None
:
async
def
check_health
(
self
)
->
None
:
logger
.
debug
(
"Called check_health."
)
logger
.
debug
(
"Called check_health."
)
...
@@ -414,8 +424,8 @@ class AsyncLLM(EngineClient):
...
@@ -414,8 +424,8 @@ class AsyncLLM(EngineClient):
async
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
async
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
await
self
.
engine_core
.
sleep_async
(
level
)
await
self
.
engine_core
.
sleep_async
(
level
)
async
def
wake_up
(
self
)
->
None
:
async
def
wake_up
(
self
,
tags
:
Optional
[
list
[
str
]]
=
None
)
->
None
:
await
self
.
engine_core
.
wake_up_async
()
await
self
.
engine_core
.
wake_up_async
(
tags
)
async
def
is_sleeping
(
self
)
->
bool
:
async
def
is_sleeping
(
self
)
->
bool
:
return
await
self
.
engine_core
.
is_sleeping_async
()
return
await
self
.
engine_core
.
is_sleeping_async
()
...
...
vllm/v1/engine/core.py
View file @
675ba75f
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
os
import
queue
import
queue
import
signal
import
signal
import
sys
import
threading
import
threading
import
time
import
time
from
concurrent.futures
import
Future
from
concurrent.futures
import
Future
from
inspect
import
isclass
,
signature
from
inspect
import
isclass
,
signature
from
multiprocessing.connection
import
Connection
from
logging
import
DEBUG
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Callable
,
Optional
,
TypeVar
,
Union
import
msgspec
import
msgspec
import
psutil
import
psutil
import
zmq
import
zmq
import
zmq.asyncio
import
zmq.asyncio
from
vllm.config
import
VllmConfig
from
vllm.config
import
ParallelConfig
,
VllmConfig
from
vllm.distributed
import
stateless_destroy_torch_distributed_process_group
from
vllm.executor.multiproc_worker_utils
import
_add_prefix
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.transformers_utils.config
import
(
from
vllm.transformers_utils.config
import
(
...
@@ -23,12 +26,14 @@ from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname,
...
@@ -23,12 +26,14 @@ from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname,
zmq_socket_ctx
)
zmq_socket_ctx
)
from
vllm.v1.core.kv_cache_utils
import
(
get_kv_cache_config
,
from
vllm.v1.core.kv_cache_utils
import
(
get_kv_cache_config
,
unify_kv_cache_configs
)
unify_kv_cache_configs
)
from
vllm.v1.core.sched.interface
import
SchedulerInterface
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.scheduler
import
Scheduler
as
V1Scheduler
from
vllm.v1.core.sched.scheduler
import
Scheduler
as
V1Scheduler
from
vllm.v1.engine
import
(
EngineCoreOutputs
,
EngineCoreRequest
,
from
vllm.v1.engine
import
(
EngineCoreOutputs
,
EngineCoreRequest
,
EngineCoreRequestType
,
UtilityOutput
)
EngineCoreRequestType
,
UtilityOutput
)
from
vllm.v1.engine.mm_input_cache
import
MMInputCacheServer
from
vllm.v1.engine.mm_input_cache
import
MMInputCacheServer
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
...
@@ -39,6 +44,8 @@ logger = init_logger(__name__)
...
@@ -39,6 +44,8 @@ logger = init_logger(__name__)
POLLING_TIMEOUT_S
=
2.5
POLLING_TIMEOUT_S
=
2.5
_R
=
TypeVar
(
'_R'
)
# Return type for collective_rpc
class
EngineCore
:
class
EngineCore
:
"""Inner loop of vLLM's Engine."""
"""Inner loop of vLLM's Engine."""
...
@@ -60,8 +67,9 @@ class EngineCore:
...
@@ -60,8 +67,9 @@ class EngineCore:
self
.
model_executor
=
executor_class
(
vllm_config
)
self
.
model_executor
=
executor_class
(
vllm_config
)
# Setup KV Caches and update CacheConfig after profiling.
# Setup KV Caches and update CacheConfig after profiling.
num_gpu_blocks
,
num_cpu_blocks
=
self
.
_initialize_kv_caches
(
num_gpu_blocks
,
num_cpu_blocks
,
kv_cache_config
=
\
vllm_config
)
self
.
_initialize_kv_caches
(
vllm_config
)
vllm_config
.
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
vllm_config
.
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
vllm_config
.
cache_config
.
num_cpu_blocks
=
num_cpu_blocks
vllm_config
.
cache_config
.
num_cpu_blocks
=
num_cpu_blocks
...
@@ -84,14 +92,16 @@ class EngineCore:
...
@@ -84,14 +92,16 @@ class EngineCore:
"compatibility may not be maintained."
,
"compatibility may not be maintained."
,
vllm_config
.
scheduler_config
.
scheduler_cls
)
vllm_config
.
scheduler_config
.
scheduler_cls
)
self
.
scheduler
=
Scheduler
(
self
.
scheduler
:
SchedulerInterface
=
Scheduler
(
scheduler_config
=
vllm_config
.
scheduler_config
,
scheduler_config
=
vllm_config
.
scheduler_config
,
model_config
=
vllm_config
.
model_config
,
model_config
=
vllm_config
.
model_config
,
cache_config
=
vllm_config
.
cache_config
,
cache_config
=
vllm_config
.
cache_config
,
lora_config
=
vllm_config
.
lora_config
,
lora_config
=
vllm_config
.
lora_config
,
speculative_config
=
vllm_config
.
speculative_config
,
kv_cache_config
=
kv_cache_config
,
log_stats
=
self
.
log_stats
,
structured_output_manager
=
self
.
structured_output_manager
,
structured_output_manager
=
self
.
structured_output_manager
,
include_finished_set
=
vllm_config
.
parallel_config
.
data_parallel_size
>
1
,
log_stats
=
self
.
log_stats
,
)
)
# Setup MM Input Mapper.
# Setup MM Input Mapper.
...
@@ -110,8 +120,8 @@ class EngineCore:
...
@@ -110,8 +120,8 @@ class EngineCore:
self
.
batch_queue_size
)
self
.
batch_queue_size
)
self
.
batch_queue
=
queue
.
Queue
(
self
.
batch_queue_size
)
self
.
batch_queue
=
queue
.
Queue
(
self
.
batch_queue_size
)
def
_initialize_kv_caches
(
self
,
def
_initialize_kv_caches
(
vllm_config
:
VllmConfig
)
->
tuple
[
int
,
int
]:
self
,
vllm_config
:
VllmConfig
)
->
tuple
[
int
,
int
,
KVCacheConfig
]:
start
=
time
.
time
()
start
=
time
.
time
()
# Get all kv cache needed by the model
# Get all kv cache needed by the model
...
@@ -136,13 +146,14 @@ class EngineCore:
...
@@ -136,13 +146,14 @@ class EngineCore:
unify_kv_cache_configs
(
kv_cache_configs
)
unify_kv_cache_configs
(
kv_cache_configs
)
# All workers have the same kv_cache_config except layer names, so use
# All workers have the same kv_cache_config except layer names, so use
# an arbitrary one to
get the number of blocks
.
# an arbitrary one to
initialize the scheduler
.
assert
all
([
assert
all
([
cfg
.
num_blocks
==
kv_cache_configs
[
0
].
num_blocks
cfg
.
num_blocks
==
kv_cache_configs
[
0
].
num_blocks
for
cfg
in
kv_cache_configs
for
cfg
in
kv_cache_configs
])
])
num_gpu_blocks
=
kv_cache_configs
[
0
].
num_blocks
num_gpu_blocks
=
kv_cache_configs
[
0
].
num_blocks
num_cpu_blocks
=
0
num_cpu_blocks
=
0
scheduler_kv_cache_config
=
kv_cache_configs
[
0
]
# Initialize kv cache and warmup the execution
# Initialize kv cache and warmup the execution
self
.
model_executor
.
initialize_from_config
(
kv_cache_configs
)
self
.
model_executor
.
initialize_from_config
(
kv_cache_configs
)
...
@@ -150,7 +161,7 @@ class EngineCore:
...
@@ -150,7 +161,7 @@ class EngineCore:
elapsed
=
time
.
time
()
-
start
elapsed
=
time
.
time
()
-
start
logger
.
info
((
"init engine (profile, create kv cache, "
logger
.
info
((
"init engine (profile, create kv cache, "
"warmup model) took %.2f seconds"
),
elapsed
)
"warmup model) took %.2f seconds"
),
elapsed
)
return
num_gpu_blocks
,
num_cpu_blocks
return
num_gpu_blocks
,
num_cpu_blocks
,
scheduler_kv_cache_config
def
add_request
(
self
,
request
:
EngineCoreRequest
):
def
add_request
(
self
,
request
:
EngineCoreRequest
):
"""Add request to the scheduler."""
"""Add request to the scheduler."""
...
@@ -253,8 +264,8 @@ class EngineCore:
...
@@ -253,8 +264,8 @@ class EngineCore:
def
sleep
(
self
,
level
:
int
=
1
):
def
sleep
(
self
,
level
:
int
=
1
):
self
.
model_executor
.
sleep
(
level
)
self
.
model_executor
.
sleep
(
level
)
def
wake_up
(
self
):
def
wake_up
(
self
,
tags
:
Optional
[
list
[
str
]]
=
None
):
self
.
model_executor
.
wake_up
()
self
.
model_executor
.
wake_up
(
tags
)
def
is_sleeping
(
self
)
->
bool
:
def
is_sleeping
(
self
)
->
bool
:
return
self
.
model_executor
.
is_sleeping
return
self
.
model_executor
.
is_sleeping
...
@@ -274,6 +285,24 @@ class EngineCore:
...
@@ -274,6 +285,24 @@ class EngineCore:
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
model_executor
.
pin_lora
(
lora_id
)
return
self
.
model_executor
.
pin_lora
(
lora_id
)
def
save_sharded_state
(
self
,
path
:
str
,
pattern
:
Optional
[
str
]
=
None
,
max_size
:
Optional
[
int
]
=
None
,
)
->
None
:
self
.
model_executor
.
save_sharded_state
(
path
=
path
,
pattern
=
pattern
,
max_size
=
max_size
)
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
[...,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
tuple
=
(),
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
return
self
.
model_executor
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
class
EngineCoreProc
(
EngineCore
):
class
EngineCoreProc
(
EngineCore
):
"""ZMQ-wrapper for running EngineCore in background process."""
"""ZMQ-wrapper for running EngineCore in background process."""
...
@@ -282,10 +311,10 @@ class EngineCoreProc(EngineCore):
...
@@ -282,10 +311,10 @@ class EngineCoreProc(EngineCore):
self
,
self
,
input_path
:
str
,
input_path
:
str
,
output_path
:
str
,
output_path
:
str
,
ready_pipe
:
Connection
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
executor_class
:
type
[
Executor
],
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
log_stats
:
bool
,
engine_index
:
int
=
0
,
):
):
super
().
__init__
(
vllm_config
,
executor_class
,
log_stats
)
super
().
__init__
(
vllm_config
,
executor_class
,
log_stats
)
...
@@ -301,14 +330,20 @@ class EngineCoreProc(EngineCore):
...
@@ -301,14 +330,20 @@ class EngineCoreProc(EngineCore):
args
=
(
input_path
,
),
args
=
(
input_path
,
),
daemon
=
True
).
start
()
daemon
=
True
).
start
()
threading
.
Thread
(
target
=
self
.
process_output_socket
,
threading
.
Thread
(
target
=
self
.
process_output_socket
,
args
=
(
output_path
,
),
args
=
(
output_path
,
engine_index
),
daemon
=
True
).
start
()
daemon
=
True
).
start
()
# Send Readiness signal to EngineClient.
self
.
global_unfinished_reqs
=
False
ready_pipe
.
send
({
"status"
:
"READY"
})
self
.
step_fn
=
(
self
.
step
if
self
.
batch_queue
is
None
else
self
.
step_with_batch_queue
)
@
staticmethod
@
staticmethod
def
run_engine_core
(
*
args
,
**
kwargs
):
def
run_engine_core
(
*
args
,
dp_rank
:
int
=
0
,
local_dp_rank
:
int
=
0
,
ready_pipe
,
**
kwargs
):
"""Launch EngineCore busy loop in background process."""
"""Launch EngineCore busy loop in background process."""
# Signal handler used for graceful termination.
# Signal handler used for graceful termination.
...
@@ -330,9 +365,21 @@ class EngineCoreProc(EngineCore):
...
@@ -330,9 +365,21 @@ class EngineCoreProc(EngineCore):
signal
.
signal
(
signal
.
SIGINT
,
signal_handler
)
signal
.
signal
(
signal
.
SIGINT
,
signal_handler
)
parent_process
=
psutil
.
Process
().
parent
()
parent_process
=
psutil
.
Process
().
parent
()
engine_core
=
None
engine_core
:
Optional
[
EngineCoreProc
]
=
None
try
:
try
:
engine_core
=
EngineCoreProc
(
*
args
,
**
kwargs
)
parallel_config
:
ParallelConfig
=
kwargs
[
"vllm_config"
].
parallel_config
if
parallel_config
.
data_parallel_size
>
1
:
# Set data parallel rank for this engine process.
parallel_config
.
data_parallel_rank
=
dp_rank
parallel_config
.
data_parallel_rank_local
=
local_dp_rank
engine_core
=
DPEngineCoreProc
(
*
args
,
**
kwargs
)
else
:
engine_core
=
EngineCoreProc
(
*
args
,
**
kwargs
)
# Send Readiness signal to EngineClient.
ready_pipe
.
send
({
"status"
:
"READY"
})
engine_core
.
run_busy_loop
()
engine_core
.
run_busy_loop
()
except
SystemExit
:
except
SystemExit
:
...
@@ -350,28 +397,44 @@ class EngineCoreProc(EngineCore):
...
@@ -350,28 +397,44 @@ class EngineCoreProc(EngineCore):
def
run_busy_loop
(
self
):
def
run_busy_loop
(
self
):
"""Core busy loop of the EngineCore."""
"""Core busy loop of the EngineCore."""
step_fn
=
(
self
.
step
if
self
.
batch_queue
is
None
else
self
.
step_with_batch_queue
)
# Loop until process is sent a SIGINT or SIGTERM
# Loop until process is sent a SIGINT or SIGTERM
while
True
:
while
True
:
# 1) Poll the input queue until there is work to do.
# 1) Poll the input queue until there is work to do.
while
not
self
.
scheduler
.
has_requests
():
self
.
_process_input_queue
()
logger
.
debug
(
"EngineCore busy loop waiting."
)
# 2) Step the engine core and return the outputs.
req
=
self
.
input_queue
.
get
()
self
.
_process_engine_step
()
self
.
_handle_client_request
(
*
req
)
def
_process_input_queue
(
self
):
# 2) Handle any new client requests.
"""Exits when an engine step needs to be performed."""
while
not
self
.
input_queue
.
empty
():
req
=
self
.
input_queue
.
get_nowait
()
waited
=
False
self
.
_handle_client_request
(
*
req
)
while
not
self
.
global_unfinished_reqs
and
not
(
self
.
scheduler
.
has_requests
()):
# 3) Step the engine core.
if
logger
.
isEnabledFor
(
DEBUG
)
and
self
.
input_queue
.
empty
():
outputs
=
step_fn
()
logger
.
debug
(
"EngineCore waiting for work."
)
waited
=
True
# 4) Put EngineCoreOutputs into the output queue.
req
=
self
.
input_queue
.
get
()
if
outputs
is
not
None
:
self
.
_handle_client_request
(
*
req
)
self
.
output_queue
.
put_nowait
(
outputs
)
if
waited
:
logger
.
debug
(
"EngineCore loop active - local unfinished: %s, finished: %s."
,
self
.
scheduler
.
has_unfinished_requests
(),
self
.
scheduler
.
has_finished_requests
())
# Handle any more client requests.
while
not
self
.
input_queue
.
empty
():
req
=
self
.
input_queue
.
get_nowait
()
self
.
_handle_client_request
(
*
req
)
def
_process_engine_step
(
self
):
"""Called only when there are unfinished local requests."""
# Step the engine core.
outputs
=
self
.
step_fn
()
# Put EngineCoreOutputs into the output queue.
if
outputs
is
not
None
:
self
.
output_queue
.
put_nowait
(
outputs
)
def
_handle_client_request
(
self
,
request_type
:
EngineCoreRequestType
,
def
_handle_client_request
(
self
,
request_type
:
EngineCoreRequestType
,
request
:
Any
)
->
None
:
request
:
Any
)
->
None
:
...
@@ -381,6 +444,10 @@ class EngineCoreProc(EngineCore):
...
@@ -381,6 +444,10 @@ class EngineCoreProc(EngineCore):
self
.
add_request
(
request
)
self
.
add_request
(
request
)
elif
request_type
==
EngineCoreRequestType
.
ABORT
:
elif
request_type
==
EngineCoreRequestType
.
ABORT
:
self
.
abort_requests
(
request
)
self
.
abort_requests
(
request
)
elif
request_type
==
EngineCoreRequestType
.
START_DP
:
if
not
self
.
global_unfinished_reqs
:
logger
.
debug
(
"EngineCore starting idle loop."
)
self
.
global_unfinished_reqs
=
True
elif
request_type
==
EngineCoreRequestType
.
UTILITY
:
elif
request_type
==
EngineCoreRequestType
.
UTILITY
:
call_id
,
method_name
,
args
=
request
call_id
,
method_name
,
args
=
request
output
=
UtilityOutput
(
call_id
)
output
=
UtilityOutput
(
call_id
)
...
@@ -431,7 +498,7 @@ class EngineCoreProc(EngineCore):
...
@@ -431,7 +498,7 @@ class EngineCoreProc(EngineCore):
# Push to input queue for core busy loop.
# Push to input queue for core busy loop.
self
.
input_queue
.
put_nowait
((
request_type
,
request
))
self
.
input_queue
.
put_nowait
((
request_type
,
request
))
def
process_output_socket
(
self
,
output_path
:
str
):
def
process_output_socket
(
self
,
output_path
:
str
,
engine_index
:
int
):
"""Output socket IO thread."""
"""Output socket IO thread."""
# Msgpack serialization encoding.
# Msgpack serialization encoding.
...
@@ -442,5 +509,114 @@ class EngineCoreProc(EngineCore):
...
@@ -442,5 +509,114 @@ class EngineCoreProc(EngineCore):
with
zmq_socket_ctx
(
output_path
,
zmq
.
constants
.
PUSH
)
as
socket
:
with
zmq_socket_ctx
(
output_path
,
zmq
.
constants
.
PUSH
)
as
socket
:
while
True
:
while
True
:
outputs
=
self
.
output_queue
.
get
()
outputs
=
self
.
output_queue
.
get
()
outputs
.
engine_index
=
engine_index
encoder
.
encode_into
(
outputs
,
buffer
)
encoder
.
encode_into
(
outputs
,
buffer
)
socket
.
send_multipart
((
buffer
,
),
copy
=
False
)
socket
.
send
(
buffer
,
copy
=
False
)
ENGINE_PAUSED_OUTPUTS
=
EngineCoreOutputs
(
engine_paused
=
True
)
class
DPEngineCoreProc
(
EngineCoreProc
):
"""ZMQ-wrapper for running EngineCore in background process
in a data parallel context."""
def
__init__
(
self
,
input_path
:
str
,
output_path
:
str
,
vllm_config
:
VllmConfig
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
):
# Add process-specific prefix to stdout and stderr before
# we initialize the engine.
from
multiprocessing
import
current_process
process_name
=
current_process
().
name
pid
=
os
.
getpid
()
_add_prefix
(
sys
.
stdout
,
process_name
,
pid
)
_add_prefix
(
sys
.
stderr
,
process_name
,
pid
)
dp_size
=
vllm_config
.
parallel_config
.
data_parallel_size
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
local_dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank_local
assert
dp_size
>
1
assert
0
<=
local_dp_rank
<=
dp_rank
<
dp_size
from
vllm.platforms
import
current_platform
if
current_platform
.
is_cuda_alike
():
from
vllm.platforms.cuda
import
device_id_to_physical_device_id
tp_size
=
vllm_config
.
parallel_config
.
tensor_parallel_size
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
","
.
join
(
str
(
device_id_to_physical_device_id
(
i
))
for
i
in
range
(
local_dp_rank
*
tp_size
,
(
local_dp_rank
+
1
)
*
tp_size
))
self
.
dp_group
=
vllm_config
.
parallel_config
.
stateless_init_dp_group
()
# Initialize the engine after setting up environment.
super
().
__init__
(
input_path
,
output_path
,
vllm_config
,
executor_class
,
log_stats
,
dp_rank
)
# Counts forward-passes of the model so that we can synchronize
# finished with DP peers every N steps.
self
.
counter
=
0
def
shutdown
(
self
):
super
().
shutdown
()
if
dp_group
:
=
getattr
(
self
,
"dp_group"
,
None
):
stateless_destroy_torch_distributed_process_group
(
dp_group
)
def
run_busy_loop
(
self
):
"""Core busy loop of the EngineCore for data parallel case."""
# Loop until process is sent a SIGINT or SIGTERM
while
True
:
# 1) Poll the input queue until there is work to do.
self
.
_process_input_queue
()
local_unfinished_reqs
=
self
.
scheduler
.
has_unfinished_requests
()
if
local_unfinished_reqs
:
# 2) Step the engine core.
self
.
_process_engine_step
()
# Check if we have now finished all requests.
local_unfinished_reqs
=
(
self
.
scheduler
.
has_unfinished_requests
())
else
:
if
self
.
scheduler
.
has_finished_requests
():
# There are no unfinished requests, but there are some
# finished requests remaining to be removed from the
# batch state. This engine step won't perform a forward
# pass but will flush the finished requests to ensure
# up-to-date state is returned in the engine outputs.
self
.
_process_engine_step
()
if
not
self
.
global_unfinished_reqs
:
# All engines are idle.
continue
# There must be unfinished requests in DP peers, run a
# dummy forward pass.
self
.
execute_dummy_batch
()
# 3) All-reduce operation to determine global unfinished reqs.
self
.
global_unfinished_reqs
=
self
.
_has_global_unfinished_reqs
(
local_unfinished_reqs
)
if
not
self
.
global_unfinished_reqs
:
# Notify client that we are pausing the loop.
self
.
output_queue
.
put_nowait
(
ENGINE_PAUSED_OUTPUTS
)
def
_has_global_unfinished_reqs
(
self
,
local_unfinished
:
bool
)
->
bool
:
# Optimization - only perform finish-sync all-reduce every 16 steps.
self
.
counter
+=
1
if
self
.
counter
!=
16
:
return
True
self
.
counter
=
0
return
ParallelConfig
.
has_unfinished_dp
(
self
.
dp_group
,
local_unfinished
)
vllm/v1/engine/core_client.py
View file @
675ba75f
...
@@ -8,10 +8,11 @@ import threading
...
@@ -8,10 +8,11 @@ import threading
import
uuid
import
uuid
import
weakref
import
weakref
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Awaitable
,
Sequence
from
concurrent.futures
import
Future
from
concurrent.futures
import
Future
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
,
field
from
threading
import
Thread
from
threading
import
Thread
from
typing
import
Any
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
TypeVar
,
Union
import
zmq
import
zmq
import
zmq.asyncio
import
zmq.asyncio
...
@@ -32,6 +33,8 @@ logger = init_logger(__name__)
...
@@ -32,6 +33,8 @@ logger = init_logger(__name__)
AnyFuture
=
Union
[
asyncio
.
Future
[
Any
],
Future
[
Any
]]
AnyFuture
=
Union
[
asyncio
.
Future
[
Any
],
Future
[
Any
]]
_R
=
TypeVar
(
'_R'
)
# Return type for collective_rpc
class
EngineCoreClient
(
ABC
):
class
EngineCoreClient
(
ABC
):
"""
"""
...
@@ -60,6 +63,9 @@ class EngineCoreClient(ABC):
...
@@ -60,6 +63,9 @@ class EngineCoreClient(ABC):
"is not currently supported."
)
"is not currently supported."
)
if
multiprocess_mode
and
asyncio_mode
:
if
multiprocess_mode
and
asyncio_mode
:
if
vllm_config
.
parallel_config
.
data_parallel_size
>
1
:
return
DPAsyncMPClient
(
vllm_config
,
executor_class
,
log_stats
)
return
AsyncMPClient
(
vllm_config
,
executor_class
,
log_stats
)
return
AsyncMPClient
(
vllm_config
,
executor_class
,
log_stats
)
if
multiprocess_mode
and
not
asyncio_mode
:
if
multiprocess_mode
and
not
asyncio_mode
:
...
@@ -86,7 +92,7 @@ class EngineCoreClient(ABC):
...
@@ -86,7 +92,7 @@ class EngineCoreClient(ABC):
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
def
wake_up
(
self
)
->
None
:
def
wake_up
(
self
,
tags
:
Optional
[
list
[
str
]]
=
None
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
def
is_sleeping
(
self
)
->
bool
:
def
is_sleeping
(
self
)
->
bool
:
...
@@ -113,6 +119,19 @@ class EngineCoreClient(ABC):
...
@@ -113,6 +119,19 @@ class EngineCoreClient(ABC):
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
raise
NotImplementedError
raise
NotImplementedError
def
save_sharded_state
(
self
,
path
:
str
,
pattern
:
Optional
[
str
]
=
None
,
max_size
:
Optional
[
int
]
=
None
)
->
None
:
raise
NotImplementedError
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
[...,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
tuple
=
(),
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
raise
NotImplementedError
async
def
get_output_async
(
self
)
->
EngineCoreOutputs
:
async
def
get_output_async
(
self
)
->
EngineCoreOutputs
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -128,7 +147,7 @@ class EngineCoreClient(ABC):
...
@@ -128,7 +147,7 @@ class EngineCoreClient(ABC):
async
def
sleep_async
(
self
,
level
:
int
=
1
)
->
None
:
async
def
sleep_async
(
self
,
level
:
int
=
1
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
async
def
wake_up_async
(
self
)
->
None
:
async
def
wake_up_async
(
self
,
tags
:
Optional
[
list
[
str
]]
=
None
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
async
def
is_sleeping_async
(
self
)
->
bool
:
async
def
is_sleeping_async
(
self
)
->
bool
:
...
@@ -149,6 +168,20 @@ class EngineCoreClient(ABC):
...
@@ -149,6 +168,20 @@ class EngineCoreClient(ABC):
async
def
pin_lora_async
(
self
,
lora_id
:
int
)
->
bool
:
async
def
pin_lora_async
(
self
,
lora_id
:
int
)
->
bool
:
raise
NotImplementedError
raise
NotImplementedError
async
def
save_sharded_state_async
(
self
,
path
:
str
,
pattern
:
Optional
[
str
]
=
None
,
max_size
:
Optional
[
int
]
=
None
)
->
None
:
raise
NotImplementedError
async
def
collective_rpc_async
(
self
,
method
:
Union
[
str
,
Callable
[...,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
tuple
=
(),
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
raise
NotImplementedError
class
InprocClient
(
EngineCoreClient
):
class
InprocClient
(
EngineCoreClient
):
"""
"""
...
@@ -185,8 +218,8 @@ class InprocClient(EngineCoreClient):
...
@@ -185,8 +218,8 @@ class InprocClient(EngineCoreClient):
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
self
.
engine_core
.
sleep
(
level
)
self
.
engine_core
.
sleep
(
level
)
def
wake_up
(
self
)
->
None
:
def
wake_up
(
self
,
tags
:
Optional
[
list
[
str
]]
=
None
)
->
None
:
self
.
engine_core
.
wake_up
()
self
.
engine_core
.
wake_up
(
tags
)
def
is_sleeping
(
self
)
->
bool
:
def
is_sleeping
(
self
)
->
bool
:
return
self
.
engine_core
.
is_sleeping
()
return
self
.
engine_core
.
is_sleeping
()
...
@@ -206,29 +239,88 @@ class InprocClient(EngineCoreClient):
...
@@ -206,29 +239,88 @@ class InprocClient(EngineCoreClient):
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
engine_core
.
pin_lora
(
lora_id
)
return
self
.
engine_core
.
pin_lora
(
lora_id
)
def
save_sharded_state
(
self
,
path
:
str
,
pattern
:
Optional
[
str
]
=
None
,
max_size
:
Optional
[
int
]
=
None
)
->
None
:
self
.
engine_core
.
save_sharded_state
(
path
,
pattern
,
max_size
)
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
[...,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
tuple
=
(),
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
return
self
.
engine_core
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
class
CoreEngine
:
"""One per data parallel rank."""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
ctx
:
Union
[
zmq
.
Context
,
zmq
.
asyncio
.
Context
],
output_path
:
str
,
index
:
int
=
0
,
local_dp_rank
:
int
=
0
,
):
# Paths and sockets for IPC.
input_path
=
get_open_zmq_ipc_path
()
self
.
input_socket
=
make_zmq_socket
(
ctx
,
input_path
,
zmq
.
constants
.
PUSH
)
try
:
# Start EngineCore in background process.
self
.
proc_handle
=
BackgroundProcHandle
(
input_path
=
input_path
,
output_path
=
output_path
,
process_name
=
f
"EngineCore_
{
index
}
"
,
target_fn
=
EngineCoreProc
.
run_engine_core
,
process_kwargs
=
{
"vllm_config"
:
vllm_config
,
"dp_rank"
:
index
,
"local_dp_rank"
:
local_dp_rank
,
"executor_class"
:
executor_class
,
"log_stats"
:
log_stats
,
})
self
.
num_reqs_in_flight
=
0
finally
:
if
not
hasattr
(
self
,
"num_reqs_in_flight"
):
# Ensure socket is closed if process fails to start.
self
.
close
()
def
send_multipart
(
self
,
msg_parts
:
Sequence
):
return
self
.
input_socket
.
send_multipart
(
msg_parts
,
copy
=
False
)
def
close
(
self
):
if
proc_handle
:
=
getattr
(
self
,
"proc_handle"
,
None
):
proc_handle
.
shutdown
()
if
socket
:
=
getattr
(
self
,
"input_socket"
,
None
):
socket
.
close
(
linger
=
0
)
@
dataclass
@
dataclass
class
BackgroundResources
:
class
BackgroundResources
:
"""Used as a finalizer for clean shutdown, avoiding
"""Used as a finalizer for clean shutdown, avoiding
circular reference back to the client object."""
circular reference back to the client object."""
ctx
:
zmq
.
Context
ctx
:
Union
[
zmq
.
Context
]
core_engines
:
list
[
CoreEngine
]
=
field
(
default_factory
=
list
)
output_socket
:
Optional
[
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]]
=
None
output_socket
:
Optional
[
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]]
=
None
input_socket
:
Optional
[
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]]
=
None
proc_handle
:
Optional
[
BackgroundProcHandle
]
=
None
shutdown_path
:
Optional
[
str
]
=
None
shutdown_path
:
Optional
[
str
]
=
None
def
__call__
(
self
):
def
__call__
(
self
):
"""Clean up background resources."""
"""Clean up background resources."""
if
self
.
proc_handle
is
not
None
:
for
core_engine
in
self
.
core_engines
:
self
.
proc_handle
.
shutdown
()
core_engine
.
close
()
# ZMQ context termination can hang if the sockets
# ZMQ context termination can hang if the sockets
# aren't explicitly closed first.
# aren't explicitly closed first.
if
self
.
output_socket
is
not
None
:
if
self
.
output_socket
is
not
None
:
self
.
output_socket
.
close
(
linger
=
0
)
self
.
output_socket
.
close
(
linger
=
0
)
if
self
.
input_socket
is
not
None
:
self
.
input_socket
.
close
(
linger
=
0
)
if
self
.
shutdown_path
is
not
None
:
if
self
.
shutdown_path
is
not
None
:
# We must ensure that the sync output socket is
# We must ensure that the sync output socket is
# closed cleanly in its own thread.
# closed cleanly in its own thread.
...
@@ -284,7 +376,7 @@ class MPClient(EngineCoreClient):
...
@@ -284,7 +376,7 @@ class MPClient(EngineCoreClient):
self
.
decoder
=
MsgpackDecoder
(
EngineCoreOutputs
)
self
.
decoder
=
MsgpackDecoder
(
EngineCoreOutputs
)
# ZMQ setup.
# ZMQ setup.
sync_ctx
=
zmq
.
Context
()
sync_ctx
=
zmq
.
Context
(
io_threads
=
2
)
self
.
ctx
=
zmq
.
asyncio
.
Context
(
sync_ctx
)
if
asyncio_mode
else
sync_ctx
self
.
ctx
=
zmq
.
asyncio
.
Context
(
sync_ctx
)
if
asyncio_mode
else
sync_ctx
# This will ensure resources created so far are closed
# This will ensure resources created so far are closed
...
@@ -293,28 +385,38 @@ class MPClient(EngineCoreClient):
...
@@ -293,28 +385,38 @@ class MPClient(EngineCoreClient):
self
.
resources
=
BackgroundResources
(
ctx
=
sync_ctx
)
self
.
resources
=
BackgroundResources
(
ctx
=
sync_ctx
)
self
.
_finalizer
=
weakref
.
finalize
(
self
,
self
.
resources
)
self
.
_finalizer
=
weakref
.
finalize
(
self
,
self
.
resources
)
# Paths for IPC.
# Paths
and sockets
for IPC.
self
.
output_path
=
get_open_zmq_ipc_path
()
self
.
output_path
=
get_open_zmq_ipc_path
()
input_path
=
get_open_zmq_ipc_path
()
# Start EngineCore in background process.
new_core_engine
=
lambda
index
,
local_dp_rank
=
None
:
CoreEngine
(
self
.
resources
.
proc_handle
=
BackgroundProcHandle
(
vllm_config
,
executor_class
,
log_stats
,
self
.
ctx
,
self
.
output_path
,
input_path
=
input_path
,
index
,
local_dp_rank
)
output_path
=
self
.
output_path
,
process_name
=
"EngineCore"
,
# Start engine core process(es).
target_fn
=
EngineCoreProc
.
run_engine_core
,
self
.
_init_core_engines
(
vllm_config
,
new_core_engine
,
process_kwargs
=
{
self
.
resources
.
core_engines
)
"vllm_config"
:
vllm_config
,
"executor_class"
:
executor_class
,
# Wait for engine core process(es) to start.
"log_stats"
:
log_stats
,
for
engine
in
self
.
resources
.
core_engines
:
})
engine
.
proc_handle
.
wait_for_startup
()
# Create input socket.
self
.
resources
.
input_socket
=
make_zmq_socket
(
self
.
ctx
,
input_path
,
zmq
.
constants
.
PUSH
)
self
.
input_socket
=
self
.
resources
.
input_socket
self
.
utility_results
:
dict
[
int
,
AnyFuture
]
=
{}
self
.
utility_results
:
dict
[
int
,
AnyFuture
]
=
{}
def
_init_core_engines
(
self
,
vllm_config
:
VllmConfig
,
new_core_engine
:
Callable
[[
int
,
Optional
[
int
]],
CoreEngine
],
core_engines
:
list
[
CoreEngine
],
)
->
None
:
# Default case - single core engine.
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
local_dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank_local
core_engine
=
new_core_engine
(
dp_rank
,
local_dp_rank
if
local_dp_rank
is
not
None
else
dp_rank
)
core_engines
.
append
(
core_engine
)
self
.
core_engine
=
core_engine
def
shutdown
(
self
):
def
shutdown
(
self
):
self
.
_finalizer
()
self
.
_finalizer
()
...
@@ -356,9 +458,9 @@ class SyncMPClient(MPClient):
...
@@ -356,9 +458,9 @@ class SyncMPClient(MPClient):
def
process_outputs_socket
():
def
process_outputs_socket
():
shutdown_socket
=
ctx
.
socket
(
zmq
.
PAIR
)
shutdown_socket
=
ctx
.
socket
(
zmq
.
PAIR
)
shutdown_socket
.
bind
(
shutdown_path
)
out_socket
=
make_zmq_socket
(
ctx
,
output_path
,
zmq
.
constants
.
PULL
)
out_socket
=
make_zmq_socket
(
ctx
,
output_path
,
zmq
.
constants
.
PULL
)
try
:
try
:
shutdown_socket
.
bind
(
shutdown_path
)
poller
=
zmq
.
Poller
()
poller
=
zmq
.
Poller
()
poller
.
register
(
shutdown_socket
)
poller
.
register
(
shutdown_socket
)
poller
.
register
(
out_socket
)
poller
.
register
(
out_socket
)
...
@@ -370,7 +472,7 @@ class SyncMPClient(MPClient):
...
@@ -370,7 +472,7 @@ class SyncMPClient(MPClient):
# shutdown signal, exit thread.
# shutdown signal, exit thread.
break
break
(
frame
,
)
=
out_socket
.
recv
_multipart
(
copy
=
False
)
frame
=
out_socket
.
recv
(
copy
=
False
)
outputs
=
decoder
.
decode
(
frame
.
buffer
)
outputs
=
decoder
.
decode
(
frame
.
buffer
)
if
outputs
.
utility_output
:
if
outputs
.
utility_output
:
_process_utility_output
(
outputs
.
utility_output
,
_process_utility_output
(
outputs
.
utility_output
,
...
@@ -391,18 +493,15 @@ class SyncMPClient(MPClient):
...
@@ -391,18 +493,15 @@ class SyncMPClient(MPClient):
def
get_output
(
self
)
->
EngineCoreOutputs
:
def
get_output
(
self
)
->
EngineCoreOutputs
:
return
self
.
outputs_queue
.
get
()
return
self
.
outputs_queue
.
get
()
def
_send_input
(
self
,
request_type
:
EngineCoreRequestType
,
def
_send_input
(
self
,
request_type
:
EngineCoreRequestType
,
request
:
Any
):
request
:
Any
)
->
None
:
# (RequestType, SerializedRequest)
# (RequestType, SerializedRequest)
msg
=
(
request_type
.
value
,
self
.
encoder
.
encode
(
request
))
msg
=
(
request_type
.
value
,
self
.
encoder
.
encode
(
request
))
self
.
input_socket
.
send_multipart
(
msg
,
copy
=
False
)
self
.
core_engine
.
send_multipart
(
msg
)
def
_
call_utility
(
self
,
method
:
str
,
*
args
)
->
Any
:
def
call_utility
(
self
,
method
:
str
,
*
args
)
->
Any
:
call_id
=
uuid
.
uuid1
().
int
>>
64
call_id
=
uuid
.
uuid1
().
int
>>
64
future
:
Future
[
Any
]
=
Future
()
future
:
Future
[
Any
]
=
Future
()
self
.
utility_results
[
call_id
]
=
future
self
.
utility_results
[
call_id
]
=
future
self
.
_send_input
(
EngineCoreRequestType
.
UTILITY
,
self
.
_send_input
(
EngineCoreRequestType
.
UTILITY
,
(
call_id
,
method
,
args
))
(
call_id
,
method
,
args
))
...
@@ -419,34 +518,48 @@ class SyncMPClient(MPClient):
...
@@ -419,34 +518,48 @@ class SyncMPClient(MPClient):
self
.
_send_input
(
EngineCoreRequestType
.
ABORT
,
request_ids
)
self
.
_send_input
(
EngineCoreRequestType
.
ABORT
,
request_ids
)
def
profile
(
self
,
is_start
:
bool
=
True
)
->
None
:
def
profile
(
self
,
is_start
:
bool
=
True
)
->
None
:
self
.
_
call_utility
(
"profile"
,
is_start
)
self
.
call_utility
(
"profile"
,
is_start
)
def
reset_prefix_cache
(
self
)
->
None
:
def
reset_prefix_cache
(
self
)
->
None
:
self
.
_
call_utility
(
"reset_prefix_cache"
)
self
.
call_utility
(
"reset_prefix_cache"
)
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
return
self
.
_
call_utility
(
"add_lora"
,
lora_request
)
return
self
.
call_utility
(
"add_lora"
,
lora_request
)
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
_
call_utility
(
"remove_lora"
,
lora_id
)
return
self
.
call_utility
(
"remove_lora"
,
lora_id
)
def
list_loras
(
self
)
->
set
[
int
]:
def
list_loras
(
self
)
->
set
[
int
]:
return
self
.
_
call_utility
(
"list_loras"
)
return
self
.
call_utility
(
"list_loras"
)
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
_
call_utility
(
"pin_lora"
,
lora_id
)
return
self
.
call_utility
(
"pin_lora"
,
lora_id
)
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
self
.
_
call_utility
(
"sleep"
,
level
)
self
.
call_utility
(
"sleep"
,
level
)
def
wake_up
(
self
)
->
None
:
def
wake_up
(
self
,
tags
:
Optional
[
list
[
str
]]
=
None
)
->
None
:
self
.
_
call_utility
(
"wake_up"
)
self
.
call_utility
(
"wake_up"
,
tags
)
def
is_sleeping
(
self
)
->
bool
:
def
is_sleeping
(
self
)
->
bool
:
return
self
.
_
call_utility
(
"is_sleeping"
)
return
self
.
call_utility
(
"is_sleeping"
)
def
execute_dummy_batch
(
self
)
->
None
:
def
execute_dummy_batch
(
self
)
->
None
:
self
.
_call_utility
(
"execute_dummy_batch"
)
self
.
call_utility
(
"execute_dummy_batch"
)
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
[...,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
tuple
=
(),
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
return
self
.
call_utility
(
"collective_rpc"
,
method
,
timeout
,
args
,
kwargs
)
def
save_sharded_state
(
self
,
path
:
str
,
pattern
:
Optional
[
str
]
=
None
,
max_size
:
Optional
[
int
]
=
None
)
->
None
:
self
.
call_utility
(
"save_sharded_state"
,
path
,
pattern
,
max_size
)
class
AsyncMPClient
(
MPClient
):
class
AsyncMPClient
(
MPClient
):
...
@@ -464,13 +577,21 @@ class AsyncMPClient(MPClient):
...
@@ -464,13 +577,21 @@ class AsyncMPClient(MPClient):
self
.
outputs_queue
:
Optional
[
asyncio
.
Queue
[
EngineCoreOutputs
]]
=
None
self
.
outputs_queue
:
Optional
[
asyncio
.
Queue
[
EngineCoreOutputs
]]
=
None
self
.
queue_task
:
Optional
[
asyncio
.
Task
]
=
None
self
.
queue_task
:
Optional
[
asyncio
.
Task
]
=
None
async
def
_start_output_queue_task
(
self
):
self
.
outputs_handler
:
Optional
[
Callable
[
[
AsyncMPClient
,
EngineCoreOutputs
],
Awaitable
[
None
]]]
=
None
def
_ensure_output_queue_task
(
self
):
if
self
.
outputs_queue
is
not
None
:
return
# Perform IO in separate task to parallelize as much as possible.
# Perform IO in separate task to parallelize as much as possible.
# Avoid task having direct reference back to the client.
# Avoid task having direct reference back to the client.
self
.
outputs_queue
=
asyncio
.
Queue
()
self
.
outputs_queue
=
asyncio
.
Queue
()
decoder
=
self
.
decoder
decoder
=
self
.
decoder
utility_results
=
self
.
utility_results
utility_results
=
self
.
utility_results
outputs_queue
=
self
.
outputs_queue
outputs_queue
=
self
.
outputs_queue
output_handler
=
self
.
outputs_handler
_self_ref
=
weakref
.
ref
(
self
)
if
output_handler
else
None
output_path
=
self
.
output_path
output_path
=
self
.
output_path
output_socket
=
make_zmq_socket
(
self
.
ctx
,
output_path
,
output_socket
=
make_zmq_socket
(
self
.
ctx
,
output_path
,
zmq
.
constants
.
PULL
)
zmq
.
constants
.
PULL
)
...
@@ -483,34 +604,52 @@ class AsyncMPClient(MPClient):
...
@@ -483,34 +604,52 @@ class AsyncMPClient(MPClient):
if
outputs
.
utility_output
:
if
outputs
.
utility_output
:
_process_utility_output
(
outputs
.
utility_output
,
_process_utility_output
(
outputs
.
utility_output
,
utility_results
)
utility_results
)
else
:
continue
if
output_handler
is
not
None
:
assert
_self_ref
is
not
None
_self
=
_self_ref
()
if
not
_self
:
# Client has been garbage collected, abort.
return
await
output_handler
(
_self
,
outputs
)
if
outputs
.
outputs
or
outputs
.
scheduler_stats
:
outputs_queue
.
put_nowait
(
outputs
)
outputs_queue
.
put_nowait
(
outputs
)
self
.
queue_task
=
asyncio
.
create_task
(
process_outputs_socket
(),
self
.
queue_task
=
asyncio
.
create_task
(
process_outputs_socket
(),
name
=
"EngineCoreOutputQueueTask"
)
name
=
"EngineCoreOutputQueueTask"
)
async
def
get_output_async
(
self
)
->
EngineCoreOutputs
:
async
def
get_output_async
(
self
)
->
EngineCoreOutputs
:
if
self
.
outputs_queue
is
None
:
self
.
_ensure_output_queue_task
()
await
self
.
_start_output_queue_task
()
assert
self
.
outputs_queue
is
not
None
assert
self
.
outputs_queue
is
not
None
return
await
self
.
outputs_queue
.
get
()
return
await
self
.
outputs_queue
.
get
()
async
def
_send_input
(
self
,
request_type
:
EngineCoreRequestType
,
async
def
_send_input
(
self
,
request_type
:
EngineCoreRequestType
,
request
:
Any
)
->
None
:
request
:
Any
)
->
None
:
await
self
.
core_engine
.
send_multipart
(
(
request_type
.
value
,
self
.
encoder
.
encode
(
request
)))
msg
=
(
request_type
.
value
,
self
.
encoder
.
encode
(
request
))
self
.
_ensure_output_queue_task
()
await
self
.
input_socket
.
send_multipart
(
msg
,
copy
=
False
)
if
self
.
outputs_queue
is
None
:
async
def
call_utility_async
(
self
,
method
:
str
,
*
args
)
->
Any
:
await
self
.
_start_output_queue_task
()
return
await
self
.
_call_utility_async
(
method
,
*
args
,
engine
=
self
.
core_engine
)
async
def
_call_utility_async
(
self
,
method
:
str
,
*
args
)
->
Any
:
async
def
_call_utility_async
(
self
,
method
:
str
,
*
args
,
engine
:
CoreEngine
,
)
->
Any
:
call_id
=
uuid
.
uuid1
().
int
>>
64
call_id
=
uuid
.
uuid1
().
int
>>
64
future
=
asyncio
.
get_running_loop
().
create_future
()
future
=
asyncio
.
get_running_loop
().
create_future
()
self
.
utility_results
[
call_id
]
=
future
self
.
utility_results
[
call_id
]
=
future
await
self
.
_send_input
(
EngineCoreRequestType
.
UTILITY
,
message
=
(
EngineCoreRequestType
.
UTILITY
.
value
,
(
call_id
,
method
,
args
))
self
.
encoder
.
encode
((
call_id
,
method
,
args
)))
await
engine
.
send_multipart
(
message
)
self
.
_ensure_output_queue_task
()
return
await
future
return
await
future
async
def
add_request_async
(
self
,
request
:
EngineCoreRequest
)
->
None
:
async
def
add_request_async
(
self
,
request
:
EngineCoreRequest
)
->
None
:
...
@@ -524,31 +663,162 @@ class AsyncMPClient(MPClient):
...
@@ -524,31 +663,162 @@ class AsyncMPClient(MPClient):
await
self
.
_send_input
(
EngineCoreRequestType
.
ABORT
,
request_ids
)
await
self
.
_send_input
(
EngineCoreRequestType
.
ABORT
,
request_ids
)
async
def
profile_async
(
self
,
is_start
:
bool
=
True
)
->
None
:
async
def
profile_async
(
self
,
is_start
:
bool
=
True
)
->
None
:
await
self
.
_
call_utility_async
(
"profile"
,
is_start
)
await
self
.
call_utility_async
(
"profile"
,
is_start
)
async
def
reset_prefix_cache_async
(
self
)
->
None
:
async
def
reset_prefix_cache_async
(
self
)
->
None
:
await
self
.
_
call_utility_async
(
"reset_prefix_cache"
)
await
self
.
call_utility_async
(
"reset_prefix_cache"
)
async
def
sleep_async
(
self
,
level
:
int
=
1
)
->
None
:
async
def
sleep_async
(
self
,
level
:
int
=
1
)
->
None
:
await
self
.
_
call_utility_async
(
"sleep"
,
level
)
await
self
.
call_utility_async
(
"sleep"
,
level
)
async
def
wake_up_async
(
self
)
->
None
:
async
def
wake_up_async
(
self
,
tags
:
Optional
[
list
[
str
]]
=
None
)
->
None
:
await
self
.
_
call_utility_async
(
"wake_up"
)
await
self
.
call_utility_async
(
"wake_up"
,
tags
)
async
def
is_sleeping_async
(
self
)
->
bool
:
async
def
is_sleeping_async
(
self
)
->
bool
:
return
await
self
.
_
call_utility_async
(
"is_sleeping"
)
return
await
self
.
call_utility_async
(
"is_sleeping"
)
async
def
execute_dummy_batch_async
(
self
)
->
None
:
async
def
execute_dummy_batch_async
(
self
)
->
None
:
await
self
.
_
call_utility_async
(
"execute_dummy_batch"
)
await
self
.
call_utility_async
(
"execute_dummy_batch"
)
async
def
add_lora_async
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
async
def
add_lora_async
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
return
await
self
.
_
call_utility_async
(
"add_lora"
,
lora_request
)
return
await
self
.
call_utility_async
(
"add_lora"
,
lora_request
)
async
def
remove_lora_async
(
self
,
lora_id
:
int
)
->
bool
:
async
def
remove_lora_async
(
self
,
lora_id
:
int
)
->
bool
:
return
await
self
.
_
call_utility_async
(
"remove_lora"
,
lora_id
)
return
await
self
.
call_utility_async
(
"remove_lora"
,
lora_id
)
async
def
list_loras_async
(
self
)
->
set
[
int
]:
async
def
list_loras_async
(
self
)
->
set
[
int
]:
return
await
self
.
_
call_utility_async
(
"list_loras"
)
return
await
self
.
call_utility_async
(
"list_loras"
)
async
def
pin_lora_async
(
self
,
lora_id
:
int
)
->
bool
:
async
def
pin_lora_async
(
self
,
lora_id
:
int
)
->
bool
:
return
await
self
.
_call_utility_async
(
"pin_lora"
,
lora_id
)
return
await
self
.
call_utility_async
(
"pin_lora"
,
lora_id
)
async
def
save_sharded_state_async
(
self
,
path
:
str
,
pattern
:
Optional
[
str
]
=
None
,
max_size
:
Optional
[
int
]
=
None
)
->
None
:
await
self
.
call_utility_async
(
"save_sharded_state"
,
path
,
pattern
,
max_size
)
async
def
collective_rpc_async
(
self
,
method
:
Union
[
str
,
Callable
[...,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
tuple
=
(),
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
return
await
self
.
call_utility_async
(
"collective_rpc"
,
method
,
timeout
,
args
,
kwargs
)
class
DPAsyncMPClient
(
AsyncMPClient
):
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)
EngineCore."""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
):
super
().
__init__
(
vllm_config
,
executor_class
,
log_stats
)
assert
len
(
self
.
core_engines
)
>
1
# Control message used for triggering dp idle mode loop.
self
.
start_dp_msg
=
(
EngineCoreRequestType
.
START_DP
.
value
,
self
.
encoder
.
encode
(
None
))
self
.
num_engines_running
=
0
self
.
reqs_in_flight
:
dict
[
str
,
CoreEngine
]
=
{}
self
.
outputs_handler
=
DPAsyncMPClient
.
process_engine_outputs
# type: ignore[assignment]
def
_init_core_engines
(
self
,
vllm_config
:
VllmConfig
,
new_core_engine
:
Callable
[[
int
,
Optional
[
int
]],
CoreEngine
],
core_engines
:
list
[
CoreEngine
],
)
->
None
:
# Launch a core engine for each data parallel rank.
dp_size
=
vllm_config
.
parallel_config
.
data_parallel_size
for
i
in
range
(
dp_size
):
# Multi-node not yet supported so local_dp_rank == dp_rank.
core_engines
.
append
(
new_core_engine
(
i
,
i
))
self
.
core_engines
=
core_engines
async
def
call_utility_async
(
self
,
method
:
str
,
*
args
)
->
Any
:
# Only the result from the first engine is returned.
return
(
await
asyncio
.
gather
(
*
[
self
.
_call_utility_async
(
method
,
*
args
,
engine
=
engine
)
for
engine
in
self
.
core_engines
]))[
0
]
async
def
add_request_async
(
self
,
request
:
EngineCoreRequest
)
->
None
:
# NOTE: text prompt is not needed in the core engine as it has been
# tokenized.
request
.
prompt
=
None
msg
=
(
EngineCoreRequestType
.
ADD
.
value
,
self
.
encoder
.
encode
(
request
))
chosen_engine
=
self
.
get_core_engine_for_request
()
self
.
reqs_in_flight
[
request
.
request_id
]
=
chosen_engine
chosen_engine
.
num_reqs_in_flight
+=
1
if
self
.
num_engines_running
>=
len
(
self
.
core_engines
):
await
chosen_engine
.
send_multipart
(
msg
)
else
:
# Send request to chosen engine and dp start loop
# control message to all other engines.
self
.
num_engines_running
+=
len
(
self
.
core_engines
)
await
asyncio
.
gather
(
*
[
engine
.
send_multipart
(
msg
if
engine
is
chosen_engine
else
self
.
start_dp_msg
)
for
engine
in
self
.
core_engines
])
self
.
_ensure_output_queue_task
()
def
get_core_engine_for_request
(
self
)
->
CoreEngine
:
return
min
(
self
.
core_engines
,
key
=
lambda
e
:
e
.
num_reqs_in_flight
)
@
staticmethod
async
def
process_engine_outputs
(
self
:
"DPAsyncMPClient"
,
outputs
:
EngineCoreOutputs
):
if
self
.
reqs_in_flight
:
for
req_id
in
outputs
.
finished_requests
or
():
if
engine
:
=
self
.
reqs_in_flight
.
pop
(
req_id
,
None
):
engine
.
num_reqs_in_flight
-=
1
if
outputs
.
engine_paused
:
assert
self
.
num_engines_running
>=
1
self
.
num_engines_running
-=
1
if
not
self
.
num_engines_running
and
self
.
reqs_in_flight
:
# If there are requests in flight here, they must have
# been sent after the engines paused. We must make
# sure to start the other engines:
self
.
num_engines_running
=
len
(
self
.
core_engines
)
coros
=
[
engine
.
send_multipart
(
self
.
start_dp_msg
)
for
engine
in
self
.
core_engines
if
not
engine
.
num_reqs_in_flight
]
if
coros
:
await
asyncio
.
gather
(
*
coros
)
async
def
abort_requests_async
(
self
,
request_ids
:
list
[
str
])
->
None
:
if
not
request_ids
:
return
if
len
(
request_ids
)
==
1
:
# Fast-path common case.
if
engine
:
=
self
.
reqs_in_flight
.
get
(
request_ids
[
0
]):
await
self
.
_abort_requests
(
request_ids
,
engine
)
return
by_engine
:
dict
[
CoreEngine
,
list
[
str
]]
=
{}
for
req_id
in
request_ids
:
if
engine
:
=
self
.
reqs_in_flight
.
get
(
req_id
):
by_engine
.
setdefault
(
engine
,
[]).
append
(
req_id
)
for
engine
,
req_ids
in
by_engine
.
items
():
await
self
.
_abort_requests
(
req_ids
,
engine
)
async
def
_abort_requests
(
self
,
request_ids
:
list
[
str
],
engine
:
CoreEngine
)
->
None
:
await
engine
.
send_multipart
((
EngineCoreRequestType
.
ABORT
.
value
,
self
.
encoder
.
encode
(
request_ids
)))
vllm/v1/engine/llm_engine.py
View file @
675ba75f
...
@@ -2,15 +2,16 @@
...
@@ -2,15 +2,16 @@
from
collections.abc
import
Mapping
from
collections.abc
import
Mapping
from
copy
import
copy
from
copy
import
copy
from
typing
import
Optional
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
Union
from
typing_extensions
import
TypeVar
from
typing_extensions
import
TypeVar
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.config
import
ParallelConfig
,
VllmConfig
from
vllm.config
import
ParallelConfig
,
VllmConfig
from
vllm.distributed
import
stateless_destroy_torch_distributed_process_group
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.metrics_types
import
StatLoggerBase
from
vllm.engine.metrics_types
import
StatLoggerBase
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
,
PromptType
from
vllm.inputs
import
PromptType
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
...
@@ -31,6 +32,7 @@ from vllm.v1.executor.abstract import Executor
...
@@ -31,6 +32,7 @@ from vllm.v1.executor.abstract import Executor
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_G
=
TypeVar
(
"_G"
,
bound
=
BaseTokenizerGroup
,
default
=
BaseTokenizerGroup
)
_G
=
TypeVar
(
"_G"
,
bound
=
BaseTokenizerGroup
,
default
=
BaseTokenizerGroup
)
_R
=
TypeVar
(
"_R"
,
default
=
Any
)
class
LLMEngine
:
class
LLMEngine
:
...
@@ -43,7 +45,6 @@ class LLMEngine:
...
@@ -43,7 +45,6 @@ class LLMEngine:
log_stats
:
bool
,
log_stats
:
bool
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
stat_loggers
:
Optional
[
dict
[
str
,
StatLoggerBase
]]
=
None
,
stat_loggers
:
Optional
[
dict
[
str
,
StatLoggerBase
]]
=
None
,
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
mm_registry
:
MultiModalRegistry
=
MULTIMODAL_REGISTRY
,
mm_registry
:
MultiModalRegistry
=
MULTIMODAL_REGISTRY
,
use_cached_outputs
:
bool
=
False
,
use_cached_outputs
:
bool
=
False
,
multiprocess_mode
:
bool
=
False
,
multiprocess_mode
:
bool
=
False
,
...
@@ -60,11 +61,13 @@ class LLMEngine:
...
@@ -60,11 +61,13 @@ class LLMEngine:
self
.
cache_config
=
vllm_config
.
cache_config
self
.
cache_config
=
vllm_config
.
cache_config
# important: init dp group before init the engine_core
# important: init dp group before init the engine_core
self
.
parallel_config
=
vllm_config
.
parallel_config
# In the decoupled engine case this is handled in EngineCoreProc.
self
.
dp_enabled
=
self
.
parallel_config
.
data_parallel_size
>
1
# noqa
parallel_config
=
vllm_config
.
parallel_config
if
not
multiprocess_mode
and
parallel_config
.
data_parallel_size
>
1
:
self
.
dp_group
=
parallel_config
.
stateless_init_dp_group
()
else
:
self
.
dp_group
=
None
self
.
should_execute_dummy_batch
=
False
self
.
should_execute_dummy_batch
=
False
if
self
.
dp_enabled
:
self
.
dp_group
=
self
.
parallel_config
.
stateless_init_dp_group
()
# Tokenizer (+ ensure liveness if running in another process).
# Tokenizer (+ ensure liveness if running in another process).
self
.
tokenizer
=
init_tokenizer_from_configs
(
self
.
tokenizer
=
init_tokenizer_from_configs
(
...
@@ -77,7 +80,6 @@ class LLMEngine:
...
@@ -77,7 +80,6 @@ class LLMEngine:
# Processor (convert Inputs --> EngineCoreRequests)
# Processor (convert Inputs --> EngineCoreRequests)
self
.
processor
=
Processor
(
vllm_config
=
vllm_config
,
self
.
processor
=
Processor
(
vllm_config
=
vllm_config
,
tokenizer
=
self
.
tokenizer
,
tokenizer
=
self
.
tokenizer
,
input_registry
=
input_registry
,
mm_registry
=
mm_registry
)
mm_registry
=
mm_registry
)
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
...
@@ -148,7 +150,7 @@ class LLMEngine:
...
@@ -148,7 +150,7 @@ class LLMEngine:
def
has_unfinished_requests
(
self
)
->
bool
:
def
has_unfinished_requests
(
self
)
->
bool
:
has_unfinished
=
self
.
output_processor
.
has_unfinished_requests
()
has_unfinished
=
self
.
output_processor
.
has_unfinished_requests
()
if
not
self
.
dp_
enabled
:
if
self
.
dp_
group
is
None
:
return
has_unfinished
return
has_unfinished
return
self
.
has_unfinished_requests_dp
(
has_unfinished
)
return
self
.
has_unfinished_requests_dp
(
has_unfinished
)
...
@@ -243,8 +245,8 @@ class LLMEngine:
...
@@ -243,8 +245,8 @@ class LLMEngine:
def
sleep
(
self
,
level
:
int
=
1
):
def
sleep
(
self
,
level
:
int
=
1
):
self
.
engine_core
.
sleep
(
level
)
self
.
engine_core
.
sleep
(
level
)
def
wake_up
(
self
):
def
wake_up
(
self
,
tags
:
Optional
[
list
[
str
]]
=
None
):
self
.
engine_core
.
wake_up
()
self
.
engine_core
.
wake_up
(
tags
)
def
is_sleeping
(
self
)
->
bool
:
def
is_sleeping
(
self
)
->
bool
:
return
self
.
engine_core
.
is_sleeping
()
return
self
.
engine_core
.
is_sleeping
()
...
@@ -280,3 +282,14 @@ class LLMEngine:
...
@@ -280,3 +282,14 @@ class LLMEngine:
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
"""Prevent an adapter from being evicted."""
"""Prevent an adapter from being evicted."""
return
self
.
engine_core
.
pin_lora
(
lora_id
)
return
self
.
engine_core
.
pin_lora
(
lora_id
)
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
[...,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
tuple
=
(),
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
return
self
.
engine_core
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
def
__del__
(
self
):
if
dp_group
:
=
getattr
(
self
,
"dp_group"
,
None
):
stateless_destroy_torch_distributed_process_group
(
dp_group
)
vllm/v1/engine/output_processor.py
View file @
675ba75f
...
@@ -328,7 +328,7 @@ class OutputProcessor:
...
@@ -328,7 +328,7 @@ class OutputProcessor:
# 2) Detokenize the token ids into text and perform stop checks.
# 2) Detokenize the token ids into text and perform stop checks.
stop_string
=
req_state
.
detokenizer
.
update
(
stop_string
=
req_state
.
detokenizer
.
update
(
new_token_ids
,
finish_reason
==
FinishReason
.
STOP
)
new_token_ids
,
finish_reason
==
FinishReason
.
STOP
)
if
stop_string
and
finish_reason
!=
FinishReason
.
STOP
:
if
stop_string
:
finish_reason
=
FinishReason
.
STOP
finish_reason
=
FinishReason
.
STOP
stop_reason
=
stop_string
stop_reason
=
stop_string
...
...
vllm/v1/engine/processor.py
View file @
675ba75f
...
@@ -5,9 +5,8 @@ from collections.abc import Mapping
...
@@ -5,9 +5,8 @@ from collections.abc import Mapping
from
typing
import
Optional
,
Union
from
typing
import
Optional
,
Union
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.inputs
import
(
INPUT_REGISTRY
,
InputRegistry
,
ProcessorInputs
,
from
vllm.inputs
import
ProcessorInputs
,
PromptType
PromptType
,
SingletonInputsAdapter
)
from
vllm.inputs.parse
import
split_enc_dec_inputs
from
vllm.inputs.parse
import
is_encoder_decoder_inputs
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
MultiModalKwargs
,
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
MultiModalKwargs
,
...
@@ -31,7 +30,6 @@ class Processor:
...
@@ -31,7 +30,6 @@ class Processor:
self
,
self
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
tokenizer
:
BaseTokenizerGroup
,
tokenizer
:
BaseTokenizerGroup
,
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
mm_registry
:
MultiModalRegistry
=
MULTIMODAL_REGISTRY
,
mm_registry
:
MultiModalRegistry
=
MULTIMODAL_REGISTRY
,
):
):
...
@@ -123,7 +121,8 @@ class Processor:
...
@@ -123,7 +121,8 @@ class Processor:
return
return
supported_backends
=
[
supported_backends
=
[
"xgrammar"
,
"xgrammar:disable-any-whitespace"
,
"guidance"
,
"auto"
"xgrammar"
,
"xgrammar:disable-any-whitespace"
,
"guidance"
,
"guidance:disable-any-whitespace"
,
"auto"
]
]
engine_level_backend
=
self
.
decoding_config
.
guided_decoding_backend
engine_level_backend
=
self
.
decoding_config
.
guided_decoding_backend
if
engine_level_backend
not
in
supported_backends
:
if
engine_level_backend
not
in
supported_backends
:
...
@@ -137,13 +136,15 @@ class Processor:
...
@@ -137,13 +136,15 @@ class Processor:
f
" !=
{
engine_level_backend
}
"
)
f
" !=
{
engine_level_backend
}
"
)
else
:
else
:
params
.
guided_decoding
.
backend
=
engine_level_backend
params
.
guided_decoding
.
backend
=
engine_level_backend
import
vllm.platforms
if
vllm
.
platforms
.
current_platform
.
is_tpu
():
raise
ValueError
(
"Structured output is not supported on TPU."
)
# Request content validation
# Request content validation
if
engine_level_backend
.
startswith
(
"xgrammar"
):
if
engine_level_backend
==
"xgrammar"
:
# xgrammar with no fallback
# xgrammar with no fallback
validate_structured_output_request_xgrammar
(
params
)
validate_structured_output_request_xgrammar
(
params
)
params
.
guided_decoding
.
backend
=
"xgrammar"
params
.
guided_decoding
.
backend
=
engine_level_backend
elif
engine_level_backend
==
"auto"
:
elif
engine_level_backend
==
"auto"
:
# "auto" is an opt-in to opinionated behavior where we try to
# "auto" is an opt-in to opinionated behavior where we try to
# choose a backend based on request contents. This is not the
# choose a backend based on request contents. This is not the
...
@@ -157,12 +158,13 @@ class Processor:
...
@@ -157,12 +158,13 @@ class Processor:
# are not supported in xgrammar. Fall back to guidance.
# are not supported in xgrammar. Fall back to guidance.
params
.
guided_decoding
.
backend
=
"guidance"
params
.
guided_decoding
.
backend
=
"guidance"
if
params
.
guided_decoding
.
backend
==
"guidance"
:
if
engine_level_backend
.
startswith
(
"guidance"
)
:
# TODO ideally we would have the LLTokenizer here as Lark syntax
# TODO ideally we would have the LLTokenizer here as Lark syntax
# allows <|special_token|> and similar, see
# allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars.
# Without tokenizer these are disallowed in grammars.
validate_guidance_grammar
(
params
,
tokenizer
=
None
)
validate_guidance_grammar
(
params
,
tokenizer
=
None
)
params
.
guided_decoding
.
backend
=
engine_level_backend
def
process_inputs
(
def
process_inputs
(
self
,
self
,
...
@@ -206,14 +208,7 @@ class Processor:
...
@@ -206,14 +208,7 @@ class Processor:
self
.
_validate_model_inputs
(
processed_inputs
,
lora_request
)
self
.
_validate_model_inputs
(
processed_inputs
,
lora_request
)
if
is_encoder_decoder_inputs
(
processed_inputs
):
encoder_inputs
,
decoder_inputs
=
split_enc_dec_inputs
(
processed_inputs
)
decoder_inputs
=
SingletonInputsAdapter
(
processed_inputs
[
"decoder"
])
encoder_inputs
=
SingletonInputsAdapter
(
processed_inputs
[
"encoder"
])
else
:
decoder_inputs
=
SingletonInputsAdapter
(
processed_inputs
)
encoder_inputs
=
None
# TODO: Impl encoder-decoder
# TODO: Impl encoder-decoder
if
encoder_inputs
is
not
None
:
if
encoder_inputs
is
not
None
:
...
@@ -224,8 +219,9 @@ class Processor:
...
@@ -224,8 +219,9 @@ class Processor:
sampling_params
=
params
.
clone
()
sampling_params
=
params
.
clone
()
# If unset max tokens, then generate up to the max_model_len.
# If unset max tokens, then generate up to the max_model_len.
if
sampling_params
.
max_tokens
is
None
:
if
sampling_params
.
max_tokens
is
None
:
sampling_params
.
max_tokens
=
(
self
.
model_config
.
max_model_len
-
sampling_params
.
max_tokens
=
(
len
(
decoder_inputs
.
prompt_token_ids
))
self
.
model_config
.
max_model_len
-
len
(
decoder_inputs
[
"prompt_token_ids"
]))
sampling_params
.
update_from_generation_config
(
sampling_params
.
update_from_generation_config
(
self
.
generation_config_fields
,
eos_token_id
)
self
.
generation_config_fields
,
eos_token_id
)
sampling_params
.
update_from_tokenizer
(
sampling_params
.
update_from_tokenizer
(
...
@@ -235,57 +231,46 @@ class Processor:
...
@@ -235,57 +231,46 @@ class Processor:
sorted_mm_inputs
:
Optional
[
list
[
MultiModalKwargs
]]
=
None
sorted_mm_inputs
:
Optional
[
list
[
MultiModalKwargs
]]
=
None
sorted_mm_positions
:
Optional
[
list
[
PlaceholderRange
]]
=
None
sorted_mm_positions
:
Optional
[
list
[
PlaceholderRange
]]
=
None
sorted_mm_hashes
:
Optional
[
list
[
str
]]
=
None
sorted_mm_hashes
:
Optional
[
list
[
str
]]
=
None
if
(
decoder_mm_inputs
:
=
decoder_inputs
.
multi_modal_data
):
if
decoder_inputs
[
"type"
]
==
"multimodal"
:
assert
isinstance
(
decoder_mm_inputs
,
MultiModalKwargs
)
decoder_mm_inputs
=
decoder_inputs
[
"mm_kwargs"
]
# The output of merged multi-modal processor (`decoder_mm_inputs`)
# contains the kwargs for all items from all modalities.
# This code separates them so that there is one set of kwargs
# per item per modality.
individual_mm_inputs
=
[
MultiModalKwargs
.
from_items
([
item
])
for
modality
in
decoder_mm_inputs
.
modalities
for
item
in
decoder_mm_inputs
.
get_items
(
modality
)
]
# Merge and flatten multimodal placeholders, hashes and inputs
# Merge and flatten multimodal placeholders, hashes and inputs
# from dictionaries to lists, and sort them by each item's position
# from dictionaries to lists, and sort them by each item's position
# in the input sequence.
# in the input sequence.
# NOTE: interleaved modalities are not supported.
(
(
sorted_modalities
,
sorted_
item_
modalities
,
sorted_mm_positions
,
sorted_mm_positions
,
sorted_mm_hashes
,
sorted_mm_hashes
,
)
=
merge_and_sort_multimodal_metadata
(
)
=
merge_and_sort_multimodal_metadata
(
decoder_inputs
.
multi_modal
_placeholders
,
decoder_inputs
[
"mm
_placeholders
"
]
,
decoder_inputs
.
multi_modal
_hashes
if
self
.
use_hash
else
None
,
decoder_inputs
[
"mm
_hashes
"
]
if
self
.
use_hash
else
None
,
)
)
# NOTE: Sort multimodal inputs/kwargs ONLY IF there are multiple
# The output of merged multi-modal processor (`decoder_mm_inputs`)
# modalities involved.
# is a single MultiModalKwargs for all items from all modalities.
if
len
(
sorted_modalities
)
>
1
:
# This code flattens kwargs for individual items in a list and
modality_order_dict
=
{
# sorts them by each item's position in the input sequence if there
modality
:
order
# are multiple modalities.
for
order
,
modality
in
enumerate
(
sorted_modalities
)
unique_modalities
=
set
(
sorted_item_modalities
)
}
if
len
(
unique_modalities
)
>
1
:
sorted_mm_inputs
=
[]
# Sanity check to make sure each multimodal input has only one
used_indices
=
{
modality
:
0
for
modality
in
unique_modalities
}
# modality key.
for
modality
in
sorted_item_modalities
:
for
mm_input
in
individual_mm_inputs
:
items
=
decoder_mm_inputs
.
get_items
(
modality
)
assert
len
(
mm_input
.
modalities
)
==
1
item
=
items
[
used_indices
[
modality
]]
sorted_mm_inputs
.
append
(
MultiModalKwargs
.
from_items
([
item
# Sort MultiModalKwargs to match sorted_mm_positions
]))
sorted_mm_inputs
=
sorted
(
used_indices
[
modality
]
+=
1
individual_mm_inputs
,
key
=
lambda
mm_input
:
modality_order_dict
[
list
(
mm_input
.
modalities
)[
0
]])
else
:
else
:
sorted_mm_inputs
=
individual_mm_inputs
sorted_mm_inputs
=
[
MultiModalKwargs
.
from_items
([
item
])
for
item
in
decoder_mm_inputs
.
get_items
(
sorted_item_modalities
[
0
])
]
return
EngineCoreRequest
(
return
EngineCoreRequest
(
request_id
=
request_id
,
request_id
=
request_id
,
prompt
=
decoder_inputs
.
prompt
,
prompt
=
decoder_inputs
.
get
(
"
prompt
"
)
,
prompt_token_ids
=
decoder_inputs
.
prompt_token_ids
,
prompt_token_ids
=
decoder_inputs
[
"
prompt_token_ids
"
]
,
mm_inputs
=
sorted_mm_inputs
,
mm_inputs
=
sorted_mm_inputs
,
mm_hashes
=
sorted_mm_hashes
,
mm_hashes
=
sorted_mm_hashes
,
mm_placeholders
=
sorted_mm_positions
,
mm_placeholders
=
sorted_mm_positions
,
...
@@ -298,15 +283,16 @@ class Processor:
...
@@ -298,15 +283,16 @@ class Processor:
def
_validate_model_inputs
(
self
,
def
_validate_model_inputs
(
self
,
inputs
:
ProcessorInputs
,
inputs
:
ProcessorInputs
,
lora_request
:
Optional
[
LoRARequest
]
=
None
):
lora_request
:
Optional
[
LoRARequest
]
=
None
):
if
is_encoder_decoder_inputs
(
inputs
):
encoder_inputs
,
decoder_inputs
=
split_enc_dec_inputs
(
inputs
)
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
# For encoder-decoder multimodal models, the max_prompt_len
prompt_inputs
=
inputs
[
"decoder"
if
self
.
model_config
.
# restricts the decoder prompt length
is_multimodal_model
else
"encoder"
]
if
self
.
model_config
.
is_multimodal_model
:
prompt_inputs
=
decoder_inputs
else
:
else
:
prompt_inputs
=
inputs
prompt_inputs
=
encoder_inputs
or
decoder_
inputs
prompt_ids
=
SingletonInputsAdapter
(
prompt_inputs
).
prompt_token_ids
prompt_ids
=
prompt_inputs
[
"
prompt_token_ids
"
]
if
prompt_ids
is
None
or
len
(
prompt_ids
)
==
0
:
if
prompt_ids
is
None
or
len
(
prompt_ids
)
==
0
:
raise
ValueError
(
"Prompt cannot be empty"
)
raise
ValueError
(
"Prompt cannot be empty"
)
...
...
vllm/v1/executor/multiproc_executor.py
View file @
675ba75f
...
@@ -235,7 +235,10 @@ class WorkerProc:
...
@@ -235,7 +235,10 @@ class WorkerProc:
worker_response_mq_handle
=
self
.
worker_response_mq
.
export_handle
()
worker_response_mq_handle
=
self
.
worker_response_mq
.
export_handle
()
# Send Readiness signal to EngineCore process.
# Send Readiness signal to EngineCore process.
with
zmq_socket_ctx
(
ready_path
,
zmq
.
constants
.
PUSH
)
as
ready_socket
:
# Set linger here because we want to ensure the message has
# been sent before the context is closed.
with
zmq_socket_ctx
(
ready_path
,
zmq
.
constants
.
PUSH
,
linger
=
10000
)
as
ready_socket
:
payload
=
pickle
.
dumps
(
worker_response_mq_handle
,
payload
=
pickle
.
dumps
(
worker_response_mq_handle
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
ready_socket
.
send_string
(
WorkerProc
.
READY_STR
)
ready_socket
.
send_string
(
WorkerProc
.
READY_STR
)
...
@@ -270,11 +273,13 @@ class WorkerProc:
...
@@ -270,11 +273,13 @@ class WorkerProc:
proc
=
context
.
Process
(
target
=
WorkerProc
.
worker_main
,
proc
=
context
.
Process
(
target
=
WorkerProc
.
worker_main
,
kwargs
=
process_kwargs
,
kwargs
=
process_kwargs
,
daemon
=
True
)
daemon
=
True
)
proc
.
start
()
# Wait for startup
with
zmq_socket_ctx
(
ready_path
,
zmq
.
constants
.
PULL
)
as
ready_socket
:
worker_response_mq_handle
=
WorkerProc
.
wait_for_startup
(
proc
.
start
()
proc
,
ready_path
)
# Wait for startup
worker_response_mq_handle
=
WorkerProc
.
wait_for_startup
(
proc
,
ready_socket
)
worker_response_mq
=
MessageQueue
.
create_from_handle
(
worker_response_mq
=
MessageQueue
.
create_from_handle
(
worker_response_mq_handle
,
0
)
worker_response_mq_handle
,
0
)
...
@@ -337,23 +342,22 @@ class WorkerProc:
...
@@ -337,23 +342,22 @@ class WorkerProc:
@
staticmethod
@
staticmethod
def
wait_for_startup
(
def
wait_for_startup
(
proc
:
BaseProcess
,
proc
:
BaseProcess
,
ready_
path
:
str
,
ready_
socket
:
zmq
.
Socket
,
)
->
Optional
[
Handle
]:
)
->
Optional
[
Handle
]:
"""Wait until the Worker is ready."""
"""Wait until the Worker is ready."""
with
zmq_socket_ctx
(
ready_path
,
zmq
.
constants
.
PULL
)
as
socket
:
# Wait for Worker to send READY.
# Wait for Worker to send READY.
while
socket
.
poll
(
timeout
=
POLLING_TIMEOUT_MS
)
==
0
:
while
ready_
socket
.
poll
(
timeout
=
POLLING_TIMEOUT_MS
)
==
0
:
logger
.
debug
(
"Waiting for WorkerProc to startup."
)
logger
.
debug
(
"Waiting for WorkerProc to startup."
)
if
not
proc
.
is_alive
():
if
not
proc
.
is_alive
():
raise
RuntimeError
(
"WorkerProc failed to start."
)
raise
RuntimeError
(
"WorkerProc failed to start."
)
message
=
socket
.
recv_string
()
message
=
ready_
socket
.
recv_string
()
assert
message
==
WorkerProc
.
READY_STR
assert
message
==
WorkerProc
.
READY_STR
handle_frame
=
socket
.
recv
(
copy
=
False
)
handle_frame
=
ready_
socket
.
recv
(
copy
=
False
)
handle
=
pickle
.
loads
(
handle_frame
.
buffer
)
handle
=
pickle
.
loads
(
handle_frame
.
buffer
)
return
handle
return
handle
class
ResponseStatus
(
Enum
):
class
ResponseStatus
(
Enum
):
SUCCESS
=
auto
()
SUCCESS
=
auto
()
...
...
vllm/v1/kv_cache_interface.py
View file @
675ba75f
...
@@ -4,6 +4,7 @@ from dataclasses import dataclass
...
@@ -4,6 +4,7 @@ from dataclasses import dataclass
import
torch
import
torch
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
cdiv
,
get_dtype_size
from
vllm.utils
import
cdiv
,
get_dtype_size
...
@@ -43,28 +44,23 @@ class KVCacheSpec:
...
@@ -43,28 +44,23 @@ class KVCacheSpec:
"""
"""
raise
NotImplementedError
raise
NotImplementedError
def
bytes_for_tokens
(
self
,
num_tokens
:
int
)
->
int
:
def
max_memory_usage_bytes
(
self
,
vllm_config
:
VllmConfig
)
->
int
:
"""
"""
The KV cache size for `num_tokens` tokens in bytes. Returns the real
The maximum possible memory usage of this KV cache in bytes.
memory size after padding `num_tokens` to full blocks.
Returns:
Returns:
The KV cache size
The KV cache size
in bytes
"""
"""
raise
NotImplementedError
raise
NotImplementedError
@
dataclass
@
dataclass
class
Full
AttentionSpec
(
KVCacheSpec
):
class
AttentionSpec
(
KVCacheSpec
):
num_kv_heads
:
int
num_kv_heads
:
int
head_size
:
int
head_size
:
int
dtype
:
torch
.
dtype
dtype
:
torch
.
dtype
use_mla
:
bool
use_mla
:
bool
@
property
def
type_id
(
self
)
->
str
:
return
f
"full_attention_
{
self
.
block_size
}
_
{
self
.
page_size_bytes
}
"
@
property
@
property
def
page_size_bytes
(
self
)
->
int
:
def
page_size_bytes
(
self
)
->
int
:
# For MLA we only store a single latent vector
# For MLA we only store a single latent vector
...
@@ -72,8 +68,47 @@ class FullAttentionSpec(KVCacheSpec):
...
@@ -72,8 +68,47 @@ class FullAttentionSpec(KVCacheSpec):
return
coef
*
self
.
block_size
*
self
.
num_kv_heads
*
self
.
head_size
\
return
coef
*
self
.
block_size
*
self
.
num_kv_heads
*
self
.
head_size
\
*
get_dtype_size
(
self
.
dtype
)
*
get_dtype_size
(
self
.
dtype
)
def
bytes_for_tokens
(
self
,
num_tokens
:
int
)
->
int
:
return
cdiv
(
num_tokens
,
self
.
block_size
)
*
self
.
page_size_bytes
@
dataclass
class
FullAttentionSpec
(
AttentionSpec
):
@
property
def
type_id
(
self
)
->
str
:
return
f
"full_attention_
{
self
.
block_size
}
_
{
self
.
page_size_bytes
}
"
def
max_memory_usage_bytes
(
self
,
vllm_config
:
VllmConfig
)
->
int
:
max_model_len
=
vllm_config
.
model_config
.
max_model_len
return
cdiv
(
max_model_len
,
self
.
block_size
)
*
self
.
page_size_bytes
@
dataclass
class
SlidingWindowSpec
(
AttentionSpec
):
sliding_window
:
int
def
__post_init__
(
self
):
assert
not
self
.
use_mla
,
"MLA is not supported for sliding window"
@
property
def
type_id
(
self
)
->
str
:
return
f
"sliding_window_
{
self
.
sliding_window
}
_
{
self
.
block_size
}
_
{
self
.
page_size_bytes
}
"
# noqa
def
max_memory_usage_bytes
(
self
,
vllm_config
:
VllmConfig
)
->
int
:
max_model_len
=
vllm_config
.
model_config
.
max_model_len
max_num_batched_tokens
=
(
vllm_config
.
scheduler_config
.
max_num_batched_tokens
)
# During chunked prefill, we allocate KV cache for the last
# `self.sliding_window-1` computed tokens plus the newly scheduled
# tokens. And we won't allocate KV cache for more than `max_model_len`
# tokens.
num_tokens
=
min
(
self
.
sliding_window
-
1
+
max_num_batched_tokens
,
max_model_len
)
# +1 here because the sliding window may not start from the beginning
# of the block. For example, if the block size is 4 and num_token
# is 4, we need two blocks [XXCD] [EF] to store the sliding
# window [CDEF] of 6 tokens.
return
(
cdiv
(
num_tokens
,
self
.
block_size
)
+
1
)
*
self
.
page_size_bytes
@
dataclass
@
dataclass
...
...
vllm/v1/metrics/loggers.py
View file @
675ba75f
...
@@ -12,6 +12,7 @@ from vllm.logger import init_logger
...
@@ -12,6 +12,7 @@ from vllm.logger import init_logger
from
vllm.v1.core.kv_cache_utils
import
PrefixCachingMetrics
from
vllm.v1.core.kv_cache_utils
import
PrefixCachingMetrics
from
vllm.v1.engine
import
FinishReason
from
vllm.v1.engine
import
FinishReason
from
vllm.v1.metrics.stats
import
IterationStats
,
SchedulerStats
from
vllm.v1.metrics.stats
import
IterationStats
,
SchedulerStats
from
vllm.v1.spec_decode.metrics
import
SpecDecodingMetrics
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -31,12 +32,14 @@ class StatLoggerBase(ABC):
...
@@ -31,12 +32,14 @@ class StatLoggerBase(ABC):
class
LoggingStatLogger
(
StatLoggerBase
):
class
LoggingStatLogger
(
StatLoggerBase
):
def
__init__
(
self
):
def
__init__
(
self
,
engine_index
:
int
=
0
):
self
.
engine_index
=
engine_index
self
.
_reset
(
time
.
monotonic
())
self
.
_reset
(
time
.
monotonic
())
self
.
last_scheduler_stats
=
SchedulerStats
()
self
.
last_scheduler_stats
=
SchedulerStats
()
# Prefix cache metrics. This cannot be reset.
# Prefix cache metrics. This cannot be reset.
# TODO: Make the interval configurable.
# TODO: Make the interval configurable.
self
.
prefix_caching_metrics
=
PrefixCachingMetrics
()
self
.
prefix_caching_metrics
=
PrefixCachingMetrics
()
self
.
spec_decoding_metrics
=
SpecDecodingMetrics
()
def
_reset
(
self
,
now
):
def
_reset
(
self
,
now
):
self
.
last_log_time
=
now
self
.
last_log_time
=
now
...
@@ -64,6 +67,10 @@ class LoggingStatLogger(StatLoggerBase):
...
@@ -64,6 +67,10 @@ class LoggingStatLogger(StatLoggerBase):
self
.
prefix_caching_metrics
.
observe
(
scheduler_stats
.
prefix_cache_stats
)
self
.
prefix_caching_metrics
.
observe
(
scheduler_stats
.
prefix_cache_stats
)
if
scheduler_stats
.
spec_decoding_stats
is
not
None
:
self
.
spec_decoding_metrics
.
observe
(
scheduler_stats
.
spec_decoding_stats
)
self
.
last_scheduler_stats
=
scheduler_stats
self
.
last_scheduler_stats
=
scheduler_stats
def
log
(
self
):
def
log
(
self
):
...
@@ -78,11 +85,13 @@ class LoggingStatLogger(StatLoggerBase):
...
@@ -78,11 +85,13 @@ class LoggingStatLogger(StatLoggerBase):
# Format and print output.
# Format and print output.
logger
.
info
(
logger
.
info
(
"Engine %03d: "
"Avg prompt throughput: %.1f tokens/s, "
"Avg prompt throughput: %.1f tokens/s, "
"Avg generation throughput: %.1f tokens/s, "
"Avg generation throughput: %.1f tokens/s, "
"Running: %d reqs, Waiting: %d reqs, "
"Running: %d reqs, Waiting: %d reqs, "
"GPU KV cache usage: %.1f%%, "
"GPU KV cache usage: %.1f%%, "
"Prefix cache hit rate: %.1f%%"
,
"Prefix cache hit rate: %.1f%%"
,
self
.
engine_index
,
prompt_throughput
,
prompt_throughput
,
generation_throughput
,
generation_throughput
,
scheduler_stats
.
num_running_reqs
,
scheduler_stats
.
num_running_reqs
,
...
@@ -91,10 +100,13 @@ class LoggingStatLogger(StatLoggerBase):
...
@@ -91,10 +100,13 @@ class LoggingStatLogger(StatLoggerBase):
self
.
prefix_caching_metrics
.
hit_rate
*
100
,
self
.
prefix_caching_metrics
.
hit_rate
*
100
,
)
)
if
scheduler_stats
.
spec_decoding_stats
is
not
None
:
self
.
spec_decoding_metrics
.
log
()
class
PrometheusStatLogger
(
StatLoggerBase
):
class
PrometheusStatLogger
(
StatLoggerBase
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
engine_index
:
int
=
0
):
self
.
_unregister_vllm_metrics
()
self
.
_unregister_vllm_metrics
()
# Use this flag to hide metrics that were deprecated in
# Use this flag to hide metrics that were deprecated in
...
@@ -102,8 +114,11 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -102,8 +114,11 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
show_hidden_metrics
=
\
self
.
show_hidden_metrics
=
\
vllm_config
.
observability_config
.
show_hidden_metrics
vllm_config
.
observability_config
.
show_hidden_metrics
labelnames
=
[
"model_name"
]
labelnames
=
[
"model_name"
,
"engine"
]
labelvalues
=
[
vllm_config
.
model_config
.
served_model_name
]
labelvalues
=
[
vllm_config
.
model_config
.
served_model_name
,
str
(
engine_index
)
]
max_model_len
=
vllm_config
.
model_config
.
max_model_len
max_model_len
=
vllm_config
.
model_config
.
max_model_len
...
@@ -296,6 +311,24 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -296,6 +311,24 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
labelname_running_lora_adapters
,
self
.
labelname_running_lora_adapters
,
])
])
#
# Speculative Decoding metrics
# The acceptance rate can be calculated using a PromQL query:
#
# rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
# rate(vllm:spec_decode_num_draft_tokens_total[$interval])
#
self
.
counter_spec_decode_num_draft_tokens
=
\
prometheus_client
.
Counter
(
name
=
"vllm:spec_decode_num_draft_tokens_total"
,
documentation
=
"Number of draft tokens."
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
self
.
counter_spec_decode_num_accepted_tokens
=
\
prometheus_client
.
Counter
(
name
=
"vllm:spec_decode_num_accepted_tokens_total"
,
documentation
=
"Number of accepted tokens."
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
#
#
# Cache config info metric
# Cache config info metric
#
#
...
@@ -332,6 +365,12 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -332,6 +365,12 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
counter_gpu_prefix_cache_hits
.
inc
(
self
.
counter_gpu_prefix_cache_hits
.
inc
(
scheduler_stats
.
prefix_cache_stats
.
hits
)
scheduler_stats
.
prefix_cache_stats
.
hits
)
if
scheduler_stats
.
spec_decoding_stats
is
not
None
:
self
.
counter_spec_decode_num_draft_tokens
.
inc
(
scheduler_stats
.
spec_decoding_stats
.
num_draft_tokens
)
self
.
counter_spec_decode_num_accepted_tokens
.
inc
(
scheduler_stats
.
spec_decoding_stats
.
num_accepted_tokens
)
if
iteration_stats
is
None
:
if
iteration_stats
is
None
:
return
return
...
...
vllm/v1/metrics/stats.py
View file @
675ba75f
...
@@ -4,6 +4,8 @@ import time
...
@@ -4,6 +4,8 @@ import time
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Optional
from
typing
import
TYPE_CHECKING
,
Optional
from
vllm.v1.spec_decode.metrics
import
SpecDecodingStats
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.v1.engine
import
EngineCoreEvent
,
EngineCoreOutput
,
FinishReason
from
vllm.v1.engine
import
EngineCoreEvent
,
EngineCoreOutput
,
FinishReason
from
vllm.v1.engine.output_processor
import
RequestState
from
vllm.v1.engine.output_processor
import
RequestState
...
@@ -35,6 +37,8 @@ class SchedulerStats:
...
@@ -35,6 +37,8 @@ class SchedulerStats:
prefix_cache_stats
:
PrefixCacheStats
=
field
(
prefix_cache_stats
:
PrefixCacheStats
=
field
(
default_factory
=
PrefixCacheStats
)
default_factory
=
PrefixCacheStats
)
spec_decoding_stats
:
Optional
[
SpecDecodingStats
]
=
None
@
dataclass
@
dataclass
class
LoRAStats
:
class
LoRAStats
:
...
...
vllm/v1/request.py
View file @
675ba75f
...
@@ -59,6 +59,8 @@ class Request:
...
@@ -59,6 +59,8 @@ class Request:
self
.
mm_positions
=
multi_modal_placeholders
or
[]
self
.
mm_positions
=
multi_modal_placeholders
or
[]
self
.
mm_inputs
=
multi_modal_inputs
or
[]
self
.
mm_inputs
=
multi_modal_inputs
or
[]
self
.
mm_hashes
:
list
[
str
]
=
multi_modal_hashes
or
[]
self
.
mm_hashes
:
list
[
str
]
=
multi_modal_hashes
or
[]
self
.
num_encoder_inputs
=
len
(
self
.
mm_inputs
)
self
.
has_encoder_inputs
=
self
.
num_encoder_inputs
>
0
# Sanity check
# Sanity check
assert
len
(
self
.
mm_inputs
)
==
len
(
self
.
mm_positions
)
assert
len
(
self
.
mm_inputs
)
==
len
(
self
.
mm_positions
)
...
@@ -93,9 +95,11 @@ class Request:
...
@@ -93,9 +95,11 @@ class Request:
token_ids
:
Union
[
int
,
list
[
int
]],
token_ids
:
Union
[
int
,
list
[
int
]],
)
->
None
:
)
->
None
:
if
isinstance
(
token_ids
,
int
):
if
isinstance
(
token_ids
,
int
):
token_ids
=
[
token_ids
]
self
.
_output_token_ids
.
append
(
token_ids
)
self
.
_output_token_ids
.
extend
(
token_ids
)
self
.
_all_token_ids
.
append
(
token_ids
)
self
.
_all_token_ids
.
extend
(
token_ids
)
else
:
self
.
_output_token_ids
.
extend
(
token_ids
)
self
.
_all_token_ids
.
extend
(
token_ids
)
@
property
@
property
def
num_tokens
(
self
)
->
int
:
def
num_tokens
(
self
)
->
int
:
...
@@ -115,13 +119,6 @@ class Request:
...
@@ -115,13 +119,6 @@ class Request:
def
get_finished_reason
(
self
)
->
Union
[
FinishReason
,
None
]:
def
get_finished_reason
(
self
)
->
Union
[
FinishReason
,
None
]:
return
RequestStatus
.
get_finished_reason
(
self
.
status
)
return
RequestStatus
.
get_finished_reason
(
self
.
status
)
def
has_encoder_inputs
(
self
)
->
bool
:
return
len
(
self
.
mm_inputs
)
>
0
@
property
def
num_encoder_inputs
(
self
)
->
int
:
return
len
(
self
.
mm_positions
)
def
get_num_encoder_tokens
(
self
,
input_id
:
int
)
->
int
:
def
get_num_encoder_tokens
(
self
,
input_id
:
int
)
->
int
:
assert
input_id
<
len
(
self
.
mm_positions
)
assert
input_id
<
len
(
self
.
mm_positions
)
num_tokens
=
self
.
mm_positions
[
input_id
][
"length"
]
num_tokens
=
self
.
mm_positions
[
input_id
][
"length"
]
...
...
vllm/v1/sample/ops/topk_topp_sampler.py
View file @
675ba75f
...
@@ -19,6 +19,12 @@ except ImportError:
...
@@ -19,6 +19,12 @@ except ImportError:
class
TopKTopPSampler
(
nn
.
Module
):
class
TopKTopPSampler
(
nn
.
Module
):
"""
Module that performs optional top-k and top-p filtering followed by
weighted random sampling of logits.
Implementations may update the logits tensor in-place.
"""
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
...
@@ -84,7 +90,11 @@ class TopKTopPSampler(nn.Module):
...
@@ -84,7 +90,11 @@ class TopKTopPSampler(nn.Module):
k
:
Optional
[
torch
.
Tensor
],
k
:
Optional
[
torch
.
Tensor
],
p
:
Optional
[
torch
.
Tensor
],
p
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""PyTorch-native implementation of top-k and top-p sampling."""
"""
PyTorch-native implementation of top-k and top-p sampling.
The logits tensor may be updated in-place.
"""
logits
=
apply_top_k_top_p
(
logits
,
k
,
p
)
logits
=
apply_top_k_top_p
(
logits
,
k
,
p
)
probs
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
probs
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
return
random_sample
(
probs
,
generators
)
return
random_sample
(
probs
,
generators
)
...
@@ -112,23 +122,48 @@ class TopKTopPSampler(nn.Module):
...
@@ -112,23 +122,48 @@ class TopKTopPSampler(nn.Module):
k
:
Optional
[
torch
.
Tensor
],
k
:
Optional
[
torch
.
Tensor
],
p
:
Optional
[
torch
.
Tensor
],
p
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# If only top-k is specified, use pytorch's builtin topk op. This leads
logits
=
apply_top_k_top_p_tpu
(
logits
,
k
,
p
)
# to significant speed up on TPU compared to using apply_top_k_top_p.
if
k
is
not
None
and
p
is
None
:
topk_values
,
topk_indices
=
torch
.
topk
(
logits
,
k
,
dim
=-
1
)
mask
=
torch
.
ones_like
(
logits
,
dtype
=
torch
.
bool
)
mask
.
scatter_
(
-
1
,
topk_indices
,
False
)
logits
.
masked_fill_
(
mask
,
float
(
'-inf'
))
else
:
# TODO Placeholder for TPU optimized topp kernel
# logits = apply_top_k_top_p(logits, k, p)
pass
probs
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
probs
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
return
random_sample
(
probs
,
generators
)
return
random_sample
(
probs
,
generators
)
def
apply_top_k_top_p_tpu
(
logits
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""
Apply top-k and top-p optimized for TPU.
This algorithm avoids using torch.scatter which is extremely slow on TPU.
This is achieved by finding a "cut-off" element in the original logit, and
after thresholding the logit using this cut-off, the remaining elements
shall constitute the top-p set.
Note: in the case of tie (i.e. multipple cut-off elements present in the
logit), all tie elements are included in the top-p set. In other words,
this function does not break ties. Instead, these tie tokens have equal
chance of being chosen during final sampling, so we can consider the tie
being broken then.
"""
if
k
is
not
None
:
logits
=
apply_top_k_only
(
logits
,
k
)
if
p
is
not
None
:
probs
=
logits
.
softmax
(
dim
=-
1
)
probs_sort
,
_
=
probs
.
sort
(
dim
=-
1
,
descending
=
False
)
cumprob
=
torch
.
cumsum
(
probs_sort
,
dim
=-
1
)
top_p_mask
=
cumprob
<=
1
-
p
.
unsqueeze
(
dim
=
1
)
top_p_mask
[:,
-
1
]
=
False
# at least one
top_p_count
=
top_p_mask
.
sum
(
dim
=-
1
).
unsqueeze
(
1
)
top_p_cutoff
=
probs_sort
.
gather
(
-
1
,
top_p_count
)
elements_to_discard
=
probs
<
top_p_cutoff
logits
.
masked_fill_
(
elements_to_discard
,
-
float
(
"inf"
))
return
logits
def
apply_top_k_top_p
(
def
apply_top_k_top_p
(
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
k
:
Optional
[
torch
.
Tensor
],
k
:
Optional
[
torch
.
Tensor
],
...
@@ -136,10 +171,18 @@ def apply_top_k_top_p(
...
@@ -136,10 +171,18 @@ def apply_top_k_top_p(
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Apply top-k and top-p masks to the logits.
"""Apply top-k and top-p masks to the logits.
This function sorts the logits tensor, which can be slow for large batches.
If a top-p is used, this function will sort the logits tensor,
which can be slow for large batches.
The logits tensor may be updated in-place.
"""
"""
if
k
is
None
and
p
is
None
:
if
p
is
None
:
return
logits
if
k
is
None
:
return
logits
# Avoid sorting vocab for top-k only case.
return
apply_top_k_only
(
logits
,
k
)
logits_sort
,
logits_idx
=
logits
.
sort
(
dim
=-
1
,
descending
=
False
)
logits_sort
,
logits_idx
=
logits
.
sort
(
dim
=-
1
,
descending
=
False
)
if
k
is
not
None
:
if
k
is
not
None
:
...
@@ -153,7 +196,7 @@ def apply_top_k_top_p(
...
@@ -153,7 +196,7 @@ def apply_top_k_top_p(
if
p
is
not
None
:
if
p
is
not
None
:
# Apply top-p.
# Apply top-p.
probs_sort
=
logits_sort
.
softmax
(
dim
=-
1
)
probs_sort
=
logits_sort
.
softmax
(
dim
=-
1
)
probs_sum
=
probs_sort
.
cumsum
(
dim
=-
1
)
probs_sum
=
torch
.
cumsum
(
probs_sort
,
dim
=-
1
,
out
=
probs_sort
)
top_p_mask
=
probs_sum
<=
1
-
p
.
unsqueeze
(
dim
=
1
)
top_p_mask
=
probs_sum
<=
1
-
p
.
unsqueeze
(
dim
=
1
)
# at least one
# at least one
top_p_mask
[:,
-
1
]
=
False
top_p_mask
[:,
-
1
]
=
False
...
@@ -164,6 +207,31 @@ def apply_top_k_top_p(
...
@@ -164,6 +207,31 @@ def apply_top_k_top_p(
return
logits
return
logits
def
apply_top_k_only
(
logits
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""
Apply top-k mask to the logits.
This implementation doesn't involve sorting the entire vocab.
The logits tensor may be updated in-place.
"""
no_top_k_mask
=
k
==
logits
.
shape
[
1
]
# Set non-top-k rows to 1 so that we can gather.
k
=
k
.
masked_fill
(
no_top_k_mask
,
1
)
max_top_k
=
k
.
max
()
# topk.values tensor has shape [batch_size, max_top_k].
# Convert top k to 0-based index in range [0, max_top_k).
k_index
=
k
.
sub_
(
1
).
unsqueeze
(
1
).
expand
(
logits
.
shape
[
0
],
1
)
top_k_mask
=
logits
.
topk
(
max_top_k
,
dim
=
1
).
values
.
gather
(
1
,
k_index
.
long
())
# Handle non-topk rows.
top_k_mask
.
masked_fill_
(
no_top_k_mask
.
unsqueeze
(
1
),
-
float
(
"inf"
))
logits
.
masked_fill_
(
logits
<
top_k_mask
,
-
float
(
"inf"
))
return
logits
def
random_sample
(
def
random_sample
(
probs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
generators
:
dict
[
int
,
torch
.
Generator
],
generators
:
dict
[
int
,
torch
.
Generator
],
...
...
vllm/v1/sample/rejection_sampler.py
View file @
675ba75f
...
@@ -109,6 +109,18 @@ class RejectionSampler(nn.Module):
...
@@ -109,6 +109,18 @@ class RejectionSampler(nn.Module):
output_token_ids
:
torch
.
Tensor
,
output_token_ids
:
torch
.
Tensor
,
vocab_size
:
int
,
vocab_size
:
int
,
)
->
list
[
list
[
int
]]:
)
->
list
[
list
[
int
]]:
"""Parse the output of the rejection sampler.
Args:
output_token_ids: The sampled token IDs in shape
[batch_size, max_spec_len + 1]. The rejected tokens are
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
and will be filtered out in this function.
vocab_size: The size of the vocabulary.
Returns:
A list of lists of token IDs.
"""
output_token_ids_np
=
output_token_ids
.
cpu
().
numpy
()
output_token_ids_np
=
output_token_ids
.
cpu
().
numpy
()
# Create mask for valid tokens.
# Create mask for valid tokens.
valid_mask
=
((
output_token_ids_np
!=
PLACEHOLDER_TOKEN_ID
)
&
valid_mask
=
((
output_token_ids_np
!=
PLACEHOLDER_TOKEN_ID
)
&
...
...
vllm/v1/sample/sampler.py
View file @
675ba75f
...
@@ -87,6 +87,12 @@ class Sampler(nn.Module):
...
@@ -87,6 +87,12 @@ class Sampler(nn.Module):
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Sample logits based on sampling metadata.
The various logits processing functions called in this method
may update the logits tensor in-place.
"""
assert
not
(
sampling_metadata
.
all_greedy
assert
not
(
sampling_metadata
.
all_greedy
and
sampling_metadata
.
all_random
)
and
sampling_metadata
.
all_random
)
if
sampling_metadata
.
all_random
:
if
sampling_metadata
.
all_random
:
...
...
vllm/v1/sample/tpu/metadata.py
View file @
675ba75f
...
@@ -5,7 +5,18 @@ from typing import Optional
...
@@ -5,7 +5,18 @@ from typing import Optional
import
torch
import
torch
import
torch_xla.core.xla_model
as
xm
import
torch_xla.core.xla_model
as
xm
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
DEFAULT_SAMPLING_PARAMS
=
dict
(
temperature
=-
1.0
,
min_p
=
0.0
,
# strictly disabled for now
# top_k=-1,
# top_p=0.0,
# frequency_penalties=0.0,
# presence_penalties=0.0,
# repetition_penalties=0.0,
)
@
dataclass
@
dataclass
...
@@ -20,14 +31,8 @@ class TPUSupportedSamplingMetadata:
...
@@ -20,14 +31,8 @@ class TPUSupportedSamplingMetadata:
top_k
:
torch
.
Tensor
=
None
top_k
:
torch
.
Tensor
=
None
top_p
:
torch
.
Tensor
=
None
top_p
:
torch
.
Tensor
=
None
# XLA-unfriendly control flow in Sampler
all_greedy
:
bool
=
False
all_random
:
bool
=
False
# Greedy sampling flag for compiling single xla graph.
# Greedy sampling flag for compiling single xla graph.
do_argmax
:
torch
.
Tensor
=
None
all_greedy
:
torch
.
Tensor
=
None
# speculation not supported
spec_token_ids
=
None
# Generator not supported by xla
# Generator not supported by xla
generators
:
dict
[
int
,
generators
:
dict
[
int
,
...
@@ -54,106 +59,62 @@ class TPUSupportedSamplingMetadata:
...
@@ -54,106 +59,62 @@ class TPUSupportedSamplingMetadata:
bad_words_token_ids
=
None
bad_words_token_ids
=
None
indices_do_sample
:
torch
.
Tensor
=
None
indices_do_sample
:
torch
.
Tensor
=
None
def
__post_init__
(
self
):
temp
=
self
.
temperature
if
self
.
indices_do_sample
is
None
:
self
.
indices_do_sample
=
torch
.
zeros
(
temp
.
shape
[
0
],
device
=
temp
.
device
,
dtype
=
torch
.
int32
)
if
self
.
do_argmax
is
None
:
self
.
do_argmax
=
torch
.
tensor
(
0
,
dtype
=
torch
.
bool
,
device
=
temp
.
device
)
@
classmethod
@
classmethod
def
from_sampling_metadata
(
def
from_input_batch
(
cls
,
metadata
:
SamplingMetadata
,
cls
,
input_batch
:
InputBatch
,
padded_do_sample_indices
:
torch
.
Tensor
,
num_do_sample
:
int
,
indices_do_sample
:
torch
.
Tensor
)
->
"TPUSupportedSamplingMetadata"
:
device
:
torch
.
device
)
->
"TPUSupportedSamplingMetadata"
:
"""
"""
Create an XLA-frienly SamplingMetadata structure. Do so by first
Copy sampling tensors slices from `input_batch` to on device tensors.
instantiating an object with fixed-sized tensors and then writing the
values in input `metadata`. Do that only for non-None values so that
`InputBatch._make_sampling_metadata` causes recompilation on XLA as it
recompilation is not triggered for optional values (None/torch.Tensor).
slices dynamic shapes on device tensors. This impl moves the dynamic
ops to CPU and produces tensors of fixed `padded_num_reqs` size. It
In order to handle different sizes for the params that range from 1 up
also reuses the on-device persistent tensors managed in `input_batch`
to `max_num_seqs`, pad tensors to the closest pre-compiled shape.
to reduce waste.
Same thing for `padded_do_sample_indices`, which contains the indices
to be fed to the Sampler, padded to the closest pre-compiled shape.
`indices_do_sample` contains the indices to be fed to the Sampler,
normally one per request, here padded to the closest pre-compiled shape
Eg. pad to 4 temperature: [0.7, 0.2]=>[0.7, 0.2, 0.0, 0.0]
We expect sampling params tensors to be padded to the same fixed shape.
do_sample_indices: [4, 10]=>padded_do_sample_indices: [4, 10, 0, 0]
Eg. 3 requests, tensors padded to 4
temperature: [0.7, 0.2, 0.9]=>[0.7, 0.2, 0.9, 0.0]
sample indices: [4, 10, 11]=>indices_do_sample: [4, 10, 11, 0]
"""
"""
metadata
=
cls
.
_validate_sampling_metadata
(
metadata
)
num_reqs
=
input_batch
.
num_reqs
# NOTE we have to initialize default tensor-based params first and
padded_num_reqs
=
len
(
indices_do_sample
)
# skip None values altogether to produce the same xla graph.
num_samples
=
len
(
padded_do_sample_indices
)
def
copy_slice
(
cpu_tensor
:
torch
.
Tensor
,
tpu_tensor
:
torch
.
Tensor
,
do_argmax
=
torch
.
t
ensor
(
metadata
.
all_greedy
,
fill_val
)
->
torch
.
T
ensor
:
dtype
=
torch
.
bool
,
# Copy slice from CPU to corresponding TPU pre-allocated tensor.
device
=
device
)
# Pad value is the default one.
new_metadata
=
cls
.
get_default_sampling_params
(
num_samples
,
device
,
cpu_tensor
[
num_reqs
:
padded_num_reqs
]
=
fill_val
indices_do_sample
=
\
# Subtle compilation: len(tpu_tensor) must be >= `padded_num_reqs`
padded_do_sample_indices
,
tpu_tensor
[:
padded_num_reqs
]
=
cpu_tensor
[:
padded_num_reqs
]
do_argmax
=
do_argmax
)
# NOTE NickLucche The sync CPU-TPU graph we produce here must be
supported_params
=
\
# consistent. We can't have flags to skip copies or we'll end up
TPUSupportedSa
mpling
Metadata
.
_get_default_params_values
()
# reco
mp
i
ling
.
# Copy input non-None values into `new_metadata` fixed-sized tensors.
copy_slice
(
input_batch
.
temperature_cpu_tensor
,
input_batch
.
temperature
,
for
p_name
in
supported_params
:
DEFAULT_SAMPLING_PARAMS
[
"temperature"
])
old_val
=
getattr
(
metadata
,
p_name
)
# TODO Temporarily disabled until sampling options are enabled
new_val
=
getattr
(
new_metadata
,
p_name
)
# copy_slice(input_batch.top_p_cpu_tensor, input_batch.top_p
)
if
isinstance
(
old_val
,
torch
.
Tensor
):
# copy_slice(input_batch.top_k_cpu_tensor, input_batch.top_k)
new_val
[:
num_do_sample
]
=
old_val
copy_slice
(
input_batch
.
min_p_cpu_tensor
,
input_batch
.
min_p
,
setattr
(
new_metadata
,
p_name
,
new_val
)
DEFAULT_SAMPLING_PARAMS
[
"min_p"
]
)
xm
.
mark_step
()
xm
.
mark_step
()
xm
.
wait_device_ops
()
xm
.
wait_device_ops
()
return
new_metadata
@
classmethod
# Slice persistent device tensors to a fixed pre-compiled padded shape.
def
get_default_sampling_params
(
return
cls
(
cls
,
temperature
=
input_batch
.
temperature
[:
padded_num_reqs
],
num_samples
:
int
,
# Scalar tensor for xla-friendly tracing.
device
:
torch
.
device
,
all_greedy
=
torch
.
tensor
(
input_batch
.
all_greedy
,
indices_do_sample
=
None
,
dtype
=
torch
.
bool
,
do_argmax
=
None
)
->
"TPUSupportedSamplingMetadata"
:
device
=
input_batch
.
device
),
# As sampling happens on a single traced graph, options
# TODO enable more and avoid returning None values
# are "disabled" by having them evaluate to an Identity op.
top_p
=
None
,
# input_batch.top_p[:padded_num_reqs],
# Note that initialization is dependent on num_samples.
top_k
=
None
,
# input_batch.top_k[:padded_num_reqs],
sampling_metadata_disable_value
=
\
min_p
=
input_batch
.
min_p
[:
padded_num_reqs
],
TPUSupportedSamplingMetadata
.
_get_default_params_values
()
generators
=
input_batch
.
generators
,
init_kwargs
=
dict
()
indices_do_sample
=
indices_do_sample
)
for
p_name
,
(
default_val
,
dtype
)
in
sampling_metadata_disable_value
.
items
():
default_tensor
=
torch
.
full
((
num_samples
,
),
default_val
,
dtype
=
dtype
,
device
=
device
)
init_kwargs
[
p_name
]
=
default_tensor
return
cls
(
**
init_kwargs
,
indices_do_sample
=
indices_do_sample
,
do_argmax
=
do_argmax
)
@
staticmethod
def
_validate_sampling_metadata
(
sampling_metadata
:
SamplingMetadata
)
->
SamplingMetadata
:
if
sampling_metadata
.
all_greedy
:
# Set to None since #13587. Make sure default isn't overruled.
assert
sampling_metadata
.
temperature
is
None
return
sampling_metadata
@
staticmethod
def
_get_default_params_values
():
return
dict
(
# Since #13587 greedy sampling requires branching off which leads
# to separate graphs. We set temp to noop and handle argmax here.
temperature
=
(
1.0
,
torch
.
float32
),
min_p
=
(
0.0
,
torch
.
float32
),
# strictly disabled for now
# top_k=(-1, torch.int32),
# top_p=(0.0, torch.float32),
# frequency_penalties=(0.0, torch.float32),
# presence_penalties=(0.0, torch.float32),
# repetition_penalties=(0.0, torch.float32),
)
\ No newline at end of file
vllm/v1/serial_utils.py
View file @
675ba75f
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
pickle
import
pickle
from
types
import
FunctionType
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Optional
import
cloudpickle
import
torch
import
torch
from
msgspec
import
msgpack
from
msgspec
import
msgpack
CUSTOM_TYPE_TENSOR
=
1
CUSTOM_TYPE_TENSOR
=
1
CUSTOM_TYPE_PICKLE
=
2
CUSTOM_TYPE_PICKLE
=
2
CUSTOM_TYPE_CLOUDPICKLE
=
3
class
MsgpackEncoder
:
class
MsgpackEncoder
:
...
@@ -41,6 +44,9 @@ def custom_enc_hook(obj: Any) -> Any:
...
@@ -41,6 +44,9 @@ def custom_enc_hook(obj: Any) -> Any:
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
return
msgpack
.
Ext
(
CUSTOM_TYPE_TENSOR
,
pickle
.
dumps
(
obj
.
numpy
()))
return
msgpack
.
Ext
(
CUSTOM_TYPE_TENSOR
,
pickle
.
dumps
(
obj
.
numpy
()))
if
isinstance
(
obj
,
FunctionType
):
return
msgpack
.
Ext
(
CUSTOM_TYPE_CLOUDPICKLE
,
cloudpickle
.
dumps
(
obj
))
return
msgpack
.
Ext
(
CUSTOM_TYPE_PICKLE
,
pickle
.
dumps
(
obj
))
return
msgpack
.
Ext
(
CUSTOM_TYPE_PICKLE
,
pickle
.
dumps
(
obj
))
...
@@ -49,5 +55,7 @@ def custom_ext_hook(code: int, data: memoryview) -> Any:
...
@@ -49,5 +55,7 @@ def custom_ext_hook(code: int, data: memoryview) -> Any:
return
torch
.
from_numpy
(
pickle
.
loads
(
data
))
return
torch
.
from_numpy
(
pickle
.
loads
(
data
))
if
code
==
CUSTOM_TYPE_PICKLE
:
if
code
==
CUSTOM_TYPE_PICKLE
:
return
pickle
.
loads
(
data
)
return
pickle
.
loads
(
data
)
if
code
==
CUSTOM_TYPE_CLOUDPICKLE
:
return
cloudpickle
.
loads
(
data
)
raise
NotImplementedError
(
f
"Extension type code
{
code
}
is not supported"
)
raise
NotImplementedError
(
f
"Extension type code
{
code
}
is not supported"
)
Prev
1
…
20
21
22
23
24
25
26
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment