Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
da461f3c
Unverified
Commit
da461f3c
authored
Mar 29, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Mar 28, 2025
Browse files
[TPU][V1][Bugfix] Fix w8a8 recompiilation with GSM8K (#15714)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
5b800f09
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
16 additions
and
15 deletions
+16
-15
.buildkite/run-tpu-v1-test.sh
.buildkite/run-tpu-v1-test.sh
+4
-6
vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py
...del_executor/layers/quantization/kernels/scaled_mm/xla.py
+2
-1
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+8
-6
vllm/v1/worker/tpu_worker.py
vllm/v1/worker/tpu_worker.py
+2
-2
No files found.
.buildkite/run-tpu-v1-test.sh
View file @
da461f3c
...
@@ -28,16 +28,14 @@ docker run --privileged --net host --shm-size=16G -it \
...
@@ -28,16 +28,14 @@ docker run --privileged --net host --shm-size=16G -it \
&& echo TEST_3
\
&& echo TEST_3
\
&& pytest -v -s /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine
\
&& pytest -v -s /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine
\
&& echo TEST_4
\
&& echo TEST_4
\
&& pyt
hon3
/workspace/vllm/
examples/offline_inference/tpu
.py
\
&& pyt
est -s -v
/workspace/vllm/
tests/tpu/test_quantization_accuracy
.py
\
&& echo TEST_5
\
&& echo TEST_5
\
&& pyt
est -s -v
/workspace/vllm/
tests/v1/tpu/worker/test_tpu_model_runner
.py
\
&& pyt
hon3
/workspace/vllm/
examples/offline_inference/tpu
.py
\
&& echo TEST_6
\
&& echo TEST_6
\
&& pytest -s -v /workspace/vllm/tests/v1/tpu/worker/test_tpu_model_runner.py
\
&& echo TEST_7
\
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py"
\
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py"
\
# TODO: This test fails because it uses RANDOM_SEED sampling
# TODO: This test fails because it uses RANDOM_SEED sampling
# && VLLM_USE_V1=1 pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \
# && VLLM_USE_V1=1 pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \
# TODO: Re-enable this after fixing recompilation in quantization.
# && echo TEST_4 \
# && pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py \
vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py
View file @
da461f3c
...
@@ -97,7 +97,8 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
...
@@ -97,7 +97,8 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
block_size
=-
1
,
block_size
=-
1
,
int4_weight
=
False
,
int4_weight
=
False
,
quantize_activation
=
True
)
quantize_activation
=
True
)
# `quantized_matmul` output is fp32, cast it down to bf16 for perf
out
=
out
.
to
(
x
.
dtype
)
# Explicitly capture control flow to make dynamo happy.
# Explicitly capture control flow to make dynamo happy.
# https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501
# https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501
return
cond
(
bias
is
None
,
self
.
no_add_bias
,
self
.
add_bias
,
[
out
,
bias
])
return
cond
(
bias
is
None
,
self
.
no_add_bias
,
self
.
add_bias
,
[
out
,
bias
])
vllm/v1/worker/tpu_model_runner.py
View file @
da461f3c
...
@@ -80,6 +80,7 @@ class TPUModelRunner:
...
@@ -80,6 +80,7 @@ class TPUModelRunner:
self
.
enforce_eager
=
model_config
.
enforce_eager
self
.
enforce_eager
=
model_config
.
enforce_eager
self
.
pin_memory
=
is_pin_memory_available
()
self
.
pin_memory
=
is_pin_memory_available
()
self
.
dtype
=
self
.
model_config
.
dtype
self
.
dtype
=
self
.
model_config
.
dtype
self
.
_hidden_states_dtype
=
self
.
dtype
self
.
is_multimodal_model
=
model_config
.
is_multimodal_model
self
.
is_multimodal_model
=
model_config
.
is_multimodal_model
self
.
sliding_window
=
model_config
.
get_sliding_window
()
self
.
sliding_window
=
model_config
.
get_sliding_window
()
...
@@ -771,10 +772,11 @@ class TPUModelRunner:
...
@@ -771,10 +772,11 @@ class TPUModelRunner:
torch
.
_dynamo
.
mark_dynamic
(
attn_metadata
.
slot_mapping
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
attn_metadata
.
slot_mapping
,
0
)
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
0
):
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
0
):
self
.
model
(
input_ids
=
input_ids
,
out
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
position_ids
,
positions
=
position_ids
,
kv_caches
=
kv_caches
,
kv_caches
=
kv_caches
,
inputs_embeds
=
inputs_embeds
)
inputs_embeds
=
inputs_embeds
)
self
.
_hidden_states_dtype
=
out
.
dtype
def
capture_model
(
self
)
->
None
:
def
capture_model
(
self
)
->
None
:
"""Compile the model."""
"""Compile the model."""
...
@@ -800,7 +802,7 @@ class TPUModelRunner:
...
@@ -800,7 +802,7 @@ class TPUModelRunner:
num_reqs_to_sample
=
MIN_NUM_SEQS
num_reqs_to_sample
=
MIN_NUM_SEQS
dummy_hidden
=
torch
.
randn
((
num_tokens
,
hsize
),
dummy_hidden
=
torch
.
randn
((
num_tokens
,
hsize
),
device
=
device
,
device
=
device
,
dtype
=
torch
.
bfloat16
)
dtype
=
self
.
_hidden_states_dtype
)
# Compile for [8, 16, .., 128,.., `self.max_num_reqs`]
# Compile for [8, 16, .., 128,.., `self.max_num_reqs`]
while
True
:
while
True
:
indices
=
torch
.
zeros
(
indices
=
torch
.
zeros
(
...
@@ -823,7 +825,7 @@ class TPUModelRunner:
...
@@ -823,7 +825,7 @@ class TPUModelRunner:
num_reqs_to_sample
+
1
,
self
.
max_num_reqs
)
num_reqs_to_sample
+
1
,
self
.
max_num_reqs
)
xm
.
wait_device_ops
()
xm
.
wait_device_ops
()
end
=
time
.
perf_counter
()
end
=
time
.
perf_counter
()
logger
.
info
(
"Compilation finished in
in
%.2f [secs]."
,
end
-
start
)
logger
.
info
(
"Compilation finished in %.2f [secs]."
,
end
-
start
)
# Record the number cached XLA graph after warming up, this will be
# Record the number cached XLA graph after warming up, this will be
# used for checking there is no additional graph compilation during
# used for checking there is no additional graph compilation during
# runtime execution.
# runtime execution.
...
...
vllm/v1/worker/tpu_worker.py
View file @
da461f3c
...
@@ -105,8 +105,8 @@ class TPUWorker:
...
@@ -105,8 +105,8 @@ class TPUWorker:
# Increase the cache size limit, which is the maximum number of
# Increase the cache size limit, which is the maximum number of
# dynamo graphs that can be compiled.
# dynamo graphs that can be compiled.
#
NOTE(woosuk): Usually,
we compile
10-15
graphs
for prefill and
#
TODO (NickLucche) On gsm
we compile
80+
graphs
.
#
30-40 graphs for decode. 128 is an arbitrary safe number
.
#
Re-evaluate limit, with MM we may get close to this limit
.
torch
.
_dynamo
.
config
.
cache_size_limit
=
128
torch
.
_dynamo
.
config
.
cache_size_limit
=
128
# Use persistent cache to avoid XLA recompilation.
# Use persistent cache to avoid XLA recompilation.
# NOTE(woosuk): Set per-rank cache path since different ranks
# NOTE(woosuk): Set per-rank cache path since different ranks
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment