Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
31330101
Commit
31330101
authored
Apr 16, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.4' into v0.8.4-dev
parents
e8933c34
dc1b4a6f
Changes
346
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
819 additions
and
457 deletions
+819
-457
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+64
-8
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+26
-6
vllm/v1/engine/__init__.py
vllm/v1/engine/__init__.py
+2
-1
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+25
-20
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+85
-46
vllm/v1/engine/mm_input_cache.py
vllm/v1/engine/mm_input_cache.py
+34
-7
vllm/v1/engine/processor.py
vllm/v1/engine/processor.py
+97
-39
vllm/v1/executor/multiproc_executor.py
vllm/v1/executor/multiproc_executor.py
+8
-7
vllm/v1/metrics/loggers.py
vllm/v1/metrics/loggers.py
+4
-3
vllm/v1/request.py
vllm/v1/request.py
+10
-6
vllm/v1/sample/sampler.py
vllm/v1/sample/sampler.py
+10
-0
vllm/v1/sample/tpu/metadata.py
vllm/v1/sample/tpu/metadata.py
+43
-46
vllm/v1/serial_utils.py
vllm/v1/serial_utils.py
+125
-41
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+36
-25
vllm/v1/structured_output/backend_guidance.py
vllm/v1/structured_output/backend_guidance.py
+3
-3
vllm/v1/structured_output/backend_xgrammar.py
vllm/v1/structured_output/backend_xgrammar.py
+7
-1
vllm/v1/structured_output/utils.py
vllm/v1/structured_output/utils.py
+1
-2
vllm/v1/utils.py
vllm/v1/utils.py
+9
-15
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+50
-32
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+180
-149
No files found.
vllm/v1/core/kv_cache_utils.py
View file @
31330101
...
...
@@ -8,7 +8,7 @@ from typing import Any, Callable, NamedTuple, Optional
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.utils
import
sha256
from
vllm.utils
import
GiB_bytes
,
sha256
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
,
KVCacheSpec
,
KVCacheTensor
,
SlidingWindowSpec
)
...
...
@@ -310,8 +310,7 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
# Note that we assume mm_positions is sorted by offset.
# We do not need to check all mm inputs if the start token index is out of
# range. This usually happens in the late prefill phase and decoding phase.
if
mm_positions
[
-
1
][
"offset"
]
+
mm_positions
[
-
1
][
"length"
]
<
start_token_idx
:
if
mm_positions
[
-
1
].
offset
+
mm_positions
[
-
1
].
length
<
start_token_idx
:
return
extra_keys
,
start_mm_idx
# Support start_mm_idx == -1 to indicate the last mm input.
...
...
@@ -322,8 +321,8 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
curr_mm_idx
=
start_mm_idx
while
mm_positions
and
curr_mm_idx
<
len
(
mm_positions
):
assert
mm_hashes
[
curr_mm_idx
]
is
not
None
offset
=
mm_positions
[
curr_mm_idx
]
[
"
offset
"
]
length
=
mm_positions
[
curr_mm_idx
]
[
"
length
"
]
offset
=
mm_positions
[
curr_mm_idx
]
.
offset
length
=
mm_positions
[
curr_mm_idx
]
.
length
if
end_token_idx
>
offset
:
if
start_token_idx
>
offset
+
length
:
# This block has passed the current mm input.
...
...
@@ -460,6 +459,54 @@ def hash_request_tokens(hash_function: Any, block_size: int,
return
ret
def
estimate_max_model_len
(
vllm_config
:
VllmConfig
,
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
],
available_memory
:
int
)
->
int
:
"""
Estimates the maximum model length that can fit in the available memory
using binary search.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes.
Returns:
The estimated maximum model length that can fit in the available memory.
"""
# Define a function to check if a given model length fits in memory
def
fits_in_memory
(
model_len
:
int
)
->
bool
:
# Modify the max_model_len for this calculation
vllm_config
.
model_config
.
max_model_len
=
model_len
# Calculate memory needed for the given model length
memory_needed
=
sum
(
(
layer_spec
.
max_memory_usage_bytes
(
vllm_config
)
for
layer_spec
in
kv_cache_spec
.
values
()),
start
=
0
,
)
return
memory_needed
<=
available_memory
# Binary search for the maximum model length
current_max
=
vllm_config
.
model_config
.
max_model_len
left
,
right
=
1
,
current_max
# If even the smallest model length doesn't fit, return 0
if
not
fits_in_memory
(
left
):
return
0
# Binary search for the maximum model length that fits
result
=
1
while
left
<=
right
:
mid
=
(
left
+
right
)
//
2
if
fits_in_memory
(
mid
):
result
=
mid
left
=
mid
+
1
else
:
right
=
mid
-
1
return
result
def
check_enough_kv_cache_memory
(
vllm_config
:
VllmConfig
,
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
],
available_memory
:
int
):
...
...
@@ -487,12 +534,21 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
needed_memory
+=
layer_spec
.
max_memory_usage_bytes
(
vllm_config
)
if
needed_memory
>
available_memory
:
# Estimate the maximum model length that can fit in the available memory
estimated_max_len
=
estimate_max_model_len
(
vllm_config
,
kv_cache_spec
,
available_memory
)
estimated_msg
=
""
if
estimated_max_len
>
0
:
estimated_msg
=
" Based on the available memory,"
f
" the estimated maximum model length is
{
estimated_max_len
}
."
raise
ValueError
(
f
"To serve at least one request with the models's max seq len "
f
"(
{
max_model_len
}
), (
{
needed_memory
/
1024
/
1024
/
1024
:.
2
f
}
GiB KV "
f
"(
{
max_model_len
}
), (
{
needed_memory
/
GiB_bytes
:.
2
f
}
GiB KV "
f
"cache is needed, which is larger than the available KV cache "
f
"memory (
{
available_memory
/
1024
/
1024
/
1024
:.
2
f
}
GiB). Try "
f
"increasing `gpu_memory_utilization` or decreasing "
f
"memory (
{
available_memory
/
GiB_bytes
:.
2
f
}
GiB)."
f
"
{
estimated_msg
}
"
f
" Try increasing `gpu_memory_utilization` or decreasing "
f
"`max_model_len` when initializing the engine."
)
...
...
vllm/v1/core/sched/scheduler.py
View file @
31330101
...
...
@@ -7,7 +7,8 @@ from collections import deque
from
collections.abc
import
Iterable
from
typing
import
Optional
,
Union
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
ModelConfig
,
SchedulerConfig
from
vllm.config
import
(
CacheConfig
,
LoRAConfig
,
ModelConfig
,
SchedulerConfig
,
SpeculativeConfig
)
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.v1.core.encoder_cache_manager
import
(
EncoderCacheManager
,
...
...
@@ -39,6 +40,7 @@ class Scheduler(SchedulerInterface):
lora_config
:
Optional
[
LoRAConfig
],
kv_cache_config
:
KVCacheConfig
,
structured_output_manager
:
StructuredOutputManager
,
speculative_config
:
SpeculativeConfig
=
None
,
mm_registry
:
MultiModalRegistry
=
MULTIMODAL_REGISTRY
,
include_finished_set
:
bool
=
False
,
log_stats
:
bool
=
False
,
...
...
@@ -112,6 +114,11 @@ class Scheduler(SchedulerInterface):
self
.
encoder_cache_manager
=
EncoderCacheManager
(
cache_size
=
encoder_cache_size
)
self
.
num_lookahead_tokens
=
0
if
speculative_config
and
speculative_config
.
method
==
"eagle"
:
self
.
num_lookahead_tokens
=
\
speculative_config
.
num_speculative_tokens
def
schedule
(
self
)
->
SchedulerOutput
:
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
...
...
@@ -188,7 +195,9 @@ class Scheduler(SchedulerInterface):
while
True
:
new_blocks
=
self
.
kv_cache_manager
.
allocate_slots
(
request
,
num_new_tokens
)
request
,
num_new_tokens
,
num_lookahead_tokens
=
self
.
num_lookahead_tokens
)
if
new_blocks
is
None
:
# The request cannot be scheduled.
# Preempt the lowest-priority request.
...
...
@@ -505,8 +514,8 @@ class Scheduler(SchedulerInterface):
assert
mm_positions
is
not
None
assert
len
(
mm_positions
)
>
0
for
i
,
pos_info
in
enumerate
(
mm_positions
):
start_pos
=
pos_info
[
"
offset
"
]
num_encoder_tokens
=
pos_info
[
"
length
"
]
start_pos
=
pos_info
.
offset
num_encoder_tokens
=
pos_info
.
length
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
...
...
@@ -522,6 +531,17 @@ class Scheduler(SchedulerInterface):
if
self
.
encoder_cache_manager
.
has_cache
(
request
,
i
):
# The encoder input is already computed and cached.
continue
# If no encoder input chunking is allowed, we do not want to
# partially schedule a multimodal item. If the scheduled range would
# only cover part of the mm input, roll back to before the mm item.
if
(
self
.
scheduler_config
.
disable_chunked_mm_input
and
num_computed_tokens
<
start_pos
and
(
num_computed_tokens
+
num_new_tokens
)
<
(
start_pos
+
num_encoder_tokens
)):
num_new_tokens
=
start_pos
-
num_computed_tokens
break
if
(
not
self
.
encoder_cache_manager
.
can_allocate
(
request
,
i
)
or
num_encoder_tokens
>
encoder_budget
):
# The encoder cache is full or the encoder budget is exhausted.
...
...
@@ -596,8 +616,8 @@ class Scheduler(SchedulerInterface):
if
cached_encoder_input_ids
:
for
input_id
in
list
(
cached_encoder_input_ids
):
mm_positions
=
request
.
mm_positions
[
input_id
]
start_pos
=
mm_positions
[
"
offset
"
]
num_tokens
=
mm_positions
[
"
length
"
]
start_pos
=
mm_positions
.
offset
num_tokens
=
mm_positions
.
length
if
start_pos
+
num_tokens
<=
request
.
num_computed_tokens
:
# The encoder output is already processed and stored
# in the decoder's KV cache.
...
...
vllm/v1/engine/__init__.py
View file @
31330101
...
...
@@ -2,6 +2,7 @@
import
enum
import
time
from
collections.abc
import
Sequence
from
typing
import
Any
,
Optional
,
Union
import
msgspec
...
...
@@ -52,7 +53,7 @@ class EngineCoreRequest(
# Detokenizer, but set to None when it is added to EngineCoreClient.
prompt
:
Optional
[
str
]
prompt_token_ids
:
list
[
int
]
mm_inputs
:
Optional
[
list
[
MultiModalKwargs
]]
mm_inputs
:
Optional
[
Sequence
[
Optional
[
MultiModalKwargs
]]
]
mm_hashes
:
Optional
[
list
[
str
]]
mm_placeholders
:
Optional
[
list
[
PlaceholderRange
]]
sampling_params
:
SamplingParams
...
...
vllm/v1/engine/core.py
View file @
31330101
...
...
@@ -31,7 +31,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
from
vllm.v1.core.sched.scheduler
import
Scheduler
as
V1Scheduler
from
vllm.v1.engine
import
(
EngineCoreOutputs
,
EngineCoreRequest
,
EngineCoreRequestType
,
UtilityOutput
)
from
vllm.v1.engine.mm_input_cache
import
M
MInputCacheServer
from
vllm.v1.engine.mm_input_cache
import
M
irroredProcessingCache
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.outputs
import
ModelRunnerOutput
...
...
@@ -98,6 +98,7 @@ class EngineCore:
cache_config
=
vllm_config
.
cache_config
,
lora_config
=
vllm_config
.
lora_config
,
kv_cache_config
=
kv_cache_config
,
speculative_config
=
vllm_config
.
speculative_config
,
structured_output_manager
=
self
.
structured_output_manager
,
include_finished_set
=
vllm_config
.
parallel_config
.
data_parallel_size
>
1
,
...
...
@@ -105,7 +106,7 @@ class EngineCore:
)
# Setup MM Input Mapper.
self
.
mm_input_cache_server
=
M
MInputCacheServer
(
self
.
mm_input_cache_server
=
M
irroredProcessingCache
(
vllm_config
.
model_config
)
# Setup batch queue for pipeline parallelism.
...
...
@@ -173,7 +174,7 @@ class EngineCore:
# anything that has a hash must have a HIT cache entry here
# as well.
assert
request
.
mm_inputs
is
not
None
request
.
mm_inputs
=
self
.
mm_input_cache_server
.
get_and_update
(
request
.
mm_inputs
=
self
.
mm_input_cache_server
.
get_and_update
_p1
(
request
.
mm_inputs
,
request
.
mm_hashes
)
req
=
Request
.
from_engine_core_request
(
request
)
...
...
@@ -318,6 +319,11 @@ class EngineCoreProc(EngineCore):
):
super
().
__init__
(
vllm_config
,
executor_class
,
log_stats
)
self
.
step_fn
=
(
self
.
step
if
self
.
batch_queue
is
None
else
self
.
step_with_batch_queue
)
self
.
global_unfinished_reqs
=
False
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
...
...
@@ -327,22 +333,16 @@ class EngineCoreProc(EngineCore):
Any
]]
=
queue
.
Queue
()
self
.
output_queue
:
queue
.
Queue
[
EngineCoreOutputs
]
=
queue
.
Queue
()
threading
.
Thread
(
target
=
self
.
process_input_socket
,
args
=
(
input_path
,
),
args
=
(
input_path
,
engine_index
),
daemon
=
True
).
start
()
threading
.
Thread
(
target
=
self
.
process_output_socket
,
args
=
(
output_path
,
engine_index
),
daemon
=
True
).
start
()
self
.
global_unfinished_reqs
=
False
self
.
step_fn
=
(
self
.
step
if
self
.
batch_queue
is
None
else
self
.
step_with_batch_queue
)
@
staticmethod
def
run_engine_core
(
*
args
,
dp_rank
:
int
=
0
,
local_dp_rank
:
int
=
0
,
ready_pipe
,
**
kwargs
):
"""Launch EngineCore busy loop in background process."""
...
...
@@ -377,9 +377,6 @@ class EngineCoreProc(EngineCore):
else
:
engine_core
=
EngineCoreProc
(
*
args
,
**
kwargs
)
# Send Readiness signal to EngineClient.
ready_pipe
.
send
({
"status"
:
"READY"
})
engine_core
.
run_busy_loop
()
except
SystemExit
:
...
...
@@ -476,24 +473,32 @@ class EngineCoreProc(EngineCore):
and
not
isinstance
(
v
,
p
.
annotation
)
else
v
for
v
,
p
in
zip
(
args
,
arg_types
))
def
process_input_socket
(
self
,
input_path
:
str
):
def
process_input_socket
(
self
,
input_path
:
str
,
engine_index
:
int
):
"""Input socket IO thread."""
# Msgpack serialization decoding.
add_request_decoder
=
MsgpackDecoder
(
EngineCoreRequest
)
generic_decoder
=
MsgpackDecoder
()
identity
=
engine_index
.
to_bytes
(
length
=
2
,
byteorder
=
"little"
)
with
zmq_socket_ctx
(
input_path
,
zmq
.
DEALER
,
identity
=
identity
,
bind
=
False
)
as
socket
:
# Send ready message to front-end once input socket is connected.
socket
.
send
(
b
'READY'
)
with
zmq_socket_ctx
(
input_path
,
zmq
.
constants
.
PULL
)
as
socket
:
while
True
:
# (RequestType, RequestData)
type_frame
,
data_frame
=
socket
.
recv_multipart
(
copy
=
False
)
type_frame
,
*
data_frame
s
=
socket
.
recv_multipart
(
copy
=
False
)
request_type
=
EngineCoreRequestType
(
bytes
(
type_frame
.
buffer
))
# Deserialize the request data.
decoder
=
add_request_decoder
if
(
request_type
==
EngineCoreRequestType
.
ADD
)
else
generic_decoder
request
=
decoder
.
decode
(
data_frame
.
buffer
)
request
=
decoder
.
decode
(
data_frame
s
)
# Push to input queue for core busy loop.
self
.
input_queue
.
put_nowait
((
request_type
,
request
))
...
...
@@ -510,8 +515,8 @@ class EngineCoreProc(EngineCore):
while
True
:
outputs
=
self
.
output_queue
.
get
()
outputs
.
engine_index
=
engine_index
encoder
.
encode_into
(
outputs
,
buffer
)
socket
.
send
(
buffer
,
copy
=
False
)
buffers
=
encoder
.
encode_into
(
outputs
,
buffer
)
socket
.
send
_multipart
(
buffer
s
,
copy
=
False
)
ENGINE_PAUSED_OUTPUTS
=
EngineCoreOutputs
(
engine_paused
=
True
)
...
...
@@ -619,4 +624,4 @@ class DPEngineCoreProc(EngineCoreProc):
self
.
counter
=
0
return
ParallelConfig
.
has_unfinished_dp
(
self
.
dp_group
,
local_unfinished
)
local_unfinished
)
\ No newline at end of file
vllm/v1/engine/core_client.py
View file @
31330101
...
...
@@ -8,7 +8,7 @@ import threading
import
uuid
import
weakref
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Awaitable
,
Sequence
from
collections.abc
import
Awaitable
from
concurrent.futures
import
Future
from
dataclasses
import
dataclass
,
field
from
threading
import
Thread
...
...
@@ -26,7 +26,7 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType
,
UtilityOutput
)
from
vllm.v1.engine.core
import
EngineCore
,
EngineCoreProc
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
,
bytestr
from
vllm.v1.utils
import
BackgroundProcHandle
logger
=
init_logger
(
__name__
)
...
...
@@ -35,6 +35,8 @@ AnyFuture = Union[asyncio.Future[Any], Future[Any]]
_R
=
TypeVar
(
'_R'
)
# Return type for collective_rpc
STARTUP_POLL_PERIOD_MS
=
10000
class
EngineCoreClient
(
ABC
):
"""
...
...
@@ -261,15 +263,13 @@ class CoreEngine:
vllm_config
:
VllmConfig
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
ctx
:
Union
[
zmq
.
Context
,
zmq
.
asyncio
.
Context
]
,
input_path
:
str
,
output_path
:
str
,
index
:
int
=
0
,
local_dp_rank
:
int
=
0
,
):
# Paths and sockets for IPC.
input_path
=
get_open_zmq_ipc_path
()
self
.
input_socket
=
make_zmq_socket
(
ctx
,
input_path
,
zmq
.
constants
.
PUSH
)
self
.
index
=
index
self
.
identity
=
index
.
to_bytes
(
length
=
2
,
byteorder
=
"little"
)
try
:
# Start EngineCore in background process.
self
.
proc_handle
=
BackgroundProcHandle
(
...
...
@@ -291,14 +291,9 @@ class CoreEngine:
# Ensure socket is closed if process fails to start.
self
.
close
()
def
send_multipart
(
self
,
msg_parts
:
Sequence
):
return
self
.
input_socket
.
send_multipart
(
msg_parts
,
copy
=
False
)
def
close
(
self
):
if
proc_handle
:
=
getattr
(
self
,
"proc_handle"
,
None
):
proc_handle
.
shutdown
()
if
socket
:
=
getattr
(
self
,
"input_socket"
,
None
):
socket
.
close
(
linger
=
0
)
@
dataclass
...
...
@@ -309,6 +304,7 @@ class BackgroundResources:
ctx
:
Union
[
zmq
.
Context
]
core_engines
:
list
[
CoreEngine
]
=
field
(
default_factory
=
list
)
output_socket
:
Optional
[
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]]
=
None
input_socket
:
Optional
[
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]]
=
None
shutdown_path
:
Optional
[
str
]
=
None
def
__call__
(
self
):
...
...
@@ -321,6 +317,8 @@ class BackgroundResources:
# aren't explicitly closed first.
if
self
.
output_socket
is
not
None
:
self
.
output_socket
.
close
(
linger
=
0
)
if
self
.
input_socket
is
not
None
:
self
.
input_socket
.
close
(
linger
=
0
)
if
self
.
shutdown_path
is
not
None
:
# We must ensure that the sync output socket is
# closed cleanly in its own thread.
...
...
@@ -387,21 +385,56 @@ class MPClient(EngineCoreClient):
# Paths and sockets for IPC.
self
.
output_path
=
get_open_zmq_ipc_path
()
input_path
=
get_open_zmq_ipc_path
()
self
.
input_socket
=
make_zmq_socket
(
self
.
ctx
,
input_path
,
zmq
.
ROUTER
,
bind
=
True
)
self
.
resources
.
input_socket
=
self
.
input_socket
new_core_engine
=
lambda
index
,
local_dp_rank
=
None
:
CoreEngine
(
vllm_config
,
executor_class
,
log_stats
,
self
.
ctx
,
self
.
out
put_path
,
index
,
local_dp_rank
)
vllm_config
,
executor_class
,
log_stats
,
in
put_path
,
self
.
output_path
,
index
,
local_dp_rank
)
# Start engine core process(es).
self
.
_init_core_engines
(
vllm_config
,
new_core_engine
,
self
.
resources
.
core_engines
)
# Wait for engine core process(es) to start.
for
engine
in
self
.
resources
.
core_engines
:
engine
.
proc_handle
.
wait_for_startup
()
self
.
_wait_for_engine_startup
()
self
.
utility_results
:
dict
[
int
,
AnyFuture
]
=
{}
def
_wait_for_engine_startup
(
self
):
# Get a sync handle to the socket which can be sync or async.
sync_input_socket
=
zmq
.
Socket
.
shadow
(
self
.
input_socket
)
# Wait for engine core process(es) to send ready messages.
identities
=
set
(
eng
.
index
for
eng
in
self
.
resources
.
core_engines
)
poller
=
zmq
.
Poller
()
poller
.
register
(
sync_input_socket
,
zmq
.
POLLIN
)
for
eng
in
self
.
resources
.
core_engines
:
poller
.
register
(
eng
.
proc_handle
,
zmq
.
POLLIN
)
while
identities
:
events
=
poller
.
poll
(
STARTUP_POLL_PERIOD_MS
)
if
not
events
:
logger
.
debug
(
"Waiting for %d core engine proc(s) to start: %s"
,
len
(
identities
),
identities
)
continue
if
len
(
events
)
>
1
or
events
[
0
][
0
]
!=
sync_input_socket
:
# One of the core processes exited.
raise
RuntimeError
(
"Engine core initialization failed. "
"See root cause above."
)
eng_id_bytes
,
msg
=
sync_input_socket
.
recv_multipart
()
eng_id
=
int
.
from_bytes
(
eng_id_bytes
,
byteorder
=
"little"
)
if
eng_id
not
in
identities
:
raise
RuntimeError
(
f
"Unexpected or duplicate engine:
{
eng_id
}
"
)
if
msg
!=
b
'READY'
:
raise
RuntimeError
(
f
"Engine
{
eng_id
}
failed:
{
msg
.
decode
()
}
"
)
logger
.
info
(
"Core engine process %d ready."
,
eng_id
)
identities
.
discard
(
eng_id
)
def
_init_core_engines
(
self
,
vllm_config
:
VllmConfig
,
...
...
@@ -472,8 +505,8 @@ class SyncMPClient(MPClient):
# shutdown signal, exit thread.
break
frame
=
out_socket
.
recv
(
copy
=
False
)
outputs
=
decoder
.
decode
(
frame
.
buffer
)
frame
s
=
out_socket
.
recv
_multipart
(
copy
=
False
)
outputs
=
decoder
.
decode
(
frame
s
)
if
outputs
.
utility_output
:
_process_utility_output
(
outputs
.
utility_output
,
utility_results
)
...
...
@@ -494,9 +527,10 @@ class SyncMPClient(MPClient):
return
self
.
outputs_queue
.
get
()
def
_send_input
(
self
,
request_type
:
EngineCoreRequestType
,
request
:
Any
):
# (RequestType, SerializedRequest)
msg
=
(
request_type
.
value
,
self
.
encoder
.
encode
(
request
))
self
.
core_engine
.
send_multipart
(
msg
)
# (Identity, RequestType, SerializedRequest)
msg
=
(
self
.
core_engine
.
identity
,
request_type
.
value
,
*
self
.
encoder
.
encode
(
request
))
self
.
input_socket
.
send_multipart
(
msg
,
copy
=
False
)
def
call_utility
(
self
,
method
:
str
,
*
args
)
->
Any
:
call_id
=
uuid
.
uuid1
().
int
>>
64
...
...
@@ -599,8 +633,8 @@ class AsyncMPClient(MPClient):
async
def
process_outputs_socket
():
while
True
:
(
frame
,
)
=
await
output_socket
.
recv_multipart
(
copy
=
False
)
outputs
:
EngineCoreOutputs
=
decoder
.
decode
(
frame
.
buffer
)
frame
s
=
await
output_socket
.
recv_multipart
(
copy
=
False
)
outputs
:
EngineCoreOutputs
=
decoder
.
decode
(
frame
s
)
if
outputs
.
utility_output
:
_process_utility_output
(
outputs
.
utility_output
,
utility_results
)
...
...
@@ -625,30 +659,34 @@ class AsyncMPClient(MPClient):
assert
self
.
outputs_queue
is
not
None
return
await
self
.
outputs_queue
.
get
()
async
def
_send_input
(
self
,
request_type
:
EngineCoreRequestType
,
request
:
Any
)
->
None
:
await
self
.
core_engine
.
send_multipart
(
(
request_type
.
value
,
self
.
encoder
.
encode
(
request
)))
def
_send_input
(
self
,
request_type
:
EngineCoreRequestType
,
request
:
Any
,
engine
:
Optional
[
CoreEngine
]
=
None
)
->
Awaitable
[
None
]:
if
engine
is
None
:
engine
=
self
.
core_engine
self
.
_ensure_output_queue_task
()
message
=
(
request_type
.
value
,
*
self
.
encoder
.
encode
(
request
))
return
self
.
_send_input_message
(
message
,
engine
)
def
_send_input_message
(
self
,
message
:
tuple
[
bytestr
,
...],
engine
:
CoreEngine
)
->
Awaitable
[
None
]:
message
=
(
engine
.
identity
,
)
+
message
return
self
.
input_socket
.
send_multipart
(
message
,
copy
=
False
)
async
def
call_utility_async
(
self
,
method
:
str
,
*
args
)
->
Any
:
return
await
self
.
_call_utility_async
(
method
,
*
args
,
engine
=
self
.
core_engine
)
async
def
_call_utility_async
(
self
,
method
:
str
,
*
args
,
engine
:
CoreEngine
,
)
->
Any
:
async
def
_call_utility_async
(
self
,
method
:
str
,
*
args
,
engine
:
CoreEngine
)
->
Any
:
call_id
=
uuid
.
uuid1
().
int
>>
64
future
=
asyncio
.
get_running_loop
().
create_future
()
self
.
utility_results
[
call_id
]
=
future
message
=
(
EngineCoreRequestType
.
UTILITY
.
value
,
self
.
encoder
.
encode
(
(
call_id
,
method
,
args
)))
await
engine
.
send_multipart
(
messag
e
)
message
=
(
EngineCoreRequestType
.
UTILITY
.
value
,
*
self
.
encoder
.
encode
(
(
call_id
,
method
,
args
)))
await
self
.
_send_input_message
(
message
,
engin
e
)
self
.
_ensure_output_queue_task
()
return
await
future
...
...
@@ -657,6 +695,7 @@ class AsyncMPClient(MPClient):
# tokenized.
request
.
prompt
=
None
await
self
.
_send_input
(
EngineCoreRequestType
.
ADD
,
request
)
self
.
_ensure_output_queue_task
()
async
def
abort_requests_async
(
self
,
request_ids
:
list
[
str
])
->
None
:
if
len
(
request_ids
)
>
0
:
...
...
@@ -721,7 +760,7 @@ class DPAsyncMPClient(AsyncMPClient):
# Control message used for triggering dp idle mode loop.
self
.
start_dp_msg
=
(
EngineCoreRequestType
.
START_DP
.
value
,
self
.
encoder
.
encode
(
None
))
*
self
.
encoder
.
encode
(
None
))
self
.
num_engines_running
=
0
self
.
reqs_in_flight
:
dict
[
str
,
CoreEngine
]
=
{}
...
...
@@ -755,21 +794,21 @@ class DPAsyncMPClient(AsyncMPClient):
# tokenized.
request
.
prompt
=
None
msg
=
(
EngineCoreRequestType
.
ADD
.
value
,
self
.
encoder
.
encode
(
request
))
msg
=
(
EngineCoreRequestType
.
ADD
.
value
,
*
self
.
encoder
.
encode
(
request
))
chosen_engine
=
self
.
get_core_engine_for_request
()
self
.
reqs_in_flight
[
request
.
request_id
]
=
chosen_engine
chosen_engine
.
num_reqs_in_flight
+=
1
if
self
.
num_engines_running
>=
len
(
self
.
core_engines
):
await
chosen_engine
.
send_multipart
(
msg
)
await
self
.
_send_input_message
(
msg
,
chosen_engine
)
else
:
# Send request to chosen engine and dp start loop
# control message to all other engines.
self
.
num_engines_running
+=
len
(
self
.
core_engines
)
await
asyncio
.
gather
(
*
[
engine
.
send_multipart
(
msg
if
engine
is
chosen_engine
else
self
.
start_dp_msg
)
for
engine
in
self
.
core_engines
self
.
_send_input_message
(
msg
if
engine
is
chosen_engine
else
self
.
start_dp_msg
,
engine
)
for
engine
in
self
.
core_engines
])
self
.
_ensure_output_queue_task
()
...
...
@@ -794,7 +833,7 @@ class DPAsyncMPClient(AsyncMPClient):
# sure to start the other engines:
self
.
num_engines_running
=
len
(
self
.
core_engines
)
coros
=
[
engine
.
send_multipart
(
self
.
start_dp_msg
)
self
.
_send_input_message
(
self
.
start_dp_msg
,
engine
)
for
engine
in
self
.
core_engines
if
not
engine
.
num_reqs_in_flight
]
...
...
@@ -820,5 +859,5 @@ class DPAsyncMPClient(AsyncMPClient):
async
def
_abort_requests
(
self
,
request_ids
:
list
[
str
],
engine
:
CoreEngine
)
->
None
:
await
engine
.
send_multipar
t
(
(
EngineCoreRequestType
.
ABORT
.
value
,
self
.
encoder
.
encode
(
request_ids
))
)
await
self
.
_send_inpu
t
(
EngineCoreRequestType
.
ABORT
,
request_ids
,
engine
)
\ No newline at end of file
vllm/v1/engine/mm_input_cache.py
View file @
31330101
# SPDX-License-Identifier: Apache-2.0
from
collections.abc
import
Sequence
from
typing
import
Optional
from
vllm.envs
import
VLLM_MM_INPUT_CACHE_GIB
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.multimodal.processing
import
ProcessingCache
from
vllm.utils
import
is_list_of
# The idea of multimodal preprocessing caching is based on having a client and
# a server, where the client executes in the frontend process (=P0) and the
...
...
@@ -11,9 +14,11 @@ from vllm.multimodal.processing import ProcessingCache
# -- Client:
# - BaseMultiModalProcessor to process MultiModalData into MultiModalKwargs
# with built-in caching functionality, with mm_hash as its identifier.
# - MirroredProcessingCache to keep track of the cached entries and
# determine whether to send the MultiModalKwargs to P1.
#
# -- Server:
# - M
MInputCacheServer to perform caching of the received
MultiModalKwargs.
# - M
irroredProcessingCache to store the
MultiModalKwargs
from P0
.
#
# The caching for both client and server is mirrored, and this allows us
# to avoid the serialization of "mm_inputs" (like pixel values) between
...
...
@@ -25,26 +30,48 @@ from vllm.multimodal.processing import ProcessingCache
# variable VLLM_MM_INPUT_CACHE_GIB.
class
M
MInputCacheServer
:
class
M
irroredProcessingCache
:
def
__init__
(
self
,
model_config
):
self
.
use_cache
=
not
model_config
.
disable_mm_preprocessor_cache
self
.
mm_cache
=
ProcessingCache
.
get_lru_cache
(
VLLM_MM_INPUT_CACHE_GIB
,
MultiModalKwargs
)
def
get_and_update
(
def
get_and_update
_p0
(
self
,
mm_inputs
:
list
[
MultiModalKwargs
],
mm_inputs
:
Sequence
[
MultiModalKwargs
],
mm_hashes
:
list
[
str
],
)
->
list
[
MultiModalKwargs
]:
)
->
Sequence
[
Optional
[
MultiModalKwargs
]
]
:
assert
len
(
mm_inputs
)
==
len
(
mm_hashes
)
if
not
self
.
use_cache
:
assert
is_list_of
(
mm_inputs
,
MultiModalKwargs
)
return
mm_inputs
full_mm_inputs
=
[]
full_mm_inputs
=
list
[
Optional
[
MultiModalKwargs
]]()
for
mm_input
,
mm_hash
in
zip
(
mm_inputs
,
mm_hashes
):
if
mm_hash
in
self
.
mm_cache
:
mm_input
=
None
else
:
self
.
mm_cache
[
mm_hash
]
=
mm_input
full_mm_inputs
.
append
(
mm_input
)
return
full_mm_inputs
def
get_and_update_p1
(
self
,
mm_inputs
:
Sequence
[
Optional
[
MultiModalKwargs
]],
mm_hashes
:
list
[
str
],
)
->
Sequence
[
MultiModalKwargs
]:
assert
len
(
mm_inputs
)
==
len
(
mm_hashes
)
if
not
self
.
use_cache
:
assert
is_list_of
(
mm_inputs
,
MultiModalKwargs
)
return
mm_inputs
full_mm_inputs
=
list
[
MultiModalKwargs
]()
for
mm_input
,
mm_hash
in
zip
(
mm_inputs
,
mm_hashes
):
assert
mm_hash
is
not
None
if
mm_input
is
None
:
mm_input
=
self
.
mm_cache
[
mm_hash
]
else
:
...
...
vllm/v1/engine/processor.py
View file @
31330101
# SPDX-License-Identifier: Apache-2.0
import
time
from
collections.abc
import
Mapping
from
typing
import
Optional
,
Union
from
collections.abc
import
Mapping
,
Sequence
from
typing
import
Literal
,
Optional
,
Union
from
vllm.config
import
VllmConfig
from
vllm.inputs
import
ProcessorInputs
,
PromptType
from
vllm.inputs
import
ProcessorInputs
,
PromptType
,
SingletonInputs
from
vllm.inputs.parse
import
split_enc_dec_inputs
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
MultiModalKwargs
,
MultiModalRegistry
)
from
vllm.multimodal.inputs
import
PlaceholderRange
from
vllm.multimodal.processing
import
EncDecMultiModalProcessor
from
vllm.multimodal.utils
import
merge_and_sort_multimodal_metadata
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer_group
import
BaseTokenizerGroup
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine.mm_input_cache
import
MirroredProcessingCache
from
vllm.v1.structured_output.backend_guidance
import
(
validate_guidance_grammar
)
from
vllm.v1.structured_output.utils
import
(
...
...
@@ -46,6 +48,8 @@ class Processor:
self
.
tokenizer
,
mm_registry
)
self
.
mm_input_cache_client
=
MirroredProcessingCache
(
self
.
model_config
)
# Multi-modal hasher (for images)
self
.
use_hash
=
(
not
self
.
model_config
.
disable_mm_preprocessor_cache
)
or
\
...
...
@@ -73,6 +77,7 @@ class Processor:
params
:
SamplingParams
,
)
->
None
:
self
.
_validate_structured_output
(
params
)
self
.
_validate_logit_bias
(
params
)
if
params
.
allowed_token_ids
is
None
:
return
...
...
@@ -83,6 +88,26 @@ class Processor:
raise
ValueError
(
"allowed_token_ids contains out-of-vocab token id!"
)
def
_validate_logit_bias
(
self
,
params
:
SamplingParams
,
)
->
None
:
"""Validate logit_bias token IDs are within vocabulary range."""
if
not
params
.
logit_bias
:
return
vocab_size
=
self
.
model_config
.
get_vocab_size
()
invalid_token_ids
=
[]
for
token_id
in
params
.
logit_bias
:
if
token_id
<
0
or
token_id
>=
vocab_size
:
invalid_token_ids
.
append
(
token_id
)
if
invalid_token_ids
:
raise
ValueError
(
f
"token_id(s)
{
invalid_token_ids
}
in logit_bias contain "
f
"out-of-vocab token ids. Vocabulary size:
{
vocab_size
}
"
)
def
_validate_supported_sampling_params
(
self
,
params
:
SamplingParams
,
...
...
@@ -136,9 +161,6 @@ class Processor:
f
" !=
{
engine_level_backend
}
"
)
else
:
params
.
guided_decoding
.
backend
=
engine_level_backend
import
vllm.platforms
if
vllm
.
platforms
.
current_platform
.
is_tpu
():
raise
ValueError
(
"Structured output is not supported on TPU."
)
# Request content validation
if
engine_level_backend
.
startswith
(
"xgrammar"
):
...
...
@@ -181,6 +203,11 @@ class Processor:
# TODO(woosuk): Support pooling models.
# TODO(woosuk): Support encoder-decoder models.
from
vllm.platforms
import
current_platform
current_platform
.
validate_request
(
prompt
=
prompt
,
params
=
params
,
)
self
.
_validate_lora
(
lora_request
)
self
.
_validate_params
(
params
)
if
priority
!=
0
:
...
...
@@ -228,7 +255,7 @@ class Processor:
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
))
# Multimodal related.
sorted_mm_inputs
:
Optional
[
list
[
MultiModalKwargs
]]
=
None
sorted_mm_inputs
:
Optional
[
Sequence
[
Optional
[
MultiModalKwargs
]]
]
=
None
sorted_mm_positions
:
Optional
[
list
[
PlaceholderRange
]]
=
None
sorted_mm_hashes
:
Optional
[
list
[
str
]]
=
None
if
decoder_inputs
[
"type"
]
==
"multimodal"
:
...
...
@@ -253,20 +280,28 @@ class Processor:
# are multiple modalities.
unique_modalities
=
set
(
sorted_item_modalities
)
if
len
(
unique_modalities
)
>
1
:
sorted_mm_inputs
=
[]
orig_
sorted_mm_inputs
=
[]
used_indices
=
{
modality
:
0
for
modality
in
unique_modalities
}
for
modality
in
sorted_item_modalities
:
items
=
decoder_mm_inputs
.
get_items
(
modality
)
item
=
items
[
used_indices
[
modality
]]
sorted_mm_inputs
.
append
(
MultiModalKwargs
.
from_items
([
item
]))
orig_sorted_mm_inputs
.
append
(
MultiModalKwargs
.
from_items
([
item
]))
used_indices
[
modality
]
+=
1
else
:
sorted_mm_inputs
=
[
orig_
sorted_mm_inputs
=
[
MultiModalKwargs
.
from_items
([
item
])
for
item
in
decoder_mm_inputs
.
get_items
(
sorted_item_modalities
[
0
])
]
if
sorted_mm_hashes
is
not
None
:
sorted_mm_inputs
=
self
.
mm_input_cache_client
.
get_and_update_p0
(
orig_sorted_mm_inputs
,
sorted_mm_hashes
)
else
:
sorted_mm_inputs
=
orig_sorted_mm_inputs
return
EngineCoreRequest
(
request_id
=
request_id
,
prompt
=
decoder_inputs
.
get
(
"prompt"
),
...
...
@@ -285,41 +320,64 @@ class Processor:
lora_request
:
Optional
[
LoRARequest
]
=
None
):
encoder_inputs
,
decoder_inputs
=
split_enc_dec_inputs
(
inputs
)
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
if
self
.
model_config
.
is_multimodal_model
:
prompt_inputs
=
decoder_inputs
else
:
prompt_inputs
=
encoder_inputs
or
decoder_inputs
prompt_ids
=
prompt_inputs
[
"prompt_token_ids"
]
if
prompt_ids
is
None
or
len
(
prompt_ids
)
==
0
:
raise
ValueError
(
"Prompt cannot be empty"
)
max_input_id
=
max
(
prompt_ids
)
max_allowed
=
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
).
max_token_id
if
max_input_id
>
max_allowed
:
raise
ValueError
(
"Token id {} is out of vocabulary"
.
format
(
max_input_id
))
if
encoder_inputs
is
not
None
:
self
.
_validate_model_input
(
encoder_inputs
,
lora_request
,
prompt_type
=
"encoder"
)
if
len
(
prompt_ids
)
>=
self
.
model_config
.
max_model_len
:
raise
ValueError
(
f
"Prompt length of
{
len
(
prompt_ids
)
}
is longer than the "
f
"maximum model length of
{
self
.
model_config
.
max_model_len
}
."
)
self
.
_validate_model_input
(
decoder_inputs
,
lora_request
,
prompt_type
=
"decoder"
)
if
self
.
model_config
.
is_multimodal_model
:
max_prompt_len
=
self
.
model_config
.
max_model_len
def
_validate_model_input
(
self
,
prompt_inputs
:
SingletonInputs
,
lora_request
:
Optional
[
LoRARequest
],
*
,
prompt_type
:
Literal
[
"encoder"
,
"decoder"
],
):
model_config
=
self
.
model_config
tokenizer
=
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
)
if
len
(
prompt_ids
)
>
max_prompt_len
:
raise
ValueError
(
f
"The prompt (total length
{
len
(
prompt_ids
)
}
) is too long "
f
"to fit into the model (context length
{
max_prompt_len
}
). "
prompt_ids
=
prompt_inputs
[
"prompt_token_ids"
]
if
not
prompt_ids
:
if
prompt_type
==
"encoder"
and
model_config
.
is_multimodal_model
:
pass
# Mllama may have empty encoder inputs for text-only data
else
:
raise
ValueError
(
f
"The
{
prompt_type
}
prompt cannot be empty"
)
max_input_id
=
max
(
prompt_ids
,
default
=
0
)
if
max_input_id
>
tokenizer
.
max_token_id
:
raise
ValueError
(
f
"Token id
{
max_input_id
}
is out of vocabulary"
)
max_prompt_len
=
self
.
model_config
.
max_model_len
if
len
(
prompt_ids
)
>=
max_prompt_len
:
if
prompt_type
==
"encoder"
and
model_config
.
is_multimodal_model
:
mm_registry
=
self
.
input_preprocessor
.
mm_registry
mm_processor
=
mm_registry
.
create_processor
(
model_config
,
tokenizer
=
tokenizer
,
)
assert
isinstance
(
mm_processor
,
EncDecMultiModalProcessor
)
if
mm_processor
.
pad_dummy_encoder_prompt
:
return
# Skip encoder length check for Whisper
if
model_config
.
is_multimodal_model
:
suggestion
=
(
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens plus multimodal tokens. For image "
"inputs, the number of image tokens depends on the number "
"of images, and possibly their aspect ratios as well."
)
else
:
suggestion
=
(
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens."
)
raise
ValueError
(
f
"The
{
prompt_type
}
prompt (length
{
len
(
prompt_ids
)
}
) is "
f
"longer than the maximum model length of
{
max_prompt_len
}
. "
f
"
{
suggestion
}
"
)
# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them
...
...
vllm/v1/executor/multiproc_executor.py
View file @
31330101
...
...
@@ -119,10 +119,9 @@ class MultiprocExecutor(Executor):
timeout
=
dequeue_timeout
)
if
status
!=
WorkerProc
.
ResponseStatus
.
SUCCESS
:
if
isinstance
(
result
,
Exception
):
raise
result
else
:
raise
RuntimeError
(
"Worker failed"
)
raise
RuntimeError
(
"Worker failed with error %s, please check the"
" stack trace above for the root cause"
,
result
)
responses
[
w
.
rank
]
=
result
...
...
@@ -327,7 +326,7 @@ class WorkerProc:
logger
.
debug
(
"Worker interrupted."
)
except
Exception
:
# worker_busy_loop sends exceptions
exceptons
to Executor
# worker_busy_loop sends exceptions to Executor
# for shutdown, but if there is an error in startup or an
# error with IPC itself, we need to alert the parent.
psutil
.
Process
().
parent
().
send_signal
(
signal
.
SIGUSR1
)
...
...
@@ -378,9 +377,11 @@ class WorkerProc:
# Notes have been introduced in python 3.11
if
hasattr
(
e
,
"add_note"
):
e
.
add_note
(
traceback
.
format_exc
())
self
.
worker_response_mq
.
enqueue
(
(
WorkerProc
.
ResponseStatus
.
FAILURE
,
e
))
logger
.
exception
(
"WorkerProc hit an exception: %s"
,
exc_info
=
e
)
# exception might not be serializable, so we convert it to
# string, only for logging purpose.
self
.
worker_response_mq
.
enqueue
(
(
WorkerProc
.
ResponseStatus
.
FAILURE
,
str
(
e
)))
continue
self
.
worker_response_mq
.
enqueue
(
...
...
vllm/v1/metrics/loggers.py
View file @
31330101
...
...
@@ -239,7 +239,8 @@ class PrometheusStatLogger(StatLoggerBase):
documentation
=
"Histogram of time to first token in seconds."
,
buckets
=
[
0.001
,
0.005
,
0.01
,
0.02
,
0.04
,
0.06
,
0.08
,
0.1
,
0.25
,
0.5
,
0.75
,
1.0
,
2.5
,
5.0
,
7.5
,
10.0
0.75
,
1.0
,
2.5
,
5.0
,
7.5
,
10.0
,
20.0
,
40.0
,
80.0
,
160.0
,
640.0
,
2560.0
],
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
...
...
@@ -249,13 +250,13 @@ class PrometheusStatLogger(StatLoggerBase):
documentation
=
"Histogram of time per output token in seconds."
,
buckets
=
[
0.01
,
0.025
,
0.05
,
0.075
,
0.1
,
0.15
,
0.2
,
0.3
,
0.4
,
0.5
,
0.75
,
1.0
,
2.5
0.75
,
1.0
,
2.5
,
5.0
,
7.5
,
10.0
,
20.0
,
40.0
,
80.0
],
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
request_latency_buckets
=
[
0.3
,
0.5
,
0.8
,
1.0
,
1.5
,
2.0
,
2.5
,
5.0
,
10.0
,
15.0
,
20.0
,
30.0
,
40.0
,
50.0
,
60.0
40.0
,
50.0
,
60.0
,
120.0
,
240.0
,
480.0
,
960.0
,
1920.0
,
7680.0
]
self
.
histogram_e2e_time_request
=
\
prometheus_client
.
Histogram
(
...
...
vllm/v1/request.py
View file @
31330101
...
...
@@ -3,17 +3,16 @@
import
enum
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
is_list_of
from
vllm.v1.engine
import
(
EngineCoreEvent
,
EngineCoreEventType
,
EngineCoreRequest
,
FinishReason
)
from
vllm.v1.structured_output.request
import
StructuredOutputRequest
from
vllm.v1.utils
import
ConstantList
if
TYPE_CHECKING
:
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.multimodal.inputs
import
PlaceholderRange
class
Request
:
...
...
@@ -23,9 +22,9 @@ class Request:
request_id
:
str
,
prompt
:
Optional
[
str
],
prompt_token_ids
:
list
[
int
],
multi_modal_inputs
:
Optional
[
list
[
"
MultiModalKwargs
"
]],
multi_modal_inputs
:
Optional
[
list
[
MultiModalKwargs
]],
multi_modal_hashes
:
Optional
[
list
[
str
]],
multi_modal_placeholders
:
Optional
[
list
[
"
PlaceholderRange
"
]],
multi_modal_placeholders
:
Optional
[
list
[
PlaceholderRange
]],
sampling_params
:
SamplingParams
,
eos_token_id
:
Optional
[
int
],
arrival_time
:
float
,
...
...
@@ -75,6 +74,11 @@ class Request:
@
classmethod
def
from_engine_core_request
(
cls
,
request
:
EngineCoreRequest
)
->
"Request"
:
if
request
.
mm_inputs
is
not
None
:
assert
isinstance
(
request
.
mm_inputs
,
list
)
assert
is_list_of
(
request
.
mm_inputs
,
MultiModalKwargs
),
(
"mm_inputs was not updated in EngineCore.add_request"
)
return
cls
(
request_id
=
request
.
request_id
,
prompt
=
request
.
prompt
,
...
...
@@ -121,7 +125,7 @@ class Request:
def
get_num_encoder_tokens
(
self
,
input_id
:
int
)
->
int
:
assert
input_id
<
len
(
self
.
mm_positions
)
num_tokens
=
self
.
mm_positions
[
input_id
]
[
"
length
"
]
num_tokens
=
self
.
mm_positions
[
input_id
]
.
length
return
num_tokens
@
property
...
...
vllm/v1/sample/sampler.py
View file @
31330101
...
...
@@ -230,9 +230,19 @@ class Sampler(nn.Module):
# TODO(houseroad): this implementation is extremely inefficient.
# One idea is implement this as a PyTorch C++ op, and we may
# even optimize the logit_bias layout.
# Get vocabulary size from logits
vocab_size
=
logits
.
shape
[
-
1
]
for
i
,
logit_bias
in
enumerate
(
sampling_metadata
.
logit_bias
):
if
logit_bias
:
for
token_id
,
bias
in
logit_bias
.
items
():
# Check token_id bounds to ensure within vocabulary
if
token_id
<
0
or
token_id
>=
vocab_size
:
raise
ValueError
(
f
"token_id
{
token_id
}
in logit_bias contains "
f
"out-of-vocab token id. Vocabulary size: "
f
"
{
vocab_size
}
"
)
logits
[
i
,
token_id
]
+=
bias
return
logits
...
...
vllm/v1/sample/tpu/metadata.py
View file @
31330101
...
...
@@ -3,7 +3,6 @@ from dataclasses import dataclass, field
from
typing
import
Optional
import
torch
import
torch_xla.core.xla_model
as
xm
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
...
...
@@ -24,19 +23,15 @@ class TPUSupportedSamplingMetadata:
# This class exposes a more xla-friendly interface than SamplingMetadata
# on TPU, in particular all arguments should be traceable and no optionals
# are allowed, to avoid graph recompilation on Nones.
temperature
:
torch
.
Tensor
temperature
:
torch
.
Tensor
=
None
min_p
:
torch
.
Tensor
min_p
:
torch
.
Tensor
=
None
# Still too slow on forward_native!
top_k
:
torch
.
Tensor
=
None
top_p
:
torch
.
Tensor
=
None
# Greedy sampling flag for compiling single xla graph.
all_greedy
:
torch
.
Tensor
=
None
# Generator not supported by xla
generators
:
dict
[
int
,
torch
.
Generator
]
=
field
(
default_factory
=
lambda
:
dict
())
all_greedy
:
bool
=
True
# unsupported, you need to return an extra tensor of static size BxV
max_num_logprobs
=
None
...
...
@@ -57,64 +52,66 @@ class TPUSupportedSamplingMetadata:
allowed_token_ids_mask
=
None
bad_words_token_ids
=
None
indices_do_sample
:
torch
.
Tensor
=
None
# Generator not supported by xla
_generators
:
dict
[
int
,
torch
.
Generator
]
=
field
(
default_factory
=
lambda
:
dict
())
@
property
def
generators
(
self
)
->
dict
[
int
,
torch
.
Generator
]:
# Generator not supported by torch/xla. This field must be immutable.
return
self
.
_generators
@
classmethod
def
from_input_batch
(
cls
,
input_batch
:
InputBatch
,
indices_do_sample
:
torch
.
Tensor
)
->
"TPUSupportedSamplingMetadata"
:
cls
,
input_batch
:
InputBatch
,
padded_num_reqs
:
int
,
xla_device
:
torch
.
device
,
generate_params_if_all_greedy
:
bool
=
False
)
->
"TPUSupportedSamplingMetadata"
:
"""
Copy sampling tensors slices from `input_batch` to on device tensors.
`InputBatch._make_sampling_metadata` causes recompilation on XLA as it
slices dynamic shapes on device tensors. This impl moves the dynamic
ops to CPU and produces tensors of fixed `padded_num_reqs` size. It
also reuses the on-device persistent tensors managed in `input_batch`
to reduce waste.
`indices_do_sample` contains the indices to be fed to the Sampler,
normally one per request, here padded to the closest pre-compiled shape
We expect sampling params tensors to be padded to the same fixed shape.
Eg. 3 requests, tensors padded to 4
temperature: [0.7, 0.2, 0.9]=>[0.7, 0.2, 0.9, 0.0]
sample indices: [4, 10, 11]=>indices_do_sample: [4, 10, 11, 0]
ops to CPU and produces tensors of fixed `padded_num_reqs` size.
Args:
input_batch: The input batch containing sampling parameters.
padded_num_reqs: The padded number of requests.
xla_device: The XLA device.
generate_params_if_all_greedy: If True, generate sampling parameters
even if all requests are greedy. this is useful for cases where
we want to pre-compile a graph with sampling parameters, even if
they are not strictly needed for greedy decoding.
"""
# Early return to avoid unnecessary cpu to tpu copy
if
(
input_batch
.
all_greedy
is
True
and
generate_params_if_all_greedy
is
False
):
return
cls
(
all_greedy
=
True
)
num_reqs
=
input_batch
.
num_reqs
padded_num_reqs
=
len
(
indices_do_sample
)
def
copy_slice
(
cpu_tensor
:
torch
.
Tensor
,
tpu_tensor
:
torch
.
Tensor
,
fill_val
)
->
torch
.
Tensor
:
# Copy slice from CPU to corresponding TPU pre-allocated tensor.
def
fill_slice
(
cpu_tensor
:
torch
.
Tensor
,
fill_val
)
->
torch
.
Tensor
:
# Pad value is the default one.
cpu_tensor
[
num_reqs
:
padded_num_reqs
]
=
fill_val
# Subtle compilation: len(tpu_tensor) must be >= `padded_num_reqs`
tpu_tensor
[:
padded_num_reqs
]
=
cpu_tensor
[:
padded_num_reqs
]
# NOTE NickLucche The sync CPU-TPU graph we produce here must be
# consistent. We can't have flags to skip copies or we'll end up
# recompiling.
copy_slice
(
input_batch
.
temperature_cpu_tensor
,
input_batch
.
temperature
,
fill_slice
(
input_batch
.
temperature_cpu_tensor
,
DEFAULT_SAMPLING_PARAMS
[
"temperature"
])
# TODO Temporarily disabled until sampling options are enabled
#
copy
_slice(input_batch.top_p_cpu_tensor
, input_batch.top_p
)
#
copy
_slice(input_batch.top_k_cpu_tensor
, input_batch.top_k
)
copy
_slice
(
input_batch
.
min_p_cpu_tensor
,
input_batch
.
min_p
,
#
fill
_slice(input_batch.top_p_cpu_tensor)
#
fill
_slice(input_batch.top_k_cpu_tensor)
fill
_slice
(
input_batch
.
min_p_cpu_tensor
,
DEFAULT_SAMPLING_PARAMS
[
"min_p"
])
xm
.
mark_step
()
xm
.
wait_device_ops
()
# Slice persistent device tensors to a fixed pre-compiled padded shape.
return
cls
(
temperature
=
input_batch
.
temperature
[:
padded_num_reqs
],
# Scalar tensor for xla-friendly tracing.
all_greedy
=
torch
.
tensor
(
input_batch
.
all_greedy
,
dtype
=
torch
.
bool
,
device
=
input_batch
.
device
),
temperature
=
input_batch
.
temperature_cpu_tensor
[:
padded_num_reqs
].
to
(
xla_device
),
all_greedy
=
input_batch
.
all_greedy
,
# TODO enable more and avoid returning None values
top_p
=
None
,
# input_batch.top_p[:padded_num_reqs],
top_k
=
None
,
# input_batch.top_k[:padded_num_reqs],
min_p
=
input_batch
.
min_p
[:
padded_num_reqs
],
generators
=
input_batch
.
generators
,
indices_do_sample
=
indices_do_sample
)
min_p
=
input_batch
.
min_p_cpu_tensor
[:
padded_num_reqs
].
to
(
xla_device
))
vllm/v1/serial_utils.py
View file @
31330101
# SPDX-License-Identifier: Apache-2.0
import
pickle
from
collections.abc
import
Sequence
from
inspect
import
isclass
from
types
import
FunctionType
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Optional
,
Union
import
cloudpickle
import
numpy
as
np
import
torch
import
zmq
from
msgspec
import
msgpack
CUSTOM_TYPE_
TENSOR
=
1
CUSTOM_TYPE_PICKLE
=
2
CUSTOM_TYPE_
CLOUDPICKLE
=
3
CUSTOM_TYPE_
PICKLE
=
1
CUSTOM_TYPE_
CLOUD
PICKLE
=
2
CUSTOM_TYPE_
RAW_VIEW
=
3
# TODO calibrate this size
MIN_NOCOPY_BUF_SIZE
=
512
class
MsgpackEncoder
:
"""Encoder with custom torch tensor serialization."""
bytestr
=
Union
[
bytes
,
bytearray
,
memoryview
,
zmq
.
Frame
]
def
__init__
(
self
):
self
.
encoder
=
msgpack
.
Encoder
(
enc_hook
=
custom_enc_hook
)
def
encode
(
self
,
obj
:
Any
)
->
bytes
:
return
self
.
encoder
.
encode
(
obj
)
class
MsgpackEncoder
:
"""Encoder with custom torch tensor and numpy array serialization.
def
encode_into
(
self
,
obj
:
Any
,
buf
:
bytearray
)
->
None
:
self
.
encoder
.
encode_into
(
obj
,
buf
)
Note that unlike vanilla `msgspec` Encoders, this interface is generally
not thread-safe when encoding tensors / numpy arrays.
"""
def
__init__
(
self
):
self
.
encoder
=
msgpack
.
Encoder
(
enc_hook
=
self
.
enc_hook
)
# This is used as a local stash of buffers that we can then access from
# our custom `msgspec` hook, `enc_hook`. We don't have a way to
# pass custom data to the hook otherwise.
self
.
aux_buffers
:
Optional
[
list
[
bytestr
]]
=
None
def
encode
(
self
,
obj
:
Any
)
->
Sequence
[
bytestr
]:
try
:
self
.
aux_buffers
=
bufs
=
[
b
''
]
bufs
[
0
]
=
self
.
encoder
.
encode
(
obj
)
# This `bufs` list allows us to collect direct pointers to backing
# buffers of tensors and np arrays, and return them along with the
# top-level encoded buffer instead of copying their data into the
# new buffer.
return
bufs
finally
:
self
.
aux_buffers
=
None
def
encode_into
(
self
,
obj
:
Any
,
buf
:
bytearray
)
->
Sequence
[
bytestr
]:
try
:
self
.
aux_buffers
=
[
buf
]
bufs
=
self
.
aux_buffers
self
.
encoder
.
encode_into
(
obj
,
buf
)
return
bufs
finally
:
self
.
aux_buffers
=
None
def
enc_hook
(
self
,
obj
:
Any
)
->
Any
:
if
isinstance
(
obj
,
torch
.
Tensor
):
return
self
.
_encode_ndarray
(
obj
.
numpy
())
# Fall back to pickle for object or void kind ndarrays.
if
isinstance
(
obj
,
np
.
ndarray
)
and
obj
.
dtype
.
kind
not
in
(
'O'
,
'V'
):
return
self
.
_encode_ndarray
(
obj
)
if
isinstance
(
obj
,
FunctionType
):
# `pickle` is generally faster than cloudpickle, but can have
# problems serializing methods.
return
msgpack
.
Ext
(
CUSTOM_TYPE_CLOUDPICKLE
,
cloudpickle
.
dumps
(
obj
))
return
msgpack
.
Ext
(
CUSTOM_TYPE_PICKLE
,
pickle
.
dumps
(
obj
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
))
def
_encode_ndarray
(
self
,
obj
:
np
.
ndarray
)
->
tuple
[
str
,
tuple
[
int
,
...],
Union
[
int
,
memoryview
]]:
assert
self
.
aux_buffers
is
not
None
arr_data
=
obj
.
data
if
obj
.
data
.
c_contiguous
else
obj
.
tobytes
()
if
not
obj
.
shape
or
obj
.
nbytes
<
MIN_NOCOPY_BUF_SIZE
:
# Encode small arrays and scalars inline. Using this extension type
# ensures we can avoid copying when decoding.
data
=
msgpack
.
Ext
(
CUSTOM_TYPE_RAW_VIEW
,
arr_data
)
else
:
# Otherwise encode index of backing buffer to avoid copy.
data
=
len
(
self
.
aux_buffers
)
self
.
aux_buffers
.
append
(
arr_data
)
# We serialize the ndarray as a tuple of native types.
# The data is either inlined if small, or an index into a list of
# backing buffers that we've stashed in `aux_buffers`.
return
obj
.
dtype
.
str
,
obj
.
shape
,
data
class
MsgpackDecoder
:
"""Decoder with custom torch tensor serialization."""
"""Decoder with custom torch tensor and numpy array serialization.
Note that unlike vanilla `msgspec` Decoders, this interface is generally
not thread-safe when encoding tensors / numpy arrays.
"""
def
__init__
(
self
,
t
:
Optional
[
Any
]
=
None
):
args
=
()
if
t
is
None
else
(
t
,
)
self
.
decoder
=
msgpack
.
Decoder
(
*
args
,
ext_hook
=
custom_ext_hook
)
def
decode
(
self
,
obj
:
Any
):
return
self
.
decoder
.
decode
(
obj
)
def
custom_enc_hook
(
obj
:
Any
)
->
Any
:
if
isinstance
(
obj
,
torch
.
Tensor
):
# NOTE(rob): it is fastest to use numpy + pickle
# when serializing torch tensors.
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
return
msgpack
.
Ext
(
CUSTOM_TYPE_TENSOR
,
pickle
.
dumps
(
obj
.
numpy
()))
if
isinstance
(
obj
,
FunctionType
):
return
msgpack
.
Ext
(
CUSTOM_TYPE_CLOUDPICKLE
,
cloudpickle
.
dumps
(
obj
))
return
msgpack
.
Ext
(
CUSTOM_TYPE_PICKLE
,
pickle
.
dumps
(
obj
))
def
custom_ext_hook
(
code
:
int
,
data
:
memoryview
)
->
Any
:
if
code
==
CUSTOM_TYPE_TENSOR
:
return
torch
.
from_numpy
(
pickle
.
loads
(
data
))
if
code
==
CUSTOM_TYPE_PICKLE
:
return
pickle
.
loads
(
data
)
if
code
==
CUSTOM_TYPE_CLOUDPICKLE
:
return
cloudpickle
.
loads
(
data
)
raise
NotImplementedError
(
f
"Extension type code
{
code
}
is not supported"
)
self
.
decoder
=
msgpack
.
Decoder
(
*
args
,
ext_hook
=
self
.
ext_hook
,
dec_hook
=
self
.
dec_hook
)
self
.
aux_buffers
:
Sequence
[
bytestr
]
=
()
def
decode
(
self
,
bufs
:
Union
[
bytestr
,
Sequence
[
bytestr
]])
->
Any
:
if
isinstance
(
bufs
,
(
bytes
,
bytearray
,
memoryview
,
zmq
.
Frame
)):
# TODO - This check can become `isinstance(bufs, bytestr)`
# as of Python 3.10.
return
self
.
decoder
.
decode
(
bufs
)
self
.
aux_buffers
=
bufs
try
:
return
self
.
decoder
.
decode
(
bufs
[
0
])
finally
:
self
.
aux_buffers
=
()
def
dec_hook
(
self
,
t
:
type
,
obj
:
Any
)
->
Any
:
# Given native types in `obj`, convert to type `t`.
if
isclass
(
t
):
if
issubclass
(
t
,
np
.
ndarray
):
return
self
.
_decode_ndarray
(
obj
)
if
issubclass
(
t
,
torch
.
Tensor
):
return
torch
.
from_numpy
(
self
.
_decode_ndarray
(
obj
))
return
obj
def
_decode_ndarray
(
self
,
arr
:
Any
)
->
np
.
ndarray
:
dtype
,
shape
,
data
=
arr
buffer
=
self
.
aux_buffers
[
data
]
if
isinstance
(
data
,
int
)
else
data
return
np
.
ndarray
(
buffer
=
buffer
,
dtype
=
np
.
dtype
(
dtype
),
shape
=
shape
)
def
ext_hook
(
self
,
code
:
int
,
data
:
memoryview
)
->
Any
:
if
code
==
CUSTOM_TYPE_RAW_VIEW
:
return
data
if
code
==
CUSTOM_TYPE_PICKLE
:
return
pickle
.
loads
(
data
)
if
code
==
CUSTOM_TYPE_CLOUDPICKLE
:
return
cloudpickle
.
loads
(
data
)
raise
NotImplementedError
(
f
"Extension type code
{
code
}
is not supported"
)
vllm/v1/spec_decode/eagle.py
View file @
31330101
...
...
@@ -4,8 +4,11 @@ import torch.nn as nn
import
triton
import
triton.language
as
tl
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.forward_context
import
set_forward_context
from
vllm.model_executor.model_loader.loader
import
get_model_loader
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
from
vllm.model_executor.models.llama_eagle
import
EagleLlamaForCausalLM
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
...
...
@@ -21,8 +24,12 @@ class EagleProposer:
self
.
num_speculative_tokens
=
(
vllm_config
.
speculative_config
.
num_speculative_tokens
)
self
.
block_size
=
vllm_config
.
cache_config
.
block_size
self
.
arange
=
torch
.
arange
(
vllm_config
.
scheduler_config
.
max_num_seqs
,
device
=
device
)
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
self
.
arange
=
torch
.
arange
(
vllm_config
.
scheduler_config
.
max_num_seqs
+
1
,
device
=
device
,
dtype
=
torch
.
int32
)
def
propose
(
self
,
...
...
@@ -54,7 +61,9 @@ class EagleProposer:
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
input_ids
[
last_token_indices
]
=
next_token_ids
seq_lens
=
target_positions
[
last_token_indices
]
+
1
# FA requires seq_len to have dtype int32.
seq_lens
=
(
target_positions
[
last_token_indices
]
+
1
).
int
()
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
max_seq_len
=
seq_lens
.
max
().
item
()
max_num_tokens
=
(
cu_num_tokens
[
1
:]
-
cu_num_tokens
[:
-
1
]).
max
().
item
()
...
...
@@ -98,7 +107,7 @@ class EagleProposer:
hidden_states
=
sample_hidden_states
attn_metadata
.
num_actual_tokens
=
batch_size
attn_metadata
.
max_query_len
=
1
attn_metadata
.
query_start_loc
=
self
.
arange
[:
batch_size
]
attn_metadata
.
query_start_loc
=
self
.
arange
[:
batch_size
+
1
]
for
_
in
range
(
self
.
num_speculative_tokens
-
1
):
# Update the inputs.
input_ids
=
draft_token_ids_list
[
-
1
]
...
...
@@ -176,26 +185,28 @@ class EagleProposer:
return
cu_num_tokens
,
token_indices
def
load_model
(
self
,
target_model
:
nn
.
Module
)
->
None
:
self
.
model
=
DummyEagleModel
()
self
.
model
.
get_input_embeddings
=
target_model
.
get_input_embeddings
self
.
model
.
compute_logits
=
target_model
.
compute_logits
# FIXME(woosuk): This is a dummy model for testing.
# Remove this once we have a real model.
class
DummyEagleModel
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
input_embeddings
=
self
.
get_input_embeddings
(
input_ids
)
return
hidden_states
+
input_embeddings
# Dummy return.
loader
=
get_model_loader
(
self
.
vllm_config
.
load_config
)
target_layer_num
=
self
.
vllm_config
.
model_config
.
get_num_layers
(
self
.
vllm_config
.
parallel_config
)
draft_model_config
=
\
self
.
vllm_config
.
speculative_config
.
draft_model_config
# FIXME(lily): This does not handle with distributed inference.
target_device
=
self
.
vllm_config
.
device_config
.
device
# We need to set the vllm_config here to register attention
# layers in the forward context.
with
set_default_torch_dtype
(
draft_model_config
.
dtype
),
set_current_vllm_config
(
self
.
vllm_config
):
self
.
model
=
EagleLlamaForCausalLM
(
model_config
=
draft_model_config
,
start_layer_id
=
target_layer_num
).
to
(
target_device
)
self
.
model
.
load_weights
(
loader
.
get_all_weights
(
self
.
vllm_config
.
speculative_config
.
draft_model_config
,
self
.
model
))
self
.
model
.
lm_head
=
target_model
.
lm_head
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
...
...
vllm/v1/structured_output/backend_guidance.py
View file @
31330101
...
...
@@ -46,7 +46,8 @@ class GuidanceBackend(StructuredOutputBackend):
in
vllm_config
.
decoding_config
.
guided_decoding_backend
)
tokenizer
=
tokenizer_group
.
get_lora_tokenizer
(
None
)
self
.
ll_tokenizer
=
llguidance_hf
.
from_tokenizer
(
tokenizer
,
None
)
self
.
ll_tokenizer
=
llguidance_hf
.
from_tokenizer
(
tokenizer
,
self
.
vocab_size
)
def
compile_grammar
(
self
,
request_type
:
StructuredOutputOptions
,
grammar_spec
:
str
)
->
StructuredOutputGrammar
:
...
...
@@ -163,7 +164,6 @@ def validate_guidance_grammar(
tokenizer
:
Optional
[
llguidance
.
LLTokenizer
]
=
None
)
->
None
:
tp
,
grm
=
get_structured_output_key
(
sampling_params
)
guidance_grm
=
serialize_guidance_grammar
(
tp
,
grm
)
err
=
llguidance
.
LLMatcher
.
validate_grammar
(
guidance_grm
,
tokenizer
=
tokenizer
)
err
=
llguidance
.
LLMatcher
.
validate_grammar
(
guidance_grm
,
tokenizer
)
if
err
:
raise
ValueError
(
f
"Grammar error:
{
err
}
"
)
vllm/v1/structured_output/backend_xgrammar.py
View file @
31330101
...
...
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
import
torch
import
vllm.envs
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
...
...
@@ -76,7 +77,12 @@ class XgrammarBackend(StructuredOutputBackend):
tokenizer
,
vocab_size
=
self
.
vocab_size
,
)
self
.
compiler
=
xgr
.
GrammarCompiler
(
tokenizer_info
,
max_threads
=
8
)
self
.
compiler
=
xgr
.
GrammarCompiler
(
tokenizer_info
,
max_threads
=
8
,
cache_enabled
=
True
,
cache_limit_bytes
=
vllm
.
envs
.
VLLM_XGRAMMAR_CACHE_MB
*
1024
*
1024
,
)
def
compile_grammar
(
self
,
request_type
:
StructuredOutputOptions
,
grammar_spec
:
str
)
->
StructuredOutputGrammar
:
...
...
vllm/v1/structured_output/utils.py
View file @
31330101
...
...
@@ -41,8 +41,7 @@ def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool:
return
True
# Unsupported keywords for strings
if
obj
.
get
(
"type"
)
==
"string"
and
any
(
key
in
obj
for
key
in
(
"minLength"
,
"maxLength"
,
"format"
)):
if
obj
.
get
(
"type"
)
==
"string"
and
"format"
in
obj
:
return
True
# Unsupported keywords for objects
...
...
vllm/v1/utils.py
View file @
31330101
# SPDX-License-Identifier: Apache-2.0
import
multiprocessing
import
os
import
weakref
from
collections
import
defaultdict
from
collections.abc
import
Sequence
from
multiprocessing
import
Process
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Generic
,
Optional
,
TypeVar
,
Union
,
overload
)
...
...
@@ -105,28 +105,22 @@ class BackgroundProcHandle:
process_kwargs
:
dict
[
Any
,
Any
],
):
context
=
get_mp_context
()
self
.
reader
,
writer
=
context
.
Pipe
(
duplex
=
False
)
assert
(
"ready_pipe"
not
in
process_kwargs
and
"input_path"
not
in
process_kwargs
assert
(
"input_path"
not
in
process_kwargs
and
"output_path"
not
in
process_kwargs
)
process_kwargs
[
"ready_pipe"
]
=
writer
process_kwargs
[
"input_path"
]
=
input_path
process_kwargs
[
"output_path"
]
=
output_path
# Run busy loop in background process.
self
.
proc
=
context
.
Process
(
target
=
target_fn
,
kwargs
=
process_kwargs
,
name
=
process_name
)
self
.
proc
:
Process
=
context
.
Process
(
target
=
target_fn
,
kwargs
=
process_kwargs
,
name
=
process_name
)
self
.
_finalizer
=
weakref
.
finalize
(
self
,
shutdown
,
self
.
proc
,
input_path
,
output_path
)
self
.
proc
.
start
()
def
wait_for_startup
(
self
):
# Wait for startup.
if
self
.
reader
.
recv
()[
"status"
]
!=
"READY"
:
raise
RuntimeError
(
f
"
{
self
.
proc
.
name
}
initialization failed. "
"See root cause above."
)
def
fileno
(
self
):
return
self
.
proc
.
sentinel
def
shutdown
(
self
):
self
.
_finalizer
()
...
...
@@ -134,7 +128,7 @@ class BackgroundProcHandle:
# Note(rob): shutdown function cannot be a bound method,
# else the gc cannot collect the object.
def
shutdown
(
proc
:
multiprocessing
.
Process
,
input_path
:
str
,
output_path
:
str
):
def
shutdown
(
proc
:
Process
,
input_path
:
str
,
output_path
:
str
):
# Shutdown the process.
if
proc
.
is_alive
():
proc
.
terminate
()
...
...
@@ -206,4 +200,4 @@ def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
Returns the sliced target tensor.
"""
return
to_tensor
[:
length
].
copy_
(
from_tensor
[:
length
],
non_blocking
=
True
)
return
to_tensor
[:
length
].
copy_
(
from_tensor
[:
length
],
non_blocking
=
True
)
\ No newline at end of file
vllm/v1/worker/gpu_model_runner.py
View file @
31330101
...
...
@@ -19,7 +19,8 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
from
vllm.model_executor.model_loader
import
get_model
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.multimodal.utils
import
group_mm_inputs_by_modality
from
vllm.sampling_params
import
SamplingType
from
vllm.sequence
import
IntermediateTensors
...
...
@@ -43,7 +44,8 @@ from vllm.v1.utils import bind_kv_cache
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
.utils
import
sanity_check_mm_encoder_outputs
from
.utils
import
(
gather_mm_placeholders
,
sanity_check_mm_encoder_outputs
,
scatter_mm_placeholders
)
if
TYPE_CHECKING
:
import
xgrammar
as
xgr
...
...
@@ -482,14 +484,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
input_batch
.
block_table
.
commit
(
num_reqs
)
# Get the number of scheduled tokens for each request.
# TODO: The Python loop can be slow. Optimize.
num_scheduled_tokens
=
np
.
empty
(
num_reqs
,
dtype
=
np
.
int32
)
max_num_scheduled_tokens
=
0
for
i
,
req_id
in
enumerate
(
self
.
input_batch
.
req_ids
):
num_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
num_scheduled_tokens
[
i
]
=
num_tokens
max_num_scheduled_tokens
=
max
(
max_num_scheduled_tokens
,
num_tokens
)
req_ids
=
self
.
input_batch
.
req_ids
tokens
=
[
scheduler_output
.
num_scheduled_tokens
[
i
]
for
i
in
req_ids
]
num_scheduled_tokens
=
np
.
array
(
tokens
,
dtype
=
np
.
int32
)
max_num_scheduled_tokens
=
max
(
tokens
)
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
...
...
@@ -830,19 +828,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
return
metadata
def
_execute_encoder
(
self
,
scheduler_output
:
"SchedulerOutput"
):
def
_execute_
mm_
encoder
(
self
,
scheduler_output
:
"SchedulerOutput"
):
scheduled_encoder_inputs
=
scheduler_output
.
scheduled_encoder_inputs
if
not
scheduled_encoder_inputs
:
return
# Batch the multi-modal inputs.
mm_inputs
:
list
[
MultiModalKwargs
]
=
[]
req_i
nput_ids
:
list
[
tuple
[
str
,
int
]]
=
[]
mm_inputs
=
list
[
MultiModalKwargs
]
()
req_i
ds_pos
=
list
[
tuple
[
str
,
int
,
PlaceholderRange
]]()
for
req_id
,
encoder_input_ids
in
scheduled_encoder_inputs
.
items
():
req_state
=
self
.
requests
[
req_id
]
for
input_id
in
encoder_input_ids
:
mm_inputs
.
append
(
req_state
.
mm_inputs
[
input_id
])
req_input_ids
.
append
((
req_id
,
input_id
))
for
mm_input_id
in
encoder_input_ids
:
mm_inputs
.
append
(
req_state
.
mm_inputs
[
mm_input_id
])
req_ids_pos
.
append
(
(
req_id
,
mm_input_id
,
req_state
.
mm_positions
[
mm_input_id
]))
# Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one,
...
...
@@ -878,16 +878,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
encoder_outputs
.
append
(
output
)
# Cache the encoder outputs.
for
(
req_id
,
input_id
),
output
in
zip
(
req_input_ids
,
encoder_outputs
):
for
(
req_id
,
input_id
,
pos_info
),
output
in
zip
(
req_ids_pos
,
encoder_outputs
,
):
if
req_id
not
in
self
.
encoder_cache
:
self
.
encoder_cache
[
req_id
]
=
{}
self
.
encoder_cache
[
req_id
][
input_id
]
=
output
def
_gather_encoder_outputs
(
self
.
encoder_cache
[
req_id
][
input_id
]
=
scatter_mm_placeholders
(
output
,
is_embed
=
pos_info
.
is_embed
,
)
def
_gather_mm_embeddings
(
self
,
scheduler_output
:
"SchedulerOutput"
,
)
->
list
[
torch
.
Tensor
]:
encoder_output
s
:
list
[
torch
.
Tensor
]
=
[]
mm_embed
s
:
list
[
torch
.
Tensor
]
=
[]
for
req_id
in
self
.
input_batch
.
req_ids
:
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
...
...
@@ -895,8 +902,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_computed_tokens
=
req_state
.
num_computed_tokens
mm_positions
=
req_state
.
mm_positions
for
i
,
pos_info
in
enumerate
(
mm_positions
):
start_pos
=
pos_info
[
"
offset
"
]
num_encoder_tokens
=
pos_info
[
"
length
"
]
start_pos
=
pos_info
.
offset
num_encoder_tokens
=
pos_info
.
length
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens,
...
...
@@ -918,8 +925,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert
req_id
in
self
.
encoder_cache
assert
i
in
self
.
encoder_cache
[
req_id
]
encoder_output
=
self
.
encoder_cache
[
req_id
][
i
]
encoder_outputs
.
append
(
encoder_output
[
start_idx
:
end_idx
])
return
encoder_outputs
if
(
is_embed
:
=
pos_info
.
is_embed
)
is
not
None
:
is_embed
=
is_embed
[
start_idx
:
end_idx
]
mm_embeds_item
=
gather_mm_placeholders
(
encoder_output
[
start_idx
:
end_idx
],
is_embed
=
is_embed
,
)
mm_embeds
.
append
(
mm_embeds_item
)
return
mm_embeds
def
get_model
(
self
)
->
nn
.
Module
:
return
self
.
model
...
...
@@ -979,15 +994,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
->
Union
[
ModelRunnerOutput
,
torch
.
Tensor
]:
self
.
_update_states
(
scheduler_output
)
if
not
scheduler_output
.
total_num_scheduled_tokens
:
# Return empty ModelRunnerOu
p
tut if there's no work to do.
# Return empty ModelRunnerOut
p
ut if there's no work to do.
return
EMPTY_MODEL_RUNNER_OUTPUT
if
self
.
is_multimodal_model
:
# Run the multimodal encoder if any.
self
.
_execute_encoder
(
scheduler_output
)
encoder_output
s
=
self
.
_gather_
encoder_output
s
(
scheduler_output
)
self
.
_execute_
mm_
encoder
(
scheduler_output
)
mm_embed
s
=
self
.
_gather_
mm_embedding
s
(
scheduler_output
)
else
:
encoder_output
s
=
[]
mm_embed
s
=
[]
# Prepare the decoder inputs.
attn_metadata
,
logits_indices
,
spec_decode_metadata
=
(
...
...
@@ -1009,9 +1024,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
input_ids
=
self
.
input_ids
[:
num_scheduled_tokens
]
if
encoder_output
s
:
if
mm_embed
s
:
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
input_ids
,
encoder_output
s
)
input_ids
,
mm_embed
s
)
else
:
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
input_ids
)
# TODO(woosuk): Avoid the copy. Optimize.
...
...
@@ -1172,9 +1187,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
spec_decode_metadata
is
None
:
# input_ids can be None for multimodal models.
# We need to slice token_ids, positions, and hidden_states
# because the eagle head does not use cuda graph and should
# not include padding.
target_token_ids
=
self
.
input_ids
[:
num_scheduled_tokens
]
target_positions
=
positions
target_hidden_states
=
hidden_states
target_positions
=
positions
[:
num_scheduled_tokens
]
target_hidden_states
=
hidden_states
[:
num_scheduled_tokens
]
target_slot_mapping
=
attn_metadata
.
slot_mapping
cu_num_tokens
=
attn_metadata
.
query_start_loc
else
:
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
31330101
...
...
@@ -15,13 +15,14 @@ import torch_xla.runtime as xr
import
vllm.envs
as
envs
from
vllm.attention.backends.abstract
import
AttentionType
from
vllm.attention.layer
import
Attention
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispatcher
from
vllm.config
import
VllmConfig
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.multimodal.utils
import
group_mm_inputs_by_modality
from
vllm.sampling_params
import
SamplingType
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
LayerBlockType
,
cdiv
,
is_pin_memory_available
from
vllm.v1.attention.backends.pallas
import
(
PallasAttentionBackend
,
...
...
@@ -30,13 +31,14 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheSpec
,
SlidingWindowSpec
)
from
vllm.v1.outputs
import
(
EMPTY_MODEL_RUNNER_OUTPUT
,
LogprobsTensors
,
ModelRunnerOutput
,
SamplerOutput
)
ModelRunnerOutput
)
from
vllm.v1.sample.tpu.metadata
import
TPUSupportedSamplingMetadata
from
vllm.v1.sample.tpu.sampler
import
Sampler
as
TPUSampler
from
vllm.v1.utils
import
bind_kv_cache
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
.utils
import
sanity_check_mm_encoder_outputs
from
.utils
import
(
gather_mm_placeholders
,
sanity_check_mm_encoder_outputs
,
scatter_mm_placeholders
)
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
...
...
@@ -174,10 +176,12 @@ class TPUModelRunner:
# Range tensor with values [0 .. self.max_num_tokens - 1].
# Used to initialize positions / context_lens / seq_lens
self
.
arange_np
=
np
.
arange
(
self
.
max_num_tokens
,
dtype
=
np
.
int32
)
self
.
num_tokens_paddings
=
_get_paddings
(
self
.
num_tokens_paddings
=
_get_
token_
paddings
(
min_token_size
=
16
,
max_token_size
=
self
.
max_num_tokens
,
padding_gap
=
envs
.
VLLM_TPU_BUCKET_PADDING_GAP
)
self
.
num_reqs_paddings
=
_get_req_paddings
(
min_req_size
=
MIN_NUM_SEQS
,
max_req_size
=
self
.
max_num_reqs
)
def
_update_num_xla_graphs
(
self
,
case_str
):
check_comp
=
self
.
check_recompilation
and
not
self
.
enforce_eager
...
...
@@ -262,11 +266,6 @@ class TPUModelRunner:
for
new_req_data
in
scheduler_output
.
scheduled_new_reqs
:
req_id
=
new_req_data
.
req_id
sampling_params
=
new_req_data
.
sampling_params
if
sampling_params
.
sampling_type
==
SamplingType
.
RANDOM_SEED
:
generator
=
torch
.
Generator
(
device
=
self
.
device
)
generator
.
manual_seed
(
sampling_params
.
seed
)
else
:
generator
=
None
self
.
requests
[
req_id
]
=
CachedRequestState
(
req_id
=
req_id
,
...
...
@@ -275,7 +274,7 @@ class TPUModelRunner:
mm_inputs
=
new_req_data
.
mm_inputs
,
mm_positions
=
new_req_data
.
mm_positions
,
sampling_params
=
sampling_params
,
generator
=
generator
,
generator
=
None
,
block_ids
=
new_req_data
.
block_ids
,
num_computed_tokens
=
new_req_data
.
num_computed_tokens
,
output_token_ids
=
[],
...
...
@@ -505,21 +504,48 @@ class TPUModelRunner:
# Padded to avoid recompiling when `num_reqs` varies.
logits_indices
=
self
.
query_start_loc_cpu
[
1
:
padded_num_reqs
+
1
]
-
1
logits_indices
=
logits_indices
.
to
(
self
.
device
)
return
attn_metadata
,
logits_indices
return
attn_metadata
,
logits_indices
,
padded_num_reqs
def
_scatter_placeholders
(
self
,
embeds
:
torch
.
Tensor
,
is_embed
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
if
is_embed
is
None
:
return
embeds
def
_execute_encoder
(
self
,
scheduler_output
:
"SchedulerOutput"
):
placeholders
=
embeds
.
new_full
(
(
is_embed
.
shape
[
0
],
embeds
.
shape
[
-
1
]),
fill_value
=
torch
.
nan
,
)
placeholders
[
is_embed
]
=
embeds
return
placeholders
def
_gather_placeholders
(
self
,
placeholders
:
torch
.
Tensor
,
is_embed
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
if
is_embed
is
None
:
return
placeholders
return
placeholders
[
is_embed
]
def
_execute_mm_encoder
(
self
,
scheduler_output
:
"SchedulerOutput"
):
scheduled_encoder_inputs
=
scheduler_output
.
scheduled_encoder_inputs
if
not
scheduled_encoder_inputs
:
return
# Batch the multi-modal inputs.
mm_inputs
:
list
[
MultiModalKwargs
]
=
[]
req_i
nput_ids
:
list
[
tuple
[
str
,
int
]]
=
[]
mm_inputs
=
list
[
MultiModalKwargs
]
()
req_i
ds_pos
=
list
[
tuple
[
str
,
int
,
PlaceholderRange
]]()
for
req_id
,
encoder_input_ids
in
scheduled_encoder_inputs
.
items
():
req_state
=
self
.
requests
[
req_id
]
for
input_id
in
encoder_input_ids
:
mm_inputs
.
append
(
req_state
.
mm_inputs
[
input_id
])
req_input_ids
.
append
((
req_id
,
input_id
))
for
mm_input_id
in
encoder_input_ids
:
mm_inputs
.
append
(
req_state
.
mm_inputs
[
mm_input_id
])
req_ids_pos
.
append
(
(
req_id
,
mm_input_id
,
req_state
.
mm_positions
[
mm_input_id
]))
# Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one,
...
...
@@ -555,16 +581,23 @@ class TPUModelRunner:
encoder_outputs
.
append
(
output
)
# Cache the encoder outputs.
for
(
req_id
,
input_id
),
output
in
zip
(
req_input_ids
,
encoder_outputs
):
for
(
req_id
,
input_id
,
pos_info
),
output
in
zip
(
req_ids_pos
,
encoder_outputs
,
):
if
req_id
not
in
self
.
encoder_cache
:
self
.
encoder_cache
[
req_id
]
=
{}
self
.
encoder_cache
[
req_id
][
input_id
]
=
output
def
_gather_encoder_outputs
(
self
.
encoder_cache
[
req_id
][
input_id
]
=
scatter_mm_placeholders
(
output
,
is_embed
=
pos_info
.
is_embed
,
)
def
_gather_mm_embeddings
(
self
,
scheduler_output
:
"SchedulerOutput"
,
)
->
list
[
torch
.
Tensor
]:
encoder_output
s
:
list
[
torch
.
Tensor
]
=
[]
mm_embed
s
:
list
[
torch
.
Tensor
]
=
[]
for
req_id
in
self
.
input_batch
.
req_ids
:
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
...
...
@@ -572,8 +605,8 @@ class TPUModelRunner:
num_computed_tokens
=
req_state
.
num_computed_tokens
mm_positions
=
req_state
.
mm_positions
for
i
,
pos_info
in
enumerate
(
mm_positions
):
start_pos
=
pos_info
[
"
offset
"
]
num_encoder_tokens
=
pos_info
[
"
length
"
]
start_pos
=
pos_info
.
offset
num_encoder_tokens
=
pos_info
.
length
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens,
...
...
@@ -595,8 +628,16 @@ class TPUModelRunner:
assert
req_id
in
self
.
encoder_cache
assert
i
in
self
.
encoder_cache
[
req_id
]
encoder_output
=
self
.
encoder_cache
[
req_id
][
i
]
encoder_outputs
.
append
(
encoder_output
[
start_idx
:
end_idx
])
return
encoder_outputs
if
(
is_embed
:
=
pos_info
.
is_embed
)
is
not
None
:
is_embed
=
is_embed
[
start_idx
:
end_idx
]
mm_embeds_item
=
gather_mm_placeholders
(
encoder_output
[
start_idx
:
end_idx
],
is_embed
=
is_embed
,
)
mm_embeds
.
append
(
mm_embeds_item
)
return
mm_embeds
@
torch
.
no_grad
()
def
execute_model
(
...
...
@@ -607,25 +648,26 @@ class TPUModelRunner:
# Update cached state
self
.
_update_states
(
scheduler_output
)
if
not
scheduler_output
.
total_num_scheduled_tokens
:
# Return empty ModelRunnerOu
p
tut if there's no work to do.
# Return empty ModelRunnerOut
p
ut if there's no work to do.
return
EMPTY_MODEL_RUNNER_OUTPUT
if
self
.
is_multimodal_model
:
# Run the multimodal encoder if any.
self
.
_execute_encoder
(
scheduler_output
)
encoder_output
s
=
self
.
_gather_
encoder_output
s
(
scheduler_output
)
self
.
_execute_
mm_
encoder
(
scheduler_output
)
mm_embed
s
=
self
.
_gather_
mm_embedding
s
(
scheduler_output
)
else
:
encoder_output
s
=
[]
mm_embed
s
=
[]
# Prepare inputs
attn_metadata
,
logits_indices
=
self
.
_prepare_inputs
(
scheduler_output
)
attn_metadata
,
logits_indices
,
padded_num_reqs
=
self
.
_prepare_inputs
(
scheduler_output
)
if
self
.
is_multimodal_model
:
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
if
encoder_output
s
:
if
mm_embed
s
:
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
self
.
input_ids
,
encoder_output
s
)
self
.
input_ids
,
mm_embed
s
)
else
:
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
self
.
input_ids
)
input_ids
=
None
...
...
@@ -637,21 +679,19 @@ class TPUModelRunner:
input_ids
=
self
.
input_ids
inputs_embeds
=
None
num_reqs
=
self
.
input_batch
.
num_reqs
# NOTE (NickLucche) here we sync with TPU: sampling params tensors
# are copied to device in chunks of pre-compiled padded shape to
# avoid recompilations.
tpu_sampling_metadata
=
TPUSupportedSamplingMetadata
.
\
from_input_batch
(
self
.
input_batch
,
logits_indices
)
# Run the decoder
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
):
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
self
.
position_ids
,
kv_caches
=
self
.
kv_caches
,
inputs_embeds
=
inputs_embeds
,
)
selected_token_ids
=
self
.
model
.
sample_from_hidden
(
hidden_states
,
tpu_sampling_metadata
)
hidden_states
=
self
.
select_hidden_states
(
hidden_states
,
logits_indices
)
tpu_sampling_metadata
=
TPUSupportedSamplingMetadata
.
\
from_input_batch
(
self
.
input_batch
,
padded_num_reqs
,
self
.
device
)
selected_token_ids
=
self
.
sample_from_hidden
(
hidden_states
,
tpu_sampling_metadata
)
# Remove padding on cpu and keep dynamic op outside of xla graph.
selected_token_ids
=
selected_token_ids
.
cpu
()[:
num_reqs
]
...
...
@@ -751,17 +791,15 @@ class TPUModelRunner:
"get_tensor_model_parallel_rank"
,
return_value
=
xm_tp_rank
):
model
=
get_model
(
vllm_config
=
self
.
vllm_config
)
model
=
model
.
eval
()
# Sync all pending XLA execution during model initialization and weight
# loading.
xm
.
mark_step
()
xm
.
wait_device_ops
()
model
=
ModelWrapperV1
(
model
)
self
.
model
=
torch
.
compile
(
model
,
backend
=
"openxla"
,
fullgraph
=
True
,
dynamic
=
False
)
self
.
model
=
model
self
.
sampler
=
TPUSampler
()
@
torch
.
no_grad
()
def
_dummy_run
(
self
,
kv_caches
,
num_tokens
:
int
)
->
None
:
def
_dummy_run
(
self
,
num_tokens
:
int
)
->
None
:
if
self
.
is_multimodal_model
:
input_ids
=
None
inputs_embeds
=
torch
.
zeros
((
num_tokens
,
self
.
hidden_size
),
...
...
@@ -812,65 +850,81 @@ class TPUModelRunner:
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
0
):
out
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
position_ids
,
kv_caches
=
kv_caches
,
inputs_embeds
=
inputs_embeds
)
self
.
_hidden_states_dtype
=
out
.
dtype
def
capture_model
(
self
)
->
None
:
"""Compile the model."""
def
_precompile_backbone
(
self
)
->
None
:
logger
.
info
(
"Compiling the model with different input shapes."
)
start
=
time
.
perf_counter
()
for
num_tokens
in
self
.
num_tokens_paddings
:
logger
.
info
(
" -- num_tokens: %d"
,
num_tokens
)
self
.
_dummy_run
(
self
.
kv_caches
,
num_tokens
)
xm
.
mark_step
()
self
.
_dummy_run
(
num_tokens
)
xm
.
wait_device_ops
()
end
=
time
.
perf_counter
()
logger
.
info
(
"Compilation finished in in %.2f [secs]."
,
end
-
start
)
self
.
_update_num_xla_graphs
(
"model"
)
self
.
_update_num_xla_graphs
(
"model
backbone
"
)
logger
.
info
(
"Compiling sampling with different input shapes."
)
def
_precompile_select_hidden_states
(
self
)
->
None
:
# Compile hidden state selection function for bucketed
# n_tokens x max_num_reqs. Graph is really small so this is fine.
logger
.
info
(
"Compiling select_hidden_states with different input shapes."
)
start
=
time
.
perf_counter
()
hsize
=
self
.
model_config
.
get_hidden_size
()
device
=
self
.
device
# Compile sampling step for different model+sampler outputs in bucketed
# n_tokens x max_num_reqs. Graph is really small so this is fine.
for
num_tokens
in
self
.
num_tokens_paddings
:
num_reqs_to_sample
=
MIN_NUM_SEQS
dummy_hidden
=
torch
.
randn
((
num_tokens
,
hsize
),
device
=
device
,
dummy_hidden
=
torch
.
zeros
((
num_tokens
,
hsize
),
device
=
self
.
device
,
dtype
=
self
.
_hidden_states_dtype
)
# Compile for [8, 16, .., 128,.., `self.max_num_reqs`]
while
True
:
indices
=
torch
.
zeros
(
num_reqs_to_sample
,
dtype
=
torch
.
int32
,
device
=
device
,
)
xm
.
mark_step
()
sampling_meta
=
TPUSupportedSamplingMetadata
.
\
from_input_batch
(
self
.
input_batch
,
indices
)
logger
.
info
(
" -- num_tokens: %d, num_seqs: %d"
,
num_tokens
,
num_reqs_to_sample
)
out
=
self
.
model
.
sample_from_hidden
(
dummy_hidden
,
sampling_meta
)
out
=
out
.
cpu
()
# Requests can't be more than tokens. But do compile for the
# next bigger value in case num_tokens uses bucketed padding.
if
num_reqs_to_sample
>=
min
(
num_tokens
,
self
.
max_num_reqs
):
break
# Make sure to compile the `max_num_reqs` upper-limit case
num_reqs_to_sample
=
_get_padded_num_reqs_with_upper_limit
(
num_reqs_to_sample
+
1
,
self
.
max_num_reqs
)
torch
.
_dynamo
.
mark_dynamic
(
dummy_hidden
,
0
)
for
num_reqs
in
self
.
num_reqs_paddings
:
indices
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
torch
.
_dynamo
.
mark_dynamic
(
indices
,
0
)
self
.
select_hidden_states
(
dummy_hidden
,
indices
)
logger
.
info
(
" -- num_tokens: %d"
,
num_tokens
)
xm
.
wait_device_ops
()
end
=
time
.
perf_counter
()
logger
.
info
(
"Compilation finished in in %.2f [secs]."
,
end
-
start
)
self
.
_update_num_xla_graphs
(
"select_hidden_states"
)
def
_precompile_sample_from_hidden
(
self
)
->
None
:
logger
.
info
(
"Compiling sampling with different input shapes."
)
start
=
time
.
perf_counter
()
hsize
=
self
.
model_config
.
get_hidden_size
()
for
num_reqs
in
self
.
num_reqs_paddings
:
dummy_hidden
=
torch
.
zeros
((
num_reqs
,
hsize
),
device
=
self
.
device
,
dtype
=
self
.
_hidden_states_dtype
)
# The first dimension of dummy_hidden cannot be mark_dynamic because
# some operations in the sampler require it to be static.
for
all_greedy
in
[
False
,
True
]:
generate_params_if_all_greedy
=
not
all_greedy
sampling_metadata
=
(
TPUSupportedSamplingMetadata
.
from_input_batch
(
self
.
input_batch
,
num_reqs
,
self
.
device
,
generate_params_if_all_greedy
,
))
sampling_metadata
.
all_greedy
=
all_greedy
self
.
sample_from_hidden
(
dummy_hidden
,
sampling_metadata
)
logger
.
info
(
" -- num_seqs: %d"
,
num_reqs
)
xm
.
wait_device_ops
()
end
=
time
.
perf_counter
()
logger
.
info
(
"Compilation finished in in %.2f [secs]."
,
end
-
start
)
self
.
_update_num_xla_graphs
(
"sampling"
)
def
capture_model
(
self
)
->
None
:
"""
Precompile all the subgraphs with possible input shapes.
"""
# TODO: precompile encoder
self
.
_precompile_backbone
()
self
.
_precompile_select_hidden_states
()
self
.
_precompile_sample_from_hidden
()
def
initialize_kv_cache
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
"""
Initialize KV cache based on `kv_cache_config`.
...
...
@@ -910,73 +964,39 @@ class TPUModelRunner:
self
.
vllm_config
.
compilation_config
.
static_forward_context
,
self
.
kv_caches
)
class
ModelWrapperV1
(
nn
.
Module
):
def
__init__
(
self
,
model
:
nn
.
Module
):
super
().
__init__
()
self
.
model
=
model
self
.
sampler
=
TPUSampler
()
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
TPUSupportedSamplingMetadata
)
->
SamplerOutput
:
sampler_out
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
sampler_out
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
list
[
torch
.
Tensor
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Executes the forward pass of the model.
Args:
input_ids: The input token IDs of shape [num_tokens].
positions: The input position IDs of shape [num_tokens].
kv_caches: The key and value caches. They can be None during the
memory profiling at initialization.
inputs_embeds: The input embeddings of shape [num_tokens,
hidden_size]. It is used for multimodal models.
"""
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
inputs_embeds
=
inputs_embeds
,
)
return
hidden_states
def
reset_dynamo_cache
(
self
):
if
self
.
is_multimodal_model
:
compiled_model
=
self
.
model
.
get_language_model
().
model
else
:
compiled_model
=
self
.
model
.
model
if
isinstance
(
compiled_model
,
TorchCompileWrapperWithCustomDispatcher
):
logger
.
info
(
"Clear dynamo cache and cached dynamo bytecode."
)
torch
.
_dynamo
.
eval_frame
.
remove_from_cache
(
compiled_model
.
original_code_object
)
compiled_model
.
compiled_codes
.
clear
()
@
torch
.
compile
(
backend
=
"openxla"
,
fullgraph
=
True
,
dynamic
=
False
)
def
select_hidden_states
(
self
,
hidden_states
,
indices_do_sample
):
return
hidden_states
[
indices_do_sample
]
@
torch
.
compile
(
backend
=
"openxla"
,
fullgraph
=
True
,
dynamic
=
False
)
def
sample_from_hidden
(
self
,
hidden_states
:
torch
.
Tensor
,
sample_
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
TPUSupportedSamplingMetadata
,
)
->
torch
.
Tensor
:
"""
Sample with xla-friendly function. This function is to be traced
separately from `forward` for lighter compilation overhead.
"""
# Tensor `sample_hidden_states` is of fixed pre-compiled size.
sample_hidden_states
=
\
hidden_states
[
sampling_metadata
.
indices_do_sample
]
logits
=
self
.
compute_logits
(
sample_hidden_states
)
# Optimized greedy sampling branch, tracing both paths in a single pass
# NOTE all_greedy is a scalar, this is just an optimized if/else.
out_tokens
=
torch
.
where
(
sampling_metadata
.
all_greedy
,
torch
.
argmax
(
logits
,
dim
=-
1
,
keepdim
=
True
),
self
.
sample
(
logits
,
sampling_metadata
)
\
.
sampled_token_ids
)
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
if
sampling_metadata
.
all_greedy
:
out_tokens
=
torch
.
argmax
(
logits
,
dim
=-
1
,
keepdim
=
True
)
else
:
out_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
).
sampled_token_ids
return
out_tokens
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
# SamplingMetadata here for pruning output in LogitsProcessor, disabled
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
None
)
return
logits
def
get_multimodal_embeddings
(
self
,
*
args
,
**
kwargs
):
return
self
.
model
.
get_multimodal_embeddings
(
*
args
,
**
kwargs
)
...
...
@@ -984,17 +1004,26 @@ class ModelWrapperV1(nn.Module):
return
self
.
model
.
get_input_embeddings
(
*
args
,
**
kwargs
)
def
_get_padded_number
(
n
:
int
,
multiple
:
int
)
->
int
:
return
((
n
+
multiple
-
1
)
//
multiple
)
*
multiple
def
_get_req_paddings
(
min_req_size
:
int
,
max_req_size
:
int
)
->
list
[
int
]:
logger
.
info
(
"Preparing request paddings:"
)
# assert min_req_size is power of 2
assert
(
min_req_size
&
(
min_req_size
-
1
)
==
0
)
and
min_req_size
>
0
paddings
:
list
=
[]
num
=
max
(
MIN_NUM_SEQS
,
min_req_size
)
while
num
<=
max_req_size
and
(
len
(
paddings
)
==
0
or
paddings
[
-
1
]
!=
num
):
paddings
.
append
(
num
)
logger
.
info
(
" %d"
,
num
)
num
=
_get_padded_num_reqs_with_upper_limit
(
num
+
1
,
max_req_size
)
return
paddings
def
_get_padded_num_reqs_with_upper_limit
(
x
,
upper_limit
)
->
int
:
def
_get_padded_num_reqs_with_upper_limit
(
x
:
int
,
upper_limit
:
int
)
->
int
:
res
=
MIN_NUM_SEQS
if
x
<=
MIN_NUM_SEQS
else
1
<<
(
x
-
1
).
bit_length
()
return
min
(
res
,
upper_limit
)
def
_get_paddings
(
min_token_size
:
int
,
max_token_size
:
int
,
padding_gap
:
int
)
->
list
[
int
]:
def
_get_
token_
paddings
(
min_token_size
:
int
,
max_token_size
:
int
,
padding_gap
:
int
)
->
list
[
int
]:
"""Generate a list of padding size, starting from min_token_size,
ending with a number that can cover max_token_size
...
...
@@ -1004,18 +1033,20 @@ def _get_paddings(min_token_size: int, max_token_size: int,
first increase the size to twice,
then increase the padding size by padding_gap.
"""
# assert min_token_size is power of 2
assert
(
min_token_size
&
(
min_token_size
-
1
)
==
0
)
and
min_token_size
>
0
paddings
=
[]
num
=
min_token_size
if
padding_gap
==
0
:
logger
.
info
(
"Using exponential paddings:"
)
logger
.
info
(
"Using exponential
token
paddings:"
)
while
num
<=
max_token_size
:
logger
.
info
(
" %d"
,
num
)
paddings
.
append
(
num
)
num
*=
2
else
:
logger
.
info
(
"Using incremental paddings:"
)
logger
.
info
(
"Using incremental
token
paddings:"
)
while
num
<=
padding_gap
:
logger
.
info
(
" %d"
,
num
)
paddings
.
append
(
num
)
...
...
Prev
1
…
13
14
15
16
17
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment