Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d4170fad
Unverified
Commit
d4170fad
authored
Jul 14, 2025
by
XiongfeiWei
Committed by
GitHub
Jul 15, 2025
Browse files
Use w8a8 quantized matmul Pallas kernel (#19170)
Signed-off-by:
Xiongfei Wei
<
isaacwxf23@gmail.com
>
parent
946aadb4
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
50 additions
and
19 deletions
+50
-19
requirements/tpu.txt
requirements/tpu.txt
+5
-5
tests/tpu/test_quantization_accuracy.py
tests/tpu/test_quantization_accuracy.py
+4
-4
tests/v1/tpu/test_basic.py
tests/v1/tpu/test_basic.py
+32
-0
vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py
...del_executor/layers/quantization/kernels/scaled_mm/xla.py
+9
-10
No files found.
requirements/tpu.txt
View file @
d4170fad
...
...
@@ -18,9 +18,9 @@ setuptools==78.1.0
--find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
torch==2.9.0.dev202507
03
torchvision==0.24.0.dev202507
03
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.
8
.0.dev202507
03
-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.
8
.0.dev202507
03
-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.
8
.0.dev202507
03
-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch==2.9.0.dev202507
11
torchvision==0.24.0.dev202507
11
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.
9
.0.dev202507
11
-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.
9
.0.dev202507
11
-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.
9
.0.dev202507
11
-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
tests/tpu/test_quantization_accuracy.py
View file @
d4170fad
...
...
@@ -14,7 +14,7 @@ RTOL = 0.03
@
dataclass
class
GSM8KAccuracyTestConfig
:
model_name
:
str
ex
cep
ted_value
:
float
ex
pec
ted_value
:
float
def
get_model_args
(
self
)
->
str
:
return
(
f
"pretrained=
{
self
.
model_name
}
,"
...
...
@@ -25,13 +25,13 @@ class GSM8KAccuracyTestConfig:
ACCURACY_CONFIGS
=
[
GSM8KAccuracyTestConfig
(
model_name
=
"neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8"
,
ex
cep
ted_value
=
0.76
),
# no bias
ex
pec
ted_value
=
0.76
),
# no bias
# NOTE(rob): We cannot re-initialize vLLM in the same process for TPU,
# so only one of these tests can run in a single call to pytest. As
# a follow up, move this into the LM-EVAL section of the CI.
# GSM8KAccuracyTestConfig(
# model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8",
# ex
cep
ted_value=0.66), # bias in QKV layers
# ex
pec
ted_value=0.66), # bias in QKV layers
]
...
...
@@ -45,7 +45,7 @@ def test_gsm8k_correctness(config: GSM8KAccuracyTestConfig):
batch_size
=
"auto"
,
)
EXPECTED_VALUE
=
config
.
ex
cep
ted_value
EXPECTED_VALUE
=
config
.
ex
pec
ted_value
measured_value
=
results
[
"results"
][
TASK
][
FILTER
]
assert
(
measured_value
-
RTOL
<
EXPECTED_VALUE
and
measured_value
+
RTOL
>
EXPECTED_VALUE
...
...
tests/v1/tpu/test_basic.py
View file @
d4170fad
...
...
@@ -145,3 +145,35 @@ def test_gemma3_27b_with_text_input_and_tp(
for
output
,
answer
in
zip
(
vllm_outputs
,
answers
):
generated_text
=
output
[
1
]
assert
answer
in
generated_text
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_tpu
(),
reason
=
"This is a basic test for TPU only"
)
def
test_w8a8_quantization
(
vllm_runner
:
type
[
VllmRunner
],
monkeypatch
:
pytest
.
MonkeyPatch
,
)
->
None
:
model
=
"neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8"
max_tokens
=
5
tensor_parallel_size
=
1
max_num_seqs
=
4
prompt
=
"The next numbers of the sequence "
+
", "
.
join
(
str
(
i
)
for
i
in
range
(
1024
))
+
" are:"
example_prompts
=
[
prompt
]
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
with
vllm_runner
(
model
,
max_num_batched_tokens
=
64
,
max_model_len
=
4096
,
gpu_memory_utilization
=
0.7
,
max_num_seqs
=
max_num_seqs
,
tensor_parallel_size
=
tensor_parallel_size
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
output
=
vllm_outputs
[
0
][
1
]
assert
"1024"
in
output
or
"0, 1"
in
output
vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py
View file @
d4170fad
...
...
@@ -90,16 +90,15 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
w_q
,
w_s
,
_
,
_
,
_
=
self
.
_get_weight_params
(
layer
)
import
torch_xla.experimental.xla_quantized_matmul
# noqa: F401
out
=
torch
.
ops
.
xla
.
quantized_matmul
(
x
,
# Required to register custom ops.
import
torch_xla.experimental.custom_kernel
# noqa: F401
out
=
torch
.
ops
.
xla
.
quantized_matmul_int8
(
x
,
w_q
,
w_s
,
zero_point
=
None
,
block_size
=-
1
,
int4_weight
=
False
,
quantize_activation
=
True
)
# `quantized_matmul` output is fp32, cast it down to bf16 for perf
out
=
out
.
to
(
x
.
dtype
)
quantize_activation
=
True
,
)
# Explicitly capture control flow to make dynamo happy.
# https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501
return
cond
(
bias
is
None
,
self
.
no_add_bias
,
self
.
add_bias
,
[
out
,
bias
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment