Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9a0c5ded
Unverified
Commit
9a0c5ded
authored
Aug 08, 2025
by
Kyuyeun Kim
Committed by
GitHub
Aug 08, 2025
Browse files
[TPU] Add support for online w8a8 quantization (#22425)
Signed-off-by:
Kyuyeun Kim
<
kyuyeunk@google.com
>
parent
10a02535
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
82 additions
and
3 deletions
+82
-3
.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh
.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh
+2
-0
tests/v1/tpu/test_tpu_int8.py
tests/v1/tpu/test_tpu_int8.py
+73
-0
vllm/model_executor/layers/quantization/tpu_int8.py
vllm/model_executor/layers/quantization/tpu_int8.py
+7
-3
No files found.
.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh
View file @
9a0c5ded
...
@@ -139,6 +139,8 @@ run_and_track_test 5 "test_spmd_model_weight_loading.py" \
...
@@ -139,6 +139,8 @@ run_and_track_test 5 "test_spmd_model_weight_loading.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py"
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py"
run_and_track_test 6 "test_kv_cache_update_kernel.py" \
run_and_track_test 6 "test_kv_cache_update_kernel.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_kv_cache_update_kernel.py"
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_kv_cache_update_kernel.py"
run_and_track_test 7 "test_tpu_int8.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_int8.py"
# After all tests have been attempted, exit with the overall status.
# After all tests have been attempted, exit with the overall status.
if [ "$overall_script_exit_code" -ne 0 ]; then
if [ "$overall_script_exit_code" -ne 0 ]; then
...
...
tests/v1/tpu/test_tpu_int8.py
0 → 100644
View file @
9a0c5ded
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests whether TPU Int8 computation is enabled correctly.
Run `pytest tests/quantization/test_tpu_int8.py`.
"""
import
pytest
from
vllm.model_executor.layers.linear
import
LinearBase
from
vllm.model_executor.layers.quantization.tpu_int8
import
(
TPUInt8LinearMethod
)
from
vllm.platforms
import
current_platform
from
...models.registry
import
HF_EXAMPLE_MODELS
MODELS
=
[
"Qwen/Qwen2.5-0.5B-Instruct"
]
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_tpu
(),
reason
=
"TPU Int8 is only enabled for TPUs."
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
10
])
@
pytest
.
mark
.
parametrize
(
"hf_overrides"
,
[
# w8a8 dynamic activation
{
'quantization_config'
:
{
'quant_method'
:
'tpu_int8'
,
'activation_scheme'
:
'dynamic'
}
}
])
def
test_model_tpu_int8
(
vllm_runner
,
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
hf_overrides
:
dict
,
monkeypatch
)
->
None
:
model_info
=
HF_EXAMPLE_MODELS
.
find_hf_info
(
model
)
model_info
.
check_transformers_version
(
on_fail
=
"skip"
)
activation_scheme
=
hf_overrides
.
get
(
'quantization_config'
,
{}).
get
(
'activation_scheme'
)
quantize_activation
=
activation_scheme
==
'dynamic'
# Allows using apply_model
monkeypatch
.
setenv
(
"VLLM_ENABLE_V1_MULTIPROCESSING"
,
"0"
)
# Prevent error from re-initializing cache
monkeypatch
.
setenv
(
"VLLM_XLA_CACHE_PATH"
,
""
)
prompts
=
[
"A robot may not injure a human being"
,
"It is only with the heart that one can see rightly;"
,
"The greatest glory in living lies not in never falling,"
,
]
answers
=
[
"or, being injured, not kill, except in"
,
"without the heart, one can only see wrongly."
,
"but in rising every time we fall. - Nelson"
]
with
vllm_runner
(
model
,
dtype
=
dtype
,
hf_overrides
=
hf_overrides
)
as
vllm
:
def
check_model
(
model
):
for
name
,
module
in
model
.
named_modules
():
if
not
isinstance
(
module
,
LinearBase
):
continue
quant_method
=
module
.
quant_method
assert
isinstance
(
quant_method
,
TPUInt8LinearMethod
)
assert
quant_method
.
quantize_activation
==
quantize_activation
vllm
.
apply_model
(
check_model
)
outputs
=
vllm
.
generate_greedy
(
prompts
,
max_tokens
)
for
(
_
,
output
),
answer
in
zip
(
outputs
,
answers
):
assert
answer
in
output
vllm/model_executor/layers/quantization/tpu_int8.py
View file @
9a0c5ded
...
@@ -13,7 +13,7 @@ from vllm.model_executor.layers.quantization.base_config import (
...
@@ -13,7 +13,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.parameter
import
ModelWeightParameter
from
vllm.model_executor.parameter
import
ModelWeightParameter
ACTIVATION_SCHEMES
=
[
"none"
]
ACTIVATION_SCHEMES
=
[
"none"
,
"dynamic"
]
class
Int8TpuConfig
(
QuantizationConfig
):
class
Int8TpuConfig
(
QuantizationConfig
):
...
@@ -61,6 +61,9 @@ class TPUInt8LinearMethod(LinearMethodBase):
...
@@ -61,6 +61,9 @@ class TPUInt8LinearMethod(LinearMethodBase):
def
__init__
(
self
,
quant_config
:
Int8TpuConfig
):
def
__init__
(
self
,
quant_config
:
Int8TpuConfig
):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
quantize_activation
=
False
if
self
.
quant_config
.
activation_scheme
==
'dynamic'
:
self
.
quantize_activation
=
True
def
create_weights
(
self
,
layer
:
Module
,
input_size_per_partition
:
int
,
def
create_weights
(
self
,
layer
:
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
list
[
int
],
input_size
:
int
,
output_partition_sizes
:
list
[
int
],
input_size
:
int
,
...
@@ -107,7 +110,7 @@ class TPUInt8LinearMethod(LinearMethodBase):
...
@@ -107,7 +110,7 @@ class TPUInt8LinearMethod(LinearMethodBase):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
try
:
try
:
import
torch_xla.experimental.
xla_quantized_matmu
l
# noqa: F401
import
torch_xla.experimental.
custom_kerne
l
# noqa: F401
except
ImportError
as
err
:
except
ImportError
as
err
:
raise
ImportError
(
raise
ImportError
(
"Please install torch_xla by following the instructions at "
"Please install torch_xla by following the instructions at "
...
@@ -115,7 +118,8 @@ class TPUInt8LinearMethod(LinearMethodBase):
...
@@ -115,7 +118,8 @@ class TPUInt8LinearMethod(LinearMethodBase):
"to run vLLM on TPU."
)
from
err
"to run vLLM on TPU."
)
from
err
weight
=
layer
.
weight
weight
=
layer
.
weight
scale
=
layer
.
scale
scale
=
layer
.
scale
out
=
torch
.
ops
.
xla
.
quantized_matmul
(
x
,
weight
,
scale
)
out
=
torch
.
ops
.
xla
.
quantized_matmul_int8
(
x
,
weight
,
scale
,
quantize_activation
=
self
.
quantize_activation
)
if
bias
is
not
None
:
if
bias
is
not
None
:
out
=
out
+
bias
out
=
out
+
bias
return
out
return
out
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment