Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a0dd7dcd
Unverified
Commit
a0dd7dcd
authored
Mar 25, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Mar 25, 2025
Browse files
[TPU][V1] Fix Sampler recompilation (#15309)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
e977c111
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
84 additions
and
127 deletions
+84
-127
vllm/v1/sample/tpu/metadata.py
vllm/v1/sample/tpu/metadata.py
+71
-104
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+13
-23
No files found.
vllm/v1/sample/tpu/metadata.py
View file @
a0dd7dcd
...
@@ -5,7 +5,18 @@ from typing import Optional
...
@@ -5,7 +5,18 @@ from typing import Optional
import
torch
import
torch
import
torch_xla.core.xla_model
as
xm
import
torch_xla.core.xla_model
as
xm
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
DEFAULT_SAMPLING_PARAMS
=
dict
(
temperature
=-
1.0
,
min_p
=
0.0
,
# strictly disabled for now
# top_k=-1,
# top_p=0.0,
# frequency_penalties=0.0,
# presence_penalties=0.0,
# repetition_penalties=0.0,
)
@
dataclass
@
dataclass
...
@@ -20,14 +31,8 @@ class TPUSupportedSamplingMetadata:
...
@@ -20,14 +31,8 @@ class TPUSupportedSamplingMetadata:
top_k
:
torch
.
Tensor
=
None
top_k
:
torch
.
Tensor
=
None
top_p
:
torch
.
Tensor
=
None
top_p
:
torch
.
Tensor
=
None
# XLA-unfriendly control flow in Sampler
all_greedy
:
bool
=
False
all_random
:
bool
=
False
# Greedy sampling flag for compiling single xla graph.
# Greedy sampling flag for compiling single xla graph.
do_argmax
:
torch
.
Tensor
=
None
all_greedy
:
torch
.
Tensor
=
None
# speculation not supported
spec_token_ids
=
None
# Generator not supported by xla
# Generator not supported by xla
generators
:
dict
[
int
,
generators
:
dict
[
int
,
...
@@ -54,106 +59,68 @@ class TPUSupportedSamplingMetadata:
...
@@ -54,106 +59,68 @@ class TPUSupportedSamplingMetadata:
bad_words_token_ids
=
None
bad_words_token_ids
=
None
indices_do_sample
:
torch
.
Tensor
=
None
indices_do_sample
:
torch
.
Tensor
=
None
def
__post_init__
(
self
):
temp
=
self
.
temperature
if
self
.
indices_do_sample
is
None
:
self
.
indices_do_sample
=
torch
.
zeros
(
temp
.
shape
[
0
],
device
=
temp
.
device
,
dtype
=
torch
.
int32
)
if
self
.
do_argmax
is
None
:
self
.
do_argmax
=
torch
.
tensor
(
0
,
dtype
=
torch
.
bool
,
device
=
temp
.
device
)
@
classmethod
@
classmethod
def
from_sampling_metadata
(
def
from_input_batch
(
cls
,
metadata
:
SamplingMetadata
,
cls
,
input_batch
:
InputBatch
,
padded_do_sample_indices
:
torch
.
Tensor
,
num_do_sample
:
int
,
indices_do_sample
:
torch
.
Tensor
)
->
"TPUSupportedSamplingMetadata"
:
device
:
torch
.
device
)
->
"TPUSupportedSamplingMetadata"
:
"""
"""
Create an XLA-frienly SamplingMetadata structure. Do so by first
Copy sampling tensors slices from `input_batch` to on device tensors.
instantiating an object with fixed-sized tensors and then writing the
values in input `metadata`. Do that only for non-None values so that
`InputBatch._make_sampling_metadata` causes recompilation on XLA as it
recompilation is not triggered for optional values (None/torch.Tensor).
slices dynamic shapes on device tensors. This impl moves the dynamic
ops to CPU and produces tensors of fixed `padded_num_reqs` size. It
In order to handle different sizes for the params that range from 1 up
also reuses the on-device persistent tensors managed in `input_batch`
to `max_num_seqs`, pad tensors to the closest pre-compiled shape.
to reduce waste.
Same thing for `padded_do_sample_indices`, which contains the indices
to be fed to the Sampler, padded to the closest pre-compiled shape.
`indices_do_sample` contains the indices to be fed to the Sampler,
normally one per request, here padded to the closest pre-compiled shape
Eg. pad to 4 temperature: [0.7, 0.2]=>[0.7, 0.2, 0.0, 0.0]
We expect sampling params tensors to be padded to the same fixed shape.
do_sample_indices: [4, 10]=>padded_do_sample_indices: [4, 10, 0, 0]
Eg. 3 requests, tensors padded to 4
temperature: [0.7, 0.2, 0.9]=>[0.7, 0.2, 0.9, 0.0]
sample indices: [4, 10, 11]=>indices_do_sample: [4, 10, 11, 0]
"""
"""
metadata
=
cls
.
_validate_sampling_metadata
(
metadata
)
num_reqs
=
input_batch
.
num_reqs
# NOTE we have to initialize default tensor-based params first and
padded_num_reqs
=
len
(
indices_do_sample
)
# skip None values altogether to produce the same xla graph.
num_samples
=
len
(
padded_do_sample_indices
)
def
copy_slice
(
cpu_tensor
:
torch
.
Tensor
,
tpu_tensor
:
torch
.
Tensor
,
do_argmax
=
torch
.
tensor
(
metadata
.
all_greedy
,
fill_val
)
->
torch
.
Tensor
:
dtype
=
torch
.
bool
,
# Copy slice from CPU to corresponding TPU pre-allocated tensor.
device
=
device
)
# Pad value is the default one.
new_metadata
=
cls
.
get_default_sampling_params
(
num_samples
,
device
,
cpu_tensor
[
num_reqs
:
padded_num_reqs
]
=
fill_val
indices_do_sample
=
\
tpu_tensor
[:
padded_num_reqs
]
=
cpu_tensor
[:
padded_num_reqs
]
padded_do_sample_indices
,
do_argmax
=
do_argmax
# NOTE NickLucche The sync CPU-TPU graph we produce here must be
)
# consistent. We can't have flags to skip copies or we'll end up
supported_params
=
\
# recompiling.
TPUSupportedSamplingMetadata
.
_get_default_params_values
()
copy_slice
(
input_batch
.
temperature_cpu_tensor
,
input_batch
.
temperature
,
# Copy input non-None values into `new_metadata` fixed-sized tensors.
DEFAULT_SAMPLING_PARAMS
[
"temperature"
])
for
p_name
in
supported_params
:
# TODO Temporarily disabled until sampling options are enabled
old_val
=
getattr
(
metadata
,
p_name
)
# copy_slice(input_batch.top_p_cpu_tensor, input_batch.top_p)
new_val
=
getattr
(
new_metadata
,
p_name
)
# copy_slice(input_batch.top_k_cpu_tensor, input_batch.top_k)
if
isinstance
(
old_val
,
torch
.
Tensor
):
copy_slice
(
input_batch
.
min_p_cpu_tensor
,
input_batch
.
min_p
,
new_val
[:
num_do_sample
]
=
old_val
DEFAULT_SAMPLING_PARAMS
[
"min_p"
])
setattr
(
new_metadata
,
p_name
,
new_val
)
# copy_slice(input_batch.frequency_penalties_cpu_tensor,
# input_batch.frequency_penalties)
# copy_slice(input_batch.presence_penalties_cpu_tensor,
# input_batch.presence_penalties)
# copy_slice(input_batch.repetition_penalties_cpu_tensor,
# input_batch.repetition_penalties)
xm
.
mark_step
()
xm
.
mark_step
()
xm
.
wait_device_ops
()
xm
.
wait_device_ops
()
return
new_metadata
@
classmethod
# Slice persistent device tensors to a fixed pre-compiled padded shape.
def
get_default_sampling_params
(
return
cls
(
cls
,
temperature
=
input_batch
.
temperature
[:
padded_num_reqs
],
num_samples
:
int
,
# Scalar tensor for xla-friendly tracing.
device
:
torch
.
device
,
all_greedy
=
torch
.
tensor
(
input_batch
.
all_greedy
,
indices_do_sample
=
None
,
dtype
=
torch
.
bool
,
do_argmax
=
None
)
->
"TPUSupportedSamplingMetadata"
:
device
=
input_batch
.
device
),
# As sampling happens on a single traced graph, options
# TODO enable more and avoid returning None values
# are "disabled" by having them evaluate to an Identity op.
top_p
=
None
,
# input_batch.top_p[:padded_num_reqs],
# Note that initialization is dependent on num_samples.
top_k
=
None
,
# input_batch.top_k[:padded_num_reqs],
sampling_metadata_disable_value
=
\
min_p
=
input_batch
.
min_p
[:
padded_num_reqs
],
TPUSupportedSamplingMetadata
.
_get_default_params_values
()
generators
=
input_batch
.
generators
,
init_kwargs
=
dict
()
indices_do_sample
=
indices_do_sample
)
for
p_name
,
(
default_val
,
dtype
)
in
sampling_metadata_disable_value
.
items
():
default_tensor
=
torch
.
full
((
num_samples
,
),
default_val
,
dtype
=
dtype
,
device
=
device
)
init_kwargs
[
p_name
]
=
default_tensor
return
cls
(
**
init_kwargs
,
indices_do_sample
=
indices_do_sample
,
do_argmax
=
do_argmax
)
@
staticmethod
def
_validate_sampling_metadata
(
sampling_metadata
:
SamplingMetadata
)
->
SamplingMetadata
:
if
sampling_metadata
.
all_greedy
:
# Set to None since #13587. Make sure default isn't overruled.
assert
sampling_metadata
.
temperature
is
None
return
sampling_metadata
@
staticmethod
def
_get_default_params_values
():
return
dict
(
# Since #13587 greedy sampling requires branching off which leads
# to separate graphs. We set temp to noop and handle argmax here.
temperature
=
(
1.0
,
torch
.
float32
),
min_p
=
(
0.0
,
torch
.
float32
),
# strictly disabled for now
# top_k=(-1, torch.int32),
# top_p=(0.0, torch.float32),
# frequency_penalties=(0.0, torch.float32),
# presence_penalties=(0.0, torch.float32),
# repetition_penalties=(0.0, torch.float32),
)
\ No newline at end of file
vllm/v1/worker/tpu_model_runner.py
View file @
a0dd7dcd
...
@@ -279,9 +279,6 @@ class TPUModelRunner:
...
@@ -279,9 +279,6 @@ class TPUModelRunner:
req_data
.
num_computed_tokens
)
req_data
.
num_computed_tokens
)
self
.
input_batch
.
block_table
.
append_row
(
req_data
.
new_block_ids
,
self
.
input_batch
.
block_table
.
append_row
(
req_data
.
new_block_ids
,
req_index
)
req_index
)
# Check if the batch has changed. If not, we can skip copying the
# sampling metadata from CPU to GPU.
batch_changed
=
len
(
removed_req_indices
)
>
0
or
len
(
req_ids_to_add
)
>
0
# Add the new or resumed requests to the persistent batch.
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
# The smaller empty indices are filled first.
...
@@ -300,9 +297,6 @@ class TPUModelRunner:
...
@@ -300,9 +297,6 @@ class TPUModelRunner:
if
removed_req_indices
:
if
removed_req_indices
:
self
.
input_batch
.
condense
(
removed_req_indices
)
self
.
input_batch
.
condense
(
removed_req_indices
)
# TODO This slices tensors to copy to device, triggering recompilation.
if
batch_changed
:
self
.
input_batch
.
refresh_sampling_metadata
()
return
len
(
unscheduled_req_ids
)
>
0
or
len
(
req_ids_to_add
)
>
0
return
len
(
unscheduled_req_ids
)
>
0
or
len
(
req_ids_to_add
)
>
0
def
get_model
(
self
)
->
nn
.
Module
:
def
get_model
(
self
)
->
nn
.
Module
:
...
@@ -597,14 +591,12 @@ class TPUModelRunner:
...
@@ -597,14 +591,12 @@ class TPUModelRunner:
# then the embedding layer is not included in the CUDA graph.
# then the embedding layer is not included in the CUDA graph.
input_ids
=
self
.
input_ids
input_ids
=
self
.
input_ids
inputs_embeds
=
None
inputs_embeds
=
None
sampling_metadata
=
self
.
input_batch
.
sampling_metadata
num_reqs
=
self
.
input_batch
.
num_reqs
num_reqs
=
self
.
input_batch
.
num_reqs
# NOTE (NickLucche) here we sync with TPU:
if there's any shape
# NOTE (NickLucche) here we sync with TPU:
sampling params tensors
#
mismatch in pre-processing, it will trigger a small
recompil
ation
#
are copied to device in chunks of p
re
-
compil
ed padded shape to
#
of the code thus far. Forward graph remains untouched
.
#
avoid recompilations
.
tpu_sampling_metadata
=
TPUSupportedSamplingMetadata
.
\
tpu_sampling_metadata
=
TPUSupportedSamplingMetadata
.
\
from_sampling_metadata
(
sampling_metadata
,
logits_indices
,
from_input_batch
(
self
.
input_batch
,
logits_indices
)
num_reqs
,
self
.
device
)
# Run the decoder
# Run the decoder
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
):
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
):
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
...
@@ -797,21 +789,19 @@ class TPUModelRunner:
...
@@ -797,21 +789,19 @@ class TPUModelRunner:
device
=
device
,
device
=
device
,
dtype
=
torch
.
bfloat16
)
dtype
=
torch
.
bfloat16
)
while
True
:
while
True
:
# Default metadata is an all_greedy setup. But since the
# `do_argmax` flag is a tensor, we still compile the full graph
meta
=
self
.
input_batch
.
sampling_metadata
indices
=
torch
.
zeros
(
indices
=
torch
.
zeros
(
num_reqs_to_sample
,
num_reqs_to_sample
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
,
device
=
device
,
)
)
xm
.
mark_step
()
sampling_meta
=
TPUSupportedSamplingMetadata
.
\
sampling_meta
=
TPUSupportedSamplingMetadata
.
\
from_sampling_metadata
(
meta
,
indices
,
from_input_batch
(
self
.
input_batch
,
indices
)
num_reqs_to_sample
,
device
)
logger
.
info
(
" -- num_tokens: %d, num_seqs: %d"
,
num_tokens
,
logger
.
info
(
" -- num_tokens: %d, num_seqs: %d"
,
num_tokens
,
num_reqs_to_sample
)
num_reqs_to_sample
)
self
.
model
.
sample_from_hidden
(
dummy_hidden
,
sampling_meta
)
out
=
self
.
model
.
sample_from_hidden
(
dummy_hidden
,
xm
.
mark_step
()
sampling_meta
)
out
=
out
.
cpu
()
if
num_reqs_to_sample
>=
self
.
max_num_reqs
:
if
num_reqs_to_sample
>=
self
.
max_num_reqs
:
break
break
num_reqs_to_sample
*=
2
num_reqs_to_sample
*=
2
...
@@ -910,6 +900,7 @@ class ModelWrapperV1(nn.Module):
...
@@ -910,6 +900,7 @@ class ModelWrapperV1(nn.Module):
return
hidden_states
return
hidden_states
# @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def
sample_from_hidden
(
def
sample_from_hidden
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
@@ -923,10 +914,9 @@ class ModelWrapperV1(nn.Module):
...
@@ -923,10 +914,9 @@ class ModelWrapperV1(nn.Module):
sample_hidden_states
=
\
sample_hidden_states
=
\
hidden_states
[
sampling_metadata
.
indices_do_sample
]
hidden_states
[
sampling_metadata
.
indices_do_sample
]
logits
=
self
.
compute_logits
(
sample_hidden_states
)
logits
=
self
.
compute_logits
(
sample_hidden_states
)
# Greedy sampling can't be run without branching the graph on Sampler.
# Optimized greedy sampling branch, tracing both paths in a single pass
# Therefore do_argmax/all_greedy is checked here in a xla-friendly way.
# NOTE all_greedy is a scalar, this is just an optimized if/else.
# NOTE do_argmax is a scalar, this is just an optimized if/else.
out_tokens
=
torch
.
where
(
sampling_metadata
.
all_greedy
,
out_tokens
=
torch
.
where
(
sampling_metadata
.
do_argmax
,
torch
.
argmax
(
logits
,
dim
=-
1
,
keepdim
=
True
),
torch
.
argmax
(
logits
,
dim
=-
1
,
keepdim
=
True
),
self
.
sample
(
logits
,
sampling_metadata
)
\
self
.
sample
(
logits
,
sampling_metadata
)
\
.
sampled_token_ids
)
.
sampled_token_ids
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment