Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3de379de
Commit
3de379de
authored
Jul 31, 2025
by
zhuwenwen
Browse files
update unused code
parent
5ad884ee
Changes
39
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
1 addition
and
7145 deletions
+1
-7145
vllm/spec_decode/medusa_worker.py
vllm/spec_decode/medusa_worker.py
+0
-357
vllm/spec_decode/metrics.py
vllm/spec_decode/metrics.py
+0
-217
vllm/spec_decode/mlp_speculator_worker.py
vllm/spec_decode/mlp_speculator_worker.py
+0
-165
vllm/spec_decode/multi_step_worker.py
vllm/spec_decode/multi_step_worker.py
+0
-426
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+0
-1568
vllm/spec_decode/tree_style_proposer.py
vllm/spec_decode/tree_style_proposer.py
+0
-318
vllm/spec_decode/util.py
vllm/spec_decode/util.py
+0
-307
vllm/triton_utils/custom_cache_manager.py
vllm/triton_utils/custom_cache_manager.py
+0
-55
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+1
-1
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+0
-2
vllm/worker/cpu_enc_dec_model_runner.py
vllm/worker/cpu_enc_dec_model_runner.py
+0
-326
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+0
-671
vllm/worker/cpu_pooling_model_runner.py
vllm/worker/cpu_pooling_model_runner.py
+0
-125
vllm/worker/cpu_worker.py
vllm/worker/cpu_worker.py
+0
-457
vllm/worker/multi_step_tpu_worker.py
vllm/worker/multi_step_tpu_worker.py
+0
-108
vllm/worker/tpu_model_runner.py
vllm/worker/tpu_model_runner.py
+0
-909
vllm/worker/tpu_worker.py
vllm/worker/tpu_worker.py
+0
-341
vllm/worker/xpu_model_runner.py
vllm/worker/xpu_model_runner.py
+0
-606
vllm/worker/xpu_worker.py
vllm/worker/xpu_worker.py
+0
-186
No files found.
vllm/spec_decode/medusa_worker.py
deleted
100644 → 0
View file @
5ad884ee
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
import
weakref
from
typing
import
List
,
Optional
,
Set
,
Tuple
,
Dict
import
torch
import
torch.nn.functional
as
F
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
SequenceGroupMetadata
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
,
SpeculativeProposer
from
vllm.spec_decode.proposer_worker_base
import
NonLLMProposerWorkerBase
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.worker.worker_base
import
DelegateWorkerBase
from
vllm.spec_decode.tree_style_proposer
import
TreeStyleProposer
from
vllm.distributed
import
broadcast_tensor_dict
from
vllm.worker.worker_base
import
WorkerWrapperBase
TOPK
=
10
# topk for sparse tree (10 is a placeholder and it is sufficient)
class
MedusaWorker
(
NonLLMProposerWorkerBase
,
DelegateWorkerBase
):
"""Worker for Medusa.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
# skip lora config in medusa
DelegateWorkerBase
.
__init__
(
self
,
*
args
,
**
kwargs
)
# Lazy initialization list.
self
.
_proposer
:
SpeculativeProposer
self
.
tree_decoding
=
(
os
.
environ
.
get
(
'VLLM_TREE_DECODING'
)
==
'1'
)
def
init_device
(
self
):
self
.
worker
.
init_device
()
def
load_model
(
self
):
super
().
load_model
()
# get medusa choices and generate medusa_buffers
self
.
medusa_buffers
=
None
if
self
.
tree_decoding
and
hasattr
(
self
.
model_runner
.
model
,
'medusa_choices'
):
self
.
medusa_choices
=
self
.
model_runner
.
model
.
medusa_choices
if
self
.
medusa_choices
is
not
None
:
self
.
medusa_buffers
=
self
.
generate_medusa_buffers
(
self
.
medusa_choices
,
device
=
self
.
device
)
if
self
.
medusa_buffers
is
None
:
self
.
_proposer
=
Top1Proposer
(
weakref
.
proxy
(
self
),
# type: ignore[arg-type]
self
.
device
,
self
.
vocab_size
,
max_proposal_len
=
self
.
max_model_len
,
)
else
:
self
.
_proposer
=
TreeStyleProposer
(
weakref
.
proxy
(
self
),
# type: ignore[arg-type]
self
.
device
,
self
.
vocab_size
,
self
.
medusa_buffers
,
max_proposal_len
=
self
.
max_model_len
,
)
def
set_include_gpu_probs_tensor
(
self
):
pass
def
set_should_modify_greedy_probs_inplace
(
self
):
pass
def
_get_driver_input_and_broadcast
(
self
,
execute_model_req
:
ExecuteModelRequest
)
->
Dict
[
str
,
torch
.
Tensor
]:
seq_group_metadata_list
=
execute_model_req
.
seq_group_metadata_list
seq_lens
,
query_lens
=
self
.
_prepare_input_tensors
(
seq_group_metadata_list
)
generators
=
self
.
model_runner
.
get_generators
(
execute_model_req
.
finished_requests_ids
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
,
query_lens
,
self
.
device
,
self
.
model_runner
.
pin_memory
,
generators
)
sample_indices_list
=
[]
for
seq_group
in
sampling_metadata
.
seq_groups
:
sample_indices_list
.
append
(
seq_group
.
sample_indices
)
previous_hidden_states
=
execute_model_req
.
previous_hidden_states
.
hidden_states
previous_logits
=
execute_model_req
.
previous_logits
.
logits
if
\
execute_model_req
.
previous_logits
is
not
None
else
None
tensor_dict
=
{
"previous_hidden_states"
:
previous_hidden_states
,
"previous_logits"
:
previous_logits
,
"sample_indices_list"
:
sample_indices_list
,
"seq_lens"
:
seq_lens
}
if
self
.
do_metadata_broadcast
:
broadcast_tensor_dict
(
tensor_dict
,
src
=
0
)
return
tensor_dict
def
_get_worker_input_from_broadcast
(
self
)
->
Optional
[
Dict
[
str
,
torch
.
Tensor
]]:
""" Get the worker input from the broadcasted tensor dict. """
assert
self
.
do_metadata_broadcast
assert
not
self
.
is_driver_worker
broadcast_data
=
broadcast_tensor_dict
(
src
=
0
)
return
broadcast_data
@
torch
.
inference_mode
()
def
sampler_output
(
self
,
execute_model_req
:
ExecuteModelRequest
,
sample_len
:
int
,
# Unused parameter.
seq_ids_with_bonus_token_in_last_step
:
Set
[
int
],
)
->
Tuple
[
List
[
SamplerOutput
],
bool
]:
"""Run the model forward pass to generate sample_len future tokens.
Returns the list of sampler output, one per layer, along with indicator
of whether torch tensor in sampler output need to be transposed in
latter sampler_output_to_torch logic.
For medusa worker, this indicator shall be False.
"""
self
.
_raise_if_unsupported
(
execute_model_req
)
if
self
.
is_driver_worker
:
tensor_dict
=
self
.
_get_driver_input_and_broadcast
(
execute_model_req
)
else
:
tensor_dict
=
self
.
_get_worker_input_from_broadcast
()
if
tensor_dict
is
None
:
raise
ValueError
(
"Can not get inputs of medusa worker!!!"
)
model_outputs
=
self
.
model_runner
.
model
.
generate_proposals
(
previous_hidden_states
=
tensor_dict
[
"previous_hidden_states"
],
sample_indices_list
=
tensor_dict
[
"sample_indices_list"
],
previous_logits
=
tensor_dict
[
"previous_logits"
],
medusa_buffers
=
self
.
medusa_buffers
)
# create tree attn masks
if
self
.
is_driver_worker
and
self
.
medusa_buffers
is
not
None
:
seq_lens
=
tensor_dict
[
"seq_lens"
]
max_context_len
=
max
(
seq_lens
)
for
sampler_output
,
seq_len
in
zip
(
model_outputs
,
seq_lens
):
context_len
=
seq_len
attn_masks
=
self
.
medusa_buffers
[
'tree_attn_masks'
]
left_mask
=
torch
.
ones
(
attn_masks
.
shape
[
0
],
context_len
,
dtype
=
attn_masks
.
dtype
,
device
=
attn_masks
.
device
)
attn_masks
=
torch
.
cat
([
left_mask
,
attn_masks
],
dim
=-
1
)
right_pad
=
max_context_len
-
context_len
if
right_pad
>
0
:
attn_masks
=
F
.
pad
(
attn_masks
,
(
0
,
right_pad
),
"constant"
,
0
)
sampler_output
.
tree_attn_masks
=
attn_masks
return
model_outputs
,
False
def
_prepare_input_tensors
(
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
if
not
seq_group_metadata_list
:
return
[],
[]
seq_lens
:
List
[
int
]
=
[]
query_lens
:
List
[
int
]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
is_prompt
=
seq_group_metadata
.
is_prompt
for
seq_data
in
seq_group_metadata
.
seq_data
.
values
():
seq_data_len
=
seq_data
.
get_len
()
if
is_prompt
:
context_len
=
seq_data
.
get_num_computed_tokens
()
seq_len
=
min
(
seq_data_len
,
context_len
+
seq_group_metadata
.
token_chunk_size
)
seq_lens
.
append
(
seq_len
)
query_lens
.
append
(
seq_len
-
context_len
)
else
:
# first step of tree decoding need to ignore first token
if
self
.
medusa_buffers
is
not
None
and
seq_data
.
get_first_step_flag
():
seq_data_len
-=
1
seq_lens
.
append
(
seq_data_len
)
query_lens
.
append
(
1
)
return
seq_lens
,
query_lens
def
get_spec_proposals
(
self
,
execute_model_req
:
ExecuteModelRequest
,
seq_ids_with_bonus_token_in_last_step
:
Set
[
int
],
)
->
SpeculativeProposals
:
"""Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len.
"""
return
self
.
_proposer
.
get_spec_proposals
(
execute_model_req
,
seq_ids_with_bonus_token_in_last_step
)
def
_raise_if_unsupported
(
self
,
execute_model_req
:
ExecuteModelRequest
,
)
->
None
:
"""MedusaWorker does not yet implement support for cache swap
operations or beam search.
"""
if
execute_model_req
is
None
:
return
None
if
any
([
execute_model_req
.
blocks_to_swap_in
,
execute_model_req
.
blocks_to_swap_out
,
execute_model_req
.
blocks_to_copy
]):
raise
NotImplementedError
(
"MedusaWorker does not support cache operations"
)
if
any
(
len
(
seq_group_metadata
.
seq_data
.
keys
())
!=
1
for
seq_group_metadata
in
execute_model_req
.
seq_group_metadata_list
):
raise
NotImplementedError
(
"MedusaWorker does not support beam search."
)
def
pad_path
(
self
,
path
,
length
,
pad_value
=-
2
):
"""
Pad the given path list with a specific value up to a specified length.
Parameters:
- path (list): The original list that needs padding.
- length (int): The desired length of the padded list.
- pad_value (optional, default=-2): The value to use for padding.
Returns:
- list: A new list based on the original path but padded to the desired length.
Example:
>>> pad_path([1,2,3], 5)
[1, 2, 3, -2, -2]
Note:
If the given path is already longer than the specified length,
then no padding occurs, and the original path is returned.
"""
# Calculate the number of padding values needed by subtracting the length
# of the path from the desired length.
# Append the padding values to the original path and return the new list.
return
path
+
[
pad_value
]
*
(
length
-
len
(
path
))
def
generate_medusa_buffers
(
self
,
medusa_choices
,
device
=
"cuda"
):
"""
Generate buffers for the Medusa structure based on the provided choices.
Parameters:
- medusa_choices (list): A nested list representing tree in the Medusa structure.
- device (str): Device to which the tensors should be moved. Default is "cuda".
Returns:
- dict: A dictionary containing buffers related to the Medusa structure.
"""
# Sort the medusa_choices based on their lengths and then their values
sorted_medusa_choices
=
sorted
(
medusa_choices
,
key
=
lambda
x
:
(
len
(
x
),
x
))
medusa_len
=
len
(
sorted_medusa_choices
)
+
1
# Initialize depth_counts to keep track of how many choices have a particular depth
depth_counts
=
[]
prev_depth
=
0
for
path
in
sorted_medusa_choices
:
depth
=
len
(
path
)
if
depth
!=
prev_depth
:
depth_counts
.
append
(
0
)
depth_counts
[
depth
-
1
]
+=
1
prev_depth
=
depth
# Create the attention mask for Medusa
medusa_attn_mask
=
torch
.
eye
(
medusa_len
,
medusa_len
)
medusa_attn_mask
[:,
0
]
=
1
start
=
0
for
i
in
range
(
len
(
depth_counts
)):
for
j
in
range
(
depth_counts
[
i
]):
cur_medusa_choice
=
sorted_medusa_choices
[
start
+
j
]
# retrieve ancestor position
if
len
(
cur_medusa_choice
)
==
1
:
continue
ancestor_idx
=
[]
for
c
in
range
(
len
(
cur_medusa_choice
)
-
1
):
ancestor_idx
.
append
(
sorted_medusa_choices
.
index
(
cur_medusa_choice
[:
c
+
1
])
+
1
)
medusa_attn_mask
[
j
+
start
+
1
,
ancestor_idx
]
=
1
start
+=
depth_counts
[
i
]
# Generate tree indices for the Medusa structure
medusa_tree_indices
=
torch
.
zeros
(
medusa_len
,
dtype
=
torch
.
long
)
medusa_tree_indices
[
0
]
=
0
start
=
0
for
i
in
range
(
len
(
depth_counts
)):
for
j
in
range
(
depth_counts
[
i
]):
cur_medusa_choice
=
sorted_medusa_choices
[
start
+
j
]
medusa_tree_indices
[
start
+
j
+
1
]
=
cur_medusa_choice
[
-
1
]
+
TOPK
*
i
+
1
start
+=
depth_counts
[
i
]
# Generate position IDs for the Medusa structure
medusa_position_ids
=
torch
.
zeros
(
medusa_len
,
dtype
=
torch
.
long
)
start
=
0
for
i
in
range
(
len
(
depth_counts
)):
medusa_position_ids
[
start
+
1
:
start
+
depth_counts
[
i
]
+
1
]
=
i
+
1
start
+=
depth_counts
[
i
]
# Generate retrieval indices for Medusa structure verification
retrieve_indices_nest
=
[]
retrieve_paths
=
[]
for
i
in
range
(
len
(
sorted_medusa_choices
)):
cur_medusa_choice
=
sorted_medusa_choices
[
-
i
-
1
]
retrieve_indice
=
[]
if
cur_medusa_choice
in
retrieve_paths
:
continue
else
:
for
c
in
range
(
len
(
cur_medusa_choice
)):
retrieve_indice
.
append
(
sorted_medusa_choices
.
index
(
cur_medusa_choice
[:
c
+
1
]))
retrieve_paths
.
append
(
cur_medusa_choice
[:
c
+
1
])
retrieve_indices_nest
.
append
(
retrieve_indice
)
max_length
=
max
([
len
(
x
)
for
x
in
retrieve_indices_nest
])
retrieve_indices
=
[
self
.
pad_path
(
path
,
max_length
)
for
path
in
retrieve_indices_nest
]
retrieve_indices
=
torch
.
tensor
(
retrieve_indices
,
dtype
=
torch
.
long
)
retrieve_indices
=
retrieve_indices
+
1
retrieve_indices
=
torch
.
cat
([
torch
.
zeros
((
retrieve_indices
.
shape
[
0
],
1
),
dtype
=
torch
.
long
),
retrieve_indices
],
dim
=
1
)
# Aggregate the generated buffers into a dictionary
medusa_buffers
=
{
"tree_attn_masks"
:
medusa_attn_mask
.
int
(),
"tree_indices"
:
medusa_tree_indices
,
"tree_position_ids"
:
medusa_position_ids
,
"retrieve_indices"
:
retrieve_indices
,
}
# Move the tensors in the dictionary to the specified device
medusa_buffers
=
{
k
:
v
.
clone
().
to
(
device
)
if
isinstance
(
v
,
torch
.
Tensor
)
else
torch
.
tensor
(
v
,
device
=
device
)
for
k
,
v
in
medusa_buffers
.
items
()
}
return
medusa_buffers
vllm/spec_decode/metrics.py
deleted
100644 → 0
View file @
5ad884ee
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
time
from
typing
import
Callable
,
Optional
,
Union
import
msgspec
import
torch
from
vllm.model_executor.layers.spec_decode_base_sampler
import
(
SpecDecodeBaseSampler
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_pin_memory_available
class
SpecDecodeWorkerMetrics
(
msgspec
.
Struct
,
omit_defaults
=
True
,
# type: ignore[call-arg]
array_like
=
True
):
# type: ignore[call-arg]
"""Dataclass holding metrics emitted from the spec decode worker.
"""
# The empirical acceptance rate of the proposal method on a per-token basis.
# This is useful for evaluating how well the proposal method aligns with the
# scoring method.
draft_acceptance_rate
:
float
# The empirical efficiency, measured as the number of tokens emitted by the
# system divided by the number of tokens that could be emitted by the system
# if the proposal method were perfect.
system_efficiency
:
float
# The number of speculative tokens produced by the proposal method.
draft_tokens
:
int
# The number of tokens emitted by the entire system.
emitted_tokens
:
int
# The number of tokens accepted by the scoring model and verification
# routine, e.g. Llama2-70B and lossless rejection sampling.
#
# NOTE: Any token accepted by the verification routine is considered
# accepted (regardless of if the speculative prefix is also accepted). The
# user will usually see less accepted tokens. This metric is helpful when
# evaluating alignment of the proposal method with the scoring model.
accepted_tokens
:
int
# The number of speculative tokens per sequence.
num_spec_tokens
:
int
Timer
=
Callable
[[],
float
]
class
AsyncMetricsCollector
:
"""Class which copies rejection/typical-acceptance sampler metrics
from the device to CPU on a non-default Torch stream.
"""
def
__init__
(
self
,
spec_decode_sampler
:
SpecDecodeBaseSampler
,
timer
:
Optional
[
Timer
]
=
None
,
collect_interval_s
:
float
=
5.0
):
self
.
spec_decode_sampler
=
spec_decode_sampler
self
.
_timer
=
time
.
time
if
timer
is
None
else
timer
self
.
_rank
:
Optional
[
int
]
=
None
# We don't have a device set yet.
self
.
_copy_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
self
.
_in_flight_copy
:
Optional
[
torch
.
cuda
.
Event
]
=
None
self
.
_aggregate_num_draft_tokens
=
0
self
.
_rejsample_metrics_collect_interval_s
=
collect_interval_s
self
.
_last_metrics_collect_time
=
self
.
_timer
()
def
init_gpu_tensors
(
self
,
rank
:
int
)
->
None
:
self
.
_rank
=
rank
self
.
_copy_stream
=
torch
.
cuda
.
Stream
()
def
init_tensors
(
self
,
rank
:
int
,
device_type
:
Union
[
torch
.
device
,
str
]
=
'cuda'
)
->
None
:
self
.
_rank
=
rank
if
isinstance
(
device_type
,
torch
.
device
):
torch
.
cuda
.
set_device
(
device_type
)
device_type
=
device_type
.
type
# stream = current_platform.Stream
# if stream is not None:
# self._copy_stream = stream()
if
device_type
==
'cuda'
:
self
.
_copy_stream
=
torch
.
cuda
.
Stream
()
pin_memory
=
is_pin_memory_available
()
self
.
_aggregate_num_accepted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
_aggregate_num_emitted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
def
maybe_collect_rejsample_metrics
(
self
,
k
:
int
)
->
Optional
[
SpecDecodeWorkerMetrics
]:
# Skip for any platform that doesn't have device Event
# if current_platform.Event is None:
# return None
# If a copy was initiated in the previous call, collect and return.
if
self
.
_in_flight_copy
is
not
None
:
ready_event
=
self
.
_in_flight_copy
self
.
_in_flight_copy
=
None
return
self
.
_collect_rejsample_metrics
(
k
,
ready_event
)
# Otherwise, check if we should start a new copy.
if
self
.
_should_collect_rejsample_metrics
(
self
.
_timer
()):
assert
self
.
_in_flight_copy
is
None
self
.
_in_flight_copy
=
self
.
_copy_rejsample_metrics_async
()
return
None
def
_should_collect_rejsample_metrics
(
self
,
now
:
float
)
->
bool
:
"""Return whether or not this iteration should print sampling
metrics.
"""
if
self
.
_rank
!=
0
:
return
False
return
now
-
self
.
_last_metrics_collect_time
>=
self
.
_rejsample_metrics_collect_interval_s
# noqa: E501
def
_copy_rejsample_metrics_async
(
self
)
->
torch
.
cuda
.
Event
:
"""Copy rejection/typical-acceptance sampling metrics
(number of accepted tokens, etc) to CPU asynchronously.
Returns a device event recording when the copy is complete.
"""
assert
self
.
_copy_stream
is
not
None
self
.
_copy_stream
.
wait_stream
(
current_platform
.
current_stream
())
with
current_platform
.
stream
(
self
.
_copy_stream
):
self
.
_aggregate_num_accepted_tokens
.
copy_
(
self
.
spec_decode_sampler
.
num_accepted_tokens
,
non_blocking
=
True
)
self
.
_aggregate_num_emitted_tokens
.
copy_
(
self
.
spec_decode_sampler
.
num_emitted_tokens
,
non_blocking
=
True
)
# Number of draft tokens is calculated on CPU, so no copy is
# required.
self
.
_aggregate_num_draft_tokens
=
(
self
.
spec_decode_sampler
.
num_draft_tokens
)
aggregate_metrics_ready
=
current_platform
.
Event
()
aggregate_metrics_ready
.
record
(
self
.
_copy_stream
)
return
aggregate_metrics_ready
def
_collect_rejsample_metrics
(
self
,
k
:
int
,
ready_event
:
torch
.
cuda
.
Event
)
->
SpecDecodeWorkerMetrics
:
"""Create metrics object from statistics copied asynchronously.
Args:
k: int. The number of speculative tokens; used to determine system
efficiency.
ready_event: torch.cuda.Event. The CUDA event recording when the
async GPU->CPU copy is complete.
"""
ready_event
.
synchronize
()
# update time of last collection
self
.
_last_metrics_collect_time
=
self
.
_timer
()
accepted_tokens
=
self
.
_aggregate_num_accepted_tokens
.
item
()
emitted_tokens
=
self
.
_aggregate_num_emitted_tokens
.
item
()
draft_tokens
=
self
.
_aggregate_num_draft_tokens
max_num_emitted_tokens
=
self
.
get_max_num_emitted_tokens
(
draft_tokens
,
k
)
if
draft_tokens
>
0
:
draft_acceptance_rate
=
accepted_tokens
/
draft_tokens
else
:
draft_acceptance_rate
=
float
(
"nan"
)
if
max_num_emitted_tokens
>
0
:
system_efficiency
=
emitted_tokens
/
max_num_emitted_tokens
else
:
system_efficiency
=
float
(
"nan"
)
return
SpecDecodeWorkerMetrics
(
num_spec_tokens
=
k
,
draft_acceptance_rate
=
draft_acceptance_rate
,
system_efficiency
=
system_efficiency
,
accepted_tokens
=
accepted_tokens
,
draft_tokens
=
draft_tokens
,
emitted_tokens
=
emitted_tokens
,
)
@
staticmethod
def
get_max_num_emitted_tokens
(
draft_tokens
:
int
,
k
:
int
)
->
int
:
"""Calculate the number of emitted tokens, assuming all tokens are
accepted.
This is equal to the number of sequences that have been speculated on,
times (speculation len + 1). The +1 comes from the bonus token.
"""
# Determine the number of sequences that have been speculated on. Since
# the batch size can be variable, we divide by k.
assert
draft_tokens
%
k
==
0
total_num_spec_seqs
=
draft_tokens
//
k
# A single sequence may emit k accepted tokens and one bonus token in
# the best case.
num_emitted_per_seq_if_all_accepted
=
k
+
1
# The max num of emitted tokens is the number of speculated sequences
# times the max emitted per seq.
return
total_num_spec_seqs
*
num_emitted_per_seq_if_all_accepted
vllm/spec_decode/mlp_speculator_worker.py
deleted
100644 → 0
View file @
5ad884ee
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
List
,
Optional
,
Set
,
Tuple
,
Dict
import
torch
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
SequenceGroupMetadata
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.proposer_worker_base
import
NonLLMProposerWorkerBase
from
vllm.distributed
import
broadcast_tensor_dict
class
MLPSpeculatorWorker
(
NonLLMProposerWorkerBase
,
MultiStepWorker
):
"""Worker for MLPSpeculator models.
Not currently compatible with LoRA or chunked prefill.
"""
def
_get_driver_input_and_broadcast
(
self
,
execute_model_req
:
ExecuteModelRequest
,
sample_len
:
int
,
index
:
int
,
last_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
previous_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
sampling_metadata
:
Optional
[
SamplingMetadata
]
=
None
)
->
Dict
[
str
,
torch
.
Tensor
]:
if
sampling_metadata
is
None
and
execute_model_req
is
not
None
:
seq_group_metadata_list
=
execute_model_req
.
seq_group_metadata_list
(
input_tokens
,
seq_lens
,
query_lens
)
=
self
.
_prepare_input_tensors
(
seq_group_metadata_list
)
# b x 1
last_tokens
=
input_tokens
.
unsqueeze
(
1
)
generators
=
self
.
model_runner
.
get_generators
(
execute_model_req
.
finished_requests_ids
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
,
query_lens
,
self
.
device
,
self
.
model_runner
.
pin_memory
,
generators
)
previous_hidden_states
=
execute_model_req
.
previous_hidden_states
.
hidden_states
# b x 1 x d
previous_hidden_states
=
previous_hidden_states
.
unsqueeze
(
1
)
tensor_dict
=
{
"input_tokens"
:
last_tokens
,
"previous_hidden_states"
:
previous_hidden_states
,
"sample_len"
:
sample_len
,
"head_index"
:
index
}
if
self
.
do_metadata_broadcast
:
broadcast_tensor_dict
(
tensor_dict
,
src
=
0
)
return
tensor_dict
,
sampling_metadata
def
_get_worker_input_from_broadcast
(
self
)
->
Optional
[
Dict
[
str
,
torch
.
Tensor
]]:
""" Get the worker input from the broadcasted tensor dict. """
assert
self
.
do_metadata_broadcast
assert
not
self
.
is_driver_worker
broadcast_data
=
broadcast_tensor_dict
(
src
=
0
)
return
broadcast_data
@
torch
.
inference_mode
()
def
sampler_output
(
self
,
execute_model_req
:
ExecuteModelRequest
,
sample_len
:
int
,
# Unused parameter. MLPSpeculatorWorker does not use the KV Cache and
# therefore does not need this parameter.
seq_ids_with_bonus_token_in_last_step
:
Set
[
int
],
)
->
Tuple
[
List
[
SamplerOutput
],
bool
]:
"""Run the model forward pass to generate sample_len future tokens.
Returns the list of sampler output, one per layer, along with indicator
of whether torch tensor in sampler output need to be transposed in
latter sampler_output_to_torch logic.
For mlp spec worker, this indicator shall be True.
"""
self
.
_raise_if_unsupported
(
execute_model_req
)
model_outputs
=
[]
last_tokens
=
None
previous_hidden_states
=
None
sampling_metadata
=
None
for
index
in
range
(
sample_len
):
if
self
.
is_driver_worker
:
tensor_dict
,
sampling_metadata
=
self
.
_get_driver_input_and_broadcast
(
execute_model_req
,
sample_len
,
index
,
last_tokens
,
previous_hidden_states
,
sampling_metadata
)
assert
sampling_metadata
is
not
None
output
,
previous_hidden_states
=
self
.
model_runner
.
model
.
generate_proposals
(
input_ids
=
tensor_dict
[
"input_tokens"
],
previous_hidden_states
=
tensor_dict
[
"previous_hidden_states"
],
num_predict_tokens
=
tensor_dict
[
"sample_len"
],
sampling_metadata
=
sampling_metadata
,
head_index
=
index
)
last_tokens
=
output
.
sampled_token_ids
model_outputs
.
append
(
output
)
else
:
tensor_dict
=
self
.
_get_worker_input_from_broadcast
()
if
tensor_dict
is
None
:
raise
ValueError
(
"Can not get inputs of mlp_speculator worker!!!"
)
self
.
model_runner
.
model
.
generate_proposals
(
input_ids
=
tensor_dict
[
"input_tokens"
],
previous_hidden_states
=
tensor_dict
[
"previous_hidden_states"
],
num_predict_tokens
=
tensor_dict
[
"sample_len"
],
sampling_metadata
=
None
,
head_index
=
tensor_dict
[
"head_index"
])
if
self
.
is_driver_worker
:
assert
len
(
model_outputs
)
==
sample_len
return
model_outputs
,
True
def
_prepare_input_tensors
(
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
)
->
Tuple
[
torch
.
Tensor
,
List
[
int
],
List
[
int
]]:
if
not
seq_group_metadata_list
:
return
torch
.
empty
(
0
,
device
=
self
.
device
),
[],
[]
input_tokens
:
List
[
int
]
=
[]
seq_lens
:
List
[
int
]
=
[]
query_lens
:
List
[
int
]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
is_prompt
=
seq_group_metadata
.
is_prompt
for
seq_data
in
seq_group_metadata
.
seq_data
.
values
():
seq_data_len
=
seq_data
.
get_len
()
if
is_prompt
:
context_len
=
seq_data
.
get_num_computed_tokens
()
seq_len
=
min
(
seq_data_len
,
context_len
+
seq_group_metadata
.
token_chunk_size
)
tokens
=
seq_data
.
get_token_ids
()[
context_len
:
seq_len
]
seq_lens
.
append
(
seq_len
)
input_tokens
.
extend
(
tokens
)
query_lens
.
append
(
seq_len
-
context_len
)
else
:
seq_lens
.
append
(
seq_data_len
)
input_tokens
.
append
(
seq_data
.
get_last_token_id
())
query_lens
.
append
(
1
)
input_tokens_tensor
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
return
input_tokens_tensor
,
seq_lens
,
query_lens
vllm/spec_decode/multi_step_worker.py
deleted
100644 → 0
View file @
5ad884ee
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
copy
import
weakref
from
typing
import
Dict
,
List
,
Set
,
Tuple
import
torch
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
(
ExecuteModelRequest
,
HiddenStates
,
SequenceData
,
SequenceGroupMetadata
)
if
current_platform
.
is_cuda_alike
():
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeProposer
)
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.worker.worker_base
import
DelegateWorkerBase
class
MultiStepWorker
(
ProposerWorkerBase
,
DelegateWorkerBase
):
"""The MultiStepWorker is equivalent to a Worker except that it allows
multiple forward passes in a single call, assuming the scheduler has
allocated enough space to store the additional KV. This reduces overhead
by invoking the scheduler less.
The MultiStepWorker does not support cache swap operations, or beam search.
Cache swap operations do not require large modifications. On the other hand,
beam search requires memory allocations during sequence forks and thus
requires more thought for MultiStepWorker support.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
DelegateWorkerBase
.
__init__
(
self
,
*
args
,
**
kwargs
)
# Lazy initialization list.
self
.
_proposer
:
SpeculativeProposer
def
init_device
(
self
)
->
None
:
self
.
worker
.
init_device
()
self
.
_proposer
=
Top1Proposer
(
weakref
.
proxy
(
self
),
# type: ignore[arg-type]
self
.
device
,
self
.
vocab_size
,
max_proposal_len
=
self
.
max_model_len
,
)
def
set_include_gpu_probs_tensor
(
self
)
->
None
:
# Need include_gpu_probs_tensor for MultiStepWorker
self
.
model_runner
.
sampler
.
include_gpu_probs_tensor
=
True
if
hasattr
(
self
.
model_runner
.
model
,
"sampler"
):
(
self
.
model_runner
.
model
.
sampler
.
include_gpu_probs_tensor
)
=
True
def
set_should_modify_greedy_probs_inplace
(
self
)
->
None
:
self
.
model_runner
.
sampler
.
should_modify_greedy_probs_inplace
=
True
if
hasattr
(
self
.
model_runner
.
model
,
"sampler"
):
(
self
.
model_runner
.
model
.
sampler
.
should_modify_greedy_probs_inplace
)
=
True
@
torch
.
inference_mode
()
def
sampler_output
(
self
,
execute_model_req
:
ExecuteModelRequest
,
sample_len
:
int
,
seq_ids_with_bonus_token_in_last_step
:
Set
[
int
],
)
->
Tuple
[
List
[
SamplerOutput
],
bool
]:
"""Run the model forward pass sample_len times. Returns the list of
sampler output, one per model forward pass, along with indicator of
whether torch tensor in sampler output need to be transposed in latter
sampler_output_to_torch logic.
For multi step worker, this indicator shall be True.
"""
self
.
_raise_if_unsupported
(
execute_model_req
)
# Expand the batch for sequences with a bonus token.
# Perform a forward pass on the expanded batch and filter the
# response to retain only the original sequences' responses.
expanded_request
,
indices_of_seq_with_bonus_tokens
=
\
self
.
_expand_execute_model_request
(
execute_model_req
,
seq_ids_with_bonus_token_in_last_step
)
# Run model sample_len times.
model_outputs
:
List
[
SamplerOutput
]
=
[]
if
current_platform
.
is_cuda_alike
()
and
isinstance
(
self
.
model_runner
,
TP1DraftModelRunner
)
and
self
.
model_runner
.
supports_gpu_multi_step
(
expanded_request
):
# Here we run the draft_model_runner with multi-step prepare
# on the GPU directly
expanded_request
.
num_steps
=
sample_len
self
.
model_runner
.
set_indices_of_seq_with_bonus_tokens
(
indices_of_seq_with_bonus_tokens
)
model_outputs
=
self
.
execute_model
(
execute_model_req
=
expanded_request
)
else
:
# Here we run multi-step directly, with every step prepared
# on the CPU.
# TODO: Remove this branch once DraftModelRunner supports TP>1
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
if
expanded_request
.
previous_hidden_states
is
not
None
:
self
.
worker
.
model_runner
.
return_hidden_states
=
True
for
_
in
range
(
sample_len
):
model_output
:
List
[
SamplerOutput
]
=
self
.
worker
.
execute_model
(
execute_model_req
=
expanded_request
)
assert
(
len
(
model_output
)
==
1
),
"composing multistep workers not supported"
model_output
=
model_output
[
0
]
self
.
_maybe_update_previous_hidden_states
(
model_output
,
expanded_request
)
self
.
_append_new_tokens
(
model_output
,
expanded_request
.
seq_group_metadata_list
,
indices_of_seq_with_bonus_tokens
)
model_outputs
.
append
(
model_output
)
# move indices to device to avoid stream sync
indices_of_seq_with_bonus_tokens
=
torch
.
tensor
(
indices_of_seq_with_bonus_tokens
,
device
=
self
.
device
)
filtered_model_outputs
=
self
.
_filter_model_output
(
model_outputs
,
indices_of_seq_with_bonus_tokens
)
return
filtered_model_outputs
,
True
@
staticmethod
def
_maybe_update_previous_hidden_states
(
model_output
:
SamplerOutput
,
expanded_request
:
ExecuteModelRequest
)
->
None
:
"""
Updates the previous hidden states in an expanded request
in-place with the hidden states from the model output.
"""
if
expanded_request
.
previous_hidden_states
is
not
None
:
expanded_request
.
previous_hidden_states
=
HiddenStates
(
model_output
.
hidden_states
,
expanded_request
.
seq_group_metadata_list
)
@
staticmethod
def
_expand_execute_model_request
(
execute_model_req
:
ExecuteModelRequest
,
seq_with_bonus_token_in_last_step
:
set
,
)
->
Tuple
[
ExecuteModelRequest
,
List
[
int
]]:
"""
Expands the execute model request based on sequences with bonus
tokens.
For each sequence with a bonus token, this method creates a new
sequence without the bonus token and adds it to the execute model
request. The original sequence groups are also retained. The indices
of the original sequence groups are returned for further processing.
Args:
execute_model_req (ExecuteModelRequest): The original execute
model request.
seq_with_bonus_token_in_last_step (set): Set of sequence IDs that
contain bonus tokens.
Returns:
Tuple[ExecuteModelRequest, List[int]]: The updated execute model
request with expanded sequences and a list of indices corresponding
to the original sequence groups.
"""
updated_seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
updated_execute_model_req
=
execute_model_req
.
clone
(
updated_seq_group_metadata_list
)
indices_of_original_sequence_groups
=
[]
for
seq_group
in
execute_model_req
.
seq_group_metadata_list
:
seq_group_has_bonus_tokens
=
False
for
seq_id
,
_
in
seq_group
.
seq_data
.
items
():
# Identify sequences with bonus tokens in the sequence group.
if
seq_id
in
seq_with_bonus_token_in_last_step
:
seq_group_has_bonus_tokens
=
True
break
if
seq_group_has_bonus_tokens
:
#Create new sequences without the last bonus token. These new
# sequence have the same sequence id as the original sequence.
# We create a new sequence group and add them there.
updated_seq_group_without_bonus_token
=
\
MultiStepWorker
.
_copy_seq_metadata_excluding_last_token
(
seq_group
,
seq_with_bonus_token_in_last_step
)
updated_seq_group_metadata_list
.
append
(
updated_seq_group_without_bonus_token
)
# Add the original sequence group.
updated_seq_group_metadata_list
.
append
(
MultiStepWorker
.
_shallow_copy_seq_group_metadata
(
seq_group
))
# Record the index of the original sequence group.
indices_of_original_sequence_groups
.
append
(
len
(
updated_seq_group_metadata_list
)
-
1
)
updated_execute_model_req
.
seq_group_metadata_list
=
\
updated_seq_group_metadata_list
if
isinstance
(
updated_execute_model_req
.
previous_hidden_states
,
HiddenStates
):
updated_execute_model_req
.
previous_hidden_states
\
.
expand_with_bonus_tokens
(
seq_with_bonus_token_in_last_step
)
return
updated_execute_model_req
,
indices_of_original_sequence_groups
@
staticmethod
def
_filter_model_output
(
expanded_batch_outputs
:
List
[
SamplerOutput
],
output_indices_to_retain
:
torch
.
Tensor
)
->
List
[
SamplerOutput
]:
"""
Filters the model output to include only the specified sequence
outputs. This method contracts the expanded batch output from the
model to retain the outputs of only those sequences indicated by the
provided indices.
Args:
expanded_batch_output (List[SamplerOutput]): The expanded output
batch from the model.
output_indices_to_retain (torch.Tensor): Indices of the model
outputs to retain.
Returns:
List[SamplerOutput]: A list containing the filtered model
outputs for the specified indices.
"""
return
[
SamplerOutput
(
outputs
=
[
expanded_batch_output
.
outputs
[
i
]
for
i
in
output_indices_to_retain
]
if
len
(
expanded_batch_output
.
outputs
)
>
0
else
[],
sampled_token_probs
=
(
expanded_batch_output
.
sampled_token_probs
[
output_indices_to_retain
]
if
expanded_batch_output
.
sampled_token_probs
is
not
None
else
None
),
logprobs
=
(
expanded_batch_output
.
logprobs
[
output_indices_to_retain
]
if
expanded_batch_output
.
logprobs
is
not
None
else
None
),
sampled_token_ids
=
(
expanded_batch_output
.
sampled_token_ids
[
output_indices_to_retain
]
if
expanded_batch_output
.
sampled_token_ids
is
not
None
else
None
))
for
expanded_batch_output
in
expanded_batch_outputs
]
def
get_spec_proposals
(
self
,
execute_model_req
:
ExecuteModelRequest
,
seq_ids_with_bonus_token_in_last_step
:
set
,
)
->
SpeculativeProposals
:
"""Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len.
"""
return
self
.
_proposer
.
get_spec_proposals
(
execute_model_req
,
seq_ids_with_bonus_token_in_last_step
)
@
staticmethod
def
_append_new_tokens
(
model_output
:
List
[
SamplerOutput
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
indices_of_seq_with_bonus_tokens
:
List
[
int
])
->
None
:
"""Given model output from a single run, append the tokens to the
sequences. This is normally done outside of the worker, but it is
required if the worker is to perform multiple forward passes.
"""
count
=
0
for
index
,
(
seq_group_metadata
,
sequence_group_outputs
)
in
enumerate
(
zip
(
seq_group_metadata_list
,
model_output
)):
seq_group_metadata
.
is_prompt
=
False
for
seq_output
in
sequence_group_outputs
.
samples
:
# NOTE: Beam search is not supported, so we can assume that
# parent_seq_id == seq_id.
seq
=
seq_group_metadata
.
seq_data
[
seq_output
.
parent_seq_id
]
token_id
=
seq_output
.
output_token
token_logprob
=
seq_output
.
logprobs
[
token_id
]
# Determine the actual token ID to be generated,
# considering bonus tokens
if
index
!=
indices_of_seq_with_bonus_tokens
[
count
]:
bonus_seq_metadata
=
seq_group_metadata_list
[
indices_of_seq_with_bonus_tokens
[
count
]]
_
,
bonus_token_seq_data
=
next
(
iter
(
bonus_seq_metadata
.
seq_data
.
items
()))
token_id
=
bonus_token_seq_data
.
output_token_ids
[
-
1
]
else
:
count
+=
1
seq
.
append_token_id
(
token_id
,
token_logprob
.
logprob
,
seq_output
.
output_embed
)
seq
.
update_num_computed_tokens
(
1
)
@
staticmethod
def
_shallow_copy_seq_group_metadata
(
seq_group_metadata
:
SequenceGroupMetadata
,
)
->
SequenceGroupMetadata
:
"""Copy input data structures to remove side-effects when input data
structures are shared with other modules.
Helpful when the vLLM scheduler runs in the same process as the worker.
The alternative is deep-copying (or other form of deep copy); this has
performance downsides.
"""
# Shallow-copy the SequenceGroupMetadata. This allows us to
# append tokens and change is_prompt without external side-effects.
# We must shallow-copy seq_group_metadata as is_prompt could change.
new_seq_group_metadata
=
copy
.
copy
(
seq_group_metadata
)
# We must shallow-copy seq_data as we will append token ids
new_seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
for
seq_id
,
old_seq_data
in
seq_group_metadata
.
seq_data
.
items
():
new_seq_data
[
seq_id
]
=
copy
.
copy
(
old_seq_data
)
new_seq_data
[
seq_id
].
output_token_ids
=
\
old_seq_data
.
output_token_ids
[:]
new_seq_group_metadata
.
seq_data
=
new_seq_data
return
new_seq_group_metadata
@
staticmethod
def
_copy_seq_metadata_excluding_last_token
(
seq_group_metadata
:
SequenceGroupMetadata
,
seq_ids_to_copy
:
Set
[
int
],
)
->
SequenceGroupMetadata
:
"""
Creates a shallow copy of the given SequenceGroupMetadata, retaining
only the sequence IDs specified in seq_ids_to_copy. For each of these
sequence IDs, all output_token_ids except the last one are copied.
Sequence IDs not in seq_ids_to_copy are excluded from the copy.
Parameters:
seq_group_metadata (SequenceGroupMetadata): The original sequence
group metadata.
seq_ids_to_copy (Set[int]): The set of sequence IDs to include in the
copy.
Returns:
SequenceGroupMetadata: A shallow copy of the sequence group metadata
with the specified modifications.
"""
# Shallow-copy the SequenceGroupMetadata.
new_seq_group_metadata
=
copy
.
copy
(
seq_group_metadata
)
# Shallow-copy seq_data and modify the output_token_ids.
new_seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
for
seq_id
,
old_seq_data
in
seq_group_metadata
.
seq_data
.
items
():
if
(
seq_id
in
seq_ids_to_copy
):
new_seq_data
[
seq_id
]
=
copy
.
copy
(
old_seq_data
)
# Copy all the output token ids except the last.
# Also reduce num_computed_tokens by 1 since we are not
# including the last output token.
# NOTE: num_computed_tokens is not directly used by the
# speculative decoding workers, as it is only relevant for
# chunked prefill, which is disabled for speculative decoding.
# However, to maintain consistency in num_computed_tokens,
# we update it here.
new_seq_data
[
seq_id
].
output_token_ids
=
\
old_seq_data
.
output_token_ids
[:
-
1
]
new_seq_data
[
seq_id
].
update_num_computed_tokens
(
-
1
)
new_seq_group_metadata
.
seq_data
=
new_seq_data
return
new_seq_group_metadata
def
_assert_enough_kv_space
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
num_steps
:
int
)
->
None
:
"""Assert there are enough physical blocks per sequence to store the
current KV plus additional KV from num_steps tokens.
"""
assert
self
.
model_runner
.
block_size
is
not
None
for
seq_group_metadata
in
seq_group_metadata_list
:
# Only one seq_id is guaranteed because there is no beam search.
seq_id
=
list
(
seq_group_metadata
.
seq_data
.
keys
())[
0
]
seq
=
seq_group_metadata
.
seq_data
[
seq_id
]
# After num_steps, the seq len will be the current seq len
# plus one token per step.
final_seq_len
=
seq
.
get_len
()
+
num_steps
# We will have final_seq_len - 1 KV because vLLM saves KV for a
# token in the iteration after the token was generated.
required_num_kv_slots
=
final_seq_len
-
1
# The allocated number of kv slots is the number of allocated blocks
# times the number of slots of block.
number_physical_blocks
=
len
(
seq_group_metadata
.
block_tables
[
seq_id
])
allocated_kv_slots
=
(
number_physical_blocks
*
self
.
model_runner
.
block_size
)
if
required_num_kv_slots
>
allocated_kv_slots
:
request_id
=
seq_group_metadata
.
request_id
raise
ValueError
(
"The worker attempted to run "
f
"
{
num_steps
}
times but found insufficient KV space for "
f
"
{
request_id
=
}
{
seq_id
=
}
. (
{
allocated_kv_slots
=
}
"
f
"
{
required_num_kv_slots
=
}
)."
)
def
_raise_if_unsupported
(
self
,
execute_model_req
:
ExecuteModelRequest
,
)
->
None
:
"""MultiStepWorker does not yet implement support for cache swap
operations or beam search.
"""
if
execute_model_req
is
None
:
return
None
if
any
([
execute_model_req
.
blocks_to_swap_in
,
execute_model_req
.
blocks_to_swap_out
,
execute_model_req
.
blocks_to_copy
]):
raise
NotImplementedError
(
"MultiStepWorker does not support cache operations"
)
if
any
(
len
(
seq_group_metadata
.
seq_data
.
keys
())
!=
1
for
seq_group_metadata
in
execute_model_req
.
seq_group_metadata_list
):
raise
NotImplementedError
(
"MultiStepWorker does not support beam search."
)
def
maybe_load_lm_head_weight
(
self
,
lm_head_weight
:
torch
.
Tensor
,
)
->
None
:
weight_loader
=
getattr
(
self
.
worker
.
model_runner
.
model_runner
.
model
.
lm_head
.
weight
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
self
.
worker
.
model_runner
.
model_runner
.
model
.
lm_head
.
weight
,
lm_head_weight
)
vllm/spec_decode/spec_decode_worker.py
deleted
100644 → 0
View file @
5ad884ee
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
import
copy
from
collections
import
defaultdict
from
functools
import
cached_property
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
import
torch
import
torch.nn
as
nn
from
vllm.config
import
ParallelConfig
,
SpeculativeConfig
,
VllmConfig
from
vllm.distributed.communication_op
import
(
broadcast_tensor_dict
,
get_tp_group
,
tensor_model_parallel_gather
)
from
vllm.distributed.parallel_state
import
model_parallel_is_initialized
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.spec_decode_base_sampler
import
(
SpecDecodeBaseSampler
,
SpecDecodeStochasticBaseSampler
)
from
vllm.model_executor.layers.typical_acceptance_sampler
import
(
TypicalAcceptanceSampler
)
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
(
VLLM_INVALID_TOKEN_ID
,
CompletionSequenceGroupOutput
,
ExecuteModelRequest
,
HiddenStates
,
SequenceGroupMetadata
,
get_all_seq_ids_and_request_ids
,
Logits
)
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTreeStyleScorer
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
if
current_platform
.
is_cuda_alike
():
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
from
vllm.spec_decode.medusa_worker
import
MedusaWorker
from
vllm.spec_decode.metrics
import
AsyncMetricsCollector
from
vllm.spec_decode.mlp_speculator_worker
import
MLPSpeculatorWorker
from
vllm.spec_decode.mqa_scorer
import
MQAScorer
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.ngram_worker
import
NGramWorker
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.spec_decode.smaller_tp_proposer_worker
import
SmallerTpProposerWorker
from
vllm.spec_decode.target_model_runner
import
TargetModelRunner
from
vllm.spec_decode.util
import
(
Timer
,
create_logprobs_output
,
create_sequence_group_output
,
get_all_num_logprobs
,
get_sampled_token_logprobs
,
nvtx_range
,
split_batch_by_proposal_len
)
from
vllm.utils
import
resolve_obj_by_qualname
from
vllm.worker.worker_base
import
LoRANotSupportedWorkerBase
,
WorkerBase
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.spec_decode.proposer_worker_base
import
NonLLMProposerWorkerBase
import
vllm.envs
as
envs
logger
=
init_logger
(
__name__
)
def
create_spec_worker
(
*
args
,
**
kwargs
)
->
"SpecDecodeWorker"
:
"""Helper method that is the entrypoint for Executors which use
WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
"""
vllm_config
:
VllmConfig
=
kwargs
.
get
(
"vllm_config"
)
speculative_config
:
SpeculativeConfig
=
vllm_config
.
speculative_config
assert
speculative_config
is
not
None
if
vllm_config
.
parallel_config
.
pipeline_parallel_size
>
1
:
raise
NotImplementedError
(
"Speculative decoding is currently "
"incompatible with pipeline parallelism"
)
draft_worker_kwargs
=
kwargs
.
copy
()
kwargs
[
"model_runner_cls"
]
=
TargetModelRunner
target_worker_config
=
copy
.
deepcopy
(
vllm_config
)
target_worker_config
.
parallel_config
.
worker_cls
=
\
target_worker_config
.
parallel_config
.
sd_worker_cls
cls
=
resolve_obj_by_qualname
(
target_worker_config
.
parallel_config
.
worker_cls
)
target_worker
=
cls
(
*
args
,
**
kwargs
)
# Set the disable_logprobs variable in the TargetModelRunner instance
# as per its value specified in the SpeculativeConfig.
target_worker
.
model_runner
.
disable_logprobs
=
\
speculative_config
.
disable_logprobs
draft_worker_config
=
copy
.
deepcopy
(
vllm_config
)
draft_worker_config
.
model_config
=
speculative_config
.
draft_model_config
# draft_worker_config.quant_config = VllmConfig._get_quantization_config(
# draft_worker_config.model_config,
# vllm_config.load_config,
# )
speculative_config
.
draft_parallel_config
.
worker_cls
=
\
draft_worker_config
.
parallel_config
.
sd_worker_cls
draft_worker_config
.
parallel_config
=
speculative_config
.
draft_parallel_config
# noqa
# TODO allow draft-model specific load config.
# Override draft-model specific worker args.
draft_worker_kwargs
.
update
(
vllm_config
=
draft_worker_config
,
ngram_prompt_lookup_max
=
speculative_config
.
prompt_lookup_max
,
ngram_prompt_lookup_min
=
speculative_config
.
prompt_lookup_min
,
)
spec_decode_worker
=
SpecDecodeWorker
.
create_worker
(
scorer_worker
=
target_worker
,
draft_worker_kwargs
=
draft_worker_kwargs
,
disable_mqa_scorer
=
speculative_config
.
disable_mqa_scorer
,
disable_by_batch_size
=
speculative_config
.
disable_by_batch_size
,
draft_token_acceptance_method
=
speculative_config
.
acceptance_method
,
typical_acceptance_sampler_posterior_threshold
=
speculative_config
.
posterior_threshold
,
typical_acceptance_sampler_posterior_alpha
=
speculative_config
.
posterior_alpha
,
disable_logprobs
=
speculative_config
.
disable_logprobs
,
disable_log_stats
=
speculative_config
.
disable_log_stats
,
num_speculative_tokens
=
speculative_config
.
num_speculative_tokens
,
)
return
spec_decode_worker
# Reminder: Please update docs/features/compatibility_matrix.md
# If the feature combo become valid
class
SpecDecodeWorker
(
LoRANotSupportedWorkerBase
):
"""Worker which implements speculative decoding.
Speculative decoding reduces decoding per-token latency by using a proposal
method, such as a small draft model, to speculate ahead of a larger LLM. The
probabilities of the speculative tokens are then determined by the larger
LLM, after which some verification routine determines which (if any) of the
speculative tokens are accepted by the larger LLM.
See https://github.com/vllm-project/vllm/pull/2188 and
https://github.com/vllm-project/vllm/pull/3103 for more info.
The current implementation has the following limitations:
* Only draft-model proposal is implemented (contributions for more forms are
welcome!).
* Only top-1 proposal and scoring are implemented. Tree-attention is left as
future work.
* All sequences in a batch must have the same proposal length, or zero. This
can be improved by having per-sequence speculation in the future.
* The scoring forward pass is done without an MQA kernel, which is
suboptimal especially as the batch size, proposal length, and sequence
lengths grow. Contributions to add a MQA scoring are welcome once
correctness tests pass.
More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit.
"""
@
classmethod
def
create_worker
(
cls
,
scorer_worker
:
WorkerBase
,
draft_worker_kwargs
:
Dict
[
str
,
Any
],
disable_mqa_scorer
:
bool
,
disable_by_batch_size
:
Optional
[
int
],
draft_token_acceptance_method
:
str
,
typical_acceptance_sampler_posterior_threshold
:
float
,
typical_acceptance_sampler_posterior_alpha
:
float
,
disable_logprobs
:
bool
,
disable_log_stats
:
bool
,
num_speculative_tokens
:
int
,
)
->
"SpecDecodeWorker"
:
allow_zero_draft_token_step
=
True
enable_lm_head_weight_load
=
False
num_spec_prefill_steps
=
1
ngram_prompt_lookup_max
=
(
draft_worker_kwargs
.
pop
(
"ngram_prompt_lookup_max"
))
ngram_prompt_lookup_min
=
(
draft_worker_kwargs
.
pop
(
"ngram_prompt_lookup_min"
))
draft_model_config
=
draft_worker_kwargs
[
"vllm_config"
].
model_config
draft_parallel_config
:
ParallelConfig
=
draft_worker_kwargs
[
'vllm_config'
].
parallel_config
if
ngram_prompt_lookup_max
>
0
:
assert
draft_parallel_config
.
tensor_parallel_size
==
1
draft_worker_kwargs
[
"device_type"
]
=
scorer_worker
.
device_config
.
device
.
type
proposer_worker
=
NGramWorker
(
**
draft_worker_kwargs
)
proposer_worker
.
set_ngram_window_size
(
ngram_prompt_lookup_min
,
ngram_prompt_lookup_max
)
else
:
draft_tp
=
draft_parallel_config
.
tensor_parallel_size
target_tp
=
scorer_worker
.
parallel_config
.
tensor_parallel_size
if
draft_model_config
.
hf_config
.
model_type
==
"mlp_speculator"
:
proposer_worker
=
MLPSpeculatorWorker
(
**
draft_worker_kwargs
)
elif
draft_model_config
.
hf_config
.
model_type
==
"medusa"
:
proposer_worker
=
MedusaWorker
(
**
draft_worker_kwargs
)
else
:
if
draft_tp
==
1
:
if
current_platform
.
is_cuda_alike
():
draft_worker_kwargs
[
"model_runner_cls"
]
=
TP1DraftModelRunner
else
:
if
draft_model_config
.
hf_config
.
model_type
==
"eagle"
:
raise
NotImplementedError
(
f
"
{
draft_model_config
.
hf_config
.
model_type
}
"
"does not support TP > 1 yet"
)
allow_zero_draft_token_step
=
False
# Load lm_head weight for eagle in init_device
if
draft_model_config
.
hf_config
.
model_type
==
"eagle"
:
enable_lm_head_weight_load
=
True
if
envs
.
VLLM_ZERO_OVERHEAD
:
assert
False
,
(
"speculative decoding not support zero overhead scheduler yet"
)
from
vllm.zero_overhead.spec_decode.muti_step_worker
import
ZeroOverheadMultiStepWorker
proposer_worker
=
ZeroOverheadMultiStepWorker
(
**
draft_worker_kwargs
)
else
:
proposer_worker
=
MultiStepWorker
(
**
draft_worker_kwargs
)
if
draft_model_config
.
hf_config
.
model_type
==
"deepseek_mtp"
:
num_spec_prefill_steps
=
\
draft_model_config
.
hf_config
.
n_predict
proposer_worker
=
SmallerTpProposerWorker
.
maybe_wrap_worker
(
proposer_worker
,
draft_tp
,
target_tp
)
logger
.
info
(
"Configuring SpecDecodeWorker with proposer=%s"
,
type
(
proposer_worker
))
spec_decode_sampler
:
SpecDecodeBaseSampler
=
None
if
draft_token_acceptance_method
==
"rejection_sampler"
:
spec_decode_sampler
=
RejectionSampler
()
elif
draft_token_acceptance_method
==
"typical_acceptance_sampler"
:
spec_decode_sampler
=
TypicalAcceptanceSampler
(
posterior_threshold
=
\
typical_acceptance_sampler_posterior_threshold
,
posterior_alpha
=
typical_acceptance_sampler_posterior_alpha
,
)
logger
.
info
(
"[Speculative Decoding] Configuring"
" SpecDecodeWorker with sampler=%s"
,
type
(
spec_decode_sampler
))
if
not
disable_mqa_scorer
:
if
scorer_worker
.
model_runner
.
attn_backend
.
get_name
(
)
!=
"FLASH_ATTN"
:
disable_mqa_scorer
=
True
logger
.
info
(
"[Speculative Decoding] Disabling MQA scorer as the "
"MQA is only available with flash attn backend."
)
if
draft_model_config
and
\
draft_model_config
.
max_model_len
<
\
scorer_worker
.
model_config
.
max_model_len
:
disable_mqa_scorer
=
True
logger
.
info
(
"[Speculative Decoding] Disabling MQA scorer as the "
"draft model max_model_len is smaller than the target "
"model max_model_len."
)
if
not
scorer_worker
.
model_runner
.
model_config
.
enforce_eager
:
disable_mqa_scorer
=
True
logger
.
info
(
"[Speculative Decoding] Disabling MQA scorer as the "
"target model is not running in eager mode."
)
if
envs
.
VLLM_ZERO_OVERHEAD
:
from
vllm.zero_overhead.spec_decode.spec_decode_worker
import
ZeroOverheadSpecDecodeWorker
return
ZeroOverheadSpecDecodeWorker
(
proposer_worker
,
scorer_worker
,
disable_mqa_scorer
=
disable_mqa_scorer
,
disable_logprobs
=
disable_logprobs
,
disable_log_stats
=
disable_log_stats
,
disable_by_batch_size
=
disable_by_batch_size
,
spec_decode_sampler
=
spec_decode_sampler
,
allow_zero_draft_token_step
=
allow_zero_draft_token_step
,
enable_lm_head_weight_load
=
enable_lm_head_weight_load
,
num_spec_prefill_steps
=
num_spec_prefill_steps
)
else
:
return
SpecDecodeWorker
(
proposer_worker
,
scorer_worker
,
disable_mqa_scorer
=
disable_mqa_scorer
,
disable_logprobs
=
disable_logprobs
,
disable_log_stats
=
disable_log_stats
,
disable_by_batch_size
=
disable_by_batch_size
,
spec_decode_sampler
=
spec_decode_sampler
,
allow_zero_draft_token_step
=
allow_zero_draft_token_step
,
enable_lm_head_weight_load
=
enable_lm_head_weight_load
,
num_spec_prefill_steps
=
num_spec_prefill_steps
)
def
__init__
(
self
,
proposer_worker
:
ProposerWorkerBase
,
scorer_worker
:
WorkerBase
,
spec_decode_sampler
:
SpecDecodeBaseSampler
,
disable_mqa_scorer
:
bool
=
False
,
disable_logprobs
:
bool
=
False
,
disable_log_stats
:
bool
=
False
,
metrics_collector
:
Optional
[
AsyncMetricsCollector
]
=
None
,
disable_by_batch_size
:
Optional
[
int
]
=
None
,
allow_zero_draft_token_step
:
Optional
[
bool
]
=
True
,
enable_lm_head_weight_load
:
Optional
[
bool
]
=
False
,
num_spec_prefill_steps
:
int
=
1
,
):
"""
Create a SpecDecodeWorker.
Args:
proposer_worker: A worker that can produce speculative tokens for
sequences.
scorer_worker: A worker that produces probabilities of speculative
tokens according to some base model. Typically a vanilla vLLM
Worker.
spec_decode_sampler: A Torch module used to perform acceptance
sampling of the draft tokens in the verification step of
speculative decoding. Currently we support two different
types of sampler namely RejectionSampler and
TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
instance of RejectionSampler or TypicalAcceptanceSampler.
disable_mqa_scorer: If set to True, disable the MQA scorer and use
the BatchExpansionTop1Scorer instead.
disable_logprobs: If set to True, token log probabilities will
not be output in both the draft worker and the target worker.
If set to False, log probabilities will be output by both.
disable_log_stats: If set to True, disable periodic printing of
speculative stage times.
disable_by_batch_size: If the batch size is larger than this,
disable speculative decoding for new incoming requests.
metrics_collector: Helper class for collecting metrics; can be set
for testing purposes.
allow_zero_draft_token_step: whether to allow a step where the draft
model generates no draft token; should disallow when the tp of
draft model is larger than 1 (TODO: #5814)
enable_lm_head_weight_load: whether to load lm_head weight for
draft models like eagle.
num_spec_prefill_steps: number of speculative prefill steps to run
before the speculative decoding starts. This is only used when
the draft model is a deepseek_mtp model that requires prefill
kv cache separately for each MTP layer.
"""
self
.
proposer_worker
=
proposer_worker
self
.
scorer_worker
=
scorer_worker
scorer_runner
=
getattr
(
self
.
scorer_worker
,
"model_runner"
,
None
)
self
.
generators
=
scorer_runner
.
get_generators
(
)
if
scorer_runner
else
None
self
.
disable_by_batch_size
=
disable_by_batch_size
or
float
(
"inf"
)
self
.
spec_decode_sampler
=
spec_decode_sampler
self
.
_allow_zero_draft_token_step
=
allow_zero_draft_token_step
self
.
_enable_lm_head_weight_load
=
enable_lm_head_weight_load
self
.
_metrics
=
AsyncMetricsCollector
(
self
.
spec_decode_sampler
)
if
metrics_collector
is
None
else
metrics_collector
# Tracks the sequence IDs that received a bonus token ID in
# their last forward pass. Needed only if KV cache is being
# used for token generation such as in the case of MultiStepWorker.
self
.
_seq_with_bonus_token_in_last_step
:
Set
[
int
]
=
set
()
# Tracks the currently active request ids and the sequence IDs
# corresponding to them
self
.
_request_id_seq_id_mapping
:
Dict
[
str
,
Set
[
int
]]
=
defaultdict
(
set
)
# Tracks if the proposer worker uses the KV cache or not.
self
.
probs_dtype
=
self
.
spec_decode_sampler
.
probs_dtype
self
.
token_id_dtype
=
self
.
spec_decode_sampler
.
token_id_dtype
# Lazy initialization.
self
.
scorer
:
BatchExpansionTop1Scorer
self
.
disable_mqa_scorer
=
disable_mqa_scorer
# Hidden states from target model to pass to proposer
# in the subsequent step.
self
.
previous_hidden_states
:
Optional
[
HiddenStates
]
=
None
self
.
previous_logits
:
Optional
[
Logits
]
=
None
self
.
kvcache_slot_to_be_moved
:
Optional
[
torch
.
Tensor
]
=
None
self
.
_disable_logprobs
=
disable_logprobs
self
.
_disable_log_stats
=
disable_log_stats
self
.
_num_spec_prefill_steps
=
num_spec_prefill_steps
self
.
tree_decoding
=
(
os
.
environ
.
get
(
'VLLM_TREE_DECODING'
)
==
'1'
)
def
init_device
(
self
)
->
None
:
"""Initialize both scorer and proposer models.
"""
# The scorer worker model is initialized first in case the proposer
# model has a smaller TP degree than the target worker.
self
.
scorer_worker
.
init_device
()
self
.
proposer_worker
.
init_device
()
# NOTE(cade): load_model is not part of the WorkerBase interface.
self
.
scorer_worker
.
load_model
()
self
.
proposer_worker
.
load_model
()
if
self
.
_enable_lm_head_weight_load
:
# NOTE(Shangming): gather lm_head weight when tp enabled
target_lm_head_weight
:
torch
.
Tensor
=
tensor_model_parallel_gather
(
self
.
scorer_worker
.
model_runner
.
model_runner
.
model
.
lm_head
.
\
weight
.
data
,
dim
=
0
,
)
self
.
proposer_worker
.
maybe_load_lm_head_weight
(
target_lm_head_weight
)
self
.
_metrics
.
init_tensors
(
self
.
rank
,
device_type
=
self
.
device
)
if
model_parallel_is_initialized
():
self
.
spec_decode_sampler
.
init_tensors
(
get_tp_group
().
local_rank
,
device_type
=
self
.
device
)
else
:
self
.
spec_decode_sampler
.
init_tensors
(
self
.
rank
,
device_type
=
self
.
device
)
scorer_cls
:
Type
[
SpeculativeScorer
]
if
self
.
disable_mqa_scorer
:
scorer_cls
=
BatchExpansionTop1Scorer
logger
.
info
(
"[Speculative Decoding] Use batch "
"expansion for scoring proposals."
)
else
:
scorer_cls
=
MQAScorer
logger
.
info
(
"[Speculative Decoding] Use MQA scorer for scoring proposals."
)
if
not
self
.
tree_decoding
:
self
.
scorer
=
scorer_cls
(
scorer_worker
=
self
.
scorer_worker
,
device
=
self
.
device
,
vocab_size
=
self
.
_vocab_size
)
else
:
self
.
scorer
=
BatchExpansionTreeStyleScorer
(
scorer_worker
=
self
.
scorer_worker
,
device
=
self
.
device
,
vocab_size
=
self
.
_vocab_size
)
self
.
_configure_model_sampler_for_spec_decode
()
def
load_model
(
self
,
*
args
,
**
kwargs
):
pass
def
_configure_model_sampler_for_spec_decode
(
self
):
"""Configure model sampler to emit GPU tensors. This allows spec decode
to keep data on device without transferring to CPU and serializing,
which significantly reduces overhead of sampling during verification.
NOTE(cade): This breaks abstraction boundaries pretty badly. The better
design is to have the "move to CPU and serialize" sampling decision be
done outside of the model/sampler; this way the "last-mile" worker
object which interfaces with the scheduler can serialize and incur the
performance hit as necessary. This allows us to run the worker several
iterations in a row without incurring the "move to CPU and serialize"
performance penalty.
Since this requires a large change to vLLM, we defer it to later and
temporarily accept this broken abstraction boundary.
NOTE(cade): This will require a special check if the proposer worker
does not have a sampler (e.g. ngram speculation).
"""
(
self
.
scorer_worker
.
model_runner
.
sampler
.
include_gpu_probs_tensor
)
=
True
# tree_style decoding modify probs in _verify_tokens
if
not
self
.
tree_decoding
:
(
self
.
scorer_worker
.
model_runner
.
sampler
.
should_modify_greedy_probs_inplace
)
=
True
self
.
proposer_worker
.
set_include_gpu_probs_tensor
()
self
.
proposer_worker
.
set_should_modify_greedy_probs_inplace
()
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
"""Determine the number of cache blocks to use.
This is done by profiling the scorer model (which is typically the
larger of the two). Then the total memory which would be used by the
scorer cache is divided evenly between the proposer and scorer model KV,
such that the number of blocks is equal in both KV caches.
"""
num_gpu_blocks
,
num_cpu_blocks
=
(
self
.
scorer_worker
.
determine_num_available_blocks
())
scorer_cache_block_size_bytes
=
(
self
.
scorer_worker
.
get_cache_block_size_bytes
())
proposer_cache_block_size_bytes
=
(
self
.
proposer_worker
.
get_cache_block_size_bytes
())
new_num_gpu_blocks
=
split_num_cache_blocks_evenly
(
scorer_cache_block_size_bytes
,
proposer_cache_block_size_bytes
,
num_gpu_blocks
)
return
new_num_gpu_blocks
,
num_cpu_blocks
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
)
->
None
:
"""Initialize the cache engine of the scorer and proposer workers.
"""
self
.
scorer_worker
.
initialize_cache
(
num_gpu_blocks
=
num_gpu_blocks
,
num_cpu_blocks
=
num_cpu_blocks
)
self
.
proposer_worker
.
initialize_cache
(
num_gpu_blocks
=
num_gpu_blocks
,
num_cpu_blocks
=
num_cpu_blocks
)
def
get_model
(
self
)
->
nn
.
Module
:
return
self
.
scorer_worker
.
get_model
()
@
torch
.
inference_mode
()
def
execute_model
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
)
->
List
[
SamplerOutput
]:
"""Perform speculative decoding on the input batch.
"""
if
self
.
rank
!=
self
.
_driver_rank
:
self
.
_run_non_driver_rank
()
return
[]
if
execute_model_req
is
None
:
# This signals that there's no more requests to process for now.
# All workers are running infinite loop with broadcast_tensor_dict,
# and it stops the loop when the driver broadcasts an empty input.
# Send an empty input to notify all other workers to stop their
# execution loop.
broadcast_tensor_dict
({},
src
=
0
)
return
[]
self
.
_track_finished_requests
(
execute_model_req
)
disable_all_speculation
=
self
.
_should_disable_all_speculation
(
execute_model_req
)
num_lookahead_slots
=
execute_model_req
.
num_lookahead_slots
all_prompt
=
True
atleast_one_prompt
=
False
all_zero_spec_tokens
=
True
for
sgm
in
execute_model_req
.
seq_group_metadata_list
:
all_prompt
=
all_prompt
and
sgm
.
is_prompt
atleast_one_prompt
=
atleast_one_prompt
or
sgm
.
is_prompt
all_zero_spec_tokens
=
all_zero_spec_tokens
and
(
sgm
.
num_speculative_tokens
==
0
)
if
all_prompt
and
execute_model_req
.
seq_group_metadata_list
:
assert
num_lookahead_slots
==
0
,
(
"Prompt only runs should have num_lookahead_slots equal to 0. "
"This should never happen, please file a bug at "
"https://github.com/vllm-project/vllm/issues"
)
# Speculative decoding is disabled in the following cases:
# 1. Prefill phase: Speculative decoding is not
# used during the prefill phase.
# 2. Auto-disable enabled: The running queue size exceeds
# the specified threshold.
# 3. No request: There are no requests in the batch, or
# none of the requests in the batch have spec decoding enabled.
# In any of these cases, the proposer and scorer workers
# are called normally.
# We expect `num_speculative_tokens` to be None for prefills.
no_spec
=
(
num_lookahead_slots
==
0
or
disable_all_speculation
or
all_zero_spec_tokens
)
# Broadcast how many lookahead slots are scheduled for this step, and
# whether all speculation is disabled, to all non-driver workers.
# This is required as if the number of draft model runs changes
# dynamically, the non-driver workers won't know unless we perform a
# communication to inform them.
# no_spec is used to signal non-driver worker about prefill vs decode
# stage. This is needed to ensure that order of execution of proposer
# and scorer is same in both driver and non-driver workers (i.e.,
# scorer -> proposer for prefill and proposer -> scorer in decode). This
# order is needed to support models like EAGLE that take scorer states
# as inputs.
broadcast_dict
=
dict
(
num_lookahead_slots
=
num_lookahead_slots
,
no_spec
=
no_spec
,
disable_all_speculation
=
disable_all_speculation
,
# When both chunked prefill and speculative decoding are enabled
# it is possible that the same batch contains both prefill
# and decodes. If that happens in the scorer we run the batch
# as one single forward pass. However, in the proposer we
# run them as 2 different batches - one for prefill and
# the other for decodes. The variable indicates to the non-driver
# worker that there are prefills as part of the speculative batch
# and hence it needs to run an extra prefill forward pass.
run_spec_proposer_for_prefill
=
atleast_one_prompt
,
)
broadcast_tensor_dict
(
broadcast_dict
,
src
=
self
.
_driver_rank
)
assert
execute_model_req
.
seq_group_metadata_list
is
not
None
,
(
"speculative decoding requires non-None seq_group_metadata_list"
)
self
.
_maybe_disable_speculative_tokens
(
disable_all_speculation
,
execute_model_req
.
seq_group_metadata_list
)
if
no_spec
:
return
self
.
_run_no_spec
(
execute_model_req
,
skip_proposer
=
disable_all_speculation
)
return
self
.
_run_speculative_decoding_step
(
execute_model_req
,
num_lookahead_slots
)
@
torch
.
inference_mode
()
def
start_worker_execution_loop
(
self
)
->
None
:
"""Execute model loop to perform speculative decoding
in parallel worker."""
while
self
.
_run_non_driver_rank
():
pass
def
_should_disable_all_speculation
(
self
,
execute_model_req
:
ExecuteModelRequest
)
->
bool
:
# When the batch size is too large, disable speculative decoding
# to stop trading off throughput for latency.
return
(
execute_model_req
.
running_queue_size
>=
self
.
disable_by_batch_size
)
def
_maybe_disable_speculative_tokens
(
self
,
disable_all_speculation
:
bool
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
])
->
None
:
if
not
disable_all_speculation
:
return
for
seq_group_metadata
in
seq_group_metadata_list
:
# Once num_speculative_tokens is set to 0, the spec decode
# of this request will be disabled forever.
# TODO(comaniac): We currently store spec decoding specific
# state in the global data structure, but we should maintain
# this state within spec decode worker.
seq_group_metadata
.
num_speculative_tokens
=
0
def
_serialize_sampler_output_no_logprobs
(
self
,
execute_model_req
:
ExecuteModelRequest
,
sampler_output
:
SamplerOutput
)
->
List
[
SamplerOutput
]:
"""
Creates and returns a `SamplerOutput` with only the token IDs being
serialized to CPU and populated in `CompletionSequenceGroupOutput`.
All other parameters in `CompletionSequenceGroupOutput` related to log
probabilities are skipped.
Args:
execute_model_req (ExecuteModelRequest): The model request that
was executed.
sampler_output (SamplerOutput): The output from the sampler with
only GPU tensors populated.
Returns:
SamplerOutput: A new `SamplerOutput` instance containing a list of
`CompletionSequenceGroupOutput` objects with only token IDs
populated.
"""
seq_output_prompt_logprobs
=
[
seq
.
is_prompt
and
seq
.
sampling_params
.
prompt_logprobs
is
not
None
and
seq
.
sampling_params
.
prompt_logprobs
>
0
for
seq
in
execute_model_req
.
seq_group_metadata_list
]
# ignore slots for prompt tokens that are filled with INVALID_TOKEN_ID
sampled_token_ids_list
=
(
sampler_output
.
sampled_token_ids
[
torch
.
where
(
# subtracting is faster than testing for equality
sampler_output
.
sampled_token_ids
-
VLLM_INVALID_TOKEN_ID
)[
0
]]
\
if
any
(
seq_output_prompt_logprobs
)
else
\
sampler_output
.
sampled_token_ids
).
tolist
()
seq_data_entries
=
[
(
seq_id
,
seq_data
)
for
sg
in
\
execute_model_req
.
seq_group_metadata_list
\
for
seq_id
,
seq_data
in
sg
.
seq_data
.
items
()
]
completion_seq_group_output_list
:
List
[
CompletionSequenceGroupOutput
]
=
[]
output_index
=
0
# Make sure the non-terminal prefill chunks are still aligned with
# their own empty output.
for
idx
,
seq_group_meta
in
enumerate
(
execute_model_req
.
seq_group_metadata_list
):
needs_prompt_logprobs
=
seq_output_prompt_logprobs
[
idx
]
seq_id
,
seq_data
=
seq_data_entries
[
idx
]
if
needs_prompt_logprobs
:
prompt_token_ids
=
seq_data
.
get_prompt_token_ids
()
# Some of these sequences may belong to non-terminal chunks,
# which may still have to report logprobs for prompts.
start
=
1
if
seq_data
.
_num_computed_tokens
==
0
\
else
seq_data
.
_num_computed_tokens
end
=
(
seq_data
.
_num_computed_tokens
+
\
seq_group_meta
.
token_chunk_size
)
prompt_token_ids
=
prompt_token_ids
[
start
:
end
]
prompt_logprobs
=
[
create_logprobs_output
(
token_id
=
p_token_id
,
token_id_logprob_rank
=-
1
,
token_id_logprob
=
0.0
,
topk_token_ids
=
[],
topk_logprobs
=
[],
)
for
p_token_id
in
prompt_token_ids
]
else
:
prompt_logprobs
=
None
# Since we can get chunks here, we dont always have a sampled token
# (only on last chunk) but we still have to provide an output.
if
not
seq_group_meta
.
do_sample
:
completion_seq_group_output_list
.
append
(
CompletionSequenceGroupOutput
(
samples
=
[],
prompt_logprobs
=
prompt_logprobs
))
continue
# Sequence with output.
completion_seq_group_output_list
.
append
(
create_sequence_group_output
(
token_id
=
sampled_token_ids_list
[
output_index
][
0
],
token_id_logprob_rank
=-
1
,
token_id_logprob
=
0.0
,
seq_id
=
seq_id
,
topk_token_ids
=
[],
topk_logprobs
=
[],
prompt_logprobs
=
prompt_logprobs
))
output_index
+=
1
return
[
SamplerOutput
(
outputs
=
completion_seq_group_output_list
)]
@
nvtx_range
(
"spec_decode_worker._run_no_spec"
)
def
_run_no_spec
(
self
,
execute_model_req
:
ExecuteModelRequest
,
skip_proposer
:
bool
)
->
List
[
SamplerOutput
]:
"""Run a single generation step without any speculation. The input is
sent to the proposer and scorer model so that the KV cache is consistent
between the two. When skip_proposer is True, the proposer model is
not called, meaning that the kv-cache in proposer for requests is not
updated, so they cannot enable spec decode in the rest decoding.
"""
if
self
.
tree_decoding
and
self
.
kvcache_slot_to_be_moved
is
not
None
:
execute_model_req
.
kvcache_slot_to_be_moved
=
self
.
kvcache_slot_to_be_moved
self
.
kvcache_slot_to_be_moved
=
None
sampler_output
=
self
.
scorer_worker
.
execute_model
(
execute_model_req
)
assert
len
(
sampler_output
)
==
1
sampler_output
=
sampler_output
[
0
]
# Store hidden states from target model execution, BxD.
hidden_states
=
sampler_output
.
hidden_states
if
hidden_states
is
not
None
:
# Only decodes and prefill terminal chunks need a hidden state.
seq_group_meta_with_hidden
=
[
sg
for
sg
in
execute_model_req
.
seq_group_metadata_list
if
sg
.
do_sample
]
if
any
(
seq
.
is_prompt
for
seq
in
seq_group_meta_with_hidden
):
# Drop hidden_states with no prediction (eg non-terminal chunks)
hidden_states
=
hidden_states
[
torch
.
where
(
sampler_output
.
sampled_token_ids
-
VLLM_INVALID_TOKEN_ID
)[
0
]]
# if not skip_proposer:
# if self.previous_hidden_states is None and len(
# seq_group_meta_with_hidden):
# self.previous_hidden_states = HiddenStates(
# hidden_states, seq_group_meta_with_hidden)
# elif self.previous_hidden_states and len(
# seq_group_meta_with_hidden):
# self.previous_hidden_states.update(hidden_states,
# seq_group_meta_with_hidden)
if
self
.
previous_hidden_states
is
None
and
len
(
seq_group_meta_with_hidden
):
self
.
previous_hidden_states
=
HiddenStates
(
hidden_states
,
seq_group_meta_with_hidden
)
elif
self
.
previous_hidden_states
and
len
(
seq_group_meta_with_hidden
):
self
.
previous_hidden_states
.
update
(
hidden_states
,
seq_group_meta_with_hidden
)
#self.previous_hidden_states.prune(seq_group_meta_with_hidden)
# Store logits from target model execution.
if
self
.
tree_decoding
:
logits
=
sampler_output
.
logits
if
logits
is
not
None
:
if
self
.
previous_logits
is
None
:
self
.
previous_logits
=
Logits
(
logits
,
execute_model_req
.
seq_group_metadata_list
)
else
:
self
.
previous_logits
.
update
(
logits
,
execute_model_req
.
seq_group_metadata_list
)
if
not
skip_proposer
:
# We prepare the prefill hidden states here so that there no
# additional complexity in worker for spec_decode vs non_spec_decode
# flow and execute_model doesn't need additional modifications.
execute_model_req
.
previous_hidden_states
=
\
prepare_prefill_hidden_states
(
sampler_output
.
prefill_hidden_states
)
for
i
in
range
(
self
.
_num_spec_prefill_steps
):
execute_model_req
.
spec_step_idx
=
i
self
.
proposer_worker
.
execute_model
(
execute_model_req
)
sampler_output_to_return
=
(
self
.
_serialize_sampler_output_no_logprobs
(
execute_model_req
=
execute_model_req
,
sampler_output
=
sampler_output
)
if
self
.
_disable_logprobs
else
[
sampler_output
])
# Clear device tensors from sampler output. This reduces communication
# overhead when the engine runs in a different process than the workers.
sampler_output
.
sampled_token_probs
=
None
sampler_output
.
sampled_token_ids
=
None
sampler_output
.
logprobs
=
None
return
sampler_output_to_return
def
_run_non_driver_rank
(
self
)
->
bool
:
"""Run proposer and verifier model in non-driver workers. This is used
for both speculation cases (num_lookahead_slots>0) and non-speculation
cases (e.g. prefill).
Returns True if there are remaining sequences to process.
"""
assert
self
.
rank
!=
self
.
_driver_rank
data
=
broadcast_tensor_dict
(
src
=
self
.
_driver_rank
)
if
not
data
:
return
False
num_lookahead_slots
=
data
[
"num_lookahead_slots"
]
# In case of prefill, scorer_worker has to be run before proposer so
# that the hidden states can be propagated to proposer when needed.
if
data
[
"no_spec"
]:
self
.
scorer_worker
.
execute_model
()
if
not
data
[
"disable_all_speculation"
]:
# if not self.tree_decoding:
# # Even if num_lookahead_slots is zero, we want to run the
# # proposer model as it may have KV.
# #
# # We run the proposer once per lookahead slot. In the future we
# # should delegate how many times it runs to the proposer.
# for _ in range(max(num_lookahead_slots, 1)):
# self.proposer_worker.execute_model()
# else:
# if not data["no_spec"]:
# self.proposer_worker.sampler_output(None, None, None)
if
issubclass
(
type
(
self
.
proposer_worker
),
NonLLMProposerWorkerBase
):
if
not
data
[
"no_spec"
]:
self
.
proposer_worker
.
sampler_output
(
None
,
num_lookahead_slots
,
None
)
else
:
# Even if num_lookahead_slots is zero, we want to run the
# proposer model as it may have KV.
#
# We run the proposer once per lookahead slot. In the future we
# should delegate how many times it runs to the proposer.
for
_
in
range
(
max
(
num_lookahead_slots
,
1
)):
self
.
proposer_worker
.
execute_model
()
if
not
data
[
"no_spec"
]:
self
.
scorer_worker
.
execute_model
()
if
data
[
"run_spec_proposer_for_prefill"
]:
self
.
proposer_worker
.
execute_model
()
return
True
@
nvtx_range
(
"spec_decode_worker._run_speculative_decoding_step"
)
def
_run_speculative_decoding_step
(
self
,
execute_model_req
:
ExecuteModelRequest
,
num_lookahead_slots
:
int
)
->
List
[
SamplerOutput
]:
"""Execute a single step of speculative decoding.
This invokes the proposer worker to get k speculative tokens for each
sequence, then scores each speculative token using the scoring worker.
When `enable_chunked_prefill` is set, scorer will batch decodes and
prefills, while proposer will sync its KV-cache by running an extra
forward on prefills.
Returns a list of SamplerOutput, each containing a single token per
sequence.
"""
# With prefill chunking, expect requests to have prompts first
# so that backend gets prefill|decode.
assert
num_lookahead_slots
==
execute_model_req
.
num_lookahead_slots
# Pass last hidden states from target model to proposer
execute_model_req
.
previous_hidden_states
=
self
.
previous_hidden_states
self
.
previous_hidden_states
=
None
# Pass last logits from target model to proposer
execute_model_req
.
previous_logits
=
self
.
previous_logits
self
.
previous_logits
=
None
execute_model_req
.
kvcache_slot_to_be_moved
=
self
.
kvcache_slot_to_be_moved
self
.
kvcache_slot_to_be_moved
=
None
with
Timer
()
as
proposal_timer
:
# Generate proposals using draft worker.
proposals
=
self
.
proposer_worker
.
get_spec_proposals
(
execute_model_req
,
self
.
_seq_with_bonus_token_in_last_step
)
if
not
self
.
_allow_zero_draft_token_step
and
proposals
.
no_proposals
:
#TODO: Fix it #5814
raise
RuntimeError
(
"Cannot handle cases where distributed draft "
"workers generate no tokens"
)
# Pass tree attention mask and postions to target model
if
self
.
tree_decoding
:
execute_model_req
.
tree_attn_masks
=
proposals
.
tree_attn_masks
execute_model_req
.
tree_position_ids
=
proposals
.
tree_position_ids
execute_model_req
.
previous_hidden_states
=
None
with
Timer
()
as
scoring_timer
:
proposal_scores
=
self
.
scorer
.
score_proposals
(
execute_model_req
,
proposals
,
)
_
,
(
non_spec_seqs
,
non_spec_indices
)
=
split_batch_by_proposal_len
(
execute_model_req
.
seq_group_metadata_list
,
proposals
.
proposal_lens
)
# With prefill chunking enabled, `non_spec_seqs` contains prefills too:
# discard decodes that have already been processed by proposer.
non_spec_indices
=
[
idx
for
idx
in
non_spec_indices
if
execute_model_req
.
seq_group_metadata_list
[
idx
].
is_prompt
]
if
len
(
non_spec_indices
):
all_hidden_states
=
proposal_scores
.
hidden_states
if
all_hidden_states
is
not
None
:
prefill_hidden_states
=
all_hidden_states
[
non_spec_indices
]
execute_model_req
.
previous_hidden_states
=
\
prepare_prefill_hidden_states
(
prefill_hidden_states
)
# Sync proposer KV cache for prefills.
prefill_req
=
execute_model_req
.
clone
(
non_spec_seqs
)
# TODO avoid sampling here?
self
.
proposer_worker
.
execute_model
(
prefill_req
)
with
Timer
()
as
verification_timer
:
accepted_token_ids
,
target_logprobs
,
select_indices_list
,
accept_lengths
=
self
.
_verify_tokens
(
execute_model_req
.
seq_group_metadata_list
,
proposal_scores
,
proposals
,
execute_model_req
.
num_lookahead_slots
)
# move kv_caches of selected tokens to right positions
if
self
.
tree_decoding
:
self
.
move_caches
(
execute_model_req
,
select_indices_list
,
accept_lengths
)
stage_times
=
(
proposal_timer
.
elapsed_time_ms
/
num_lookahead_slots
,
scoring_timer
.
elapsed_time_ms
,
verification_timer
.
elapsed_time_ms
)
return
self
.
_create_output_sampler_list
(
execute_model_req
.
seq_group_metadata_list
,
accepted_token_ids
,
target_logprobs
=
target_logprobs
,
prompt_logprobs
=
proposal_scores
.
prompt_logprobs
if
not
self
.
_disable_logprobs
else
None
,
k
=
execute_model_req
.
num_lookahead_slots
,
stage_times
=
stage_times
)
@
nvtx_range
(
"spec_decode_worker._verify_tokens"
)
def
_verify_tokens
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
proposal_scores
:
SpeculativeScores
,
proposals
:
SpeculativeProposals
,
max_proposal_len
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
List
[
List
[
int
]],
List
[
int
]]:
"""Determine which speculative tokens are accepted using the
probabilities of each token according to the proposer and scorer models.
Returns a tuple of Tensors, one for the accepted token ids and one for
the logprobs according to the scoring model.
"""
proposal_lens_list
=
proposals
.
proposal_lens
.
tolist
()
# vLLM currently only supports proposal lens equal to zero or the batch
# proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
(
_
,
spec_indices
),
(
_
,
non_spec_indices
)
=
split_batch_by_proposal_len
(
seq_group_metadata_list
,
proposal_lens_list
)
original_indices
=
spec_indices
+
non_spec_indices
# Get probabilities of target model, including bonus tokens.
if
non_spec_indices
:
proposal_verifier_probs
=
proposal_scores
.
probs
[
spec_indices
]
else
:
proposal_verifier_probs
=
proposal_scores
.
probs
if
self
.
tree_decoding
:
retrieve_indices
=
proposals
.
retrieve_indices
proposal_verifier_probs
=
proposal_verifier_probs
[:,
retrieve_indices
]
# Get non-speculative sampled tokens from target model.
non_spec_token_ids
=
proposal_scores
.
token_ids
[
non_spec_indices
]
# Get bonus tokens from target model.
bonus_token_ids
=
proposal_scores
.
token_ids
[:,
-
1
:]
if
non_spec_indices
:
bonus_token_ids
=
bonus_token_ids
[
spec_indices
,
:]
# Get probabilities according to proposal method.
proposal_probs
=
proposals
.
proposal_probs
if
proposals
.
proposal_probs
is
not
None
else
None
if
proposal_probs
is
not
None
and
non_spec_indices
:
proposal_probs
=
proposal_probs
[
spec_indices
]
# Get proposed tokens.
proposal_token_ids
=
proposals
.
proposal_token_ids
if
non_spec_indices
:
proposal_token_ids
=
proposal_token_ids
[
spec_indices
]
# Get tree buffers.
cart_candidates
=
proposals
.
cart_candidates
if
proposals
.
cart_candidates
is
not
None
else
None
if
cart_candidates
is
not
None
and
non_spec_indices
:
cart_candidates
=
cart_candidates
[
spec_indices
]
# Sampler arguments
sampler_extra_kwargs
:
Dict
[
str
,
Any
]
=
{}
if
self
.
generators
and
isinstance
(
self
.
spec_decode_sampler
,
SpecDecodeStochasticBaseSampler
):
sampler_extra_kwargs
[
"seeded_seqs"
]
=
{
idx
:
self
.
generators
[
sgm
.
request_id
]
for
idx
,
sgm
in
enumerate
(
seq_group_metadata_list
)
if
sgm
.
sampling_params
.
seed
is
not
None
}
if
isinstance
(
self
.
spec_decode_sampler
,
TypicalAcceptanceSampler
):
sampler_extra_kwargs
[
"cart_candidates"
]
=
cart_candidates
sampler_extra_kwargs
[
"best_candidates"
]
=
[]
sampler_extra_kwargs
[
"accept_lengths"
]
=
[]
first_step_flags
=
[]
for
i
,
sgm
in
enumerate
(
seq_group_metadata_list
):
seq
=
next
(
iter
(
sgm
.
seq_data
.
values
()))
first_step_flags
.
append
(
True
if
seq
.
get_first_step_flag
()
else
False
)
sampler_extra_kwargs
[
"first_step_flags"
]
=
first_step_flags
accepted_token_ids
=
self
.
spec_decode_sampler
(
target_with_bonus_probs
=
proposal_verifier_probs
,
bonus_token_ids
=
bonus_token_ids
,
draft_probs
=
proposal_probs
,
draft_token_ids
=
proposal_token_ids
,
**
sampler_extra_kwargs
,
)
# Append output tokens from non-speculative sequences to
# the accepted token ids tensor.
if
not
self
.
tree_decoding
:
non_spec_token_ids
=
non_spec_token_ids
.
expand
(
-
1
,
max_proposal_len
+
1
).
clone
()
else
:
non_spec_token_ids
=
non_spec_token_ids
.
expand
(
-
1
,
max_proposal_len
).
clone
()
non_spec_token_ids
[:,
1
:]
=
-
1
accepted_token_ids
=
torch
.
cat
(
[
accepted_token_ids
,
non_spec_token_ids
])
logprobs
=
proposal_scores
.
logprobs
# Rearrange so that results are in the order of the original seq group
# metadata.
accepted_token_ids
[
original_indices
]
=
accepted_token_ids
.
clone
()
# B x K+1 x D
hidden_states
=
proposal_scores
.
hidden_states
select_indices
=
None
accept_lengths
=
None
select_indices_list
=
[]
if
cart_candidates
is
None
:
if
hidden_states
is
not
None
:
# Only get terminal hidden states for next step
terminal_metadata
=
[
sg
for
sg
in
seq_group_metadata_list
if
sg
.
do_sample
]
# Contract hidden states based on accepted tokens
hs_size
=
hidden_states
.
shape
[
-
1
]
accepted_index
=
accepted_token_ids
+
1
# Convert -1 to 0
accepted_index
=
accepted_index
.
count_nonzero
(
dim
=
1
).
add_
(
-
1
)
# b
# Drop non-terminal prefill chunks hidden states.
hidden_states
=
hidden_states
[
accepted_index
!=
VLLM_INVALID_TOKEN_ID
]
accepted_index
=
accepted_index
[
accepted_index
!=
VLLM_INVALID_TOKEN_ID
]
assert
len
(
accepted_index
)
==
hidden_states
.
shape
[
0
]
==
len
(
terminal_metadata
)
index
=
accepted_index
[:,
None
,
None
].
expand
(
-
1
,
1
,
hs_size
)
# b x 1 x d
second_last_token_hidden_states
=
hidden_states
[:,
-
2
]
# b x d
hidden_states
=
hidden_states
.
gather
(
1
,
index
).
squeeze
(
1
)
# b x d
# Store hidden states from target model for subsequent decode step
self
.
previous_hidden_states
=
HiddenStates
(
hidden_states
,
terminal_metadata
,
second_last_token_hidden_states
)
else
:
retrieve_indices
=
proposals
.
retrieve_indices
batch_size
=
len
(
seq_group_metadata_list
)
best_candidates
=
sampler_extra_kwargs
[
"best_candidates"
]
accept_lengths
=
sampler_extra_kwargs
[
"accept_lengths"
]
# Contract hidden states based on accepted tokens
hs_size
=
hidden_states
.
shape
[
-
1
]
hidden_states
=
hidden_states
.
view
(
batch_size
,
-
1
,
hs_size
)
# Store logits from target model for subsequent proposal
logits
=
proposal_scores
.
logits
logits
=
logits
.
view
(
batch_size
,
-
1
,
logits
.
shape
[
-
1
])
logits
=
logits
[:,
retrieve_indices
]
# [batch_size, retrieve_size, max_depth, vocab_size]
previous_logits_list
=
[]
previous_hidden_state_list
=
[]
retrieve_indices
=
retrieve_indices
.
cpu
()
for
i
in
range
(
batch_size
):
logit
=
logits
[
i
,
best_candidates
[
i
],
accept_lengths
[
i
]].
unsqueeze
(
0
)
previous_logits_list
.
append
(
logit
)
select_indices
=
retrieve_indices
[
best_candidates
[
i
],
:
accept_lengths
[
i
]
+
1
]
hidden_state
=
hidden_states
[
i
,
select_indices
[
-
1
]].
unsqueeze
(
0
)
select_indices_list
.
append
(
select_indices
)
previous_hidden_state_list
.
append
(
hidden_state
)
logits
=
torch
.
cat
(
previous_logits_list
,
dim
=
0
)
self
.
previous_logits
=
Logits
(
logits
,
seq_group_metadata_list
)
hidden_states
=
torch
.
cat
(
previous_hidden_state_list
,
dim
=
0
)
# [batch_size, 1, vocab_size]
self
.
previous_hidden_states
=
HiddenStates
(
hidden_states
,
seq_group_metadata_list
,)
return
accepted_token_ids
,
logprobs
,
select_indices_list
,
accept_lengths
def
move_caches
(
self
,
execute_model_req
:
ExecuteModelRequest
,
select_indices_list
:
List
[
torch
.
Tensor
],
accept_lengths
:
List
[
int
]):
"""Given selected output tokens and accept length,
move kv_caches of selected tokens to right positions.
"""
seq_lens
=
[]
for
sg
in
execute_model_req
.
seq_group_metadata_list
:
seq_ids
=
list
(
sg
.
seq_data
.
keys
())
for
seq_id
in
seq_ids
:
seq_data
=
sg
.
seq_data
[
seq_id
]
seq_len
=
seq_data
.
get_len
()
token_chunk_size
=
sg
.
token_chunk_size
context_len
=
seq_len
-
1
seq_len
=
min
(
seq_len
,
context_len
+
token_chunk_size
)
# first step of tree-style decoding need to ignore first generated token
if
seq_data
.
get_first_step_flag
():
seq_len
-=
1
# move cache is the last step of tree decoding, so set first_step_flag to false
seq_data
.
set_first_step_flag
(
False
)
seq_lens
.
append
(
seq_len
)
model_input
=
self
.
scorer
.
_scorer_worker
.
model_input
block_tables
=
None
if
hasattr
(
model_input
,
'attn_metadata'
)
and
hasattr
(
model_input
.
attn_metadata
,
'block_tables_list'
):
block_tables
=
model_input
.
attn_metadata
.
block_tables_list
if
block_tables
is
None
:
raise
RuntimeError
(
"Can not get block_tables from model_input."
)
cache_engine
=
self
.
scorer
.
_scorer_worker
.
cache_engines
[
execute_model_req
.
virtual_engine
]
block_size
=
cache_engine
.
block_size
batch_size
=
len
(
select_indices_list
)
block_table_stride
=
len
(
block_tables
)
//
batch_size
select_indices_slot_mapping
=
[]
target_slot_mapping
=
[]
for
i
in
range
(
batch_size
):
accept_legth
=
accept_lengths
[
i
]
if
accept_legth
>
0
:
select_indices
=
select_indices_list
[
i
][
1
:]
+
seq_lens
[
i
]
select_indices
=
select_indices
.
tolist
()
self
.
compute_slot_mapping
(
select_indices_slot_mapping
,
i
*
block_table_stride
,
select_indices
,
block_size
,
block_tables
)
target_indices
=
torch
.
arange
(
accept_legth
+
1
)[
1
:]
+
seq_lens
[
i
]
target_indices
=
target_indices
.
tolist
()
self
.
compute_slot_mapping
(
target_slot_mapping
,
i
*
block_table_stride
,
target_indices
,
block_size
,
block_tables
)
if
len
(
select_indices_slot_mapping
)
>
0
:
select_indices_slot_tensor
=
torch
.
tensor
(
select_indices_slot_mapping
,
dtype
=
torch
.
long
,
device
=
self
.
device
).
view
(
-
1
,
1
)
target_slot_mapping_tensor
=
torch
.
tensor
(
target_slot_mapping
,
dtype
=
torch
.
long
,
device
=
self
.
device
).
view
(
-
1
,
1
)
src_dst_tensor
=
torch
.
cat
([
select_indices_slot_tensor
,
target_slot_mapping_tensor
],
dim
=-
1
)
#[batch_size*T, 2]
self
.
kvcache_slot_to_be_moved
=
src_dst_tensor
def
_create_output_sampler_list
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
accepted_token_ids
:
torch
.
Tensor
,
# shape: [batch_size, k+1]
target_logprobs
:
torch
.
Tensor
,
# shape: [batch_size, k+1, vocab_size]
prompt_logprobs
:
Optional
[
torch
.
Tensor
],
# shape: [nprompt_tokens, vocab_size]
k
:
int
,
stage_times
:
Tuple
[
float
,
float
,
float
],
)
->
List
[
SamplerOutput
]:
"""Given the accepted token ids, create a list of SamplerOutput.
The output is padded with -1 tokens such that each sequence has
the same number of outputs.
"""
batch_size
,
num_steps
=
accepted_token_ids
.
shape
accepted_token_ids_by_step
=
accepted_token_ids
.
transpose
(
0
,
1
)
if
self
.
_disable_logprobs
:
# We are skipping the logprobs. Hence don't serialize the
# logprobs related tensors from the GPU. Instead create
# empty/dummy lists.
(
accepted_token_id_ranks_by_step
,
accepted_token_id_logprobs_by_step
,
topk_logprobs_by_step
,
topk_indices_by_step
)
=
\
self
.
_create_dummy_logprob_lists
(
batch_size
,
num_steps
,
self
.
scorer_worker
.
model_config
.
max_logprobs
)
else
:
# Organize input tensors by step instead of by sequence.
target_logprobs_by_step
=
target_logprobs
.
transpose
(
0
,
1
)
# Serialize all tensors into Python lists.
(
accepted_token_id_ranks_by_step
,
accepted_token_id_logprobs_by_step
,
topk_logprobs_by_step
,
topk_indices_by_step
)
=
\
self
.
_create_logprob_lists_from_tensors
(
target_logprobs_by_step
,
accepted_token_ids_by_step
,
self
.
scorer_worker
.
model_config
.
max_logprobs
)
# Get the sequence ids and num_logprobs (sampling parameter) in the
# batch.
seq_ids
,
request_ids_seq_ids_mapping
=
get_all_seq_ids_and_request_ids
(
seq_group_metadata_list
)
num_logprobs_per_seq
=
get_all_num_logprobs
(
seq_group_metadata_list
)
# Serialize tensor to CPU Python list.
accepted_token_ids_by_step
=
accepted_token_ids_by_step
.
tolist
()
# Construct the output on a per-step, per-sequence basis.
# Non-terminal prefill chunks will end up here as rows with just -1s
# i.e mixed-batch [[-1, 1576], [-1, 29884], [-1, -1], [-1, -1]] while
# terminal chunks will only have one generated token at time 0.
sampler_output_list
:
List
[
SamplerOutput
]
=
[]
# Prefills are not multi-step (return at most 1 token), in order to
# avoid padding or repetition to fit decodes, we separate them.
for
i
,
sg
in
enumerate
(
seq_group_metadata_list
):
if
not
sg
.
is_prompt
:
# Requests are ordered as prefills|decodes=>no more prefills.
break
num_logprobs
=
num_logprobs_per_seq
[
i
]
seq_kwargs
=
dict
(
token_id
=-
1
,
token_id_logprob_rank
=
0
,
token_id_logprob
=-
float
(
'inf'
),
topk_token_ids
=
[
-
1
]
*
num_logprobs
,
topk_logprobs
=
[
-
float
(
'inf'
)]
*
num_logprobs
,
seq_id
=
seq_ids
[
i
])
# Terminal chunk, has token.
if
sg
.
do_sample
:
seq_kwargs
.
update
(
dict
(
token_id
=
accepted_token_ids
[
i
][
0
].
item
(),
token_id_logprob_rank
=
accepted_token_id_ranks_by_step
[
0
][
i
],
token_id_logprob
=
accepted_token_id_logprobs_by_step
[
0
]
[
i
],
topk_token_ids
=
topk_indices_by_step
[
0
][
i
]
[:
num_logprobs
],
# output only so step is 0
topk_logprobs
=
topk_logprobs_by_step
[
0
][
i
]
[:
num_logprobs
],
))
needs_plogs
=
(
sg
.
sampling_params
.
prompt_logprobs
and
sg
.
sampling_params
.
prompt_logprobs
>
0
)
plogs
=
None
if
prompt_logprobs
is
not
None
:
# Even non-terminal prompt chunks can have logprobs here.
plogs
=
prompt_logprobs
[
i
]
elif
needs_plogs
:
# Prompt logprobs are requested but `_disable_logprobs` is set.
seq_data
=
next
(
iter
(
sg
.
seq_data
.
values
()))
# Get only the tokens in this chunk!
prompt_token_ids
=
seq_data
.
get_prompt_token_ids
()
prompt_token_ids
=
prompt_token_ids
[
seq_data
.
_num_computed_tokens
:
seq_data
.
_num_computed_tokens
+
sg
.
token_chunk_size
]
is_first_chunk
=
seq_data
.
_num_computed_tokens
==
0
# There's no prob generated for the first token in a sequence.
if
is_first_chunk
:
prompt_token_ids
=
prompt_token_ids
[
1
:]
plogs
=
[
create_logprobs_output
(
token_id
=
p_token_id
,
token_id_logprob_rank
=-
1
,
token_id_logprob
=
0.0
,
topk_token_ids
=
[],
topk_logprobs
=
[],
)
for
p_token_id
in
prompt_token_ids
]
seq_kwargs
.
update
(
dict
(
prompt_logprobs
=
plogs
))
sampler_output_list
.
append
(
SamplerOutput
(
outputs
=
[
create_sequence_group_output
(
**
seq_kwargs
)]))
# type: ignore
# Decodes, create one SamplerOutput per-step (at most K+1).
for
step_index
in
range
(
num_steps
):
if
all
(
token_id
==
-
1
for
sg
,
token_id
in
zip
(
seq_group_metadata_list
,
accepted_token_ids_by_step
[
step_index
])
if
not
sg
.
is_prompt
):
break
step_output_token_ids
:
List
[
CompletionSequenceGroupOutput
]
=
[]
for
sequence_index
in
range
(
batch_size
):
seq_meta
=
seq_group_metadata_list
[
sequence_index
]
# Prompts already processed above.
if
seq_meta
.
is_prompt
:
continue
# Each sequence may have a different num_logprobs; retrieve it.
num_logprobs
=
num_logprobs_per_seq
[
sequence_index
]
step_output_token_ids
.
append
(
create_sequence_group_output
(
token_id
=
accepted_token_ids_by_step
[
step_index
]
[
sequence_index
],
token_id_logprob_rank
=
accepted_token_id_ranks_by_step
[
step_index
][
sequence_index
],
token_id_logprob
=
accepted_token_id_logprobs_by_step
[
step_index
][
sequence_index
],
seq_id
=
seq_ids
[
sequence_index
],
topk_token_ids
=
topk_indices_by_step
[
step_index
]
[
sequence_index
][:
num_logprobs
],
topk_logprobs
=
topk_logprobs_by_step
[
step_index
]
[
sequence_index
][:
num_logprobs
],
step_index
=
step_index
))
sampler_output_list
.
append
(
SamplerOutput
(
outputs
=
step_output_token_ids
))
# Populate the data structures needed to keep track of sequences with
# bonus tokens.
self
.
_track_sequences_with_bonus_tokens
(
seq_ids
,
request_ids_seq_ids_mapping
,
accepted_token_ids_by_step
)
maybe_rejsample_metrics
=
(
self
.
_metrics
.
maybe_collect_rejsample_metrics
(
k
))
if
maybe_rejsample_metrics
is
not
None
and
sampler_output_list
:
sampler_output_list
[
0
].
spec_decode_worker_metrics
=
maybe_rejsample_metrics
# Log time spent in each stage periodically.
# This is periodic because the rejection sampler emits metrics
# periodically.
self
.
_maybe_log_stage_times
(
*
stage_times
)
# First `n_prefills` entries will contain prefills SamplerOutput when
# chunked prefill is enabled, the rest is decodes in multi-step format.
return
sampler_output_list
def
_maybe_log_stage_times
(
self
,
average_time_per_proposal_tok_ms
:
float
,
scoring_time_ms
:
float
,
verification_time_ms
:
float
)
->
None
:
"""Log the speculative stage times. If stat logging is disabled, do
nothing.
"""
if
self
.
_disable_log_stats
:
return
logger
.
info
(
"SpecDecodeWorker stage times: "
"average_time_per_proposal_tok_ms=%.02f "
"scoring_time_ms=%.02f verification_time_ms=%.02f"
,
average_time_per_proposal_tok_ms
,
scoring_time_ms
,
verification_time_ms
)
def
_create_dummy_logprob_lists
(
self
,
batch_size
:
int
,
num_steps
:
int
,
num_top_k
:
int
,
)
->
Tuple
[
List
[
List
[
int
]],
List
[
List
[
float
]],
List
[
List
[
List
[
Optional
[
float
]]]],
List
[
List
[
List
[
Optional
[
int
]]]]]:
"""
Creates and returns four dummy lists representing token probabilities
and their ranks.
This method initializes and returns:
- The ranks of the accepted tokens, shaped (num_steps, batch_size)
- The log probabilities of the accepted tokens,
shaped (num_steps, batch_size)
- The log probabilities of the top k tokens,
shaped (num_steps, batch_size, num_top_k)
- The token IDs of the top k tokens,
shaped (num_steps, batch_size, num_top_k)
Args:
batch_size (int): The size of the batch.
num_steps (int): The number of steps in the sequence.
num_top_k (int): The number of top-k token log probabilities to
return.
Returns:
A tuple containing four dummy lists as described above.
"""
accepted_token_id_ranks_by_step
=
[[
-
1
]
*
batch_size
for
_
in
range
(
num_steps
)]
accepted_token_id_logprobs_by_step
=
[[
0.0
]
*
batch_size
for
_
in
range
(
num_steps
)]
topk_logprobs_by_step
:
List
[
List
[
List
[
Optional
[
float
]]]]
=
[[
[
None
]
*
num_top_k
for
_
in
range
(
batch_size
)
]
for
_
in
range
(
num_steps
)]
topk_indices_by_step
:
List
[
List
[
List
[
Optional
[
int
]]]]
=
[[
[
None
]
*
num_top_k
for
_
in
range
(
batch_size
)
]
for
_
in
range
(
num_steps
)]
return
(
accepted_token_id_ranks_by_step
,
accepted_token_id_logprobs_by_step
,
topk_logprobs_by_step
,
topk_indices_by_step
)
def
_create_logprob_lists_from_tensors
(
self
,
target_logprobs_by_step
:
torch
.
Tensor
,
accepted_token_ids_by_step
:
torch
.
Tensor
,
num_top_k
:
int
,
)
->
Tuple
[
List
[
List
[
int
]],
List
[
List
[
float
]],
List
[
List
[
List
[
Optional
[
float
]]]],
List
[
List
[
List
[
Optional
[
int
]]]]]:
"""
Creates and returns four lists representing token probabilities and
their ranks.
This method initializes and returns four lists containing:
- The ranks of the accepted tokens, shaped (num_steps, batch_size)
- The log probabilities of the accepted tokens,
shaped (num_steps, batch_size)
- The log probabilities of the top k tokens,
shaped (num_steps, batch_size, num_top_k)
- The token IDs of the top k tokens,
shaped (num_steps, batch_size, num_top_k)
Args:
target_logprobs_by_step (torch.Tensor): Tensor representing the
log probabilities of the target model,
shaped (num_steps, batch_size, vocab_size)
accepted_token_ids_by_step (torch.Tensor): Tensor representing
the accepted token_ids, shaped (num_steps, batch_size)
num_top_k (int): The number of top-k token log probabilities to
return.
Returns:
A tuple containing the lists as described above.
"""
# Serialize all tensors to CPU Python lists.
# Get the logprobs/rank of the accepted tokens.
(
accepted_token_id_ranks_by_step_tensor
,
accepted_token_id_logprobs_by_step_tensor
)
=
get_sampled_token_logprobs
(
logprob_tensor
=
target_logprobs_by_step
,
sampled_token_ids
=
accepted_token_ids_by_step
,
)
# Get the top-k logprobs (which may or may not include the
# logprob of the accepted token).
(
topk_logprobs_by_step_tensor
,
topk_indices_by_step_tensor
)
=
target_logprobs_by_step
.
topk
(
k
=
num_top_k
,
dim
=-
1
,
)
accepted_token_id_ranks_by_step
=
(
accepted_token_id_ranks_by_step_tensor
.
tolist
())
accepted_token_id_logprobs_by_step
=
(
accepted_token_id_logprobs_by_step_tensor
.
tolist
())
topk_logprobs_by_step
=
topk_logprobs_by_step_tensor
.
tolist
()
topk_indices_by_step
=
topk_indices_by_step_tensor
.
tolist
()
return
(
accepted_token_id_ranks_by_step
,
accepted_token_id_logprobs_by_step
,
topk_logprobs_by_step
,
topk_indices_by_step
)
def
_track_finished_requests
(
self
,
execute_model_req
:
ExecuteModelRequest
):
"""
Removes the finished requests and their associated sequence ids from
internal book keeping data structures.
"""
for
finished_request
in
execute_model_req
.
finished_requests_ids
:
for
seq_id
in
self
.
_request_id_seq_id_mapping
[
finished_request
]:
self
.
_seq_with_bonus_token_in_last_step
.
discard
(
seq_id
)
del
self
.
_request_id_seq_id_mapping
[
finished_request
]
def
_track_sequences_with_bonus_tokens
(
self
,
seq_ids
:
List
[
int
],
request_ids_seq_ids_mapping
:
Dict
[
str
,
Set
[
int
]],
accepted_token_ids_by_step
:
List
[
List
[
int
]]):
"""
Updates the internal data structures which keep track of sequences
which have been assigned bonus tokens in their last forward pass.
"""
for
seq_index
,
seq_id
in
enumerate
(
seq_ids
):
last_token_id
=
accepted_token_ids_by_step
[
-
1
][
seq_index
]
if
last_token_id
==
-
1
:
self
.
_seq_with_bonus_token_in_last_step
.
discard
(
seq_id
)
else
:
self
.
_seq_with_bonus_token_in_last_step
.
add
(
seq_id
)
for
request_id
,
sequences
in
request_ids_seq_ids_mapping
.
items
():
self
.
_request_id_seq_id_mapping
[
request_id
].
update
(
sequences
)
@
cached_property
def
_vocab_size
(
self
)
->
int
:
"""Get the vocab size of the model and make sure it's consistent between
draft and target workers.
"""
vocab_sizes
=
[
worker
.
vocab_size
for
worker
in
[
self
.
proposer_worker
,
self
.
scorer_worker
]
]
assert
all
(
vocab_sizes
[
0
]
==
vocab_size
for
vocab_size
in
vocab_sizes
)
return
vocab_sizes
[
0
]
@
property
def
rank
(
self
):
return
self
.
scorer_worker
.
rank
@
property
def
device
(
self
):
return
self
.
scorer_worker
.
device
@
property
def
_driver_rank
(
self
)
->
int
:
return
0
def
get_cache_block_size_bytes
(
self
):
"""Return the size of a cache block in bytes.
This function is only used to compose workers within a SpecDecodeWorker.
We leave composing a SpecDecodeWorker within a SpecDecodeWorker
undefined for now, although it could be implemented in the future.
See https://arxiv.org/abs/2308.04623.
"""
raise
NotImplementedError
def
start_profile
(
self
):
if
isinstance
(
self
.
scorer_worker
,
WorkerBase
):
self
.
scorer_worker
.
start_profile
()
def
stop_profile
(
self
):
if
isinstance
(
self
.
scorer_worker
,
WorkerBase
):
self
.
scorer_worker
.
stop_profile
()
def
split_num_cache_blocks_evenly
(
scorer_cache_block_size_bytes
:
int
,
proposer_cache_block_size_bytes
:
int
,
total_num_gpu_blocks
:
int
)
->
int
:
"""Given total_num_gpu_blocks, the number of GPU blocks that could be
allocate to the target model, this function calculates how many blocks
should be given to the draft and target model.
Note that usually the block size, in bytes, of each model is different,
as it's a function of number of KV/layer, number of heads, and hidden
dimension size.
Since the target and draft models allocate the same number of blocks, we
simply calculate the number of blocks where if allocated by both models,
the total memory usage from KV cache is no larger than the number of
blocks allocatable by the target model alone.
"""
new_num_gpu_blocks
=
int
(
total_num_gpu_blocks
*
scorer_cache_block_size_bytes
/
(
proposer_cache_block_size_bytes
+
scorer_cache_block_size_bytes
))
return
new_num_gpu_blocks
def
prepare_prefill_hidden_states
(
prefill_hidden_states
:
torch
.
Tensor
)
->
HiddenStates
:
# For prefill step in proposer, we run the model for N-1 tokens
# because Nth token will be processed in the first decode step. For
# N-1 tokens, the input should be 0:N-1 hidden states which should
# be concatanated with 1:N token (since output of scorer has to be
# the input for proposer). Therefore, we shift the hidden states to
# align n-1th hidden state with nth token.
return
HiddenStates
(
prefill_hidden_states
.
roll
(
shifts
=
1
,
dims
=
0
))
if
prefill_hidden_states
is
not
None
else
None
vllm/spec_decode/tree_style_proposer.py
deleted
100644 → 0
View file @
5ad884ee
from
typing
import
List
,
Optional
,
Set
,
Tuple
,
Any
,
Dict
import
torch
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
(
ExecuteModelRequest
,
SequenceGroupMetadata
,
get_all_seq_ids
)
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeProposer
)
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.spec_decode.util
import
sampler_output_to_torch
class
TreeStyleProposer
(
SpeculativeProposer
):
"""Helper class which separates out sequences which would exceed the max
model length when speculated upon.
This allows combinations of models such as JackFram/llama-68m draft with
meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
2048 while Llama2-13b has max_position_embeddings of 4096.
We treat the sequences which exceed the proposal draft model length as
"non-spec sequences". Essentially they skip the draft model and go through
normal decoding in the target model.
Currently, only proposal_lens of 0 and k are supported, where k is a global
batch proposal length. In the future vLLM should support per-sequence
proposal lengths.
"""
def
__init__
(
self
,
worker
:
ProposerWorkerBase
,
device
:
str
,
vocab_size
:
int
,
tree_buffers
:
Dict
[
str
,
Any
],
max_proposal_len
:
Optional
[
int
]
=
None
,
):
self
.
_worker
=
worker
self
.
_device
=
device
self
.
tree_buffers
=
tree_buffers
self
.
max_proposal_len
=
max_proposal_len
self
.
_vocab_size
=
vocab_size
def
get_spec_proposals
(
self
,
execute_model_req
:
ExecuteModelRequest
,
seq_ids_with_bonus_token_in_last_step
:
Set
[
int
],
)
->
SpeculativeProposals
:
"""Get speculative proposals given the input batch.
Sequences which would exceed the max model length are skipped during
speculation.
"""
#proposal_len = execute_model_req.num_lookahead_slots
proposal_len
=
self
.
tree_buffers
[
"tree_indices"
].
shape
[
0
]
seq_group_metadata_list
=
execute_model_req
.
seq_group_metadata_list
# Split speculative- and non-speculative- sequences.
(
proposal_lens
,
nonzero_proposal_len_seqs
,
nonzero_proposal_len_indices
,
)
=
self
.
_split_by_proposal_len
(
seq_group_metadata_list
,
proposal_len
)
if
nonzero_proposal_len_seqs
:
# Speculate tokens using the draft worker for the speculative
# sequences.
# If sampler_transposed is true, then maybe_sampler_output's
# token_ids is like [batch] format in proposal_len size list,
# while if it is false, the format would be [proposal_len]
# in batch size list
hidden_states
=
execute_model_req
.
previous_hidden_states
if
hidden_states
is
not
None
:
hidden_states
.
prune
(
nonzero_proposal_len_seqs
)
logits
=
execute_model_req
.
previous_logits
if
logits
is
not
None
:
logits
.
prune
(
nonzero_proposal_len_seqs
)
nonzero_execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
nonzero_proposal_len_seqs
,
num_lookahead_slots
=
proposal_len
,
previous_hidden_states
=
hidden_states
,
previous_logits
=
logits
,
)
maybe_sampler_output
,
transposed
=
self
.
_worker
.
sampler_output
(
execute_model_req
=
nonzero_execute_model_req
,
sample_len
=
proposal_len
,
seq_ids_with_bonus_token_in_last_step
=
\
seq_ids_with_bonus_token_in_last_step
,
)
(
proposal_lens
,
maybe_sampler_output
,
nonzero_proposal_len_indices
,
)
=
self
.
_remove_no_proposal_seqs
(
proposal_lens
,
maybe_sampler_output
,
nonzero_proposal_len_indices
,
transposed
)
else
:
# If no sequences can be speculated, set sampler output to None.
maybe_sampler_output
=
None
transposed
=
False
# Combine speculative- and non-speculative sequences into the same
# representation.
proposal_tokens
,
proposal_probs
,
proposal_lens
,
cart_candidates
,
tree_attn_masks
=
self
.
_merge_outputs
(
batch_size
=
len
(
seq_group_metadata_list
),
proposal_len
=
proposal_len
,
maybe_sampler_output
=
maybe_sampler_output
,
proposal_lens
=
proposal_lens
,
nonzero_proposal_len_indices
=
nonzero_proposal_len_indices
,
sampler_transposed
=
transposed
,
)
tree_position_ids_list
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
seq_data
=
next
(
iter
(
seq_group_metadata
.
seq_data
.
values
()))
if
seq_data
.
get_first_step_flag
():
seq_len
=
seq_data
.
get_len
()
-
1
else
:
seq_len
=
seq_data
.
get_len
()
tree_position_ids
=
self
.
tree_buffers
[
'tree_position_ids'
]
+
seq_len
tree_position_ids_list
.
append
(
tree_position_ids
)
tree_position_ids
=
torch
.
stack
(
tree_position_ids_list
,
dim
=
0
).
reshape
(
-
1
,
1
)
proposals
=
SpeculativeProposals
(
proposal_token_ids
=
proposal_tokens
,
proposal_probs
=
proposal_probs
,
proposal_lens
=
proposal_lens
,
no_proposals
=
maybe_sampler_output
is
None
,
cart_candidates
=
cart_candidates
,
retrieve_indices
=
self
.
tree_buffers
[
'retrieve_indices'
],
tree_attn_masks
=
tree_attn_masks
,
tree_position_ids
=
tree_position_ids
)
return
proposals
def
_split_by_proposal_len
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
proposal_len
:
int
,
)
->
Tuple
[
List
[
int
],
List
[
SequenceGroupMetadata
],
List
[
int
]]:
"""Split sequences by two groups:
1. Sequences with non-zero proposal length.
2. Sequences with zero proposal length (due to disabled speculation
or exceed the maximum model length).
"""
proposal_lens
:
List
[
int
]
=
[]
nonzero_proposal_len_seqs
:
List
[
SequenceGroupMetadata
]
=
[]
nonzero_proposal_len_indices
:
List
[
int
]
=
[]
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
# The speculative decoding for this request has been disabled
# (e.g. due to high traffic).
if
seq_group_metadata
.
num_speculative_tokens
==
0
:
proposal_lens
.
append
(
0
)
continue
seq_data
=
next
(
iter
(
seq_group_metadata
.
seq_data
.
values
()))
seq_len
=
seq_data
.
get_len
()
# Currently only proposal lens of 0 or the global batch proposal len
# are supported.
# If max_proposal_len is defined, then we shall no exceed this
# quota for nonzero_proposal
new_k
=
0
if
(
self
.
max_proposal_len
is
None
or
seq_len
+
proposal_len
<
self
.
max_proposal_len
):
new_k
=
proposal_len
nonzero_proposal_len_seqs
.
append
(
seq_group_metadata
)
nonzero_proposal_len_indices
.
append
(
i
)
proposal_lens
.
append
(
new_k
)
seq_group_metadata
.
num_speculative_tokens
=
new_k
return
(
proposal_lens
,
nonzero_proposal_len_seqs
,
nonzero_proposal_len_indices
,
)
@
staticmethod
def
_remove_no_proposal_seqs
(
proposal_lens
,
maybe_sampler_output
,
nonzero_proposal_len_indices
,
transposed
):
"""Remove sequences from nonzero_proposal_len_indices and reset
their proposal_len to 0 the draft worker does not provide a proposal
(maybe_sampler_output=None). This can avoid scoring overheads.
"""
# If maybe_sampler_output is None, then the draft worker did not
# provide a proposal for any sequence and thus no action needed.
# Also we do not support transposed maybe_sampler_output for now
# because it seems not straightforward for draft workers outputting
# transposed sampler outputs to handle the case of no proposal.
if
maybe_sampler_output
is
None
or
transposed
:
return
(
proposal_lens
,
maybe_sampler_output
,
nonzero_proposal_len_indices
)
new_proposal_lens
:
List
[
int
]
=
[]
new_nonzero_proposal_len_indices
:
List
[
int
]
=
[]
new_maybe_sampler_output
:
List
[
SamplerOutput
]
=
[]
nonzero_proposal_len_idx_ptr
=
0
seq_idx
=
0
while
seq_idx
<
len
(
proposal_lens
)
and
nonzero_proposal_len_idx_ptr
<
len
(
nonzero_proposal_len_indices
):
if
seq_idx
<
nonzero_proposal_len_indices
[
nonzero_proposal_len_idx_ptr
]:
# Sequence is not in the original nonzero_proposal_len_indices,
# meaning that it has a proposal length of 0 before sending to
# the draft worker.
assert
proposal_lens
[
seq_idx
]
==
0
new_proposal_lens
.
append
(
0
)
else
:
# Sequence is in the original nonzero_proposal_len_indices
if
maybe_sampler_output
[
nonzero_proposal_len_idx_ptr
]
is
None
:
# but does not have a proposal from the draft worker.
new_proposal_lens
.
append
(
0
)
else
:
# and has a proposal from the draft worker. Add it to the
# new nonzero proposal list and keep the sampler output.
new_proposal_lens
.
append
(
proposal_lens
[
seq_idx
])
new_nonzero_proposal_len_indices
.
append
(
seq_idx
)
new_maybe_sampler_output
.
append
(
maybe_sampler_output
[
nonzero_proposal_len_idx_ptr
])
nonzero_proposal_len_idx_ptr
+=
1
seq_idx
+=
1
# The remaining sequences should have proposal length of 0.
new_proposal_lens
.
extend
(
proposal_lens
[
seq_idx
:])
# We assume sampler_output will not be a list of all Nones.
# In this case this function should not be called.
assert
new_maybe_sampler_output
return
(
new_proposal_lens
,
new_maybe_sampler_output
,
new_nonzero_proposal_len_indices
)
def
_merge_outputs
(
self
,
batch_size
:
int
,
proposal_len
:
int
,
maybe_sampler_output
:
Optional
[
List
[
SamplerOutput
]],
proposal_lens
:
List
[
int
],
nonzero_proposal_len_indices
:
List
[
int
],
sampler_transposed
:
bool
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""After speculations are produced, merge the speculation results with
the skipped sequences.
"""
retrieve_indices
=
self
.
tree_buffers
[
"retrieve_indices"
]
if
maybe_sampler_output
is
None
:
# If no speculative tokens, the sampler output will be None.
# In this case we return empty proposals.
proposal_tokens
=
torch
.
tensor
(
-
1
,
dtype
=
torch
.
long
,
device
=
self
.
_device
).
expand
(
batch_size
,
proposal_len
)
proposal_probs
=
torch
.
tensor
(
0
,
dtype
=
torch
.
float32
,
device
=
self
.
_device
).
expand
(
batch_size
,
proposal_len
,
self
.
_vocab_size
)
proposal_lens_tensor
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
self
.
_device
).
expand
(
len
(
proposal_lens
))
cart_candidates_tensor
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
self
.
_device
).
expand
(
batch_size
,
retrieve_indices
.
shape
[
0
],
retrieve_indices
.
shape
[
1
])
tree_attn_masks_tensor
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int32
,
device
=
self
.
_device
).
expand
(
batch_size
,
self
.
tree_buffers
[
"tree_attn_masks"
].
shape
[
0
],
self
.
tree_buffers
[
"tree_attn_masks"
].
shape
[
1
])
return
proposal_tokens
,
proposal_probs
,
proposal_lens_tensor
,
cart_candidates_tensor
,
tree_attn_masks_tensor
sampler_output
=
maybe_sampler_output
proposal_tokens
,
proposal_probs
,
_
,
_
,
cart_candidates
,
tree_attn_masks
=
sampler_output_to_torch
(
sampler_output
,
sampler_transposed
)
# Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
entire_proposal_tokens
=
proposal_tokens
.
new_full
(
size
=
(
batch_size
,
*
proposal_tokens
.
shape
[
1
:]),
fill_value
=-
1
,
)
entire_proposal_tokens
[
nonzero_proposal_len_indices
]
=
proposal_tokens
entire_proposal_probs
=
None
proposal_tokens
,
proposal_probs
=
(
entire_proposal_tokens
,
entire_proposal_probs
,
)
proposal_lens_tensor
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
long
,
device
=
self
.
_device
)
proposal_lens_tensor
[
nonzero_proposal_len_indices
]
=
proposal_len
entire_cart_candidates
=
cart_candidates
.
new_zeros
(
batch_size
,
*
cart_candidates
.
shape
[
1
:],
)
entire_cart_candidates
[
nonzero_proposal_len_indices
]
=
cart_candidates
entire_tree_attn_masks
=
tree_attn_masks
.
new_zeros
(
batch_size
,
*
tree_attn_masks
.
shape
[
1
:],
)
entire_tree_attn_masks
[
nonzero_proposal_len_indices
]
=
tree_attn_masks
entire_tree_attn_masks
=
entire_tree_attn_masks
.
reshape
(
-
1
,
tree_attn_masks
.
shape
[
-
1
])
return
proposal_tokens
,
proposal_probs
,
proposal_lens_tensor
,
entire_cart_candidates
,
entire_tree_attn_masks
vllm/spec_decode/util.py
deleted
100644 → 0
View file @
5ad884ee
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
time
from
contextlib
import
contextmanager
from
typing
import
Dict
,
List
,
Optional
,
Sequence
,
Tuple
import
torch
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
PromptLogprobs
,
SequenceGroupMetadata
,
SequenceOutput
)
SeqId
=
int
def
get_all_num_logprobs
(
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
])
->
List
[
int
]:
"""Given a list of SequenceGroupMetadata, create a list of all num_logprobs.
If the sampling params do not call for any logprobs, return 0 for that
sequence.
"""
all_num_logprobs
:
List
[
int
]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
num_logprobs
=
seq_group_metadata
.
sampling_params
.
logprobs
if
num_logprobs
is
None
:
num_logprobs
=
0
all_num_logprobs
.
append
(
num_logprobs
)
return
all_num_logprobs
def
get_sampled_token_logprobs
(
# shape [num_steps, batch_size, vocab_size]
logprob_tensor
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
,
# shape [num_steps, batch_size]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Get the logprobs for the sampled tokens. Returns the ranks and logprobs.
"""
num_steps
,
batch_size
,
vocab_size
=
logprob_tensor
.
shape
selected_logprobs
=
logprob_tensor
[
torch
.
arange
(
num_steps
).
unsqueeze
(
1
),
torch
.
arange
(
batch_size
),
sampled_token_ids
,
]
expanded_selected_logprobs
=
selected_logprobs
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
vocab_size
)
sampled_token_ids_ranks
=
(
logprob_tensor
>
expanded_selected_logprobs
).
sum
(
-
1
).
add_
(
1
)
return
sampled_token_ids_ranks
,
selected_logprobs
def
create_logprobs_output
(
token_id
:
int
,
token_id_logprob_rank
:
int
,
token_id_logprob
:
float
,
topk_token_ids
:
List
[
Optional
[
int
]],
topk_logprobs
:
List
[
Optional
[
float
]],
)
->
Dict
[
int
,
Logprob
]:
"""Create a Logprob Dict for a token given the sampling results.
Args:
token_id (int): The sampled token for the sequence.
token_id_logprob_rank (int): The logprob rank of the sampled token.
token_id_logprob (float): The logprob value of the sampled token.
topk_token_ids (List[Optional[int]]): The list of top-k token ids.
topk_logprobs (List[Optional[float]]): The list of top-k logprobs.
"""
# vLLM logprobs always include the sampled token. In addition, the user may
# request topk-logprobs (where top-k varies per user up to max_logprobs).
logprobs
:
Dict
[
int
,
Logprob
]
=
{
token_id
:
Logprob
(
logprob
=
token_id_logprob
,
rank
=
token_id_logprob_rank
,
),
}
logprobs
.
update
({
topk_token_id
:
Logprob
(
logprob
=
topk_logprob
if
topk_logprob
is
not
None
else
0.0
,
rank
=
topk_index
+
1
,
)
for
topk_index
,
(
topk_token_id
,
topk_logprob
)
\
in
enumerate
(
zip
(
topk_token_ids
,
topk_logprobs
))
\
if
topk_token_id
is
not
None
})
return
logprobs
def
create_sequence_group_output
(
token_id
:
int
,
token_id_logprob_rank
:
int
,
token_id_logprob
:
float
,
seq_id
:
SeqId
,
topk_token_ids
:
List
[
Optional
[
int
]],
topk_logprobs
:
List
[
Optional
[
float
]],
prompt_logprobs
:
Optional
[
PromptLogprobs
]
=
None
,
step_index
:
Optional
[
int
]
=
0
)
->
CompletionSequenceGroupOutput
:
"""Create a SequenceGroupOutput given the sampling results.
Args:
token_id (int): The sampled token for the sequence.
token_id_logprob_rank (int): The logprob rank of the sampled token.
token_id_logprob (float): The logprob value of the sampled token.
seq_id (int): The sequence id.
topk_token_ids (List[Optional[int]]): The list of top-k token ids.
topk_logprobs (List[Optional[float]]): The list of top-k logprobs.
step_index: (Optional[int]): The index of the speculative token.
"""
logprobs
=
create_logprobs_output
(
token_id
,
token_id_logprob_rank
,
token_id_logprob
,
topk_token_ids
,
topk_logprobs
,
)
return
CompletionSequenceGroupOutput
(
samples
=
[
SequenceOutput
(
parent_seq_id
=
seq_id
,
output_token
=
token_id
,
logprobs
=
logprobs
)
],
prompt_logprobs
=
prompt_logprobs
,
step_index
=
step_index
)
def
split_batch_by_proposal_len
(
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
proposal_lens
:
List
[
int
],
)
->
Tuple
[
Tuple
[
List
[
SequenceGroupMetadata
],
List
[
int
]],
Tuple
[
List
[
SequenceGroupMetadata
],
List
[
int
]]]:
"""Utility function that splits a batch based on whether the proposal len is
zero or not. We should remove this once vLLM supports per-sequence proposal
lens in a batch.
"""
nonzero_lists
:
Tuple
[
List
[
SequenceGroupMetadata
],
List
[
int
]]
=
([],
[])
zero_lists
:
Tuple
[
List
[
SequenceGroupMetadata
],
List
[
int
]]
=
([],
[])
for
i
,
(
seq_group
,
proposal_len
)
in
enumerate
(
zip
(
seq_group_metadata_list
,
proposal_lens
)):
seq_groups
,
indices
=
nonzero_lists
if
proposal_len
else
zero_lists
seq_groups
.
append
(
seq_group
)
indices
.
append
(
i
)
return
nonzero_lists
,
zero_lists
def
sampler_output_to_torch
(
sampler_output_list
:
Sequence
[
SamplerOutput
],
sampler_transposed
:
bool
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]:
"""Utility function which converts a list of SamplerOutput to tensors.
sampler_transposed here is used as the indicator for whether
we need do additional tensor transpose logic here.
Returns:
sampled_token_ids: torch.Tensor
shape: [batch_size, len(sampler_output_list)]
sampled_token_probs: torch.Tensor
shape: [batch_size, len(sampler_output_list), vocab_size]
"""
# shape: [batch_size, num_sampler_output, vocab_size]
sampled_token_probs
=
None
if
sampler_output_list
[
0
].
sampled_token_probs
is
not
None
:
sampled_token_probs
=
torch
.
stack
(
[
sampler_output
.
sampled_token_probs
for
sampler_output
in
sampler_output_list
],
dim
=
0
,
)
# shape: [batch_size, num_sampler_output, vocab_size]
sampled_token_logprobs
=
None
if
sampler_output_list
[
0
].
logprobs
is
not
None
:
sampled_token_logprobs
=
torch
.
stack
(
[
sampler_output
.
logprobs
for
sampler_output
in
sampler_output_list
],
dim
=
0
,
)
# shape: [batch_size, num_sampler_output]
sampled_token_ids
=
torch
.
stack
(
[
sampler_output
.
sampled_token_ids
.
flatten
()
for
sampler_output
in
sampler_output_list
],
dim
=
0
,
)
if
sampler_transposed
:
sampled_token_probs
=
sampled_token_probs
.
transpose
(
0
,
1
)
sampled_token_logprobs
=
sampled_token_logprobs
.
transpose
(
0
,
1
)
sampled_token_ids
=
sampled_token_ids
.
transpose
(
0
,
1
)
if
sampler_output_list
[
0
].
hidden_states
is
not
None
:
# shape: [batch_size, num_sampler_output, hidden_dim]
sampled_hidden_states
=
torch
.
stack
(
[
sampler_output
.
hidden_states
for
sampler_output
in
sampler_output_list
],
dim
=
0
,
)
if
sampler_transposed
:
sampled_hidden_states
=
sampled_hidden_states
.
transpose
(
0
,
1
)
else
:
sampled_hidden_states
=
None
sampled_cart_candidates
=
None
if
sampler_output_list
[
0
].
cart_candidates
is
not
None
:
sampled_cart_candidates
=
torch
.
cat
(
[
sampler_output
.
cart_candidates
for
sampler_output
in
sampler_output_list
],
dim
=
0
,
)
if
sampler_transposed
:
sampled_cart_candidates
=
sampled_cart_candidates
.
transpose
(
0
,
1
)
sampled_tree_attn_masks
=
None
if
sampler_output_list
[
0
].
tree_attn_masks
is
not
None
:
sampled_tree_attn_masks
=
torch
.
stack
(
[
sampler_output
.
tree_attn_masks
for
sampler_output
in
sampler_output_list
],
dim
=
0
,
)
if
sampler_transposed
:
sampled_tree_attn_masks
=
sampled_tree_attn_masks
.
transpose
(
0
,
1
)
return
(
sampled_token_ids
,
sampled_token_probs
,
sampled_token_logprobs
,
sampled_hidden_states
,
sampled_cart_candidates
,
sampled_tree_attn_masks
)
def
maybe_mock_device_tensors
(
sampler_output
:
SamplerOutput
,
batch_size
:
int
,
vocab_size
:
int
,
device
:
str
)
->
None
:
"""Helper method which mocks out the GPU tensors in SamplerOutput with dummy
values. This will be removed in PR 7/9.
https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
"""
values
=
[
sampler_output
.
sampled_token_probs
,
sampler_output
.
sampled_token_ids
]
assert
all
(
v
is
None
for
v
in
values
)
or
not
any
(
v
is
None
for
v
in
values
)
if
not
any
(
v
is
None
for
v
in
values
):
# Do nothing if the tensors are already created (usually in unit tests).
return
# Softmax to ensure valid probs.
sampler_output
.
sampled_token_probs
=
torch
.
nn
.
functional
.
softmax
(
torch
.
rand
(
batch_size
,
vocab_size
,
dtype
=
torch
.
float32
,
device
=
device
),
dim
=-
1
)
sampler_output
.
sampled_token_ids
=
torch
.
randint
(
low
=
10
,
high
=
100
,
size
=
(
batch_size
,
),
dtype
=
torch
.
long
,
device
=
device
)
@
contextmanager
def
nvtx_range
(
msg
,
*
args
,
**
kwargs
):
"""
Context manager / decorator that pushes an NVTX range at the beginning
of its scope, and pops it at the end. If extra arguments are given,
they are passed as arguments to msg.format().
If running with cuda graphs, you must enable nsys cuda graph profiling.
Arguments:
msg (string): message to associate with the range
"""
if
current_platform
.
is_cuda_alike
():
torch
.
cuda
.
nvtx
.
range_push
(
msg
.
format
(
*
args
,
**
kwargs
))
try
:
yield
finally
:
torch
.
cuda
.
nvtx
.
range_pop
()
else
:
yield
class
Timer
:
"""Basic timer context manager for measuring CPU time.
"""
def
__enter__
(
self
):
self
.
start_time
=
time
.
time
()
return
self
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
self
.
end_time
=
time
.
time
()
self
.
elapsed_time_s
=
self
.
end_time
-
self
.
start_time
self
.
elapsed_time_ms
=
self
.
elapsed_time_s
*
1000
vllm/triton_utils/custom_cache_manager.py
deleted
100644 → 0
View file @
5ad884ee
# SPDX-License-Identifier: Apache-2.0
import
os
from
triton.runtime.cache
import
(
FileCacheManager
,
default_cache_dir
,
default_dump_dir
,
default_override_dir
)
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
def
maybe_set_triton_cache_manager
()
->
None
:
"""Set environment variable to tell Triton to use a
custom cache manager"""
cache_manger
=
os
.
environ
.
get
(
"TRITON_CACHE_MANAGER"
,
None
)
if
cache_manger
is
None
:
manager
=
"vllm.triton_utils.custom_cache_manager:CustomCacheManager"
logger
.
info
(
"Setting Triton cache manager to: %s"
,
manager
)
os
.
environ
[
"TRITON_CACHE_MANAGER"
]
=
manager
class
CustomCacheManager
(
FileCacheManager
):
"""Re-implements Triton's cache manager, ensuring that a
unique cache directory is created for each process. This is
needed to avoid collisions when running with tp>1 and
using multi-processing as the distributed backend.
Note this issue was fixed by triton-lang/triton/pull/4295,
but the fix is not yet included in triton==v3.0.0. However,
it should be included in the subsequent version.
"""
def
__init__
(
self
,
key
,
override
=
False
,
dump
=
False
):
self
.
key
=
key
self
.
lock_path
=
None
if
dump
:
self
.
cache_dir
=
default_dump_dir
()
self
.
cache_dir
=
os
.
path
.
join
(
self
.
cache_dir
,
self
.
key
)
self
.
lock_path
=
os
.
path
.
join
(
self
.
cache_dir
,
"lock"
)
os
.
makedirs
(
self
.
cache_dir
,
exist_ok
=
True
)
elif
override
:
self
.
cache_dir
=
default_override_dir
()
self
.
cache_dir
=
os
.
path
.
join
(
self
.
cache_dir
,
self
.
key
)
else
:
# create cache directory if it doesn't exist
self
.
cache_dir
=
os
.
getenv
(
"TRITON_CACHE_DIR"
,
""
).
strip
()
or
default_cache_dir
()
if
self
.
cache_dir
:
# self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
self
.
cache_dir
=
os
.
path
.
join
(
self
.
cache_dir
,
self
.
key
)
self
.
lock_path
=
os
.
path
.
join
(
self
.
cache_dir
,
"lock"
)
os
.
makedirs
(
self
.
cache_dir
,
exist_ok
=
True
)
else
:
raise
RuntimeError
(
"Could not create or locate cache dir"
)
vllm/v1/attention/backends/mla/common.py
View file @
3de379de
...
...
@@ -161,7 +161,7 @@ curr_o, curr_lse = scaled_dot_product_attention(
for chunk_idx in range(cdiv(C, MCC)):
chunk_start = chunk_idx * MCC
chunk_end = min(chunk_start + MCC, C)
Sc = chunk_end - chunk_start
_table
Sc = chunk_end - chunk_start
cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end]
cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end]
cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P)
...
...
vllm/v1/attention/backends/utils.py
View file @
3de379de
...
...
@@ -45,10 +45,8 @@ class CommonAttentionMetadata:
seq_lens_cpu
:
torch
.
Tensor
"""(batch_size,), the length of each request including both computed tokens
and newly scheduled tokens"""
num_computed_tokens_cpu
:
torch
.
Tensor
"""(batch_size,), the number of computed tokens for each request"""
num_reqs
:
int
"""Number of requests"""
num_actual_tokens
:
int
...
...
vllm/worker/cpu_enc_dec_model_runner.py
deleted
100644 → 0
View file @
5ad884ee
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
dataclasses
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
cast
import
torch
from
vllm.attention
import
AttentionMetadata
from
vllm.forward_context
import
set_forward_context
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.sequence
import
IntermediateTensors
,
SequenceGroupMetadata
from
vllm.utils
import
make_tensor_with_pad
from
vllm.worker.cpu_model_runner
import
(
CPUModelRunnerBase
,
ModelInputForCPUBuilder
,
ModelInputForCPUWithSamplingMetadata
)
from
vllm.worker.model_runner_base
import
(
_add_attn_metadata_broadcastable_dict
,
_add_sampling_metadata_broadcastable_dict
)
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
EncoderDecoderModelInputForCPU
(
ModelInputForCPUWithSamplingMetadata
):
"""
Used by the EncoderDecoderModelRunner.
"""
encoder_input_tokens
:
Optional
[
torch
.
Tensor
]
=
None
encoder_input_positions
:
Optional
[
torch
.
Tensor
]
=
None
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
tensor_dict
=
{
"input_tokens"
:
self
.
input_tokens
,
"input_positions"
:
self
.
input_positions
,
"encoder_input_tokens"
:
self
.
encoder_input_tokens
,
"encoder_input_positions"
:
self
.
encoder_input_positions
,
"multi_modal_kwargs"
:
self
.
multi_modal_kwargs
,
}
_add_attn_metadata_broadcastable_dict
(
tensor_dict
,
self
.
attn_metadata
)
_add_sampling_metadata_broadcastable_dict
(
tensor_dict
,
self
.
sampling_metadata
)
return
tensor_dict
@
classmethod
def
from_broadcasted_tensor_dict
(
cls
,
tensor_dict
:
Dict
[
str
,
Any
],
attn_backend
:
Optional
[
"AttentionBackend"
]
=
None
,
)
->
"EncoderDecoderModelInputForCPU"
:
return
cast
(
EncoderDecoderModelInputForCPU
,
super
().
from_broadcasted_tensor_dict
(
tensor_dict
,
attn_backend
))
class
CPUEncoderDecoderModelRunner
(
CPUModelRunnerBase
[
EncoderDecoderModelInputForCPU
]):
_model_input_cls
:
Type
[
EncoderDecoderModelInputForCPU
]
=
(
EncoderDecoderModelInputForCPU
)
_builder_cls
:
Type
[
ModelInputForCPUBuilder
]
=
ModelInputForCPUBuilder
def
_list_to_int32_tensor
(
self
,
_list
:
List
[
int
],
)
->
torch
.
Tensor
:
return
torch
.
tensor
(
_list
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
def
_list_to_long_tensor
(
self
,
_list
:
List
[
int
],
)
->
torch
.
Tensor
:
return
torch
.
tensor
(
_list
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
def
_empty_int32_tensor
(
self
)
->
torch
.
Tensor
:
return
self
.
_list_to_int32_tensor
([])
def
_empty_long_tensor
(
self
)
->
torch
.
Tensor
:
return
self
.
_list_to_long_tensor
([])
def
make_model_input_from_broadcasted_tensor_dict
(
self
,
tensor_dict
:
Dict
[
str
,
Any
])
->
EncoderDecoderModelInputForCPU
:
return
EncoderDecoderModelInputForCPU
.
from_broadcasted_tensor_dict
(
tensor_dict
,
attn_backend
=
self
.
attn_backend
,
)
def
prepare_model_input
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
EncoderDecoderModelInputForCPU
:
model_input
=
self
.
_prepare_model_input_tensors
(
seq_group_metadata_list
,
finished_requests_ids
)
(
attn_metadata
,
encoder_input_tokens_tensor
,
encoder_input_positions_tensor
,
)
=
self
.
_prepare_encoder_model_input_tensors
(
seq_group_metadata_list
,
model_input
)
# Sampling metadata is only required for the final pp group
generators
=
self
.
get_generators
(
finished_requests_ids
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
model_input
.
seq_lens
,
model_input
.
query_lens
,
self
.
device
,
pin_memory
=
False
,
generators
=
generators
)
return
dataclasses
.
replace
(
model_input
,
sampling_metadata
=
sampling_metadata
,
attn_metadata
=
attn_metadata
,
encoder_input_tokens
=
encoder_input_tokens_tensor
,
encoder_input_positions
=
encoder_input_positions_tensor
,
virtual_engine
=
virtual_engine
,
)
def
_prepare_encoder_model_input_tensors
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
model_input
:
EncoderDecoderModelInputForCPU
,
)
->
Tuple
[
AttentionMetadata
,
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]:
"""Helper method to prepare the encoder- and cross-attn-related
model inputs based on a given sequence group. These additional inputs
are used to augment an already-computed `EncoderDecoderModelInput`
data structure which already has decoder-related model inputs
populated.
Sets the following attn_metadata fields:
* `num_encoder_tokens`
* `encoder_seq_lens`
* `encoder_seq_lens_tensor`
* `max_encoder_seq_len`
* `cross_slot_mapping`
* `cross_block_tables`
Constructs a new model inputs data structure, based on
(1) the existing fields in the `model_inputs` argument,
and (2) the following additional fields which are
computed (or in the case of `attn_metadata`, updated)
by this function:
* attn_metadata
* encoder_input_tokens
* encoder_input_positions
Arguments:
* seq_group_metadata_list: list of sequence groups for which to
compute inputs
* model_inputs: model inputs data structure with decoder-oriented
fields already computed.
Return:
* Updated model inputs data structure
"""
if
len
(
seq_group_metadata_list
)
==
0
:
return
(
model_input
.
attn_metadata
,
None
,
None
)
# Since we are not supporting chunked prefill either the entire
# batch is prefill or it is decode
is_prompt
=
seq_group_metadata_list
[
0
].
is_prompt
# Build encoder inputs
encoder_seq_lens
:
List
[
int
]
=
[]
if
is_prompt
:
# Prefill phase.
cross_block_tables
=
self
.
_empty_int32_tensor
().
view
(
len
(
seq_group_metadata_list
),
-
1
)
# Extract input tokens/positions, cross-attention slot-mapping,
# & seq len from each sequence group metadata
(
encoder_input_tokens
,
encoder_input_positions
,
cross_slot_mapping
,
)
=
(
[],
[],
[],
)
for
seq_group_metadata
in
seq_group_metadata_list
:
# Build seq lens
seq_len
=
seq_group_metadata
.
encoder_seq_data
.
get_len
()
token_ids
=
seq_group_metadata
.
encoder_seq_data
.
get_token_ids
()
encoder_seq_lens
.
append
(
seq_len
)
# Build slot mapping
for
i
in
range
(
0
,
seq_len
):
block_number
=
seq_group_metadata
.
cross_block_table
[
i
//
self
.
block_size
]
block_offset
=
i
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
cross_slot_mapping
.
append
(
slot
)
# Build encoder input tokens
encoder_input_tokens
.
extend
(
token_ids
)
encoder_input_positions
.
extend
(
list
(
range
(
0
,
seq_len
)))
# Convert tokens/positions & cross-attention
# slot-mapping to encoder input tensors
encoder_input_tokens_tensor
=
self
.
_list_to_long_tensor
(
encoder_input_tokens
)
encoder_input_positions_tensor
=
self
.
_list_to_long_tensor
(
encoder_input_positions
)
cross_slot_mapping_tensor
=
self
.
_list_to_long_tensor
(
cross_slot_mapping
)
else
:
# Decode phase.
encoder_input_tokens_tensor
=
self
.
_empty_long_tensor
()
encoder_input_positions_tensor
=
self
.
_empty_long_tensor
()
cross_slot_mapping_tensor
=
self
.
_empty_long_tensor
()
# Extract cross-attention block tables &
# seq len from each sequence group metadata.
# Cross-attention block tables are empty
# during vLLM memory profiling.
cross_block_tables
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
for
_
in
range
(
len
(
seq_group_metadata
.
seq_data
)):
encoder_seq_lens
.
append
(
seq_group_metadata
.
encoder_seq_data
.
get_len
())
cross_block_table
=
seq_group_metadata
.
cross_block_table
cross_block_tables
.
append
([]
if
(
cross_block_table
is
None
)
else
cross_block_table
)
max_len_of_block_table
=
max
(
len
(
block_table
)
for
block_table
in
cross_block_tables
)
cross_block_tables
=
make_tensor_with_pad
(
cross_block_tables
,
max_len
=
max_len_of_block_table
,
pad
=
0
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
# Compute encoder sequence lengths & encoder
# sequence starting offset tensors
max_encoder_seq_len
=
max
(
encoder_seq_lens
,
default
=
0
)
encoder_seq_lens_tensor
=
self
.
_list_to_int32_tensor
(
encoder_seq_lens
)
encoder_seq_start_loc
=
torch
.
zeros
(
encoder_seq_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
torch
.
cumsum
(
encoder_seq_lens_tensor
,
dim
=
0
,
dtype
=
encoder_seq_start_loc
.
dtype
,
out
=
encoder_seq_start_loc
[
1
:])
# Update attention metadata with encoder-oriented attributes
attn_metadata
=
model_input
.
attn_metadata
assert
attn_metadata
is
not
None
(
attn_metadata
.
num_encoder_tokens
,
attn_metadata
.
encoder_seq_lens
,
attn_metadata
.
encoder_seq_lens_tensor
,
attn_metadata
.
max_encoder_seq_len
,
attn_metadata
.
cross_slot_mapping
,
attn_metadata
.
cross_block_tables
,
)
=
(
sum
(
encoder_seq_lens
),
encoder_seq_lens
,
encoder_seq_lens_tensor
,
max_encoder_seq_len
,
cross_slot_mapping_tensor
,
cross_block_tables
,
)
return
(
attn_metadata
,
encoder_input_tokens_tensor
,
encoder_input_positions_tensor
)
@
torch
.
no_grad
()
def
execute_model
(
self
,
model_input
:
EncoderDecoderModelInputForCPU
,
kv_caches
:
List
[
torch
.
Tensor
],
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
num_steps
:
int
=
1
,
)
->
Optional
[
List
[
SamplerOutput
]]:
if
num_steps
>
1
:
raise
ValueError
(
"CPU worker does not support multi-step execution."
)
model_executable
=
self
.
model
execute_model_kwargs
=
{
"input_ids"
:
model_input
.
input_tokens
,
"positions"
:
model_input
.
input_positions
,
"encoder_input_ids"
:
model_input
.
encoder_input_tokens
,
"encoder_positions"
:
model_input
.
encoder_input_positions
,
**
MultiModalKwargs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
or
{},
device
=
self
.
device
,
),
"intermediate_tensors"
:
intermediate_tensors
,
}
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
,
model_input
.
virtual_engine
):
hidden_states
=
model_executable
(
**
execute_model_kwargs
)
# Compute the logits.
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
model_input
.
sampling_metadata
)
# Only perform sampling in the driver worker.
if
not
self
.
is_driver_worker
:
return
[]
# Sample the next token.
output
=
self
.
sampler
(
logits
=
logits
,
sampling_metadata
=
model_input
.
sampling_metadata
,
)
return
[
output
]
vllm/worker/cpu_model_runner.py
deleted
100644 → 0
View file @
5ad884ee
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
dataclasses
import
weakref
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Set
,
Type
,
TypeVar
,
Union
)
import
torch
from
torch
import
nn
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.config
import
VllmConfig
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.lora.layers
import
LoRAMapping
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.worker_manager
import
LRUCacheWorkerLoRAManager
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.models
import
supports_lora
,
supports_multimodal
from
vllm.multimodal
import
(
BatchedTensorInputs
,
MultiModalKwargs
,
MultiModalPlaceholderMap
)
from
vllm.sequence
import
(
IntermediateTensors
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerInputBuilderBase
,
_add_attn_metadata_broadcastable_dict
,
_add_sampling_metadata_broadcastable_dict
,
_init_attn_metadata_from_tensor_dict
,
_init_sampling_metadata_from_tensor_dict
)
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
logger
=
init_logger
(
__name__
)
TModelInputForCPU
=
TypeVar
(
'TModelInputForCPU'
,
bound
=
"ModelInputForCPU"
)
_PAD_SLOT_ID
=
-
1
@
dataclass
(
frozen
=
True
)
class
ModelInputForCPU
(
ModelRunnerInputBase
):
"""
Base class contains metadata needed for the base model forward pass on CPU
"""
input_tokens
:
Optional
[
torch
.
Tensor
]
=
None
input_positions
:
Optional
[
torch
.
Tensor
]
=
None
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
attn_metadata
:
Optional
[
"AttentionMetadata"
]
=
None
multi_modal_kwargs
:
Optional
[
BatchedTensorInputs
]
=
None
virtual_engine
:
Optional
[
int
]
=
None
seq_lens
:
Optional
[
List
[
int
]]
=
None
query_lens
:
Optional
[
List
[
int
]]
=
None
lora_mapping
:
Optional
[
"LoRAMapping"
]
=
None
lora_requests
:
Optional
[
Set
[
LoRARequest
]]
=
None
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Union
[
int
,
torch
.
Tensor
]]:
tensor_dict
=
{
"input_tokens"
:
self
.
input_tokens
,
"input_positions"
:
self
.
input_positions
,
"token_type_ids"
:
self
.
token_type_ids
,
"multi_modal_kwargs"
:
self
.
multi_modal_kwargs
,
"lora_requests"
:
self
.
lora_requests
,
"lora_mapping"
:
self
.
lora_mapping
,
}
_add_attn_metadata_broadcastable_dict
(
tensor_dict
,
self
.
attn_metadata
)
return
tensor_dict
@
classmethod
def
from_broadcasted_tensor_dict
(
cls
:
Type
[
TModelInputForCPU
],
tensor_dict
:
Dict
[
str
,
Any
],
attn_backend
:
Optional
[
"AttentionBackend"
]
=
None
)
->
TModelInputForCPU
:
if
attn_backend
is
not
None
:
tensor_dict
=
_init_attn_metadata_from_tensor_dict
(
attn_backend
,
tensor_dict
)
return
cls
(
**
tensor_dict
)
@
dataclass
(
frozen
=
True
)
class
ModelInputForCPUWithSamplingMetadata
(
ModelInputForCPU
):
"""
Used by the ModelRunner.
"""
sampling_metadata
:
Optional
[
"SamplingMetadata"
]
=
None
is_prompt
:
Optional
[
bool
]
=
None
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
tensor_dict
=
{
"input_tokens"
:
self
.
input_tokens
,
"input_positions"
:
self
.
input_positions
,
"token_type_ids"
:
self
.
token_type_ids
,
"multi_modal_kwargs"
:
self
.
multi_modal_kwargs
,
}
_add_attn_metadata_broadcastable_dict
(
tensor_dict
,
self
.
attn_metadata
)
_add_sampling_metadata_broadcastable_dict
(
tensor_dict
,
self
.
sampling_metadata
)
return
tensor_dict
@
classmethod
def
from_broadcasted_tensor_dict
(
cls
,
tensor_dict
:
Dict
[
str
,
Any
],
attn_backend
:
Optional
[
"AttentionBackend"
]
=
None
,
)
->
"ModelInputForCPUWithSamplingMetadata"
:
tensor_dict
=
_init_sampling_metadata_from_tensor_dict
(
tensor_dict
)
if
attn_backend
is
not
None
:
tensor_dict
=
_init_attn_metadata_from_tensor_dict
(
attn_backend
,
tensor_dict
)
return
cls
(
**
tensor_dict
)
class
ModelInputForCPUBuilder
(
ModelRunnerInputBuilderBase
[
ModelInputForCPU
]):
class
ModelInputData
:
def
__init__
(
self
,
use_mrope
:
bool
):
self
.
use_mrope
=
use_mrope
self
.
input_tokens
:
List
[
int
]
=
[]
self
.
input_positions
:
List
[
int
]
=
[]
self
.
token_type_ids
:
Optional
[
List
[
int
]]
=
[]
self
.
seq_lens
:
List
[
int
]
=
[]
self
.
query_lens
:
List
[
int
]
=
[]
self
.
prefill_block_tables
:
List
[
List
[
int
]]
=
[]
self
.
decode_block_tables
:
List
[
List
[
int
]]
=
[]
self
.
max_decode_seq_len
:
int
=
0
self
.
num_prefills
:
int
=
0
self
.
num_prefill_tokens
:
int
=
0
self
.
num_decode_tokens
:
int
=
0
self
.
slot_mapping
:
List
[
int
]
=
[]
self
.
multi_modal_inputs_list
:
List
[
MultiModalKwargs
]
=
[]
self
.
multi_modal_placeholder_maps
:
Dict
[
str
,
MultiModalPlaceholderMap
]
=
defaultdict
(
MultiModalPlaceholderMap
)
self
.
input_mrope_positions
:
List
[
List
[
int
]]
=
[[]
for
_
in
range
(
3
)]
def
__init__
(
self
,
runner
:
"CPUModelRunner"
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
None
:
super
().
__init__
()
self
.
runner
=
runner
self
.
chunked_prefill
=
(
runner
.
scheduler_config
.
chunked_prefill_enabled
or
runner
.
cache_config
.
enable_prefix_caching
)
self
.
model_input_cls
=
self
.
runner
.
_model_input_cls
self
.
attn_backend
=
self
.
runner
.
attn_backend
self
.
sliding_window
=
self
.
runner
.
sliding_window
self
.
block_size
=
self
.
runner
.
block_size
self
.
device
=
self
.
runner
.
device
self
.
enable_lora
=
self
.
runner
.
lora_config
is
not
None
if
self
.
runner
.
attn_backend
is
not
None
:
# spec decode (e.g. Medusa) does not have atten backend
attn_backend
=
self
.
runner
.
attn_backend
self
.
att_metadata_builder
=
attn_backend
.
get_builder_cls
()(
self
)
def
prepare
(
self
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
None
:
self
.
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
self
.
input_data
=
ModelInputForCPUBuilder
.
ModelInputData
(
self
.
runner
.
model_config
.
uses_mrope
)
self
.
att_metadata_builder
.
prepare
()
def
add_seq_group
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
):
self
.
seq_group_metadata_list
.
append
(
seq_group_metadata
)
def
set_seq_group_list
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]):
self
.
seq_group_metadata_list
=
seq_group_metadata_list
def
build
(
self
)
->
ModelInputForCPU
:
self
.
_build_input_data
()
input_data
=
self
.
input_data
input_tokens
=
torch
.
tensor
(
input_data
.
input_tokens
,
dtype
=
torch
.
long
,
device
=
"cpu"
)
input_positions
=
torch
.
tensor
(
input_data
.
input_positions
if
not
any
(
input_data
.
input_mrope_positions
)
else
input_data
.
input_mrope_positions
,
dtype
=
torch
.
long
,
device
=
"cpu"
)
token_type_ids
=
torch
.
tensor
(
input_data
.
token_type_ids
,
dtype
=
torch
.
long
,
device
=
"cpu"
)
\
if
input_data
.
token_type_ids
else
None
# For multi-modal models
multi_modal_kwargs
=
None
if
len
(
input_data
.
multi_modal_inputs_list
)
!=
0
:
multi_modal_kwargs
=
MultiModalKwargs
.
batch
(
input_data
.
multi_modal_inputs_list
)
attn_metadata
=
self
.
att_metadata_builder
.
build
(
input_data
.
seq_lens
,
input_data
.
query_lens
,
-
1
,
-
1
)
is_prompt
=
(
self
.
seq_group_metadata_list
[
0
].
is_prompt
if
self
.
seq_group_metadata_list
else
None
)
# LoRA data.
lora_requests
=
set
()
lora_mapping
=
None
if
self
.
enable_lora
:
lora_requests
=
set
(
seq
.
lora_request
for
seq
in
self
.
seq_group_metadata_list
if
seq
.
lora_request
is
not
None
)
lora_mapping
=
self
.
_prepare_lora_input
(
self
.
seq_group_metadata_list
,
is_prompt
)
return
self
.
model_input_cls
(
input_tokens
=
input_tokens
,
input_positions
=
input_positions
,
token_type_ids
=
token_type_ids
,
seq_lens
=
input_data
.
seq_lens
,
query_lens
=
input_data
.
query_lens
,
attn_metadata
=
attn_metadata
,
multi_modal_kwargs
=
multi_modal_kwargs
,
lora_mapping
=
lora_mapping
,
lora_requests
=
lora_requests
)
def
_build_input_data
(
self
):
for
seq_group_metadata
in
self
.
seq_group_metadata_list
:
for
seq_id
,
seq_data
in
seq_group_metadata
.
seq_data
.
items
():
if
seq_group_metadata
.
is_prompt
:
self
.
_compute_prompt_input_tokens
(
self
.
input_data
,
seq_group_metadata
,
seq_data
,
seq_id
)
if
seq_group_metadata
.
multi_modal_data
:
self
.
_compute_multi_modal_input
(
seq_group_metadata
,
seq_data
)
else
:
self
.
_compute_decode_input_tokens
(
self
.
input_data
,
seq_group_metadata
,
seq_data
,
seq_id
)
def
_compute_decode_input_tokens
(
self
,
data
:
ModelInputData
,
seq_group_metadata
:
SequenceGroupMetadata
,
seq_data
:
SequenceData
,
seq_id
:
int
):
"""
Compute decode input tokens, positions, block table and slot mapping.
"""
block_size
=
self
.
runner
.
block_size
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
seq_len
=
seq_data
.
get_len
()
context_len
=
seq_data
.
get_num_computed_tokens
()
tokens
=
seq_data
.
get_last_token_id
()
token_positions
=
seq_len
-
1
block_number
=
block_table
[
token_positions
//
block_size
]
block_offset
=
token_positions
%
block_size
slot
=
block_number
*
block_size
+
block_offset
# For paged_attention kernel
if
self
.
runner
.
sliding_window
:
start_idx
=
max
(
0
,
seq_len
-
self
.
runner
.
sliding_window
)
start_block
=
start_idx
//
block_size
start_idx
=
start_block
*
block_size
seq_len
=
seq_len
-
start_idx
block_table
=
block_table
[
start_block
:]
# For MRotaryEmbedding
if
seq_data
.
mrope_position_delta
is
not
None
:
next_pos
=
MRotaryEmbedding
.
get_next_input_positions
(
seq_data
.
mrope_position_delta
,
context_len
,
seq_len
,
)
for
idx
in
range
(
3
):
data
.
input_mrope_positions
[
idx
].
extend
(
# type: ignore
next_pos
[
idx
])
else
:
data
.
input_positions
.
append
(
token_positions
)
# type: ignore
# Update fields
data
.
input_tokens
.
append
(
tokens
)
data
.
max_decode_seq_len
=
max
(
data
.
max_decode_seq_len
,
seq_len
)
data
.
num_decode_tokens
+=
1
data
.
slot_mapping
.
append
(
slot
)
data
.
decode_block_tables
.
append
(
block_table
)
data
.
query_lens
.
append
(
1
)
data
.
seq_lens
.
append
(
seq_len
)
def
_compute_prompt_input_tokens
(
self
,
data
:
ModelInputData
,
seq_group_metadata
:
SequenceGroupMetadata
,
seq_data
:
SequenceData
,
seq_id
:
int
):
"""
Compute prompt input tokens, positions, block table and slot mapping.
"""
token_chunk_size
=
seq_group_metadata
.
token_chunk_size
block_size
=
self
.
runner
.
block_size
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
seq_len
=
seq_data
.
get_len
()
context_len
=
seq_data
.
get_num_computed_tokens
()
seq_len
=
min
(
seq_len
,
context_len
+
token_chunk_size
)
# For prefix caching
prefix_cache_block_num
=
len
(
seq_group_metadata
.
computed_block_nums
)
if
prefix_cache_block_num
>
0
:
prefix_cache_len
=
(
prefix_cache_block_num
*
self
.
runner
.
block_size
)
if
prefix_cache_len
<=
context_len
:
# We already passed the cache hit region,
# so do normal computation.
pass
elif
context_len
<
prefix_cache_len
<
seq_len
:
# Partial hit. Compute the missing part.
context_len
=
prefix_cache_len
token_chunk_size
=
seq_len
-
context_len
elif
seq_len
<=
prefix_cache_len
:
# Full hit. Only compute the last token to avoid
# erroneous behavior. FIXME: Ideally we should directly
# mark all tokens as computed in the scheduler and do not
# schedule this sequence, so this case should not happen.
context_len
=
seq_len
-
1
token_chunk_size
=
1
tokens
=
seq_data
.
get_token_ids
()
tokens
=
tokens
[
context_len
:
seq_len
]
token_positions
=
range
(
context_len
,
seq_len
)
token_types
=
seq_group_metadata
.
token_type_ids
# For encoder-only models, the block_table is None,
# and there is no need to initialize the slot_mapping.
if
block_table
is
not
None
:
slot_mapping
=
[
_PAD_SLOT_ID
]
*
len
(
token_positions
)
for
i
,
pos
in
enumerate
(
token_positions
):
block_number
=
block_table
[
pos
//
block_size
]
block_offset
=
pos
%
block_size
slot
=
block_number
*
block_size
+
block_offset
slot_mapping
[
i
]
=
slot
data
.
slot_mapping
.
extend
(
slot_mapping
)
# The MROPE positions are prepared in _compute_multi_modal_input
data
.
input_positions
.
extend
(
token_positions
)
if
data
.
token_type_ids
is
not
None
:
data
.
token_type_ids
.
extend
(
token_types
if
token_types
else
[])
# Update fields
data
.
input_tokens
.
extend
(
tokens
)
data
.
num_prefills
+=
1
data
.
num_prefill_tokens
+=
len
(
tokens
)
data
.
query_lens
.
append
(
len
(
tokens
))
data
.
prefill_block_tables
.
append
(
block_table
)
data
.
seq_lens
.
append
(
seq_len
)
def
_compute_multi_modal_input
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
,
seq_data
:
SequenceData
):
computed_len
=
seq_data
.
get_num_computed_tokens
()
seq_len
=
self
.
input_data
.
seq_lens
[
-
1
]
# NOTE: mm_kwargs only includes the subset of multi-modal items that
# intersect with the current prefill positions.
mm_kwargs
,
placeholder_maps
=
MultiModalPlaceholderMap
.
from_seq_group
(
seq_group_metadata
,
range
(
computed_len
,
seq_len
))
if
not
mm_kwargs
:
return
# special processing for mrope position deltas.
if
self
.
runner
.
model_config
.
uses_mrope
:
assert
not
self
.
chunked_prefill
,
\
"MROPE on CPU does not support chunked-prefill."
image_grid_thw
=
mm_kwargs
.
get
(
"image_grid_thw"
,
None
)
video_grid_thw
=
mm_kwargs
.
get
(
"video_grid_thw"
,
None
)
audio_feature_lengths
=
mm_kwargs
.
get
(
"audio_feature_lengths"
,
None
)
assert
(
image_grid_thw
is
not
None
or
video_grid_thw
is
not
None
or
audio_feature_lengths
is
not
None
),
(
"mrope embedding type requires multi-modal input mapper "
"returns 'image_grid_thw' or 'video_grid_thw' or "
"'audio_feature_lengths'."
)
second_per_grid_ts
=
mm_kwargs
.
get
(
"second_per_grid_ts"
,
None
)
use_audio_in_video
=
mm_kwargs
.
get
(
"use_audio_in_video"
,
False
)
hf_config
=
self
.
runner
.
model_config
.
hf_config
token_ids
=
seq_data
.
get_token_ids
()
mrope_positions
,
mrope_position_delta
=
\
MRotaryEmbedding
.
get_input_positions
(
token_ids
,
hf_config
=
hf_config
,
image_grid_thw
=
image_grid_thw
,
video_grid_thw
=
video_grid_thw
,
second_per_grid_ts
=
second_per_grid_ts
,
context_len
=
computed_len
,
audio_feature_lengths
=
audio_feature_lengths
,
use_audio_in_video
=
use_audio_in_video
,
)
seq_data
.
mrope_position_delta
=
mrope_position_delta
for
i
in
range
(
3
):
self
.
input_data
.
input_mrope_positions
[
# type: ignore
i
].
extend
(
mrope_positions
[
i
])
self
.
input_data
.
multi_modal_inputs_list
.
append
(
mm_kwargs
)
for
modality
,
placeholder_map
in
placeholder_maps
.
items
():
self
.
input_data
.
multi_modal_placeholder_maps
[
modality
].
extend
(
placeholder_map
)
def
_prepare_lora_input
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
is_prefill
:
bool
)
->
LoRAMapping
:
index_mapping
=
[]
prompt_mapping
=
[]
for
seq
in
seq_group_metadata_list
:
lora_id
=
seq
.
lora_int_id
query_len
=
seq
.
token_chunk_size
index_mapping
+=
[
lora_id
]
*
query_len
prompt_mapping
+=
[
lora_id
]
*
(
query_len
if
seq
.
sampling_params
and
seq
.
sampling_params
.
prompt_logprobs
is
not
None
else
1
)
return
LoRAMapping
(
index_mapping
=
tuple
(
index_mapping
),
prompt_mapping
=
tuple
(
prompt_mapping
),
is_prefill
=
is_prefill
)
class
CPUModelRunnerBase
(
ModelRunnerBase
[
TModelInputForCPU
]):
"""
Helper class for shared methods between CPU model runners.
"""
_model_input_cls
:
Type
[
TModelInputForCPU
]
_builder_cls
:
Type
[
ModelInputForCPUBuilder
]
builder
:
ModelInputForCPUBuilder
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
is_driver_worker
:
bool
=
False
,
return_hidden_states
:
bool
=
False
,
*
args
,
**
kwargs
,
):
ModelRunnerBase
.
__init__
(
self
,
vllm_config
)
model_config
=
self
.
model_config
cache_config
=
self
.
cache_config
self
.
is_driver_worker
=
is_driver_worker
self
.
return_hidden_states
=
return_hidden_states
self
.
device
=
self
.
device_config
.
device
self
.
pin_memory
=
False
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
sliding_window
=
model_config
.
get_sliding_window
()
self
.
block_size
=
cache_config
.
block_size
num_attn_heads
=
self
.
model_config
.
get_num_attention_heads
(
self
.
parallel_config
)
needs_attn_backend
=
(
num_attn_heads
!=
0
or
self
.
model_config
.
is_attention_free
)
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
dtype
,
self
.
kv_cache_dtype
,
self
.
block_size
,
self
.
model_config
.
is_attention_free
,
use_mla
=
self
.
model_config
.
use_mla
,
)
if
needs_attn_backend
else
None
# Lazy initialization.
self
.
model
:
nn
.
Module
# Set after init_Model
# Set after load_model.
self
.
lora_manager
:
Optional
[
LRUCacheWorkerLoRAManager
]
=
None
self
.
sampler
=
get_sampler
()
if
hasattr
(
self
,
"_builder_cls"
):
# multi-step model runner does not have `_builder_cls`
self
.
builder
=
self
.
_builder_cls
(
weakref
.
proxy
(
self
))
def
load_model
(
self
)
->
None
:
self
.
model
=
get_model
(
vllm_config
=
self
.
vllm_config
)
if
self
.
lora_config
:
assert
supports_lora
(
self
.
model
),
f
"
{
self
.
model
.
__class__
.
__name__
}
does not support LoRA yet."
if
supports_multimodal
(
self
.
model
):
logger
.
warning
(
"Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model."
)
# Use get_text_config() in case of multimodal models
text_config
=
self
.
model_config
.
hf_config
.
get_text_config
()
self
.
lora_manager
=
LRUCacheWorkerLoRAManager
(
self
.
scheduler_config
.
max_num_seqs
,
self
.
scheduler_config
.
max_num_batched_tokens
,
self
.
vocab_size
,
self
.
lora_config
,
self
.
device
,
self
.
model
.
embedding_modules
,
self
.
model
.
embedding_padding_modules
,
max_position_embeddings
=
text_config
.
max_position_embeddings
,
)
self
.
model
=
self
.
lora_manager
.
create_lora_manager
(
self
.
model
)
def
get_model
(
self
)
->
nn
.
Module
:
return
self
.
model
def
_prepare_model_input_tensors
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
TModelInputForCPU
:
"""Helper method to prepare the model input based on a given sequence
group. Prepares metadata needed for the base model forward pass but not
metadata for possible additional steps, e.g., sampling.
"""
self
.
builder
.
prepare
(
finished_requests_ids
)
self
.
builder
.
set_seq_group_list
(
seq_group_metadata_list
)
return
self
.
builder
.
build
()
# type: ignore
@
property
def
vocab_size
(
self
)
->
int
:
return
self
.
model_config
.
get_vocab_size
()
def
remove_all_loras
(
self
):
if
not
self
.
lora_manager
:
raise
RuntimeError
(
"LoRA is not enabled."
)
self
.
lora_manager
.
remove_all_adapters
()
def
set_active_loras
(
self
,
lora_requests
:
Set
[
LoRARequest
],
lora_mapping
:
LoRAMapping
)
->
None
:
if
not
self
.
lora_manager
:
raise
RuntimeError
(
"LoRA is not enabled."
)
self
.
lora_manager
.
set_active_adapters
(
lora_requests
,
lora_mapping
)
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
if
not
self
.
lora_manager
:
raise
RuntimeError
(
"LoRA is not enabled."
)
return
self
.
lora_manager
.
add_adapter
(
lora_request
)
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
if
not
self
.
lora_manager
:
raise
RuntimeError
(
"LoRA is not enabled."
)
return
self
.
lora_manager
.
remove_adapter
(
lora_id
)
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
if
not
self
.
lora_manager
:
raise
RuntimeError
(
"LoRA is not enabled."
)
return
self
.
lora_manager
.
pin_adapter
(
lora_id
)
def
list_loras
(
self
)
->
Set
[
int
]:
if
not
self
.
lora_manager
:
raise
RuntimeError
(
"LoRA is not enabled."
)
return
self
.
lora_manager
.
list_adapters
()
class
CPUModelRunner
(
CPUModelRunnerBase
[
ModelInputForCPUWithSamplingMetadata
]):
_model_input_cls
:
Type
[
ModelInputForCPUWithSamplingMetadata
]
=
(
ModelInputForCPUWithSamplingMetadata
)
_builder_cls
:
Type
[
ModelInputForCPUBuilder
]
=
ModelInputForCPUBuilder
def
make_model_input_from_broadcasted_tensor_dict
(
self
,
tensor_dict
:
Dict
[
str
,
Any
],
)
->
ModelInputForCPUWithSamplingMetadata
:
return
ModelInputForCPUWithSamplingMetadata
.
from_broadcasted_tensor_dict
(
# noqa: E501
tensor_dict
,
attn_backend
=
self
.
attn_backend
,
)
def
prepare_model_input
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
ModelInputForCPUWithSamplingMetadata
:
"""Prepare the model input based on a given sequence group, including
metadata for the sampling step.
"""
model_input
=
self
.
_prepare_model_input_tensors
(
seq_group_metadata_list
,
finished_requests_ids
)
# Sampling metadata is only required for the final pp group
generators
=
self
.
get_generators
(
finished_requests_ids
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
model_input
.
seq_lens
,
model_input
.
query_lens
,
self
.
device
,
pin_memory
=
False
,
generators
=
generators
)
is_prompt
=
(
seq_group_metadata_list
[
0
].
is_prompt
if
seq_group_metadata_list
else
None
)
return
dataclasses
.
replace
(
model_input
,
sampling_metadata
=
sampling_metadata
,
virtual_engine
=
virtual_engine
,
is_prompt
=
is_prompt
)
@
torch
.
no_grad
()
def
execute_model
(
self
,
model_input
:
ModelInputForCPUWithSamplingMetadata
,
kv_caches
:
List
[
torch
.
Tensor
],
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
num_steps
:
int
=
1
,
previous_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Optional
[
List
[
SamplerOutput
]]:
if
num_steps
>
1
:
raise
ValueError
(
"CPU worker does not support multi-step execution."
)
if
self
.
lora_config
:
assert
model_input
.
lora_requests
is
not
None
assert
model_input
.
lora_mapping
is
not
None
self
.
set_active_loras
(
model_input
.
lora_requests
,
model_input
.
lora_mapping
)
model_executable
=
self
.
model
multimodal_kwargs
=
{}
if
model_input
.
multi_modal_kwargs
is
not
None
:
multimodal_kwargs
=
MultiModalKwargs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
,
device
=
self
.
device
,
)
execute_model_kwargs
=
{}
if
previous_hidden_states
is
not
None
:
execute_model_kwargs
.
update
(
{
"previous_hidden_states"
:
previous_hidden_states
})
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
,
model_input
.
virtual_engine
):
hidden_states
=
model_executable
(
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
intermediate_tensors
=
intermediate_tensors
,
**
execute_model_kwargs
,
**
multimodal_kwargs
,
)
# Compute the logits.
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
model_input
.
sampling_metadata
)
# Only perform sampling in the driver worker.
if
not
self
.
is_driver_worker
:
return
[]
# Sample the next token.
output
=
self
.
sampler
(
logits
=
logits
,
sampling_metadata
=
model_input
.
sampling_metadata
,
)
if
self
.
return_hidden_states
:
# we only need to pass hidden states of most recent token
if
model_input
.
is_prompt
:
output
.
prefill_hidden_states
=
hidden_states
output
.
hidden_states
=
hidden_states
return
[
output
]
def
generate_proposals
(
self
,
*
args
,
**
kwargs
):
return
self
.
model
.
generate_proposals
(
*
args
,
**
kwargs
)
vllm/worker/cpu_pooling_model_runner.py
deleted
100644 → 0
View file @
5ad884ee
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
dataclasses
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
Union
import
torch
from
vllm.forward_context
import
set_forward_context
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.pooling_params
import
PoolingParams
from
vllm.sequence
import
(
IntermediateTensors
,
PoolerOutput
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.worker.cpu_model_runner
import
(
CPUModelRunnerBase
,
ModelInputForCPU
,
ModelInputForCPUBuilder
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
ModelInputForCPUWithPoolingMetadata
(
ModelInputForCPU
):
"""
Used by the CPUPoolingModelRunner.
"""
pooling_metadata
:
Optional
[
"PoolingMetadata"
]
=
None
class
CPUPoolingModelRunner
(
CPUModelRunnerBase
[
ModelInputForCPUWithPoolingMetadata
]):
_model_input_cls
:
Type
[
ModelInputForCPUWithPoolingMetadata
]
=
(
ModelInputForCPUWithPoolingMetadata
)
_builder_cls
:
Type
[
ModelInputForCPUBuilder
]
=
ModelInputForCPUBuilder
@
torch
.
inference_mode
()
def
execute_model
(
self
,
model_input
:
ModelInputForCPUWithPoolingMetadata
,
kv_caches
:
List
[
torch
.
Tensor
],
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
num_steps
:
int
=
1
,
)
->
Optional
[
Union
[
List
[
PoolerOutput
],
IntermediateTensors
]]:
if
num_steps
>
1
:
raise
ValueError
(
"CPU worker does not support multi-step execution."
)
model_executable
=
self
.
model
cross_enc_kwargs
=
{}
if
model_input
.
token_type_ids
is
not
None
:
cross_enc_kwargs
[
"token_type_ids"
]
=
model_input
.
token_type_ids
execute_model_kwargs
=
{
"input_ids"
:
model_input
.
input_tokens
,
"positions"
:
model_input
.
input_positions
,
**
MultiModalKwargs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
or
{},
device
=
self
.
device
,
),
**
cross_enc_kwargs
,
"intermediate_tensors"
:
intermediate_tensors
,
}
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
,
model_input
.
virtual_engine
):
hidden_states
=
model_executable
(
**
execute_model_kwargs
)
# Only perform pooling in the driver worker.
if
not
self
.
is_driver_worker
:
return
[]
return
[
self
.
model
.
pooler
(
hidden_states
=
hidden_states
,
pooling_metadata
=
model_input
.
pooling_metadata
)
]
def
make_model_input_from_broadcasted_tensor_dict
(
self
,
tensor_dict
:
Dict
[
str
,
Any
])
->
ModelInputForCPUWithPoolingMetadata
:
return
ModelInputForCPUWithPoolingMetadata
.
from_broadcasted_tensor_dict
(
tensor_dict
,
attn_backend
=
self
.
attn_backend
,
)
def
prepare_model_input
(
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
ModelInputForCPUWithPoolingMetadata
:
assert
seq_group_metadata_list
is
not
None
model_input
=
self
.
_prepare_model_input_tensors
(
seq_group_metadata_list
,
finished_requests_ids
)
# Prepare PoolingMetadata.
assert
model_input
.
seq_lens
is
not
None
pooling_metadata
=
self
.
_prepare_pooling
(
seq_group_metadata_list
,
model_input
.
seq_lens
)
return
dataclasses
.
replace
(
model_input
,
virtual_engine
=
virtual_engine
,
pooling_metadata
=
pooling_metadata
)
def
_prepare_pooling
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
prompt_lens
:
List
[
int
],
)
->
PoolingMetadata
:
"""Prepare PoolingMetadata for the sequence group metadata list."""
seq_groups
:
List
[
Tuple
[
List
[
int
],
PoolingParams
]]
=
[]
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
pooling_params
=
seq_group_metadata
.
pooling_params
seq_groups
.
append
((
seq_ids
,
pooling_params
))
seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
for
seq_group_metadata
in
seq_group_metadata_list
:
seq_data
.
update
(
seq_group_metadata
.
seq_data
)
pooling_metadata
=
PoolingMetadata
(
seq_groups
=
seq_groups
,
seq_data
=
seq_data
,
prompt_lens
=
prompt_lens
,
)
return
pooling_metadata
vllm/worker/cpu_worker.py
deleted
100644 → 0
View file @
5ad884ee
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A CPU worker class."""
import
os
from
importlib
import
util
from
typing
import
List
,
Optional
,
Set
,
Tuple
,
Type
import
torch
import
torch.distributed
import
vllm.envs
as
envs
from
vllm.attention
import
get_attn_backend
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
ModelConfig
,
ParallelConfig
,
VllmConfig
)
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
)
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor
import
set_random_seed
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
bind_kv_cache
from
vllm.worker.cpu_enc_dec_model_runner
import
CPUEncoderDecoderModelRunner
from
vllm.worker.cpu_model_runner
import
CPUModelRunner
,
CPUModelRunnerBase
from
vllm.worker.cpu_pooling_model_runner
import
CPUPoolingModelRunner
from
vllm.worker.worker_base
import
(
LocalOrDistributedWorkerBase
,
WorkerBase
,
WorkerInput
)
logger
=
init_logger
(
__name__
)
class
CPUCacheEngine
:
"""Manages the KV cache for CPU backend.
This class is responsible for initializing and managing CPU KV
caches. It also provides methods for performing KV cache operations, such
as copying.
"""
def
__init__
(
self
,
cache_config
:
CacheConfig
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
device_config
:
DeviceConfig
)
->
None
:
assert
device_config
.
device_type
==
"cpu"
self
.
cache_config
=
cache_config
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
head_size
=
model_config
.
get_head_size
()
self
.
num_layers
=
model_config
.
get_num_layers
(
parallel_config
)
self
.
num_heads
=
model_config
.
get_num_kv_heads
(
parallel_config
)
self
.
block_size
=
cache_config
.
block_size
# Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks
# for CPU backend, because we want to reuse KV cache management
# in the scheduler.
self
.
num_cpu_blocks
=
cache_config
.
num_gpu_blocks
self
.
dtype
=
CPUCacheEngine
.
get_kv_cache_dtype
(
cache_config
,
model_config
)
# Get attention backend.
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
dtype
,
cache_config
.
cache_dtype
,
self
.
block_size
,
self
.
model_config
.
is_attention_free
,
use_mla
=
self
.
model_config
.
use_mla
,
)
# Initialize the cache.
self
.
cpu_cache
=
self
.
_allocate_kv_cache
(
self
.
num_cpu_blocks
)
def
_allocate_kv_cache
(
self
,
num_blocks
:
int
,
)
->
List
[
torch
.
Tensor
]:
"""Allocates KV cache on CPU."""
kv_cache_shape
=
self
.
attn_backend
.
get_kv_cache_shape
(
num_blocks
,
self
.
block_size
,
self
.
num_heads
,
self
.
head_size
)
kv_cache
:
List
[
torch
.
Tensor
]
=
[]
for
_
in
range
(
self
.
num_layers
):
kv_cache
.
append
(
torch
.
empty
(
kv_cache_shape
,
dtype
=
self
.
dtype
,
device
=
"cpu"
))
return
kv_cache
def
swap_in
(
self
,
src_to_dst
:
torch
.
Tensor
)
->
None
:
raise
NotImplementedError
(
"Swap is not supported in CPUCacheEngine."
)
def
swap_out
(
self
,
src_to_dst
:
torch
.
Tensor
)
->
None
:
raise
NotImplementedError
(
"Swap is not supported in CPUCacheEngine."
)
def
copy
(
self
,
src_to_dsts
:
torch
.
Tensor
)
->
None
:
self
.
attn_backend
.
copy_blocks
(
self
.
cpu_cache
,
src_to_dsts
)
@
staticmethod
def
get_kv_cache_dtype
(
cache_config
:
CacheConfig
,
model_config
:
ModelConfig
):
if
cache_config
.
cache_dtype
==
"auto"
:
return
model_config
.
dtype
elif
cache_config
.
cache_dtype
in
[
"fp8"
,
"fp8_e5m2"
]:
return
torch
.
float8_e5m2
else
:
raise
NotImplementedError
(
f
"Unsupported KV cache type "
f
"
{
cache_config
.
cache_dtype
}
."
)
@
staticmethod
def
get_cache_block_size
(
cache_config
:
CacheConfig
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
)
->
int
:
head_size
=
model_config
.
get_head_size
()
num_heads
=
model_config
.
get_num_kv_heads
(
parallel_config
)
num_layers
=
model_config
.
get_num_layers
(
parallel_config
)
key_cache_block
=
cache_config
.
block_size
*
num_heads
*
head_size
value_cache_block
=
key_cache_block
if
not
model_config
.
use_mla
else
0
total
=
num_layers
*
(
key_cache_block
+
value_cache_block
)
dtype
=
CPUCacheEngine
.
get_kv_cache_dtype
(
cache_config
,
model_config
)
dtype_size
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
return
dtype_size
*
total
class
CPUWorker
(
LocalOrDistributedWorkerBase
):
"""A worker class that executes (a partition of) the model on a CPU socket.
Each worker is associated with a single CPU socket. The worker is
responsible for maintaining the KV cache and executing the model on the
CPU. In case of distributed inference, each worker is assigned a partition
of the model.
"""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
local_rank
:
int
,
rank
:
int
,
distributed_init_method
:
str
,
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
is_driver_worker
:
bool
=
False
,
model_runner_cls
:
Optional
[
Type
[
CPUModelRunner
]]
=
None
,
)
->
None
:
WorkerBase
.
__init__
(
self
,
vllm_config
=
vllm_config
)
self
.
local_rank
=
local_rank
self
.
rank
=
rank
vllm_config
.
parallel_config
.
rank
=
rank
self
.
distributed_init_method
=
distributed_init_method
self
.
is_driver_worker
=
is_driver_worker
if
self
.
is_driver_worker
:
assert
self
.
rank
==
0
,
"The driver worker must have rank 0."
if
self
.
model_config
.
trust_remote_code
:
# note: lazy import to avoid importing torch before initializing
from
vllm.utils
import
init_cached_hf_modules
init_cached_hf_modules
()
# Setup OpenMP threads affinity.
omp_cpuids
=
envs
.
VLLM_CPU_OMP_THREADS_BIND
self
.
local_omp_cpuid
=
"all"
if
omp_cpuids
==
"auto"
:
self
.
local_omp_cpuid
=
self
.
get_cpus_id_binding_based_on_numa_nodes
(
)
else
:
self
.
local_omp_cpuid
=
omp_cpuids
.
split
(
"|"
)[
rank
]
# Return hidden states from target model if the draft model is an
# mlp_speculator
speculative_config
=
self
.
speculative_config
model_config
=
self
.
model_config
speculative_args
=
{}
if
speculative_config
is
None
\
or
(
speculative_config
.
draft_model_config
.
model
==
model_config
.
model
)
\
or
(
speculative_config
.
draft_model_config
.
hf_config
.
model_type
not
in
[
"medusa"
,
"mlp_speculator"
,
"eagle"
])
\
else
{
"return_hidden_states"
:
True
}
ModelRunnerClass
:
Type
[
CPUModelRunnerBase
]
=
CPUModelRunner
if
self
.
model_config
.
runner_type
==
"pooling"
:
ModelRunnerClass
=
CPUPoolingModelRunner
elif
self
.
model_config
.
is_encoder_decoder
:
ModelRunnerClass
=
CPUEncoderDecoderModelRunner
self
.
model_runner
:
CPUModelRunnerBase
=
ModelRunnerClass
(
vllm_config
=
vllm_config
,
kv_cache_dtype
=
kv_cache_dtype
,
is_driver_worker
=
is_driver_worker
,
**
speculative_args
,
)
if
model_runner_cls
is
not
None
:
self
.
model_runner
=
model_runner_cls
(
self
.
model_runner
)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self
.
cache_engine
:
List
[
CPUCacheEngine
]
# Initialize cpu_cache as pooling models don't initialize kv_caches
self
.
cpu_cache
:
Optional
[
List
[
List
[
torch
.
Tensor
]]]
=
None
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if
envs
.
VLLM_TORCH_PROFILER_DIR
:
torch_profiler_trace_dir
=
envs
.
VLLM_TORCH_PROFILER_DIR
logger
.
info
(
"Profiling enabled. Traces will be saved to: %s"
,
torch_profiler_trace_dir
)
self
.
profiler
=
torch
.
profiler
.
profile
(
activities
=
[
torch
.
profiler
.
ProfilerActivity
.
CPU
,
],
with_stack
=
True
,
on_trace_ready
=
torch
.
profiler
.
tensorboard_trace_handler
(
torch_profiler_trace_dir
,
use_gzip
=
True
))
else
:
self
.
profiler
=
None
def
start_profile
(
self
):
if
self
.
profiler
is
None
:
raise
RuntimeError
(
"Profiler is not enabled."
)
self
.
profiler
.
start
()
def
stop_profile
(
self
):
if
self
.
profiler
is
None
:
raise
RuntimeError
(
"Profiler is not enabled."
)
self
.
profiler
.
stop
()
def
init_device
(
self
)
->
None
:
if
self
.
local_omp_cpuid
!=
"all"
:
ret
=
torch
.
ops
.
_C_utils
.
init_cpu_threads_env
(
self
.
local_omp_cpuid
)
if
ret
:
logger
.
info
(
ret
)
# Note: unique identifier for creating allreduce shared memory
os
.
environ
[
"VLLM_DIST_IDENT"
]
=
self
.
distributed_init_method
.
split
(
":"
)[
-
1
]
self
.
device
=
torch
.
device
(
"cpu"
)
self
.
init_distributed_environment
()
# Set random seed.
set_random_seed
(
self
.
model_config
.
seed
)
def
load_model
(
self
):
self
.
model_runner
.
load_model
()
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
"""Determine the number of blocks available for the KV cache.
This determines how many KV blocks can fit into the configured CPU
KV cache space.
Note that since vLLM assumes a block resides on GPU if it can be
modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0.
This allows us to reuse the scheduler of vLLM without generalizing it
to different devices.
"""
# For CPU device, the block number will be calculated based on the
# cpu_kvcache_space.
cache_block_size
=
self
.
get_cache_block_size_bytes
()
num_cpu_blocks
=
int
(
self
.
cache_config
.
cpu_kvcache_space_bytes
//
cache_block_size
)
num_cpu_blocks
=
max
(
num_cpu_blocks
,
0
)
# Note: To reuse the cache management procedure,
# use cpu cache as 'gpu cache'.
num_gpu_blocks
=
num_cpu_blocks
num_cpu_blocks
=
0
return
num_gpu_blocks
,
num_cpu_blocks
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
)
->
None
:
"""Initialize the KV cache. Currently, swappable CPU memory is not
supported.
Since this worker does not support GPUs, we use the num_gpu_blocks to
determine how many non-swappable CPU blocks to allocate.
"""
assert
(
num_cpu_blocks
==
0
),
f
"
{
type
(
self
)
}
does not support swappable cache"
# Note: To reuse the cache management procedure,
# use cpu cache as 'gpu cache'.
num_cpu_blocks
=
num_gpu_blocks
self
.
_validate_num_cpu_blocks
(
num_cpu_blocks
)
self
.
cache_config
.
num_gpu_blocks
=
num_cpu_blocks
self
.
cache_config
.
num_cpu_blocks
=
0
# Initialize the cache.
self
.
_init_cache_engine
()
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
return
self
.
model_runner
.
add_lora
(
lora_request
)
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
model_runner
.
remove_lora
(
lora_id
)
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
model_runner
.
pin_lora
(
lora_id
)
def
list_loras
(
self
)
->
Set
[
int
]:
return
self
.
model_runner
.
list_loras
()
def
_validate_num_cpu_blocks
(
self
,
num_cpu_blocks
:
int
)
->
None
:
"""Raise errors if the num_cpu_blocks is invalid.
"""
if
num_cpu_blocks
<=
0
:
raise
ValueError
(
"No available memory for the cache blocks. "
"Try increasing `VLLM_CPU_KVCACHE_SPACE` when "
"initializing the engine."
)
max_seq_len
=
self
.
cache_config
.
block_size
*
num_cpu_blocks
if
self
.
model_config
.
max_model_len
>
max_seq_len
:
raise
ValueError
(
f
"The model's max seq len (
{
self
.
model_config
.
max_model_len
}
) "
"is larger than the maximum number of tokens that can be "
f
"stored in KV cache (
{
max_seq_len
}
). Try increasing "
"`VLLM_CPU_KVCACHE_SPACE` or decreasing `max_model_len` when "
"initializing the engine."
)
def
_init_cache_engine
(
self
)
->
None
:
self
.
cache_engine
=
[
CPUCacheEngine
(
self
.
cache_config
,
self
.
model_config
,
self
.
parallel_config
,
self
.
device_config
)
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
self
.
cpu_cache
=
[
self
.
cache_engine
[
ve
].
cpu_cache
for
ve
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
bind_kv_cache
(
self
.
compilation_config
.
static_forward_context
,
self
.
cpu_cache
)
self
.
model_runner
.
block_size
=
self
.
cache_engine
[
0
].
block_size
assert
all
(
self
.
cpu_cache
[
ve
]
is
not
None
for
ve
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
))
# Populate the cache to warmup the memory
for
ve
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
):
for
layer_cache
in
self
.
cpu_cache
[
ve
]:
layer_cache
.
fill_
(
0
)
@
property
def
do_metadata_broadcast
(
self
)
->
bool
:
return
self
.
parallel_config
.
tensor_parallel_size
>
1
@
property
def
kv_cache
(
self
)
->
Optional
[
List
[
List
[
torch
.
Tensor
]]]:
return
self
.
cpu_cache
@
property
def
cache_engines
(
self
)
->
Optional
[
List
[
CacheEngine
]]:
return
None
@
property
def
vocab_size
(
self
)
->
int
:
return
self
.
model_runner
.
vocab_size
@
property
def
max_model_len
(
self
)
->
int
:
return
self
.
model_config
.
max_model_len
def
execute_worker
(
self
,
worker_input
:
WorkerInput
,
)
->
None
:
if
(
worker_input
.
blocks_to_copy
is
not
None
and
worker_input
.
blocks_to_copy
.
numel
()
>
0
):
self
.
cache_engine
[
worker_input
.
virtual_engine
].
copy
(
worker_input
.
blocks_to_copy
)
@
torch
.
inference_mode
()
def
prepare_worker_input
(
self
,
execute_model_req
:
ExecuteModelRequest
)
->
WorkerInput
:
assert
execute_model_req
is
not
None
virtual_engine
:
int
=
execute_model_req
.
virtual_engine
num_seq_groups
:
int
=
len
(
execute_model_req
.
seq_group_metadata_list
)
blocks_to_copy
=
torch
.
tensor
(
execute_model_req
.
blocks_to_copy
,
device
=
"cpu"
,
dtype
=
torch
.
int64
).
view
(
-
1
,
2
)
assert
len
(
execute_model_req
.
blocks_to_swap_in
)
==
0
assert
len
(
execute_model_req
.
blocks_to_swap_out
)
==
0
return
WorkerInput
(
num_seq_groups
=
num_seq_groups
,
blocks_to_copy
=
blocks_to_copy
,
virtual_engine
=
virtual_engine
,
)
def
init_distributed_environment
(
self
)
->
None
:
"""Initialize the distributed environment."""
parallel_config
=
self
.
parallel_config
rank
=
self
.
rank
distributed_init_method
=
self
.
distributed_init_method
init_distributed_environment
(
world_size
=
parallel_config
.
world_size
,
rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
backend
=
"gloo"
,
)
# A small all_reduce for warmup.
torch
.
distributed
.
all_reduce
(
torch
.
zeros
(
1
).
cpu
())
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
parallel_config
.
pipeline_parallel_size
)
def
get_cache_block_size_bytes
(
self
)
->
int
:
"""Return the size in bytes of a single KV cache block.
"""
return
CPUCacheEngine
.
get_cache_block_size
(
self
.
cache_config
,
self
.
model_config
,
self
.
parallel_config
)
def
get_cpus_id_binding_based_on_numa_nodes
(
self
)
->
str
:
"""Return CPUs id binding based on NUMA nodes.
"""
rank_to_cpus
=
self
.
local_omp_cpuid
# Setup OpenMP thread affinity based on NUMA nodes automatically
world_size
=
self
.
vllm_config
.
parallel_config
.
world_size
libnuma_found
=
util
.
find_spec
(
"numa"
)
is
not
None
psutil_found
=
util
.
find_spec
(
"psutil"
)
is
not
None
if
libnuma_found
and
psutil_found
:
import
psutil
from
numa
import
info
cpu_count
=
psutil
.
cpu_count
(
logical
=
False
)
cpus_allow_list
=
psutil
.
Process
().
cpu_affinity
()
numa_size
=
info
.
get_num_configured_nodes
()
cpu_count_per_numa
=
cpu_count
//
numa_size
num_of_reserved_cpu
=
min
(
envs
.
VLLM_CPU_NUM_OF_RESERVED_CPU
,
cpu_count_per_numa
//
2
)
# check allow node_to_cpus list
node_to_cpus
=
[]
for
i
in
range
(
numa_size
):
node_intersect
=
set
(
info
.
node_to_cpus
(
i
)).
intersection
(
cpus_allow_list
)
if
bool
(
node_intersect
):
node_to_cpus
.
append
(
list
(
node_intersect
))
if
world_size
>
len
(
node_to_cpus
):
logger
.
error
(
"Auto thread-binding failed due to "
"world size: %d is larger than "
"allowed NUMA nodes number: %d."
"Please try to bind threads manually."
,
world_size
,
len
(
node_to_cpus
))
else
:
end
=
cpu_count_per_numa
-
num_of_reserved_cpu
rank_to_cpus_list
=
node_to_cpus
[
self
.
rank
][:
end
]
rank_to_cpus
=
','
.
join
(
str
(
x
)
for
x
in
rank_to_cpus_list
)
logger
.
info
(
"auto thread-binding list: %s"
,
rank_to_cpus
)
else
:
logger
.
warning
(
"Auto thread-binding is not supported due to "
"the lack of package numa and psutil,"
"fallback to no thread-binding. To get better performance,"
"please try to manually bind threads."
)
return
rank_to_cpus
vllm/worker/multi_step_tpu_worker.py
deleted
100644 → 0
View file @
5ad884ee
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
dataclasses
from
typing
import
Dict
,
Optional
,
Tuple
import
torch
from
vllm.distributed
import
broadcast_tensor_dict
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.worker.tpu_model_runner
import
ModelInputForTPU
from
vllm.worker.tpu_worker
import
TPUWorker
from
vllm.worker.worker_base
import
WorkerInput
class
MultiStepTPUWorker
(
TPUWorker
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
cached_model_input
:
Optional
[
ModelInputForTPU
]
=
None
def
_get_driver_input_and_broadcast
(
self
,
execute_model_req
:
ExecuteModelRequest
)
->
Tuple
[
ModelInputForTPU
,
WorkerInput
,
Dict
[
str
,
torch
.
Tensor
]]:
assert
self
.
is_driver_worker
assert
execute_model_req
.
virtual_engine
==
0
is_first_multi_step
=
execute_model_req
.
is_first_multi_step
is_last_step
=
execute_model_req
.
is_last_step
if
is_first_multi_step
:
worker_input
:
WorkerInput
=
self
.
prepare_worker_input
(
execute_model_req
=
execute_model_req
)
worker_input
=
dataclasses
.
replace
(
worker_input
,
num_steps
=
execute_model_req
.
num_lookahead_slots
+
1
)
model_input
:
ModelInputForTPU
=
(
self
.
model_runner
.
prepare_model_input
(
execute_model_req
.
seq_group_metadata_list
,
execute_model_req
.
virtual_engine
,
execute_model_req
.
finished_requests_ids
))
if
execute_model_req
.
async_callback
:
model_input
=
dataclasses
.
replace
(
model_input
,
async_callback
=
execute_model_req
.
async_callback
)
else
:
assert
self
.
cached_model_input
is
not
None
model_input
=
self
.
cached_model_input
worker_input
=
WorkerInput
()
model_input
=
dataclasses
.
replace
(
model_input
,
is_first_multi_step
=
is_first_multi_step
,
is_last_step
=
is_last_step
)
if
self
.
do_metadata_broadcast
:
if
is_first_multi_step
:
broadcast_data
=
worker_input
.
as_broadcastable_tensor_dict
()
broadcast_data
.
update
(
model_input
.
as_broadcastable_tensor_dict
())
broadcast_tensor_dict
(
broadcast_data
,
src
=
0
)
else
:
broadcast_data
=
{
"is_first_multi_step"
:
is_first_multi_step
,
"is_last_step"
:
is_last_step
,
}
broadcast_tensor_dict
(
broadcast_data
,
src
=
0
)
# Retuning empty dict here to keep this compatible with
# `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
return
model_input
,
worker_input
,
{}
def
prepare_input
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
,
)
->
Optional
[
Tuple
[
ModelInputForTPU
,
WorkerInput
,
Dict
[
str
,
torch
.
Tensor
]]]:
if
self
.
is_driver_worker
:
if
execute_model_req
is
None
:
if
self
.
do_metadata_broadcast
:
broadcast_tensor_dict
({},
src
=
0
)
return
None
model_input
,
worker_input
,
_
=
self
.
_get_driver_input_and_broadcast
(
execute_model_req
)
if
model_input
.
is_first_multi_step
:
self
.
cached_model_input
=
model_input
return
model_input
,
worker_input
,
{}
else
:
broadcast_data
=
broadcast_tensor_dict
(
src
=
0
)
if
not
broadcast_data
:
return
None
if
len
(
broadcast_data
)
==
2
:
assert
self
.
cached_model_input
is
not
None
self
.
cached_model_input
=
dataclasses
.
replace
(
self
.
cached_model_input
,
is_first_multi_step
=
broadcast_data
[
"is_first_multi_step"
],
is_last_step
=
broadcast_data
[
"is_last_step"
])
empty_worker_input
=
WorkerInput
()
return
self
.
cached_model_input
,
empty_worker_input
,
{}
worker_input
=
WorkerInput
.
from_broadcasted_tensor_dict
(
broadcast_data
)
model_input
=
(
self
.
model_runner
.
make_model_input_from_broadcasted_tensor_dict
(
broadcast_data
))
self
.
cached_model_input
=
model_input
return
model_input
,
worker_input
,
{}
vllm/worker/tpu_model_runner.py
deleted
100644 → 0
View file @
5ad884ee
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
enum
import
time
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
Union
)
from
unittest.mock
import
patch
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch_xla.core.xla_model
as
xm
import
torch_xla.runtime
as
xr
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.config
import
VllmConfig
from
vllm.forward_context
import
get_forward_context
,
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
IntermediateTensors
,
Logprob
,
SequenceGroupMetadata
,
SequenceOutput
)
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
_add_attn_metadata_broadcastable_dict
,
_init_attn_metadata_from_tensor_dict
)
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
logger
=
init_logger
(
__name__
)
# Here we utilize the behavior that out-of-bound index is ignored.
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
_PAD_SLOT_ID
=
1_000_000_000
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
_ENABLE_TOP_P
=
False
# FIXME(woosuk): A temporary hack to support `n > 1`.
# This can significantly affect the performance if too large.
_MAX_NUM_SAMPLES
=
128
class
ExecutionMode
(
enum
.
Enum
):
PREFILL
=
enum
.
auto
()
DECODE
=
enum
.
auto
()
PREFIX_PREFILL
=
enum
.
auto
()
def
is_prefill
(
self
)
->
bool
:
return
self
in
(
ExecutionMode
.
PREFILL
,
ExecutionMode
.
PREFIX_PREFILL
)
@
dataclass
(
frozen
=
True
)
class
ModelInputForTPU
(
ModelRunnerInputBase
):
token_ids
:
torch
.
Tensor
position_ids
:
torch
.
Tensor
attn_metadata
:
AttentionMetadata
input_lens
:
torch
.
Tensor
t
:
torch
.
Tensor
p
:
torch
.
Tensor
num_samples
:
int
n
:
List
[
int
]
seq_groups
:
List
[
List
[
int
]]
is_first_multi_step
:
bool
=
True
is_last_step
:
bool
=
True
virtual_engine
:
int
=
0
async_callback
:
Optional
[
Callable
]
=
None
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Union
[
int
,
torch
.
Tensor
]]:
tensor_dict
=
{
"token_ids"
:
self
.
token_ids
,
"position_ids"
:
self
.
position_ids
,
"input_lens"
:
self
.
input_lens
,
"t"
:
self
.
t
,
"p"
:
self
.
p
,
"num_samples"
:
self
.
num_samples
,
"n"
:
self
.
n
,
"seq_groups"
:
self
.
seq_groups
,
"is_first_multi_step"
:
self
.
is_first_multi_step
,
"is_last_step"
:
self
.
is_last_step
,
"virtual_engine"
:
self
.
virtual_engine
,
}
_add_attn_metadata_broadcastable_dict
(
tensor_dict
,
self
.
attn_metadata
)
return
tensor_dict
@
classmethod
def
from_broadcasted_tensor_dict
(
cls
:
Type
[
"ModelInputForTPU"
],
tensor_dict
:
Dict
[
str
,
Any
],
attn_backend
:
Optional
[
"AttentionBackend"
]
=
None
,
)
->
"ModelInputForTPU"
:
if
attn_backend
is
not
None
:
tensor_dict
=
_init_attn_metadata_from_tensor_dict
(
attn_backend
,
tensor_dict
)
return
cls
(
**
tensor_dict
)
class
TPUModelRunner
(
ModelRunnerBase
[
ModelInputForTPU
]):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
is_driver_worker
:
bool
=
False
,
):
ModelRunnerBase
.
__init__
(
self
,
vllm_config
=
vllm_config
)
self
.
is_driver_worker
=
is_driver_worker
self
.
block_size
=
self
.
cache_config
.
block_size
self
.
max_num_blocks_per_seq
=
(
self
.
model_config
.
max_model_len
//
self
.
block_size
)
self
.
block_tables
=
np
.
zeros
(
(
self
.
scheduler_config
.
max_num_seqs
,
self
.
max_num_blocks_per_seq
),
dtype
=
np
.
int32
)
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
dtype
,
self
.
cache_config
.
cache_dtype
,
self
.
block_size
,
self
.
model_config
.
is_attention_free
,
False
,
)
self
.
cached_step_outputs
:
List
[
torch
.
Tensor
]
=
[]
smem_size
=
512
*
1024
block_table_size
=
4
*
self
.
block_tables
.
size
if
block_table_size
>=
smem_size
:
logger
.
warning
(
"The max_model_len (%d) is too large. This may degrade the "
"performance due to the insufficient smem size. Consider "
"setting --max-model-len to a smaller value, like %d."
,
self
.
model_config
.
max_model_len
,
self
.
model_config
.
max_model_len
/
(
block_table_size
/
smem_size
))
def
load_model
(
self
)
->
None
:
self
.
device
=
self
.
device_config
.
device
# NOTE(woosuk): While the executor assigns the TP ranks to the worker
# process, the ranks can be different from the ranks internally assigned
# by the xm runtime. Therefore, there is a mismatch in the rank
# assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
# This is not a problem in linear layers because all-reduce is
# rank-agnostic. However, it matters for all-gather as the ranks
# determine the order of concatenating the output tensors.
# As a workaround, we use the xm's rank assignment only when loading
# the embedding weights.
xm_tp_rank
=
xr
.
global_ordinal
()
with
patch
(
"vllm.model_executor.layers.vocab_parallel_embedding."
"get_tensor_model_parallel_rank"
,
return_value
=
xm_tp_rank
):
model
=
get_model
(
vllm_config
=
self
.
vllm_config
)
model
=
model
.
eval
()
xm
.
wait_device_ops
()
model
=
ModelWrapper
(
model
)
self
.
model
=
torch
.
compile
(
model
,
backend
=
"openxla"
,
fullgraph
=
True
,
dynamic
=
False
)
def
get_model
(
self
)
->
nn
.
Module
:
return
self
.
model
.
model
def
_dummy_run
(
self
,
batch_size
:
int
,
seq_len
:
int
,
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
exec_mode
:
ExecutionMode
,
)
->
None
:
exec_mode
=
ExecutionMode
(
exec_mode
)
if
exec_mode
.
is_prefill
():
seq_len
=
(
seq_len
+
15
)
//
16
*
16
token_ids
=
torch
.
zeros
((
batch_size
,
seq_len
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
position_ids
=
torch
.
zeros
((
batch_size
,
seq_len
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
slot_mapping
=
torch
.
zeros
((
batch_size
,
seq_len
),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
input_lens
=
torch
.
ones
((
batch_size
,
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
if
exec_mode
==
ExecutionMode
.
PREFILL
:
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
batch_size
,
num_prefill_tokens
=
batch_size
*
seq_len
,
num_decode_tokens
=
0
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
None
,
enable_kv_scales_calculation
=
False
,
block_tables
=
None
,
context_lens
=
None
,
effective_query_lens
=
None
,
)
else
:
context_lens
=
torch
.
ones
((
batch_size
,
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
block_tables
=
torch
.
tensor
(
self
.
block_tables
[:
batch_size
],
dtype
=
torch
.
int32
,
device
=
self
.
device
)
effective_query_lens
=
torch
.
ones_like
(
context_lens
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
batch_size
,
num_prefill_tokens
=
batch_size
*
seq_len
,
num_decode_tokens
=
0
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
None
,
enable_kv_scales_calculation
=
False
,
block_tables
=
block_tables
,
context_lens
=
context_lens
,
effective_query_lens
=
effective_query_lens
,
)
else
:
assert
seq_len
==
1
token_ids
=
torch
.
zeros
((
batch_size
,
seq_len
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
position_ids
=
torch
.
zeros
((
batch_size
,
seq_len
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
slot_mapping
=
torch
.
zeros
((
batch_size
,
seq_len
),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
block_tables
=
torch
.
zeros
(
(
batch_size
,
self
.
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
context_lens
=
torch
.
ones
((
batch_size
,
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
input_lens
=
torch
.
ones
((
batch_size
,
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
batch_size
*
seq_len
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
None
,
enable_kv_scales_calculation
=
False
,
block_tables
=
block_tables
,
context_lens
=
context_lens
,
)
t
=
torch
.
ones
((
batch_size
,
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
p
=
torch
.
ones
((
batch_size
,
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
num_samples
=
_MAX_NUM_SAMPLES
if
exec_mode
.
is_prefill
()
else
1
# NOTE(woosuk): There are two stages of compilation: torch.compile and
# XLA compilation. Using `mark_dynamic` can reduce the torch.compile
# overhead by reusing the FX graph for different shapes.
# However, the XLA graph will still require static shapes and needs to
# be re-compiled for every different shapes. This overhead is inevitable
# in the first run, but can be skipped afterwards as we cache the XLA
# graphs in the disk (VLLM_XLA_CACHE_PATH).
if
exec_mode
.
is_prefill
():
# Prefll
torch
.
_dynamo
.
mark_dynamic
(
token_ids
,
1
)
torch
.
_dynamo
.
mark_dynamic
(
position_ids
,
1
)
torch
.
_dynamo
.
mark_dynamic
(
attn_metadata
.
slot_mapping
,
1
)
else
:
# Decode
torch
.
_dynamo
.
mark_dynamic
(
token_ids
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
position_ids
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
input_lens
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
attn_metadata
.
slot_mapping
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
attn_metadata
.
context_lens
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
attn_metadata
.
block_tables
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
t
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
p
,
0
)
# Dummy run.
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
0
):
self
.
model
(
token_ids
,
position_ids
,
input_lens
,
t
,
p
,
num_samples
,
kv_caches
)
def
warmup_model
(
self
,
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
)
->
None
:
# Prefill
logger
.
info
(
"Compiling the model with different input shapes..."
)
start
=
time
.
time
()
for
batch_size
in
[
1
]:
seq_len
=
16
while
seq_len
<=
self
.
model_config
.
max_model_len
:
self
.
_dummy_run
(
batch_size
,
seq_len
,
kv_caches
,
exec_mode
=
ExecutionMode
.
PREFILL
)
xm
.
wait_device_ops
()
logger
.
info
(
"batch_size: %d, seq_len: %d"
,
batch_size
,
seq_len
)
num_tokens
=
batch_size
*
seq_len
if
num_tokens
>=
self
.
scheduler_config
.
max_num_batched_tokens
:
break
seq_len
=
seq_len
*
2
end
=
time
.
time
()
logger
.
info
(
"Compilation for prefill done in %.2f s."
,
end
-
start
)
# Prefix prefill
if
self
.
cache_config
.
enable_prefix_caching
:
logger
.
info
(
"Compiling the model with different input shapes for "
"prefix prefill..."
)
start
=
time
.
time
()
for
batch_size
in
[
1
]:
seq_len
=
16
while
seq_len
<=
self
.
model_config
.
max_model_len
:
self
.
_dummy_run
(
batch_size
,
seq_len
,
kv_caches
,
exec_mode
=
ExecutionMode
.
PREFIX_PREFILL
)
xm
.
wait_device_ops
()
logger
.
info
(
"batch_size: %d, seq_len: %d"
,
batch_size
,
seq_len
)
num_tokens
=
batch_size
*
seq_len
if
(
num_tokens
>=
self
.
scheduler_config
.
max_num_batched_tokens
):
break
seq_len
=
seq_len
*
2
end
=
time
.
time
()
logger
.
info
(
"Compilation for prefix prefill done in %.2f s."
,
end
-
start
)
# Decode
start
=
time
.
time
()
seq_len
=
1
batch_size
=
8
# Must be in sync with _get_padded_batch_size()
while
True
:
self
.
_dummy_run
(
batch_size
,
seq_len
,
kv_caches
,
exec_mode
=
ExecutionMode
.
DECODE
)
xm
.
wait_device_ops
()
logger
.
info
(
"batch_size: %d, seq_len: %d"
,
batch_size
,
seq_len
)
if
batch_size
>=
self
.
scheduler_config
.
max_num_seqs
:
break
batch_size
=
batch_size
+
16
if
batch_size
>=
16
else
batch_size
*
2
end
=
time
.
time
()
logger
.
info
(
"Compilation for decode done in %.2f s."
,
end
-
start
)
def
_prepare_prompt
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
torch
.
Tensor
]:
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
prompt_lens
:
List
[
int
]
=
[]
context_lens
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
for
batch_idx
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
assert
seq_group_metadata
.
is_prompt
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
assert
len
(
seq_ids
)
==
1
seq_id
=
seq_ids
[
0
]
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
# Could include output tokens when a request is preempted.
prompt_tokens
=
seq_data
.
get_token_ids
()
seq_len
=
len
(
prompt_tokens
)
num_computed_blocks
=
len
(
seq_group_metadata
.
computed_block_nums
)
num_computed_tokens
=
num_computed_blocks
*
self
.
block_size
if
num_computed_tokens
>
0
:
prompt_tokens
=
prompt_tokens
[
num_computed_tokens
:]
context_lens
.
append
(
seq_len
)
else
:
context_lens
.
append
(
0
)
prompt_len
=
len
(
prompt_tokens
)
prompt_lens
.
append
(
prompt_len
)
input_tokens
.
extend
(
prompt_tokens
)
input_positions
.
extend
(
range
(
num_computed_tokens
,
seq_len
))
assert
seq_group_metadata
.
block_tables
is
not
None
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
for
i
in
range
(
num_computed_tokens
,
seq_len
):
block_number
=
block_table
[
i
//
self
.
block_size
]
block_offset
=
i
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
if
num_computed_tokens
>
0
:
self
.
block_tables
[
batch_idx
,
:
len
(
block_table
)]
=
block_table
# Add paddings to EACH prompt to the smallest power of 2 that is
# greater than or equal to the prompt length.
# We pad the seq_len to reduce the compilation overhead.
# We execute each prompt individually (i.e., with batch_size 1)
# because the FlashAttention kernel does not support ragged inputs.
# TODO(woosuk): Use SplashAttention to support ragged inputs.
padded_prompt_len
=
_get_padded_prefill_len
(
prompt_len
)
num_paddings
=
padded_prompt_len
-
prompt_len
input_tokens
+=
[
0
]
*
num_paddings
input_positions
+=
[
0
]
*
num_paddings
slot_mapping
+=
[
_PAD_SLOT_ID
]
*
num_paddings
assert
len
(
prompt_lens
)
>
0
num_prefills
=
len
(
prompt_lens
)
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
int64
,
device
=
"cpu"
)
prompt_lens
=
torch
.
tensor
(
prompt_lens
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
block_tables
=
torch
.
tensor
(
self
.
block_tables
[:
num_prefills
],
dtype
=
torch
.
int32
,
device
=
"cpu"
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
num_prefills
,
num_prefill_tokens
=
0
,
# NOTE: This is not used.
num_decode_tokens
=
0
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
None
,
enable_kv_scales_calculation
=
False
,
block_tables
=
block_tables
,
context_lens
=
context_lens
,
effective_query_lens
=
prompt_lens
,
)
return
input_tokens
,
input_positions
,
attn_metadata
,
prompt_lens
def
_prepare_decode
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
torch
.
Tensor
]:
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
List
[
int
]]
=
[]
input_positions
:
List
[
List
[
int
]]
=
[]
slot_mapping
:
List
[
List
[
int
]]
=
[]
context_lens
:
List
[
int
]
=
[]
batch_idx
=
0
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
not
seq_group_metadata
.
is_prompt
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
for
seq_id
in
seq_ids
:
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
generation_token
=
seq_data
.
get_last_token_id
()
input_tokens
.
append
([
generation_token
])
seq_len
=
seq_data
.
get_len
()
position
=
seq_len
-
1
input_positions
.
append
([
position
])
context_lens
.
append
(
seq_len
)
assert
seq_group_metadata
.
block_tables
is
not
None
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
self
.
block_tables
[
batch_idx
,
:
len
(
block_table
)]
=
block_table
batch_idx
+=
1
block_number
=
block_table
[
position
//
self
.
block_size
]
block_offset
=
position
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
([
slot
])
batch_size
=
_get_padded_batch_size
(
batch_idx
)
num_paddings
=
batch_size
-
batch_idx
input_tokens
=
input_tokens
+
[[
0
]]
*
num_paddings
input_positions
=
input_positions
+
[[
0
]]
*
num_paddings
slot_mapping
=
slot_mapping
+
[[
_PAD_SLOT_ID
]]
*
num_paddings
context_lens
=
context_lens
+
[
0
]
*
num_paddings
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
int64
,
device
=
"cpu"
)
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
block_tables
=
torch
.
tensor
(
self
.
block_tables
[:
batch_size
],
dtype
=
torch
.
int32
,
device
=
"cpu"
)
input_lens
=
torch
.
tensor
([
1
]
*
batch_size
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
batch_size
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
None
,
enable_kv_scales_calculation
=
False
,
block_tables
=
block_tables
,
context_lens
=
context_lens
,
)
return
input_tokens
,
input_positions
,
attn_metadata
,
input_lens
def
_prepare_sample
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
padded_batch_size
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
List
[
int
]]:
assert
len
(
seq_group_metadata_list
)
>
0
t
=
[]
p
=
[]
n
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
sampling_params
=
seq_group_metadata
.
sampling_params
t
.
append
(
sampling_params
.
temperature
)
if
sampling_params
.
top_p
!=
1
and
not
_ENABLE_TOP_P
:
raise
NotImplementedError
(
"Top-p sampling is currently disabled for the TPU backend "
"due to performance issues."
)
p
.
append
(
sampling_params
.
top_p
)
if
sampling_params
.
top_k
>
0
:
raise
NotImplementedError
(
"Top-k sampling is currently disabled for the TPU backend "
"due to performance issues."
)
if
sampling_params
.
n
>
_MAX_NUM_SAMPLES
:
raise
NotImplementedError
(
f
"Best of >
{
_MAX_NUM_SAMPLES
}
is not supported by the TPU "
"backend."
)
n
.
append
(
sampling_params
.
n
)
if
sampling_params
.
logprobs
is
not
None
:
raise
NotImplementedError
(
"logprobs is not currently supported by the TPU backend."
)
if
sampling_params
.
prompt_logprobs
is
not
None
:
raise
NotImplementedError
(
"prompt_logprobs is not currently supported by the TPU "
"backend."
)
# Repeat the sampling params if the seq group has multiple seqs.
num_seqs
=
len
(
seq_group_metadata
.
seq_data
)
t
+=
[
t
[
-
1
]]
*
(
num_seqs
-
1
)
p
+=
[
p
[
-
1
]]
*
(
num_seqs
-
1
)
n
+=
[
n
[
-
1
]]
*
(
num_seqs
-
1
)
num_paddings
=
padded_batch_size
-
len
(
t
)
t
+=
[
1.0
]
*
num_paddings
p
+=
[
1.0
]
*
num_paddings
t
=
torch
.
tensor
(
t
,
dtype
=
torch
.
float32
,
device
=
"cpu"
)
p
=
torch
.
tensor
(
p
,
dtype
=
torch
.
float32
,
device
=
"cpu"
)
return
t
,
p
,
n
def
prepare_model_input
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
,
)
->
ModelInputForTPU
:
del
finished_requests_ids
# Unused.
assert
virtual_engine
==
0
assert
len
(
seq_group_metadata_list
)
>
0
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt
=
seq_group_metadata_list
[
0
].
is_prompt
if
is_prompt
:
inputs
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
else
:
inputs
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
input_tokens
,
input_positions
,
attn_metadata
,
input_lens
=
inputs
padded_batch_size
=
input_tokens
.
shape
[
0
]
t
,
p
,
n
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
padded_batch_size
)
num_samples
=
_MAX_NUM_SAMPLES
if
is_prompt
else
1
seq_groups
=
[
list
(
metadata
.
seq_data
.
keys
())
for
metadata
in
seq_group_metadata_list
]
return
ModelInputForTPU
(
input_tokens
,
input_positions
,
attn_metadata
,
input_lens
,
t
,
p
,
num_samples
,
n
,
seq_groups
)
def
make_model_input_from_broadcasted_tensor_dict
(
self
,
tensor_dict
:
Dict
[
str
,
Any
])
->
ModelInputForTPU
:
model_input
=
ModelInputForTPU
.
from_broadcasted_tensor_dict
(
tensor_dict
,
attn_backend
=
self
.
attn_backend
)
return
model_input
@
torch
.
no_grad
()
def
execute_model
(
self
,
model_input
:
ModelInputForTPU
,
kv_caches
:
Optional
[
List
[
Any
]],
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
num_steps
:
int
=
1
,
)
->
List
[
SamplerOutput
]:
assert
intermediate_tensors
is
None
if
not
model_input
.
is_first_multi_step
:
if
not
model_input
.
is_last_step
:
return
[]
use_async_out_proc
=
model_input
.
async_callback
is
not
None
sampler_outputs
=
[]
num_outputs
=
len
(
self
.
cached_step_outputs
)
for
i
in
range
(
num_outputs
):
next_token_ids
=
self
.
cached_step_outputs
.
pop
(
0
)
next_token_ids
=
next_token_ids
.
cpu
().
tolist
()
sampler_output
=
_make_decode_output
(
next_token_ids
,
model_input
.
seq_groups
)
sampler_outputs
.
append
(
sampler_output
)
if
i
<
num_outputs
-
1
and
use_async_out_proc
:
assert
model_input
.
async_callback
is
not
None
ctx
=
model_input
.
async_callback
.
keywords
[
# type: ignore
"ctx"
]
ctx
.
append_output
(
outputs
=
[
sampler_output
],
seq_group_metadata_list
=
ctx
.
seq_group_metadata_list
,
scheduler_outputs
=
ctx
.
scheduler_outputs
,
is_async
=
False
,
is_last_step
=
False
,
is_first_step_output
=
i
==
0
)
model_input
.
async_callback
()
if
use_async_out_proc
:
return
[
sampler_outputs
[
-
1
]]
else
:
return
sampler_outputs
is_prompt
=
model_input
.
attn_metadata
.
num_prefills
>
0
if
is_prompt
:
assert
num_steps
==
1
# NOTE(woosuk): Since the FlashAttention kernel does not support
# ragged inputs, we split the prompts into different batches and
# process them separately. This is a temporary hack that should be
# optimized by using SplashAttention.
orig_slot_mapping
=
model_input
.
attn_metadata
.
slot_mapping
orig_block_tables
=
model_input
.
attn_metadata
.
block_tables
orig_context_lens
=
model_input
.
attn_metadata
.
context_lens
orig_effective_query_lens
=
\
model_input
.
attn_metadata
.
effective_query_lens
batch_size
=
model_input
.
input_lens
.
shape
[
0
]
start_idx
=
0
next_token_ids
=
[]
for
i
in
range
(
batch_size
):
# Get the actual prefill_len.
prefill_len
=
model_input
.
input_lens
[
i
:
i
+
1
].
item
()
prefill_len
=
_get_padded_prefill_len
(
prefill_len
)
end_idx
=
start_idx
+
prefill_len
token_ids
=
model_input
.
token_ids
[
None
,
start_idx
:
end_idx
].
to
(
self
.
device
)
position_ids
=
model_input
.
position_ids
[
None
,
start_idx
:
end_idx
].
to
(
self
.
device
)
attn_metadata
=
model_input
.
attn_metadata
attn_metadata
.
num_prefills
=
1
attn_metadata
.
slot_mapping
=
orig_slot_mapping
[
None
,
start_idx
:
end_idx
].
to
(
self
.
device
)
if
orig_context_lens
[
i
].
item
()
>
0
:
attn_metadata
.
context_lens
=
orig_context_lens
[
i
:
i
+
1
].
to
(
self
.
device
)
attn_metadata
.
block_tables
=
orig_block_tables
[
i
].
unsqueeze
(
0
).
to
(
self
.
device
)
attn_metadata
.
effective_query_lens
=
\
orig_effective_query_lens
[
i
:
i
+
1
].
to
(
self
.
device
)
else
:
attn_metadata
.
context_lens
=
None
attn_metadata
.
block_tables
=
None
attn_metadata
.
effective_query_lens
=
None
input_lens
=
model_input
.
input_lens
[
i
:
i
+
1
].
to
(
self
.
device
)
t
=
model_input
.
t
[
i
:
i
+
1
].
to
(
self
.
device
)
p
=
model_input
.
p
[
i
:
i
+
1
].
to
(
self
.
device
)
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
,
model_input
.
virtual_engine
):
output_token_ids
=
self
.
model
(
token_ids
,
position_ids
,
input_lens
,
t
,
p
,
model_input
.
num_samples
,
kv_caches
)
next_token_ids
.
append
(
output_token_ids
[
0
])
start_idx
=
end_idx
if
model_input
.
async_callback
is
not
None
:
model_input
.
async_callback
()
# Retrieve the outputs to CPU.
next_token_ids
=
[
output_token_ids
.
cpu
().
tolist
()
for
output_token_ids
in
next_token_ids
]
# NOTE(woosuk): Minimal code to construct the sampler outputs.
# The TPU backend does not reuse the sampler, since the TPU backend
# does not support advanced sampling parameters such as logprobs.
zero_logprob
=
Logprob
(
0.0
)
sampler_outputs
=
[]
for
i
,
seq_group
in
enumerate
(
model_input
.
seq_groups
):
seq_ids
=
seq_group
assert
len
(
seq_ids
)
==
1
seq_id
=
seq_ids
[
0
]
seq_outputs
=
[]
for
j
in
range
(
model_input
.
n
[
i
]):
next_token_id
=
next_token_ids
[
i
][
j
]
seq_outputs
.
append
(
SequenceOutput
(
seq_id
,
next_token_id
,
{
next_token_id
:
zero_logprob
}))
sampler_outputs
.
append
(
CompletionSequenceGroupOutput
(
seq_outputs
,
None
))
return
[
SamplerOutput
(
sampler_outputs
)]
else
:
token_ids
=
model_input
.
token_ids
.
to
(
self
.
device
)
position_ids
=
model_input
.
position_ids
.
to
(
self
.
device
)
attn_metadata
=
model_input
.
attn_metadata
attn_metadata
.
slot_mapping
=
attn_metadata
.
slot_mapping
.
to
(
self
.
device
)
attn_metadata
.
block_tables
=
attn_metadata
.
block_tables
.
to
(
self
.
device
)
attn_metadata
.
context_lens
=
attn_metadata
.
context_lens
.
to
(
self
.
device
)
t
=
model_input
.
t
.
to
(
self
.
device
)
p
=
model_input
.
p
.
to
(
self
.
device
)
input_lens
=
model_input
.
input_lens
.
to
(
self
.
device
)
for
i
in
range
(
num_steps
):
slot_mapping
=
attn_metadata
.
slot_mapping
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
,
model_input
.
virtual_engine
):
output_token_ids
=
self
.
model
(
token_ids
,
position_ids
,
input_lens
,
t
,
p
,
model_input
.
num_samples
,
kv_caches
)
self
.
cached_step_outputs
.
append
(
output_token_ids
)
if
i
<
num_steps
-
1
:
# Prepare the inputs for the next step.
token_ids
=
output_token_ids
.
unsqueeze
(
dim
=
1
).
int
()
position_ids
=
position_ids
+
1
attn_metadata
.
context_lens
=
attn_metadata
.
context_lens
+
1
block_tables
=
attn_metadata
.
block_tables
block_number
=
block_tables
.
gather
(
1
,
position_ids
.
long
()
//
self
.
block_size
)
block_offset
=
position_ids
%
self
.
block_size
is_padding
=
slot_mapping
==
_PAD_SLOT_ID
slot_mapping
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
=
slot_mapping
.
long
()
slot_mapping
=
torch
.
where
(
is_padding
,
_PAD_SLOT_ID
,
slot_mapping
)
attn_metadata
.
slot_mapping
=
slot_mapping
if
model_input
.
async_callback
is
not
None
:
model_input
.
async_callback
()
if
num_steps
>
1
:
return
[]
# Retrieve the outputs to CPU.
next_token_ids
=
self
.
cached_step_outputs
.
pop
(
0
)
next_token_ids
=
next_token_ids
.
cpu
().
tolist
()
sampler_output
=
_make_decode_output
(
next_token_ids
,
model_input
.
seq_groups
)
return
[
sampler_output
]
class
ModelWrapper
(
nn
.
Module
):
def
__init__
(
self
,
model
:
nn
.
Module
):
super
().
__init__
()
self
.
model
=
model
def
forward
(
self
,
token_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
input_lens
:
torch
.
Tensor
,
t
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
num_samples
:
int
,
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
)
->
torch
.
Tensor
:
"""Executes the forward pass of the model and samples the next token.
Args:
token_ids: The input token IDs of shape [batch_size, seq_len].
position_ids: The input position IDs of shape [batch_size, seq_len].
input_lens: The actual input lengths of shape [batch_size].
t: The sampling temperature of shape [batch_size].
p: The top-p probability of shape [batch_size].
num_samples: Number of samples to draw from each logits vector.
kv_caches: The key and value caches. They can be None during the
memory profiling at initialization.
"""
batch_size
,
seq_len
=
token_ids
.
shape
# Calculate the positions to sample from.
start_indices
=
torch
.
arange
(
batch_size
,
dtype
=
torch
.
int32
,
device
=
input_lens
.
device
)
*
seq_len
logits_indices
=
start_indices
+
input_lens
-
1
attn_metadata
=
get_forward_context
().
attn_metadata
# FIXME(woosuk): This is a temporary hack to avoid using the existing
# sampler and sampling metadata.
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
[],
selected_token_indices
=
logits_indices
,
categorized_sample_indices
=
{},
num_prompts
=
attn_metadata
.
num_prefills
,
)
# Skip this in memory profiling at initialization.
if
kv_caches
[
0
][
0
].
numel
()
>
0
:
# index_copy_(slot_mapping) only works when the inserted dimension
# is 0. However, the KV cache in the Pallas backend has the shape
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
# work, we need to flatten the first three dimensions and modify
# the slot_mapping accordingly.
num_kv_heads
,
num_blocks
,
block_size
,
_
=
kv_caches
[
0
][
0
].
shape
slot_mapping
=
attn_metadata
.
slot_mapping
slot_mapping
=
slot_mapping
.
flatten
()
head_indices
=
torch
.
arange
(
0
,
num_kv_heads
,
device
=
slot_mapping
.
device
,
dtype
=
slot_mapping
.
dtype
)
head_indices
*=
block_size
*
num_blocks
slot_mapping
=
slot_mapping
.
repeat_interleave
(
num_kv_heads
).
view
(
-
1
,
num_kv_heads
)
slot_mapping
=
slot_mapping
+
head_indices
.
view
(
1
,
-
1
)
slot_mapping
=
slot_mapping
.
flatten
()
attn_metadata
.
slot_mapping
=
slot_mapping
hidden_states
=
self
.
model
(
token_ids
,
position_ids
)
hidden_states
=
hidden_states
.
flatten
(
0
,
1
)
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
# Argmax sampling.
argmax_token_ids
=
torch
.
argmax
(
logits
,
dim
=-
1
,
keepdim
=
True
)
argmax_token_ids
=
argmax_token_ids
.
repeat
(
1
,
num_samples
)
# Zero temperature means greedy decoding. Avoid division by zero.
nonzero_t
=
torch
.
where
(
t
!=
0
,
t
,
1.0
)
logits
=
logits
/
nonzero_t
.
unsqueeze
(
dim
=
1
)
if
_ENABLE_TOP_P
:
logits
=
_apply_top_p
(
logits
,
p
.
unsqueeze
(
dim
=
1
))
# Random sampling.
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
sampled_token_ids
=
torch
.
multinomial
(
probs
,
num_samples
,
replacement
=
True
)
if
num_samples
==
1
:
argmax_token_ids
=
argmax_token_ids
.
squeeze
(
dim
=-
1
)
sampled_token_ids
=
sampled_token_ids
.
squeeze
(
dim
=-
1
)
next_token_ids
=
torch
.
where
(
t
!=
0
,
sampled_token_ids
,
argmax_token_ids
)
return
next_token_ids
def
_get_padded_prefill_len
(
x
:
int
)
->
int
:
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
# length to be a multiple of 16. We pad the prompt length to the nearest
# multiple of 16. This is also good for performance.
if
x
<=
16
:
return
16
return
1
<<
(
x
-
1
).
bit_length
()
def
_get_padded_batch_size
(
batch_size
:
int
)
->
int
:
# The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
# To meet this requirement in the simplest way, we set the minimal batch
# size to 8.
if
batch_size
<=
8
:
return
8
else
:
return
((
batch_size
+
15
)
//
16
)
*
16
def
_apply_top_p
(
logits
:
torch
.
Tensor
,
p
:
torch
.
Tensor
)
->
torch
.
Tensor
:
logits_sorted
=
torch
.
sort
(
logits
,
dim
=-
1
,
descending
=
True
).
values
sorted_cum_probs
=
torch
.
cumsum
(
logits_sorted
.
softmax
(
dim
=-
1
),
dim
=-
1
)
cutoff_index
=
torch
.
sum
(
sorted_cum_probs
<
p
,
dim
=-
1
,
keepdim
=
True
)
cutoff_logit
=
torch
.
gather
(
logits_sorted
,
-
1
,
cutoff_index
)
logits
=
logits
.
masked_fill_
(
logits
<
cutoff_logit
,
-
float
(
"inf"
))
return
logits
def
_make_decode_output
(
next_token_ids
:
List
[
int
],
seq_groups
:
List
[
List
[
int
]],
)
->
SamplerOutput
:
zero_logprob
=
Logprob
(
0.0
)
sampler_outputs
=
[]
batch_idx
=
0
for
seq_group
in
seq_groups
:
seq_ids
=
seq_group
seq_outputs
=
[]
for
seq_id
in
seq_ids
:
next_token_id
=
next_token_ids
[
batch_idx
]
seq_outputs
.
append
(
SequenceOutput
(
seq_id
,
next_token_id
,
{
next_token_id
:
zero_logprob
}))
batch_idx
+=
1
sampler_outputs
.
append
(
CompletionSequenceGroupOutput
(
seq_outputs
,
None
))
return
SamplerOutput
(
sampler_outputs
)
vllm/worker/tpu_worker.py
deleted
100644 → 0
View file @
5ad884ee
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch_xla.core.xla_model
as
xm
import
torch_xla.debug.profiler
as
xp
import
torch_xla.runtime
as
xr
import
vllm.envs
as
envs
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
)
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
set_random_seed
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
bind_kv_cache
,
get_dtype_size
from
vllm.worker.tpu_model_runner
import
ExecutionMode
,
TPUModelRunner
from
vllm.worker.worker_base
import
(
LocalOrDistributedWorkerBase
,
LoRANotSupportedWorkerBase
,
WorkerBase
,
WorkerInput
)
logger
=
init_logger
(
__name__
)
class
TPUWorker
(
LoRANotSupportedWorkerBase
,
LocalOrDistributedWorkerBase
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
local_rank
:
int
,
rank
:
int
,
distributed_init_method
:
str
,
is_driver_worker
:
bool
,
)
->
None
:
WorkerBase
.
__init__
(
self
,
vllm_config
=
vllm_config
)
self
.
parallel_config
.
rank
=
rank
self
.
local_rank
=
local_rank
self
.
rank
=
rank
self
.
distributed_init_method
=
distributed_init_method
self
.
is_driver_worker
=
is_driver_worker
assert
self
.
device_config
.
device_type
==
"tpu"
if
self
.
cache_config
.
cache_dtype
==
"auto"
:
self
.
cache_dtype
=
self
.
model_config
.
dtype
else
:
self
.
cache_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
self
.
cache_config
.
cache_dtype
]
self
.
model_runner
:
TPUModelRunner
=
TPUModelRunner
(
vllm_config
=
vllm_config
,
is_driver_worker
=
is_driver_worker
)
if
self
.
model_config
.
seed
is
None
:
self
.
model_config
.
seed
=
0
if
vllm_config
.
lora_config
is
not
None
:
raise
NotImplementedError
(
"The V0 TPU backend doesn't support LoRA serving"
)
def
init_device
(
self
)
->
None
:
os
.
environ
[
"PJRT_DEVICE"
]
=
"TPU"
torch
.
set_grad_enabled
(
False
)
torch
.
set_default_dtype
(
self
.
model_config
.
dtype
)
# NOTE(woosuk): This is just to initialize the TP group and broadcast
# the input objects on CPU. The all-reduce and all-gather ops on TPU
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
# own context.
init_distributed_environment
(
world_size
=
self
.
parallel_config
.
world_size
,
rank
=
self
.
rank
,
local_rank
=
self
.
local_rank
,
distributed_init_method
=
self
.
distributed_init_method
,
backend
=
"gloo"
,
)
ensure_model_parallel_initialized
(
self
.
parallel_config
.
tensor_parallel_size
,
self
.
parallel_config
.
pipeline_parallel_size
)
# Device initialization should happen after initializing the distributed
# runtime.
self
.
device
=
xm
.
xla_device
()
self
.
device_config
.
device
=
self
.
device
# Set random seed.
set_random_seed
(
self
.
model_config
.
seed
)
xm
.
set_rng_state
(
self
.
model_config
.
seed
,
self
.
device
)
# Increase the cache size limit, which is the maximum number of
# dynamo graphs that can be compiled.
# NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and
# 30-40 graphs for decode. 128 is an arbitrary safe number.
torch
.
_dynamo
.
config
.
cache_size_limit
=
128
# Use persistent cache to avoid XLA recompilation.
# NOTE(woosuk): Set per-rank cache path since different ranks
# can have slightly different XLA graphs.
world_size
=
self
.
parallel_config
.
world_size
rank
=
xr
.
global_ordinal
()
# The PyTorch/XLA compilation cache uses the Torch IR to generate keys.
# Consequently, changes in optimization flags, which affect compilation
# results, don't change the cache key. This can result in the wrong
# compilation being used. To prevent this, disabling the XLA compilation
# cache during development is recommended.We can disable it by
# `export VLLM_XLA_CACHE_PATH=`
if
envs
.
VLLM_XLA_CACHE_PATH
:
per_rank_path
=
os
.
path
.
join
(
envs
.
VLLM_XLA_CACHE_PATH
,
f
"tp
{
world_size
}
_rank
{
rank
}
"
)
xr
.
initialize_cache
(
per_rank_path
,
readonly
=
False
)
self
.
profiler
=
None
if
envs
.
VLLM_TORCH_PROFILER_DIR
and
self
.
rank
<
1
:
# For TPU, we can only have 1 active profiler session for 1 profiler
# server. So we only profile on rank0.
self
.
profile_dir
=
envs
.
VLLM_TORCH_PROFILER_DIR
logger
.
info
(
"Profiling enabled. Traces will be saved to: %s"
,
self
.
profile_dir
)
self
.
profiler
=
xp
.
start_server
(
9012
)
def
start_profile
(
self
):
if
self
.
rank
<
1
:
if
self
.
profiler
is
None
:
raise
RuntimeError
(
"Profiler is not enabled."
)
xp
.
start_trace
(
self
.
profile_dir
)
def
stop_profile
(
self
):
if
self
.
rank
<
1
:
if
self
.
profiler
is
None
:
raise
RuntimeError
(
"Profiler is not enabled."
)
xp
.
stop_trace
()
def
load_model
(
self
):
self
.
model_runner
.
load_model
()
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
head_size
=
self
.
model_config
.
get_head_size
()
num_kv_heads
=
self
.
model_config
.
get_num_kv_heads
(
self
.
parallel_config
)
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
kv_caches
=
[(
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
self
.
device
),
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
self
.
device
))
for
_
in
range
(
num_layers
)]
bind_kv_cache
(
self
.
compilation_config
.
static_forward_context
,
[
kv_caches
])
self
.
model_runner
.
_dummy_run
(
batch_size
=
1
,
seq_len
=
self
.
scheduler_config
.
max_num_batched_tokens
,
kv_caches
=
kv_caches
,
exec_mode
=
ExecutionMode
.
PREFILL
,
)
# Synchronize before measuring the memory usage.
xm
.
wait_device_ops
()
# Get the maximum amount of memory used by the model weights and
# intermediate activations.
m
=
xm
.
get_memory_info
(
self
.
device
)
total_memory_size
=
m
[
"bytes_limit"
]
profiled
=
m
[
"peak_bytes_used"
]
# Weights + intermediate activations.
# Calculate the TPU KV cache size based on profiling.
usable_memory_size
=
int
(
total_memory_size
*
self
.
cache_config
.
gpu_memory_utilization
)
tpu_kv_cache_bytes
=
max
(
usable_memory_size
-
profiled
,
0
)
dtype_bytes
=
get_dtype_size
(
self
.
cache_dtype
)
block_size_bytes
=
(
dtype_bytes
*
self
.
cache_config
.
block_size
*
num_layers
*
2
*
head_size
*
num_kv_heads
)
num_tpu_blocks
=
tpu_kv_cache_bytes
//
block_size_bytes
num_tpu_blocks
=
(
num_tpu_blocks
//
8
)
*
8
# Round down to 8.
# Calculate the CPU KV cache size based on the config.
num_cpu_blocks
=
int
(
self
.
cache_config
.
swap_space_bytes
//
block_size_bytes
)
num_cpu_blocks
=
(
num_cpu_blocks
//
8
)
*
8
# Round down to 8.
return
num_tpu_blocks
,
num_cpu_blocks
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
,
)
->
None
:
self
.
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
self
.
cache_config
.
num_cpu_blocks
=
num_cpu_blocks
self
.
block_size
=
self
.
cache_config
.
block_size
dtype
=
self
.
cache_dtype
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
num_kv_heads
=
self
.
model_config
.
get_num_kv_heads
(
self
.
parallel_config
)
head_size
=
self
.
model_config
.
get_head_size
()
self
.
cpu_cache
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
[]
self
.
tpu_cache
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
[]
tpu_cache_shape
=
self
.
model_runner
.
attn_backend
.
get_kv_cache_shape
(
num_gpu_blocks
,
self
.
block_size
,
num_kv_heads
,
head_size
)
cpu_cache_shape
=
self
.
model_runner
.
attn_backend
.
get_kv_cache_shape
(
num_cpu_blocks
,
self
.
block_size
,
num_kv_heads
,
head_size
)
for
_
in
range
(
num_layers
):
tpu_k_cache
=
torch
.
zeros
(
tpu_cache_shape
,
dtype
=
dtype
,
device
=
self
.
device
)
tpu_v_cache
=
torch
.
zeros_like
(
tpu_k_cache
)
self
.
tpu_cache
.
append
((
tpu_k_cache
,
tpu_v_cache
))
cpu_k_cache
=
torch
.
zeros
(
cpu_cache_shape
,
dtype
=
dtype
,
device
=
"cpu"
)
cpu_v_cache
=
torch
.
zeros_like
(
cpu_k_cache
)
self
.
cpu_cache
.
append
((
cpu_k_cache
,
cpu_v_cache
))
bind_kv_cache
(
self
.
compilation_config
.
static_forward_context
,
[
self
.
tpu_cache
])
self
.
_warmup_model
()
def
_warmup_model
(
self
)
->
None
:
# FIXME(woosuk): Here we are abusing `enforce_eager` which is defined
# for CUDA graphs. We should refactor this part.
if
not
self
.
model_config
.
enforce_eager
:
# Warm up the model with all possible input shapes so that
# compilation never happens during the actual execution.
# This may take ~30 mins for the first run and ~20 mins for the
# subsequent runs.
# If `enforce_eager` is True, the ahead-of-time compilation is
# skipped and the compilation happens during the actual execution,
# which is bad for performance but useful for development.
self
.
model_runner
.
warmup_model
(
self
.
tpu_cache
)
def
get_cache_block_size_bytes
(
self
)
->
int
:
head_size
=
self
.
model_config
.
get_head_size
()
num_heads
=
self
.
model_config
.
get_num_kv_heads
(
self
.
parallel_config
)
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
key_cache_block
=
self
.
cache_config
.
block_size
*
num_heads
*
head_size
value_cache_block
=
key_cache_block
total
=
num_layers
*
(
key_cache_block
+
value_cache_block
)
dtype_size
=
get_dtype_size
(
self
.
cache_dtype
)
return
dtype_size
*
total
@
property
def
do_metadata_broadcast
(
self
)
->
bool
:
return
self
.
parallel_config
.
tensor_parallel_size
>
1
@
property
def
kv_cache
(
self
)
->
Optional
[
List
[
List
[
torch
.
Tensor
]]]:
# NOTE(woosuk): This assumes virtual_engine == 0, i.e., no pipeline
# parallelism.
return
[
self
.
tpu_cache
]
@
property
def
cache_engines
(
self
)
->
Optional
[
List
[
CacheEngine
]]:
return
None
def
prepare_worker_input
(
self
,
execute_model_req
:
ExecuteModelRequest
,
)
->
WorkerInput
:
virtual_engine
=
execute_model_req
.
virtual_engine
num_seq_groups
=
len
(
execute_model_req
.
seq_group_metadata_list
)
blocks_to_swap_in
=
_make_src_to_dst
(
execute_model_req
.
blocks_to_swap_in
,
"cpu"
,
self
.
device
)
blocks_to_swap_out
=
_make_src_to_dst
(
execute_model_req
.
blocks_to_swap_out
,
self
.
device
,
"cpu"
)
blocks_to_copy
=
_make_src_to_dst
(
execute_model_req
.
blocks_to_copy
,
self
.
device
,
self
.
device
)
return
WorkerInput
(
num_seq_groups
=
num_seq_groups
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
virtual_engine
=
virtual_engine
,
)
def
execute_worker
(
self
,
worker_input
:
WorkerInput
)
->
None
:
virtual_engine
=
worker_input
.
virtual_engine
assert
virtual_engine
==
0
attn_backend
=
self
.
model_runner
.
attn_backend
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
# Issue cache operations.
if
worker_input
.
blocks_to_swap_in
is
not
None
:
src_indices
,
dst_indices
=
worker_input
.
blocks_to_swap_in
if
src_indices
.
numel
()
>
0
:
# Swap from CPU to TPU.
for
i
in
range
(
num_layers
):
tpu_k_cache
,
tpu_v_cache
=
self
.
tpu_cache
[
i
]
cpu_k_cache
,
cpu_v_cache
=
self
.
cpu_cache
[
i
]
k
=
cpu_k_cache
[:,
src_indices
].
to
(
self
.
device
)
v
=
cpu_v_cache
[:,
src_indices
].
to
(
self
.
device
)
_insert_kv
(
k
,
v
,
dst_indices
,
tpu_k_cache
,
tpu_v_cache
)
if
worker_input
.
blocks_to_swap_out
is
not
None
:
src_indices
,
dst_indices
=
worker_input
.
blocks_to_swap_out
if
src_indices
.
numel
()
>
0
:
# Swap from TPU to CPU.
for
i
in
range
(
num_layers
):
tpu_k_cache
,
tpu_v_cache
=
self
.
tpu_cache
[
i
]
cpu_k_cache
,
cpu_v_cache
=
self
.
cpu_cache
[
i
]
cpu_k_cache
[:,
dst_indices
]
=
tpu_k_cache
[:,
src_indices
]
cpu_v_cache
[:,
dst_indices
]
=
tpu_v_cache
[:,
src_indices
]
if
worker_input
.
blocks_to_copy
is
not
None
:
src_indices
,
dst_indices
=
worker_input
.
blocks_to_copy
if
src_indices
.
numel
()
>
0
:
attn_backend
.
copy_blocks
(
self
.
tpu_cache
,
(
src_indices
,
dst_indices
))
def
_make_src_to_dst
(
mapping
:
List
[
Tuple
[
int
,
int
]],
src_device
:
Union
[
torch
.
device
,
str
],
dst_device
:
Union
[
torch
.
device
,
str
],
)
->
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
if
not
mapping
:
return
None
src_indices
=
[
i
for
i
,
_
in
mapping
]
dst_indices
=
[
i
for
_
,
i
in
mapping
]
src_indices
=
torch
.
tensor
(
src_indices
,
device
=
src_device
,
dtype
=
torch
.
int64
)
dst_indices
=
torch
.
tensor
(
dst_indices
,
device
=
dst_device
,
dtype
=
torch
.
int64
)
return
src_indices
,
dst_indices
@
torch
.
compile
(
backend
=
"openxla"
)
def
_insert_kv
(
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
,
tpu_k_cache
:
torch
.
Tensor
,
tpu_v_cache
:
torch
.
Tensor
,
)
->
None
:
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
tpu_k_cache
,
True
)
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
tpu_v_cache
,
True
)
tpu_k_cache
[:,
indices
]
=
k
tpu_v_cache
[:,
indices
]
=
v
vllm/worker/xpu_model_runner.py
deleted
100644 → 0
View file @
5ad884ee
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
dataclasses
import
time
import
weakref
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
TypeVar
)
import
torch
import
torch.nn
as
nn
from
vllm.attention
import
get_attn_backend
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_pp_group
from
vllm.forward_context
import
set_forward_context
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadataCache
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.model_loader
import
get_model
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensorInputs
,
MultiModalKwargs
,
MultiModalPlaceholderMap
,
MultiModalRegistry
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
IntermediateTensors
,
SequenceGroupMetadata
from
vllm.utils
import
DeviceMemoryProfiler
,
GiB_bytes
,
make_tensor_with_pad
from
vllm.worker.model_runner
import
AttentionMetadata
,
SamplingMetadata
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerInputBuilderBase
,
_add_attn_metadata_broadcastable_dict
,
_add_sampling_metadata_broadcastable_dict
,
_init_attn_metadata_from_tensor_dict
,
_init_sampling_metadata_from_tensor_dict
)
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
logger
=
init_logger
(
__name__
)
_PAD_SLOT_ID
=
-
1
TModelInputForXPU
=
TypeVar
(
'TModelInputForXPU'
,
bound
=
"ModelInputForXPU"
)
@
dataclass
(
frozen
=
True
)
class
ModelInputForXPU
(
ModelRunnerInputBase
):
"""
Used by the NeuronModelRunner.
"""
input_tokens
:
Optional
[
torch
.
Tensor
]
=
None
input_positions
:
Optional
[
torch
.
Tensor
]
=
None
attn_metadata
:
Optional
[
"AttentionMetadata"
]
=
None
multi_modal_kwargs
:
Optional
[
BatchedTensorInputs
]
=
None
virtual_engine
:
Optional
[
int
]
=
None
seq_lens
:
Optional
[
List
[
int
]]
=
None
query_lens
:
Optional
[
List
[
int
]]
=
None
async_callback
:
Optional
[
Callable
]
=
None
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
tensor_dict
=
{
"input_tokens"
:
self
.
input_tokens
,
"input_positions"
:
self
.
input_positions
,
}
_add_attn_metadata_broadcastable_dict
(
tensor_dict
,
self
.
attn_metadata
)
return
tensor_dict
@
classmethod
def
from_broadcasted_tensor_dict
(
cls
:
Type
[
TModelInputForXPU
],
tensor_dict
:
Dict
[
str
,
Any
],
attn_backend
:
Optional
[
"AttentionBackend"
]
=
None
,
)
->
TModelInputForXPU
:
if
attn_backend
is
not
None
:
tensor_dict
=
_init_attn_metadata_from_tensor_dict
(
attn_backend
,
tensor_dict
)
return
cls
(
**
tensor_dict
)
@
dataclass
(
frozen
=
True
)
class
ModelInputForXPUWithSamplingMetadata
(
ModelInputForXPU
):
"""
Used by the ModelRunner.
"""
sampling_metadata
:
Optional
[
"SamplingMetadata"
]
=
None
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
tensor_dict
=
{
"input_tokens"
:
self
.
input_tokens
,
"input_positions"
:
self
.
input_positions
,
}
_add_attn_metadata_broadcastable_dict
(
tensor_dict
,
self
.
attn_metadata
)
_add_sampling_metadata_broadcastable_dict
(
tensor_dict
,
self
.
sampling_metadata
)
return
tensor_dict
@
classmethod
def
from_broadcasted_tensor_dict
(
cls
,
tensor_dict
:
Dict
[
str
,
Any
],
attn_backend
:
Optional
[
"AttentionBackend"
]
=
None
,
)
->
"ModelInputForXPUWithSamplingMetadata"
:
tensor_dict
=
_init_sampling_metadata_from_tensor_dict
(
tensor_dict
)
if
attn_backend
is
not
None
:
tensor_dict
=
_init_attn_metadata_from_tensor_dict
(
attn_backend
,
tensor_dict
)
return
cls
(
**
tensor_dict
)
class
ModelInputForXPUBuilder
(
ModelRunnerInputBuilderBase
[
ModelInputForXPU
]):
def
__init__
(
self
,
runner
:
"XPUModelRunner"
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
None
:
super
().
__init__
()
self
.
runner
=
runner
self
.
model_input_cls
=
self
.
runner
.
_model_input_cls
self
.
attn_backend
=
self
.
runner
.
attn_backend
self
.
sliding_window
=
self
.
runner
.
sliding_window
self
.
block_size
=
self
.
runner
.
block_size
self
.
device
=
self
.
runner
.
device
def
prepare
(
self
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
None
:
self
.
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
def
add_seq_group
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
):
self
.
seq_group_metadata_list
.
append
(
seq_group_metadata
)
def
build
(
self
)
->
ModelInputForXPU
:
is_prompt
=
self
.
seq_group_metadata_list
[
0
].
is_prompt
# Prepare input tensors.
if
is_prompt
:
(
input_tokens
,
input_positions
,
attn_metadata
,
seq_lens
,
multi_modal_kwargs
)
=
self
.
_prepare_prompt
(
self
.
seq_group_metadata_list
)
else
:
(
input_tokens
,
input_positions
,
attn_metadata
)
=
self
.
_prepare_decode
(
self
.
seq_group_metadata_list
)
seq_lens
=
None
multi_modal_kwargs
=
None
return
self
.
model_input_cls
(
input_tokens
=
input_tokens
,
input_positions
=
input_positions
,
attn_metadata
=
attn_metadata
,
multi_modal_kwargs
=
multi_modal_kwargs
,
seq_lens
=
seq_lens
,
query_lens
=
seq_lens
,
)
def
_prepare_prompt
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
List
[
int
],
BatchedTensorInputs
]:
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
seq_lens
:
List
[
int
]
=
[]
multi_modal_kwargs_list
:
List
[
MultiModalKwargs
]
=
[]
multi_modal_placeholder_maps
:
Dict
[
str
,
MultiModalPlaceholderMap
]
=
defaultdict
(
MultiModalPlaceholderMap
)
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
seq_group_metadata
.
is_prompt
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
assert
len
(
seq_ids
)
==
1
seq_id
=
seq_ids
[
0
]
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
prompt_tokens
=
seq_data
.
get_token_ids
()
computed_len
=
seq_data
.
get_num_computed_tokens
()
seq_len
=
len
(
prompt_tokens
)
seq_lens
.
append
(
seq_len
)
# Prompt token num
input_tokens
.
extend
(
prompt_tokens
)
# Token ids
# Token position ids
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
positions_range
=
range
(
computed_len
,
seq_len
)
input_positions
.
extend
(
list
(
positions_range
))
if
seq_group_metadata
.
multi_modal_data
:
# NOTE: mm_kwargs only includes the subset of multi-modal items
# that intersect with the current prefill positions.
mm_kwargs
,
placeholder_maps
=
MultiModalPlaceholderMap
\
.
from_seq_group
(
seq_group_metadata
,
positions_range
)
multi_modal_kwargs_list
.
append
(
mm_kwargs
)
for
modality
,
placeholder_map
in
placeholder_maps
.
items
():
multi_modal_placeholder_maps
[
modality
].
extend
(
placeholder_map
)
if
seq_group_metadata
.
block_tables
is
None
:
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
slot_mapping
.
extend
([
_PAD_SLOT_ID
]
*
seq_len
)
continue
# Compute the slot mapping.
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, seq_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx
=
0
if
self
.
sliding_window
is
not
None
:
start_idx
=
max
(
0
,
seq_len
-
self
.
sliding_window
)
for
i
in
range
(
computed_len
,
seq_len
):
if
i
<
start_idx
:
slot_mapping
.
append
(
_PAD_SLOT_ID
)
continue
block_number
=
block_table
[
i
//
self
.
block_size
]
# type: ignore
block_offset
=
i
%
self
.
block_size
# type: ignore
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
num_prompt_tokens
=
len
(
input_tokens
)
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
# type: ignore
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
# type: ignore
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
# type: ignore
placeholder_index_maps
=
{
modality
:
placeholder_map
.
index_map
()
for
modality
,
placeholder_map
in
multi_modal_placeholder_maps
.
items
()
}
max_seqlen
=
max
(
seq_lens
)
tmp
=
[
0
]
tmp
.
extend
(
seq_lens
)
seqlen
=
torch
.
tensor
(
tmp
)
seqlen_q
=
torch
.
cumsum
(
seqlen
,
dim
=
0
).
to
(
device
=
self
.
device
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
True
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
placeholder_index_maps
,
enable_kv_scales_calculation
=
False
,
seq_lens
=
seq_lens
,
seqlen_q
=
seqlen_q
,
max_seqlen
=
max_seqlen
,
seq_lens_tensor
=
torch
.
tensor
([]),
max_decode_seq_len
=
0
,
num_prefills
=
len
(
seq_lens
),
num_prefill_tokens
=
num_prompt_tokens
,
num_decode_tokens
=
0
,
block_tables
=
torch
.
tensor
([],
device
=
self
.
device
,
dtype
=
torch
.
int
),
)
multi_modal_kwargs
=
MultiModalKwargs
.
batch
(
multi_modal_kwargs_list
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
seq_lens
,
multi_modal_kwargs
)
def
_prepare_decode
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
]:
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
seq_lens
:
List
[
int
]
=
[]
block_tables
:
List
[
List
[
int
]]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
not
seq_group_metadata
.
is_prompt
assert
seq_group_metadata
.
token_chunk_size
==
1
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
for
seq_id
in
seq_ids
:
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
generation_token
=
seq_data
.
get_last_token_id
()
input_tokens
.
append
(
generation_token
)
seq_len
=
seq_data
.
get_len
()
position
=
seq_len
-
1
input_positions
.
append
(
position
)
seq_len
=
seq_len
if
self
.
sliding_window
is
None
else
min
(
seq_len
,
self
.
sliding_window
)
seq_lens
.
append
(
seq_len
)
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
block_number
=
block_table
[
position
//
self
.
block_size
]
block_offset
=
position
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
if
self
.
sliding_window
is
not
None
:
sliding_window_blocks
=
(
self
.
sliding_window
//
self
.
block_size
)
block_table
=
block_table
[
-
sliding_window_blocks
:]
block_tables
.
append
(
block_table
)
max_decode_seq_len
=
max
(
seq_lens
)
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
block_tables
=
make_tensor_with_pad
(
block_tables
,
pad
=
0
,
dtype
=
torch
.
int
,
device
=
self
.
device
,
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
False
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
None
,
enable_kv_scales_calculation
=
False
,
seq_lens
=
seq_lens
,
seqlen_q
=
torch
.
tensor
([]),
max_seqlen
=
0
,
seq_lens_tensor
=
seq_lens_tensor
,
max_decode_seq_len
=
max_decode_seq_len
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
len
(
input_tokens
),
num_prefills
=
0
,
block_tables
=
block_tables
,
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
)
class
XPUModelRunner
(
ModelRunnerBase
[
ModelInputForXPUWithSamplingMetadata
]):
_model_input_cls
:
Type
[
ModelInputForXPUWithSamplingMetadata
]
=
(
ModelInputForXPUWithSamplingMetadata
)
_builder_cls
:
Type
[
ModelInputForXPUBuilder
]
=
ModelInputForXPUBuilder
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
is_driver_worker
:
bool
=
False
,
return_hidden_states
:
bool
=
False
,
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
mm_registry
:
MultiModalRegistry
=
MULTIMODAL_REGISTRY
,
):
ModelRunnerBase
.
__init__
(
self
,
vllm_config
=
vllm_config
)
model_config
=
self
.
model_config
cache_config
=
self
.
cache_config
self
.
is_driver_worker
=
is_driver_worker
self
.
return_hidden_states
=
return_hidden_states
self
.
device
=
self
.
device_config
.
device
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
sliding_window
=
model_config
.
get_sliding_window
()
self
.
block_size
=
cache_config
.
block_size
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
dtype
,
self
.
kv_cache_dtype
,
self
.
block_size
,
self
.
model_config
.
is_attention_free
,
)
# Multi-modal data support
self
.
input_registry
=
input_registry
self
.
mm_registry
=
mm_registry
# Lazy initialization.
self
.
model
:
nn
.
Module
# Set after init_Model
self
.
sampler
=
get_sampler
()
self
.
sampling_metadata_cache
:
SamplingMetadataCache
=
\
SamplingMetadataCache
()
\
if
self
.
parallel_config
.
pipeline_parallel_size
==
1
else
None
self
.
builder
=
self
.
_builder_cls
(
weakref
.
proxy
(
self
))
def
load_model
(
self
)
->
None
:
with
DeviceMemoryProfiler
()
as
m
:
self
.
model
=
get_model
(
vllm_config
=
self
.
vllm_config
)
self
.
model_memory_usage
=
m
.
consumed_memory
logger
.
info
(
"Loading model weights took %.4f GiB"
,
self
.
model_memory_usage
/
GiB_bytes
)
def
get_model
(
self
)
->
nn
.
Module
:
return
self
.
model
@
property
def
vocab_size
(
self
)
->
int
:
return
self
.
model_config
.
get_vocab_size
()
@
torch
.
inference_mode
()
def
profile_run
(
self
)
->
None
:
# Enable top-k sampling to reflect the accurate memory usage.
sampling_params
=
SamplingParams
(
top_p
=
0.99
,
top_k
=
self
.
vocab_size
-
1
)
max_num_batched_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
max_num_seqs
=
self
.
scheduler_config
.
max_num_seqs
# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
seqs
:
List
[
SequenceGroupMetadata
]
=
[]
# Additional GPU memory may be needed for multi-modal encoding, which
# needs to be accounted for when calculating the GPU blocks for
# vLLM blocker manager.
# To exercise the worst scenario for GPU memory consumption,
# the number of seqs (batch_size) is chosen to maximize the number
# of images processed.
max_mm_tokens
=
self
.
mm_registry
.
get_max_multimodal_tokens
(
self
.
model_config
)
if
max_mm_tokens
>
0
:
max_num_seqs_orig
=
max_num_seqs
max_num_seqs
=
min
(
max_num_seqs
,
max_num_batched_tokens
//
max_mm_tokens
)
if
max_num_seqs
<
1
:
expr
=
(
f
"min(
{
max_num_seqs_orig
}
, "
f
"
{
max_num_batched_tokens
}
//
{
max_mm_tokens
}
)"
)
logger
.
warning
(
"Computed max_num_seqs (%s) to be less than 1. "
"Setting it to the minimum value of 1."
,
expr
)
max_num_seqs
=
1
batch_size
=
0
for
group_id
in
range
(
max_num_seqs
):
seq_len
=
(
max_num_batched_tokens
//
max_num_seqs
+
(
group_id
<
max_num_batched_tokens
%
max_num_seqs
))
batch_size
+=
seq_len
dummy_data
=
self
.
input_registry
\
.
dummy_data_for_profiling
(
self
.
model_config
,
seq_len
,
self
.
mm_registry
)
seq
=
SequenceGroupMetadata
(
request_id
=
str
(
group_id
),
is_prompt
=
True
,
seq_data
=
{
group_id
:
dummy_data
.
seq_data
},
sampling_params
=
sampling_params
,
block_tables
=
None
,
lora_request
=
None
,
multi_modal_data
=
dummy_data
.
multi_modal_data
,
multi_modal_placeholders
=
dummy_data
.
multi_modal_placeholders
)
seqs
.
append
(
seq
)
finished_requests_ids
=
[
seq
.
request_id
for
seq
in
seqs
]
model_input
=
self
.
prepare_model_input
(
seqs
,
finished_requests_ids
=
finished_requests_ids
)
intermediate_tensors
=
None
if
not
get_pp_group
().
is_first_rank
:
intermediate_tensors
=
self
.
model
.
make_empty_intermediate_tensors
(
batch_size
=
batch_size
,
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
)
self
.
execute_model
(
model_input
,
None
,
intermediate_tensors
)
torch
.
xpu
.
synchronize
()
return
def
make_model_input_from_broadcasted_tensor_dict
(
self
,
tensor_dict
:
Dict
[
str
,
Any
])
->
ModelInputForXPUWithSamplingMetadata
:
return
(
ModelInputForXPUWithSamplingMetadata
.
from_broadcasted_tensor_dict
(
tensor_dict
,
attn_backend
=
self
.
attn_backend
,
))
def
_prepare_model_input_tensors
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
ModelInputForXPUWithSamplingMetadata
:
"""Helper method to prepare the model input based on a given sequence
group. Prepares metadata needed for the base model forward pass but not
metadata for possible additional steps, e.g., sampling.
"""
builder
=
self
.
builder
builder
.
prepare
(
finished_requests_ids
)
for
seq_group_metadata
in
seq_group_metadata_list
:
builder
.
add_seq_group
(
seq_group_metadata
)
return
builder
.
build
()
# type: ignore
def
prepare_model_input
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
ModelInputForXPUWithSamplingMetadata
:
"""Prepare the model input based on a given sequence group, including
metadata for the sampling step.
"""
model_input
=
self
.
_prepare_model_input_tensors
(
seq_group_metadata_list
,
finished_requests_ids
)
# Sampling metadata is only required for the final pp group
generators
=
self
.
get_generators
(
finished_requests_ids
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
model_input
.
seq_lens
,
model_input
.
query_lens
,
self
.
device
,
pin_memory
=
False
,
generators
=
generators
,
cache
=
self
.
sampling_metadata_cache
)
return
dataclasses
.
replace
(
model_input
,
sampling_metadata
=
sampling_metadata
,
virtual_engine
=
virtual_engine
)
@
torch
.
inference_mode
()
def
execute_model
(
self
,
model_input
:
ModelInputForXPUWithSamplingMetadata
,
kv_caches
:
List
[
torch
.
Tensor
],
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
num_steps
:
int
=
1
,
)
->
Optional
[
List
[
SamplerOutput
]]:
if
num_steps
>
1
:
raise
ValueError
(
"XPUModelRunner does not support multi-step execution."
)
model_executable
=
self
.
model
if
(
self
.
observability_config
is
not
None
and
self
.
observability_config
.
collect_model_forward_time
):
model_forward_start_time
=
time
.
time
()
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
,
model_input
.
virtual_engine
):
hidden_or_intermediate_states
=
model_executable
(
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
intermediate_tensors
=
intermediate_tensors
,
**
MultiModalKwargs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
or
{},
device
=
self
.
device
,
),
)
# Compute the logits in the last pipeline stage.
if
not
get_pp_group
().
is_last_rank
:
return
hidden_or_intermediate_states
if
(
self
.
observability_config
is
not
None
and
self
.
observability_config
.
collect_model_forward_time
):
model_forward_end_time
=
time
.
time
()
# Compute the logits.
logits
=
self
.
model
.
compute_logits
(
hidden_or_intermediate_states
,
model_input
.
sampling_metadata
)
# Only perform sampling in the driver worker.
if
not
self
.
is_driver_worker
:
return
[]
if
model_input
.
async_callback
is
not
None
:
model_input
.
async_callback
()
# Sample the next token.
output
:
SamplerOutput
=
self
.
sampler
(
logits
=
logits
,
sampling_metadata
=
model_input
.
sampling_metadata
,
)
if
(
self
.
observability_config
is
not
None
and
self
.
observability_config
.
collect_model_forward_time
and
output
is
not
None
):
model_forward_time
=
(
model_forward_end_time
-
model_forward_start_time
)
# If there are multiple workers, we are still tracking the latency
# from the start time of the driver worker to the end time of the
# driver worker. The model forward time will then end up covering
# the communication time as well.
output
.
model_forward_time
=
model_forward_time
return
[
output
]
vllm/worker/xpu_worker.py
deleted
100644 → 0
View file @
5ad884ee
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A XPU worker class."""
import
gc
import
os
from
typing
import
List
,
Optional
,
Tuple
import
intel_extension_for_pytorch
# noqa: F401
import
oneccl_bindings_for_pytorch
# noqa: F401
import
torch
import
torch.distributed
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
)
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
set_random_seed
from
vllm.platforms
import
current_platform
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.worker
import
Worker
from
vllm.worker.worker_base
import
LoRANotSupportedWorkerBase
,
WorkerBase
from
vllm.worker.xpu_model_runner
import
XPUModelRunner
logger
=
init_logger
(
__name__
)
class
XPUWorker
(
LoRANotSupportedWorkerBase
,
Worker
):
"""A worker class that executes (a partition of) the model on a GPU.
Each worker is associated with a single XPU device. The worker is
responsible for maintaining the KV cache and executing the model on the
XPU. In case of distributed inference, each worker is assigned a partition
of the model.
"""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
local_rank
:
int
,
rank
:
int
,
distributed_init_method
:
str
,
is_driver_worker
:
bool
=
False
,
)
->
None
:
WorkerBase
.
__init__
(
self
,
vllm_config
=
vllm_config
)
device_config
=
self
.
device_config
parallel_config
=
self
.
parallel_config
assert
device_config
.
device_type
==
"xpu"
assert
current_platform
.
is_xpu
()
self
.
parallel_config
.
rank
=
rank
self
.
local_rank
=
local_rank
self
.
rank
=
rank
self
.
distributed_init_method
=
distributed_init_method
self
.
is_driver_worker
=
is_driver_worker
if
parallel_config
and
is_driver_worker
:
assert
rank
%
parallel_config
.
tensor_parallel_size
==
0
,
\
"Driver worker should be rank 0 of tensor parallel group."
self
.
model_runner
=
XPUModelRunner
(
# type: ignore
vllm_config
=
vllm_config
,
kv_cache_dtype
=
self
.
cache_config
.
cache_dtype
,
is_driver_worker
=
is_driver_worker
,
)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self
.
cache_engine
:
List
[
CacheEngine
]
self
.
gpu_cache
:
Optional
[
List
[
List
[
torch
.
Tensor
]]]
def
init_device
(
self
)
->
None
:
if
self
.
device_config
.
device
.
type
==
"xpu"
and
current_platform
.
is_xpu
(
):
self
.
device
=
torch
.
device
(
f
"xpu:
{
self
.
local_rank
}
"
)
torch
.
xpu
.
set_device
(
self
.
device
)
torch
.
xpu
.
empty_cache
()
self
.
init_gpu_memory
=
torch
.
xpu
.
get_device_properties
(
self
.
local_rank
).
total_memory
else
:
raise
RuntimeError
(
f
"Not support device type:
{
self
.
device_config
.
device
}
"
)
# Initialize the distributed environment.
self
.
init_worker_distributed_environment
()
# Initialize the model.
set_random_seed
(
self
.
model_config
.
seed
)
# keep this method for `empty_cache` and `synchronize` api
@
torch
.
inference_mode
()
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
Tip:
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch
.
xpu
.
empty_cache
()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
self
.
model_runner
.
profile_run
()
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
torch
.
xpu
.
synchronize
()
used_memory
=
torch
.
xpu
.
memory_allocated
()
total_gpu_memory
=
torch
.
xpu
.
get_device_properties
(
self
.
local_rank
).
total_memory
free_gpu_memory
=
total_gpu_memory
-
used_memory
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
peak_memory
=
self
.
init_gpu_memory
-
free_gpu_memory
assert
peak_memory
>
0
,
(
"Error in memory profiling. "
f
"Initial free memory
{
self
.
init_gpu_memory
}
, current free memory"
f
"
{
free_gpu_memory
}
. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance."
)
cache_block_size
=
self
.
get_cache_block_size_bytes
()
num_gpu_blocks
=
int
(
(
total_gpu_memory
*
self
.
cache_config
.
gpu_memory_utilization
-
peak_memory
)
//
cache_block_size
)
num_cpu_blocks
=
int
(
self
.
cache_config
.
swap_space_bytes
//
cache_block_size
)
num_gpu_blocks
=
max
(
num_gpu_blocks
,
0
)
num_cpu_blocks
=
max
(
num_cpu_blocks
,
0
)
gc
.
collect
()
torch
.
xpu
.
empty_cache
()
return
num_gpu_blocks
,
num_cpu_blocks
def
_warm_up_model
(
self
)
->
None
:
# IPEX don't support capture graph yet
pass
def
init_worker_distributed_environment
(
self
)
->
None
:
"""Initialize the distributed environment."""
parallel_config
=
self
.
parallel_config
rank
=
self
.
rank
distributed_init_method
=
self
.
distributed_init_method
if
torch
.
distributed
.
is_initialized
():
torch_world_size
=
torch
.
distributed
.
get_world_size
()
if
torch_world_size
!=
parallel_config
.
world_size
:
raise
RuntimeError
(
"torch.distributed is already initialized but the torch "
"world size does not match parallel_config.world_size "
f
"(
{
torch_world_size
}
vs.
{
parallel_config
.
world_size
}
)."
)
elif
not
distributed_init_method
:
raise
ValueError
(
"distributed_init_method must be set if torch.distributed "
"is not already initialized"
)
else
:
# use sockets as default Level zero IPC exchange backend. By
# default oneccl will use `drmfd` as mechanism which need extra
# dependency (libdrm and drm headers) on your system.
ENV_CCL_ATL_TRANSPORT
=
os
.
getenv
(
"CCL_ATL_TRANSPORT"
,
"ofi"
)
ENV_LOCAL_WORLD_SIZE
=
os
.
getenv
(
"LOCAL_WORLD_SIZE"
,
str
(
parallel_config
.
world_size
))
os
.
environ
[
"CCL_ATL_TRANSPORT"
]
=
ENV_CCL_ATL_TRANSPORT
os
.
environ
[
"LOCAL_WORLD_SIZE"
]
=
ENV_LOCAL_WORLD_SIZE
os
.
environ
[
"LOCAL_RANK"
]
=
str
(
self
.
local_rank
)
init_distributed_environment
(
world_size
=
parallel_config
.
world_size
,
rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
local_rank
=
self
.
local_rank
,
backend
=
"ccl"
)
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
parallel_config
.
pipeline_parallel_size
)
# global all_reduce needed for overall oneccl warm up
torch
.
distributed
.
all_reduce
(
torch
.
zeros
(
1
).
xpu
())
if
parallel_config
.
pipeline_parallel_size
>
1
:
# Add pp group init to avoid
# p2p communication as the first call
get_pp_group
().
all_reduce
(
torch
.
zeros
(
1
).
xpu
())
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment