Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9c4ecf15
Commit
9c4ecf15
authored
Apr 14, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.4' into v0.8.4-ori
parents
bfc2d6f7
dc1b4a6f
Changes
342
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1065 additions
and
507 deletions
+1065
-507
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+57
-19
vllm/v1/engine/mm_input_cache.py
vllm/v1/engine/mm_input_cache.py
+34
-7
vllm/v1/engine/processor.py
vllm/v1/engine/processor.py
+97
-39
vllm/v1/executor/multiproc_executor.py
vllm/v1/executor/multiproc_executor.py
+8
-7
vllm/v1/metrics/loggers.py
vllm/v1/metrics/loggers.py
+4
-3
vllm/v1/request.py
vllm/v1/request.py
+10
-6
vllm/v1/sample/sampler.py
vllm/v1/sample/sampler.py
+10
-0
vllm/v1/sample/tpu/metadata.py
vllm/v1/sample/tpu/metadata.py
+43
-46
vllm/v1/serial_utils.py
vllm/v1/serial_utils.py
+125
-41
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+36
-25
vllm/v1/structured_output/backend_guidance.py
vllm/v1/structured_output/backend_guidance.py
+3
-3
vllm/v1/structured_output/backend_xgrammar.py
vllm/v1/structured_output/backend_xgrammar.py
+7
-1
vllm/v1/structured_output/utils.py
vllm/v1/structured_output/utils.py
+1
-2
vllm/v1/utils.py
vllm/v1/utils.py
+7
-10
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+50
-32
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+180
-149
vllm/v1/worker/tpu_worker.py
vllm/v1/worker/tpu_worker.py
+9
-3
vllm/v1/worker/utils.py
vllm/v1/worker/utils.py
+45
-0
vllm/worker/enc_dec_model_runner.py
vllm/worker/enc_dec_model_runner.py
+25
-1
vllm/worker/hpu_model_runner.py
vllm/worker/hpu_model_runner.py
+314
-113
No files found.
vllm/v1/engine/core_client.py
View file @
9c4ecf15
...
@@ -26,7 +26,7 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
...
@@ -26,7 +26,7 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType
,
UtilityOutput
)
EngineCoreRequestType
,
UtilityOutput
)
from
vllm.v1.engine.core
import
EngineCore
,
EngineCoreProc
from
vllm.v1.engine.core
import
EngineCore
,
EngineCoreProc
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
,
bytestr
from
vllm.v1.utils
import
BackgroundProcHandle
from
vllm.v1.utils
import
BackgroundProcHandle
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -402,6 +402,36 @@ class MPClient(EngineCoreClient):
...
@@ -402,6 +402,36 @@ class MPClient(EngineCoreClient):
self
.
utility_results
:
dict
[
int
,
AnyFuture
]
=
{}
self
.
utility_results
:
dict
[
int
,
AnyFuture
]
=
{}
def
_wait_for_engine_startup
(
self
):
# Get a sync handle to the socket which can be sync or async.
sync_input_socket
=
zmq
.
Socket
.
shadow
(
self
.
input_socket
)
# Wait for engine core process(es) to send ready messages.
identities
=
set
(
eng
.
index
for
eng
in
self
.
resources
.
core_engines
)
poller
=
zmq
.
Poller
()
poller
.
register
(
sync_input_socket
,
zmq
.
POLLIN
)
for
eng
in
self
.
resources
.
core_engines
:
poller
.
register
(
eng
.
proc_handle
,
zmq
.
POLLIN
)
while
identities
:
events
=
poller
.
poll
(
STARTUP_POLL_PERIOD_MS
)
if
not
events
:
logger
.
debug
(
"Waiting for %d core engine proc(s) to start: %s"
,
len
(
identities
),
identities
)
continue
if
len
(
events
)
>
1
or
events
[
0
][
0
]
!=
sync_input_socket
:
# One of the core processes exited.
raise
RuntimeError
(
"Engine core initialization failed. "
"See root cause above."
)
eng_id_bytes
,
msg
=
sync_input_socket
.
recv_multipart
()
eng_id
=
int
.
from_bytes
(
eng_id_bytes
,
byteorder
=
"little"
)
if
eng_id
not
in
identities
:
raise
RuntimeError
(
f
"Unexpected or duplicate engine:
{
eng_id
}
"
)
if
msg
!=
b
'READY'
:
raise
RuntimeError
(
f
"Engine
{
eng_id
}
failed:
{
msg
.
decode
()
}
"
)
logger
.
info
(
"Core engine process %d ready."
,
eng_id
)
identities
.
discard
(
eng_id
)
def
_init_core_engines
(
def
_init_core_engines
(
self
,
self
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
...
@@ -472,8 +502,8 @@ class SyncMPClient(MPClient):
...
@@ -472,8 +502,8 @@ class SyncMPClient(MPClient):
# shutdown signal, exit thread.
# shutdown signal, exit thread.
break
break
frame
=
out_socket
.
recv
(
copy
=
False
)
frame
s
=
out_socket
.
recv
_multipart
(
copy
=
False
)
outputs
=
decoder
.
decode
(
frame
.
buffer
)
outputs
=
decoder
.
decode
(
frame
s
)
if
outputs
.
utility_output
:
if
outputs
.
utility_output
:
_process_utility_output
(
outputs
.
utility_output
,
_process_utility_output
(
outputs
.
utility_output
,
utility_results
)
utility_results
)
...
@@ -494,10 +524,10 @@ class SyncMPClient(MPClient):
...
@@ -494,10 +524,10 @@ class SyncMPClient(MPClient):
return
self
.
outputs_queue
.
get
()
return
self
.
outputs_queue
.
get
()
def
_send_input
(
self
,
request_type
:
EngineCoreRequestType
,
request
:
Any
):
def
_send_input
(
self
,
request_type
:
EngineCoreRequestType
,
request
:
Any
):
# (RequestType, SerializedRequest)
# (
Identity,
RequestType, SerializedRequest)
msg
=
(
request_type
.
value
,
self
.
encoder
.
encode
(
request
))
msg
=
(
self
.
core_engine
.
identity
,
request_type
.
value
,
self
.
core_engine
.
send_multipart
(
msg
)
*
self
.
encoder
.
encode
(
request
)
)
self
.
input_socket
.
send_multipart
(
msg
,
copy
=
False
)
def
call_utility
(
self
,
method
:
str
,
*
args
)
->
Any
:
def
call_utility
(
self
,
method
:
str
,
*
args
)
->
Any
:
call_id
=
uuid
.
uuid1
().
int
>>
64
call_id
=
uuid
.
uuid1
().
int
>>
64
future
:
Future
[
Any
]
=
Future
()
future
:
Future
[
Any
]
=
Future
()
...
@@ -599,8 +629,8 @@ class AsyncMPClient(MPClient):
...
@@ -599,8 +629,8 @@ class AsyncMPClient(MPClient):
async
def
process_outputs_socket
():
async
def
process_outputs_socket
():
while
True
:
while
True
:
(
frame
,
)
=
await
output_socket
.
recv_multipart
(
copy
=
False
)
frame
s
=
await
output_socket
.
recv_multipart
(
copy
=
False
)
outputs
:
EngineCoreOutputs
=
decoder
.
decode
(
frame
.
buffer
)
outputs
:
EngineCoreOutputs
=
decoder
.
decode
(
frame
s
)
if
outputs
.
utility_output
:
if
outputs
.
utility_output
:
_process_utility_output
(
outputs
.
utility_output
,
_process_utility_output
(
outputs
.
utility_output
,
utility_results
)
utility_results
)
...
@@ -625,12 +655,20 @@ class AsyncMPClient(MPClient):
...
@@ -625,12 +655,20 @@ class AsyncMPClient(MPClient):
assert
self
.
outputs_queue
is
not
None
assert
self
.
outputs_queue
is
not
None
return
await
self
.
outputs_queue
.
get
()
return
await
self
.
outputs_queue
.
get
()
async
def
_send_input
(
self
,
request_type
:
EngineCoreRequestType
,
def
_send_input
(
self
,
request
:
Any
)
->
None
:
request_type
:
EngineCoreRequestType
,
await
self
.
core_engine
.
send_multipart
(
request
:
Any
,
(
request_type
.
value
,
self
.
encoder
.
encode
(
request
)))
engine
:
Optional
[
CoreEngine
]
=
None
)
->
Awaitable
[
None
]:
if
engine
is
None
:
engine
=
self
.
core_engine
self
.
_ensure_output_queue_task
()
message
=
(
request_type
.
value
,
*
self
.
encoder
.
encode
(
request
))
return
self
.
_send_input_message
(
message
,
engine
)
def
_send_input_message
(
self
,
message
:
tuple
[
bytestr
,
...],
engine
:
CoreEngine
)
->
Awaitable
[
None
]:
message
=
(
engine
.
identity
,
)
+
message
return
self
.
input_socket
.
send_multipart
(
message
,
copy
=
False
)
async
def
call_utility_async
(
self
,
method
:
str
,
*
args
)
->
Any
:
async
def
call_utility_async
(
self
,
method
:
str
,
*
args
)
->
Any
:
return
await
self
.
_call_utility_async
(
method
,
return
await
self
.
_call_utility_async
(
method
,
...
@@ -646,9 +684,9 @@ class AsyncMPClient(MPClient):
...
@@ -646,9 +684,9 @@ class AsyncMPClient(MPClient):
call_id
=
uuid
.
uuid1
().
int
>>
64
call_id
=
uuid
.
uuid1
().
int
>>
64
future
=
asyncio
.
get_running_loop
().
create_future
()
future
=
asyncio
.
get_running_loop
().
create_future
()
self
.
utility_results
[
call_id
]
=
future
self
.
utility_results
[
call_id
]
=
future
message
=
(
EngineCoreRequestType
.
UTILITY
.
value
,
message
=
(
EngineCoreRequestType
.
UTILITY
.
value
,
*
self
.
encoder
.
encode
(
self
.
encoder
.
encode
(
(
call_id
,
method
,
args
)))
(
call_id
,
method
,
args
)))
await
engine
.
send_multipart
(
messag
e
)
await
self
.
_send_input_message
(
message
,
engin
e
)
self
.
_ensure_output_queue_task
()
self
.
_ensure_output_queue_task
()
return
await
future
return
await
future
...
@@ -721,7 +759,7 @@ class DPAsyncMPClient(AsyncMPClient):
...
@@ -721,7 +759,7 @@ class DPAsyncMPClient(AsyncMPClient):
# Control message used for triggering dp idle mode loop.
# Control message used for triggering dp idle mode loop.
self
.
start_dp_msg
=
(
EngineCoreRequestType
.
START_DP
.
value
,
self
.
start_dp_msg
=
(
EngineCoreRequestType
.
START_DP
.
value
,
self
.
encoder
.
encode
(
None
))
*
self
.
encoder
.
encode
(
None
))
self
.
num_engines_running
=
0
self
.
num_engines_running
=
0
self
.
reqs_in_flight
:
dict
[
str
,
CoreEngine
]
=
{}
self
.
reqs_in_flight
:
dict
[
str
,
CoreEngine
]
=
{}
...
@@ -755,7 +793,7 @@ class DPAsyncMPClient(AsyncMPClient):
...
@@ -755,7 +793,7 @@ class DPAsyncMPClient(AsyncMPClient):
# tokenized.
# tokenized.
request
.
prompt
=
None
request
.
prompt
=
None
msg
=
(
EngineCoreRequestType
.
ADD
.
value
,
self
.
encoder
.
encode
(
request
))
msg
=
(
EngineCoreRequestType
.
ADD
.
value
,
*
self
.
encoder
.
encode
(
request
))
chosen_engine
=
self
.
get_core_engine_for_request
()
chosen_engine
=
self
.
get_core_engine_for_request
()
self
.
reqs_in_flight
[
request
.
request_id
]
=
chosen_engine
self
.
reqs_in_flight
[
request
.
request_id
]
=
chosen_engine
...
...
vllm/v1/engine/mm_input_cache.py
View file @
9c4ecf15
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
collections.abc
import
Sequence
from
typing
import
Optional
from
vllm.envs
import
VLLM_MM_INPUT_CACHE_GIB
from
vllm.envs
import
VLLM_MM_INPUT_CACHE_GIB
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.multimodal.processing
import
ProcessingCache
from
vllm.multimodal.processing
import
ProcessingCache
from
vllm.utils
import
is_list_of
# The idea of multimodal preprocessing caching is based on having a client and
# The idea of multimodal preprocessing caching is based on having a client and
# a server, where the client executes in the frontend process (=P0) and the
# a server, where the client executes in the frontend process (=P0) and the
...
@@ -11,9 +14,11 @@ from vllm.multimodal.processing import ProcessingCache
...
@@ -11,9 +14,11 @@ from vllm.multimodal.processing import ProcessingCache
# -- Client:
# -- Client:
# - BaseMultiModalProcessor to process MultiModalData into MultiModalKwargs
# - BaseMultiModalProcessor to process MultiModalData into MultiModalKwargs
# with built-in caching functionality, with mm_hash as its identifier.
# with built-in caching functionality, with mm_hash as its identifier.
# - MirroredProcessingCache to keep track of the cached entries and
# determine whether to send the MultiModalKwargs to P1.
#
#
# -- Server:
# -- Server:
# - M
MInputCacheServer to perform caching of the received
MultiModalKwargs.
# - M
irroredProcessingCache to store the
MultiModalKwargs
from P0
.
#
#
# The caching for both client and server is mirrored, and this allows us
# The caching for both client and server is mirrored, and this allows us
# to avoid the serialization of "mm_inputs" (like pixel values) between
# to avoid the serialization of "mm_inputs" (like pixel values) between
...
@@ -25,26 +30,48 @@ from vllm.multimodal.processing import ProcessingCache
...
@@ -25,26 +30,48 @@ from vllm.multimodal.processing import ProcessingCache
# variable VLLM_MM_INPUT_CACHE_GIB.
# variable VLLM_MM_INPUT_CACHE_GIB.
class
M
MInputCacheServer
:
class
M
irroredProcessingCache
:
def
__init__
(
self
,
model_config
):
def
__init__
(
self
,
model_config
):
self
.
use_cache
=
not
model_config
.
disable_mm_preprocessor_cache
self
.
use_cache
=
not
model_config
.
disable_mm_preprocessor_cache
self
.
mm_cache
=
ProcessingCache
.
get_lru_cache
(
VLLM_MM_INPUT_CACHE_GIB
,
self
.
mm_cache
=
ProcessingCache
.
get_lru_cache
(
VLLM_MM_INPUT_CACHE_GIB
,
MultiModalKwargs
)
MultiModalKwargs
)
def
get_and_update
(
def
get_and_update
_p0
(
self
,
self
,
mm_inputs
:
list
[
MultiModalKwargs
],
mm_inputs
:
Sequence
[
MultiModalKwargs
],
mm_hashes
:
list
[
str
],
mm_hashes
:
list
[
str
],
)
->
list
[
MultiModalKwargs
]:
)
->
Sequence
[
Optional
[
MultiModalKwargs
]
]
:
assert
len
(
mm_inputs
)
==
len
(
mm_hashes
)
assert
len
(
mm_inputs
)
==
len
(
mm_hashes
)
if
not
self
.
use_cache
:
if
not
self
.
use_cache
:
assert
is_list_of
(
mm_inputs
,
MultiModalKwargs
)
return
mm_inputs
return
mm_inputs
full_mm_inputs
=
[]
full_mm_inputs
=
list
[
Optional
[
MultiModalKwargs
]]()
for
mm_input
,
mm_hash
in
zip
(
mm_inputs
,
mm_hashes
):
if
mm_hash
in
self
.
mm_cache
:
mm_input
=
None
else
:
self
.
mm_cache
[
mm_hash
]
=
mm_input
full_mm_inputs
.
append
(
mm_input
)
return
full_mm_inputs
def
get_and_update_p1
(
self
,
mm_inputs
:
Sequence
[
Optional
[
MultiModalKwargs
]],
mm_hashes
:
list
[
str
],
)
->
Sequence
[
MultiModalKwargs
]:
assert
len
(
mm_inputs
)
==
len
(
mm_hashes
)
if
not
self
.
use_cache
:
assert
is_list_of
(
mm_inputs
,
MultiModalKwargs
)
return
mm_inputs
full_mm_inputs
=
list
[
MultiModalKwargs
]()
for
mm_input
,
mm_hash
in
zip
(
mm_inputs
,
mm_hashes
):
for
mm_input
,
mm_hash
in
zip
(
mm_inputs
,
mm_hashes
):
assert
mm_hash
is
not
None
if
mm_input
is
None
:
if
mm_input
is
None
:
mm_input
=
self
.
mm_cache
[
mm_hash
]
mm_input
=
self
.
mm_cache
[
mm_hash
]
else
:
else
:
...
...
vllm/v1/engine/processor.py
View file @
9c4ecf15
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
time
import
time
from
collections.abc
import
Mapping
from
collections.abc
import
Mapping
,
Sequence
from
typing
import
Optional
,
Union
from
typing
import
Literal
,
Optional
,
Union
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.inputs
import
ProcessorInputs
,
PromptType
from
vllm.inputs
import
ProcessorInputs
,
PromptType
,
SingletonInputs
from
vllm.inputs.parse
import
split_enc_dec_inputs
from
vllm.inputs.parse
import
split_enc_dec_inputs
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
MultiModalKwargs
,
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
MultiModalKwargs
,
MultiModalRegistry
)
MultiModalRegistry
)
from
vllm.multimodal.inputs
import
PlaceholderRange
from
vllm.multimodal.inputs
import
PlaceholderRange
from
vllm.multimodal.processing
import
EncDecMultiModalProcessor
from
vllm.multimodal.utils
import
merge_and_sort_multimodal_metadata
from
vllm.multimodal.utils
import
merge_and_sort_multimodal_metadata
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer_group
import
BaseTokenizerGroup
from
vllm.transformers_utils.tokenizer_group
import
BaseTokenizerGroup
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine.mm_input_cache
import
MirroredProcessingCache
from
vllm.v1.structured_output.backend_guidance
import
(
from
vllm.v1.structured_output.backend_guidance
import
(
validate_guidance_grammar
)
validate_guidance_grammar
)
from
vllm.v1.structured_output.utils
import
(
from
vllm.v1.structured_output.utils
import
(
...
@@ -46,6 +48,8 @@ class Processor:
...
@@ -46,6 +48,8 @@ class Processor:
self
.
tokenizer
,
self
.
tokenizer
,
mm_registry
)
mm_registry
)
self
.
mm_input_cache_client
=
MirroredProcessingCache
(
self
.
model_config
)
# Multi-modal hasher (for images)
# Multi-modal hasher (for images)
self
.
use_hash
=
(
self
.
use_hash
=
(
not
self
.
model_config
.
disable_mm_preprocessor_cache
)
or
\
not
self
.
model_config
.
disable_mm_preprocessor_cache
)
or
\
...
@@ -73,6 +77,7 @@ class Processor:
...
@@ -73,6 +77,7 @@ class Processor:
params
:
SamplingParams
,
params
:
SamplingParams
,
)
->
None
:
)
->
None
:
self
.
_validate_structured_output
(
params
)
self
.
_validate_structured_output
(
params
)
self
.
_validate_logit_bias
(
params
)
if
params
.
allowed_token_ids
is
None
:
if
params
.
allowed_token_ids
is
None
:
return
return
...
@@ -83,6 +88,26 @@ class Processor:
...
@@ -83,6 +88,26 @@ class Processor:
raise
ValueError
(
raise
ValueError
(
"allowed_token_ids contains out-of-vocab token id!"
)
"allowed_token_ids contains out-of-vocab token id!"
)
def
_validate_logit_bias
(
self
,
params
:
SamplingParams
,
)
->
None
:
"""Validate logit_bias token IDs are within vocabulary range."""
if
not
params
.
logit_bias
:
return
vocab_size
=
self
.
model_config
.
get_vocab_size
()
invalid_token_ids
=
[]
for
token_id
in
params
.
logit_bias
:
if
token_id
<
0
or
token_id
>=
vocab_size
:
invalid_token_ids
.
append
(
token_id
)
if
invalid_token_ids
:
raise
ValueError
(
f
"token_id(s)
{
invalid_token_ids
}
in logit_bias contain "
f
"out-of-vocab token ids. Vocabulary size:
{
vocab_size
}
"
)
def
_validate_supported_sampling_params
(
def
_validate_supported_sampling_params
(
self
,
self
,
params
:
SamplingParams
,
params
:
SamplingParams
,
...
@@ -136,9 +161,6 @@ class Processor:
...
@@ -136,9 +161,6 @@ class Processor:
f
" !=
{
engine_level_backend
}
"
)
f
" !=
{
engine_level_backend
}
"
)
else
:
else
:
params
.
guided_decoding
.
backend
=
engine_level_backend
params
.
guided_decoding
.
backend
=
engine_level_backend
import
vllm.platforms
if
vllm
.
platforms
.
current_platform
.
is_tpu
():
raise
ValueError
(
"Structured output is not supported on TPU."
)
# Request content validation
# Request content validation
if
engine_level_backend
.
startswith
(
"xgrammar"
):
if
engine_level_backend
.
startswith
(
"xgrammar"
):
...
@@ -181,6 +203,11 @@ class Processor:
...
@@ -181,6 +203,11 @@ class Processor:
# TODO(woosuk): Support pooling models.
# TODO(woosuk): Support pooling models.
# TODO(woosuk): Support encoder-decoder models.
# TODO(woosuk): Support encoder-decoder models.
from
vllm.platforms
import
current_platform
current_platform
.
validate_request
(
prompt
=
prompt
,
params
=
params
,
)
self
.
_validate_lora
(
lora_request
)
self
.
_validate_lora
(
lora_request
)
self
.
_validate_params
(
params
)
self
.
_validate_params
(
params
)
if
priority
!=
0
:
if
priority
!=
0
:
...
@@ -228,7 +255,7 @@ class Processor:
...
@@ -228,7 +255,7 @@ class Processor:
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
))
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
))
# Multimodal related.
# Multimodal related.
sorted_mm_inputs
:
Optional
[
list
[
MultiModalKwargs
]]
=
None
sorted_mm_inputs
:
Optional
[
Sequence
[
Optional
[
MultiModalKwargs
]]
]
=
None
sorted_mm_positions
:
Optional
[
list
[
PlaceholderRange
]]
=
None
sorted_mm_positions
:
Optional
[
list
[
PlaceholderRange
]]
=
None
sorted_mm_hashes
:
Optional
[
list
[
str
]]
=
None
sorted_mm_hashes
:
Optional
[
list
[
str
]]
=
None
if
decoder_inputs
[
"type"
]
==
"multimodal"
:
if
decoder_inputs
[
"type"
]
==
"multimodal"
:
...
@@ -253,20 +280,28 @@ class Processor:
...
@@ -253,20 +280,28 @@ class Processor:
# are multiple modalities.
# are multiple modalities.
unique_modalities
=
set
(
sorted_item_modalities
)
unique_modalities
=
set
(
sorted_item_modalities
)
if
len
(
unique_modalities
)
>
1
:
if
len
(
unique_modalities
)
>
1
:
sorted_mm_inputs
=
[]
orig_
sorted_mm_inputs
=
[]
used_indices
=
{
modality
:
0
for
modality
in
unique_modalities
}
used_indices
=
{
modality
:
0
for
modality
in
unique_modalities
}
for
modality
in
sorted_item_modalities
:
for
modality
in
sorted_item_modalities
:
items
=
decoder_mm_inputs
.
get_items
(
modality
)
items
=
decoder_mm_inputs
.
get_items
(
modality
)
item
=
items
[
used_indices
[
modality
]]
item
=
items
[
used_indices
[
modality
]]
sorted_mm_inputs
.
append
(
MultiModalKwargs
.
from_items
([
item
]))
orig_sorted_mm_inputs
.
append
(
MultiModalKwargs
.
from_items
([
item
]))
used_indices
[
modality
]
+=
1
used_indices
[
modality
]
+=
1
else
:
else
:
sorted_mm_inputs
=
[
orig_
sorted_mm_inputs
=
[
MultiModalKwargs
.
from_items
([
item
])
for
item
in
MultiModalKwargs
.
from_items
([
item
])
for
item
in
decoder_mm_inputs
.
get_items
(
sorted_item_modalities
[
0
])
decoder_mm_inputs
.
get_items
(
sorted_item_modalities
[
0
])
]
]
if
sorted_mm_hashes
is
not
None
:
sorted_mm_inputs
=
self
.
mm_input_cache_client
.
get_and_update_p0
(
orig_sorted_mm_inputs
,
sorted_mm_hashes
)
else
:
sorted_mm_inputs
=
orig_sorted_mm_inputs
return
EngineCoreRequest
(
return
EngineCoreRequest
(
request_id
=
request_id
,
request_id
=
request_id
,
prompt
=
decoder_inputs
.
get
(
"prompt"
),
prompt
=
decoder_inputs
.
get
(
"prompt"
),
...
@@ -285,41 +320,64 @@ class Processor:
...
@@ -285,41 +320,64 @@ class Processor:
lora_request
:
Optional
[
LoRARequest
]
=
None
):
lora_request
:
Optional
[
LoRARequest
]
=
None
):
encoder_inputs
,
decoder_inputs
=
split_enc_dec_inputs
(
inputs
)
encoder_inputs
,
decoder_inputs
=
split_enc_dec_inputs
(
inputs
)
# For encoder-decoder multimodal models, the max_prompt_len
if
encoder_inputs
is
not
None
:
# restricts the decoder prompt length
self
.
_validate_model_input
(
encoder_inputs
,
if
self
.
model_config
.
is_multimodal_model
:
lora_request
,
prompt_inputs
=
decoder_inputs
prompt_type
=
"encoder"
)
else
:
prompt_inputs
=
encoder_inputs
or
decoder_inputs
prompt_ids
=
prompt_inputs
[
"prompt_token_ids"
]
if
prompt_ids
is
None
or
len
(
prompt_ids
)
==
0
:
raise
ValueError
(
"Prompt cannot be empty"
)
max_input_id
=
max
(
prompt_ids
)
max_allowed
=
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
).
max_token_id
if
max_input_id
>
max_allowed
:
raise
ValueError
(
"Token id {} is out of vocabulary"
.
format
(
max_input_id
))
if
len
(
prompt_ids
)
>=
self
.
model_config
.
max_model_len
:
self
.
_validate_model_input
(
decoder_inputs
,
raise
ValueError
(
lora_request
,
f
"Prompt length of
{
len
(
prompt_ids
)
}
is longer than the "
prompt_type
=
"decoder"
)
f
"maximum model length of
{
self
.
model_config
.
max_model_len
}
."
)
if
self
.
model_config
.
is_multimodal_model
:
def
_validate_model_input
(
max_prompt_len
=
self
.
model_config
.
max_model_len
self
,
prompt_inputs
:
SingletonInputs
,
lora_request
:
Optional
[
LoRARequest
],
*
,
prompt_type
:
Literal
[
"encoder"
,
"decoder"
],
):
model_config
=
self
.
model_config
tokenizer
=
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
)
if
len
(
prompt_ids
)
>
max_prompt_len
:
prompt_ids
=
prompt_inputs
[
"prompt_token_ids"
]
raise
ValueError
(
if
not
prompt_ids
:
f
"The prompt (total length
{
len
(
prompt_ids
)
}
) is too long "
if
prompt_type
==
"encoder"
and
model_config
.
is_multimodal_model
:
f
"to fit into the model (context length
{
max_prompt_len
}
). "
pass
# Mllama may have empty encoder inputs for text-only data
else
:
raise
ValueError
(
f
"The
{
prompt_type
}
prompt cannot be empty"
)
max_input_id
=
max
(
prompt_ids
,
default
=
0
)
if
max_input_id
>
tokenizer
.
max_token_id
:
raise
ValueError
(
f
"Token id
{
max_input_id
}
is out of vocabulary"
)
max_prompt_len
=
self
.
model_config
.
max_model_len
if
len
(
prompt_ids
)
>=
max_prompt_len
:
if
prompt_type
==
"encoder"
and
model_config
.
is_multimodal_model
:
mm_registry
=
self
.
input_preprocessor
.
mm_registry
mm_processor
=
mm_registry
.
create_processor
(
model_config
,
tokenizer
=
tokenizer
,
)
assert
isinstance
(
mm_processor
,
EncDecMultiModalProcessor
)
if
mm_processor
.
pad_dummy_encoder_prompt
:
return
# Skip encoder length check for Whisper
if
model_config
.
is_multimodal_model
:
suggestion
=
(
"Make sure that `max_model_len` is no smaller than the "
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens plus multimodal tokens. For image "
"number of text tokens plus multimodal tokens. For image "
"inputs, the number of image tokens depends on the number "
"inputs, the number of image tokens depends on the number "
"of images, and possibly their aspect ratios as well."
)
"of images, and possibly their aspect ratios as well."
)
else
:
suggestion
=
(
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens."
)
raise
ValueError
(
f
"The
{
prompt_type
}
prompt (length
{
len
(
prompt_ids
)
}
) is "
f
"longer than the maximum model length of
{
max_prompt_len
}
. "
f
"
{
suggestion
}
"
)
# TODO: Find out how many placeholder tokens are there so we can
# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them
# check that chunked prefill does not truncate them
...
...
vllm/v1/executor/multiproc_executor.py
View file @
9c4ecf15
...
@@ -119,10 +119,9 @@ class MultiprocExecutor(Executor):
...
@@ -119,10 +119,9 @@ class MultiprocExecutor(Executor):
timeout
=
dequeue_timeout
)
timeout
=
dequeue_timeout
)
if
status
!=
WorkerProc
.
ResponseStatus
.
SUCCESS
:
if
status
!=
WorkerProc
.
ResponseStatus
.
SUCCESS
:
if
isinstance
(
result
,
Exception
):
raise
RuntimeError
(
raise
result
"Worker failed with error %s, please check the"
else
:
" stack trace above for the root cause"
,
result
)
raise
RuntimeError
(
"Worker failed"
)
responses
[
w
.
rank
]
=
result
responses
[
w
.
rank
]
=
result
...
@@ -327,7 +326,7 @@ class WorkerProc:
...
@@ -327,7 +326,7 @@ class WorkerProc:
logger
.
debug
(
"Worker interrupted."
)
logger
.
debug
(
"Worker interrupted."
)
except
Exception
:
except
Exception
:
# worker_busy_loop sends exceptions
exceptons
to Executor
# worker_busy_loop sends exceptions to Executor
# for shutdown, but if there is an error in startup or an
# for shutdown, but if there is an error in startup or an
# error with IPC itself, we need to alert the parent.
# error with IPC itself, we need to alert the parent.
psutil
.
Process
().
parent
().
send_signal
(
signal
.
SIGUSR1
)
psutil
.
Process
().
parent
().
send_signal
(
signal
.
SIGUSR1
)
...
@@ -378,9 +377,11 @@ class WorkerProc:
...
@@ -378,9 +377,11 @@ class WorkerProc:
# Notes have been introduced in python 3.11
# Notes have been introduced in python 3.11
if
hasattr
(
e
,
"add_note"
):
if
hasattr
(
e
,
"add_note"
):
e
.
add_note
(
traceback
.
format_exc
())
e
.
add_note
(
traceback
.
format_exc
())
self
.
worker_response_mq
.
enqueue
(
(
WorkerProc
.
ResponseStatus
.
FAILURE
,
e
))
logger
.
exception
(
"WorkerProc hit an exception: %s"
,
exc_info
=
e
)
logger
.
exception
(
"WorkerProc hit an exception: %s"
,
exc_info
=
e
)
# exception might not be serializable, so we convert it to
# string, only for logging purpose.
self
.
worker_response_mq
.
enqueue
(
(
WorkerProc
.
ResponseStatus
.
FAILURE
,
str
(
e
)))
continue
continue
self
.
worker_response_mq
.
enqueue
(
self
.
worker_response_mq
.
enqueue
(
...
...
vllm/v1/metrics/loggers.py
View file @
9c4ecf15
...
@@ -239,7 +239,8 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -239,7 +239,8 @@ class PrometheusStatLogger(StatLoggerBase):
documentation
=
"Histogram of time to first token in seconds."
,
documentation
=
"Histogram of time to first token in seconds."
,
buckets
=
[
buckets
=
[
0.001
,
0.005
,
0.01
,
0.02
,
0.04
,
0.06
,
0.08
,
0.1
,
0.25
,
0.5
,
0.001
,
0.005
,
0.01
,
0.02
,
0.04
,
0.06
,
0.08
,
0.1
,
0.25
,
0.5
,
0.75
,
1.0
,
2.5
,
5.0
,
7.5
,
10.0
0.75
,
1.0
,
2.5
,
5.0
,
7.5
,
10.0
,
20.0
,
40.0
,
80.0
,
160.0
,
640.0
,
2560.0
],
],
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
...
@@ -249,13 +250,13 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -249,13 +250,13 @@ class PrometheusStatLogger(StatLoggerBase):
documentation
=
"Histogram of time per output token in seconds."
,
documentation
=
"Histogram of time per output token in seconds."
,
buckets
=
[
buckets
=
[
0.01
,
0.025
,
0.05
,
0.075
,
0.1
,
0.15
,
0.2
,
0.3
,
0.4
,
0.5
,
0.01
,
0.025
,
0.05
,
0.075
,
0.1
,
0.15
,
0.2
,
0.3
,
0.4
,
0.5
,
0.75
,
1.0
,
2.5
0.75
,
1.0
,
2.5
,
5.0
,
7.5
,
10.0
,
20.0
,
40.0
,
80.0
],
],
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
request_latency_buckets
=
[
request_latency_buckets
=
[
0.3
,
0.5
,
0.8
,
1.0
,
1.5
,
2.0
,
2.5
,
5.0
,
10.0
,
15.0
,
20.0
,
30.0
,
0.3
,
0.5
,
0.8
,
1.0
,
1.5
,
2.0
,
2.5
,
5.0
,
10.0
,
15.0
,
20.0
,
30.0
,
40.0
,
50.0
,
60.0
40.0
,
50.0
,
60.0
,
120.0
,
240.0
,
480.0
,
960.0
,
1920.0
,
7680.0
]
]
self
.
histogram_e2e_time_request
=
\
self
.
histogram_e2e_time_request
=
\
prometheus_client
.
Histogram
(
prometheus_client
.
Histogram
(
...
...
vllm/v1/request.py
View file @
9c4ecf15
...
@@ -3,17 +3,16 @@
...
@@ -3,17 +3,16 @@
import
enum
import
enum
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
is_list_of
from
vllm.v1.engine
import
(
EngineCoreEvent
,
EngineCoreEventType
,
from
vllm.v1.engine
import
(
EngineCoreEvent
,
EngineCoreEventType
,
EngineCoreRequest
,
FinishReason
)
EngineCoreRequest
,
FinishReason
)
from
vllm.v1.structured_output.request
import
StructuredOutputRequest
from
vllm.v1.structured_output.request
import
StructuredOutputRequest
from
vllm.v1.utils
import
ConstantList
from
vllm.v1.utils
import
ConstantList
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.multimodal.inputs
import
PlaceholderRange
class
Request
:
class
Request
:
...
@@ -23,9 +22,9 @@ class Request:
...
@@ -23,9 +22,9 @@ class Request:
request_id
:
str
,
request_id
:
str
,
prompt
:
Optional
[
str
],
prompt
:
Optional
[
str
],
prompt_token_ids
:
list
[
int
],
prompt_token_ids
:
list
[
int
],
multi_modal_inputs
:
Optional
[
list
[
"
MultiModalKwargs
"
]],
multi_modal_inputs
:
Optional
[
list
[
MultiModalKwargs
]],
multi_modal_hashes
:
Optional
[
list
[
str
]],
multi_modal_hashes
:
Optional
[
list
[
str
]],
multi_modal_placeholders
:
Optional
[
list
[
"
PlaceholderRange
"
]],
multi_modal_placeholders
:
Optional
[
list
[
PlaceholderRange
]],
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
eos_token_id
:
Optional
[
int
],
eos_token_id
:
Optional
[
int
],
arrival_time
:
float
,
arrival_time
:
float
,
...
@@ -75,6 +74,11 @@ class Request:
...
@@ -75,6 +74,11 @@ class Request:
@
classmethod
@
classmethod
def
from_engine_core_request
(
cls
,
request
:
EngineCoreRequest
)
->
"Request"
:
def
from_engine_core_request
(
cls
,
request
:
EngineCoreRequest
)
->
"Request"
:
if
request
.
mm_inputs
is
not
None
:
assert
isinstance
(
request
.
mm_inputs
,
list
)
assert
is_list_of
(
request
.
mm_inputs
,
MultiModalKwargs
),
(
"mm_inputs was not updated in EngineCore.add_request"
)
return
cls
(
return
cls
(
request_id
=
request
.
request_id
,
request_id
=
request
.
request_id
,
prompt
=
request
.
prompt
,
prompt
=
request
.
prompt
,
...
@@ -121,7 +125,7 @@ class Request:
...
@@ -121,7 +125,7 @@ class Request:
def
get_num_encoder_tokens
(
self
,
input_id
:
int
)
->
int
:
def
get_num_encoder_tokens
(
self
,
input_id
:
int
)
->
int
:
assert
input_id
<
len
(
self
.
mm_positions
)
assert
input_id
<
len
(
self
.
mm_positions
)
num_tokens
=
self
.
mm_positions
[
input_id
]
[
"
length
"
]
num_tokens
=
self
.
mm_positions
[
input_id
]
.
length
return
num_tokens
return
num_tokens
@
property
@
property
...
...
vllm/v1/sample/sampler.py
View file @
9c4ecf15
...
@@ -230,9 +230,19 @@ class Sampler(nn.Module):
...
@@ -230,9 +230,19 @@ class Sampler(nn.Module):
# TODO(houseroad): this implementation is extremely inefficient.
# TODO(houseroad): this implementation is extremely inefficient.
# One idea is implement this as a PyTorch C++ op, and we may
# One idea is implement this as a PyTorch C++ op, and we may
# even optimize the logit_bias layout.
# even optimize the logit_bias layout.
# Get vocabulary size from logits
vocab_size
=
logits
.
shape
[
-
1
]
for
i
,
logit_bias
in
enumerate
(
sampling_metadata
.
logit_bias
):
for
i
,
logit_bias
in
enumerate
(
sampling_metadata
.
logit_bias
):
if
logit_bias
:
if
logit_bias
:
for
token_id
,
bias
in
logit_bias
.
items
():
for
token_id
,
bias
in
logit_bias
.
items
():
# Check token_id bounds to ensure within vocabulary
if
token_id
<
0
or
token_id
>=
vocab_size
:
raise
ValueError
(
f
"token_id
{
token_id
}
in logit_bias contains "
f
"out-of-vocab token id. Vocabulary size: "
f
"
{
vocab_size
}
"
)
logits
[
i
,
token_id
]
+=
bias
logits
[
i
,
token_id
]
+=
bias
return
logits
return
logits
...
...
vllm/v1/sample/tpu/metadata.py
View file @
9c4ecf15
...
@@ -3,7 +3,6 @@ from dataclasses import dataclass, field
...
@@ -3,7 +3,6 @@ from dataclasses import dataclass, field
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
import
torch_xla.core.xla_model
as
xm
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
...
@@ -24,19 +23,15 @@ class TPUSupportedSamplingMetadata:
...
@@ -24,19 +23,15 @@ class TPUSupportedSamplingMetadata:
# This class exposes a more xla-friendly interface than SamplingMetadata
# This class exposes a more xla-friendly interface than SamplingMetadata
# on TPU, in particular all arguments should be traceable and no optionals
# on TPU, in particular all arguments should be traceable and no optionals
# are allowed, to avoid graph recompilation on Nones.
# are allowed, to avoid graph recompilation on Nones.
temperature
:
torch
.
Tensor
temperature
:
torch
.
Tensor
=
None
min_p
:
torch
.
Tensor
min_p
:
torch
.
Tensor
=
None
# Still too slow on forward_native!
# Still too slow on forward_native!
top_k
:
torch
.
Tensor
=
None
top_k
:
torch
.
Tensor
=
None
top_p
:
torch
.
Tensor
=
None
top_p
:
torch
.
Tensor
=
None
# Greedy sampling flag for compiling single xla graph.
# Greedy sampling flag for compiling single xla graph.
all_greedy
:
torch
.
Tensor
=
None
all_greedy
:
bool
=
True
# Generator not supported by xla
generators
:
dict
[
int
,
torch
.
Generator
]
=
field
(
default_factory
=
lambda
:
dict
())
# unsupported, you need to return an extra tensor of static size BxV
# unsupported, you need to return an extra tensor of static size BxV
max_num_logprobs
=
None
max_num_logprobs
=
None
...
@@ -57,64 +52,66 @@ class TPUSupportedSamplingMetadata:
...
@@ -57,64 +52,66 @@ class TPUSupportedSamplingMetadata:
allowed_token_ids_mask
=
None
allowed_token_ids_mask
=
None
bad_words_token_ids
=
None
bad_words_token_ids
=
None
indices_do_sample
:
torch
.
Tensor
=
None
# Generator not supported by xla
_generators
:
dict
[
int
,
torch
.
Generator
]
=
field
(
default_factory
=
lambda
:
dict
())
@
property
def
generators
(
self
)
->
dict
[
int
,
torch
.
Generator
]:
# Generator not supported by torch/xla. This field must be immutable.
return
self
.
_generators
@
classmethod
@
classmethod
def
from_input_batch
(
def
from_input_batch
(
cls
,
input_batch
:
InputBatch
,
cls
,
indices_do_sample
:
torch
.
Tensor
)
->
"TPUSupportedSamplingMetadata"
:
input_batch
:
InputBatch
,
padded_num_reqs
:
int
,
xla_device
:
torch
.
device
,
generate_params_if_all_greedy
:
bool
=
False
)
->
"TPUSupportedSamplingMetadata"
:
"""
"""
Copy sampling tensors slices from `input_batch` to on device tensors.
Copy sampling tensors slices from `input_batch` to on device tensors.
`InputBatch._make_sampling_metadata` causes recompilation on XLA as it
`InputBatch._make_sampling_metadata` causes recompilation on XLA as it
slices dynamic shapes on device tensors. This impl moves the dynamic
slices dynamic shapes on device tensors. This impl moves the dynamic
ops to CPU and produces tensors of fixed `padded_num_reqs` size. It
ops to CPU and produces tensors of fixed `padded_num_reqs` size.
also reuses the on-device persistent tensors managed in `input_batch`
to reduce waste.
Args:
input_batch: The input batch containing sampling parameters.
`indices_do_sample` contains the indices to be fed to the Sampler,
padded_num_reqs: The padded number of requests.
normally one per request, here padded to the closest pre-compiled shape
xla_device: The XLA device.
We expect sampling params tensors to be padded to the same fixed shape.
generate_params_if_all_greedy: If True, generate sampling parameters
even if all requests are greedy. this is useful for cases where
Eg. 3 requests, tensors padded to 4
we want to pre-compile a graph with sampling parameters, even if
temperature: [0.7, 0.2, 0.9]=>[0.7, 0.2, 0.9, 0.0]
they are not strictly needed for greedy decoding.
sample indices: [4, 10, 11]=>indices_do_sample: [4, 10, 11, 0]
"""
"""
# Early return to avoid unnecessary cpu to tpu copy
if
(
input_batch
.
all_greedy
is
True
and
generate_params_if_all_greedy
is
False
):
return
cls
(
all_greedy
=
True
)
num_reqs
=
input_batch
.
num_reqs
num_reqs
=
input_batch
.
num_reqs
padded_num_reqs
=
len
(
indices_do_sample
)
def
copy_slice
(
cpu_tensor
:
torch
.
Tensor
,
tpu_tensor
:
torch
.
Tensor
,
def
fill_slice
(
cpu_tensor
:
torch
.
Tensor
,
fill_val
)
->
torch
.
Tensor
:
fill_val
)
->
torch
.
Tensor
:
# Copy slice from CPU to corresponding TPU pre-allocated tensor.
# Pad value is the default one.
# Pad value is the default one.
cpu_tensor
[
num_reqs
:
padded_num_reqs
]
=
fill_val
cpu_tensor
[
num_reqs
:
padded_num_reqs
]
=
fill_val
# Subtle compilation: len(tpu_tensor) must be >= `padded_num_reqs`
tpu_tensor
[:
padded_num_reqs
]
=
cpu_tensor
[:
padded_num_reqs
]
# NOTE NickLucche The sync CPU-TPU graph we produce here must be
fill_slice
(
input_batch
.
temperature_cpu_tensor
,
# consistent. We can't have flags to skip copies or we'll end up
# recompiling.
copy_slice
(
input_batch
.
temperature_cpu_tensor
,
input_batch
.
temperature
,
DEFAULT_SAMPLING_PARAMS
[
"temperature"
])
DEFAULT_SAMPLING_PARAMS
[
"temperature"
])
# TODO Temporarily disabled until sampling options are enabled
# TODO Temporarily disabled until sampling options are enabled
#
copy
_slice(input_batch.top_p_cpu_tensor
, input_batch.top_p
)
#
fill
_slice(input_batch.top_p_cpu_tensor)
#
copy
_slice(input_batch.top_k_cpu_tensor
, input_batch.top_k
)
#
fill
_slice(input_batch.top_k_cpu_tensor)
copy
_slice
(
input_batch
.
min_p_cpu_tensor
,
input_batch
.
min_p
,
fill
_slice
(
input_batch
.
min_p_cpu_tensor
,
DEFAULT_SAMPLING_PARAMS
[
"min_p"
])
DEFAULT_SAMPLING_PARAMS
[
"min_p"
])
xm
.
mark_step
()
xm
.
wait_device_ops
()
# Slice persistent device tensors to a fixed pre-compiled padded shape.
# Slice persistent device tensors to a fixed pre-compiled padded shape.
return
cls
(
return
cls
(
temperature
=
input_batch
.
temperature
[:
padded_num_reqs
],
temperature
=
input_batch
.
temperature_cpu_tensor
[:
padded_num_reqs
].
# Scalar tensor for xla-friendly tracing.
to
(
xla_device
),
all_greedy
=
torch
.
tensor
(
input_batch
.
all_greedy
,
all_greedy
=
input_batch
.
all_greedy
,
dtype
=
torch
.
bool
,
device
=
input_batch
.
device
),
# TODO enable more and avoid returning None values
# TODO enable more and avoid returning None values
top_p
=
None
,
# input_batch.top_p[:padded_num_reqs],
top_p
=
None
,
# input_batch.top_p[:padded_num_reqs],
top_k
=
None
,
# input_batch.top_k[:padded_num_reqs],
top_k
=
None
,
# input_batch.top_k[:padded_num_reqs],
min_p
=
input_batch
.
min_p
[:
padded_num_reqs
],
min_p
=
input_batch
.
min_p_cpu_tensor
[:
padded_num_reqs
].
to
(
generators
=
input_batch
.
generators
,
xla_device
))
indices_do_sample
=
indices_do_sample
)
vllm/v1/serial_utils.py
View file @
9c4ecf15
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
pickle
import
pickle
from
collections.abc
import
Sequence
from
inspect
import
isclass
from
types
import
FunctionType
from
types
import
FunctionType
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Optional
,
Union
import
cloudpickle
import
cloudpickle
import
numpy
as
np
import
torch
import
torch
import
zmq
from
msgspec
import
msgpack
from
msgspec
import
msgpack
CUSTOM_TYPE_
TENSOR
=
1
CUSTOM_TYPE_
PICKLE
=
1
CUSTOM_TYPE_PICKLE
=
2
CUSTOM_TYPE_
CLOUD
PICKLE
=
2
CUSTOM_TYPE_
CLOUDPICKLE
=
3
CUSTOM_TYPE_
RAW_VIEW
=
3
# TODO calibrate this size
MIN_NOCOPY_BUF_SIZE
=
512
class
MsgpackEncoder
:
bytestr
=
Union
[
bytes
,
bytearray
,
memoryview
,
zmq
.
Frame
]
"""Encoder with custom torch tensor serialization."""
def
__init__
(
self
):
self
.
encoder
=
msgpack
.
Encoder
(
enc_hook
=
custom_enc_hook
)
def
encode
(
self
,
obj
:
Any
)
->
bytes
:
class
MsgpackEncoder
:
return
self
.
encoder
.
encode
(
obj
)
"""Encoder with custom torch tensor and numpy array serialization.
def
encode_into
(
self
,
obj
:
Any
,
buf
:
bytearray
)
->
None
:
Note that unlike vanilla `msgspec` Encoders, this interface is generally
self
.
encoder
.
encode_into
(
obj
,
buf
)
not thread-safe when encoding tensors / numpy arrays.
"""
def
__init__
(
self
):
self
.
encoder
=
msgpack
.
Encoder
(
enc_hook
=
self
.
enc_hook
)
# This is used as a local stash of buffers that we can then access from
# our custom `msgspec` hook, `enc_hook`. We don't have a way to
# pass custom data to the hook otherwise.
self
.
aux_buffers
:
Optional
[
list
[
bytestr
]]
=
None
def
encode
(
self
,
obj
:
Any
)
->
Sequence
[
bytestr
]:
try
:
self
.
aux_buffers
=
bufs
=
[
b
''
]
bufs
[
0
]
=
self
.
encoder
.
encode
(
obj
)
# This `bufs` list allows us to collect direct pointers to backing
# buffers of tensors and np arrays, and return them along with the
# top-level encoded buffer instead of copying their data into the
# new buffer.
return
bufs
finally
:
self
.
aux_buffers
=
None
def
encode_into
(
self
,
obj
:
Any
,
buf
:
bytearray
)
->
Sequence
[
bytestr
]:
try
:
self
.
aux_buffers
=
[
buf
]
bufs
=
self
.
aux_buffers
self
.
encoder
.
encode_into
(
obj
,
buf
)
return
bufs
finally
:
self
.
aux_buffers
=
None
def
enc_hook
(
self
,
obj
:
Any
)
->
Any
:
if
isinstance
(
obj
,
torch
.
Tensor
):
return
self
.
_encode_ndarray
(
obj
.
numpy
())
# Fall back to pickle for object or void kind ndarrays.
if
isinstance
(
obj
,
np
.
ndarray
)
and
obj
.
dtype
.
kind
not
in
(
'O'
,
'V'
):
return
self
.
_encode_ndarray
(
obj
)
if
isinstance
(
obj
,
FunctionType
):
# `pickle` is generally faster than cloudpickle, but can have
# problems serializing methods.
return
msgpack
.
Ext
(
CUSTOM_TYPE_CLOUDPICKLE
,
cloudpickle
.
dumps
(
obj
))
return
msgpack
.
Ext
(
CUSTOM_TYPE_PICKLE
,
pickle
.
dumps
(
obj
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
))
def
_encode_ndarray
(
self
,
obj
:
np
.
ndarray
)
->
tuple
[
str
,
tuple
[
int
,
...],
Union
[
int
,
memoryview
]]:
assert
self
.
aux_buffers
is
not
None
arr_data
=
obj
.
data
if
obj
.
data
.
c_contiguous
else
obj
.
tobytes
()
if
not
obj
.
shape
or
obj
.
nbytes
<
MIN_NOCOPY_BUF_SIZE
:
# Encode small arrays and scalars inline. Using this extension type
# ensures we can avoid copying when decoding.
data
=
msgpack
.
Ext
(
CUSTOM_TYPE_RAW_VIEW
,
arr_data
)
else
:
# Otherwise encode index of backing buffer to avoid copy.
data
=
len
(
self
.
aux_buffers
)
self
.
aux_buffers
.
append
(
arr_data
)
# We serialize the ndarray as a tuple of native types.
# The data is either inlined if small, or an index into a list of
# backing buffers that we've stashed in `aux_buffers`.
return
obj
.
dtype
.
str
,
obj
.
shape
,
data
class
MsgpackDecoder
:
class
MsgpackDecoder
:
"""Decoder with custom torch tensor serialization."""
"""Decoder with custom torch tensor and numpy array serialization.
Note that unlike vanilla `msgspec` Decoders, this interface is generally
not thread-safe when encoding tensors / numpy arrays.
"""
def
__init__
(
self
,
t
:
Optional
[
Any
]
=
None
):
def
__init__
(
self
,
t
:
Optional
[
Any
]
=
None
):
args
=
()
if
t
is
None
else
(
t
,
)
args
=
()
if
t
is
None
else
(
t
,
)
self
.
decoder
=
msgpack
.
Decoder
(
*
args
,
ext_hook
=
custom_ext_hook
)
self
.
decoder
=
msgpack
.
Decoder
(
*
args
,
ext_hook
=
self
.
ext_hook
,
def
decode
(
self
,
obj
:
Any
):
dec_hook
=
self
.
dec_hook
)
return
self
.
decoder
.
decode
(
obj
)
self
.
aux_buffers
:
Sequence
[
bytestr
]
=
()
def
decode
(
self
,
bufs
:
Union
[
bytestr
,
Sequence
[
bytestr
]])
->
Any
:
def
custom_enc_hook
(
obj
:
Any
)
->
Any
:
if
isinstance
(
bufs
,
(
bytes
,
bytearray
,
memoryview
,
zmq
.
Frame
)):
if
isinstance
(
obj
,
torch
.
Tensor
):
# TODO - This check can become `isinstance(bufs, bytestr)`
# NOTE(rob): it is fastest to use numpy + pickle
# as of Python 3.10.
# when serializing torch tensors.
return
self
.
decoder
.
decode
(
bufs
)
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
return
msgpack
.
Ext
(
CUSTOM_TYPE_TENSOR
,
pickle
.
dumps
(
obj
.
numpy
()))
self
.
aux_buffers
=
bufs
try
:
if
isinstance
(
obj
,
FunctionType
):
return
self
.
decoder
.
decode
(
bufs
[
0
])
return
msgpack
.
Ext
(
CUSTOM_TYPE_CLOUDPICKLE
,
cloudpickle
.
dumps
(
obj
))
finally
:
self
.
aux_buffers
=
()
return
msgpack
.
Ext
(
CUSTOM_TYPE_PICKLE
,
pickle
.
dumps
(
obj
))
def
dec_hook
(
self
,
t
:
type
,
obj
:
Any
)
->
Any
:
# Given native types in `obj`, convert to type `t`.
def
custom_ext_hook
(
code
:
int
,
data
:
memoryview
)
->
Any
:
if
isclass
(
t
):
if
code
==
CUSTOM_TYPE_TENSOR
:
if
issubclass
(
t
,
np
.
ndarray
):
return
torch
.
from_numpy
(
pickle
.
loads
(
data
))
return
self
.
_decode_ndarray
(
obj
)
if
code
==
CUSTOM_TYPE_PICKLE
:
if
issubclass
(
t
,
torch
.
Tensor
):
return
pickle
.
loads
(
data
)
return
torch
.
from_numpy
(
self
.
_decode_ndarray
(
obj
))
if
code
==
CUSTOM_TYPE_CLOUDPICKLE
:
return
obj
return
cloudpickle
.
loads
(
data
)
def
_decode_ndarray
(
self
,
arr
:
Any
)
->
np
.
ndarray
:
raise
NotImplementedError
(
f
"Extension type code
{
code
}
is not supported"
)
dtype
,
shape
,
data
=
arr
buffer
=
self
.
aux_buffers
[
data
]
if
isinstance
(
data
,
int
)
else
data
return
np
.
ndarray
(
buffer
=
buffer
,
dtype
=
np
.
dtype
(
dtype
),
shape
=
shape
)
def
ext_hook
(
self
,
code
:
int
,
data
:
memoryview
)
->
Any
:
if
code
==
CUSTOM_TYPE_RAW_VIEW
:
return
data
if
code
==
CUSTOM_TYPE_PICKLE
:
return
pickle
.
loads
(
data
)
if
code
==
CUSTOM_TYPE_CLOUDPICKLE
:
return
cloudpickle
.
loads
(
data
)
raise
NotImplementedError
(
f
"Extension type code
{
code
}
is not supported"
)
vllm/v1/spec_decode/eagle.py
View file @
9c4ecf15
...
@@ -4,8 +4,11 @@ import torch.nn as nn
...
@@ -4,8 +4,11 @@ import torch.nn as nn
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.model_executor.model_loader.loader
import
get_model_loader
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
from
vllm.model_executor.models.llama_eagle
import
EagleLlamaForCausalLM
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
...
@@ -21,8 +24,12 @@ class EagleProposer:
...
@@ -21,8 +24,12 @@ class EagleProposer:
self
.
num_speculative_tokens
=
(
self
.
num_speculative_tokens
=
(
vllm_config
.
speculative_config
.
num_speculative_tokens
)
vllm_config
.
speculative_config
.
num_speculative_tokens
)
self
.
block_size
=
vllm_config
.
cache_config
.
block_size
self
.
block_size
=
vllm_config
.
cache_config
.
block_size
self
.
arange
=
torch
.
arange
(
vllm_config
.
scheduler_config
.
max_num_seqs
,
# We need +1 here because the arange is used to set query_start_loc,
device
=
device
)
# which has one more element than batch_size.
self
.
arange
=
torch
.
arange
(
vllm_config
.
scheduler_config
.
max_num_seqs
+
1
,
device
=
device
,
dtype
=
torch
.
int32
)
def
propose
(
def
propose
(
self
,
self
,
...
@@ -54,7 +61,9 @@ class EagleProposer:
...
@@ -54,7 +61,9 @@ class EagleProposer:
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
input_ids
[
last_token_indices
]
=
next_token_ids
input_ids
[
last_token_indices
]
=
next_token_ids
seq_lens
=
target_positions
[
last_token_indices
]
+
1
# FA requires seq_len to have dtype int32.
seq_lens
=
(
target_positions
[
last_token_indices
]
+
1
).
int
()
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
max_seq_len
=
seq_lens
.
max
().
item
()
max_seq_len
=
seq_lens
.
max
().
item
()
max_num_tokens
=
(
cu_num_tokens
[
1
:]
-
cu_num_tokens
[:
-
1
]).
max
().
item
()
max_num_tokens
=
(
cu_num_tokens
[
1
:]
-
cu_num_tokens
[:
-
1
]).
max
().
item
()
...
@@ -98,7 +107,7 @@ class EagleProposer:
...
@@ -98,7 +107,7 @@ class EagleProposer:
hidden_states
=
sample_hidden_states
hidden_states
=
sample_hidden_states
attn_metadata
.
num_actual_tokens
=
batch_size
attn_metadata
.
num_actual_tokens
=
batch_size
attn_metadata
.
max_query_len
=
1
attn_metadata
.
max_query_len
=
1
attn_metadata
.
query_start_loc
=
self
.
arange
[:
batch_size
]
attn_metadata
.
query_start_loc
=
self
.
arange
[:
batch_size
+
1
]
for
_
in
range
(
self
.
num_speculative_tokens
-
1
):
for
_
in
range
(
self
.
num_speculative_tokens
-
1
):
# Update the inputs.
# Update the inputs.
input_ids
=
draft_token_ids_list
[
-
1
]
input_ids
=
draft_token_ids_list
[
-
1
]
...
@@ -176,26 +185,28 @@ class EagleProposer:
...
@@ -176,26 +185,28 @@ class EagleProposer:
return
cu_num_tokens
,
token_indices
return
cu_num_tokens
,
token_indices
def
load_model
(
self
,
target_model
:
nn
.
Module
)
->
None
:
def
load_model
(
self
,
target_model
:
nn
.
Module
)
->
None
:
self
.
model
=
DummyEagleModel
()
loader
=
get_model_loader
(
self
.
vllm_config
.
load_config
)
self
.
model
.
get_input_embeddings
=
target_model
.
get_input_embeddings
target_layer_num
=
self
.
vllm_config
.
model_config
.
get_num_layers
(
self
.
model
.
compute_logits
=
target_model
.
compute_logits
self
.
vllm_config
.
parallel_config
)
draft_model_config
=
\
# FIXME(woosuk): This is a dummy model for testing.
self
.
vllm_config
.
speculative_config
.
draft_model_config
# Remove this once we have a real model.
# FIXME(lily): This does not handle with distributed inference.
class
DummyEagleModel
(
nn
.
Module
):
target_device
=
self
.
vllm_config
.
device_config
.
device
# We need to set the vllm_config here to register attention
def
__init__
(
self
):
# layers in the forward context.
super
().
__init__
()
with
set_default_torch_dtype
(
draft_model_config
.
dtype
),
set_current_vllm_config
(
def
forward
(
self
.
vllm_config
):
self
,
self
.
model
=
EagleLlamaForCausalLM
(
input_ids
:
torch
.
Tensor
,
model_config
=
draft_model_config
,
hidden_states
:
torch
.
Tensor
,
start_layer_id
=
target_layer_num
).
to
(
target_device
)
positions
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
self
.
model
.
load_weights
(
input_embeddings
=
self
.
get_input_embeddings
(
input_ids
)
loader
.
get_all_weights
(
return
hidden_states
+
input_embeddings
# Dummy return.
self
.
vllm_config
.
speculative_config
.
draft_model_config
,
self
.
model
))
self
.
model
.
lm_head
=
target_model
.
lm_head
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
...
...
vllm/v1/structured_output/backend_guidance.py
View file @
9c4ecf15
...
@@ -46,7 +46,8 @@ class GuidanceBackend(StructuredOutputBackend):
...
@@ -46,7 +46,8 @@ class GuidanceBackend(StructuredOutputBackend):
in
vllm_config
.
decoding_config
.
guided_decoding_backend
)
in
vllm_config
.
decoding_config
.
guided_decoding_backend
)
tokenizer
=
tokenizer_group
.
get_lora_tokenizer
(
None
)
tokenizer
=
tokenizer_group
.
get_lora_tokenizer
(
None
)
self
.
ll_tokenizer
=
llguidance_hf
.
from_tokenizer
(
tokenizer
,
None
)
self
.
ll_tokenizer
=
llguidance_hf
.
from_tokenizer
(
tokenizer
,
self
.
vocab_size
)
def
compile_grammar
(
self
,
request_type
:
StructuredOutputOptions
,
def
compile_grammar
(
self
,
request_type
:
StructuredOutputOptions
,
grammar_spec
:
str
)
->
StructuredOutputGrammar
:
grammar_spec
:
str
)
->
StructuredOutputGrammar
:
...
@@ -163,7 +164,6 @@ def validate_guidance_grammar(
...
@@ -163,7 +164,6 @@ def validate_guidance_grammar(
tokenizer
:
Optional
[
llguidance
.
LLTokenizer
]
=
None
)
->
None
:
tokenizer
:
Optional
[
llguidance
.
LLTokenizer
]
=
None
)
->
None
:
tp
,
grm
=
get_structured_output_key
(
sampling_params
)
tp
,
grm
=
get_structured_output_key
(
sampling_params
)
guidance_grm
=
serialize_guidance_grammar
(
tp
,
grm
)
guidance_grm
=
serialize_guidance_grammar
(
tp
,
grm
)
err
=
llguidance
.
LLMatcher
.
validate_grammar
(
guidance_grm
,
err
=
llguidance
.
LLMatcher
.
validate_grammar
(
guidance_grm
,
tokenizer
)
tokenizer
=
tokenizer
)
if
err
:
if
err
:
raise
ValueError
(
f
"Grammar error:
{
err
}
"
)
raise
ValueError
(
f
"Grammar error:
{
err
}
"
)
vllm/v1/structured_output/backend_xgrammar.py
View file @
9c4ecf15
...
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
...
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
import
torch
import
torch
import
vllm.envs
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
...
@@ -76,7 +77,12 @@ class XgrammarBackend(StructuredOutputBackend):
...
@@ -76,7 +77,12 @@ class XgrammarBackend(StructuredOutputBackend):
tokenizer
,
tokenizer
,
vocab_size
=
self
.
vocab_size
,
vocab_size
=
self
.
vocab_size
,
)
)
self
.
compiler
=
xgr
.
GrammarCompiler
(
tokenizer_info
,
max_threads
=
8
)
self
.
compiler
=
xgr
.
GrammarCompiler
(
tokenizer_info
,
max_threads
=
8
,
cache_enabled
=
True
,
cache_limit_bytes
=
vllm
.
envs
.
VLLM_XGRAMMAR_CACHE_MB
*
1024
*
1024
,
)
def
compile_grammar
(
self
,
request_type
:
StructuredOutputOptions
,
def
compile_grammar
(
self
,
request_type
:
StructuredOutputOptions
,
grammar_spec
:
str
)
->
StructuredOutputGrammar
:
grammar_spec
:
str
)
->
StructuredOutputGrammar
:
...
...
vllm/v1/structured_output/utils.py
View file @
9c4ecf15
...
@@ -41,8 +41,7 @@ def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool:
...
@@ -41,8 +41,7 @@ def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool:
return
True
return
True
# Unsupported keywords for strings
# Unsupported keywords for strings
if
obj
.
get
(
"type"
)
==
"string"
and
any
(
if
obj
.
get
(
"type"
)
==
"string"
and
"format"
in
obj
:
key
in
obj
for
key
in
(
"minLength"
,
"maxLength"
,
"format"
)):
return
True
return
True
# Unsupported keywords for objects
# Unsupported keywords for objects
...
...
vllm/v1/utils.py
View file @
9c4ecf15
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
multiprocessing
import
os
import
os
import
weakref
import
weakref
from
collections
import
defaultdict
from
collections
import
defaultdict
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
multiprocessing
import
Process
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Generic
,
Optional
,
TypeVar
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Generic
,
Optional
,
TypeVar
,
Union
,
overload
)
Union
,
overload
)
...
@@ -115,18 +115,15 @@ class BackgroundProcHandle:
...
@@ -115,18 +115,15 @@ class BackgroundProcHandle:
process_kwargs
[
"output_path"
]
=
output_path
process_kwargs
[
"output_path"
]
=
output_path
# Run busy loop in background process.
# Run busy loop in background process.
self
.
proc
=
context
.
Process
(
target
=
target_fn
,
self
.
proc
:
Process
=
context
.
Process
(
target
=
target_fn
,
kwargs
=
process_kwargs
,
kwargs
=
process_kwargs
,
name
=
process_name
)
name
=
process_name
)
self
.
_finalizer
=
weakref
.
finalize
(
self
,
shutdown
,
self
.
proc
,
self
.
_finalizer
=
weakref
.
finalize
(
self
,
shutdown
,
self
.
proc
,
input_path
,
output_path
)
input_path
,
output_path
)
self
.
proc
.
start
()
self
.
proc
.
start
()
def
wait_for_startup
(
self
):
def
fileno
(
self
):
# Wait for startup.
return
self
.
proc
.
sentinel
if
self
.
reader
.
recv
()[
"status"
]
!=
"READY"
:
raise
RuntimeError
(
f
"
{
self
.
proc
.
name
}
initialization failed. "
"See root cause above."
)
def
shutdown
(
self
):
def
shutdown
(
self
):
self
.
_finalizer
()
self
.
_finalizer
()
...
@@ -134,7 +131,7 @@ class BackgroundProcHandle:
...
@@ -134,7 +131,7 @@ class BackgroundProcHandle:
# Note(rob): shutdown function cannot be a bound method,
# Note(rob): shutdown function cannot be a bound method,
# else the gc cannot collect the object.
# else the gc cannot collect the object.
def
shutdown
(
proc
:
multiprocessing
.
Process
,
input_path
:
str
,
output_path
:
str
):
def
shutdown
(
proc
:
Process
,
input_path
:
str
,
output_path
:
str
):
# Shutdown the process.
# Shutdown the process.
if
proc
.
is_alive
():
if
proc
.
is_alive
():
proc
.
terminate
()
proc
.
terminate
()
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
9c4ecf15
...
@@ -19,7 +19,8 @@ from vllm.logger import init_logger
...
@@ -19,7 +19,8 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.multimodal.utils
import
group_mm_inputs_by_modality
from
vllm.multimodal.utils
import
group_mm_inputs_by_modality
from
vllm.sampling_params
import
SamplingType
from
vllm.sampling_params
import
SamplingType
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
...
@@ -43,7 +44,8 @@ from vllm.v1.utils import bind_kv_cache
...
@@ -43,7 +44,8 @@ from vllm.v1.utils import bind_kv_cache
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
.utils
import
sanity_check_mm_encoder_outputs
from
.utils
import
(
gather_mm_placeholders
,
sanity_check_mm_encoder_outputs
,
scatter_mm_placeholders
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
import
xgrammar
as
xgr
import
xgrammar
as
xgr
...
@@ -482,14 +484,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -482,14 +484,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
input_batch
.
block_table
.
commit
(
num_reqs
)
self
.
input_batch
.
block_table
.
commit
(
num_reqs
)
# Get the number of scheduled tokens for each request.
# Get the number of scheduled tokens for each request.
# TODO: The Python loop can be slow. Optimize.
req_ids
=
self
.
input_batch
.
req_ids
num_scheduled_tokens
=
np
.
empty
(
num_reqs
,
dtype
=
np
.
int32
)
tokens
=
[
scheduler_output
.
num_scheduled_tokens
[
i
]
for
i
in
req_ids
]
max_num_scheduled_tokens
=
0
num_scheduled_tokens
=
np
.
array
(
tokens
,
dtype
=
np
.
int32
)
for
i
,
req_id
in
enumerate
(
self
.
input_batch
.
req_ids
):
max_num_scheduled_tokens
=
max
(
tokens
)
num_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
num_scheduled_tokens
[
i
]
=
num_tokens
max_num_scheduled_tokens
=
max
(
max_num_scheduled_tokens
,
num_tokens
)
# Get request indices.
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
...
@@ -830,19 +828,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -830,19 +828,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
)
return
metadata
return
metadata
def
_execute_encoder
(
self
,
scheduler_output
:
"SchedulerOutput"
):
def
_execute_
mm_
encoder
(
self
,
scheduler_output
:
"SchedulerOutput"
):
scheduled_encoder_inputs
=
scheduler_output
.
scheduled_encoder_inputs
scheduled_encoder_inputs
=
scheduler_output
.
scheduled_encoder_inputs
if
not
scheduled_encoder_inputs
:
if
not
scheduled_encoder_inputs
:
return
return
# Batch the multi-modal inputs.
# Batch the multi-modal inputs.
mm_inputs
:
list
[
MultiModalKwargs
]
=
[]
mm_inputs
=
list
[
MultiModalKwargs
]
()
req_i
nput_ids
:
list
[
tuple
[
str
,
int
]]
=
[]
req_i
ds_pos
=
list
[
tuple
[
str
,
int
,
PlaceholderRange
]]()
for
req_id
,
encoder_input_ids
in
scheduled_encoder_inputs
.
items
():
for
req_id
,
encoder_input_ids
in
scheduled_encoder_inputs
.
items
():
req_state
=
self
.
requests
[
req_id
]
req_state
=
self
.
requests
[
req_id
]
for
input_id
in
encoder_input_ids
:
mm_inputs
.
append
(
req_state
.
mm_inputs
[
input_id
])
for
mm_input_id
in
encoder_input_ids
:
req_input_ids
.
append
((
req_id
,
input_id
))
mm_inputs
.
append
(
req_state
.
mm_inputs
[
mm_input_id
])
req_ids_pos
.
append
(
(
req_id
,
mm_input_id
,
req_state
.
mm_positions
[
mm_input_id
]))
# Batch mm inputs as much as we can: if a request in the batch has
# Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one,
# multiple modalities or a different modality than the previous one,
...
@@ -878,16 +878,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -878,16 +878,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
encoder_outputs
.
append
(
output
)
encoder_outputs
.
append
(
output
)
# Cache the encoder outputs.
# Cache the encoder outputs.
for
(
req_id
,
input_id
),
output
in
zip
(
req_input_ids
,
encoder_outputs
):
for
(
req_id
,
input_id
,
pos_info
),
output
in
zip
(
req_ids_pos
,
encoder_outputs
,
):
if
req_id
not
in
self
.
encoder_cache
:
if
req_id
not
in
self
.
encoder_cache
:
self
.
encoder_cache
[
req_id
]
=
{}
self
.
encoder_cache
[
req_id
]
=
{}
self
.
encoder_cache
[
req_id
][
input_id
]
=
output
def
_gather_encoder_outputs
(
self
.
encoder_cache
[
req_id
][
input_id
]
=
scatter_mm_placeholders
(
output
,
is_embed
=
pos_info
.
is_embed
,
)
def
_gather_mm_embeddings
(
self
,
self
,
scheduler_output
:
"SchedulerOutput"
,
scheduler_output
:
"SchedulerOutput"
,
)
->
list
[
torch
.
Tensor
]:
)
->
list
[
torch
.
Tensor
]:
encoder_output
s
:
list
[
torch
.
Tensor
]
=
[]
mm_embed
s
:
list
[
torch
.
Tensor
]
=
[]
for
req_id
in
self
.
input_batch
.
req_ids
:
for
req_id
in
self
.
input_batch
.
req_ids
:
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
[
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
req_id
]
...
@@ -895,8 +902,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -895,8 +902,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_computed_tokens
=
req_state
.
num_computed_tokens
num_computed_tokens
=
req_state
.
num_computed_tokens
mm_positions
=
req_state
.
mm_positions
mm_positions
=
req_state
.
mm_positions
for
i
,
pos_info
in
enumerate
(
mm_positions
):
for
i
,
pos_info
in
enumerate
(
mm_positions
):
start_pos
=
pos_info
[
"
offset
"
]
start_pos
=
pos_info
.
offset
num_encoder_tokens
=
pos_info
[
"
length
"
]
num_encoder_tokens
=
pos_info
.
length
# The encoder output is needed if the two ranges overlap:
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens,
# [num_computed_tokens,
...
@@ -918,8 +925,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -918,8 +925,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert
req_id
in
self
.
encoder_cache
assert
req_id
in
self
.
encoder_cache
assert
i
in
self
.
encoder_cache
[
req_id
]
assert
i
in
self
.
encoder_cache
[
req_id
]
encoder_output
=
self
.
encoder_cache
[
req_id
][
i
]
encoder_output
=
self
.
encoder_cache
[
req_id
][
i
]
encoder_outputs
.
append
(
encoder_output
[
start_idx
:
end_idx
])
return
encoder_outputs
if
(
is_embed
:
=
pos_info
.
is_embed
)
is
not
None
:
is_embed
=
is_embed
[
start_idx
:
end_idx
]
mm_embeds_item
=
gather_mm_placeholders
(
encoder_output
[
start_idx
:
end_idx
],
is_embed
=
is_embed
,
)
mm_embeds
.
append
(
mm_embeds_item
)
return
mm_embeds
def
get_model
(
self
)
->
nn
.
Module
:
def
get_model
(
self
)
->
nn
.
Module
:
return
self
.
model
return
self
.
model
...
@@ -979,15 +994,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -979,15 +994,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
->
Union
[
ModelRunnerOutput
,
torch
.
Tensor
]:
)
->
Union
[
ModelRunnerOutput
,
torch
.
Tensor
]:
self
.
_update_states
(
scheduler_output
)
self
.
_update_states
(
scheduler_output
)
if
not
scheduler_output
.
total_num_scheduled_tokens
:
if
not
scheduler_output
.
total_num_scheduled_tokens
:
# Return empty ModelRunnerOu
p
tut if there's no work to do.
# Return empty ModelRunnerOut
p
ut if there's no work to do.
return
EMPTY_MODEL_RUNNER_OUTPUT
return
EMPTY_MODEL_RUNNER_OUTPUT
if
self
.
is_multimodal_model
:
if
self
.
is_multimodal_model
:
# Run the multimodal encoder if any.
# Run the multimodal encoder if any.
self
.
_execute_encoder
(
scheduler_output
)
self
.
_execute_
mm_
encoder
(
scheduler_output
)
encoder_output
s
=
self
.
_gather_
encoder_output
s
(
scheduler_output
)
mm_embed
s
=
self
.
_gather_
mm_embedding
s
(
scheduler_output
)
else
:
else
:
encoder_output
s
=
[]
mm_embed
s
=
[]
# Prepare the decoder inputs.
# Prepare the decoder inputs.
attn_metadata
,
logits_indices
,
spec_decode_metadata
=
(
attn_metadata
,
logits_indices
,
spec_decode_metadata
=
(
...
@@ -1009,9 +1024,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1009,9 +1024,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# embeddings), we always use embeddings (rather than token ids)
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
# as input to the multimodal model, even when the input is text.
input_ids
=
self
.
input_ids
[:
num_scheduled_tokens
]
input_ids
=
self
.
input_ids
[:
num_scheduled_tokens
]
if
encoder_output
s
:
if
mm_embed
s
:
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
input_ids
,
encoder_output
s
)
input_ids
,
mm_embed
s
)
else
:
else
:
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
input_ids
)
# TODO(woosuk): Avoid the copy. Optimize.
# TODO(woosuk): Avoid the copy. Optimize.
...
@@ -1172,9 +1187,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1172,9 +1187,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
spec_decode_metadata
is
None
:
if
spec_decode_metadata
is
None
:
# input_ids can be None for multimodal models.
# input_ids can be None for multimodal models.
# We need to slice token_ids, positions, and hidden_states
# because the eagle head does not use cuda graph and should
# not include padding.
target_token_ids
=
self
.
input_ids
[:
num_scheduled_tokens
]
target_token_ids
=
self
.
input_ids
[:
num_scheduled_tokens
]
target_positions
=
positions
target_positions
=
positions
[:
num_scheduled_tokens
]
target_hidden_states
=
hidden_states
target_hidden_states
=
hidden_states
[:
num_scheduled_tokens
]
target_slot_mapping
=
attn_metadata
.
slot_mapping
target_slot_mapping
=
attn_metadata
.
slot_mapping
cu_num_tokens
=
attn_metadata
.
query_start_loc
cu_num_tokens
=
attn_metadata
.
query_start_loc
else
:
else
:
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
9c4ecf15
...
@@ -15,13 +15,14 @@ import torch_xla.runtime as xr
...
@@ -15,13 +15,14 @@ import torch_xla.runtime as xr
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.attention.backends.abstract
import
AttentionType
from
vllm.attention.backends.abstract
import
AttentionType
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispatcher
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.multimodal.utils
import
group_mm_inputs_by_modality
from
vllm.multimodal.utils
import
group_mm_inputs_by_modality
from
vllm.sampling_params
import
SamplingType
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
LayerBlockType
,
cdiv
,
is_pin_memory_available
from
vllm.utils
import
LayerBlockType
,
cdiv
,
is_pin_memory_available
from
vllm.v1.attention.backends.pallas
import
(
PallasAttentionBackend
,
from
vllm.v1.attention.backends.pallas
import
(
PallasAttentionBackend
,
...
@@ -30,13 +31,14 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
...
@@ -30,13 +31,14 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheSpec
,
SlidingWindowSpec
)
KVCacheSpec
,
SlidingWindowSpec
)
from
vllm.v1.outputs
import
(
EMPTY_MODEL_RUNNER_OUTPUT
,
LogprobsTensors
,
from
vllm.v1.outputs
import
(
EMPTY_MODEL_RUNNER_OUTPUT
,
LogprobsTensors
,
ModelRunnerOutput
,
SamplerOutput
)
ModelRunnerOutput
)
from
vllm.v1.sample.tpu.metadata
import
TPUSupportedSamplingMetadata
from
vllm.v1.sample.tpu.metadata
import
TPUSupportedSamplingMetadata
from
vllm.v1.sample.tpu.sampler
import
Sampler
as
TPUSampler
from
vllm.v1.sample.tpu.sampler
import
Sampler
as
TPUSampler
from
vllm.v1.utils
import
bind_kv_cache
from
vllm.v1.utils
import
bind_kv_cache
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
.utils
import
sanity_check_mm_encoder_outputs
from
.utils
import
(
gather_mm_placeholders
,
sanity_check_mm_encoder_outputs
,
scatter_mm_placeholders
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
...
@@ -174,10 +176,12 @@ class TPUModelRunner:
...
@@ -174,10 +176,12 @@ class TPUModelRunner:
# Range tensor with values [0 .. self.max_num_tokens - 1].
# Range tensor with values [0 .. self.max_num_tokens - 1].
# Used to initialize positions / context_lens / seq_lens
# Used to initialize positions / context_lens / seq_lens
self
.
arange_np
=
np
.
arange
(
self
.
max_num_tokens
,
dtype
=
np
.
int32
)
self
.
arange_np
=
np
.
arange
(
self
.
max_num_tokens
,
dtype
=
np
.
int32
)
self
.
num_tokens_paddings
=
_get_paddings
(
self
.
num_tokens_paddings
=
_get_
token_
paddings
(
min_token_size
=
16
,
min_token_size
=
16
,
max_token_size
=
self
.
max_num_tokens
,
max_token_size
=
self
.
max_num_tokens
,
padding_gap
=
envs
.
VLLM_TPU_BUCKET_PADDING_GAP
)
padding_gap
=
envs
.
VLLM_TPU_BUCKET_PADDING_GAP
)
self
.
num_reqs_paddings
=
_get_req_paddings
(
min_req_size
=
MIN_NUM_SEQS
,
max_req_size
=
self
.
max_num_reqs
)
def
_update_num_xla_graphs
(
self
,
case_str
):
def
_update_num_xla_graphs
(
self
,
case_str
):
check_comp
=
self
.
check_recompilation
and
not
self
.
enforce_eager
check_comp
=
self
.
check_recompilation
and
not
self
.
enforce_eager
...
@@ -262,11 +266,6 @@ class TPUModelRunner:
...
@@ -262,11 +266,6 @@ class TPUModelRunner:
for
new_req_data
in
scheduler_output
.
scheduled_new_reqs
:
for
new_req_data
in
scheduler_output
.
scheduled_new_reqs
:
req_id
=
new_req_data
.
req_id
req_id
=
new_req_data
.
req_id
sampling_params
=
new_req_data
.
sampling_params
sampling_params
=
new_req_data
.
sampling_params
if
sampling_params
.
sampling_type
==
SamplingType
.
RANDOM_SEED
:
generator
=
torch
.
Generator
(
device
=
self
.
device
)
generator
.
manual_seed
(
sampling_params
.
seed
)
else
:
generator
=
None
self
.
requests
[
req_id
]
=
CachedRequestState
(
self
.
requests
[
req_id
]
=
CachedRequestState
(
req_id
=
req_id
,
req_id
=
req_id
,
...
@@ -275,7 +274,7 @@ class TPUModelRunner:
...
@@ -275,7 +274,7 @@ class TPUModelRunner:
mm_inputs
=
new_req_data
.
mm_inputs
,
mm_inputs
=
new_req_data
.
mm_inputs
,
mm_positions
=
new_req_data
.
mm_positions
,
mm_positions
=
new_req_data
.
mm_positions
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
generator
=
generator
,
generator
=
None
,
block_ids
=
new_req_data
.
block_ids
,
block_ids
=
new_req_data
.
block_ids
,
num_computed_tokens
=
new_req_data
.
num_computed_tokens
,
num_computed_tokens
=
new_req_data
.
num_computed_tokens
,
output_token_ids
=
[],
output_token_ids
=
[],
...
@@ -505,21 +504,48 @@ class TPUModelRunner:
...
@@ -505,21 +504,48 @@ class TPUModelRunner:
# Padded to avoid recompiling when `num_reqs` varies.
# Padded to avoid recompiling when `num_reqs` varies.
logits_indices
=
self
.
query_start_loc_cpu
[
1
:
padded_num_reqs
+
1
]
-
1
logits_indices
=
self
.
query_start_loc_cpu
[
1
:
padded_num_reqs
+
1
]
-
1
logits_indices
=
logits_indices
.
to
(
self
.
device
)
logits_indices
=
logits_indices
.
to
(
self
.
device
)
return
attn_metadata
,
logits_indices
return
attn_metadata
,
logits_indices
,
padded_num_reqs
def
_scatter_placeholders
(
self
,
embeds
:
torch
.
Tensor
,
is_embed
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
if
is_embed
is
None
:
return
embeds
def
_execute_encoder
(
self
,
scheduler_output
:
"SchedulerOutput"
):
placeholders
=
embeds
.
new_full
(
(
is_embed
.
shape
[
0
],
embeds
.
shape
[
-
1
]),
fill_value
=
torch
.
nan
,
)
placeholders
[
is_embed
]
=
embeds
return
placeholders
def
_gather_placeholders
(
self
,
placeholders
:
torch
.
Tensor
,
is_embed
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
if
is_embed
is
None
:
return
placeholders
return
placeholders
[
is_embed
]
def
_execute_mm_encoder
(
self
,
scheduler_output
:
"SchedulerOutput"
):
scheduled_encoder_inputs
=
scheduler_output
.
scheduled_encoder_inputs
scheduled_encoder_inputs
=
scheduler_output
.
scheduled_encoder_inputs
if
not
scheduled_encoder_inputs
:
if
not
scheduled_encoder_inputs
:
return
return
# Batch the multi-modal inputs.
# Batch the multi-modal inputs.
mm_inputs
:
list
[
MultiModalKwargs
]
=
[]
mm_inputs
=
list
[
MultiModalKwargs
]
()
req_i
nput_ids
:
list
[
tuple
[
str
,
int
]]
=
[]
req_i
ds_pos
=
list
[
tuple
[
str
,
int
,
PlaceholderRange
]]()
for
req_id
,
encoder_input_ids
in
scheduled_encoder_inputs
.
items
():
for
req_id
,
encoder_input_ids
in
scheduled_encoder_inputs
.
items
():
req_state
=
self
.
requests
[
req_id
]
req_state
=
self
.
requests
[
req_id
]
for
input_id
in
encoder_input_ids
:
mm_inputs
.
append
(
req_state
.
mm_inputs
[
input_id
])
for
mm_input_id
in
encoder_input_ids
:
req_input_ids
.
append
((
req_id
,
input_id
))
mm_inputs
.
append
(
req_state
.
mm_inputs
[
mm_input_id
])
req_ids_pos
.
append
(
(
req_id
,
mm_input_id
,
req_state
.
mm_positions
[
mm_input_id
]))
# Batch mm inputs as much as we can: if a request in the batch has
# Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one,
# multiple modalities or a different modality than the previous one,
...
@@ -555,16 +581,23 @@ class TPUModelRunner:
...
@@ -555,16 +581,23 @@ class TPUModelRunner:
encoder_outputs
.
append
(
output
)
encoder_outputs
.
append
(
output
)
# Cache the encoder outputs.
# Cache the encoder outputs.
for
(
req_id
,
input_id
),
output
in
zip
(
req_input_ids
,
encoder_outputs
):
for
(
req_id
,
input_id
,
pos_info
),
output
in
zip
(
req_ids_pos
,
encoder_outputs
,
):
if
req_id
not
in
self
.
encoder_cache
:
if
req_id
not
in
self
.
encoder_cache
:
self
.
encoder_cache
[
req_id
]
=
{}
self
.
encoder_cache
[
req_id
]
=
{}
self
.
encoder_cache
[
req_id
][
input_id
]
=
output
def
_gather_encoder_outputs
(
self
.
encoder_cache
[
req_id
][
input_id
]
=
scatter_mm_placeholders
(
output
,
is_embed
=
pos_info
.
is_embed
,
)
def
_gather_mm_embeddings
(
self
,
self
,
scheduler_output
:
"SchedulerOutput"
,
scheduler_output
:
"SchedulerOutput"
,
)
->
list
[
torch
.
Tensor
]:
)
->
list
[
torch
.
Tensor
]:
encoder_output
s
:
list
[
torch
.
Tensor
]
=
[]
mm_embed
s
:
list
[
torch
.
Tensor
]
=
[]
for
req_id
in
self
.
input_batch
.
req_ids
:
for
req_id
in
self
.
input_batch
.
req_ids
:
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
[
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
req_id
]
...
@@ -572,8 +605,8 @@ class TPUModelRunner:
...
@@ -572,8 +605,8 @@ class TPUModelRunner:
num_computed_tokens
=
req_state
.
num_computed_tokens
num_computed_tokens
=
req_state
.
num_computed_tokens
mm_positions
=
req_state
.
mm_positions
mm_positions
=
req_state
.
mm_positions
for
i
,
pos_info
in
enumerate
(
mm_positions
):
for
i
,
pos_info
in
enumerate
(
mm_positions
):
start_pos
=
pos_info
[
"
offset
"
]
start_pos
=
pos_info
.
offset
num_encoder_tokens
=
pos_info
[
"
length
"
]
num_encoder_tokens
=
pos_info
.
length
# The encoder output is needed if the two ranges overlap:
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens,
# [num_computed_tokens,
...
@@ -595,8 +628,16 @@ class TPUModelRunner:
...
@@ -595,8 +628,16 @@ class TPUModelRunner:
assert
req_id
in
self
.
encoder_cache
assert
req_id
in
self
.
encoder_cache
assert
i
in
self
.
encoder_cache
[
req_id
]
assert
i
in
self
.
encoder_cache
[
req_id
]
encoder_output
=
self
.
encoder_cache
[
req_id
][
i
]
encoder_output
=
self
.
encoder_cache
[
req_id
][
i
]
encoder_outputs
.
append
(
encoder_output
[
start_idx
:
end_idx
])
return
encoder_outputs
if
(
is_embed
:
=
pos_info
.
is_embed
)
is
not
None
:
is_embed
=
is_embed
[
start_idx
:
end_idx
]
mm_embeds_item
=
gather_mm_placeholders
(
encoder_output
[
start_idx
:
end_idx
],
is_embed
=
is_embed
,
)
mm_embeds
.
append
(
mm_embeds_item
)
return
mm_embeds
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
execute_model
(
def
execute_model
(
...
@@ -607,25 +648,26 @@ class TPUModelRunner:
...
@@ -607,25 +648,26 @@ class TPUModelRunner:
# Update cached state
# Update cached state
self
.
_update_states
(
scheduler_output
)
self
.
_update_states
(
scheduler_output
)
if
not
scheduler_output
.
total_num_scheduled_tokens
:
if
not
scheduler_output
.
total_num_scheduled_tokens
:
# Return empty ModelRunnerOu
p
tut if there's no work to do.
# Return empty ModelRunnerOut
p
ut if there's no work to do.
return
EMPTY_MODEL_RUNNER_OUTPUT
return
EMPTY_MODEL_RUNNER_OUTPUT
if
self
.
is_multimodal_model
:
if
self
.
is_multimodal_model
:
# Run the multimodal encoder if any.
# Run the multimodal encoder if any.
self
.
_execute_encoder
(
scheduler_output
)
self
.
_execute_
mm_
encoder
(
scheduler_output
)
encoder_output
s
=
self
.
_gather_
encoder_output
s
(
scheduler_output
)
mm_embed
s
=
self
.
_gather_
mm_embedding
s
(
scheduler_output
)
else
:
else
:
encoder_output
s
=
[]
mm_embed
s
=
[]
# Prepare inputs
# Prepare inputs
attn_metadata
,
logits_indices
=
self
.
_prepare_inputs
(
scheduler_output
)
attn_metadata
,
logits_indices
,
padded_num_reqs
=
self
.
_prepare_inputs
(
scheduler_output
)
if
self
.
is_multimodal_model
:
if
self
.
is_multimodal_model
:
# NOTE(woosuk): To unify token ids and soft tokens (vision
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
# as input to the multimodal model, even when the input is text.
if
encoder_output
s
:
if
mm_embed
s
:
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
self
.
input_ids
,
encoder_output
s
)
self
.
input_ids
,
mm_embed
s
)
else
:
else
:
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
self
.
input_ids
)
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
self
.
input_ids
)
input_ids
=
None
input_ids
=
None
...
@@ -637,21 +679,19 @@ class TPUModelRunner:
...
@@ -637,21 +679,19 @@ class TPUModelRunner:
input_ids
=
self
.
input_ids
input_ids
=
self
.
input_ids
inputs_embeds
=
None
inputs_embeds
=
None
num_reqs
=
self
.
input_batch
.
num_reqs
num_reqs
=
self
.
input_batch
.
num_reqs
# NOTE (NickLucche) here we sync with TPU: sampling params tensors
# are copied to device in chunks of pre-compiled padded shape to
# avoid recompilations.
tpu_sampling_metadata
=
TPUSupportedSamplingMetadata
.
\
from_input_batch
(
self
.
input_batch
,
logits_indices
)
# Run the decoder
# Run the decoder
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
):
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
):
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
positions
=
self
.
position_ids
,
positions
=
self
.
position_ids
,
kv_caches
=
self
.
kv_caches
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
)
)
selected_token_ids
=
self
.
model
.
sample_from_hidden
(
hidden_states
=
self
.
select_hidden_states
(
hidden_states
,
hidden_states
,
tpu_sampling_metadata
)
logits_indices
)
tpu_sampling_metadata
=
TPUSupportedSamplingMetadata
.
\
from_input_batch
(
self
.
input_batch
,
padded_num_reqs
,
self
.
device
)
selected_token_ids
=
self
.
sample_from_hidden
(
hidden_states
,
tpu_sampling_metadata
)
# Remove padding on cpu and keep dynamic op outside of xla graph.
# Remove padding on cpu and keep dynamic op outside of xla graph.
selected_token_ids
=
selected_token_ids
.
cpu
()[:
num_reqs
]
selected_token_ids
=
selected_token_ids
.
cpu
()[:
num_reqs
]
...
@@ -751,17 +791,15 @@ class TPUModelRunner:
...
@@ -751,17 +791,15 @@ class TPUModelRunner:
"get_tensor_model_parallel_rank"
,
"get_tensor_model_parallel_rank"
,
return_value
=
xm_tp_rank
):
return_value
=
xm_tp_rank
):
model
=
get_model
(
vllm_config
=
self
.
vllm_config
)
model
=
get_model
(
vllm_config
=
self
.
vllm_config
)
model
=
model
.
eval
()
# Sync all pending XLA execution during model initialization and weight
# loading.
xm
.
mark_step
()
xm
.
mark_step
()
xm
.
wait_device_ops
()
xm
.
wait_device_ops
()
model
=
ModelWrapperV1
(
model
)
self
.
model
=
model
self
.
model
=
torch
.
compile
(
model
,
self
.
sampler
=
TPUSampler
()
backend
=
"openxla"
,
fullgraph
=
True
,
dynamic
=
False
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
_dummy_run
(
self
,
kv_caches
,
num_tokens
:
int
)
->
None
:
def
_dummy_run
(
self
,
num_tokens
:
int
)
->
None
:
if
self
.
is_multimodal_model
:
if
self
.
is_multimodal_model
:
input_ids
=
None
input_ids
=
None
inputs_embeds
=
torch
.
zeros
((
num_tokens
,
self
.
hidden_size
),
inputs_embeds
=
torch
.
zeros
((
num_tokens
,
self
.
hidden_size
),
...
@@ -812,65 +850,81 @@ class TPUModelRunner:
...
@@ -812,65 +850,81 @@ class TPUModelRunner:
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
0
):
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
0
):
out
=
self
.
model
(
input_ids
=
input_ids
,
out
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
position_ids
,
positions
=
position_ids
,
kv_caches
=
kv_caches
,
inputs_embeds
=
inputs_embeds
)
inputs_embeds
=
inputs_embeds
)
self
.
_hidden_states_dtype
=
out
.
dtype
self
.
_hidden_states_dtype
=
out
.
dtype
def
capture_model
(
self
)
->
None
:
def
_precompile_backbone
(
self
)
->
None
:
"""Compile the model."""
logger
.
info
(
"Compiling the model with different input shapes."
)
logger
.
info
(
"Compiling the model with different input shapes."
)
start
=
time
.
perf_counter
()
start
=
time
.
perf_counter
()
for
num_tokens
in
self
.
num_tokens_paddings
:
for
num_tokens
in
self
.
num_tokens_paddings
:
logger
.
info
(
" -- num_tokens: %d"
,
num_tokens
)
logger
.
info
(
" -- num_tokens: %d"
,
num_tokens
)
self
.
_dummy_run
(
self
.
kv_caches
,
num_tokens
)
self
.
_dummy_run
(
num_tokens
)
xm
.
mark_step
()
xm
.
wait_device_ops
()
xm
.
wait_device_ops
()
end
=
time
.
perf_counter
()
end
=
time
.
perf_counter
()
logger
.
info
(
"Compilation finished in in %.2f [secs]."
,
end
-
start
)
logger
.
info
(
"Compilation finished in in %.2f [secs]."
,
end
-
start
)
self
.
_update_num_xla_graphs
(
"model"
)
self
.
_update_num_xla_graphs
(
"model
backbone
"
)
logger
.
info
(
"Compiling sampling with different input shapes."
)
def
_precompile_select_hidden_states
(
self
)
->
None
:
# Compile hidden state selection function for bucketed
# n_tokens x max_num_reqs. Graph is really small so this is fine.
logger
.
info
(
"Compiling select_hidden_states with different input shapes."
)
start
=
time
.
perf_counter
()
start
=
time
.
perf_counter
()
hsize
=
self
.
model_config
.
get_hidden_size
()
hsize
=
self
.
model_config
.
get_hidden_size
()
device
=
self
.
device
# Compile sampling step for different model+sampler outputs in bucketed
# n_tokens x max_num_reqs. Graph is really small so this is fine.
for
num_tokens
in
self
.
num_tokens_paddings
:
for
num_tokens
in
self
.
num_tokens_paddings
:
num_reqs_to_sample
=
MIN_NUM_SEQS
dummy_hidden
=
torch
.
zeros
((
num_tokens
,
hsize
),
dummy_hidden
=
torch
.
randn
((
num_tokens
,
hsize
),
device
=
self
.
device
,
device
=
device
,
dtype
=
self
.
_hidden_states_dtype
)
dtype
=
self
.
_hidden_states_dtype
)
# Compile for [8, 16, .., 128,.., `self.max_num_reqs`]
torch
.
_dynamo
.
mark_dynamic
(
dummy_hidden
,
0
)
while
True
:
for
num_reqs
in
self
.
num_reqs_paddings
:
indices
=
torch
.
zeros
(
indices
=
torch
.
zeros
(
num_reqs
,
num_reqs_to_sample
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
device
=
device
,
torch
.
_dynamo
.
mark_dynamic
(
indices
,
0
)
)
self
.
select_hidden_states
(
dummy_hidden
,
indices
)
xm
.
mark_step
()
logger
.
info
(
" -- num_tokens: %d"
,
num_tokens
)
sampling_meta
=
TPUSupportedSamplingMetadata
.
\
from_input_batch
(
self
.
input_batch
,
indices
)
logger
.
info
(
" -- num_tokens: %d, num_seqs: %d"
,
num_tokens
,
num_reqs_to_sample
)
out
=
self
.
model
.
sample_from_hidden
(
dummy_hidden
,
sampling_meta
)
out
=
out
.
cpu
()
# Requests can't be more than tokens. But do compile for the
# next bigger value in case num_tokens uses bucketed padding.
if
num_reqs_to_sample
>=
min
(
num_tokens
,
self
.
max_num_reqs
):
break
# Make sure to compile the `max_num_reqs` upper-limit case
num_reqs_to_sample
=
_get_padded_num_reqs_with_upper_limit
(
num_reqs_to_sample
+
1
,
self
.
max_num_reqs
)
xm
.
wait_device_ops
()
xm
.
wait_device_ops
()
end
=
time
.
perf_counter
()
end
=
time
.
perf_counter
()
logger
.
info
(
"Compilation finished in in %.2f [secs]."
,
end
-
start
)
self
.
_update_num_xla_graphs
(
"select_hidden_states"
)
def
_precompile_sample_from_hidden
(
self
)
->
None
:
logger
.
info
(
"Compiling sampling with different input shapes."
)
start
=
time
.
perf_counter
()
hsize
=
self
.
model_config
.
get_hidden_size
()
for
num_reqs
in
self
.
num_reqs_paddings
:
dummy_hidden
=
torch
.
zeros
((
num_reqs
,
hsize
),
device
=
self
.
device
,
dtype
=
self
.
_hidden_states_dtype
)
# The first dimension of dummy_hidden cannot be mark_dynamic because
# some operations in the sampler require it to be static.
for
all_greedy
in
[
False
,
True
]:
generate_params_if_all_greedy
=
not
all_greedy
sampling_metadata
=
(
TPUSupportedSamplingMetadata
.
from_input_batch
(
self
.
input_batch
,
num_reqs
,
self
.
device
,
generate_params_if_all_greedy
,
))
sampling_metadata
.
all_greedy
=
all_greedy
self
.
sample_from_hidden
(
dummy_hidden
,
sampling_metadata
)
logger
.
info
(
" -- num_seqs: %d"
,
num_reqs
)
xm
.
wait_device_ops
()
end
=
time
.
perf_counter
()
logger
.
info
(
"Compilation finished in in %.2f [secs]."
,
end
-
start
)
logger
.
info
(
"Compilation finished in in %.2f [secs]."
,
end
-
start
)
self
.
_update_num_xla_graphs
(
"sampling"
)
self
.
_update_num_xla_graphs
(
"sampling"
)
def
capture_model
(
self
)
->
None
:
"""
Precompile all the subgraphs with possible input shapes.
"""
# TODO: precompile encoder
self
.
_precompile_backbone
()
self
.
_precompile_select_hidden_states
()
self
.
_precompile_sample_from_hidden
()
def
initialize_kv_cache
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
def
initialize_kv_cache
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
"""
"""
Initialize KV cache based on `kv_cache_config`.
Initialize KV cache based on `kv_cache_config`.
...
@@ -910,73 +964,39 @@ class TPUModelRunner:
...
@@ -910,73 +964,39 @@ class TPUModelRunner:
self
.
vllm_config
.
compilation_config
.
static_forward_context
,
self
.
vllm_config
.
compilation_config
.
static_forward_context
,
self
.
kv_caches
)
self
.
kv_caches
)
def
reset_dynamo_cache
(
self
):
class
ModelWrapperV1
(
nn
.
Module
):
if
self
.
is_multimodal_model
:
compiled_model
=
self
.
model
.
get_language_model
().
model
def
__init__
(
self
,
model
:
nn
.
Module
):
else
:
super
().
__init__
()
compiled_model
=
self
.
model
.
model
self
.
model
=
model
if
isinstance
(
compiled_model
,
TorchCompileWrapperWithCustomDispatcher
):
self
.
sampler
=
TPUSampler
()
logger
.
info
(
"Clear dynamo cache and cached dynamo bytecode."
)
torch
.
_dynamo
.
eval_frame
.
remove_from_cache
(
def
sample
(
compiled_model
.
original_code_object
)
self
,
logits
:
torch
.
Tensor
,
compiled_model
.
compiled_codes
.
clear
()
sampling_metadata
:
TPUSupportedSamplingMetadata
)
->
SamplerOutput
:
sampler_out
=
self
.
sampler
(
logits
,
sampling_metadata
)
@
torch
.
compile
(
backend
=
"openxla"
,
fullgraph
=
True
,
dynamic
=
False
)
return
sampler_out
def
select_hidden_states
(
self
,
hidden_states
,
indices_do_sample
):
return
hidden_states
[
indices_do_sample
]
def
forward
(
self
,
@
torch
.
compile
(
backend
=
"openxla"
,
fullgraph
=
True
,
dynamic
=
False
)
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
list
[
torch
.
Tensor
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Executes the forward pass of the model.
Args:
input_ids: The input token IDs of shape [num_tokens].
positions: The input position IDs of shape [num_tokens].
kv_caches: The key and value caches. They can be None during the
memory profiling at initialization.
inputs_embeds: The input embeddings of shape [num_tokens,
hidden_size]. It is used for multimodal models.
"""
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
inputs_embeds
=
inputs_embeds
,
)
return
hidden_states
def
sample_from_hidden
(
def
sample_from_hidden
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
sample_
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
TPUSupportedSamplingMetadata
,
sampling_metadata
:
TPUSupportedSamplingMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Sample with xla-friendly function. This function is to be traced
Sample with xla-friendly function. This function is to be traced
separately from `forward` for lighter compilation overhead.
separately from `forward` for lighter compilation overhead.
"""
"""
# Tensor `sample_hidden_states` is of fixed pre-compiled size.
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
sample_hidden_states
=
\
if
sampling_metadata
.
all_greedy
:
hidden_states
[
sampling_metadata
.
indices_do_sample
]
out_tokens
=
torch
.
argmax
(
logits
,
dim
=-
1
,
keepdim
=
True
)
logits
=
self
.
compute_logits
(
sample_hidden_states
)
else
:
# Optimized greedy sampling branch, tracing both paths in a single pass
out_tokens
=
self
.
sampler
(
logits
,
# NOTE all_greedy is a scalar, this is just an optimized if/else.
sampling_metadata
).
sampled_token_ids
out_tokens
=
torch
.
where
(
sampling_metadata
.
all_greedy
,
torch
.
argmax
(
logits
,
dim
=-
1
,
keepdim
=
True
),
self
.
sample
(
logits
,
sampling_metadata
)
\
.
sampled_token_ids
)
return
out_tokens
return
out_tokens
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
# SamplingMetadata here for pruning output in LogitsProcessor, disabled
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
None
)
return
logits
def
get_multimodal_embeddings
(
self
,
*
args
,
**
kwargs
):
def
get_multimodal_embeddings
(
self
,
*
args
,
**
kwargs
):
return
self
.
model
.
get_multimodal_embeddings
(
*
args
,
**
kwargs
)
return
self
.
model
.
get_multimodal_embeddings
(
*
args
,
**
kwargs
)
...
@@ -984,17 +1004,26 @@ class ModelWrapperV1(nn.Module):
...
@@ -984,17 +1004,26 @@ class ModelWrapperV1(nn.Module):
return
self
.
model
.
get_input_embeddings
(
*
args
,
**
kwargs
)
return
self
.
model
.
get_input_embeddings
(
*
args
,
**
kwargs
)
def
_get_padded_number
(
n
:
int
,
multiple
:
int
)
->
int
:
def
_get_req_paddings
(
min_req_size
:
int
,
max_req_size
:
int
)
->
list
[
int
]:
return
((
n
+
multiple
-
1
)
//
multiple
)
*
multiple
logger
.
info
(
"Preparing request paddings:"
)
# assert min_req_size is power of 2
assert
(
min_req_size
&
(
min_req_size
-
1
)
==
0
)
and
min_req_size
>
0
paddings
:
list
=
[]
num
=
max
(
MIN_NUM_SEQS
,
min_req_size
)
while
num
<=
max_req_size
and
(
len
(
paddings
)
==
0
or
paddings
[
-
1
]
!=
num
):
paddings
.
append
(
num
)
logger
.
info
(
" %d"
,
num
)
num
=
_get_padded_num_reqs_with_upper_limit
(
num
+
1
,
max_req_size
)
return
paddings
def
_get_padded_num_reqs_with_upper_limit
(
x
,
upper_limit
)
->
int
:
def
_get_padded_num_reqs_with_upper_limit
(
x
:
int
,
upper_limit
:
int
)
->
int
:
res
=
MIN_NUM_SEQS
if
x
<=
MIN_NUM_SEQS
else
1
<<
(
x
-
1
).
bit_length
()
res
=
MIN_NUM_SEQS
if
x
<=
MIN_NUM_SEQS
else
1
<<
(
x
-
1
).
bit_length
()
return
min
(
res
,
upper_limit
)
return
min
(
res
,
upper_limit
)
def
_get_paddings
(
min_token_size
:
int
,
max_token_size
:
int
,
def
_get_
token_
paddings
(
min_token_size
:
int
,
max_token_size
:
int
,
padding_gap
:
int
)
->
list
[
int
]:
padding_gap
:
int
)
->
list
[
int
]:
"""Generate a list of padding size, starting from min_token_size,
"""Generate a list of padding size, starting from min_token_size,
ending with a number that can cover max_token_size
ending with a number that can cover max_token_size
...
@@ -1004,18 +1033,20 @@ def _get_paddings(min_token_size: int, max_token_size: int,
...
@@ -1004,18 +1033,20 @@ def _get_paddings(min_token_size: int, max_token_size: int,
first increase the size to twice,
first increase the size to twice,
then increase the padding size by padding_gap.
then increase the padding size by padding_gap.
"""
"""
# assert min_token_size is power of 2
assert
(
min_token_size
&
(
min_token_size
-
1
)
==
0
)
and
min_token_size
>
0
paddings
=
[]
paddings
=
[]
num
=
min_token_size
num
=
min_token_size
if
padding_gap
==
0
:
if
padding_gap
==
0
:
logger
.
info
(
"Using exponential paddings:"
)
logger
.
info
(
"Using exponential
token
paddings:"
)
while
num
<=
max_token_size
:
while
num
<=
max_token_size
:
logger
.
info
(
" %d"
,
num
)
logger
.
info
(
" %d"
,
num
)
paddings
.
append
(
num
)
paddings
.
append
(
num
)
num
*=
2
num
*=
2
else
:
else
:
logger
.
info
(
"Using incremental paddings:"
)
logger
.
info
(
"Using incremental
token
paddings:"
)
while
num
<=
padding_gap
:
while
num
<=
padding_gap
:
logger
.
info
(
" %d"
,
num
)
logger
.
info
(
" %d"
,
num
)
paddings
.
append
(
num
)
paddings
.
append
(
num
)
...
...
vllm/v1/worker/tpu_worker.py
View file @
9c4ecf15
...
@@ -157,13 +157,19 @@ class TPUWorker:
...
@@ -157,13 +157,19 @@ class TPUWorker:
runner_kv_caches
)
runner_kv_caches
)
self
.
model_runner
.
_dummy_run
(
self
.
model_runner
.
_dummy_run
(
runner_kv_caches
,
self
.
scheduler_config
.
max_num_batched_tokens
)
num_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
,
)
# Synchronize before measuring the memory usage.
# Synchronize before measuring the memory usage.
xm
.
wait_device_ops
()
xm
.
wait_device_ops
()
# During the profiling run, the model runs without KV cache. After
# the profiling run, the model always runs with KV cache. Here we clear
# the dynamo cache and cached bytecode to ensure the model always has
# one compiled bytecode. Having one FX graph/cached bytecode per
# compiled model is required for `support_torch_compile` decorator to
# skip dynamo guard.
self
.
model_runner
.
reset_dynamo_cache
()
# Get the maximum amount of memory used by the model weights and
# Get the maximum amount of memory used by the model weights and
# intermediate activations.
# intermediate activations.
m
=
xm
.
get_memory_info
(
self
.
device
)
m
=
xm
.
get_memory_info
(
self
.
device
)
...
...
vllm/v1/worker/utils.py
View file @
9c4ecf15
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
import
torch
import
torch
...
@@ -27,3 +29,46 @@ def sanity_check_mm_encoder_outputs(
...
@@ -27,3 +29,46 @@ def sanity_check_mm_encoder_outputs(
f
"but got tensors with shapes
{
[
e
.
shape
for
e
in
mm_embeddings
]
}
"
f
"but got tensors with shapes
{
[
e
.
shape
for
e
in
mm_embeddings
]
}
"
"instead. This is most likely due to incorrect implementation "
"instead. This is most likely due to incorrect implementation "
"of the model's `get_multimodal_embeddings` method."
)
"of the model's `get_multimodal_embeddings` method."
)
def
scatter_mm_placeholders
(
embeds
:
torch
.
Tensor
,
is_embed
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
"""
Scatter the multimodal embeddings into a contiguous tensor that represents
the placeholder tokens.
:class:`vllm.multimodal.processing.PromptUpdateDetails.is_embed`.
Args:
embeds: The multimodal embeddings.
Shape: `(num_embeds, embed_dim)`
is_embed: A boolean mask indicating which positions in the placeholder
tokens need to be filled with multimodal embeddings.
Shape: `(num_placeholders, num_embeds)`
"""
if
is_embed
is
None
:
return
embeds
placeholders
=
embeds
.
new_full
(
(
is_embed
.
shape
[
0
],
embeds
.
shape
[
-
1
]),
fill_value
=
torch
.
nan
,
)
placeholders
[
is_embed
]
=
embeds
return
placeholders
def
gather_mm_placeholders
(
placeholders
:
torch
.
Tensor
,
is_embed
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
"""
Reconstructs the embeddings from the placeholder tokens.
This is the operation of :func:`scatter_mm_placeholders`.
"""
if
is_embed
is
None
:
return
placeholders
return
placeholders
[
is_embed
]
vllm/worker/enc_dec_model_runner.py
View file @
9c4ecf15
...
@@ -16,6 +16,7 @@ from vllm.config import VllmConfig
...
@@ -16,6 +16,7 @@ from vllm.config import VllmConfig
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
MultiModalKwargs
,
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
MultiModalKwargs
,
...
@@ -34,6 +35,7 @@ from vllm.worker.model_runner_base import (
...
@@ -34,6 +35,7 @@ from vllm.worker.model_runner_base import (
from
vllm.worker.utils
import
assert_enc_dec_mr_supported_scenario
from
vllm.worker.utils
import
assert_enc_dec_mr_supported_scenario
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
LORA_WARMUP_RANK
=
8
@
dataclasses
.
dataclass
(
frozen
=
True
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
...
@@ -160,7 +162,11 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
...
@@ -160,7 +162,11 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
if
num_steps
>
1
:
if
num_steps
>
1
:
raise
ValueError
(
"num_steps > 1 is not supported in "
raise
ValueError
(
"num_steps > 1 is not supported in "
"EncoderDecoderModelRunner"
)
"EncoderDecoderModelRunner"
)
if
self
.
lora_config
:
assert
model_input
.
lora_requests
is
not
None
assert
model_input
.
lora_mapping
is
not
None
self
.
set_active_loras
(
model_input
.
lora_requests
,
model_input
.
lora_mapping
)
if
(
model_input
.
attn_metadata
is
not
None
if
(
model_input
.
attn_metadata
is
not
None
and
model_input
.
attn_metadata
.
prefill_metadata
is
None
and
model_input
.
attn_metadata
.
prefill_metadata
is
None
and
model_input
.
attn_metadata
.
decode_metadata
.
use_cuda_graph
):
and
model_input
.
attn_metadata
.
decode_metadata
.
use_cuda_graph
):
...
@@ -268,6 +274,22 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
...
@@ -268,6 +274,22 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
max_num_batched_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
max_num_batched_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
max_num_seqs
=
self
.
scheduler_config
.
max_num_seqs
max_num_seqs
=
self
.
scheduler_config
.
max_num_seqs
# This represents the maximum number of different requests
# that will have unique loras, and therefore the max amount of
# memory consumption. Create dummy lora request copies from the
# lora request passed in, which contains a lora from the lora
# warmup path.
dummy_lora_requests
:
List
[
LoRARequest
]
=
[]
dummy_lora_requests_per_seq
:
List
[
LoRARequest
]
=
[]
if
self
.
lora_config
:
dummy_lora_requests
=
self
.
_add_dummy_loras
(
self
.
lora_config
.
max_loras
)
assert
len
(
dummy_lora_requests
)
==
self
.
lora_config
.
max_loras
dummy_lora_requests_per_seq
=
[
dummy_lora_requests
[
idx
%
len
(
dummy_lora_requests
)]
for
idx
in
range
(
max_num_seqs
)
]
# Profile memory usage with max_num_sequences sequences and the total
# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
# number of tokens equal to max_num_batched_tokens.
seqs
:
List
[
SequenceGroupMetadata
]
=
[]
seqs
:
List
[
SequenceGroupMetadata
]
=
[]
...
@@ -315,6 +337,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
...
@@ -315,6 +337,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
block_tables
=
None
,
block_tables
=
None
,
encoder_seq_data
=
encoder_dummy_data
.
seq_data
,
encoder_seq_data
=
encoder_dummy_data
.
seq_data
,
cross_block_table
=
None
,
cross_block_table
=
None
,
lora_request
=
dummy_lora_requests_per_seq
[
group_id
]
if
dummy_lora_requests_per_seq
else
None
,
multi_modal_data
=
decoder_dummy_data
.
multi_modal_data
multi_modal_data
=
decoder_dummy_data
.
multi_modal_data
or
encoder_dummy_data
.
multi_modal_data
,
or
encoder_dummy_data
.
multi_modal_data
,
multi_modal_placeholders
=
decoder_dummy_data
.
multi_modal_placeholders
=
decoder_dummy_data
.
...
...
vllm/worker/hpu_model_runner.py
View file @
9c4ecf15
...
@@ -32,6 +32,7 @@ from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler,
...
@@ -32,6 +32,7 @@ from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler,
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.config
import
DeviceConfig
,
VllmConfig
from
vllm.config
import
DeviceConfig
,
VllmConfig
from
vllm.distributed
import
broadcast_tensor_dict
from
vllm.distributed.parallel_state
import
get_world_group
from
vllm.distributed.parallel_state
import
get_world_group
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -44,11 +45,13 @@ from vllm.model_executor.layers.sampler import SamplerOutput
...
@@ -44,11 +45,13 @@ from vllm.model_executor.layers.sampler import SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.sampling_metadata
import
SequenceGroupToSample
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensorInputs
,
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensorInputs
,
MultiModalKwargs
)
MultiModalKwargs
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
IntermediateTensors
,
SequenceData
,
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
IntermediateTensors
,
SequenceGroupMetadata
)
Logprob
,
SequenceData
,
SequenceGroupMetadata
,
SequenceOutput
)
from
vllm.utils
import
(
bind_kv_cache
,
is_pin_memory_available
,
from
vllm.utils
import
(
bind_kv_cache
,
is_pin_memory_available
,
make_tensor_with_pad
)
make_tensor_with_pad
)
from
vllm.worker.model_runner_base
import
(
from
vllm.worker.model_runner_base
import
(
...
@@ -100,7 +103,10 @@ def subtuple(obj: object,
...
@@ -100,7 +103,10 @@ def subtuple(obj: object,
if
to_override
is
None
:
if
to_override
is
None
:
to_override
=
{}
to_override
=
{}
fields
=
set
(
to_copy
)
|
set
(
to_override
.
keys
())
fields
=
set
(
to_copy
)
|
set
(
to_override
.
keys
())
values
=
{
f
:
to_override
.
get
(
f
,
getattr
(
obj
,
f
))
for
f
in
fields
}
if
type
(
obj
)
is
dict
:
values
=
{
key
:
obj
[
key
]
for
key
in
fields
if
key
in
obj
}
else
:
values
=
{
f
:
to_override
.
get
(
f
,
getattr
(
obj
,
f
))
for
f
in
fields
}
if
typename
not
in
_TYPE_CACHE
:
if
typename
not
in
_TYPE_CACHE
:
_TYPE_CACHE
[
typename
]
=
collections
.
namedtuple
(
typename
,
_TYPE_CACHE
[
typename
]
=
collections
.
namedtuple
(
typename
,
' '
.
join
(
fields
))
' '
.
join
(
fields
))
...
@@ -533,6 +539,8 @@ class ModelInputForHPU(ModelRunnerInputBase):
...
@@ -533,6 +539,8 @@ class ModelInputForHPU(ModelRunnerInputBase):
virtual_engine
:
int
=
0
virtual_engine
:
int
=
0
lora_ids
:
Optional
[
List
[
int
]]
=
None
lora_ids
:
Optional
[
List
[
int
]]
=
None
async_callback
:
Optional
[
Callable
]
=
None
async_callback
:
Optional
[
Callable
]
=
None
is_first_multi_step
:
bool
=
True
is_last_step
:
bool
=
True
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
tensor_dict
=
{
tensor_dict
=
{
...
@@ -545,6 +553,8 @@ class ModelInputForHPU(ModelRunnerInputBase):
...
@@ -545,6 +553,8 @@ class ModelInputForHPU(ModelRunnerInputBase):
"batch_size_padded"
:
self
.
batch_size_padded
,
"batch_size_padded"
:
self
.
batch_size_padded
,
"virtual_engine"
:
self
.
virtual_engine
,
"virtual_engine"
:
self
.
virtual_engine
,
"lora_ids"
:
self
.
lora_ids
,
"lora_ids"
:
self
.
lora_ids
,
"is_first_multi_step"
:
self
.
is_first_multi_step
,
"is_last_step"
:
self
.
is_last_step
,
}
}
_add_attn_metadata_broadcastable_dict
(
tensor_dict
,
self
.
attn_metadata
)
_add_attn_metadata_broadcastable_dict
(
tensor_dict
,
self
.
attn_metadata
)
return
tensor_dict
return
tensor_dict
...
@@ -656,6 +666,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
...
@@ -656,6 +666,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
self
.
_set_gc_threshold
()
self
.
_set_gc_threshold
()
self
.
use_contiguous_pa
=
envs
.
VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH
self
.
use_contiguous_pa
=
envs
.
VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH
# For multi-step scheduling
self
.
cached_step_outputs
:
List
[
torch
.
Tensor
]
=
[]
def
_set_gc_threshold
(
self
)
->
None
:
def
_set_gc_threshold
(
self
)
->
None
:
# Read https://docs.python.org/3/library/gc.html#gc.set_threshold
# Read https://docs.python.org/3/library/gc.html#gc.set_threshold
# for comprehensive description of gc generations.
# for comprehensive description of gc generations.
...
@@ -1005,6 +1018,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
...
@@ -1005,6 +1018,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
def
_prepare_decode
(
def
_prepare_decode
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
output
=
None
,
)
->
PrepareDecodeMetadata
:
)
->
PrepareDecodeMetadata
:
input_tokens
:
List
[
List
[
int
]]
=
[]
input_tokens
:
List
[
List
[
int
]]
=
[]
input_positions
:
List
[
List
[
int
]]
=
[]
input_positions
:
List
[
List
[
int
]]
=
[]
...
@@ -1035,8 +1049,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
...
@@ -1035,8 +1049,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
for
seq_id
in
seq_ids
:
for
seq_id
in
seq_ids
:
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
generation_token
=
seq_data
.
get_last_token_id
()
if
output
is
None
:
input_tokens
.
append
([
generation_token
])
generation_token
=
seq_data
.
get_last_token_id
()
input_tokens
.
append
([
generation_token
])
seq_len
=
seq_data
.
get_len
()
seq_len
=
seq_data
.
get_len
()
position
=
seq_len
-
1
position
=
seq_len
-
1
...
@@ -1047,6 +1062,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
...
@@ -1047,6 +1062,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
seq_lens
.
append
(
seq_len
)
seq_lens
.
append
(
seq_len
)
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
num_fully_occupied_blocks
=
position
//
self
.
block_size
block_table
=
block_table
[:
num_fully_occupied_blocks
+
1
]
if
len
(
block_table
)
==
0
:
if
len
(
block_table
)
==
0
:
block_number
=
_PAD_BLOCK_ID
block_number
=
_PAD_BLOCK_ID
else
:
else
:
...
@@ -1066,9 +1084,14 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
...
@@ -1066,9 +1084,14 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
block_table
=
block_table
[
-
sliding_window_blocks
:]
block_table
=
block_table
[
-
sliding_window_blocks
:]
block_tables
.
append
(
block_table
)
block_tables
.
append
(
block_table
)
input_tokens
=
torch
.
tensor
(
input_tokens
,
if
output
is
None
:
dtype
=
torch
.
long
,
input_tokens
=
torch
.
tensor
(
input_tokens
,
device
=
self
.
device
)
dtype
=
torch
.
long
,
device
=
self
.
device
)
else
:
real_batch_size
=
len
(
seq_group_metadata_list
)
input_tokens
=
output
[:
real_batch_size
]
input_positions
=
torch
.
tensor
(
input_positions
,
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
device
=
self
.
device
)
...
@@ -1462,7 +1485,27 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
...
@@ -1462,7 +1485,27 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
profiler
.
start
()
profiler
.
start
()
for
_
in
range
(
times
):
for
_
in
range
(
times
):
inputs
=
self
.
prepare_model_input
(
seqs
)
inputs
=
self
.
prepare_model_input
(
seqs
)
self
.
execute_model
(
inputs
,
None
,
warmup_mode
=
True
)
is_single_step
=
\
self
.
vllm_config
.
scheduler_config
.
num_scheduler_steps
==
1
if
is_prompt
or
is_single_step
:
self
.
execute_model
(
inputs
,
None
,
warmup_mode
=
True
)
else
:
# decode with multi-step
inputs
=
dataclasses
.
replace
(
inputs
,
is_first_multi_step
=
True
,
is_last_step
=
False
)
self
.
execute_model
(
inputs
,
None
,
warmup_mode
=
True
,
num_steps
=
2
,
seqs
=
seqs
)
inputs
=
dataclasses
.
replace
(
inputs
,
is_first_multi_step
=
False
,
is_last_step
=
True
)
self
.
execute_model
(
inputs
,
None
,
warmup_mode
=
True
,
num_steps
=
2
,
seqs
=
seqs
)
torch
.
hpu
.
synchronize
()
torch
.
hpu
.
synchronize
()
if
profiler
:
if
profiler
:
profiler
.
step
()
profiler
.
step
()
...
@@ -1985,115 +2028,273 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
...
@@ -1985,115 +2028,273 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
num_steps
:
int
=
1
,
num_steps
:
int
=
1
,
warmup_mode
=
False
,
warmup_mode
=
False
,
seqs
=
None
,
)
->
Optional
[
Union
[
List
[
SamplerOutput
],
IntermediateTensors
]]:
)
->
Optional
[
Union
[
List
[
SamplerOutput
],
IntermediateTensors
]]:
if
num_steps
>
1
:
if
not
model_input
.
is_first_multi_step
:
raise
ValueError
(
if
not
model_input
.
is_last_step
:
"num_steps > 1 is not supported in HPUModelRunner"
)
# not first or last multi-step
return
[]
# last multi-step
output
=
self
.
_decode_sampler_outputs
(
model_input
)
if
self
.
is_driver_worker
else
[]
torch
.
hpu
.
synchronize
()
if
model_input
.
is_first_multi_step
:
# first multi-step
if
self
.
lora_config
:
assert
model_input
.
lora_requests
is
not
None
assert
model_input
.
lora_mapping
is
not
None
self
.
set_active_loras
(
model_input
.
lora_requests
,
model_input
.
lora_mapping
)
input_tokens
=
model_input
.
input_tokens
input_positions
=
model_input
.
input_positions
attn_metadata
=
model_input
.
attn_metadata
sampling_metadata
=
model_input
.
sampling_metadata
real_batch_size
=
model_input
.
real_batch_size
batch_size_padded
=
model_input
.
batch_size_padded
assert
input_tokens
is
not
None
assert
input_positions
is
not
None
assert
sampling_metadata
is
not
None
assert
attn_metadata
is
not
None
is_prompt
=
attn_metadata
.
is_prompt
assert
is_prompt
is
not
None
batch_size
=
input_tokens
.
size
(
0
)
seq_len
=
self
.
_seq_len
(
attn_metadata
)
use_graphs
=
self
.
_use_graphs
(
batch_size
,
seq_len
,
is_prompt
)
self
.
_check_config
(
batch_size
,
seq_len
,
is_prompt
,
warmup_mode
)
lora_mask
:
torch
.
Tensor
=
None
lora_logits_mask
:
torch
.
Tensor
=
None
if
self
.
lora_config
:
assert
model_input
.
lora_ids
is
not
None
lora_mask
,
lora_logits_mask
=
self
.
create_lora_mask
(
input_tokens
,
model_input
.
lora_ids
,
attn_metadata
.
is_prompt
)
execute_model_kwargs
=
{
"input_ids"
:
input_tokens
,
"positions"
:
input_positions
,
"attn_metadata"
:
self
.
trim_attn_metadata
(
attn_metadata
),
"intermediate_tensors"
:
intermediate_tensors
,
"lora_mask"
:
lora_mask
,
"virtual_engine"
:
model_input
.
virtual_engine
,
**
(
model_input
.
multi_modal_kwargs
or
{}),
}
if
htorch
.
utils
.
internal
.
is_lazy
():
execute_model_kwargs
.
update
(
{
"bypass_hpu_graphs"
:
not
use_graphs
})
if
self
.
lora_config
:
htorch
.
core
.
mark_step
()
assert
model_input
.
lora_requests
is
not
None
if
self
.
is_driver_worker
:
assert
model_input
.
lora_mapping
is
not
None
model_event_name
=
(
"model_"
self
.
set_active_loras
(
model_input
.
lora_requests
,
f
"
{
'prompt'
if
is_prompt
else
'decode'
}
_"
model_input
.
lora_mapping
)
f
"bs
{
batch_size
}
_"
input_tokens
=
model_input
.
input_tokens
f
"seq
{
seq_len
}
_"
input_positions
=
model_input
.
input_positions
f
"graphs
{
'T'
if
use_graphs
else
'F'
}
"
)
attn_metadata
=
model_input
.
attn_metadata
else
:
sampling_metadata
=
model_input
.
sampling_metadata
model_event_name
=
'model_executable'
real_batch_size
=
model_input
.
real_batch_size
if
num_steps
>
1
:
batch_size_padded
=
model_input
.
batch_size_padded
# in case of multi-step scheduling
assert
input_tokens
is
not
None
# we only want to pythonize in the last step
assert
input_positions
is
not
None
sampling_metadata
.
skip_sampler_cpu_output
=
True
assert
sampling_metadata
is
not
None
self
.
model
.
model
.
sampler
.
include_gpu_probs_tensor
=
True
assert
attn_metadata
is
not
None
cache_orig_output_tokens_len
:
List
[
Dict
]
=
[]
is_prompt
=
attn_metadata
.
is_prompt
assert
is_prompt
is
not
None
def
try_revert_dummy_output_tokens
():
batch_size
=
input_tokens
.
size
(
0
)
if
len
(
cache_orig_output_tokens_len
)
>
0
:
seq_len
=
self
.
_seq_len
(
attn_metadata
)
# Reuse the original output token ids length
use_graphs
=
self
.
_use_graphs
(
batch_size
,
seq_len
,
is_prompt
)
for
i
,
seq_group_metadata
in
enumerate
(
self
.
_check_config
(
batch_size
,
seq_len
,
is_prompt
,
warmup_mode
)
seq_group_metadata_list
):
for
j
,
data
in
seq_group_metadata
.
seq_data
.
items
():
orig_output_tokens_len
=
\
cache_orig_output_tokens_len
[
i
][
j
]
data
.
output_token_ids
=
\
data
.
output_token_ids
[:
orig_output_tokens_len
]
for
i
in
range
(
num_steps
):
if
i
!=
0
and
not
self
.
is_driver_worker
:
broadcast_data
=
broadcast_tensor_dict
(
src
=
0
)
if
'early_exit'
in
broadcast_data
and
broadcast_data
[
'early_exit'
]:
return
[
output
]
if
num_steps
==
1
else
[]
execute_model_kwargs
.
update
({
"input_ids"
:
broadcast_data
[
"input_ids"
],
"positions"
:
broadcast_data
[
"positions"
],
"attn_metadata"
:
self
.
trim_attn_metadata
(
broadcast_data
[
"attn_metadata"
])
})
with
self
.
profiler
.
record_event
(
'internal'
,
model_event_name
):
hidden_states
=
self
.
model
.
forward
(
**
execute_model_kwargs
,
selected_token_indices
=
sampling_metadata
.
selected_token_indices
)
if
self
.
lora_config
:
LoraMask
.
setLoraMask
(
lora_logits_mask
.
index_select
(
0
,
sampling_metadata
.
selected_token_indices
))
# Compute the logits.
with
self
.
profiler
.
record_event
(
'internal'
,
(
'compute_logits_'
f
'
{
"prompt"
if
is_prompt
else
"decode"
}
_bs'
f
'
{
batch_size
}
_'
f
'seq
{
seq_len
}
'
)):
if
num_steps
==
1
:
sampling_metadata
.
selected_token_indices
=
None
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
htorch
.
core
.
mark_step
()
# Only perform sampling in the driver worker.
if
not
self
.
is_driver_worker
:
continue
lora_mask
:
torch
.
Tensor
=
None
if
model_input
.
async_callback
is
not
None
:
lora_logits_mask
:
torch
.
Tensor
=
None
model_input
.
async_callback
()
if
self
.
lora_config
:
# Sample the next token.
assert
model_input
.
lora_ids
is
not
None
with
self
.
profiler
.
record_event
(
lora_mask
,
lora_logits_mask
=
self
.
create_lora_mask
(
'internal'
,
(
'sample_'
input_tokens
,
model_input
.
lora_ids
,
attn_metadata
.
is_prompt
)
f
'
{
"prompt"
if
is_prompt
else
"decode"
}
_'
f
'bs
{
batch_size
}
_'
execute_model_kwargs
=
{
f
'seq
{
seq_len
}
'
)):
"input_ids"
:
input_tokens
,
output
=
self
.
model
.
sample
(
"positions"
:
input_positions
,
logits
=
logits
,
"attn_metadata"
:
self
.
trim_attn_metadata
(
attn_metadata
),
sampling_metadata
=
sampling_metadata
,
"intermediate_tensors"
:
intermediate_tensors
,
)
"lora_mask"
:
lora_mask
,
if
num_steps
>
1
:
"virtual_engine"
:
model_input
.
virtual_engine
,
output
=
output
.
sampled_token_ids
**
(
model_input
.
multi_modal_kwargs
or
{}),
self
.
cached_step_outputs
.
append
(
}
output
.
detach
().
clone
())
if
htorch
.
utils
.
internal
.
is_lazy
():
htorch
.
core
.
mark_step
()
execute_model_kwargs
.
update
({
"bypass_hpu_graphs"
:
not
use_graphs
})
if
i
<
num_steps
-
1
:
if
i
==
0
:
htorch
.
core
.
mark_step
()
if
model_input
.
async_callback
is
not
None
:
if
self
.
is_driver_worker
:
ctx
=
model_input
.
async_callback
.
keywords
[
# type: ignore
model_event_name
=
(
"model_"
"ctx"
]
f
"
{
'prompt'
if
is_prompt
else
'decode'
}
_"
seq_group_metadata_list
=
\
f
"bs
{
batch_size
}
_"
ctx
.
seq_group_metadata_list
f
"seq
{
seq_len
}
_"
elif
seqs
is
not
None
:
f
"graphs
{
'T'
if
use_graphs
else
'F'
}
"
)
seq_group_metadata_list
=
seqs
else
:
raise
RuntimeError
(
"seq_group_metadata_list is uninitialized"
)
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
# Skip empty steps
seq_group_metadata
.
state
.
current_step
+=
(
num_steps
-
2
)
# Cache the original output token ids
cache_orig_output_tokens_len
.
append
({})
for
j
,
data
in
seq_group_metadata
.
seq_data
.
items
():
cache_orig_output_tokens_len
[
i
][
j
]
=
\
len
(
data
.
output_token_ids
)
for
seq_group_metadata
in
seq_group_metadata_list
:
for
data
in
seq_group_metadata
.
seq_data
.
values
():
max_output_len
=
sampling_metadata
.
seq_groups
[
0
].
sampling_params
.
max_tokens
if
len
(
data
.
output_token_ids
)
<
max_output_len
-
1
:
# add a place holder for prepare_decode
# arbitrary value, this could be any token
dummy_token
=
(
540
,
)
data
.
output_token_ids
+=
(
dummy_token
)
else
:
broadcast_tensor_dict
({
'early_exit'
:
True
},
src
=
0
)
if
num_steps
==
1
:
return
[
output
]
else
:
try_revert_dummy_output_tokens
()
return
[]
result
=
self
.
_prepare_decode
(
seq_group_metadata_list
,
output
=
output
)
execute_model_kwargs
.
update
({
"input_ids"
:
result
.
input_tokens
,
"positions"
:
result
.
input_positions
,
"attn_metadata"
:
self
.
trim_attn_metadata
(
result
.
attn_metadata
)
})
model_kwargs_broadcast_data
=
{
"input_ids"
:
result
.
input_tokens
,
"positions"
:
result
.
input_positions
,
"attn_metadata"
:
vars
(
result
.
attn_metadata
)
}
broadcast_tensor_dict
(
model_kwargs_broadcast_data
,
src
=
0
)
else
:
try_revert_dummy_output_tokens
()
if
self
.
is_driver_worker
and
self
.
profiler
.
enabled
:
# Stop recording 'execute_model' event
self
.
profiler
.
end
()
event_end
=
self
.
profiler
.
get_timestamp_us
()
counters
=
self
.
profiler_counter_helper
.
get_counter_dict
(
cache_config
=
self
.
cache_config
,
duration
=
event_end
-
self
.
event_start
,
seq_len
=
seq_len
,
batch_size_padded
=
batch_size_padded
,
real_batch_size
=
real_batch_size
,
is_prompt
=
is_prompt
)
self
.
profiler
.
record_counter
(
self
.
event_start
,
counters
)
if
num_steps
==
1
:
return
[
output
]
if
self
.
is_driver_worker
else
[]
else
:
return
[]
return
output
if
type
(
output
)
is
list
else
[
output
]
def
_decode_sampler_outputs
(
self
,
model_input
):
use_async_out_proc
=
model_input
.
async_callback
is
not
None
sampler_outputs
=
[]
num_outputs
=
len
(
self
.
cached_step_outputs
)
for
i
in
range
(
num_outputs
):
next_token_ids
=
self
.
cached_step_outputs
.
pop
(
0
)
next_token_ids
=
next_token_ids
.
cpu
().
tolist
()
sampler_output
=
self
.
_make_decode_output
(
next_token_ids
,
model_input
.
sampling_metadata
.
seq_groups
)
sampler_outputs
.
append
(
sampler_output
)
if
i
<
num_outputs
-
1
and
use_async_out_proc
:
assert
model_input
.
async_callback
is
not
None
ctx
=
model_input
.
async_callback
.
keywords
[
# type: ignore
"ctx"
]
ctx
.
append_output
(
outputs
=
[
sampler_output
],
seq_group_metadata_list
=
ctx
.
seq_group_metadata_list
,
scheduler_outputs
=
ctx
.
scheduler_outputs
,
is_async
=
False
,
is_last_step
=
False
,
is_first_step_output
=
False
)
model_input
.
async_callback
()
if
use_async_out_proc
:
return
[
sampler_outputs
[
-
1
]]
else
:
else
:
model_event_name
=
'model_executable'
return
sampler_outputs
with
self
.
profiler
.
record_event
(
'internal'
,
model_event_name
):
hidden_states
=
self
.
model
.
forward
(
**
execute_model_kwargs
,
selected_token_indices
=
sampling_metadata
.
selected_token_indices
)
if
self
.
lora_config
:
def
_make_decode_output
(
LoraMask
.
setLoraMask
(
self
,
lora_logits_mask
.
index_select
(
next_token_ids
:
List
[
List
[
int
]],
0
,
sampling_metadata
.
selected_token_indices
))
seq_groups
:
List
[
SequenceGroupToSample
],
)
->
SamplerOutput
:
# Compute the logits.
zero_logprob
=
Logprob
(
0.0
)
with
self
.
profiler
.
record_event
(
sampler_outputs
=
[]
'internal'
,
(
'compute_logits_'
batch_idx
=
0
f
'
{
"prompt"
if
is_prompt
else
"decode"
}
_bs'
for
seq_group
in
seq_groups
:
f
'
{
batch_size
}
_'
seq_ids
=
seq_group
.
seq_ids
f
'seq
{
seq_len
}
'
)):
seq_outputs
=
[]
sampling_metadata
.
selected_token_indices
=
None
for
seq_id
in
seq_ids
:
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
next_token_id
=
next_token_ids
[
batch_idx
][
0
]
sampling_metadata
)
seq_outputs
.
append
(
htorch
.
core
.
mark_step
()
SequenceOutput
(
seq_id
,
next_token_id
,
# Only perform sampling in the driver worker.
{
next_token_id
:
zero_logprob
}))
if
not
self
.
is_driver_worker
:
batch_idx
+=
1
return
[]
sampler_outputs
.
append
(
CompletionSequenceGroupOutput
(
seq_outputs
,
None
))
if
model_input
.
async_callback
is
not
None
:
return
SamplerOutput
(
sampler_outputs
)
model_input
.
async_callback
()
# Sample the next token.
with
self
.
profiler
.
record_event
(
'internal'
,
(
'sample_'
f
'
{
"prompt"
if
is_prompt
else
"decode"
}
_'
f
'bs
{
batch_size
}
_'
f
'seq
{
seq_len
}
'
)):
output
=
self
.
model
.
sample
(
logits
=
logits
,
sampling_metadata
=
sampling_metadata
,
)
output
.
outputs
=
output
.
outputs
[:
real_batch_size
]
htorch
.
core
.
mark_step
()
if
self
.
is_driver_worker
and
self
.
profiler
.
enabled
:
# Stop recording 'execute_model' event
self
.
profiler
.
end
()
event_end
=
self
.
profiler
.
get_timestamp_us
()
counters
=
self
.
profiler_counter_helper
.
get_counter_dict
(
cache_config
=
self
.
cache_config
,
duration
=
event_end
-
self
.
event_start
,
seq_len
=
seq_len
,
batch_size_padded
=
batch_size_padded
,
real_batch_size
=
real_batch_size
,
is_prompt
=
is_prompt
)
self
.
profiler
.
record_counter
(
self
.
event_start
,
counters
)
return
[
output
]
def
shutdown_inc
(
self
):
def
shutdown_inc
(
self
):
can_finalize_inc
=
False
can_finalize_inc
=
False
...
...
Prev
1
…
13
14
15
16
17
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment