Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4098b722
Unverified
Commit
4098b722
authored
Mar 27, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Mar 27, 2025
Browse files
[Bugfix][TPU][V1] Fix recompilation (#15553)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
46450b8d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
15 additions
and
74 deletions
+15
-74
.buildkite/run-tpu-v1-test.sh
.buildkite/run-tpu-v1-test.sh
+3
-1
tests/v1/tpu/test_sampler.py
tests/v1/tpu/test_sampler.py
+5
-64
vllm/v1/sample/tpu/metadata.py
vllm/v1/sample/tpu/metadata.py
+1
-7
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+6
-2
No files found.
.buildkite/run-tpu-v1-test.sh
View file @
4098b722
...
...
@@ -32,7 +32,9 @@ docker run --privileged --net host --shm-size=16G -it \
&& echo TEST_5
\
&& python3 /workspace/vllm/examples/offline_inference/tpu.py
\
&& echo TEST_6
\
&& pytest -s -v /workspace/vllm/tests/tpu/worker/test_tpu_model_runner.py"
\
&& pytest -s -v /workspace/vllm/tests/tpu/worker/test_tpu_model_runner.py
\
&& echo TEST_7
\
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py"
\
# TODO: This test fails because it uses RANDOM_SEED sampling
...
...
tests/v1/tpu/test_sampler.py
View file @
4098b722
# SPDX-License-Identifier: Apache-2.0
import
tempfile
from
time
import
time
import
pytest
from
vllm
import
LLM
,
envs
...
...
@@ -15,60 +12,6 @@ if not envs.VLLM_USE_V1:
)
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"D4nt3/Qwen2.5-two-layers"
])
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_tpu
(),
reason
=
"This test needs a TPU"
)
def
test_sampler_compilation
(
model_name
:
str
,
monkeypatch
):
"""
Check that no recompilation happens despite changing sampling parameters.
We can't read XLA metrics from the engine process, hence we measure time.
"""
with
tempfile
.
TemporaryDirectory
()
as
temp_dir
:
monkeypatch
.
setenv
(
"VLLM_XLA_CACHE_PATH"
,
temp_dir
)
# Compiling model init may still take some time, enforce_eager to skip.
llm
=
LLM
(
model_name
,
enforce_eager
=
True
,
max_num_seqs
=
16
,
max_model_len
=
1024
,
gpu_memory_utilization
=
0.5
)
prompts
=
[
"A robot may not injure a human being"
,
"It is only with the heart that one can see rightly;"
,
]
# First inference should be slow
sampling_params
=
SamplingParams
(
temperature
=
0.7
,
# top_p=0.6, # TODO too slow!
top_k
=
10
,
min_p
=
0.2
,
max_tokens
=
16
)
s
=
time
()
_
=
llm
.
generate
(
prompts
,
sampling_params
)
run1
=
time
()
-
s
# Second request with different params, but for which we
# compiled for in previous eager iteration.
sampling_params
=
SamplingParams
(
temperature
=
0.1
,
top_k
=
12
,
min_p
=
0.8
,
max_tokens
=
24
)
s
=
time
()
_
=
llm
.
generate
(
prompts
,
sampling_params
)
run2
=
time
()
-
s
# Much faster after compiling
assert
run1
*
0.1
>
run2
print
(
"TIMES"
,
run1
,
run2
)
# Third request with min_p set to "None". It will not trigger
# recompilation as a default 0 value will be used.
sampling_params
=
SamplingParams
(
max_tokens
=
24
,
temperature
=
0.0
)
s
=
time
()
_
=
llm
.
generate
(
prompts
,
sampling_params
)
run3
=
time
()
-
s
assert
run1
*
0.1
>
run3
print
(
"TIMES"
,
run1
,
run3
)
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"Qwen/Qwen2.5-1.5B-Instruct"
])
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_tpu
(),
reason
=
"This test needs a TPU"
)
...
...
@@ -77,13 +20,11 @@ def test_sampler_different(model_name: str):
Test significantly different sampling params to assert the model produces
different results.
"""
llm
=
LLM
(
model_name
,
enforce_eager
=
True
,
max_num_seqs
=
1
,
max_model_len
=
64
,
# TODO: setting to 0.5 or it will go OOM
gpu_memory_utilization
=
0.5
)
llm
=
LLM
(
model_name
,
enforce_eager
=
False
,
max_num_seqs
=
1
,
max_model_len
=
512
,
max_num_batched_tokens
=
512
)
prompts
=
[
"Write a short story about a robot that dreams for the first time."
]
...
...
vllm/v1/sample/tpu/metadata.py
View file @
4098b722
...
...
@@ -88,6 +88,7 @@ class TPUSupportedSamplingMetadata:
# Copy slice from CPU to corresponding TPU pre-allocated tensor.
# Pad value is the default one.
cpu_tensor
[
num_reqs
:
padded_num_reqs
]
=
fill_val
# Subtle compilation: len(tpu_tensor) must be >= `padded_num_reqs`
tpu_tensor
[:
padded_num_reqs
]
=
cpu_tensor
[:
padded_num_reqs
]
# NOTE NickLucche The sync CPU-TPU graph we produce here must be
...
...
@@ -101,13 +102,6 @@ class TPUSupportedSamplingMetadata:
copy_slice
(
input_batch
.
min_p_cpu_tensor
,
input_batch
.
min_p
,
DEFAULT_SAMPLING_PARAMS
[
"min_p"
])
# copy_slice(input_batch.frequency_penalties_cpu_tensor,
# input_batch.frequency_penalties)
# copy_slice(input_batch.presence_penalties_cpu_tensor,
# input_batch.presence_penalties)
# copy_slice(input_batch.repetition_penalties_cpu_tensor,
# input_batch.repetition_penalties)
xm
.
mark_step
()
xm
.
wait_device_ops
()
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
4098b722
...
...
@@ -88,6 +88,8 @@ class TPUModelRunner:
self
.
max_model_len
=
model_config
.
max_model_len
self
.
max_num_blocks_per_req
=
cdiv
(
self
.
max_model_len
,
self
.
block_size
)
self
.
max_num_tokens
=
scheduler_config
.
max_num_batched_tokens
# InputBatch needs to work with sampling tensors greater than padding
# to avoid dynamic shapes. Also, avoid suboptimal alignment.
self
.
max_num_reqs
=
max
(
scheduler_config
.
max_num_seqs
,
MIN_NUM_SEQS
)
# Model-related.
...
...
@@ -788,6 +790,7 @@ class TPUModelRunner:
dummy_hidden
=
torch
.
randn
((
num_tokens
,
hsize
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
# Compile for [8, 16, .., 128,.., `self.max_num_reqs`]
while
True
:
indices
=
torch
.
zeros
(
num_reqs_to_sample
,
...
...
@@ -804,7 +807,9 @@ class TPUModelRunner:
out
=
out
.
cpu
()
if
num_reqs_to_sample
>=
self
.
max_num_reqs
:
break
num_reqs_to_sample
*=
2
# Make sure to compile the `max_num_reqs` upper-limit case
num_reqs_to_sample
=
_get_padded_num_reqs_with_upper_limit
(
num_reqs_to_sample
+
1
,
self
.
max_num_reqs
)
xm
.
wait_device_ops
()
end
=
time
.
perf_counter
()
logger
.
info
(
"Compilation finished in in %.2f [secs]."
,
end
-
start
)
...
...
@@ -897,7 +902,6 @@ class ModelWrapperV1(nn.Module):
return
hidden_states
# @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def
sample_from_hidden
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment