Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fcfc474d
Commit
fcfc474d
authored
Apr 09, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.3' into v0.8.3-dev
parents
bb94d2e5
296c6572
Changes
503
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
942 additions
and
348 deletions
+942
-348
vllm/v1/sample/tpu/metadata.py
vllm/v1/sample/tpu/metadata.py
+65
-104
vllm/v1/serial_utils.py
vllm/v1/serial_utils.py
+8
-0
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+261
-0
vllm/v1/spec_decode/metrics.py
vllm/v1/spec_decode/metrics.py
+62
-0
vllm/v1/spec_decode/ngram_proposer.py
vllm/v1/spec_decode/ngram_proposer.py
+23
-12
vllm/v1/structured_output/__init__.py
vllm/v1/structured_output/__init__.py
+5
-5
vllm/v1/structured_output/backend_guidance.py
vllm/v1/structured_output/backend_guidance.py
+12
-7
vllm/v1/structured_output/backend_xgrammar.py
vllm/v1/structured_output/backend_xgrammar.py
+14
-8
vllm/v1/structured_output/utils.py
vllm/v1/structured_output/utils.py
+0
-4
vllm/v1/utils.py
vllm/v1/utils.py
+7
-4
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+13
-7
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+146
-63
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+16
-2
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+166
-89
vllm/v1/worker/tpu_worker.py
vllm/v1/worker/tpu_worker.py
+41
-15
vllm/v1/worker/utils.py
vllm/v1/worker/utils.py
+29
-0
vllm/version.py
vllm/version.py
+9
-0
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+1
-0
vllm/worker/cpu_worker.py
vllm/worker/cpu_worker.py
+9
-1
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+55
-27
No files found.
vllm/v1/sample/tpu/metadata.py
View file @
fcfc474d
...
...
@@ -5,7 +5,18 @@ from typing import Optional
import
torch
import
torch_xla.core.xla_model
as
xm
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
DEFAULT_SAMPLING_PARAMS
=
dict
(
temperature
=-
1.0
,
min_p
=
0.0
,
# strictly disabled for now
# top_k=-1,
# top_p=0.0,
# frequency_penalties=0.0,
# presence_penalties=0.0,
# repetition_penalties=0.0,
)
@
dataclass
...
...
@@ -20,14 +31,8 @@ class TPUSupportedSamplingMetadata:
top_k
:
torch
.
Tensor
=
None
top_p
:
torch
.
Tensor
=
None
# XLA-unfriendly control flow in Sampler
all_greedy
:
bool
=
False
all_random
:
bool
=
False
# Greedy sampling flag for compiling single xla graph.
do_argmax
:
torch
.
Tensor
=
None
# speculation not supported
spec_token_ids
=
None
all_greedy
:
torch
.
Tensor
=
None
# Generator not supported by xla
generators
:
dict
[
int
,
...
...
@@ -54,106 +59,62 @@ class TPUSupportedSamplingMetadata:
bad_words_token_ids
=
None
indices_do_sample
:
torch
.
Tensor
=
None
def
__post_init__
(
self
):
temp
=
self
.
temperature
if
self
.
indices_do_sample
is
None
:
self
.
indices_do_sample
=
torch
.
zeros
(
temp
.
shape
[
0
],
device
=
temp
.
device
,
dtype
=
torch
.
int32
)
if
self
.
do_argmax
is
None
:
self
.
do_argmax
=
torch
.
tensor
(
0
,
dtype
=
torch
.
bool
,
device
=
temp
.
device
)
@
classmethod
def
from_sampling_metadata
(
cls
,
metadata
:
SamplingMetadata
,
padded_do_sample_indices
:
torch
.
Tensor
,
num_do_sample
:
int
,
device
:
torch
.
device
)
->
"TPUSupportedSamplingMetadata"
:
def
from_input_batch
(
cls
,
input_batch
:
InputBatch
,
indices_do_sample
:
torch
.
Tensor
)
->
"TPUSupportedSamplingMetadata"
:
"""
Create an XLA-frienly SamplingMetadata structure. Do so by first
instantiating an object with fixed-sized tensors and then writing the
values in input `metadata`. Do that only for non-None values so that
recompilation is not triggered for optional values (None/torch.Tensor).
In order to handle different sizes for the params that range from 1 up
to `max_num_seqs`, pad tensors to the closest pre-compiled shape.
Same thing for `padded_do_sample_indices`, which contains the indices
to be fed to the Sampler, padded to the closest pre-compiled shape.
Eg. pad to 4 temperature: [0.7, 0.2]=>[0.7, 0.2, 0.0, 0.0]
do_sample_indices: [4, 10]=>padded_do_sample_indices: [4, 10, 0, 0]
Copy sampling tensors slices from `input_batch` to on device tensors.
`InputBatch._make_sampling_metadata` causes recompilation on XLA as it
slices dynamic shapes on device tensors. This impl moves the dynamic
ops to CPU and produces tensors of fixed `padded_num_reqs` size. It
also reuses the on-device persistent tensors managed in `input_batch`
to reduce waste.
`indices_do_sample` contains the indices to be fed to the Sampler,
normally one per request, here padded to the closest pre-compiled shape
We expect sampling params tensors to be padded to the same fixed shape.
Eg. 3 requests, tensors padded to 4
temperature: [0.7, 0.2, 0.9]=>[0.7, 0.2, 0.9, 0.0]
sample indices: [4, 10, 11]=>indices_do_sample: [4, 10, 11, 0]
"""
metadata
=
cls
.
_validate_sampling_metadata
(
metadata
)
# NOTE we have to initialize default tensor-based params first and
# skip None values altogether to produce the same xla graph.
num_samples
=
len
(
padded_do_sample_indices
)
do_argmax
=
torch
.
t
ensor
(
metadata
.
all_greedy
,
dtype
=
torch
.
bool
,
device
=
device
)
new_metadata
=
cls
.
get_default_sampling_params
(
num_samples
,
device
,
indices_do_sample
=
\
padded_do_sample_indices
,
do_argmax
=
do_argmax
)
supported_params
=
\
TPUSupportedSa
mpling
Metadata
.
_get_default_params_values
()
# Copy input non-None values into `new_metadata` fixed-sized tensors.
for
p_name
in
supported_params
:
old_val
=
getattr
(
metadata
,
p_name
)
new_val
=
getattr
(
new_metadata
,
p_name
)
if
isinstance
(
old_val
,
torch
.
Tensor
):
new_val
[:
num_do_sample
]
=
old_val
setattr
(
new_metadata
,
p_name
,
new_val
)
num_reqs
=
input_batch
.
num_reqs
padded_num_reqs
=
len
(
indices_do_sample
)
def
copy_slice
(
cpu_tensor
:
torch
.
Tensor
,
tpu_tensor
:
torch
.
Tensor
,
fill_val
)
->
torch
.
T
ensor
:
# Copy slice from CPU to corresponding TPU pre-allocated tensor.
# Pad value is the default one.
cpu_tensor
[
num_reqs
:
padded_num_reqs
]
=
fill_val
# Subtle compilation: len(tpu_tensor) must be >= `padded_num_reqs`
tpu_tensor
[:
padded_num_reqs
]
=
cpu_tensor
[:
padded_num_reqs
]
# NOTE NickLucche The sync CPU-TPU graph we produce here must be
# consistent. We can't have flags to skip copies or we'll end up
# reco
mp
i
ling
.
copy_slice
(
input_batch
.
temperature_cpu_tensor
,
input_batch
.
temperature
,
DEFAULT_SAMPLING_PARAMS
[
"temperature"
])
# TODO Temporarily disabled until sampling options are enabled
# copy_slice(input_batch.top_p_cpu_tensor, input_batch.top_p
)
# copy_slice(input_batch.top_k_cpu_tensor, input_batch.top_k)
copy_slice
(
input_batch
.
min_p_cpu_tensor
,
input_batch
.
min_p
,
DEFAULT_SAMPLING_PARAMS
[
"min_p"
]
)
xm
.
mark_step
()
xm
.
wait_device_ops
()
return
new_metadata
@
classmethod
def
get_default_sampling_params
(
cls
,
num_samples
:
int
,
device
:
torch
.
device
,
indices_do_sample
=
None
,
do_argmax
=
None
)
->
"TPUSupportedSamplingMetadata"
:
# As sampling happens on a single traced graph, options
# are "disabled" by having them evaluate to an Identity op.
# Note that initialization is dependent on num_samples.
sampling_metadata_disable_value
=
\
TPUSupportedSamplingMetadata
.
_get_default_params_values
()
init_kwargs
=
dict
()
for
p_name
,
(
default_val
,
dtype
)
in
sampling_metadata_disable_value
.
items
():
default_tensor
=
torch
.
full
((
num_samples
,
),
default_val
,
dtype
=
dtype
,
device
=
device
)
init_kwargs
[
p_name
]
=
default_tensor
return
cls
(
**
init_kwargs
,
indices_do_sample
=
indices_do_sample
,
do_argmax
=
do_argmax
)
@
staticmethod
def
_validate_sampling_metadata
(
sampling_metadata
:
SamplingMetadata
)
->
SamplingMetadata
:
if
sampling_metadata
.
all_greedy
:
# Set to None since #13587. Make sure default isn't overruled.
assert
sampling_metadata
.
temperature
is
None
return
sampling_metadata
@
staticmethod
def
_get_default_params_values
():
return
dict
(
# Since #13587 greedy sampling requires branching off which leads
# to separate graphs. We set temp to noop and handle argmax here.
temperature
=
(
1.0
,
torch
.
float32
),
min_p
=
(
0.0
,
torch
.
float32
),
# strictly disabled for now
# top_k=(-1, torch.int32),
# top_p=(0.0, torch.float32),
# frequency_penalties=(0.0, torch.float32),
# presence_penalties=(0.0, torch.float32),
# repetition_penalties=(0.0, torch.float32),
)
\ No newline at end of file
# Slice persistent device tensors to a fixed pre-compiled padded shape.
return
cls
(
temperature
=
input_batch
.
temperature
[:
padded_num_reqs
],
# Scalar tensor for xla-friendly tracing.
all_greedy
=
torch
.
tensor
(
input_batch
.
all_greedy
,
dtype
=
torch
.
bool
,
device
=
input_batch
.
device
),
# TODO enable more and avoid returning None values
top_p
=
None
,
# input_batch.top_p[:padded_num_reqs],
top_k
=
None
,
# input_batch.top_k[:padded_num_reqs],
min_p
=
input_batch
.
min_p
[:
padded_num_reqs
],
generators
=
input_batch
.
generators
,
indices_do_sample
=
indices_do_sample
)
vllm/v1/serial_utils.py
View file @
fcfc474d
# SPDX-License-Identifier: Apache-2.0
import
pickle
from
types
import
FunctionType
from
typing
import
Any
,
Optional
import
cloudpickle
import
torch
from
msgspec
import
msgpack
CUSTOM_TYPE_TENSOR
=
1
CUSTOM_TYPE_PICKLE
=
2
CUSTOM_TYPE_CLOUDPICKLE
=
3
class
MsgpackEncoder
:
...
...
@@ -41,6 +44,9 @@ def custom_enc_hook(obj: Any) -> Any:
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
return
msgpack
.
Ext
(
CUSTOM_TYPE_TENSOR
,
pickle
.
dumps
(
obj
.
numpy
()))
if
isinstance
(
obj
,
FunctionType
):
return
msgpack
.
Ext
(
CUSTOM_TYPE_CLOUDPICKLE
,
cloudpickle
.
dumps
(
obj
))
return
msgpack
.
Ext
(
CUSTOM_TYPE_PICKLE
,
pickle
.
dumps
(
obj
))
...
...
@@ -49,5 +55,7 @@ def custom_ext_hook(code: int, data: memoryview) -> Any:
return
torch
.
from_numpy
(
pickle
.
loads
(
data
))
if
code
==
CUSTOM_TYPE_PICKLE
:
return
pickle
.
loads
(
data
)
if
code
==
CUSTOM_TYPE_CLOUDPICKLE
:
return
cloudpickle
.
loads
(
data
)
raise
NotImplementedError
(
f
"Extension type code
{
code
}
is not supported"
)
vllm/v1/spec_decode/eagle.py
0 → 100644
View file @
fcfc474d
# SPDX-License-Identifier: Apache-2.0
import
torch
import
torch.nn
as
nn
import
triton
import
triton.language
as
tl
from
vllm.config
import
VllmConfig
from
vllm.forward_context
import
set_forward_context
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
class
EagleProposer
:
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
):
self
.
vllm_config
=
vllm_config
self
.
num_speculative_tokens
=
(
vllm_config
.
speculative_config
.
num_speculative_tokens
)
self
.
block_size
=
vllm_config
.
cache_config
.
block_size
self
.
arange
=
torch
.
arange
(
vllm_config
.
scheduler_config
.
max_num_seqs
,
device
=
device
)
def
propose
(
self
,
# [num_tokens]
target_token_ids
:
torch
.
Tensor
,
# [num_tokens]
target_positions
:
torch
.
Tensor
,
# [num_tokens, hidden_size]
target_hidden_states
:
torch
.
Tensor
,
# [num_tokens]
target_slot_mapping
:
torch
.
Tensor
,
# [batch_size]
next_token_ids
:
torch
.
Tensor
,
# [batch_size + 1] starting with 0
cu_num_tokens
:
torch
.
Tensor
,
# [batch_size, max_num_blocks_per_req]
block_table
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
num_tokens
=
target_token_ids
.
shape
[
0
]
batch_size
=
next_token_ids
.
shape
[
0
]
last_token_indices
=
cu_num_tokens
[
1
:]
-
1
input_ids
=
torch
.
empty_like
(
target_token_ids
)
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
input_ids
[:
-
1
]
=
target_token_ids
[
1
:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
input_ids
[
last_token_indices
]
=
next_token_ids
seq_lens
=
target_positions
[
last_token_indices
]
+
1
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
max_seq_len
=
seq_lens
.
max
().
item
()
max_num_tokens
=
(
cu_num_tokens
[
1
:]
-
cu_num_tokens
[:
-
1
]).
max
().
item
()
attn_metadata
=
FlashAttentionMetadata
(
num_actual_tokens
=
num_tokens
,
max_query_len
=
max_num_tokens
,
query_start_loc
=
cu_num_tokens
,
max_seq_len
=
max_seq_len
,
seq_lens
=
seq_lens
,
block_table
=
block_table
,
slot_mapping
=
target_slot_mapping
,
# TODO(woosuk): Support cascade attention.
use_cascade
=
False
,
common_prefix_len
=
0
,
cu_prefix_query_lens
=
None
,
prefix_kv_lens
=
None
,
suffix_kv_lens
=
None
,
)
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
):
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
hidden_states
=
target_hidden_states
,
positions
=
target_positions
,
)
sample_hidden_states
=
hidden_states
[
last_token_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
draft_token_ids
,
draft_probs
=
compute_probs_and_sample_next_token
(
logits
,
sampling_metadata
)
# Early exit if there is only one draft token to be generated.
if
self
.
num_speculative_tokens
==
1
:
# [batch_size, 1] and [batch_size, 1, vocab_size]
return
draft_token_ids
.
view
(
-
1
,
1
),
draft_probs
.
unsqueeze
(
dim
=
1
)
# Generate the remaining draft tokens.
draft_token_ids_list
=
[
draft_token_ids
]
draft_probs_list
=
[
draft_probs
]
positions
=
target_positions
[
last_token_indices
]
hidden_states
=
sample_hidden_states
attn_metadata
.
num_actual_tokens
=
batch_size
attn_metadata
.
max_query_len
=
1
attn_metadata
.
query_start_loc
=
self
.
arange
[:
batch_size
]
for
_
in
range
(
self
.
num_speculative_tokens
-
1
):
# Update the inputs.
input_ids
=
draft_token_ids_list
[
-
1
]
positions
+=
1
attn_metadata
.
max_seq_len
+=
1
attn_metadata
.
seq_lens
+=
1
# Compute the slot mapping.
block_numbers
=
positions
//
self
.
block_size
block_ids
=
block_table
.
gather
(
dim
=
1
,
index
=
block_numbers
.
view
(
-
1
,
1
))
block_ids
=
block_ids
.
view
(
-
1
)
attn_metadata
.
slot_mapping
=
(
block_ids
*
self
.
block_size
+
positions
%
self
.
block_size
)
# Run the model.
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
):
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
hidden_states
=
hidden_states
,
positions
=
positions
,
)
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
None
)
draft_token_ids
,
probs
=
compute_probs_and_sample_next_token
(
logits
,
sampling_metadata
)
draft_token_ids_list
.
append
(
draft_token_ids
)
draft_probs_list
.
append
(
probs
)
# [batch_size, num_speculative_tokens]
draft_token_ids
=
torch
.
stack
(
draft_token_ids_list
,
dim
=
1
)
# [batch_size, num_speculative_tokens, vocab_size]
draft_probs
=
torch
.
stack
(
draft_probs_list
,
dim
=
1
)
return
draft_token_ids
,
draft_probs
@
staticmethod
def
prepare_inputs
(
# [batch_size + 1]
cu_target_query_lens
:
torch
.
Tensor
,
# [batch_size]
num_rejected_tokens
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# cu_target_query_lens: [0, a, a + b, a + b + c]
# num_rejected_tokens: [n1, n2, n3]
# num_tokens_per_req: [a - n1, b - n2, c - n3]
# cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
# token_indices: [0, 1, ..., a - n1 - 1,
# a, a + 1, ..., a + b - n2 - 1,
# a + b, a + b + 1, ..., a + b + c - n3 - 1]
# [0, a, a + b, a + b + c] -> [a, b, c]
query_len_per_req
=
(
cu_target_query_lens
[
1
:]
-
cu_target_query_lens
[:
-
1
])
# [a, b, c] -> [a - n1, b - n2, c - n3]
num_tokens_per_req
=
query_len_per_req
-
num_rejected_tokens
cu_num_tokens
=
torch
.
empty_like
(
cu_target_query_lens
)
torch
.
cumsum
(
num_tokens_per_req
,
dim
=
0
,
out
=
cu_num_tokens
[
1
:])
cu_num_tokens
[
0
]
=
0
# FIXME(woosuk): Avoid synchronization.
num_tokens
=
cu_num_tokens
[
-
1
].
item
()
token_indices
=
torch
.
empty
(
num_tokens
,
dtype
=
torch
.
int32
,
device
=
cu_num_tokens
.
device
,
)
batch_size
=
num_rejected_tokens
.
shape
[
0
]
BLOCK_SIZE
=
1024
prepare_input_kernel
[(
batch_size
,
)](
token_indices
,
cu_target_query_lens
,
cu_num_tokens
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
return
cu_num_tokens
,
token_indices
def
load_model
(
self
,
target_model
:
nn
.
Module
)
->
None
:
self
.
model
=
DummyEagleModel
()
self
.
model
.
get_input_embeddings
=
target_model
.
get_input_embeddings
self
.
model
.
compute_logits
=
target_model
.
compute_logits
# FIXME(woosuk): This is a dummy model for testing.
# Remove this once we have a real model.
class
DummyEagleModel
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
input_embeddings
=
self
.
get_input_embeddings
(
input_ids
)
return
hidden_states
+
input_embeddings
# Dummy return.
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
def
compute_probs_and_sample_next_token
(
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
sampling_metadata
.
all_greedy
:
# For greedy requests, draft_probs is not used in rejection sampling.
# Therefore, we can just return the logits.
probs
=
logits
next_token_ids
=
logits
.
argmax
(
dim
=-
1
)
return
next_token_ids
,
probs
is_greedy
=
sampling_metadata
.
temperature
==
-
1
temperature
=
torch
.
where
(
is_greedy
,
1.0
,
sampling_metadata
.
temperature
)
logits
.
div_
(
temperature
.
view
(
-
1
,
1
))
probs
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
# NOTE(woosuk): Currently, we ignore most of the sampling parameters in
# generating the draft tokens. We only use the temperature. While this
# could degrade the acceptance rate, it does not affect the distribution
# of the generated tokens after rejection sampling.
# TODO(woosuk): Consider seeds.
q
=
torch
.
empty_like
(
probs
)
q
.
exponential_
()
next_token_ids
=
probs
.
div_
(
q
).
argmax
(
dim
=-
1
).
view
(
-
1
)
if
not
sampling_metadata
.
all_random
:
greedy_token_ids
=
probs
.
argmax
(
dim
=-
1
)
next_token_ids
=
torch
.
where
(
is_greedy
,
greedy_token_ids
,
next_token_ids
,
)
return
next_token_ids
,
probs
@
triton
.
jit
def
prepare_input_kernel
(
out_ptr
,
cu_query_lens_ptr
,
cu_num_tokens_ptr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
# [start_pos, end_pos)
start_pos
=
tl
.
load
(
cu_num_tokens_ptr
+
pid
)
end_pos
=
tl
.
load
(
cu_num_tokens_ptr
+
pid
+
1
)
num_tokens
=
end_pos
-
start_pos
index_start
=
tl
.
load
(
cu_query_lens_ptr
+
pid
)
num_blocks
=
tl
.
cdiv
(
num_tokens
,
BLOCK_SIZE
)
for
i
in
tl
.
range
(
num_blocks
):
offset
=
i
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
tl
.
store
(
out_ptr
+
start_pos
+
offset
,
index_start
+
offset
,
mask
=
offset
<
num_tokens
,
)
vllm/v1/spec_decode/metrics.py
0 → 100644
View file @
fcfc474d
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
import
numpy
as
np
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
@
dataclass
class
SpecDecodingStats
:
num_draft_tokens
:
int
=
0
num_accepted_tokens
:
int
=
0
def
take
(
self
):
copied
=
SpecDecodingStats
(
self
.
num_draft_tokens
,
self
.
num_accepted_tokens
)
self
.
reset
()
return
copied
def
reset
(
self
):
self
.
num_draft_tokens
=
0
self
.
num_accepted_tokens
=
0
def
observe
(
self
,
num_draft_tokens
:
int
,
num_accepted_tokens
:
int
):
self
.
num_draft_tokens
+=
num_draft_tokens
self
.
num_accepted_tokens
+=
num_accepted_tokens
class
SpecDecodingMetrics
:
def
__init__
(
self
):
self
.
reset
()
def
reset
(
self
):
self
.
num_draft_tokens
:
list
[
int
]
=
[]
self
.
num_accepted_tokens
:
list
[
int
]
=
[]
def
observe
(
self
,
spec_decoding_stats
:
SpecDecodingStats
):
self
.
num_draft_tokens
.
append
(
spec_decoding_stats
.
num_draft_tokens
)
self
.
num_accepted_tokens
.
append
(
spec_decoding_stats
.
num_accepted_tokens
)
def
log
(
self
):
num_draft_tokens
=
np
.
sum
(
self
.
num_draft_tokens
)
num_accepted_tokens
=
np
.
sum
(
self
.
num_accepted_tokens
)
draft_acceptance_rate
=
(
num_accepted_tokens
/
num_draft_tokens
*
100
if
num_draft_tokens
>
0
else
float
(
"nan"
))
logger
.
info
(
"SpecDecoding metrics: "
"Draft acceptance rate: %.1f%%, "
"Accepted: %d tokens, "
"Drafted: %d tokens"
,
draft_acceptance_rate
,
num_accepted_tokens
,
num_draft_tokens
,
)
self
.
reset
()
vllm/v1/spec_decode/ngram_proposer.py
View file @
fcfc474d
...
...
@@ -4,15 +4,27 @@ from typing import Optional
import
numpy
as
np
from
numba
import
jit
from
vllm.config
import
VllmConfig
class
NgramProposer
:
def
__init__
(
self
,
vllm_config
:
VllmConfig
):
# Minimum length of the n-gram to match.
self
.
min_n
=
vllm_config
.
speculative_config
.
prompt_lookup_min
# Maximum length of the n-gram to match.
self
.
max_n
=
vllm_config
.
speculative_config
.
prompt_lookup_max
# Number of tokens follow the match. If there are less than k
# tokens follow the match, we will return the maximum amount of
# tokens until the end.
self
.
k
=
vllm_config
.
speculative_config
.
num_speculative_tokens
# Trigger Numba JIT compilation for N-gram proposer.
# This usually takes less than 1 second.
self
.
propose
(
np
.
zeros
(
1024
,
dtype
=
np
.
int32
))
def
propose
(
self
,
context_token_ids
:
np
.
ndarray
,
min_n
:
int
,
max_n
:
int
,
k
:
int
,
)
->
Optional
[
np
.
ndarray
]:
"""Proposes the next sequence of tokens based on n-gram pattern
matching in the context. The function finds matches of the last n
...
...
@@ -22,17 +34,12 @@ class NgramProposer:
Args:
context_token_ids: Numpy array of token IDs representing the
context sequence.
min_n: Minimum length of the n-gram to match.
max_n: Maximum length of the n-gram to match.
k: Number of tokens follow the match. If there are less
than k tokens follow the match, we will return
the maximum amount of tokens until the end.
Returns:
np.ndarray: The sequence of tokens that followed
the matched n-gram in the context.
None: If no matching n-gram pattern is found.
Example:
If context_token_ids = [1,2,3,4,2,3], min_n = 2, max_n = 3, and
k = 4:
...
...
@@ -44,12 +51,16 @@ class NgramProposer:
we only have three tokens after the match.
"""
# TODO(woosuk): Optimize this.
for
n
in
range
(
max_n
,
min_n
-
1
,
-
1
):
result
=
_find_subarray_kmp
(
context_token_ids
,
n
,
k
)
for
n
in
range
(
self
.
max_n
,
self
.
min_n
-
1
,
-
1
):
result
=
_find_subarray_kmp
(
context_token_ids
,
n
,
self
.
k
)
if
result
is
not
None
:
return
result
return
None
def
load_model
(
self
,
*
args
,
**
kwargs
):
# No model to load.
pass
@
jit
(
nopython
=
True
)
def
_kmp_lps_array
(
pattern
:
np
.
ndarray
)
->
np
.
ndarray
:
...
...
vllm/v1/structured_output/__init__.py
View file @
fcfc474d
...
...
@@ -2,7 +2,7 @@
from
__future__
import
annotations
import
multiprocessing
from
concurrent.futures
import
Future
,
ThreadPoolExecutor
from
concurrent.futures
import
ThreadPoolExecutor
from
typing
import
TYPE_CHECKING
,
Optional
from
vllm.config
import
VllmConfig
...
...
@@ -57,13 +57,13 @@ class StructuredOutputManager:
raise
ValueError
(
f
"Unsupported structured output backend:
{
backend_name
}
"
)
grammar
:
Future
[
StructuredOutputGrammar
]
=
self
.
executor
.
submit
(
self
.
_async_create_grammar
,
request
,
self
.
backend
)
grammar
=
self
.
executor
.
submit
(
self
.
_async_create_grammar
,
request
)
request
.
structured_output_request
.
grammar
=
grammar
# type: ignore[assignment]
def
_async_create_grammar
(
self
,
request
:
Request
,
backend
:
StructuredOutputBackend
)
->
StructuredOutputGrammar
:
self
,
request
:
Request
,
)
->
StructuredOutputGrammar
:
key
=
request
.
structured_output_request
.
structured_output_key
# type: ignore[union-attr]
# Note that the request was validated in the engine core client,
...
...
vllm/v1/structured_output/backend_guidance.py
View file @
fcfc474d
...
...
@@ -41,6 +41,9 @@ class GuidanceBackend(StructuredOutputBackend):
tokenizer_group
.
ping
()
self
.
vllm_config
=
vllm_config
self
.
vocab_size
=
vllm_config
.
model_config
.
get_vocab_size
()
self
.
disable_any_whitespace
=
(
"disable-any-whitespace"
in
vllm_config
.
decoding_config
.
guided_decoding_backend
)
tokenizer
=
tokenizer_group
.
get_lora_tokenizer
(
None
)
self
.
ll_tokenizer
=
llguidance_hf
.
from_tokenizer
(
tokenizer
,
None
)
...
...
@@ -48,7 +51,7 @@ class GuidanceBackend(StructuredOutputBackend):
def
compile_grammar
(
self
,
request_type
:
StructuredOutputOptions
,
grammar_spec
:
str
)
->
StructuredOutputGrammar
:
self
.
serialized_grammar
=
serialize_guidance_grammar
(
request_type
,
grammar_spec
)
request_type
,
grammar_spec
,
self
.
disable_any_whitespace
)
ll_matcher
=
llguidance
.
LLMatcher
(
self
.
ll_tokenizer
,
...
...
@@ -126,17 +129,19 @@ class GuidanceGrammar(StructuredOutputGrammar):
def
serialize_guidance_grammar
(
request_type
:
StructuredOutputOptions
,
grammar_spec
:
str
)
->
str
:
grammar_spec
:
str
,
disable_any_whitespace
:
bool
=
False
)
->
str
:
if
request_type
==
StructuredOutputOptions
.
JSON
:
# TODO: make whitespace_flexible configurable
return
llguidance
.
LLMatcher
.
grammar_from_json_schema
(
grammar_spec
,
defaults
=
{
"whitespace_flexible"
:
True
,
grammar_spec
,
defaults
=
{
"whitespace_flexible"
:
not
disable_any_whitespace
,
})
elif
request_type
==
StructuredOutputOptions
.
JSON_OBJECT
:
return
llguidance
.
LLMatcher
.
grammar_from_json_schema
(
'{"type": "object"}'
,
defaults
=
{
"whitespace_flexible"
:
True
,
'{"type": "object"}'
,
defaults
=
{
"whitespace_flexible"
:
not
disable_any_whitespace
,
})
else
:
if
request_type
==
StructuredOutputOptions
.
REGEX
:
...
...
vllm/v1/structured_output/backend_xgrammar.py
View file @
fcfc474d
...
...
@@ -42,12 +42,15 @@ class XgrammarBackend(StructuredOutputBackend):
# NOTE: ideally, xgrammar should handle this accordingly.
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
try
:
encoded_vocab
=
[
token
for
token
,
_
in
sorted
(
tokenizer
.
get_vocab
().
items
(),
key
=
lambda
x
:
x
[
1
],
)
]
if
tokenizer
.
is_tekken
:
encoded_vocab
=
tokenizer
.
_vocab
else
:
encoded_vocab
=
[
token
for
token
,
_
in
sorted
(
tokenizer
.
get_vocab
().
items
(),
key
=
lambda
x
:
x
[
1
],
)
]
stop_token_ids
=
None
if
hasattr
(
tokenizer
,
...
...
@@ -62,7 +65,8 @@ class XgrammarBackend(StructuredOutputBackend):
tokenizer_info
=
xgr
.
TokenizerInfo
(
# type: ignore
encoded_vocab
=
encoded_vocab
,
# NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
vocab_type
=
xgr
.
VocabType
.
BYTE_FALLBACK
,
vocab_type
=
xgr
.
VocabType
.
RAW
if
tokenizer
.
is_tekken
else
xgr
.
VocabType
.
BYTE_FALLBACK
,
vocab_size
=
self
.
vocab_size
,
stop_token_ids
=
stop_token_ids
,
add_prefix_space
=
True
,
...
...
@@ -80,7 +84,9 @@ class XgrammarBackend(StructuredOutputBackend):
ctx
=
self
.
compiler
.
compile_json_schema
(
grammar_spec
,
any_whitespace
=
not
self
.
disable_any_whitespace
)
elif
request_type
==
StructuredOutputOptions
.
JSON_OBJECT
:
ctx
=
self
.
compiler
.
compile_builtin_json_grammar
()
ctx
=
self
.
compiler
.
compile_json_schema
(
'{"type": "object"}'
,
any_whitespace
=
not
self
.
disable_any_whitespace
)
elif
request_type
==
StructuredOutputOptions
.
GRAMMAR
:
ctx
=
self
.
compiler
.
compile_grammar
(
grammar_spec
)
elif
request_type
==
StructuredOutputOptions
.
REGEX
:
...
...
vllm/v1/structured_output/utils.py
View file @
fcfc474d
...
...
@@ -26,10 +26,6 @@ def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool:
if
"pattern"
in
obj
:
return
True
# Check for enum restrictions
if
"enum"
in
obj
:
return
True
# Check for numeric ranges
if
obj
.
get
(
"type"
)
in
(
"integer"
,
"number"
)
and
any
(
key
in
obj
...
...
vllm/v1/utils.py
View file @
fcfc474d
...
...
@@ -105,7 +105,7 @@ class BackgroundProcHandle:
process_kwargs
:
dict
[
Any
,
Any
],
):
context
=
get_mp_context
()
reader
,
writer
=
context
.
Pipe
(
duplex
=
False
)
self
.
reader
,
writer
=
context
.
Pipe
(
duplex
=
False
)
assert
(
"ready_pipe"
not
in
process_kwargs
and
"input_path"
not
in
process_kwargs
...
...
@@ -115,14 +115,17 @@ class BackgroundProcHandle:
process_kwargs
[
"output_path"
]
=
output_path
# Run busy loop in background process.
self
.
proc
=
context
.
Process
(
target
=
target_fn
,
kwargs
=
process_kwargs
)
self
.
proc
=
context
.
Process
(
target
=
target_fn
,
kwargs
=
process_kwargs
,
name
=
process_name
)
self
.
_finalizer
=
weakref
.
finalize
(
self
,
shutdown
,
self
.
proc
,
input_path
,
output_path
)
self
.
proc
.
start
()
def
wait_for_startup
(
self
):
# Wait for startup.
if
reader
.
recv
()[
"status"
]
!=
"READY"
:
raise
RuntimeError
(
f
"
{
process_
name
}
initialization failed. "
if
self
.
reader
.
recv
()[
"status"
]
!=
"READY"
:
raise
RuntimeError
(
f
"
{
self
.
proc
.
name
}
initialization failed. "
"See root cause above."
)
def
shutdown
(
self
):
...
...
vllm/v1/worker/gpu_input_batch.py
View file @
fcfc474d
...
...
@@ -2,13 +2,13 @@
# Datastructures defining an input batch
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Optional
,
cast
from
typing
import
Optional
,
cast
import
numpy
as
np
import
torch
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.multimodal
.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.utils
import
swap_dict_values
from
vllm.v1.outputs
import
LogprobsTensors
...
...
@@ -18,9 +18,6 @@ from vllm.v1.worker.block_table import BlockTable
_SAMPLING_EPS
=
1e-5
if
TYPE_CHECKING
:
from
vllm.multimodal.inputs
import
PlaceholderRange
@
dataclass
class
CachedRequestState
:
...
...
@@ -29,7 +26,7 @@ class CachedRequestState:
prompt_token_ids
:
list
[
int
]
prompt
:
Optional
[
str
]
mm_inputs
:
list
[
MultiModalKwargs
]
mm_positions
:
list
[
"
PlaceholderRange
"
]
mm_positions
:
list
[
PlaceholderRange
]
sampling_params
:
SamplingParams
generator
:
Optional
[
torch
.
Generator
]
...
...
@@ -42,9 +39,18 @@ class CachedRequestState:
lora_request
:
Optional
[
LoRARequest
]
=
None
def
__post_init__
(
self
):
self
.
num_prompt_tokens
=
len
(
self
.
prompt_token_ids
)
@
property
def
num_tokens
(
self
)
->
int
:
return
len
(
self
.
prompt_token_ids
)
+
len
(
self
.
output_token_ids
)
return
self
.
num_prompt_tokens
+
len
(
self
.
output_token_ids
)
def
get_token_id
(
self
,
idx
:
int
)
->
int
:
if
idx
<
self
.
num_prompt_tokens
:
return
self
.
prompt_token_ids
[
idx
]
else
:
return
self
.
output_token_ids
[
idx
-
self
.
num_prompt_tokens
]
class
InputBatch
:
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
fcfc474d
...
...
@@ -15,7 +15,6 @@ from vllm.attention.layer import Attention
from
vllm.config
import
CompilationLevel
,
VllmConfig
from
vllm.distributed.parallel_state
import
get_pp_group
,
graph_capture
from
vllm.forward_context
import
set_forward_context
from
vllm.inputs
import
INPUT_REGISTRY
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
...
...
@@ -25,16 +24,18 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
from
vllm.sampling_params
import
SamplingType
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
DeviceMemoryProfiler
,
LayerBlockType
,
LazyLoader
,
cdiv
,
is_pin_memory_available
)
GiB_bytes
,
LayerBlockType
,
LazyLoader
,
cdiv
,
check_use_alibi
,
is_pin_memory_available
)
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheSpec
)
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
FullAttentionSpec
,
KVCacheConfig
,
KVCacheSpec
,
SlidingWindowSpec
)
from
vllm.v1.outputs
import
(
EMPTY_MODEL_RUNNER_OUTPUT
,
LogprobsTensors
,
ModelRunnerOutput
)
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.rejection_sampler
import
RejectionSampler
from
vllm.v1.spec_decode.eagle
import
EagleProposer
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.spec_decode.utils
import
is_spec_decode_supported
...
...
@@ -42,6 +43,8 @@ from vllm.v1.utils import bind_kv_cache
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
.utils
import
sanity_check_mm_encoder_outputs
if
TYPE_CHECKING
:
import
xgrammar
as
xgr
...
...
@@ -70,6 +73,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
prompt_adapter_config
=
vllm_config
.
prompt_adapter_config
self
.
observability_config
=
vllm_config
.
observability_config
from
vllm.model_executor.models.utils
import
set_cpu_offload_max_bytes
set_cpu_offload_max_bytes
(
int
(
self
.
cache_config
.
cpu_offload_gb
*
1024
**
3
))
model_config
=
self
.
model_config
cache_config
=
self
.
cache_config
scheduler_config
=
self
.
scheduler_config
...
...
@@ -106,6 +113,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
num_kv_heads
=
model_config
.
get_num_kv_heads
(
parallel_config
)
self
.
head_size
=
model_config
.
get_head_size
()
self
.
hidden_size
=
model_config
.
get_hidden_size
()
self
.
attention_chunk_size
=
model_config
.
attention_chunk_size
self
.
attn_backend
=
get_attn_backend
(
self
.
head_size
,
...
...
@@ -130,13 +138,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
cascade_attn_enabled
=
not
self
.
model_config
.
disable_cascade_attn
# Multi-modal data support
self
.
input_registry
=
INPUT_REGISTRY
self
.
mm_registry
=
MULTIMODAL_REGISTRY
self
.
uses_mrope
=
model_config
.
uses_mrope
encoder_compute_budget
,
encoder_cache_size
=
compute_encoder_budget
(
model_config
=
model_config
,
scheduler_config
=
scheduler_config
,
mm_registry
=
self
.
mm_registry
,
)
self
.
max_num_encoder_input_tokens
=
encoder_compute_budget
self
.
encoder_cache_size
=
encoder_cache_size
...
...
@@ -151,18 +159,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
use_spec_decode
=
False
if
self
.
speculative_config
:
self
.
use_spec_decode
=
True
assert
self
.
speculative_config
.
method
==
"ngram"
,
\
"Currently, only ngram spec decode is supported in V1."
if
get_pp_group
().
is_last_rank
:
self
.
drafter
=
NgramProposer
()
# Trigger Numba JIT compilation for N-gram proposer.
# This usually takes less than 1 second.
self
.
drafter
.
propose
(
np
.
zeros
(
1024
,
dtype
=
np
.
int32
),
self
.
speculative_config
.
prompt_lookup_min
,
self
.
speculative_config
.
prompt_lookup_max
,
self
.
speculative_config
.
num_speculative_tokens
,
)
if
self
.
speculative_config
.
method
==
"ngram"
:
self
.
drafter
=
NgramProposer
(
self
.
vllm_config
)
elif
self
.
speculative_config
.
method
==
"eagle"
:
self
.
drafter
=
EagleProposer
(
self
.
vllm_config
,
self
.
device
)
# type: ignore
else
:
raise
ValueError
(
"Unknown speculative decoding method: "
f
"
{
self
.
speculative_config
.
method
}
"
)
self
.
rejection_sampler
=
RejectionSampler
()
# Request states.
...
...
@@ -223,6 +228,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
device
=
"cpu"
,
pin_memory
=
self
.
pin_memory
)
# Only relevant for models using ALiBi (e.g, MPT)
self
.
use_alibi
=
check_use_alibi
(
model_config
)
self
.
inputs_embeds
=
torch
.
zeros
(
(
self
.
max_num_tokens
,
self
.
hidden_size
),
dtype
=
self
.
dtype
,
...
...
@@ -671,7 +679,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# use two kernels for cascade attention. Let's imagine:
# Request 3's input query: [D]
# Request 3's kv cache: [A, B, C, D]
# Request 3's num_computed_tokens:
4
(i.e., [A, B, C
, D
])
# Request 3's num_computed_tokens:
3
(i.e., [A, B, C])
# If we use [A, B, C, D] as the common prefix for Request 1-3,
# then Request 3 will be processed only by the first kernel,
# and the second kernel will get an empty input. While this is not
...
...
@@ -689,7 +697,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
query_lens
=
num_scheduled_tokens
,
num_query_heads
=
self
.
num_query_heads
,
num_kv_heads
=
self
.
num_kv_heads
,
use_alibi
=
False
,
# FIXME
use_alibi
=
self
.
use_alibi
,
use_sliding_window
=
self
.
window_size
is
not
None
,
num_sms
=
self
.
num_sms
,
)
...
...
@@ -861,6 +869,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
curr_group_outputs
=
self
.
model
.
get_multimodal_embeddings
(
**
batched_mm_inputs
)
sanity_check_mm_encoder_outputs
(
curr_group_outputs
,
expected_num_items
=
len
(
grouped_mm_inputs
),
)
for
output
in
curr_group_outputs
:
encoder_outputs
.
append
(
output
)
...
...
@@ -1085,8 +1098,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
for
i
,
generator
in
self
.
input_batch
.
generators
.
items
():
req_id
=
self
.
input_batch
.
req_ids
[
i
]
discard_sampled_tokens_req_indices
=
[]
for
i
,
req_id
in
enumerate
(
self
.
input_batch
.
req_ids
):
req_state
=
self
.
requests
[
req_id
]
seq_len
=
(
req_state
.
num_computed_tokens
+
scheduler_output
.
num_scheduled_tokens
[
req_id
])
...
...
@@ -1094,7 +1107,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Ignore the sampled token for partial prefills.
# Rewind the generator state as if the token was not sampled.
# This relies on cuda-specific torch-internal impl details
generator
.
set_offset
(
generator
.
get_offset
()
-
4
)
generator
=
self
.
input_batch
.
generators
.
get
(
i
)
if
generator
is
not
None
:
generator
.
set_offset
(
generator
.
get_offset
()
-
4
)
# Record the index of the request that should not be sampled,
# so that we could clear the sampled tokens before returning.
discard_sampled_tokens_req_indices
.
append
(
i
)
# NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
...
...
@@ -1117,13 +1135,83 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else
:
# Includes spec decode tokens.
valid_sampled_token_ids
=
self
.
rejection_sampler
.
parse_output
(
sampled_token_ids
,
self
.
input_batch
.
vocab_size
)
sampled_token_ids
,
self
.
input_batch
.
vocab_size
,
)
# Mask out the sampled tokens that should not be sampled.
for
i
in
discard_sampled_tokens_req_indices
:
valid_sampled_token_ids
[
i
].
clear
()
if
not
self
.
use_spec_decode
:
# Speculative decoding is not enabled.
spec_token_ids
=
None
else
:
elif
self
.
speculative_config
.
method
==
"ngram"
:
assert
isinstance
(
self
.
drafter
,
NgramProposer
)
spec_token_ids
=
self
.
generate_draft_token_ids
(
valid_sampled_token_ids
,
sampling_metadata
)
elif
self
.
speculative_config
.
method
==
"eagle"
:
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
# TODO(woosuk): Refactor the loop.
next_token_ids
:
list
[
int
]
=
[]
for
i
,
token_ids
in
enumerate
(
valid_sampled_token_ids
):
if
token_ids
:
# Common case.
next_token_id
=
token_ids
[
-
1
]
else
:
# Partial prefill (rare case).
# Get the next token id from the request state.
req_id
=
self
.
input_batch
.
req_ids
[
i
]
req_state
=
self
.
requests
[
req_id
]
seq_len
=
(
req_state
.
num_computed_tokens
+
scheduler_output
.
num_scheduled_tokens
[
req_id
])
next_token_id
=
req_state
.
get_token_id
(
seq_len
)
next_token_ids
.
append
(
next_token_id
)
next_token_ids
=
torch
.
tensor
(
next_token_ids
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
if
spec_decode_metadata
is
None
:
# input_ids can be None for multimodal models.
target_token_ids
=
self
.
input_ids
[:
num_scheduled_tokens
]
target_positions
=
positions
target_hidden_states
=
hidden_states
target_slot_mapping
=
attn_metadata
.
slot_mapping
cu_num_tokens
=
attn_metadata
.
query_start_loc
else
:
# TODO(woosuk): Refactor this.
num_draft_tokens
=
spec_decode_metadata
.
num_draft_tokens
num_rejected_tokens
=
[
n
+
1
-
len
(
valid_sampled_token_ids
[
i
])
if
n
>
0
else
0
for
i
,
n
in
enumerate
(
num_draft_tokens
)
]
num_rejected_tokens
=
torch
.
tensor
(
num_rejected_tokens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
cu_num_tokens
,
token_indices
=
self
.
drafter
.
prepare_inputs
(
attn_metadata
.
query_start_loc
,
num_rejected_tokens
,
)
target_token_ids
=
self
.
input_ids
[
token_indices
]
target_positions
=
positions
[
token_indices
]
target_hidden_states
=
hidden_states
[
token_indices
]
target_slot_mapping
=
attn_metadata
.
slot_mapping
[
token_indices
]
draft_token_ids
,
draft_probs
=
self
.
drafter
.
propose
(
target_token_ids
=
target_token_ids
,
target_positions
=
target_positions
,
target_hidden_states
=
target_hidden_states
,
target_slot_mapping
=
target_slot_mapping
,
next_token_ids
=
next_token_ids
,
cu_num_tokens
=
cu_num_tokens
,
block_table
=
attn_metadata
.
block_table
,
sampling_metadata
=
sampling_metadata
,
)
spec_token_ids
=
draft_token_ids
.
tolist
()
# TODO(woosuk): Cache draft_probs and use it for rejection sampling
# in the next step.
del
draft_probs
return
ModelRunnerOutput
(
req_ids
=
self
.
input_batch
.
req_ids
,
...
...
@@ -1159,11 +1247,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
end_idx
=
start_idx
+
num_sampled_ids
self
.
input_batch
.
token_ids_cpu
[
i
,
start_idx
:
end_idx
]
=
sampled_ids
drafter_output
=
self
.
drafter
.
propose
(
self
.
input_batch
.
token_ids_cpu
[
i
,
:
end_idx
],
self
.
speculative_config
.
prompt_lookup_min
,
self
.
speculative_config
.
prompt_lookup_max
,
self
.
speculative_config
.
num_speculative_tokens
,
)
self
.
input_batch
.
token_ids_cpu
[
i
,
:
end_idx
])
if
drafter_output
is
None
or
len
(
drafter_output
)
==
0
:
draft_token_ids
.
append
([])
else
:
...
...
@@ -1181,10 +1265,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
scheduler_config
,
self
.
lora_config
,
self
.
device
)
if
hasattr
(
self
,
"drafter"
):
logger
.
info
(
"Loading drafter model..."
)
self
.
drafter
.
load_model
(
self
.
model
)
time_after_load
=
time
.
perf_counter
()
self
.
model_memory_usage
=
m
.
consumed_memory
logger
.
info
(
"Model loading took %.4f GB and %.6f seconds"
,
self
.
model_memory_usage
/
float
(
2
**
30
)
,
logger
.
info
(
"Model loading took %.4f G
i
B and %.6f seconds"
,
self
.
model_memory_usage
/
GiB_bytes
,
time_after_load
-
time_before_load
)
def
_get_prompt_logprobs_dict
(
...
...
@@ -1425,9 +1512,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# NOTE: Currently model is profiled with a single non-text
# modality with the max possible input tokens even when
# it supports multiple.
max_tokens_by_modality_dict
=
(
MULTIMODAL_REGISTRY
.
get_max_tokens_per_item_by_nonzero_modality
(
self
.
model_config
))
max_tokens_by_modality_dict
=
self
.
mm_registry
\
.
get_max_tokens_per_item_by_nonzero_modality
(
self
.
model_config
)
dummy_data_modality
,
max_tokens_per_mm_item
=
max
(
max_tokens_by_modality_dict
.
items
(),
key
=
lambda
item
:
item
[
1
])
...
...
@@ -1459,24 +1545,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
encoder_budget
,
max_num_mm_items
,
dummy_data_modality
)
# Create dummy batch of multimodal inputs.
dummy_
request_data
=
self
.
input
_registry
.
dummy_data_for_profiling
(
dummy_
mm_kwargs
=
self
.
mm
_registry
.
get_decoder_dummy_data
(
model_config
=
self
.
model_config
,
seq_len
=
self
.
max_num_tokens
,
mm_registry
=
self
.
mm_registry
,
)
dummy_mm_data
=
dummy_request_data
.
multi_modal_data
if
not
isinstance
(
dummy_mm_data
,
MultiModalKwargs
):
# TODO: Delete this check once input mapper is fully removed.
raise
RuntimeError
(
"Legacy input mapper is not supported in V1"
)
# Dummy data definition may contain multiple multimodal items
# (e.g, multiple images) for a single request, therefore here we
# always replicate first item by max_num_mm_items times since in V1
# they are scheduled to be processed separately.
dummy_mm_item
=
dummy_mm_data
.
get_item
(
modality
=
dummy_data_modality
,
item_index
=
0
)
dummy_mm_kwargs
=
MultiModalKwargs
.
from_items
([
dummy_mm_item
])
mm_counts
=
{
dummy_data_modality
:
1
},
).
multi_modal_data
batched_dummy_mm_inputs
=
MultiModalKwargs
.
batch
(
[
dummy_mm_kwargs
]
*
max_num_mm_items
)
...
...
@@ -1486,12 +1561,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Run multimodal encoder.
dummy_encoder_outputs
=
self
.
model
.
get_multimodal_embeddings
(
**
batched_dummy_mm_inputs
)
assert
len
(
dummy_encoder_outputs
)
==
max_num_mm_items
,
(
"Expected dimension 0 of encoder outputs to match the number "
f
"of multimodal data items:
{
max_num_mm_items
}
, got "
f
"
{
len
(
dummy_encoder_outputs
)
=
}
instead. This is most likely "
"due to the 'get_multimodal_embeddings' method of the model "
"not implemented correctly."
)
sanity_check_mm_encoder_outputs
(
dummy_encoder_outputs
,
expected_num_items
=
max_num_mm_items
,
)
# Cache the dummy encoder outputs.
self
.
encoder_cache
[
"tmp"
]
=
dict
(
enumerate
(
dummy_encoder_outputs
))
...
...
@@ -1562,7 +1636,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# different GPUs, and `kv_cache_config.num_blocks` is set to
# the min of all `num_blocks`. Verify it here.
assert
num_blocks
>=
kv_cache_config
.
num_blocks
if
isinstance
(
kv_cache_spec
,
Full
AttentionSpec
):
if
isinstance
(
kv_cache_spec
,
AttentionSpec
):
kv_cache_shape
=
self
.
attn_backend
.
get_kv_cache_shape
(
num_blocks
,
kv_cache_spec
.
block_size
,
kv_cache_spec
.
num_kv_heads
,
kv_cache_spec
.
head_size
)
...
...
@@ -1601,12 +1675,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# cross-attention
assert
isinstance
(
attn_module
,
Attention
)
if
attn_module
.
attn_type
==
AttentionType
.
DECODER
:
kv_cache_spec
[
layer_name
]
=
FullAttentionSpec
(
block_size
=
block_size
,
num_kv_heads
=
attn_module
.
num_kv_heads
,
head_size
=
attn_module
.
head_size
,
dtype
=
self
.
kv_cache_dtype
,
use_mla
=
use_mla
)
if
attn_module
.
sliding_window
is
not
None
:
kv_cache_spec
[
layer_name
]
=
SlidingWindowSpec
(
block_size
=
block_size
,
num_kv_heads
=
attn_module
.
num_kv_heads
,
head_size
=
attn_module
.
head_size
,
dtype
=
self
.
kv_cache_dtype
,
sliding_window
=
attn_module
.
sliding_window
,
use_mla
=
use_mla
)
else
:
kv_cache_spec
[
layer_name
]
=
FullAttentionSpec
(
block_size
=
block_size
,
num_kv_heads
=
attn_module
.
num_kv_heads
,
head_size
=
attn_module
.
head_size
,
dtype
=
self
.
kv_cache_dtype
,
use_mla
=
use_mla
)
elif
attn_module
.
attn_type
in
(
AttentionType
.
ENCODER
,
AttentionType
.
ENCODER_ONLY
):
# encoder-only attention does not need KV cache.
...
...
vllm/v1/worker/gpu_worker.py
View file @
fcfc474d
...
...
@@ -83,9 +83,9 @@ class Worker(WorkerBase):
"%.2f GiB memory is still in use."
,
freed_bytes
/
GiB_bytes
,
used_bytes
/
GiB_bytes
)
def
wake_up
(
self
)
->
None
:
def
wake_up
(
self
,
tags
:
Optional
[
list
[
str
]]
=
None
)
->
None
:
allocator
=
CuMemAllocator
.
get_instance
()
allocator
.
wake_up
()
allocator
.
wake_up
(
tags
)
def
init_device
(
self
):
if
self
.
device_config
.
device
.
type
==
"cuda"
:
...
...
@@ -269,6 +269,20 @@ class Worker(WorkerBase):
# worker will always be healthy as long as it's running.
return
def
save_sharded_state
(
self
,
path
:
str
,
pattern
:
Optional
[
str
]
=
None
,
max_size
:
Optional
[
int
]
=
None
,
)
->
None
:
from
vllm.model_executor.model_loader.loader
import
ShardedStateLoader
ShardedStateLoader
.
save_model
(
self
.
model_runner
.
model
,
path
,
pattern
=
pattern
,
max_size
=
max_size
,
)
def
init_worker_distributed_environment
(
parallel_config
:
ParallelConfig
,
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
fcfc474d
# SPDX-License-Identifier: Apache-2.0
import
bisect
import
time
from
typing
import
TYPE_CHECKING
,
Optional
,
cast
from
unittest.mock
import
patch
...
...
@@ -16,7 +17,6 @@ from vllm.attention.backends.abstract import AttentionType
from
vllm.attention.layer
import
Attention
from
vllm.config
import
VllmConfig
from
vllm.forward_context
import
set_forward_context
from
vllm.inputs
import
INPUT_REGISTRY
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
...
...
@@ -24,12 +24,11 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
from
vllm.sampling_params
import
SamplingType
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
LayerBlockType
,
cdiv
,
is_pin_memory_available
from
vllm.v1.attention.backends.pallas
import
(
NUM_KV_PAGES_PER_BLOCK
,
PallasAttentionBackend
,
from
vllm.v1.attention.backends.pallas
import
(
PallasAttentionBackend
,
PallasMetadata
)
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheSpec
)
KVCacheSpec
,
SlidingWindowSpec
)
from
vllm.v1.outputs
import
(
EMPTY_MODEL_RUNNER_OUTPUT
,
LogprobsTensors
,
ModelRunnerOutput
,
SamplerOutput
)
from
vllm.v1.sample.tpu.metadata
import
TPUSupportedSamplingMetadata
...
...
@@ -37,6 +36,8 @@ from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
from
vllm.v1.utils
import
bind_kv_cache
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
.utils
import
sanity_check_mm_encoder_outputs
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
...
...
@@ -75,11 +76,15 @@ class TPUModelRunner:
parallel_config
=
self
.
parallel_config
self
.
device
=
device
self
.
check_recompilation
=
envs
.
VLLM_XLA_CHECK_RECOMPILATION
if
self
.
check_recompilation
:
self
.
num_xla_graphs
=
xr
.
get_num_cached_compilation_graph
()
self
.
enforce_eager
=
model_config
.
enforce_eager
self
.
num_xla_graphs
=
0
self
.
_update_num_xla_graphs
(
"init"
)
self
.
pin_memory
=
is_pin_memory_available
()
self
.
dtype
=
self
.
model_config
.
dtype
self
.
_hidden_states_dtype
=
self
.
dtype
self
.
is_multimodal_model
=
model_config
.
is_multimodal_model
self
.
sliding_window
=
model_config
.
get_sliding_window
()
...
...
@@ -87,7 +92,9 @@ class TPUModelRunner:
self
.
max_model_len
=
model_config
.
max_model_len
self
.
max_num_blocks_per_req
=
cdiv
(
self
.
max_model_len
,
self
.
block_size
)
self
.
max_num_tokens
=
scheduler_config
.
max_num_batched_tokens
self
.
max_num_reqs
=
scheduler_config
.
max_num_seqs
# InputBatch needs to work with sampling tensors greater than padding
# to avoid dynamic shapes. Also, avoid suboptimal alignment.
self
.
max_num_reqs
=
max
(
scheduler_config
.
max_num_seqs
,
MIN_NUM_SEQS
)
# Model-related.
self
.
num_attn_layers
=
model_config
.
get_num_layers_by_block_type
(
...
...
@@ -99,7 +106,6 @@ class TPUModelRunner:
self
.
hidden_size
=
model_config
.
get_hidden_size
()
# Multi-modal data support
self
.
input_registry
=
INPUT_REGISTRY
self
.
mm_registry
=
MULTIMODAL_REGISTRY
self
.
uses_mrope
=
model_config
.
uses_mrope
# TODO: Support M-RoPE (e.g, Qwen2-VL)
...
...
@@ -108,6 +114,7 @@ class TPUModelRunner:
encoder_compute_budget
,
encoder_cache_size
=
compute_encoder_budget
(
model_config
=
model_config
,
scheduler_config
=
scheduler_config
,
mm_registry
=
self
.
mm_registry
,
)
self
.
max_num_encoder_input_tokens
=
encoder_compute_budget
self
.
encoder_cache_size
=
encoder_cache_size
...
...
@@ -147,11 +154,8 @@ class TPUModelRunner:
dtype
=
torch
.
int64
,
device
=
"cpu"
)
self
.
slot_mapping_np
=
self
.
slot_mapping_cpu
.
numpy
()
padded_max_num_blocks_per_req
=
_get_padded_number
(
self
.
max_num_blocks_per_req
,
NUM_KV_PAGES_PER_BLOCK
)
self
.
block_table_cpu
=
torch
.
zeros
(
(
self
.
max_num_tokens
,
padded_
max_num_blocks_per_req
),
(
self
.
max_num_tokens
,
self
.
max_num_blocks_per_req
),
dtype
=
self
.
input_batch
.
block_table
.
get_cpu_tensor
().
dtype
,
device
=
"cpu"
)
...
...
@@ -170,6 +174,35 @@ class TPUModelRunner:
# Range tensor with values [0 .. self.max_num_tokens - 1].
# Used to initialize positions / context_lens / seq_lens
self
.
arange_np
=
np
.
arange
(
self
.
max_num_tokens
,
dtype
=
np
.
int32
)
self
.
num_tokens_paddings
=
_get_paddings
(
min_token_size
=
16
,
max_token_size
=
self
.
max_num_tokens
,
padding_gap
=
envs
.
VLLM_TPU_BUCKET_PADDING_GAP
)
def
_update_num_xla_graphs
(
self
,
case_str
):
check_comp
=
self
.
check_recompilation
and
not
self
.
enforce_eager
if
not
check_comp
:
return
total_cached_graphs
=
xr
.
get_num_cached_compilation_graph
()
new_compiled_graphs
=
total_cached_graphs
-
self
.
num_xla_graphs
if
new_compiled_graphs
==
0
:
return
logger
.
info
(
"Add new %d compiled XLA graphs due to %s"
,
new_compiled_graphs
,
case_str
)
self
.
num_xla_graphs
+=
new_compiled_graphs
def
_verify_num_xla_graphs
(
self
,
case_str
):
check_comp
=
self
.
check_recompilation
and
not
self
.
enforce_eager
if
not
check_comp
:
return
curr_cached_graph
=
xr
.
get_num_cached_compilation_graph
()
assert
self
.
num_xla_graphs
==
curr_cached_graph
,
(
"Recompilation after warm up is detected during {}."
" num_xla_graphs = {} curr_cached_graph = {}"
.
format
(
case_str
,
self
.
num_xla_graphs
,
curr_cached_graph
))
def
_update_states
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
"""Update the cached states and the persistent batch with the scheduler
...
...
@@ -279,9 +312,6 @@ class TPUModelRunner:
req_data
.
num_computed_tokens
)
self
.
input_batch
.
block_table
.
append_row
(
req_data
.
new_block_ids
,
req_index
)
# Check if the batch has changed. If not, we can skip copying the
# sampling metadata from CPU to GPU.
batch_changed
=
len
(
removed_req_indices
)
>
0
or
len
(
req_ids_to_add
)
>
0
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
...
...
@@ -300,9 +330,6 @@ class TPUModelRunner:
if
removed_req_indices
:
self
.
input_batch
.
condense
(
removed_req_indices
)
# TODO This slices tensors to copy to device, triggering recompilation.
if
batch_changed
:
self
.
input_batch
.
refresh_sampling_metadata
()
return
len
(
unscheduled_req_ids
)
>
0
or
len
(
req_ids_to_add
)
>
0
def
get_model
(
self
)
->
nn
.
Module
:
...
...
@@ -322,17 +349,25 @@ class TPUModelRunner:
block_size
=
self
.
vllm_config
.
cache_config
.
block_size
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
]
=
{}
for
layer_name
,
attn_module
in
forward_ctx
.
items
():
# TODO: Support other attention modules, e.g., sliding window,
# cross-attention, MLA.
assert
isinstance
(
attn_module
,
Attention
)
if
attn_module
.
attn_type
==
AttentionType
.
DECODER
:
kv_cache_spec
[
layer_name
]
=
FullAttentionSpec
(
block_size
=
block_size
,
num_kv_heads
=
attn_module
.
num_kv_heads
,
head_size
=
attn_module
.
head_size
,
dtype
=
attn_module
.
dtype
,
use_mla
=
False
,
)
if
attn_module
.
sliding_window
is
not
None
:
kv_cache_spec
[
layer_name
]
=
SlidingWindowSpec
(
block_size
=
block_size
,
num_kv_heads
=
attn_module
.
num_kv_heads
,
head_size
=
attn_module
.
head_size
,
dtype
=
attn_module
.
dtype
,
sliding_window
=
attn_module
.
sliding_window
,
use_mla
=
False
,
)
else
:
kv_cache_spec
[
layer_name
]
=
FullAttentionSpec
(
block_size
=
block_size
,
num_kv_heads
=
attn_module
.
num_kv_heads
,
head_size
=
attn_module
.
head_size
,
dtype
=
attn_module
.
dtype
,
use_mla
=
False
,
)
elif
attn_module
.
attn_type
in
(
AttentionType
.
ENCODER
,
AttentionType
.
ENCODER_ONLY
):
# encoder-only attention does not need KV cache.
...
...
@@ -428,7 +463,7 @@ class TPUModelRunner:
# Do the padding and copy the tensors to the TPU.
padded_total_num_scheduled_tokens
=
_get_padded_token_len
(
total_num_scheduled_tokens
)
self
.
num_tokens_paddings
,
total_num_scheduled_tokens
)
# Zero out to avoid spurious values from prev iteration (last cp chunk)
self
.
input_ids_cpu
[
total_num_scheduled_tokens
:
padded_total_num_scheduled_tokens
]
=
0
...
...
@@ -511,6 +546,11 @@ class TPUModelRunner:
curr_group_outputs
=
self
.
model
.
get_multimodal_embeddings
(
**
batched_mm_inputs
)
sanity_check_mm_encoder_outputs
(
curr_group_outputs
,
expected_num_items
=
len
(
grouped_mm_inputs
),
)
for
output
in
curr_group_outputs
:
encoder_outputs
.
append
(
output
)
...
...
@@ -579,7 +619,6 @@ class TPUModelRunner:
# Prepare inputs
attn_metadata
,
logits_indices
=
self
.
_prepare_inputs
(
scheduler_output
)
if
self
.
is_multimodal_model
:
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
...
...
@@ -597,14 +636,12 @@ class TPUModelRunner:
# then the embedding layer is not included in the CUDA graph.
input_ids
=
self
.
input_ids
inputs_embeds
=
None
sampling_metadata
=
self
.
input_batch
.
sampling_metadata
num_reqs
=
self
.
input_batch
.
num_reqs
# NOTE (NickLucche) here we sync with TPU:
if there's any shape
#
mismatch in pre-processing, it will trigger a small
recompil
ation
#
of the code thus far. Forward graph remains untouched
.
# NOTE (NickLucche) here we sync with TPU:
sampling params tensors
#
are copied to device in chunks of p
re
-
compil
ed padded shape to
#
avoid recompilations
.
tpu_sampling_metadata
=
TPUSupportedSamplingMetadata
.
\
from_sampling_metadata
(
sampling_metadata
,
logits_indices
,
num_reqs
,
self
.
device
)
from_input_batch
(
self
.
input_batch
,
logits_indices
)
# Run the decoder
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
):
hidden_states
=
self
.
model
(
...
...
@@ -621,6 +658,7 @@ class TPUModelRunner:
# Update the cache state concurrently. Code above will not block until
# we use `selected_token_ids`. Add mark_step if post-processing changes
request_seq_lens
:
list
[
tuple
[
int
,
CachedRequestState
,
int
]]
=
[]
discard_sampled_tokens_req_indices
=
[]
for
i
,
req_id
in
zip
(
range
(
num_reqs
),
self
.
input_batch
.
req_ids
):
assert
req_id
is
not
None
req_state
=
self
.
requests
[
req_id
]
...
...
@@ -636,6 +674,10 @@ class TPUModelRunner:
# This relies on cuda-specific torch-internal impl details
generator
.
set_offset
(
generator
.
get_offset
()
-
4
)
# Record the index of the request that should not be sampled,
# so that we could clear the sampled tokens before returning.
discard_sampled_tokens_req_indices
.
append
(
i
)
assert
all
(
req_id
is
not
None
for
req_id
in
self
.
input_batch
.
req_ids
[:
num_reqs
]),
"req_ids contains None"
...
...
@@ -649,11 +691,19 @@ class TPUModelRunner:
if
max_gen_len
==
1
:
valid_sampled_token_ids
=
selected_token_ids
.
tolist
()
# Mask out the sampled tokens that should not be sampled.
# TODO: Keep in sync with gpu_model_runner.py, in particular
# the "else" case here
for
i
in
discard_sampled_tokens_req_indices
:
valid_sampled_token_ids
[
i
].
clear
()
# Append sampled tokens
for
i
,
req_state
,
seq_len
in
request_seq_lens
:
token_id
=
valid_sampled_token_ids
[
i
][
0
]
self
.
input_batch
.
token_ids_cpu
[
i
,
seq_len
]
=
token_id
req_state
.
output_token_ids
.
append
(
token_id
)
self
.
input_batch
.
num_tokens
[
i
]
+=
1
else
:
valid_mask
=
selected_token_ids
!=
INVALID_TOKEN_ID
gen_lens
=
valid_mask
.
sum
(
dim
=
1
).
tolist
()
...
...
@@ -676,12 +726,11 @@ class TPUModelRunner:
logprobs
=
None
,
prompt_logprobs_dict
=
prompt_logprobs_dict
,
)
# Check there is no new graph compilation, all the graphs should be
# captured and compiled during warming up.
if
self
.
check_recompilation
and
not
self
.
enforce_eager
:
curr_cached_graph
=
xr
.
get_num_cached_compilation_graph
()
assert
self
.
num_xla_graphs
==
curr_cached_graph
,
(
"Recompilation after warm up is detected."
)
# Check there are no new graphs compiled - all the graphs should be
# captured and compiled during warm up.
self
.
_verify_num_xla_graphs
(
"execute_model"
)
return
model_runner_output
def
load_model
(
self
)
->
None
:
...
...
@@ -761,10 +810,11 @@ class TPUModelRunner:
torch
.
_dynamo
.
mark_dynamic
(
attn_metadata
.
slot_mapping
,
0
)
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
0
):
self
.
model
(
input_ids
=
input_ids
,
positions
=
position_ids
,
kv_caches
=
kv_caches
,
inputs_embeds
=
inputs_embeds
)
out
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
position_ids
,
kv_caches
=
kv_caches
,
inputs_embeds
=
inputs_embeds
)
self
.
_hidden_states_dtype
=
out
.
dtype
def
capture_model
(
self
)
->
None
:
"""Compile the model."""
...
...
@@ -772,63 +822,54 @@ class TPUModelRunner:
logger
.
info
(
"Compiling the model with different input shapes."
)
start
=
time
.
perf_counter
()
num_tokens
=
16
while
True
:
for
num_tokens
in
self
.
num_tokens_paddings
:
logger
.
info
(
" -- num_tokens: %d"
,
num_tokens
)
self
.
_dummy_run
(
self
.
kv_caches
,
num_tokens
)
xm
.
mark_step
()
if
num_tokens
>=
self
.
max_num_tokens
:
break
num_tokens
*=
2
xm
.
wait_device_ops
()
end
=
time
.
perf_counter
()
logger
.
info
(
"Compilation finished in in %.2f [secs]."
,
end
-
start
)
self
.
_update_num_xla_graphs
(
"model"
)
logger
.
info
(
"Compiling sampling with different input shapes."
)
start
=
time
.
perf_counter
()
num_tokens
=
16
hsize
=
self
.
model_config
.
get_hidden_size
()
device
=
self
.
device
# Compile sampling step for different model+sampler outputs in bucketed
# n_tokens x max_num_reqs. Graph is really small so this is fine.
while
True
:
for
num_tokens
in
self
.
num_tokens_paddings
:
num_reqs_to_sample
=
MIN_NUM_SEQS
dummy_hidden
=
torch
.
randn
((
num_tokens
,
hsize
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
dtype
=
self
.
_hidden_states_dtype
)
# Compile for [8, 16, .., 128,.., `self.max_num_reqs`]
while
True
:
# Default metadata is an all_greedy setup. But since the
# `do_argmax` flag is a tensor, we still compile the full graph
meta
=
self
.
input_batch
.
sampling_metadata
indices
=
torch
.
zeros
(
num_reqs_to_sample
,
dtype
=
torch
.
int32
,
device
=
device
,
)
xm
.
mark_step
()
sampling_meta
=
TPUSupportedSamplingMetadata
.
\
from_sampling_metadata
(
meta
,
indices
,
num_reqs_to_sample
,
device
)
from_input_batch
(
self
.
input_batch
,
indices
)
logger
.
info
(
" -- num_tokens: %d, num_seqs: %d"
,
num_tokens
,
num_reqs_to_sample
)
self
.
model
.
sample_from_hidden
(
dummy_hidden
,
sampling_meta
)
xm
.
mark_step
()
if
num_reqs_to_sample
>=
self
.
max_num_reqs
:
out
=
self
.
model
.
sample_from_hidden
(
dummy_hidden
,
sampling_meta
)
out
=
out
.
cpu
()
# Requests can't be more than tokens. But do compile for the
# next bigger value in case num_tokens uses bucketed padding.
if
num_reqs_to_sample
>=
min
(
num_tokens
,
self
.
max_num_reqs
):
break
num_reqs_to_sample
*=
2
if
num_tokens
>=
self
.
max_num_tokens
:
break
num_tokens
*=
2
# Make sure to compile the `max_num_reqs` upper-limit case
num_reqs_to_sample
=
_get_padded_num_reqs_with_upper_limit
(
num_reqs_to_sample
+
1
,
self
.
max_num_reqs
)
xm
.
wait_device_ops
()
end
=
time
.
perf_counter
()
logger
.
info
(
"Compilation finished in in %.2f [secs]."
,
end
-
start
)
# Record the number cached XLA graph after warming up, this will be
# used for checking there is no additional graph compilation during
# runtime execution.
if
self
.
check_recompilation
:
total_cached_graphs
=
xr
.
get_num_cached_compilation_graph
()
num_compiled_graphs
=
total_cached_graphs
-
self
.
num_xla_graphs
logger
.
info
(
"Compiled %d XLA graphs."
,
num_compiled_graphs
)
self
.
num_xla_graphs
+=
num_compiled_graphs
self
.
_update_num_xla_graphs
(
"sampling"
)
def
initialize_kv_cache
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
"""
...
...
@@ -856,12 +897,11 @@ class TPUModelRunner:
kv_cache_spec
.
num_kv_heads
,
kv_cache_spec
.
head_size
)
dtype
=
kv_cache_spec
.
dtype
tpu_k_cache
=
torch
.
zeros
(
kv_cache_shape
,
dtype
=
dtype
,
device
=
self
.
device
)
tpu_v_cache
=
torch
.
zeros_like
(
tpu_k_cache
)
tpu_kv_cache
=
torch
.
zeros
(
kv_cache_shape
,
dtype
=
dtype
,
device
=
self
.
device
)
kv_caches
[
layer_name
]
=
(
tpu_k
_cache
,
tpu_
v_cache
)
kv_caches
[
layer_name
]
=
tpu_kv_cache
else
:
raise
NotImplementedError
...
...
@@ -888,7 +928,7 @@ class ModelWrapperV1(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
list
[
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
],
kv_caches
:
list
[
torch
.
Tensor
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Executes the forward pass of the model.
...
...
@@ -923,10 +963,9 @@ class ModelWrapperV1(nn.Module):
sample_hidden_states
=
\
hidden_states
[
sampling_metadata
.
indices_do_sample
]
logits
=
self
.
compute_logits
(
sample_hidden_states
)
# Greedy sampling can't be run without branching the graph on Sampler.
# Therefore do_argmax/all_greedy is checked here in a xla-friendly way.
# NOTE do_argmax is a scalar, this is just an optimized if/else.
out_tokens
=
torch
.
where
(
sampling_metadata
.
do_argmax
,
# Optimized greedy sampling branch, tracing both paths in a single pass
# NOTE all_greedy is a scalar, this is just an optimized if/else.
out_tokens
=
torch
.
where
(
sampling_metadata
.
all_greedy
,
torch
.
argmax
(
logits
,
dim
=-
1
,
keepdim
=
True
),
self
.
sample
(
logits
,
sampling_metadata
)
\
.
sampled_token_ids
)
...
...
@@ -949,12 +988,50 @@ def _get_padded_number(n: int, multiple: int) -> int:
return
((
n
+
multiple
-
1
)
//
multiple
)
*
multiple
def
_get_padded_token_len
(
x
:
int
)
->
int
:
if
x
<=
16
:
return
16
return
1
<<
(
x
-
1
).
bit_length
()
def
_get_padded_num_reqs_with_upper_limit
(
x
,
upper_limit
)
->
int
:
res
=
MIN_NUM_SEQS
if
x
<=
MIN_NUM_SEQS
else
1
<<
(
x
-
1
).
bit_length
()
return
min
(
res
,
upper_limit
)
def
_get_paddings
(
min_token_size
:
int
,
max_token_size
:
int
,
padding_gap
:
int
)
->
list
[
int
]:
"""Generate a list of padding size, starting from min_token_size,
ending with a number that can cover max_token_size
If padding_gap == 0 then:
increase 2X each time (exponential)
else:
first increase the size to twice,
then increase the padding size by padding_gap.
"""
paddings
=
[]
num
=
min_token_size
if
padding_gap
==
0
:
logger
.
info
(
"Using exponential paddings:"
)
while
num
<=
max_token_size
:
logger
.
info
(
" %d"
,
num
)
paddings
.
append
(
num
)
num
*=
2
else
:
logger
.
info
(
"Using incremental paddings:"
)
while
num
<=
padding_gap
:
logger
.
info
(
" %d"
,
num
)
paddings
.
append
(
num
)
num
*=
2
num
//=
2
while
num
<
max_token_size
:
num
+=
padding_gap
logger
.
info
(
" %d"
,
num
)
paddings
.
append
(
num
)
return
paddings
def
_get_padded_token_len
(
paddings
:
list
[
int
],
x
:
int
)
->
int
:
"""Return the first element in paddings list greater or equal to x.
"""
index
=
bisect
.
bisect_left
(
paddings
,
x
)
assert
index
<
len
(
paddings
)
return
paddings
[
index
]
vllm/v1/worker/tpu_worker.py
View file @
fcfc474d
...
...
@@ -18,7 +18,7 @@ from vllm.logger import init_logger
from
vllm.model_executor
import
set_random_seed
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.kv_cache_interface
import
(
Full
AttentionSpec
,
KVCacheConfig
,
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
KVCacheConfig
,
KVCacheSpec
)
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.utils
import
bind_kv_cache
...
...
@@ -66,20 +66,30 @@ class TPUWorker:
from
vllm.utils
import
init_cached_hf_modules
init_cached_hf_modules
()
# Delay profiler initialization to the start of the profiling.
# This is because in vLLM V1, MP runtime is initialized before the
# TPU Worker is initialized. The profiler server needs to start after
# MP runtime is initialized.
self
.
profiler
=
None
self
.
profile_dir
=
None
if
envs
.
VLLM_TORCH_PROFILER_DIR
and
self
.
rank
<
1
:
# For TPU, we can only have 1 active profiler session for 1 profiler
# server. So we only profile on rank0.
self
.
profile_dir
=
envs
.
VLLM_TORCH_PROFILER_DIR
logger
.
info
(
"Profiling enabled. Traces will be saved to: %s"
,
self
.
profile_dir
)
self
.
profiler
=
xp
.
start_server
(
9012
)
if
self
.
model_config
.
seed
is
None
:
self
.
model_config
.
seed
=
0
def
init_device
(
self
):
os
.
environ
[
"PJRT_DEVICE"
]
=
"TPU"
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
# ring, the xla tpu compiler flag
# `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to
# fix this. It will be removed after the bug in XLA compiler is fixed.
os
.
environ
[
"LIBTPU_INIT_ARGS"
]
=
(
"--xla_tpu_force_1d_allreduce_at_chunk_count=1"
)
torch
.
set_grad_enabled
(
False
)
torch
.
set_default_dtype
(
self
.
model_config
.
dtype
)
...
...
@@ -101,17 +111,24 @@ class TPUWorker:
# Increase the cache size limit, which is the maximum number of
# dynamo graphs that can be compiled.
#
NOTE(woosuk): Usually,
we compile
10-15
graphs
for prefill and
#
30-40 graphs for decode. 128 is an arbitrary safe number
.
#
TODO (NickLucche) On gsm
we compile
80+
graphs
.
#
Re-evaluate limit, with MM we may get close to this limit
.
torch
.
_dynamo
.
config
.
cache_size_limit
=
128
# Use persistent cache to avoid XLA recompilation.
# NOTE(woosuk): Set per-rank cache path since different ranks
# can have slightly different XLA graphs.
world_size
=
self
.
parallel_config
.
world_size
rank
=
xr
.
global_ordinal
()
per_rank_path
=
os
.
path
.
join
(
envs
.
VLLM_XLA_CACHE_PATH
,
f
"tp
{
world_size
}
_rank
{
rank
}
"
)
xr
.
initialize_cache
(
per_rank_path
,
readonly
=
False
)
# The PyTorch/XLA compilation cache uses the Torch IR to generate keys.
# Consequently, changes in optimization flags, which affect compilation
# results, don't change the cache key. This can result in the wrong
# compilation being used. To prevent this, disabling the XLA compilation
# cache during development is recommended.We can disable it by
# `export VLLM_XLA_CACHE_PATH=`
if
envs
.
VLLM_XLA_CACHE_PATH
:
per_rank_path
=
os
.
path
.
join
(
envs
.
VLLM_XLA_CACHE_PATH
,
f
"tp
{
world_size
}
_rank
{
rank
}
"
)
xr
.
initialize_cache
(
per_rank_path
,
readonly
=
False
)
# Init ModelRunner here, so that we have access to self.device.
self
.
model_runner
=
TPUModelRunner
(
self
.
vllm_config
,
self
.
device
)
...
...
@@ -120,17 +137,18 @@ class TPUWorker:
kv_caches
:
dict
[
str
,
torch
.
Tensor
]
=
{}
kv_cache_spec
=
self
.
model_runner
.
get_kv_cache_spec
()
for
layer_name
,
layer_spec
in
kv_cache_spec
.
items
():
if
isinstance
(
layer_spec
,
Full
AttentionSpec
):
if
isinstance
(
layer_spec
,
AttentionSpec
):
dtype
=
layer_spec
.
dtype
# Use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
tpu_k_cache
=
torch
.
tensor
([],
dtype
=
dtype
,
device
=
self
.
device
)
tpu_v_cache
=
torch
.
tensor
([],
dtype
=
dtype
,
device
=
self
.
device
)
kv_caches
[
layer_name
]
=
(
tpu_k
_cache
,
tpu_
v_cache
)
tpu_k
v
_cache
=
torch
.
tensor
([],
dtype
=
dtype
,
device
=
self
.
device
)
kv_caches
[
layer_name
]
=
tpu_kv_cache
else
:
raise
NotImplementedError
raise
NotImplementedError
(
f
"Unsupported KV cache spec '
{
type
(
layer_spec
)
}
'"
)
runner_kv_caches
:
list
[
torch
.
Tensor
]
=
[]
bind_kv_cache
(
...
...
@@ -150,7 +168,13 @@ class TPUWorker:
# intermediate activations.
m
=
xm
.
get_memory_info
(
self
.
device
)
total_memory_size
=
m
[
"bytes_limit"
]
profiled
=
m
[
"peak_bytes_used"
]
# Weights + intermediate activations.
current_mem
=
m
[
"bytes_used"
]
# Ideally we would use profiled = m["peak_bytes_used"] to
# get weights + activations. But there is memory used during
# compilation / weight loading that impacts the peak and
# there is no way to reset peak memory in XLA, So we
# use the heuristic of 2% of weights.
profiled
=
current_mem
*
1.02
# Calculate the TPU KV cache size based on profiling.
usable_memory_size
=
int
(
total_memory_size
*
...
...
@@ -168,9 +192,11 @@ class TPUWorker:
def
profile
(
self
,
is_start
:
bool
=
True
):
if
self
.
rank
<
1
:
if
self
.
profiler
is
None
:
if
self
.
profile
_di
r
is
None
:
raise
RuntimeError
(
"Profiler is not enabled."
)
if
is_start
:
if
self
.
profiler
is
None
:
self
.
profiler
=
xp
.
start_server
(
9012
)
xp
.
start_trace
(
self
.
profile_dir
)
else
:
xp
.
stop_trace
()
...
...
vllm/v1/worker/utils.py
0 → 100644
View file @
fcfc474d
# SPDX-License-Identifier: Apache-2.0
import
torch
def
sanity_check_mm_encoder_outputs
(
mm_embeddings
:
object
,
expected_num_items
:
int
,
)
->
None
:
"""
Perform sanity checks for the result of
:meth:`vllm.model_executor.models.SupportsMultiModal.get_multimodal_embeddings`.
"""
assert
isinstance
(
mm_embeddings
,
(
list
,
tuple
,
torch
.
Tensor
)),
(
"Expected multimodal embeddings to be a list/tuple of 2D tensors, "
f
"or a single 3D tensor, but got
{
type
(
mm_embeddings
)
}
"
"instead. This is most likely due to incorrect implementation "
"of the model's `get_multimodal_embeddings` method."
)
assert
len
(
mm_embeddings
)
==
expected_num_items
,
(
"Expected number of multimodal embeddings to match number of "
f
"input items:
{
expected_num_items
}
, but got
{
len
(
mm_embeddings
)
=
}
"
"instead. This is most likely due to incorrect implementation "
"of the model's `get_multimodal_embeddings` method."
)
assert
all
(
e
.
ndim
==
2
for
e
in
mm_embeddings
),
(
"Expected multimodal embeddings to be a sequence of 2D tensors, "
f
"but got tensors with shapes
{
[
e
.
shape
for
e
in
mm_embeddings
]
}
"
"instead. This is most likely due to incorrect implementation "
"of the model's `get_multimodal_embeddings` method."
)
vllm/version.py
View file @
fcfc474d
...
...
@@ -28,4 +28,13 @@ def _prev_minor_version_was(version_str):
return
True
# Note - this won't do the right thing when we release 1.0!
assert
__version_tuple__
[
0
]
==
0
assert
isinstance
(
__version_tuple__
[
1
],
int
)
return
version_str
==
f
"
{
__version_tuple__
[
0
]
}
.
{
__version_tuple__
[
1
]
-
1
}
"
def
_prev_minor_version
():
"""For the purpose of testing, return a previous minor version number."""
# In dev tree, this will return "0.-1", but that will work fine"
assert
isinstance
(
__version_tuple__
[
1
],
int
)
return
f
"
{
__version_tuple__
[
0
]
}
.
{
__version_tuple__
[
1
]
-
1
}
"
vllm/worker/cpu_model_runner.py
View file @
fcfc474d
...
...
@@ -469,6 +469,7 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
self
.
kv_cache_dtype
,
self
.
block_size
,
self
.
model_config
.
is_attention_free
,
use_mla
=
self
.
model_config
.
use_mla
,
)
if
needs_attn_backend
else
None
# Multi-modal data support
...
...
vllm/worker/cpu_worker.py
View file @
fcfc474d
# SPDX-License-Identifier: Apache-2.0
"""A CPU worker class."""
import
os
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
import
torch
...
...
@@ -67,6 +68,7 @@ class CPUCacheEngine:
cache_config
.
cache_dtype
,
self
.
block_size
,
self
.
model_config
.
is_attention_free
,
use_mla
=
self
.
model_config
.
use_mla
,
)
# Initialize the cache.
...
...
@@ -106,7 +108,7 @@ class CPUCacheEngine:
num_layers
=
model_config
.
get_num_layers
(
parallel_config
)
key_cache_block
=
block_size
*
num_heads
*
head_size
value_cache_block
=
key_cache_block
value_cache_block
=
key_cache_block
if
not
model_config
.
use_mla
else
0
total
=
num_layers
*
(
key_cache_block
+
value_cache_block
)
if
cache_dtype
==
"auto"
:
dtype
=
model_config
.
dtype
...
...
@@ -139,6 +141,8 @@ class CPUWorker(LocalOrDistributedWorkerBase):
self
.
local_rank
=
local_rank
self
.
rank
=
rank
vllm_config
.
parallel_config
.
rank
=
rank
self
.
distributed_init_method
=
distributed_init_method
self
.
is_driver_worker
=
is_driver_worker
...
...
@@ -217,6 +221,10 @@ class CPUWorker(LocalOrDistributedWorkerBase):
ret
=
torch
.
ops
.
_C_utils
.
init_cpu_threads_env
(
self
.
local_omp_cpuid
)
if
ret
:
logger
.
info
(
ret
)
# Note: unique identifier for creating allreduce shared memory
os
.
environ
[
"VLLM_DIST_IDENT"
]
=
self
.
distributed_init_method
.
split
(
":"
)[
-
1
]
self
.
device
=
torch
.
device
(
"cpu"
)
self
.
init_distributed_environment
()
# Set random seed.
...
...
vllm/worker/model_runner.py
View file @
fcfc474d
...
...
@@ -1145,8 +1145,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
time_after_load
=
time
.
perf_counter
()
self
.
model_memory_usage
=
m
.
consumed_memory
logger
.
info
(
"Model loading took %.4f GB and %.6f seconds"
,
self
.
model_memory_usage
/
float
(
2
**
30
)
,
logger
.
info
(
"Model loading took %.4f G
i
B and %.6f seconds"
,
self
.
model_memory_usage
/
GiB_bytes
,
time_after_load
-
time_before_load
)
if
self
.
prompt_adapter_config
:
self
.
prompt_adapter_manager
=
LRUCacheWorkerPromptAdapterManager
(
...
...
@@ -1244,6 +1244,29 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
max_num_seqs
=
self
.
scheduler_config
.
max_num_seqs
self
.
_dummy_run
(
max_num_batched_tokens
,
max_num_seqs
)
def
_add_dummy_loras
(
self
,
num_loras
:
int
)
->
list
[
LoRARequest
]:
assert
num_loras
>
0
assert
self
.
lora_manager
is
not
None
dummy_lora_requests
:
list
[
LoRARequest
]
=
[]
with
self
.
lora_manager
.
dummy_lora_cache
():
for
idx
in
range
(
num_loras
):
lora_id
=
idx
+
1
dummy_lora_request
=
LoRARequest
(
lora_name
=
f
"warmup_
{
lora_id
}
"
,
lora_int_id
=
lora_id
,
lora_path
=
"/not/a/real/path"
,
)
self
.
lora_manager
.
add_dummy_lora
(
dummy_lora_request
,
rank
=
LORA_WARMUP_RANK
)
dummy_lora_requests
.
append
(
dummy_lora_request
)
return
dummy_lora_requests
def
_remove_dummy_loras
(
self
):
# Remove dummy loras.
assert
self
.
lora_manager
is
not
None
self
.
remove_all_loras
()
def
_dummy_run
(
self
,
max_num_batched_tokens
:
int
,
max_num_seqs
:
int
=
1
)
->
None
:
...
...
@@ -1253,28 +1276,20 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
SamplingParams
(
top_p
=
0.99
,
top_k
=
self
.
vocab_size
-
1
)
# This represents the maximum number of different requests
# that will have unique loras, an therefore the max amount of memory
# consumption create dummy lora request copies from the lora request
# passed in, which contains a lora from the lora warmup path.
# that will have unique loras, and therefore the max amount of
# memory consumption. Create dummy lora request copies from the
# lora request passed in, which contains a lora from the lora
# warmup path.
dummy_lora_requests
:
List
[
LoRARequest
]
=
[]
dummy_lora_requests_per_seq
:
List
[
LoRARequest
]
=
[]
if
self
.
lora_config
:
assert
self
.
lora_manager
is
not
None
with
self
.
lora_manager
.
dummy_lora_cache
():
for
idx
in
range
(
self
.
lora_config
.
max_loras
):
lora_id
=
idx
+
1
dummy_lora_request
=
LoRARequest
(
lora_name
=
f
"warmup_
{
lora_id
}
"
,
lora_int_id
=
lora_id
,
lora_path
=
"/not/a/real/path"
,
)
self
.
lora_manager
.
add_dummy_lora
(
dummy_lora_request
,
rank
=
LORA_WARMUP_RANK
)
dummy_lora_requests
.
append
(
dummy_lora_request
)
dummy_lora_requests_per_seq
=
[
dummy_lora_requests
[
idx
%
len
(
dummy_lora_requests
)]
for
idx
in
range
(
max_num_seqs
)
]
dummy_lora_requests
=
self
.
_add_dummy_loras
(
self
.
lora_config
.
max_loras
)
assert
len
(
dummy_lora_requests
)
==
self
.
lora_config
.
max_loras
dummy_lora_requests_per_seq
=
[
dummy_lora_requests
[
idx
%
len
(
dummy_lora_requests
)]
for
idx
in
range
(
max_num_seqs
)
]
# Profile memory usage with max_num_sequences sequences and the
# total number of tokens equal to max_num_batched_tokens.
...
...
@@ -1356,9 +1371,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
execute_model
(
model_input
,
kv_caches
,
intermediate_tensors
)
torch
.
cuda
.
synchronize
()
if
self
.
lora_config
:
# Remove dummy loras.
assert
self
.
lora_manager
is
not
None
self
.
remove_all_loras
()
self
.
_remove_dummy_loras
()
return
def
remove_all_loras
(
self
):
...
...
@@ -1481,6 +1495,16 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
)
dummy_lora_id
:
Optional
[
int
]
=
None
dummy_lora_request
:
LoRARequest
=
[]
if
self
.
lora_config
:
# The goal is to capture the LoRA kernels in cuda graphs.
# for this purpose, as single dummy lora is sufficient.
dummy_lora_requests
=
self
.
_add_dummy_loras
(
num_loras
=
1
)
assert
len
(
dummy_lora_requests
)
==
1
dummy_lora_request
=
dummy_lora_requests
[
0
]
dummy_lora_id
=
dummy_lora_request
.
lora_int_id
with
self
.
attn_state
.
graph_capture
(
max_batch_size
),
graph_capture
(
self
.
device
)
as
graph_capture_context
:
# NOTE: Capturing the largest batch size first may help reduce the
...
...
@@ -1505,10 +1529,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
attn_metadata
.
enable_kv_scales_calculation
=
False
if
self
.
lora_config
:
lora_mapping
=
LoRAMapping
(
**
dict
(
index_mapping
=
[
0
]
*
batch_size
,
prompt_mapping
=
[
0
]
*
batch_size
,
**
dict
(
index_mapping
=
[
dummy_lora_id
]
*
batch_size
,
prompt_mapping
=
[
dummy_lora_id
]
*
batch_size
,
is_prefill
=
False
))
self
.
set_active_loras
(
set
(),
lora_mapping
)
self
.
set_active_loras
(
set
([
dummy_lora_request
]),
lora_mapping
)
if
self
.
prompt_adapter_config
:
prompt_adapter_mapping
=
PromptAdapterMapping
(
...
...
@@ -1564,6 +1589,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
graph_runners
[
virtual_engine
][
batch_size
]
=
(
graph_runner
)
if
self
.
lora_config
:
self
.
_remove_dummy_loras
()
end_time
=
time
.
perf_counter
()
end_free_gpu_memory
=
torch
.
cuda
.
mem_get_info
()[
0
]
elapsed_time
=
end_time
-
start_time
...
...
Prev
1
…
21
22
23
24
25
26
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment