Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9c4ecf15
Commit
9c4ecf15
authored
Apr 14, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.4' into v0.8.4-ori
parents
bfc2d6f7
dc1b4a6f
Changes
342
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1065 additions
and
507 deletions
+1065
-507
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+57
-19
vllm/v1/engine/mm_input_cache.py
vllm/v1/engine/mm_input_cache.py
+34
-7
vllm/v1/engine/processor.py
vllm/v1/engine/processor.py
+97
-39
vllm/v1/executor/multiproc_executor.py
vllm/v1/executor/multiproc_executor.py
+8
-7
vllm/v1/metrics/loggers.py
vllm/v1/metrics/loggers.py
+4
-3
vllm/v1/request.py
vllm/v1/request.py
+10
-6
vllm/v1/sample/sampler.py
vllm/v1/sample/sampler.py
+10
-0
vllm/v1/sample/tpu/metadata.py
vllm/v1/sample/tpu/metadata.py
+43
-46
vllm/v1/serial_utils.py
vllm/v1/serial_utils.py
+125
-41
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+36
-25
vllm/v1/structured_output/backend_guidance.py
vllm/v1/structured_output/backend_guidance.py
+3
-3
vllm/v1/structured_output/backend_xgrammar.py
vllm/v1/structured_output/backend_xgrammar.py
+7
-1
vllm/v1/structured_output/utils.py
vllm/v1/structured_output/utils.py
+1
-2
vllm/v1/utils.py
vllm/v1/utils.py
+7
-10
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+50
-32
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+180
-149
vllm/v1/worker/tpu_worker.py
vllm/v1/worker/tpu_worker.py
+9
-3
vllm/v1/worker/utils.py
vllm/v1/worker/utils.py
+45
-0
vllm/worker/enc_dec_model_runner.py
vllm/worker/enc_dec_model_runner.py
+25
-1
vllm/worker/hpu_model_runner.py
vllm/worker/hpu_model_runner.py
+314
-113
No files found.
vllm/v1/engine/core_client.py
View file @
9c4ecf15
...
...
@@ -26,7 +26,7 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType
,
UtilityOutput
)
from
vllm.v1.engine.core
import
EngineCore
,
EngineCoreProc
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
,
bytestr
from
vllm.v1.utils
import
BackgroundProcHandle
logger
=
init_logger
(
__name__
)
...
...
@@ -402,6 +402,36 @@ class MPClient(EngineCoreClient):
self
.
utility_results
:
dict
[
int
,
AnyFuture
]
=
{}
def
_wait_for_engine_startup
(
self
):
# Get a sync handle to the socket which can be sync or async.
sync_input_socket
=
zmq
.
Socket
.
shadow
(
self
.
input_socket
)
# Wait for engine core process(es) to send ready messages.
identities
=
set
(
eng
.
index
for
eng
in
self
.
resources
.
core_engines
)
poller
=
zmq
.
Poller
()
poller
.
register
(
sync_input_socket
,
zmq
.
POLLIN
)
for
eng
in
self
.
resources
.
core_engines
:
poller
.
register
(
eng
.
proc_handle
,
zmq
.
POLLIN
)
while
identities
:
events
=
poller
.
poll
(
STARTUP_POLL_PERIOD_MS
)
if
not
events
:
logger
.
debug
(
"Waiting for %d core engine proc(s) to start: %s"
,
len
(
identities
),
identities
)
continue
if
len
(
events
)
>
1
or
events
[
0
][
0
]
!=
sync_input_socket
:
# One of the core processes exited.
raise
RuntimeError
(
"Engine core initialization failed. "
"See root cause above."
)
eng_id_bytes
,
msg
=
sync_input_socket
.
recv_multipart
()
eng_id
=
int
.
from_bytes
(
eng_id_bytes
,
byteorder
=
"little"
)
if
eng_id
not
in
identities
:
raise
RuntimeError
(
f
"Unexpected or duplicate engine:
{
eng_id
}
"
)
if
msg
!=
b
'READY'
:
raise
RuntimeError
(
f
"Engine
{
eng_id
}
failed:
{
msg
.
decode
()
}
"
)
logger
.
info
(
"Core engine process %d ready."
,
eng_id
)
identities
.
discard
(
eng_id
)
def
_init_core_engines
(
self
,
vllm_config
:
VllmConfig
,
...
...
@@ -472,8 +502,8 @@ class SyncMPClient(MPClient):
# shutdown signal, exit thread.
break
frame
=
out_socket
.
recv
(
copy
=
False
)
outputs
=
decoder
.
decode
(
frame
.
buffer
)
frame
s
=
out_socket
.
recv
_multipart
(
copy
=
False
)
outputs
=
decoder
.
decode
(
frame
s
)
if
outputs
.
utility_output
:
_process_utility_output
(
outputs
.
utility_output
,
utility_results
)
...
...
@@ -494,10 +524,10 @@ class SyncMPClient(MPClient):
return
self
.
outputs_queue
.
get
()
def
_send_input
(
self
,
request_type
:
EngineCoreRequestType
,
request
:
Any
):
# (RequestType, SerializedRequest)
msg
=
(
request_type
.
value
,
self
.
encoder
.
encode
(
request
))
self
.
core_engine
.
send_multipart
(
msg
)
# (
Identity,
RequestType, SerializedRequest)
msg
=
(
self
.
core_engine
.
identity
,
request_type
.
value
,
*
self
.
encoder
.
encode
(
request
)
)
self
.
input_socket
.
send_multipart
(
msg
,
copy
=
False
)
def
call_utility
(
self
,
method
:
str
,
*
args
)
->
Any
:
call_id
=
uuid
.
uuid1
().
int
>>
64
future
:
Future
[
Any
]
=
Future
()
...
...
@@ -599,8 +629,8 @@ class AsyncMPClient(MPClient):
async
def
process_outputs_socket
():
while
True
:
(
frame
,
)
=
await
output_socket
.
recv_multipart
(
copy
=
False
)
outputs
:
EngineCoreOutputs
=
decoder
.
decode
(
frame
.
buffer
)
frame
s
=
await
output_socket
.
recv_multipart
(
copy
=
False
)
outputs
:
EngineCoreOutputs
=
decoder
.
decode
(
frame
s
)
if
outputs
.
utility_output
:
_process_utility_output
(
outputs
.
utility_output
,
utility_results
)
...
...
@@ -625,12 +655,20 @@ class AsyncMPClient(MPClient):
assert
self
.
outputs_queue
is
not
None
return
await
self
.
outputs_queue
.
get
()
async
def
_send_input
(
self
,
request_type
:
EngineCoreRequestType
,
request
:
Any
)
->
None
:
await
self
.
core_engine
.
send_multipart
(
(
request_type
.
value
,
self
.
encoder
.
encode
(
request
)))
def
_send_input
(
self
,
request_type
:
EngineCoreRequestType
,
request
:
Any
,
engine
:
Optional
[
CoreEngine
]
=
None
)
->
Awaitable
[
None
]:
if
engine
is
None
:
engine
=
self
.
core_engine
self
.
_ensure_output_queue_task
()
message
=
(
request_type
.
value
,
*
self
.
encoder
.
encode
(
request
))
return
self
.
_send_input_message
(
message
,
engine
)
def
_send_input_message
(
self
,
message
:
tuple
[
bytestr
,
...],
engine
:
CoreEngine
)
->
Awaitable
[
None
]:
message
=
(
engine
.
identity
,
)
+
message
return
self
.
input_socket
.
send_multipart
(
message
,
copy
=
False
)
async
def
call_utility_async
(
self
,
method
:
str
,
*
args
)
->
Any
:
return
await
self
.
_call_utility_async
(
method
,
...
...
@@ -646,9 +684,9 @@ class AsyncMPClient(MPClient):
call_id
=
uuid
.
uuid1
().
int
>>
64
future
=
asyncio
.
get_running_loop
().
create_future
()
self
.
utility_results
[
call_id
]
=
future
message
=
(
EngineCoreRequestType
.
UTILITY
.
value
,
self
.
encoder
.
encode
(
(
call_id
,
method
,
args
)))
await
engine
.
send_multipart
(
messag
e
)
message
=
(
EngineCoreRequestType
.
UTILITY
.
value
,
*
self
.
encoder
.
encode
(
(
call_id
,
method
,
args
)))
await
self
.
_send_input_message
(
message
,
engin
e
)
self
.
_ensure_output_queue_task
()
return
await
future
...
...
@@ -721,7 +759,7 @@ class DPAsyncMPClient(AsyncMPClient):
# Control message used for triggering dp idle mode loop.
self
.
start_dp_msg
=
(
EngineCoreRequestType
.
START_DP
.
value
,
self
.
encoder
.
encode
(
None
))
*
self
.
encoder
.
encode
(
None
))
self
.
num_engines_running
=
0
self
.
reqs_in_flight
:
dict
[
str
,
CoreEngine
]
=
{}
...
...
@@ -755,7 +793,7 @@ class DPAsyncMPClient(AsyncMPClient):
# tokenized.
request
.
prompt
=
None
msg
=
(
EngineCoreRequestType
.
ADD
.
value
,
self
.
encoder
.
encode
(
request
))
msg
=
(
EngineCoreRequestType
.
ADD
.
value
,
*
self
.
encoder
.
encode
(
request
))
chosen_engine
=
self
.
get_core_engine_for_request
()
self
.
reqs_in_flight
[
request
.
request_id
]
=
chosen_engine
...
...
vllm/v1/engine/mm_input_cache.py
View file @
9c4ecf15
# SPDX-License-Identifier: Apache-2.0
from
collections.abc
import
Sequence
from
typing
import
Optional
from
vllm.envs
import
VLLM_MM_INPUT_CACHE_GIB
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.multimodal.processing
import
ProcessingCache
from
vllm.utils
import
is_list_of
# The idea of multimodal preprocessing caching is based on having a client and
# a server, where the client executes in the frontend process (=P0) and the
...
...
@@ -11,9 +14,11 @@ from vllm.multimodal.processing import ProcessingCache
# -- Client:
# - BaseMultiModalProcessor to process MultiModalData into MultiModalKwargs
# with built-in caching functionality, with mm_hash as its identifier.
# - MirroredProcessingCache to keep track of the cached entries and
# determine whether to send the MultiModalKwargs to P1.
#
# -- Server:
# - M
MInputCacheServer to perform caching of the received
MultiModalKwargs.
# - M
irroredProcessingCache to store the
MultiModalKwargs
from P0
.
#
# The caching for both client and server is mirrored, and this allows us
# to avoid the serialization of "mm_inputs" (like pixel values) between
...
...
@@ -25,26 +30,48 @@ from vllm.multimodal.processing import ProcessingCache
# variable VLLM_MM_INPUT_CACHE_GIB.
class
M
MInputCacheServer
:
class
M
irroredProcessingCache
:
def
__init__
(
self
,
model_config
):
self
.
use_cache
=
not
model_config
.
disable_mm_preprocessor_cache
self
.
mm_cache
=
ProcessingCache
.
get_lru_cache
(
VLLM_MM_INPUT_CACHE_GIB
,
MultiModalKwargs
)
def
get_and_update
(
def
get_and_update
_p0
(
self
,
mm_inputs
:
list
[
MultiModalKwargs
],
mm_inputs
:
Sequence
[
MultiModalKwargs
],
mm_hashes
:
list
[
str
],
)
->
list
[
MultiModalKwargs
]:
)
->
Sequence
[
Optional
[
MultiModalKwargs
]
]
:
assert
len
(
mm_inputs
)
==
len
(
mm_hashes
)
if
not
self
.
use_cache
:
assert
is_list_of
(
mm_inputs
,
MultiModalKwargs
)
return
mm_inputs
full_mm_inputs
=
[]
full_mm_inputs
=
list
[
Optional
[
MultiModalKwargs
]]()
for
mm_input
,
mm_hash
in
zip
(
mm_inputs
,
mm_hashes
):
if
mm_hash
in
self
.
mm_cache
:
mm_input
=
None
else
:
self
.
mm_cache
[
mm_hash
]
=
mm_input
full_mm_inputs
.
append
(
mm_input
)
return
full_mm_inputs
def
get_and_update_p1
(
self
,
mm_inputs
:
Sequence
[
Optional
[
MultiModalKwargs
]],
mm_hashes
:
list
[
str
],
)
->
Sequence
[
MultiModalKwargs
]:
assert
len
(
mm_inputs
)
==
len
(
mm_hashes
)
if
not
self
.
use_cache
:
assert
is_list_of
(
mm_inputs
,
MultiModalKwargs
)
return
mm_inputs
full_mm_inputs
=
list
[
MultiModalKwargs
]()
for
mm_input
,
mm_hash
in
zip
(
mm_inputs
,
mm_hashes
):
assert
mm_hash
is
not
None
if
mm_input
is
None
:
mm_input
=
self
.
mm_cache
[
mm_hash
]
else
:
...
...
vllm/v1/engine/processor.py
View file @
9c4ecf15
# SPDX-License-Identifier: Apache-2.0
import
time
from
collections.abc
import
Mapping
from
typing
import
Optional
,
Union
from
collections.abc
import
Mapping
,
Sequence
from
typing
import
Literal
,
Optional
,
Union
from
vllm.config
import
VllmConfig
from
vllm.inputs
import
ProcessorInputs
,
PromptType
from
vllm.inputs
import
ProcessorInputs
,
PromptType
,
SingletonInputs
from
vllm.inputs.parse
import
split_enc_dec_inputs
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
MultiModalKwargs
,
MultiModalRegistry
)
from
vllm.multimodal.inputs
import
PlaceholderRange
from
vllm.multimodal.processing
import
EncDecMultiModalProcessor
from
vllm.multimodal.utils
import
merge_and_sort_multimodal_metadata
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer_group
import
BaseTokenizerGroup
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine.mm_input_cache
import
MirroredProcessingCache
from
vllm.v1.structured_output.backend_guidance
import
(
validate_guidance_grammar
)
from
vllm.v1.structured_output.utils
import
(
...
...
@@ -46,6 +48,8 @@ class Processor:
self
.
tokenizer
,
mm_registry
)
self
.
mm_input_cache_client
=
MirroredProcessingCache
(
self
.
model_config
)
# Multi-modal hasher (for images)
self
.
use_hash
=
(
not
self
.
model_config
.
disable_mm_preprocessor_cache
)
or
\
...
...
@@ -73,6 +77,7 @@ class Processor:
params
:
SamplingParams
,
)
->
None
:
self
.
_validate_structured_output
(
params
)
self
.
_validate_logit_bias
(
params
)
if
params
.
allowed_token_ids
is
None
:
return
...
...
@@ -83,6 +88,26 @@ class Processor:
raise
ValueError
(
"allowed_token_ids contains out-of-vocab token id!"
)
def
_validate_logit_bias
(
self
,
params
:
SamplingParams
,
)
->
None
:
"""Validate logit_bias token IDs are within vocabulary range."""
if
not
params
.
logit_bias
:
return
vocab_size
=
self
.
model_config
.
get_vocab_size
()
invalid_token_ids
=
[]
for
token_id
in
params
.
logit_bias
:
if
token_id
<
0
or
token_id
>=
vocab_size
:
invalid_token_ids
.
append
(
token_id
)
if
invalid_token_ids
:
raise
ValueError
(
f
"token_id(s)
{
invalid_token_ids
}
in logit_bias contain "
f
"out-of-vocab token ids. Vocabulary size:
{
vocab_size
}
"
)
def
_validate_supported_sampling_params
(
self
,
params
:
SamplingParams
,
...
...
@@ -136,9 +161,6 @@ class Processor:
f
" !=
{
engine_level_backend
}
"
)
else
:
params
.
guided_decoding
.
backend
=
engine_level_backend
import
vllm.platforms
if
vllm
.
platforms
.
current_platform
.
is_tpu
():
raise
ValueError
(
"Structured output is not supported on TPU."
)
# Request content validation
if
engine_level_backend
.
startswith
(
"xgrammar"
):
...
...
@@ -181,6 +203,11 @@ class Processor:
# TODO(woosuk): Support pooling models.
# TODO(woosuk): Support encoder-decoder models.
from
vllm.platforms
import
current_platform
current_platform
.
validate_request
(
prompt
=
prompt
,
params
=
params
,
)
self
.
_validate_lora
(
lora_request
)
self
.
_validate_params
(
params
)
if
priority
!=
0
:
...
...
@@ -228,7 +255,7 @@ class Processor:
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
))
# Multimodal related.
sorted_mm_inputs
:
Optional
[
list
[
MultiModalKwargs
]]
=
None
sorted_mm_inputs
:
Optional
[
Sequence
[
Optional
[
MultiModalKwargs
]]
]
=
None
sorted_mm_positions
:
Optional
[
list
[
PlaceholderRange
]]
=
None
sorted_mm_hashes
:
Optional
[
list
[
str
]]
=
None
if
decoder_inputs
[
"type"
]
==
"multimodal"
:
...
...
@@ -253,20 +280,28 @@ class Processor:
# are multiple modalities.
unique_modalities
=
set
(
sorted_item_modalities
)
if
len
(
unique_modalities
)
>
1
:
sorted_mm_inputs
=
[]
orig_
sorted_mm_inputs
=
[]
used_indices
=
{
modality
:
0
for
modality
in
unique_modalities
}
for
modality
in
sorted_item_modalities
:
items
=
decoder_mm_inputs
.
get_items
(
modality
)
item
=
items
[
used_indices
[
modality
]]
sorted_mm_inputs
.
append
(
MultiModalKwargs
.
from_items
([
item
]))
orig_sorted_mm_inputs
.
append
(
MultiModalKwargs
.
from_items
([
item
]))
used_indices
[
modality
]
+=
1
else
:
sorted_mm_inputs
=
[
orig_
sorted_mm_inputs
=
[
MultiModalKwargs
.
from_items
([
item
])
for
item
in
decoder_mm_inputs
.
get_items
(
sorted_item_modalities
[
0
])
]
if
sorted_mm_hashes
is
not
None
:
sorted_mm_inputs
=
self
.
mm_input_cache_client
.
get_and_update_p0
(
orig_sorted_mm_inputs
,
sorted_mm_hashes
)
else
:
sorted_mm_inputs
=
orig_sorted_mm_inputs
return
EngineCoreRequest
(
request_id
=
request_id
,
prompt
=
decoder_inputs
.
get
(
"prompt"
),
...
...
@@ -285,41 +320,64 @@ class Processor:
lora_request
:
Optional
[
LoRARequest
]
=
None
):
encoder_inputs
,
decoder_inputs
=
split_enc_dec_inputs
(
inputs
)
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
if
self
.
model_config
.
is_multimodal_model
:
prompt_inputs
=
decoder_inputs
else
:
prompt_inputs
=
encoder_inputs
or
decoder_inputs
prompt_ids
=
prompt_inputs
[
"prompt_token_ids"
]
if
prompt_ids
is
None
or
len
(
prompt_ids
)
==
0
:
raise
ValueError
(
"Prompt cannot be empty"
)
max_input_id
=
max
(
prompt_ids
)
max_allowed
=
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
).
max_token_id
if
max_input_id
>
max_allowed
:
raise
ValueError
(
"Token id {} is out of vocabulary"
.
format
(
max_input_id
))
if
encoder_inputs
is
not
None
:
self
.
_validate_model_input
(
encoder_inputs
,
lora_request
,
prompt_type
=
"encoder"
)
if
len
(
prompt_ids
)
>=
self
.
model_config
.
max_model_len
:
raise
ValueError
(
f
"Prompt length of
{
len
(
prompt_ids
)
}
is longer than the "
f
"maximum model length of
{
self
.
model_config
.
max_model_len
}
."
)
self
.
_validate_model_input
(
decoder_inputs
,
lora_request
,
prompt_type
=
"decoder"
)
if
self
.
model_config
.
is_multimodal_model
:
max_prompt_len
=
self
.
model_config
.
max_model_len
def
_validate_model_input
(
self
,
prompt_inputs
:
SingletonInputs
,
lora_request
:
Optional
[
LoRARequest
],
*
,
prompt_type
:
Literal
[
"encoder"
,
"decoder"
],
):
model_config
=
self
.
model_config
tokenizer
=
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
)
if
len
(
prompt_ids
)
>
max_prompt_len
:
raise
ValueError
(
f
"The prompt (total length
{
len
(
prompt_ids
)
}
) is too long "
f
"to fit into the model (context length
{
max_prompt_len
}
). "
prompt_ids
=
prompt_inputs
[
"prompt_token_ids"
]
if
not
prompt_ids
:
if
prompt_type
==
"encoder"
and
model_config
.
is_multimodal_model
:
pass
# Mllama may have empty encoder inputs for text-only data
else
:
raise
ValueError
(
f
"The
{
prompt_type
}
prompt cannot be empty"
)
max_input_id
=
max
(
prompt_ids
,
default
=
0
)
if
max_input_id
>
tokenizer
.
max_token_id
:
raise
ValueError
(
f
"Token id
{
max_input_id
}
is out of vocabulary"
)
max_prompt_len
=
self
.
model_config
.
max_model_len
if
len
(
prompt_ids
)
>=
max_prompt_len
:
if
prompt_type
==
"encoder"
and
model_config
.
is_multimodal_model
:
mm_registry
=
self
.
input_preprocessor
.
mm_registry
mm_processor
=
mm_registry
.
create_processor
(
model_config
,
tokenizer
=
tokenizer
,
)
assert
isinstance
(
mm_processor
,
EncDecMultiModalProcessor
)
if
mm_processor
.
pad_dummy_encoder_prompt
:
return
# Skip encoder length check for Whisper
if
model_config
.
is_multimodal_model
:
suggestion
=
(
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens plus multimodal tokens. For image "
"inputs, the number of image tokens depends on the number "
"of images, and possibly their aspect ratios as well."
)
else
:
suggestion
=
(
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens."
)
raise
ValueError
(
f
"The
{
prompt_type
}
prompt (length
{
len
(
prompt_ids
)
}
) is "
f
"longer than the maximum model length of
{
max_prompt_len
}
. "
f
"
{
suggestion
}
"
)
# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them
...
...
vllm/v1/executor/multiproc_executor.py
View file @
9c4ecf15
...
...
@@ -119,10 +119,9 @@ class MultiprocExecutor(Executor):
timeout
=
dequeue_timeout
)
if
status
!=
WorkerProc
.
ResponseStatus
.
SUCCESS
:
if
isinstance
(
result
,
Exception
):
raise
result
else
:
raise
RuntimeError
(
"Worker failed"
)
raise
RuntimeError
(
"Worker failed with error %s, please check the"
" stack trace above for the root cause"
,
result
)
responses
[
w
.
rank
]
=
result
...
...
@@ -327,7 +326,7 @@ class WorkerProc:
logger
.
debug
(
"Worker interrupted."
)
except
Exception
:
# worker_busy_loop sends exceptions
exceptons
to Executor
# worker_busy_loop sends exceptions to Executor
# for shutdown, but if there is an error in startup or an
# error with IPC itself, we need to alert the parent.
psutil
.
Process
().
parent
().
send_signal
(
signal
.
SIGUSR1
)
...
...
@@ -378,9 +377,11 @@ class WorkerProc:
# Notes have been introduced in python 3.11
if
hasattr
(
e
,
"add_note"
):
e
.
add_note
(
traceback
.
format_exc
())
self
.
worker_response_mq
.
enqueue
(
(
WorkerProc
.
ResponseStatus
.
FAILURE
,
e
))
logger
.
exception
(
"WorkerProc hit an exception: %s"
,
exc_info
=
e
)
# exception might not be serializable, so we convert it to
# string, only for logging purpose.
self
.
worker_response_mq
.
enqueue
(
(
WorkerProc
.
ResponseStatus
.
FAILURE
,
str
(
e
)))
continue
self
.
worker_response_mq
.
enqueue
(
...
...
vllm/v1/metrics/loggers.py
View file @
9c4ecf15
...
...
@@ -239,7 +239,8 @@ class PrometheusStatLogger(StatLoggerBase):
documentation
=
"Histogram of time to first token in seconds."
,
buckets
=
[
0.001
,
0.005
,
0.01
,
0.02
,
0.04
,
0.06
,
0.08
,
0.1
,
0.25
,
0.5
,
0.75
,
1.0
,
2.5
,
5.0
,
7.5
,
10.0
0.75
,
1.0
,
2.5
,
5.0
,
7.5
,
10.0
,
20.0
,
40.0
,
80.0
,
160.0
,
640.0
,
2560.0
],
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
...
...
@@ -249,13 +250,13 @@ class PrometheusStatLogger(StatLoggerBase):
documentation
=
"Histogram of time per output token in seconds."
,
buckets
=
[
0.01
,
0.025
,
0.05
,
0.075
,
0.1
,
0.15
,
0.2
,
0.3
,
0.4
,
0.5
,
0.75
,
1.0
,
2.5
0.75
,
1.0
,
2.5
,
5.0
,
7.5
,
10.0
,
20.0
,
40.0
,
80.0
],
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
request_latency_buckets
=
[
0.3
,
0.5
,
0.8
,
1.0
,
1.5
,
2.0
,
2.5
,
5.0
,
10.0
,
15.0
,
20.0
,
30.0
,
40.0
,
50.0
,
60.0
40.0
,
50.0
,
60.0
,
120.0
,
240.0
,
480.0
,
960.0
,
1920.0
,
7680.0
]
self
.
histogram_e2e_time_request
=
\
prometheus_client
.
Histogram
(
...
...
vllm/v1/request.py
View file @
9c4ecf15
...
...
@@ -3,17 +3,16 @@
import
enum
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
is_list_of
from
vllm.v1.engine
import
(
EngineCoreEvent
,
EngineCoreEventType
,
EngineCoreRequest
,
FinishReason
)
from
vllm.v1.structured_output.request
import
StructuredOutputRequest
from
vllm.v1.utils
import
ConstantList
if
TYPE_CHECKING
:
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.multimodal.inputs
import
PlaceholderRange
class
Request
:
...
...
@@ -23,9 +22,9 @@ class Request:
request_id
:
str
,
prompt
:
Optional
[
str
],
prompt_token_ids
:
list
[
int
],
multi_modal_inputs
:
Optional
[
list
[
"
MultiModalKwargs
"
]],
multi_modal_inputs
:
Optional
[
list
[
MultiModalKwargs
]],
multi_modal_hashes
:
Optional
[
list
[
str
]],
multi_modal_placeholders
:
Optional
[
list
[
"
PlaceholderRange
"
]],
multi_modal_placeholders
:
Optional
[
list
[
PlaceholderRange
]],
sampling_params
:
SamplingParams
,
eos_token_id
:
Optional
[
int
],
arrival_time
:
float
,
...
...
@@ -75,6 +74,11 @@ class Request:
@
classmethod
def
from_engine_core_request
(
cls
,
request
:
EngineCoreRequest
)
->
"Request"
:
if
request
.
mm_inputs
is
not
None
:
assert
isinstance
(
request
.
mm_inputs
,
list
)
assert
is_list_of
(
request
.
mm_inputs
,
MultiModalKwargs
),
(
"mm_inputs was not updated in EngineCore.add_request"
)
return
cls
(
request_id
=
request
.
request_id
,
prompt
=
request
.
prompt
,
...
...
@@ -121,7 +125,7 @@ class Request:
def
get_num_encoder_tokens
(
self
,
input_id
:
int
)
->
int
:
assert
input_id
<
len
(
self
.
mm_positions
)
num_tokens
=
self
.
mm_positions
[
input_id
]
[
"
length
"
]
num_tokens
=
self
.
mm_positions
[
input_id
]
.
length
return
num_tokens
@
property
...
...
vllm/v1/sample/sampler.py
View file @
9c4ecf15
...
...
@@ -230,9 +230,19 @@ class Sampler(nn.Module):
# TODO(houseroad): this implementation is extremely inefficient.
# One idea is implement this as a PyTorch C++ op, and we may
# even optimize the logit_bias layout.
# Get vocabulary size from logits
vocab_size
=
logits
.
shape
[
-
1
]
for
i
,
logit_bias
in
enumerate
(
sampling_metadata
.
logit_bias
):
if
logit_bias
:
for
token_id
,
bias
in
logit_bias
.
items
():
# Check token_id bounds to ensure within vocabulary
if
token_id
<
0
or
token_id
>=
vocab_size
:
raise
ValueError
(
f
"token_id
{
token_id
}
in logit_bias contains "
f
"out-of-vocab token id. Vocabulary size: "
f
"
{
vocab_size
}
"
)
logits
[
i
,
token_id
]
+=
bias
return
logits
...
...
vllm/v1/sample/tpu/metadata.py
View file @
9c4ecf15
...
...
@@ -3,7 +3,6 @@ from dataclasses import dataclass, field
from
typing
import
Optional
import
torch
import
torch_xla.core.xla_model
as
xm
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
...
...
@@ -24,19 +23,15 @@ class TPUSupportedSamplingMetadata:
# This class exposes a more xla-friendly interface than SamplingMetadata
# on TPU, in particular all arguments should be traceable and no optionals
# are allowed, to avoid graph recompilation on Nones.
temperature
:
torch
.
Tensor
temperature
:
torch
.
Tensor
=
None
min_p
:
torch
.
Tensor
min_p
:
torch
.
Tensor
=
None
# Still too slow on forward_native!
top_k
:
torch
.
Tensor
=
None
top_p
:
torch
.
Tensor
=
None
# Greedy sampling flag for compiling single xla graph.
all_greedy
:
torch
.
Tensor
=
None
# Generator not supported by xla
generators
:
dict
[
int
,
torch
.
Generator
]
=
field
(
default_factory
=
lambda
:
dict
())
all_greedy
:
bool
=
True
# unsupported, you need to return an extra tensor of static size BxV
max_num_logprobs
=
None
...
...
@@ -57,64 +52,66 @@ class TPUSupportedSamplingMetadata:
allowed_token_ids_mask
=
None
bad_words_token_ids
=
None
indices_do_sample
:
torch
.
Tensor
=
None
# Generator not supported by xla
_generators
:
dict
[
int
,
torch
.
Generator
]
=
field
(
default_factory
=
lambda
:
dict
())
@
property
def
generators
(
self
)
->
dict
[
int
,
torch
.
Generator
]:
# Generator not supported by torch/xla. This field must be immutable.
return
self
.
_generators
@
classmethod
def
from_input_batch
(
cls
,
input_batch
:
InputBatch
,
indices_do_sample
:
torch
.
Tensor
)
->
"TPUSupportedSamplingMetadata"
:
cls
,
input_batch
:
InputBatch
,
padded_num_reqs
:
int
,
xla_device
:
torch
.
device
,
generate_params_if_all_greedy
:
bool
=
False
)
->
"TPUSupportedSamplingMetadata"
:
"""
Copy sampling tensors slices from `input_batch` to on device tensors.
`InputBatch._make_sampling_metadata` causes recompilation on XLA as it
slices dynamic shapes on device tensors. This impl moves the dynamic
ops to CPU and produces tensors of fixed `padded_num_reqs` size. It
also reuses the on-device persistent tensors managed in `input_batch`
to reduce waste.
`indices_do_sample` contains the indices to be fed to the Sampler,
normally one per request, here padded to the closest pre-compiled shape
We expect sampling params tensors to be padded to the same fixed shape.
Eg. 3 requests, tensors padded to 4
temperature: [0.7, 0.2, 0.9]=>[0.7, 0.2, 0.9, 0.0]
sample indices: [4, 10, 11]=>indices_do_sample: [4, 10, 11, 0]
ops to CPU and produces tensors of fixed `padded_num_reqs` size.
Args:
input_batch: The input batch containing sampling parameters.
padded_num_reqs: The padded number of requests.
xla_device: The XLA device.
generate_params_if_all_greedy: If True, generate sampling parameters
even if all requests are greedy. this is useful for cases where
we want to pre-compile a graph with sampling parameters, even if
they are not strictly needed for greedy decoding.
"""
# Early return to avoid unnecessary cpu to tpu copy
if
(
input_batch
.
all_greedy
is
True
and
generate_params_if_all_greedy
is
False
):
return
cls
(
all_greedy
=
True
)
num_reqs
=
input_batch
.
num_reqs
padded_num_reqs
=
len
(
indices_do_sample
)
def
copy_slice
(
cpu_tensor
:
torch
.
Tensor
,
tpu_tensor
:
torch
.
Tensor
,
fill_val
)
->
torch
.
Tensor
:
# Copy slice from CPU to corresponding TPU pre-allocated tensor.
def
fill_slice
(
cpu_tensor
:
torch
.
Tensor
,
fill_val
)
->
torch
.
Tensor
:
# Pad value is the default one.
cpu_tensor
[
num_reqs
:
padded_num_reqs
]
=
fill_val
# Subtle compilation: len(tpu_tensor) must be >= `padded_num_reqs`
tpu_tensor
[:
padded_num_reqs
]
=
cpu_tensor
[:
padded_num_reqs
]
# NOTE NickLucche The sync CPU-TPU graph we produce here must be
# consistent. We can't have flags to skip copies or we'll end up
# recompiling.
copy_slice
(
input_batch
.
temperature_cpu_tensor
,
input_batch
.
temperature
,
fill_slice
(
input_batch
.
temperature_cpu_tensor
,
DEFAULT_SAMPLING_PARAMS
[
"temperature"
])
# TODO Temporarily disabled until sampling options are enabled
#
copy
_slice(input_batch.top_p_cpu_tensor
, input_batch.top_p
)
#
copy
_slice(input_batch.top_k_cpu_tensor
, input_batch.top_k
)
copy
_slice
(
input_batch
.
min_p_cpu_tensor
,
input_batch
.
min_p
,
#
fill
_slice(input_batch.top_p_cpu_tensor)
#
fill
_slice(input_batch.top_k_cpu_tensor)
fill
_slice
(
input_batch
.
min_p_cpu_tensor
,
DEFAULT_SAMPLING_PARAMS
[
"min_p"
])
xm
.
mark_step
()
xm
.
wait_device_ops
()
# Slice persistent device tensors to a fixed pre-compiled padded shape.
return
cls
(
temperature
=
input_batch
.
temperature
[:
padded_num_reqs
],
# Scalar tensor for xla-friendly tracing.
all_greedy
=
torch
.
tensor
(
input_batch
.
all_greedy
,
dtype
=
torch
.
bool
,
device
=
input_batch
.
device
),
temperature
=
input_batch
.
temperature_cpu_tensor
[:
padded_num_reqs
].
to
(
xla_device
),
all_greedy
=
input_batch
.
all_greedy
,
# TODO enable more and avoid returning None values
top_p
=
None
,
# input_batch.top_p[:padded_num_reqs],
top_k
=
None
,
# input_batch.top_k[:padded_num_reqs],
min_p
=
input_batch
.
min_p
[:
padded_num_reqs
],
generators
=
input_batch
.
generators
,
indices_do_sample
=
indices_do_sample
)
min_p
=
input_batch
.
min_p_cpu_tensor
[:
padded_num_reqs
].
to
(
xla_device
))
vllm/v1/serial_utils.py
View file @
9c4ecf15
# SPDX-License-Identifier: Apache-2.0
import
pickle
from
collections.abc
import
Sequence
from
inspect
import
isclass
from
types
import
FunctionType
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Optional
,
Union
import
cloudpickle
import
numpy
as
np
import
torch
import
zmq
from
msgspec
import
msgpack
CUSTOM_TYPE_
TENSOR
=
1
CUSTOM_TYPE_PICKLE
=
2
CUSTOM_TYPE_
CLOUDPICKLE
=
3
CUSTOM_TYPE_
PICKLE
=
1
CUSTOM_TYPE_
CLOUD
PICKLE
=
2
CUSTOM_TYPE_
RAW_VIEW
=
3
# TODO calibrate this size
MIN_NOCOPY_BUF_SIZE
=
512
class
MsgpackEncoder
:
"""Encoder with custom torch tensor serialization."""
bytestr
=
Union
[
bytes
,
bytearray
,
memoryview
,
zmq
.
Frame
]
def
__init__
(
self
):
self
.
encoder
=
msgpack
.
Encoder
(
enc_hook
=
custom_enc_hook
)
def
encode
(
self
,
obj
:
Any
)
->
bytes
:
return
self
.
encoder
.
encode
(
obj
)
class
MsgpackEncoder
:
"""Encoder with custom torch tensor and numpy array serialization.
def
encode_into
(
self
,
obj
:
Any
,
buf
:
bytearray
)
->
None
:
self
.
encoder
.
encode_into
(
obj
,
buf
)
Note that unlike vanilla `msgspec` Encoders, this interface is generally
not thread-safe when encoding tensors / numpy arrays.
"""
def
__init__
(
self
):
self
.
encoder
=
msgpack
.
Encoder
(
enc_hook
=
self
.
enc_hook
)
# This is used as a local stash of buffers that we can then access from
# our custom `msgspec` hook, `enc_hook`. We don't have a way to
# pass custom data to the hook otherwise.
self
.
aux_buffers
:
Optional
[
list
[
bytestr
]]
=
None
def
encode
(
self
,
obj
:
Any
)
->
Sequence
[
bytestr
]:
try
:
self
.
aux_buffers
=
bufs
=
[
b
''
]
bufs
[
0
]
=
self
.
encoder
.
encode
(
obj
)
# This `bufs` list allows us to collect direct pointers to backing
# buffers of tensors and np arrays, and return them along with the
# top-level encoded buffer instead of copying their data into the
# new buffer.
return
bufs
finally
:
self
.
aux_buffers
=
None
def
encode_into
(
self
,
obj
:
Any
,
buf
:
bytearray
)
->
Sequence
[
bytestr
]:
try
:
self
.
aux_buffers
=
[
buf
]
bufs
=
self
.
aux_buffers
self
.
encoder
.
encode_into
(
obj
,
buf
)
return
bufs
finally
:
self
.
aux_buffers
=
None
def
enc_hook
(
self
,
obj
:
Any
)
->
Any
:
if
isinstance
(
obj
,
torch
.
Tensor
):
return
self
.
_encode_ndarray
(
obj
.
numpy
())
# Fall back to pickle for object or void kind ndarrays.
if
isinstance
(
obj
,
np
.
ndarray
)
and
obj
.
dtype
.
kind
not
in
(
'O'
,
'V'
):
return
self
.
_encode_ndarray
(
obj
)
if
isinstance
(
obj
,
FunctionType
):
# `pickle` is generally faster than cloudpickle, but can have
# problems serializing methods.
return
msgpack
.
Ext
(
CUSTOM_TYPE_CLOUDPICKLE
,
cloudpickle
.
dumps
(
obj
))
return
msgpack
.
Ext
(
CUSTOM_TYPE_PICKLE
,
pickle
.
dumps
(
obj
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
))
def
_encode_ndarray
(
self
,
obj
:
np
.
ndarray
)
->
tuple
[
str
,
tuple
[
int
,
...],
Union
[
int
,
memoryview
]]:
assert
self
.
aux_buffers
is
not
None
arr_data
=
obj
.
data
if
obj
.
data
.
c_contiguous
else
obj
.
tobytes
()
if
not
obj
.
shape
or
obj
.
nbytes
<
MIN_NOCOPY_BUF_SIZE
:
# Encode small arrays and scalars inline. Using this extension type
# ensures we can avoid copying when decoding.
data
=
msgpack
.
Ext
(
CUSTOM_TYPE_RAW_VIEW
,
arr_data
)
else
:
# Otherwise encode index of backing buffer to avoid copy.
data
=
len
(
self
.
aux_buffers
)
self
.
aux_buffers
.
append
(
arr_data
)
# We serialize the ndarray as a tuple of native types.
# The data is either inlined if small, or an index into a list of
# backing buffers that we've stashed in `aux_buffers`.
return
obj
.
dtype
.
str
,
obj
.
shape
,
data
class
MsgpackDecoder
:
"""Decoder with custom torch tensor serialization."""
"""Decoder with custom torch tensor and numpy array serialization.
Note that unlike vanilla `msgspec` Decoders, this interface is generally
not thread-safe when encoding tensors / numpy arrays.
"""
def
__init__
(
self
,
t
:
Optional
[
Any
]
=
None
):
args
=
()
if
t
is
None
else
(
t
,
)
self
.
decoder
=
msgpack
.
Decoder
(
*
args
,
ext_hook
=
custom_ext_hook
)
def
decode
(
self
,
obj
:
Any
):
return
self
.
decoder
.
decode
(
obj
)
def
custom_enc_hook
(
obj
:
Any
)
->
Any
:
if
isinstance
(
obj
,
torch
.
Tensor
):
# NOTE(rob): it is fastest to use numpy + pickle
# when serializing torch tensors.
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
return
msgpack
.
Ext
(
CUSTOM_TYPE_TENSOR
,
pickle
.
dumps
(
obj
.
numpy
()))
if
isinstance
(
obj
,
FunctionType
):
return
msgpack
.
Ext
(
CUSTOM_TYPE_CLOUDPICKLE
,
cloudpickle
.
dumps
(
obj
))
return
msgpack
.
Ext
(
CUSTOM_TYPE_PICKLE
,
pickle
.
dumps
(
obj
))
def
custom_ext_hook
(
code
:
int
,
data
:
memoryview
)
->
Any
:
if
code
==
CUSTOM_TYPE_TENSOR
:
return
torch
.
from_numpy
(
pickle
.
loads
(
data
))
if
code
==
CUSTOM_TYPE_PICKLE
:
return
pickle
.
loads
(
data
)
if
code
==
CUSTOM_TYPE_CLOUDPICKLE
:
return
cloudpickle
.
loads
(
data
)
raise
NotImplementedError
(
f
"Extension type code
{
code
}
is not supported"
)
self
.
decoder
=
msgpack
.
Decoder
(
*
args
,
ext_hook
=
self
.
ext_hook
,
dec_hook
=
self
.
dec_hook
)
self
.
aux_buffers
:
Sequence
[
bytestr
]
=
()
def
decode
(
self
,
bufs
:
Union
[
bytestr
,
Sequence
[
bytestr
]])
->
Any
:
if
isinstance
(
bufs
,
(
bytes
,
bytearray
,
memoryview
,
zmq
.
Frame
)):
# TODO - This check can become `isinstance(bufs, bytestr)`
# as of Python 3.10.
return
self
.
decoder
.
decode
(
bufs
)
self
.
aux_buffers
=
bufs
try
:
return
self
.
decoder
.
decode
(
bufs
[
0
])
finally
:
self
.
aux_buffers
=
()
def
dec_hook
(
self
,
t
:
type
,
obj
:
Any
)
->
Any
:
# Given native types in `obj`, convert to type `t`.
if
isclass
(
t
):
if
issubclass
(
t
,
np
.
ndarray
):
return
self
.
_decode_ndarray
(
obj
)
if
issubclass
(
t
,
torch
.
Tensor
):
return
torch
.
from_numpy
(
self
.
_decode_ndarray
(
obj
))
return
obj
def
_decode_ndarray
(
self
,
arr
:
Any
)
->
np
.
ndarray
:
dtype
,
shape
,
data
=
arr
buffer
=
self
.
aux_buffers
[
data
]
if
isinstance
(
data
,
int
)
else
data
return
np
.
ndarray
(
buffer
=
buffer
,
dtype
=
np
.
dtype
(
dtype
),
shape
=
shape
)
def
ext_hook
(
self
,
code
:
int
,
data
:
memoryview
)
->
Any
:
if
code
==
CUSTOM_TYPE_RAW_VIEW
:
return
data
if
code
==
CUSTOM_TYPE_PICKLE
:
return
pickle
.
loads
(
data
)
if
code
==
CUSTOM_TYPE_CLOUDPICKLE
:
return
cloudpickle
.
loads
(
data
)
raise
NotImplementedError
(
f
"Extension type code
{
code
}
is not supported"
)
vllm/v1/spec_decode/eagle.py
View file @
9c4ecf15
...
...
@@ -4,8 +4,11 @@ import torch.nn as nn
import
triton
import
triton.language
as
tl
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.forward_context
import
set_forward_context
from
vllm.model_executor.model_loader.loader
import
get_model_loader
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
from
vllm.model_executor.models.llama_eagle
import
EagleLlamaForCausalLM
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
...
...
@@ -21,8 +24,12 @@ class EagleProposer:
self
.
num_speculative_tokens
=
(
vllm_config
.
speculative_config
.
num_speculative_tokens
)
self
.
block_size
=
vllm_config
.
cache_config
.
block_size
self
.
arange
=
torch
.
arange
(
vllm_config
.
scheduler_config
.
max_num_seqs
,
device
=
device
)
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
self
.
arange
=
torch
.
arange
(
vllm_config
.
scheduler_config
.
max_num_seqs
+
1
,
device
=
device
,
dtype
=
torch
.
int32
)
def
propose
(
self
,
...
...
@@ -54,7 +61,9 @@ class EagleProposer:
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
input_ids
[
last_token_indices
]
=
next_token_ids
seq_lens
=
target_positions
[
last_token_indices
]
+
1
# FA requires seq_len to have dtype int32.
seq_lens
=
(
target_positions
[
last_token_indices
]
+
1
).
int
()
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
max_seq_len
=
seq_lens
.
max
().
item
()
max_num_tokens
=
(
cu_num_tokens
[
1
:]
-
cu_num_tokens
[:
-
1
]).
max
().
item
()
...
...
@@ -98,7 +107,7 @@ class EagleProposer:
hidden_states
=
sample_hidden_states
attn_metadata
.
num_actual_tokens
=
batch_size
attn_metadata
.
max_query_len
=
1
attn_metadata
.
query_start_loc
=
self
.
arange
[:
batch_size
]
attn_metadata
.
query_start_loc
=
self
.
arange
[:
batch_size
+
1
]
for
_
in
range
(
self
.
num_speculative_tokens
-
1
):
# Update the inputs.
input_ids
=
draft_token_ids_list
[
-
1
]
...
...
@@ -176,26 +185,28 @@ class EagleProposer:
return
cu_num_tokens
,
token_indices
def
load_model
(
self
,
target_model
:
nn
.
Module
)
->
None
:
self
.
model
=
DummyEagleModel
()
self
.
model
.
get_input_embeddings
=
target_model
.
get_input_embeddings
self
.
model
.
compute_logits
=
target_model
.
compute_logits
# FIXME(woosuk): This is a dummy model for testing.
# Remove this once we have a real model.
class
DummyEagleModel
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
input_embeddings
=
self
.
get_input_embeddings
(
input_ids
)
return
hidden_states
+
input_embeddings
# Dummy return.
loader
=
get_model_loader
(
self
.
vllm_config
.
load_config
)
target_layer_num
=
self
.
vllm_config
.
model_config
.
get_num_layers
(
self
.
vllm_config
.
parallel_config
)
draft_model_config
=
\
self
.
vllm_config
.
speculative_config
.
draft_model_config
# FIXME(lily): This does not handle with distributed inference.
target_device
=
self
.
vllm_config
.
device_config
.
device
# We need to set the vllm_config here to register attention
# layers in the forward context.
with
set_default_torch_dtype
(
draft_model_config
.
dtype
),
set_current_vllm_config
(
self
.
vllm_config
):
self
.
model
=
EagleLlamaForCausalLM
(
model_config
=
draft_model_config
,
start_layer_id
=
target_layer_num
).
to
(
target_device
)
self
.
model
.
load_weights
(
loader
.
get_all_weights
(
self
.
vllm_config
.
speculative_config
.
draft_model_config
,
self
.
model
))
self
.
model
.
lm_head
=
target_model
.
lm_head
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
...
...
vllm/v1/structured_output/backend_guidance.py
View file @
9c4ecf15
...
...
@@ -46,7 +46,8 @@ class GuidanceBackend(StructuredOutputBackend):
in
vllm_config
.
decoding_config
.
guided_decoding_backend
)
tokenizer
=
tokenizer_group
.
get_lora_tokenizer
(
None
)
self
.
ll_tokenizer
=
llguidance_hf
.
from_tokenizer
(
tokenizer
,
None
)
self
.
ll_tokenizer
=
llguidance_hf
.
from_tokenizer
(
tokenizer
,
self
.
vocab_size
)
def
compile_grammar
(
self
,
request_type
:
StructuredOutputOptions
,
grammar_spec
:
str
)
->
StructuredOutputGrammar
:
...
...
@@ -163,7 +164,6 @@ def validate_guidance_grammar(
tokenizer
:
Optional
[
llguidance
.
LLTokenizer
]
=
None
)
->
None
:
tp
,
grm
=
get_structured_output_key
(
sampling_params
)
guidance_grm
=
serialize_guidance_grammar
(
tp
,
grm
)
err
=
llguidance
.
LLMatcher
.
validate_grammar
(
guidance_grm
,
tokenizer
=
tokenizer
)
err
=
llguidance
.
LLMatcher
.
validate_grammar
(
guidance_grm
,
tokenizer
)
if
err
:
raise
ValueError
(
f
"Grammar error:
{
err
}
"
)
vllm/v1/structured_output/backend_xgrammar.py
View file @
9c4ecf15
...
...
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
import
torch
import
vllm.envs
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
...
...
@@ -76,7 +77,12 @@ class XgrammarBackend(StructuredOutputBackend):
tokenizer
,
vocab_size
=
self
.
vocab_size
,
)
self
.
compiler
=
xgr
.
GrammarCompiler
(
tokenizer_info
,
max_threads
=
8
)
self
.
compiler
=
xgr
.
GrammarCompiler
(
tokenizer_info
,
max_threads
=
8
,
cache_enabled
=
True
,
cache_limit_bytes
=
vllm
.
envs
.
VLLM_XGRAMMAR_CACHE_MB
*
1024
*
1024
,
)
def
compile_grammar
(
self
,
request_type
:
StructuredOutputOptions
,
grammar_spec
:
str
)
->
StructuredOutputGrammar
:
...
...
vllm/v1/structured_output/utils.py
View file @
9c4ecf15
...
...
@@ -41,8 +41,7 @@ def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool:
return
True
# Unsupported keywords for strings
if
obj
.
get
(
"type"
)
==
"string"
and
any
(
key
in
obj
for
key
in
(
"minLength"
,
"maxLength"
,
"format"
)):
if
obj
.
get
(
"type"
)
==
"string"
and
"format"
in
obj
:
return
True
# Unsupported keywords for objects
...
...
vllm/v1/utils.py
View file @
9c4ecf15
# SPDX-License-Identifier: Apache-2.0
import
multiprocessing
import
os
import
weakref
from
collections
import
defaultdict
from
collections.abc
import
Sequence
from
multiprocessing
import
Process
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Generic
,
Optional
,
TypeVar
,
Union
,
overload
)
...
...
@@ -115,18 +115,15 @@ class BackgroundProcHandle:
process_kwargs
[
"output_path"
]
=
output_path
# Run busy loop in background process.
self
.
proc
=
context
.
Process
(
target
=
target_fn
,
kwargs
=
process_kwargs
,
name
=
process_name
)
self
.
proc
:
Process
=
context
.
Process
(
target
=
target_fn
,
kwargs
=
process_kwargs
,
name
=
process_name
)
self
.
_finalizer
=
weakref
.
finalize
(
self
,
shutdown
,
self
.
proc
,
input_path
,
output_path
)
self
.
proc
.
start
()
def
wait_for_startup
(
self
):
# Wait for startup.
if
self
.
reader
.
recv
()[
"status"
]
!=
"READY"
:
raise
RuntimeError
(
f
"
{
self
.
proc
.
name
}
initialization failed. "
"See root cause above."
)
def
fileno
(
self
):
return
self
.
proc
.
sentinel
def
shutdown
(
self
):
self
.
_finalizer
()
...
...
@@ -134,7 +131,7 @@ class BackgroundProcHandle:
# Note(rob): shutdown function cannot be a bound method,
# else the gc cannot collect the object.
def
shutdown
(
proc
:
multiprocessing
.
Process
,
input_path
:
str
,
output_path
:
str
):
def
shutdown
(
proc
:
Process
,
input_path
:
str
,
output_path
:
str
):
# Shutdown the process.
if
proc
.
is_alive
():
proc
.
terminate
()
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
9c4ecf15
...
...
@@ -19,7 +19,8 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
from
vllm.model_executor.model_loader
import
get_model
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.multimodal.utils
import
group_mm_inputs_by_modality
from
vllm.sampling_params
import
SamplingType
from
vllm.sequence
import
IntermediateTensors
...
...
@@ -43,7 +44,8 @@ from vllm.v1.utils import bind_kv_cache
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
.utils
import
sanity_check_mm_encoder_outputs
from
.utils
import
(
gather_mm_placeholders
,
sanity_check_mm_encoder_outputs
,
scatter_mm_placeholders
)
if
TYPE_CHECKING
:
import
xgrammar
as
xgr
...
...
@@ -482,14 +484,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
input_batch
.
block_table
.
commit
(
num_reqs
)
# Get the number of scheduled tokens for each request.
# TODO: The Python loop can be slow. Optimize.
num_scheduled_tokens
=
np
.
empty
(
num_reqs
,
dtype
=
np
.
int32
)
max_num_scheduled_tokens
=
0
for
i
,
req_id
in
enumerate
(
self
.
input_batch
.
req_ids
):
num_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
num_scheduled_tokens
[
i
]
=
num_tokens
max_num_scheduled_tokens
=
max
(
max_num_scheduled_tokens
,
num_tokens
)
req_ids
=
self
.
input_batch
.
req_ids
tokens
=
[
scheduler_output
.
num_scheduled_tokens
[
i
]
for
i
in
req_ids
]
num_scheduled_tokens
=
np
.
array
(
tokens
,
dtype
=
np
.
int32
)
max_num_scheduled_tokens
=
max
(
tokens
)
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
...
...
@@ -830,19 +828,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
return
metadata
def
_execute_encoder
(
self
,
scheduler_output
:
"SchedulerOutput"
):
def
_execute_
mm_
encoder
(
self
,
scheduler_output
:
"SchedulerOutput"
):
scheduled_encoder_inputs
=
scheduler_output
.
scheduled_encoder_inputs
if
not
scheduled_encoder_inputs
:
return
# Batch the multi-modal inputs.
mm_inputs
:
list
[
MultiModalKwargs
]
=
[]
req_i
nput_ids
:
list
[
tuple
[
str
,
int
]]
=
[]
mm_inputs
=
list
[
MultiModalKwargs
]
()
req_i
ds_pos
=
list
[
tuple
[
str
,
int
,
PlaceholderRange
]]()
for
req_id
,
encoder_input_ids
in
scheduled_encoder_inputs
.
items
():
req_state
=
self
.
requests
[
req_id
]
for
input_id
in
encoder_input_ids
:
mm_inputs
.
append
(
req_state
.
mm_inputs
[
input_id
])
req_input_ids
.
append
((
req_id
,
input_id
))
for
mm_input_id
in
encoder_input_ids
:
mm_inputs
.
append
(
req_state
.
mm_inputs
[
mm_input_id
])
req_ids_pos
.
append
(
(
req_id
,
mm_input_id
,
req_state
.
mm_positions
[
mm_input_id
]))
# Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one,
...
...
@@ -878,16 +878,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
encoder_outputs
.
append
(
output
)
# Cache the encoder outputs.
for
(
req_id
,
input_id
),
output
in
zip
(
req_input_ids
,
encoder_outputs
):
for
(
req_id
,
input_id
,
pos_info
),
output
in
zip
(
req_ids_pos
,
encoder_outputs
,
):
if
req_id
not
in
self
.
encoder_cache
:
self
.
encoder_cache
[
req_id
]
=
{}
self
.
encoder_cache
[
req_id
][
input_id
]
=
output
def
_gather_encoder_outputs
(
self
.
encoder_cache
[
req_id
][
input_id
]
=
scatter_mm_placeholders
(
output
,
is_embed
=
pos_info
.
is_embed
,
)
def
_gather_mm_embeddings
(
self
,
scheduler_output
:
"SchedulerOutput"
,
)
->
list
[
torch
.
Tensor
]:
encoder_output
s
:
list
[
torch
.
Tensor
]
=
[]
mm_embed
s
:
list
[
torch
.
Tensor
]
=
[]
for
req_id
in
self
.
input_batch
.
req_ids
:
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
...
...
@@ -895,8 +902,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_computed_tokens
=
req_state
.
num_computed_tokens
mm_positions
=
req_state
.
mm_positions
for
i
,
pos_info
in
enumerate
(
mm_positions
):
start_pos
=
pos_info
[
"
offset
"
]
num_encoder_tokens
=
pos_info
[
"
length
"
]
start_pos
=
pos_info
.
offset
num_encoder_tokens
=
pos_info
.
length
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens,
...
...
@@ -918,8 +925,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert
req_id
in
self
.
encoder_cache
assert
i
in
self
.
encoder_cache
[
req_id
]
encoder_output
=
self
.
encoder_cache
[
req_id
][
i
]
encoder_outputs
.
append
(
encoder_output
[
start_idx
:
end_idx
])
return
encoder_outputs
if
(
is_embed
:
=
pos_info
.
is_embed
)
is
not
None
:
is_embed
=
is_embed
[
start_idx
:
end_idx
]
mm_embeds_item
=
gather_mm_placeholders
(
encoder_output
[
start_idx
:
end_idx
],
is_embed
=
is_embed
,
)
mm_embeds
.
append
(
mm_embeds_item
)
return
mm_embeds
def
get_model
(
self
)
->
nn
.
Module
:
return
self
.
model
...
...
@@ -979,15 +994,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
->
Union
[
ModelRunnerOutput
,
torch
.
Tensor
]:
self
.
_update_states
(
scheduler_output
)
if
not
scheduler_output
.
total_num_scheduled_tokens
:
# Return empty ModelRunnerOu
p
tut if there's no work to do.
# Return empty ModelRunnerOut
p
ut if there's no work to do.
return
EMPTY_MODEL_RUNNER_OUTPUT
if
self
.
is_multimodal_model
:
# Run the multimodal encoder if any.
self
.
_execute_encoder
(
scheduler_output
)
encoder_output
s
=
self
.
_gather_
encoder_output
s
(
scheduler_output
)
self
.
_execute_
mm_
encoder
(
scheduler_output
)
mm_embed
s
=
self
.
_gather_
mm_embedding
s
(
scheduler_output
)
else
:
encoder_output
s
=
[]
mm_embed
s
=
[]
# Prepare the decoder inputs.
attn_metadata
,
logits_indices
,
spec_decode_metadata
=
(
...
...
@@ -1009,9 +1024,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
input_ids
=
self
.
input_ids
[:
num_scheduled_tokens
]
if
encoder_output
s
:
if
mm_embed
s
:
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
input_ids
,
encoder_output
s
)
input_ids
,
mm_embed
s
)
else
:
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
input_ids
)
# TODO(woosuk): Avoid the copy. Optimize.
...
...
@@ -1172,9 +1187,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
spec_decode_metadata
is
None
:
# input_ids can be None for multimodal models.
# We need to slice token_ids, positions, and hidden_states
# because the eagle head does not use cuda graph and should
# not include padding.
target_token_ids
=
self
.
input_ids
[:
num_scheduled_tokens
]
target_positions
=
positions
target_hidden_states
=
hidden_states
target_positions
=
positions
[:
num_scheduled_tokens
]
target_hidden_states
=
hidden_states
[:
num_scheduled_tokens
]
target_slot_mapping
=
attn_metadata
.
slot_mapping
cu_num_tokens
=
attn_metadata
.
query_start_loc
else
:
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
9c4ecf15
...
...
@@ -15,13 +15,14 @@ import torch_xla.runtime as xr
import
vllm.envs
as
envs
from
vllm.attention.backends.abstract
import
AttentionType
from
vllm.attention.layer
import
Attention
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispatcher
from
vllm.config
import
VllmConfig
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.multimodal.utils
import
group_mm_inputs_by_modality
from
vllm.sampling_params
import
SamplingType
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
LayerBlockType
,
cdiv
,
is_pin_memory_available
from
vllm.v1.attention.backends.pallas
import
(
PallasAttentionBackend
,
...
...
@@ -30,13 +31,14 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheSpec
,
SlidingWindowSpec
)
from
vllm.v1.outputs
import
(
EMPTY_MODEL_RUNNER_OUTPUT
,
LogprobsTensors
,
ModelRunnerOutput
,
SamplerOutput
)
ModelRunnerOutput
)
from
vllm.v1.sample.tpu.metadata
import
TPUSupportedSamplingMetadata
from
vllm.v1.sample.tpu.sampler
import
Sampler
as
TPUSampler
from
vllm.v1.utils
import
bind_kv_cache
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
.utils
import
sanity_check_mm_encoder_outputs
from
.utils
import
(
gather_mm_placeholders
,
sanity_check_mm_encoder_outputs
,
scatter_mm_placeholders
)
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
...
...
@@ -174,10 +176,12 @@ class TPUModelRunner:
# Range tensor with values [0 .. self.max_num_tokens - 1].
# Used to initialize positions / context_lens / seq_lens
self
.
arange_np
=
np
.
arange
(
self
.
max_num_tokens
,
dtype
=
np
.
int32
)
self
.
num_tokens_paddings
=
_get_paddings
(
self
.
num_tokens_paddings
=
_get_
token_
paddings
(
min_token_size
=
16
,
max_token_size
=
self
.
max_num_tokens
,
padding_gap
=
envs
.
VLLM_TPU_BUCKET_PADDING_GAP
)
self
.
num_reqs_paddings
=
_get_req_paddings
(
min_req_size
=
MIN_NUM_SEQS
,
max_req_size
=
self
.
max_num_reqs
)
def
_update_num_xla_graphs
(
self
,
case_str
):
check_comp
=
self
.
check_recompilation
and
not
self
.
enforce_eager
...
...
@@ -262,11 +266,6 @@ class TPUModelRunner:
for
new_req_data
in
scheduler_output
.
scheduled_new_reqs
:
req_id
=
new_req_data
.
req_id
sampling_params
=
new_req_data
.
sampling_params
if
sampling_params
.
sampling_type
==
SamplingType
.
RANDOM_SEED
:
generator
=
torch
.
Generator
(
device
=
self
.
device
)
generator
.
manual_seed
(
sampling_params
.
seed
)
else
:
generator
=
None
self
.
requests
[
req_id
]
=
CachedRequestState
(
req_id
=
req_id
,
...
...
@@ -275,7 +274,7 @@ class TPUModelRunner:
mm_inputs
=
new_req_data
.
mm_inputs
,
mm_positions
=
new_req_data
.
mm_positions
,
sampling_params
=
sampling_params
,
generator
=
generator
,
generator
=
None
,
block_ids
=
new_req_data
.
block_ids
,
num_computed_tokens
=
new_req_data
.
num_computed_tokens
,
output_token_ids
=
[],
...
...
@@ -505,21 +504,48 @@ class TPUModelRunner:
# Padded to avoid recompiling when `num_reqs` varies.
logits_indices
=
self
.
query_start_loc_cpu
[
1
:
padded_num_reqs
+
1
]
-
1
logits_indices
=
logits_indices
.
to
(
self
.
device
)
return
attn_metadata
,
logits_indices
return
attn_metadata
,
logits_indices
,
padded_num_reqs
def
_scatter_placeholders
(
self
,
embeds
:
torch
.
Tensor
,
is_embed
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
if
is_embed
is
None
:
return
embeds
def
_execute_encoder
(
self
,
scheduler_output
:
"SchedulerOutput"
):
placeholders
=
embeds
.
new_full
(
(
is_embed
.
shape
[
0
],
embeds
.
shape
[
-
1
]),
fill_value
=
torch
.
nan
,
)
placeholders
[
is_embed
]
=
embeds
return
placeholders
def
_gather_placeholders
(
self
,
placeholders
:
torch
.
Tensor
,
is_embed
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
if
is_embed
is
None
:
return
placeholders
return
placeholders
[
is_embed
]
def
_execute_mm_encoder
(
self
,
scheduler_output
:
"SchedulerOutput"
):
scheduled_encoder_inputs
=
scheduler_output
.
scheduled_encoder_inputs
if
not
scheduled_encoder_inputs
:
return
# Batch the multi-modal inputs.
mm_inputs
:
list
[
MultiModalKwargs
]
=
[]
req_i
nput_ids
:
list
[
tuple
[
str
,
int
]]
=
[]
mm_inputs
=
list
[
MultiModalKwargs
]
()
req_i
ds_pos
=
list
[
tuple
[
str
,
int
,
PlaceholderRange
]]()
for
req_id
,
encoder_input_ids
in
scheduled_encoder_inputs
.
items
():
req_state
=
self
.
requests
[
req_id
]
for
input_id
in
encoder_input_ids
:
mm_inputs
.
append
(
req_state
.
mm_inputs
[
input_id
])
req_input_ids
.
append
((
req_id
,
input_id
))
for
mm_input_id
in
encoder_input_ids
:
mm_inputs
.
append
(
req_state
.
mm_inputs
[
mm_input_id
])
req_ids_pos
.
append
(
(
req_id
,
mm_input_id
,
req_state
.
mm_positions
[
mm_input_id
]))
# Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one,
...
...
@@ -555,16 +581,23 @@ class TPUModelRunner:
encoder_outputs
.
append
(
output
)
# Cache the encoder outputs.
for
(
req_id
,
input_id
),
output
in
zip
(
req_input_ids
,
encoder_outputs
):
for
(
req_id
,
input_id
,
pos_info
),
output
in
zip
(
req_ids_pos
,
encoder_outputs
,
):
if
req_id
not
in
self
.
encoder_cache
:
self
.
encoder_cache
[
req_id
]
=
{}
self
.
encoder_cache
[
req_id
][
input_id
]
=
output
def
_gather_encoder_outputs
(
self
.
encoder_cache
[
req_id
][
input_id
]
=
scatter_mm_placeholders
(
output
,
is_embed
=
pos_info
.
is_embed
,
)
def
_gather_mm_embeddings
(
self
,
scheduler_output
:
"SchedulerOutput"
,
)
->
list
[
torch
.
Tensor
]:
encoder_output
s
:
list
[
torch
.
Tensor
]
=
[]
mm_embed
s
:
list
[
torch
.
Tensor
]
=
[]
for
req_id
in
self
.
input_batch
.
req_ids
:
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
...
...
@@ -572,8 +605,8 @@ class TPUModelRunner:
num_computed_tokens
=
req_state
.
num_computed_tokens
mm_positions
=
req_state
.
mm_positions
for
i
,
pos_info
in
enumerate
(
mm_positions
):
start_pos
=
pos_info
[
"
offset
"
]
num_encoder_tokens
=
pos_info
[
"
length
"
]
start_pos
=
pos_info
.
offset
num_encoder_tokens
=
pos_info
.
length
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens,
...
...
@@ -595,8 +628,16 @@ class TPUModelRunner:
assert
req_id
in
self
.
encoder_cache
assert
i
in
self
.
encoder_cache
[
req_id
]
encoder_output
=
self
.
encoder_cache
[
req_id
][
i
]
encoder_outputs
.
append
(
encoder_output
[
start_idx
:
end_idx
])
return
encoder_outputs
if
(
is_embed
:
=
pos_info
.
is_embed
)
is
not
None
:
is_embed
=
is_embed
[
start_idx
:
end_idx
]
mm_embeds_item
=
gather_mm_placeholders
(
encoder_output
[
start_idx
:
end_idx
],
is_embed
=
is_embed
,
)
mm_embeds
.
append
(
mm_embeds_item
)
return
mm_embeds
@
torch
.
no_grad
()
def
execute_model
(
...
...
@@ -607,25 +648,26 @@ class TPUModelRunner:
# Update cached state
self
.
_update_states
(
scheduler_output
)
if
not
scheduler_output
.
total_num_scheduled_tokens
:
# Return empty ModelRunnerOu
p
tut if there's no work to do.
# Return empty ModelRunnerOut
p
ut if there's no work to do.
return
EMPTY_MODEL_RUNNER_OUTPUT
if
self
.
is_multimodal_model
:
# Run the multimodal encoder if any.
self
.
_execute_encoder
(
scheduler_output
)
encoder_output
s
=
self
.
_gather_
encoder_output
s
(
scheduler_output
)
self
.
_execute_
mm_
encoder
(
scheduler_output
)
mm_embed
s
=
self
.
_gather_
mm_embedding
s
(
scheduler_output
)
else
:
encoder_output
s
=
[]
mm_embed
s
=
[]
# Prepare inputs
attn_metadata
,
logits_indices
=
self
.
_prepare_inputs
(
scheduler_output
)
attn_metadata
,
logits_indices
,
padded_num_reqs
=
self
.
_prepare_inputs
(
scheduler_output
)
if
self
.
is_multimodal_model
:
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
if
encoder_output
s
:
if
mm_embed
s
:
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
self
.
input_ids
,
encoder_output
s
)
self
.
input_ids
,
mm_embed
s
)
else
:
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
self
.
input_ids
)
input_ids
=
None
...
...
@@ -637,21 +679,19 @@ class TPUModelRunner:
input_ids
=
self
.
input_ids
inputs_embeds
=
None
num_reqs
=
self
.
input_batch
.
num_reqs
# NOTE (NickLucche) here we sync with TPU: sampling params tensors
# are copied to device in chunks of pre-compiled padded shape to
# avoid recompilations.
tpu_sampling_metadata
=
TPUSupportedSamplingMetadata
.
\
from_input_batch
(
self
.
input_batch
,
logits_indices
)
# Run the decoder
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
):
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
self
.
position_ids
,
kv_caches
=
self
.
kv_caches
,
inputs_embeds
=
inputs_embeds
,
)
selected_token_ids
=
self
.
model
.
sample_from_hidden
(
hidden_states
,
tpu_sampling_metadata
)
hidden_states
=
self
.
select_hidden_states
(
hidden_states
,
logits_indices
)
tpu_sampling_metadata
=
TPUSupportedSamplingMetadata
.
\
from_input_batch
(
self
.
input_batch
,
padded_num_reqs
,
self
.
device
)
selected_token_ids
=
self
.
sample_from_hidden
(
hidden_states
,
tpu_sampling_metadata
)
# Remove padding on cpu and keep dynamic op outside of xla graph.
selected_token_ids
=
selected_token_ids
.
cpu
()[:
num_reqs
]
...
...
@@ -751,17 +791,15 @@ class TPUModelRunner:
"get_tensor_model_parallel_rank"
,
return_value
=
xm_tp_rank
):
model
=
get_model
(
vllm_config
=
self
.
vllm_config
)
model
=
model
.
eval
()
# Sync all pending XLA execution during model initialization and weight
# loading.
xm
.
mark_step
()
xm
.
wait_device_ops
()
model
=
ModelWrapperV1
(
model
)
self
.
model
=
torch
.
compile
(
model
,
backend
=
"openxla"
,
fullgraph
=
True
,
dynamic
=
False
)
self
.
model
=
model
self
.
sampler
=
TPUSampler
()
@
torch
.
no_grad
()
def
_dummy_run
(
self
,
kv_caches
,
num_tokens
:
int
)
->
None
:
def
_dummy_run
(
self
,
num_tokens
:
int
)
->
None
:
if
self
.
is_multimodal_model
:
input_ids
=
None
inputs_embeds
=
torch
.
zeros
((
num_tokens
,
self
.
hidden_size
),
...
...
@@ -812,65 +850,81 @@ class TPUModelRunner:
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
0
):
out
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
position_ids
,
kv_caches
=
kv_caches
,
inputs_embeds
=
inputs_embeds
)
self
.
_hidden_states_dtype
=
out
.
dtype
def
capture_model
(
self
)
->
None
:
"""Compile the model."""
def
_precompile_backbone
(
self
)
->
None
:
logger
.
info
(
"Compiling the model with different input shapes."
)
start
=
time
.
perf_counter
()
for
num_tokens
in
self
.
num_tokens_paddings
:
logger
.
info
(
" -- num_tokens: %d"
,
num_tokens
)
self
.
_dummy_run
(
self
.
kv_caches
,
num_tokens
)
xm
.
mark_step
()
self
.
_dummy_run
(
num_tokens
)
xm
.
wait_device_ops
()
end
=
time
.
perf_counter
()
logger
.
info
(
"Compilation finished in in %.2f [secs]."
,
end
-
start
)
self
.
_update_num_xla_graphs
(
"model"
)
self
.
_update_num_xla_graphs
(
"model
backbone
"
)
logger
.
info
(
"Compiling sampling with different input shapes."
)
def
_precompile_select_hidden_states
(
self
)
->
None
:
# Compile hidden state selection function for bucketed
# n_tokens x max_num_reqs. Graph is really small so this is fine.
logger
.
info
(
"Compiling select_hidden_states with different input shapes."
)
start
=
time
.
perf_counter
()
hsize
=
self
.
model_config
.
get_hidden_size
()
device
=
self
.
device
# Compile sampling step for different model+sampler outputs in bucketed
# n_tokens x max_num_reqs. Graph is really small so this is fine.
for
num_tokens
in
self
.
num_tokens_paddings
:
num_reqs_to_sample
=
MIN_NUM_SEQS
dummy_hidden
=
torch
.
randn
((
num_tokens
,
hsize
),
device
=
device
,
dummy_hidden
=
torch
.
zeros
((
num_tokens
,
hsize
),
device
=
self
.
device
,
dtype
=
self
.
_hidden_states_dtype
)
# Compile for [8, 16, .., 128,.., `self.max_num_reqs`]
while
True
:
indices
=
torch
.
zeros
(
num_reqs_to_sample
,
dtype
=
torch
.
int32
,
device
=
device
,
)
xm
.
mark_step
()
sampling_meta
=
TPUSupportedSamplingMetadata
.
\
from_input_batch
(
self
.
input_batch
,
indices
)
logger
.
info
(
" -- num_tokens: %d, num_seqs: %d"
,
num_tokens
,
num_reqs_to_sample
)
out
=
self
.
model
.
sample_from_hidden
(
dummy_hidden
,
sampling_meta
)
out
=
out
.
cpu
()
# Requests can't be more than tokens. But do compile for the
# next bigger value in case num_tokens uses bucketed padding.
if
num_reqs_to_sample
>=
min
(
num_tokens
,
self
.
max_num_reqs
):
break
# Make sure to compile the `max_num_reqs` upper-limit case
num_reqs_to_sample
=
_get_padded_num_reqs_with_upper_limit
(
num_reqs_to_sample
+
1
,
self
.
max_num_reqs
)
torch
.
_dynamo
.
mark_dynamic
(
dummy_hidden
,
0
)
for
num_reqs
in
self
.
num_reqs_paddings
:
indices
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
torch
.
_dynamo
.
mark_dynamic
(
indices
,
0
)
self
.
select_hidden_states
(
dummy_hidden
,
indices
)
logger
.
info
(
" -- num_tokens: %d"
,
num_tokens
)
xm
.
wait_device_ops
()
end
=
time
.
perf_counter
()
logger
.
info
(
"Compilation finished in in %.2f [secs]."
,
end
-
start
)
self
.
_update_num_xla_graphs
(
"select_hidden_states"
)
def
_precompile_sample_from_hidden
(
self
)
->
None
:
logger
.
info
(
"Compiling sampling with different input shapes."
)
start
=
time
.
perf_counter
()
hsize
=
self
.
model_config
.
get_hidden_size
()
for
num_reqs
in
self
.
num_reqs_paddings
:
dummy_hidden
=
torch
.
zeros
((
num_reqs
,
hsize
),
device
=
self
.
device
,
dtype
=
self
.
_hidden_states_dtype
)
# The first dimension of dummy_hidden cannot be mark_dynamic because
# some operations in the sampler require it to be static.
for
all_greedy
in
[
False
,
True
]:
generate_params_if_all_greedy
=
not
all_greedy
sampling_metadata
=
(
TPUSupportedSamplingMetadata
.
from_input_batch
(
self
.
input_batch
,
num_reqs
,
self
.
device
,
generate_params_if_all_greedy
,
))
sampling_metadata
.
all_greedy
=
all_greedy
self
.
sample_from_hidden
(
dummy_hidden
,
sampling_metadata
)
logger
.
info
(
" -- num_seqs: %d"
,
num_reqs
)
xm
.
wait_device_ops
()
end
=
time
.
perf_counter
()
logger
.
info
(
"Compilation finished in in %.2f [secs]."
,
end
-
start
)
self
.
_update_num_xla_graphs
(
"sampling"
)
def
capture_model
(
self
)
->
None
:
"""
Precompile all the subgraphs with possible input shapes.
"""
# TODO: precompile encoder
self
.
_precompile_backbone
()
self
.
_precompile_select_hidden_states
()
self
.
_precompile_sample_from_hidden
()
def
initialize_kv_cache
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
"""
Initialize KV cache based on `kv_cache_config`.
...
...
@@ -910,73 +964,39 @@ class TPUModelRunner:
self
.
vllm_config
.
compilation_config
.
static_forward_context
,
self
.
kv_caches
)
class
ModelWrapperV1
(
nn
.
Module
):
def
__init__
(
self
,
model
:
nn
.
Module
):
super
().
__init__
()
self
.
model
=
model
self
.
sampler
=
TPUSampler
()
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
TPUSupportedSamplingMetadata
)
->
SamplerOutput
:
sampler_out
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
sampler_out
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
list
[
torch
.
Tensor
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Executes the forward pass of the model.
Args:
input_ids: The input token IDs of shape [num_tokens].
positions: The input position IDs of shape [num_tokens].
kv_caches: The key and value caches. They can be None during the
memory profiling at initialization.
inputs_embeds: The input embeddings of shape [num_tokens,
hidden_size]. It is used for multimodal models.
"""
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
inputs_embeds
=
inputs_embeds
,
)
return
hidden_states
def
reset_dynamo_cache
(
self
):
if
self
.
is_multimodal_model
:
compiled_model
=
self
.
model
.
get_language_model
().
model
else
:
compiled_model
=
self
.
model
.
model
if
isinstance
(
compiled_model
,
TorchCompileWrapperWithCustomDispatcher
):
logger
.
info
(
"Clear dynamo cache and cached dynamo bytecode."
)
torch
.
_dynamo
.
eval_frame
.
remove_from_cache
(
compiled_model
.
original_code_object
)
compiled_model
.
compiled_codes
.
clear
()
@
torch
.
compile
(
backend
=
"openxla"
,
fullgraph
=
True
,
dynamic
=
False
)
def
select_hidden_states
(
self
,
hidden_states
,
indices_do_sample
):
return
hidden_states
[
indices_do_sample
]
@
torch
.
compile
(
backend
=
"openxla"
,
fullgraph
=
True
,
dynamic
=
False
)
def
sample_from_hidden
(
self
,
hidden_states
:
torch
.
Tensor
,
sample_
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
TPUSupportedSamplingMetadata
,
)
->
torch
.
Tensor
:
"""
Sample with xla-friendly function. This function is to be traced
separately from `forward` for lighter compilation overhead.
"""
# Tensor `sample_hidden_states` is of fixed pre-compiled size.
sample_hidden_states
=
\
hidden_states
[
sampling_metadata
.
indices_do_sample
]
logits
=
self
.
compute_logits
(
sample_hidden_states
)
# Optimized greedy sampling branch, tracing both paths in a single pass
# NOTE all_greedy is a scalar, this is just an optimized if/else.
out_tokens
=
torch
.
where
(
sampling_metadata
.
all_greedy
,
torch
.
argmax
(
logits
,
dim
=-
1
,
keepdim
=
True
),
self
.
sample
(
logits
,
sampling_metadata
)
\
.
sampled_token_ids
)
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
if
sampling_metadata
.
all_greedy
:
out_tokens
=
torch
.
argmax
(
logits
,
dim
=-
1
,
keepdim
=
True
)
else
:
out_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
).
sampled_token_ids
return
out_tokens
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
# SamplingMetadata here for pruning output in LogitsProcessor, disabled
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
None
)
return
logits
def
get_multimodal_embeddings
(
self
,
*
args
,
**
kwargs
):
return
self
.
model
.
get_multimodal_embeddings
(
*
args
,
**
kwargs
)
...
...
@@ -984,17 +1004,26 @@ class ModelWrapperV1(nn.Module):
return
self
.
model
.
get_input_embeddings
(
*
args
,
**
kwargs
)
def
_get_padded_number
(
n
:
int
,
multiple
:
int
)
->
int
:
return
((
n
+
multiple
-
1
)
//
multiple
)
*
multiple
def
_get_req_paddings
(
min_req_size
:
int
,
max_req_size
:
int
)
->
list
[
int
]:
logger
.
info
(
"Preparing request paddings:"
)
# assert min_req_size is power of 2
assert
(
min_req_size
&
(
min_req_size
-
1
)
==
0
)
and
min_req_size
>
0
paddings
:
list
=
[]
num
=
max
(
MIN_NUM_SEQS
,
min_req_size
)
while
num
<=
max_req_size
and
(
len
(
paddings
)
==
0
or
paddings
[
-
1
]
!=
num
):
paddings
.
append
(
num
)
logger
.
info
(
" %d"
,
num
)
num
=
_get_padded_num_reqs_with_upper_limit
(
num
+
1
,
max_req_size
)
return
paddings
def
_get_padded_num_reqs_with_upper_limit
(
x
,
upper_limit
)
->
int
:
def
_get_padded_num_reqs_with_upper_limit
(
x
:
int
,
upper_limit
:
int
)
->
int
:
res
=
MIN_NUM_SEQS
if
x
<=
MIN_NUM_SEQS
else
1
<<
(
x
-
1
).
bit_length
()
return
min
(
res
,
upper_limit
)
def
_get_paddings
(
min_token_size
:
int
,
max_token_size
:
int
,
padding_gap
:
int
)
->
list
[
int
]:
def
_get_
token_
paddings
(
min_token_size
:
int
,
max_token_size
:
int
,
padding_gap
:
int
)
->
list
[
int
]:
"""Generate a list of padding size, starting from min_token_size,
ending with a number that can cover max_token_size
...
...
@@ -1004,18 +1033,20 @@ def _get_paddings(min_token_size: int, max_token_size: int,
first increase the size to twice,
then increase the padding size by padding_gap.
"""
# assert min_token_size is power of 2
assert
(
min_token_size
&
(
min_token_size
-
1
)
==
0
)
and
min_token_size
>
0
paddings
=
[]
num
=
min_token_size
if
padding_gap
==
0
:
logger
.
info
(
"Using exponential paddings:"
)
logger
.
info
(
"Using exponential
token
paddings:"
)
while
num
<=
max_token_size
:
logger
.
info
(
" %d"
,
num
)
paddings
.
append
(
num
)
num
*=
2
else
:
logger
.
info
(
"Using incremental paddings:"
)
logger
.
info
(
"Using incremental
token
paddings:"
)
while
num
<=
padding_gap
:
logger
.
info
(
" %d"
,
num
)
paddings
.
append
(
num
)
...
...
vllm/v1/worker/tpu_worker.py
View file @
9c4ecf15
...
...
@@ -157,13 +157,19 @@ class TPUWorker:
runner_kv_caches
)
self
.
model_runner
.
_dummy_run
(
runner_kv_caches
,
num_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
,
)
self
.
scheduler_config
.
max_num_batched_tokens
)
# Synchronize before measuring the memory usage.
xm
.
wait_device_ops
()
# During the profiling run, the model runs without KV cache. After
# the profiling run, the model always runs with KV cache. Here we clear
# the dynamo cache and cached bytecode to ensure the model always has
# one compiled bytecode. Having one FX graph/cached bytecode per
# compiled model is required for `support_torch_compile` decorator to
# skip dynamo guard.
self
.
model_runner
.
reset_dynamo_cache
()
# Get the maximum amount of memory used by the model weights and
# intermediate activations.
m
=
xm
.
get_memory_info
(
self
.
device
)
...
...
vllm/v1/worker/utils.py
View file @
9c4ecf15
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
import
torch
...
...
@@ -27,3 +29,46 @@ def sanity_check_mm_encoder_outputs(
f
"but got tensors with shapes
{
[
e
.
shape
for
e
in
mm_embeddings
]
}
"
"instead. This is most likely due to incorrect implementation "
"of the model's `get_multimodal_embeddings` method."
)
def
scatter_mm_placeholders
(
embeds
:
torch
.
Tensor
,
is_embed
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
"""
Scatter the multimodal embeddings into a contiguous tensor that represents
the placeholder tokens.
:class:`vllm.multimodal.processing.PromptUpdateDetails.is_embed`.
Args:
embeds: The multimodal embeddings.
Shape: `(num_embeds, embed_dim)`
is_embed: A boolean mask indicating which positions in the placeholder
tokens need to be filled with multimodal embeddings.
Shape: `(num_placeholders, num_embeds)`
"""
if
is_embed
is
None
:
return
embeds
placeholders
=
embeds
.
new_full
(
(
is_embed
.
shape
[
0
],
embeds
.
shape
[
-
1
]),
fill_value
=
torch
.
nan
,
)
placeholders
[
is_embed
]
=
embeds
return
placeholders
def
gather_mm_placeholders
(
placeholders
:
torch
.
Tensor
,
is_embed
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
"""
Reconstructs the embeddings from the placeholder tokens.
This is the operation of :func:`scatter_mm_placeholders`.
"""
if
is_embed
is
None
:
return
placeholders
return
placeholders
[
is_embed
]
vllm/worker/enc_dec_model_runner.py
View file @
9c4ecf15
...
...
@@ -16,6 +16,7 @@ from vllm.config import VllmConfig
from
vllm.forward_context
import
set_forward_context
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
MultiModalKwargs
,
...
...
@@ -34,6 +35,7 @@ from vllm.worker.model_runner_base import (
from
vllm.worker.utils
import
assert_enc_dec_mr_supported_scenario
logger
=
init_logger
(
__name__
)
LORA_WARMUP_RANK
=
8
@
dataclasses
.
dataclass
(
frozen
=
True
)
...
...
@@ -160,7 +162,11 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
if
num_steps
>
1
:
raise
ValueError
(
"num_steps > 1 is not supported in "
"EncoderDecoderModelRunner"
)
if
self
.
lora_config
:
assert
model_input
.
lora_requests
is
not
None
assert
model_input
.
lora_mapping
is
not
None
self
.
set_active_loras
(
model_input
.
lora_requests
,
model_input
.
lora_mapping
)
if
(
model_input
.
attn_metadata
is
not
None
and
model_input
.
attn_metadata
.
prefill_metadata
is
None
and
model_input
.
attn_metadata
.
decode_metadata
.
use_cuda_graph
):
...
...
@@ -268,6 +274,22 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
max_num_batched_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
max_num_seqs
=
self
.
scheduler_config
.
max_num_seqs
# This represents the maximum number of different requests
# that will have unique loras, and therefore the max amount of
# memory consumption. Create dummy lora request copies from the
# lora request passed in, which contains a lora from the lora
# warmup path.
dummy_lora_requests
:
List
[
LoRARequest
]
=
[]
dummy_lora_requests_per_seq
:
List
[
LoRARequest
]
=
[]
if
self
.
lora_config
:
dummy_lora_requests
=
self
.
_add_dummy_loras
(
self
.
lora_config
.
max_loras
)
assert
len
(
dummy_lora_requests
)
==
self
.
lora_config
.
max_loras
dummy_lora_requests_per_seq
=
[
dummy_lora_requests
[
idx
%
len
(
dummy_lora_requests
)]
for
idx
in
range
(
max_num_seqs
)
]
# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
seqs
:
List
[
SequenceGroupMetadata
]
=
[]
...
...
@@ -315,6 +337,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
block_tables
=
None
,
encoder_seq_data
=
encoder_dummy_data
.
seq_data
,
cross_block_table
=
None
,
lora_request
=
dummy_lora_requests_per_seq
[
group_id
]
if
dummy_lora_requests_per_seq
else
None
,
multi_modal_data
=
decoder_dummy_data
.
multi_modal_data
or
encoder_dummy_data
.
multi_modal_data
,
multi_modal_placeholders
=
decoder_dummy_data
.
...
...
vllm/worker/hpu_model_runner.py
View file @
9c4ecf15
...
...
@@ -32,6 +32,7 @@ from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler,
import
vllm.envs
as
envs
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.config
import
DeviceConfig
,
VllmConfig
from
vllm.distributed
import
broadcast_tensor_dict
from
vllm.distributed.parallel_state
import
get_world_group
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
...
...
@@ -44,11 +45,13 @@ from vllm.model_executor.layers.sampler import SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.sampling_metadata
import
SequenceGroupToSample
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensorInputs
,
MultiModalKwargs
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
IntermediateTensors
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
IntermediateTensors
,
Logprob
,
SequenceData
,
SequenceGroupMetadata
,
SequenceOutput
)
from
vllm.utils
import
(
bind_kv_cache
,
is_pin_memory_available
,
make_tensor_with_pad
)
from
vllm.worker.model_runner_base
import
(
...
...
@@ -100,7 +103,10 @@ def subtuple(obj: object,
if
to_override
is
None
:
to_override
=
{}
fields
=
set
(
to_copy
)
|
set
(
to_override
.
keys
())
values
=
{
f
:
to_override
.
get
(
f
,
getattr
(
obj
,
f
))
for
f
in
fields
}
if
type
(
obj
)
is
dict
:
values
=
{
key
:
obj
[
key
]
for
key
in
fields
if
key
in
obj
}
else
:
values
=
{
f
:
to_override
.
get
(
f
,
getattr
(
obj
,
f
))
for
f
in
fields
}
if
typename
not
in
_TYPE_CACHE
:
_TYPE_CACHE
[
typename
]
=
collections
.
namedtuple
(
typename
,
' '
.
join
(
fields
))
...
...
@@ -533,6 +539,8 @@ class ModelInputForHPU(ModelRunnerInputBase):
virtual_engine
:
int
=
0
lora_ids
:
Optional
[
List
[
int
]]
=
None
async_callback
:
Optional
[
Callable
]
=
None
is_first_multi_step
:
bool
=
True
is_last_step
:
bool
=
True
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
tensor_dict
=
{
...
...
@@ -545,6 +553,8 @@ class ModelInputForHPU(ModelRunnerInputBase):
"batch_size_padded"
:
self
.
batch_size_padded
,
"virtual_engine"
:
self
.
virtual_engine
,
"lora_ids"
:
self
.
lora_ids
,
"is_first_multi_step"
:
self
.
is_first_multi_step
,
"is_last_step"
:
self
.
is_last_step
,
}
_add_attn_metadata_broadcastable_dict
(
tensor_dict
,
self
.
attn_metadata
)
return
tensor_dict
...
...
@@ -656,6 +666,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
self
.
_set_gc_threshold
()
self
.
use_contiguous_pa
=
envs
.
VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH
# For multi-step scheduling
self
.
cached_step_outputs
:
List
[
torch
.
Tensor
]
=
[]
def
_set_gc_threshold
(
self
)
->
None
:
# Read https://docs.python.org/3/library/gc.html#gc.set_threshold
# for comprehensive description of gc generations.
...
...
@@ -1005,6 +1018,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
def
_prepare_decode
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
output
=
None
,
)
->
PrepareDecodeMetadata
:
input_tokens
:
List
[
List
[
int
]]
=
[]
input_positions
:
List
[
List
[
int
]]
=
[]
...
...
@@ -1035,8 +1049,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
for
seq_id
in
seq_ids
:
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
generation_token
=
seq_data
.
get_last_token_id
()
input_tokens
.
append
([
generation_token
])
if
output
is
None
:
generation_token
=
seq_data
.
get_last_token_id
()
input_tokens
.
append
([
generation_token
])
seq_len
=
seq_data
.
get_len
()
position
=
seq_len
-
1
...
...
@@ -1047,6 +1062,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
seq_lens
.
append
(
seq_len
)
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
num_fully_occupied_blocks
=
position
//
self
.
block_size
block_table
=
block_table
[:
num_fully_occupied_blocks
+
1
]
if
len
(
block_table
)
==
0
:
block_number
=
_PAD_BLOCK_ID
else
:
...
...
@@ -1066,9 +1084,14 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
block_table
=
block_table
[
-
sliding_window_blocks
:]
block_tables
.
append
(
block_table
)
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
if
output
is
None
:
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
else
:
real_batch_size
=
len
(
seq_group_metadata_list
)
input_tokens
=
output
[:
real_batch_size
]
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
...
...
@@ -1462,7 +1485,27 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
profiler
.
start
()
for
_
in
range
(
times
):
inputs
=
self
.
prepare_model_input
(
seqs
)
self
.
execute_model
(
inputs
,
None
,
warmup_mode
=
True
)
is_single_step
=
\
self
.
vllm_config
.
scheduler_config
.
num_scheduler_steps
==
1
if
is_prompt
or
is_single_step
:
self
.
execute_model
(
inputs
,
None
,
warmup_mode
=
True
)
else
:
# decode with multi-step
inputs
=
dataclasses
.
replace
(
inputs
,
is_first_multi_step
=
True
,
is_last_step
=
False
)
self
.
execute_model
(
inputs
,
None
,
warmup_mode
=
True
,
num_steps
=
2
,
seqs
=
seqs
)
inputs
=
dataclasses
.
replace
(
inputs
,
is_first_multi_step
=
False
,
is_last_step
=
True
)
self
.
execute_model
(
inputs
,
None
,
warmup_mode
=
True
,
num_steps
=
2
,
seqs
=
seqs
)
torch
.
hpu
.
synchronize
()
if
profiler
:
profiler
.
step
()
...
...
@@ -1985,115 +2028,273 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
num_steps
:
int
=
1
,
warmup_mode
=
False
,
seqs
=
None
,
)
->
Optional
[
Union
[
List
[
SamplerOutput
],
IntermediateTensors
]]:
if
num_steps
>
1
:
raise
ValueError
(
"num_steps > 1 is not supported in HPUModelRunner"
)
if
not
model_input
.
is_first_multi_step
:
if
not
model_input
.
is_last_step
:
# not first or last multi-step
return
[]
# last multi-step
output
=
self
.
_decode_sampler_outputs
(
model_input
)
if
self
.
is_driver_worker
else
[]
torch
.
hpu
.
synchronize
()
if
model_input
.
is_first_multi_step
:
# first multi-step
if
self
.
lora_config
:
assert
model_input
.
lora_requests
is
not
None
assert
model_input
.
lora_mapping
is
not
None
self
.
set_active_loras
(
model_input
.
lora_requests
,
model_input
.
lora_mapping
)
input_tokens
=
model_input
.
input_tokens
input_positions
=
model_input
.
input_positions
attn_metadata
=
model_input
.
attn_metadata
sampling_metadata
=
model_input
.
sampling_metadata
real_batch_size
=
model_input
.
real_batch_size
batch_size_padded
=
model_input
.
batch_size_padded
assert
input_tokens
is
not
None
assert
input_positions
is
not
None
assert
sampling_metadata
is
not
None
assert
attn_metadata
is
not
None
is_prompt
=
attn_metadata
.
is_prompt
assert
is_prompt
is
not
None
batch_size
=
input_tokens
.
size
(
0
)
seq_len
=
self
.
_seq_len
(
attn_metadata
)
use_graphs
=
self
.
_use_graphs
(
batch_size
,
seq_len
,
is_prompt
)
self
.
_check_config
(
batch_size
,
seq_len
,
is_prompt
,
warmup_mode
)
lora_mask
:
torch
.
Tensor
=
None
lora_logits_mask
:
torch
.
Tensor
=
None
if
self
.
lora_config
:
assert
model_input
.
lora_ids
is
not
None
lora_mask
,
lora_logits_mask
=
self
.
create_lora_mask
(
input_tokens
,
model_input
.
lora_ids
,
attn_metadata
.
is_prompt
)
execute_model_kwargs
=
{
"input_ids"
:
input_tokens
,
"positions"
:
input_positions
,
"attn_metadata"
:
self
.
trim_attn_metadata
(
attn_metadata
),
"intermediate_tensors"
:
intermediate_tensors
,
"lora_mask"
:
lora_mask
,
"virtual_engine"
:
model_input
.
virtual_engine
,
**
(
model_input
.
multi_modal_kwargs
or
{}),
}
if
htorch
.
utils
.
internal
.
is_lazy
():
execute_model_kwargs
.
update
(
{
"bypass_hpu_graphs"
:
not
use_graphs
})
if
self
.
lora_config
:
assert
model_input
.
lora_requests
is
not
None
assert
model_input
.
lora_mapping
is
not
None
self
.
set_active_loras
(
model_input
.
lora_requests
,
model_input
.
lora_mapping
)
input_tokens
=
model_input
.
input_tokens
input_positions
=
model_input
.
input_positions
attn_metadata
=
model_input
.
attn_metadata
sampling_metadata
=
model_input
.
sampling_metadata
real_batch_size
=
model_input
.
real_batch_size
batch_size_padded
=
model_input
.
batch_size_padded
assert
input_tokens
is
not
None
assert
input_positions
is
not
None
assert
sampling_metadata
is
not
None
assert
attn_metadata
is
not
None
is_prompt
=
attn_metadata
.
is_prompt
assert
is_prompt
is
not
None
batch_size
=
input_tokens
.
size
(
0
)
seq_len
=
self
.
_seq_len
(
attn_metadata
)
use_graphs
=
self
.
_use_graphs
(
batch_size
,
seq_len
,
is_prompt
)
self
.
_check_config
(
batch_size
,
seq_len
,
is_prompt
,
warmup_mode
)
htorch
.
core
.
mark_step
()
if
self
.
is_driver_worker
:
model_event_name
=
(
"model_"
f
"
{
'prompt'
if
is_prompt
else
'decode'
}
_"
f
"bs
{
batch_size
}
_"
f
"seq
{
seq_len
}
_"
f
"graphs
{
'T'
if
use_graphs
else
'F'
}
"
)
else
:
model_event_name
=
'model_executable'
if
num_steps
>
1
:
# in case of multi-step scheduling
# we only want to pythonize in the last step
sampling_metadata
.
skip_sampler_cpu_output
=
True
self
.
model
.
model
.
sampler
.
include_gpu_probs_tensor
=
True
cache_orig_output_tokens_len
:
List
[
Dict
]
=
[]
def
try_revert_dummy_output_tokens
():
if
len
(
cache_orig_output_tokens_len
)
>
0
:
# Reuse the original output token ids length
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
for
j
,
data
in
seq_group_metadata
.
seq_data
.
items
():
orig_output_tokens_len
=
\
cache_orig_output_tokens_len
[
i
][
j
]
data
.
output_token_ids
=
\
data
.
output_token_ids
[:
orig_output_tokens_len
]
for
i
in
range
(
num_steps
):
if
i
!=
0
and
not
self
.
is_driver_worker
:
broadcast_data
=
broadcast_tensor_dict
(
src
=
0
)
if
'early_exit'
in
broadcast_data
and
broadcast_data
[
'early_exit'
]:
return
[
output
]
if
num_steps
==
1
else
[]
execute_model_kwargs
.
update
({
"input_ids"
:
broadcast_data
[
"input_ids"
],
"positions"
:
broadcast_data
[
"positions"
],
"attn_metadata"
:
self
.
trim_attn_metadata
(
broadcast_data
[
"attn_metadata"
])
})
with
self
.
profiler
.
record_event
(
'internal'
,
model_event_name
):
hidden_states
=
self
.
model
.
forward
(
**
execute_model_kwargs
,
selected_token_indices
=
sampling_metadata
.
selected_token_indices
)
if
self
.
lora_config
:
LoraMask
.
setLoraMask
(
lora_logits_mask
.
index_select
(
0
,
sampling_metadata
.
selected_token_indices
))
# Compute the logits.
with
self
.
profiler
.
record_event
(
'internal'
,
(
'compute_logits_'
f
'
{
"prompt"
if
is_prompt
else
"decode"
}
_bs'
f
'
{
batch_size
}
_'
f
'seq
{
seq_len
}
'
)):
if
num_steps
==
1
:
sampling_metadata
.
selected_token_indices
=
None
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
htorch
.
core
.
mark_step
()
# Only perform sampling in the driver worker.
if
not
self
.
is_driver_worker
:
continue
lora_mask
:
torch
.
Tensor
=
None
lora_logits_mask
:
torch
.
Tensor
=
None
if
self
.
lora_config
:
assert
model_input
.
lora_ids
is
not
None
lora_mask
,
lora_logits_mask
=
self
.
create_lora_mask
(
input_tokens
,
model_input
.
lora_ids
,
attn_metadata
.
is_prompt
)
execute_model_kwargs
=
{
"input_ids"
:
input_tokens
,
"positions"
:
input_positions
,
"attn_metadata"
:
self
.
trim_attn_metadata
(
attn_metadata
),
"intermediate_tensors"
:
intermediate_tensors
,
"lora_mask"
:
lora_mask
,
"virtual_engine"
:
model_input
.
virtual_engine
,
**
(
model_input
.
multi_modal_kwargs
or
{}),
}
if
htorch
.
utils
.
internal
.
is_lazy
():
execute_model_kwargs
.
update
({
"bypass_hpu_graphs"
:
not
use_graphs
})
htorch
.
core
.
mark_step
()
if
self
.
is_driver_worker
:
model_event_name
=
(
"model_"
f
"
{
'prompt'
if
is_prompt
else
'decode'
}
_"
f
"bs
{
batch_size
}
_"
f
"seq
{
seq_len
}
_"
f
"graphs
{
'T'
if
use_graphs
else
'F'
}
"
)
if
model_input
.
async_callback
is
not
None
:
model_input
.
async_callback
()
# Sample the next token.
with
self
.
profiler
.
record_event
(
'internal'
,
(
'sample_'
f
'
{
"prompt"
if
is_prompt
else
"decode"
}
_'
f
'bs
{
batch_size
}
_'
f
'seq
{
seq_len
}
'
)):
output
=
self
.
model
.
sample
(
logits
=
logits
,
sampling_metadata
=
sampling_metadata
,
)
if
num_steps
>
1
:
output
=
output
.
sampled_token_ids
self
.
cached_step_outputs
.
append
(
output
.
detach
().
clone
())
htorch
.
core
.
mark_step
()
if
i
<
num_steps
-
1
:
if
i
==
0
:
if
model_input
.
async_callback
is
not
None
:
ctx
=
model_input
.
async_callback
.
keywords
[
# type: ignore
"ctx"
]
seq_group_metadata_list
=
\
ctx
.
seq_group_metadata_list
elif
seqs
is
not
None
:
seq_group_metadata_list
=
seqs
else
:
raise
RuntimeError
(
"seq_group_metadata_list is uninitialized"
)
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
# Skip empty steps
seq_group_metadata
.
state
.
current_step
+=
(
num_steps
-
2
)
# Cache the original output token ids
cache_orig_output_tokens_len
.
append
({})
for
j
,
data
in
seq_group_metadata
.
seq_data
.
items
():
cache_orig_output_tokens_len
[
i
][
j
]
=
\
len
(
data
.
output_token_ids
)
for
seq_group_metadata
in
seq_group_metadata_list
:
for
data
in
seq_group_metadata
.
seq_data
.
values
():
max_output_len
=
sampling_metadata
.
seq_groups
[
0
].
sampling_params
.
max_tokens
if
len
(
data
.
output_token_ids
)
<
max_output_len
-
1
:
# add a place holder for prepare_decode
# arbitrary value, this could be any token
dummy_token
=
(
540
,
)
data
.
output_token_ids
+=
(
dummy_token
)
else
:
broadcast_tensor_dict
({
'early_exit'
:
True
},
src
=
0
)
if
num_steps
==
1
:
return
[
output
]
else
:
try_revert_dummy_output_tokens
()
return
[]
result
=
self
.
_prepare_decode
(
seq_group_metadata_list
,
output
=
output
)
execute_model_kwargs
.
update
({
"input_ids"
:
result
.
input_tokens
,
"positions"
:
result
.
input_positions
,
"attn_metadata"
:
self
.
trim_attn_metadata
(
result
.
attn_metadata
)
})
model_kwargs_broadcast_data
=
{
"input_ids"
:
result
.
input_tokens
,
"positions"
:
result
.
input_positions
,
"attn_metadata"
:
vars
(
result
.
attn_metadata
)
}
broadcast_tensor_dict
(
model_kwargs_broadcast_data
,
src
=
0
)
else
:
try_revert_dummy_output_tokens
()
if
self
.
is_driver_worker
and
self
.
profiler
.
enabled
:
# Stop recording 'execute_model' event
self
.
profiler
.
end
()
event_end
=
self
.
profiler
.
get_timestamp_us
()
counters
=
self
.
profiler_counter_helper
.
get_counter_dict
(
cache_config
=
self
.
cache_config
,
duration
=
event_end
-
self
.
event_start
,
seq_len
=
seq_len
,
batch_size_padded
=
batch_size_padded
,
real_batch_size
=
real_batch_size
,
is_prompt
=
is_prompt
)
self
.
profiler
.
record_counter
(
self
.
event_start
,
counters
)
if
num_steps
==
1
:
return
[
output
]
if
self
.
is_driver_worker
else
[]
else
:
return
[]
return
output
if
type
(
output
)
is
list
else
[
output
]
def
_decode_sampler_outputs
(
self
,
model_input
):
use_async_out_proc
=
model_input
.
async_callback
is
not
None
sampler_outputs
=
[]
num_outputs
=
len
(
self
.
cached_step_outputs
)
for
i
in
range
(
num_outputs
):
next_token_ids
=
self
.
cached_step_outputs
.
pop
(
0
)
next_token_ids
=
next_token_ids
.
cpu
().
tolist
()
sampler_output
=
self
.
_make_decode_output
(
next_token_ids
,
model_input
.
sampling_metadata
.
seq_groups
)
sampler_outputs
.
append
(
sampler_output
)
if
i
<
num_outputs
-
1
and
use_async_out_proc
:
assert
model_input
.
async_callback
is
not
None
ctx
=
model_input
.
async_callback
.
keywords
[
# type: ignore
"ctx"
]
ctx
.
append_output
(
outputs
=
[
sampler_output
],
seq_group_metadata_list
=
ctx
.
seq_group_metadata_list
,
scheduler_outputs
=
ctx
.
scheduler_outputs
,
is_async
=
False
,
is_last_step
=
False
,
is_first_step_output
=
False
)
model_input
.
async_callback
()
if
use_async_out_proc
:
return
[
sampler_outputs
[
-
1
]]
else
:
model_event_name
=
'model_executable'
with
self
.
profiler
.
record_event
(
'internal'
,
model_event_name
):
hidden_states
=
self
.
model
.
forward
(
**
execute_model_kwargs
,
selected_token_indices
=
sampling_metadata
.
selected_token_indices
)
return
sampler_outputs
if
self
.
lora_config
:
LoraMask
.
setLoraMask
(
lora_logits_mask
.
index_select
(
0
,
sampling_metadata
.
selected_token_indices
))
# Compute the logits.
with
self
.
profiler
.
record_event
(
'internal'
,
(
'compute_logits_'
f
'
{
"prompt"
if
is_prompt
else
"decode"
}
_bs'
f
'
{
batch_size
}
_'
f
'seq
{
seq_len
}
'
)):
sampling_metadata
.
selected_token_indices
=
None
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
htorch
.
core
.
mark_step
()
# Only perform sampling in the driver worker.
if
not
self
.
is_driver_worker
:
return
[]
if
model_input
.
async_callback
is
not
None
:
model_input
.
async_callback
()
# Sample the next token.
with
self
.
profiler
.
record_event
(
'internal'
,
(
'sample_'
f
'
{
"prompt"
if
is_prompt
else
"decode"
}
_'
f
'bs
{
batch_size
}
_'
f
'seq
{
seq_len
}
'
)):
output
=
self
.
model
.
sample
(
logits
=
logits
,
sampling_metadata
=
sampling_metadata
,
)
output
.
outputs
=
output
.
outputs
[:
real_batch_size
]
htorch
.
core
.
mark_step
()
if
self
.
is_driver_worker
and
self
.
profiler
.
enabled
:
# Stop recording 'execute_model' event
self
.
profiler
.
end
()
event_end
=
self
.
profiler
.
get_timestamp_us
()
counters
=
self
.
profiler_counter_helper
.
get_counter_dict
(
cache_config
=
self
.
cache_config
,
duration
=
event_end
-
self
.
event_start
,
seq_len
=
seq_len
,
batch_size_padded
=
batch_size_padded
,
real_batch_size
=
real_batch_size
,
is_prompt
=
is_prompt
)
self
.
profiler
.
record_counter
(
self
.
event_start
,
counters
)
return
[
output
]
def
_make_decode_output
(
self
,
next_token_ids
:
List
[
List
[
int
]],
seq_groups
:
List
[
SequenceGroupToSample
],
)
->
SamplerOutput
:
zero_logprob
=
Logprob
(
0.0
)
sampler_outputs
=
[]
batch_idx
=
0
for
seq_group
in
seq_groups
:
seq_ids
=
seq_group
.
seq_ids
seq_outputs
=
[]
for
seq_id
in
seq_ids
:
next_token_id
=
next_token_ids
[
batch_idx
][
0
]
seq_outputs
.
append
(
SequenceOutput
(
seq_id
,
next_token_id
,
{
next_token_id
:
zero_logprob
}))
batch_idx
+=
1
sampler_outputs
.
append
(
CompletionSequenceGroupOutput
(
seq_outputs
,
None
))
return
SamplerOutput
(
sampler_outputs
)
def
shutdown_inc
(
self
):
can_finalize_inc
=
False
...
...
Prev
1
…
13
14
15
16
17
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment