Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3a1d8940
Unverified
Commit
3a1d8940
authored
Jul 19, 2025
by
Chengji Yao
Committed by
GitHub
Jul 20, 2025
Browse files
[TPU] support fp8 kv cache quantization (#19292)
Signed-off-by:
Chengji Yao
<
chengjiyao@google.com
>
parent
2b504eb7
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
95 additions
and
28 deletions
+95
-28
tests/entrypoints/llm/test_accuracy.py
tests/entrypoints/llm/test_accuracy.py
+30
-10
tests/v1/tpu/test_pallas.py
tests/v1/tpu/test_pallas.py
+2
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+4
-4
vllm/platforms/tpu.py
vllm/platforms/tpu.py
+3
-1
vllm/v1/attention/backends/pallas.py
vllm/v1/attention/backends/pallas.py
+50
-8
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+6
-5
No files found.
tests/entrypoints/llm/test_accuracy.py
View file @
3a1d8940
...
@@ -15,15 +15,18 @@ import pytest
...
@@ -15,15 +15,18 @@ import pytest
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
MODEL_NAMES
=
[
MODEL_NAMES
=
[
"Qwen/Qwen
2
-1.
5B-Instruct
"
,
"Qwen/Qwen
3
-1.
7B
"
,
"google/gemma-3-1b-it"
,
"google/gemma-3-1b-it"
,
]
]
FP8_KV_MODEL_NAMES
=
[
"Qwen/Qwen3-1.7B"
,
]
NUM_CONCURRENT
=
500
NUM_CONCURRENT
=
500
TASK
=
"gsm8k"
TASK
=
"gsm8k"
FILTER
=
"exact_match,strict-match"
FILTER
=
"exact_match,strict-match"
RTOL
=
0.03
RTOL
=
0.03
EXPECTED_VALUES
=
{
EXPECTED_VALUES
=
{
"Qwen/Qwen
2
-1.
5B-Instruct
"
:
0.
5
8
,
"Qwen/Qwen
3
-1.
7B
"
:
0.
6
8
,
"google/gemma-3-1b-it"
:
0.25
,
"google/gemma-3-1b-it"
:
0.25
,
}
}
...
@@ -70,10 +73,9 @@ def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch):
...
@@ -70,10 +73,9 @@ def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch):
if
current_platform
.
is_tpu
():
if
current_platform
.
is_tpu
():
# Limit compilation time for TPU V1
# Limit compilation time for TPU V1
if
model
==
"google/gemma-3-1b-it"
:
# xet doesn't work well for both Qwen/Qwen3-1.7B and
# TPU +
google/gemma-3-1b-it
+ xet doesn't work well.
#
google/gemma-3-1b-it
m
.
setenv
(
"HF_HUB_DISABLE_XET"
,
"1"
)
m
.
setenv
(
"HF_HUB_DISABLE_XET"
,
"1"
)
more_args
=
"max_model_len=2048,max_num_seqs=64"
more_args
=
"max_model_len=2048,max_num_seqs=64"
# Add TP test (if provided)
# Add TP test (if provided)
...
@@ -83,9 +85,27 @@ def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch):
...
@@ -83,9 +85,27 @@ def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch):
run_test
(
model
,
more_args
)
run_test
(
model
,
more_args
)
def
test_lm_eval_accuracy_v0_engine
(
monkeypatch
:
pytest
.
MonkeyPatch
):
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
()
"""Run with the V0 Engine."""
and
not
current_platform
.
is_tpu
(),
reason
=
"V1 is currently only supported on CUDA and TPU"
)
@
pytest
.
mark
.
parametrize
(
"model"
,
FP8_KV_MODEL_NAMES
)
def
test_lm_eval_accuracy_v1_engine_fp8_kv_cache
(
model
,
monkeypatch
:
pytest
.
MonkeyPatch
):
"""Run with the V1 Engine."""
with
monkeypatch
.
context
()
as
m
:
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"0"
)
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
run_test
(
"Qwen/Qwen2-1.5B-Instruct"
)
more_args
=
None
if
current_platform
.
is_tpu
():
# Limit compilation time for TPU V1
# xet doesn't work well for Qwen/Qwen3-1.7B
m
.
setenv
(
"HF_HUB_DISABLE_XET"
,
"1"
)
more_args
=
"max_model_len=2048,max_num_seqs=128,kv_cache_dtype=fp8"
# Add TP test (if provided)
if
TPU_TP_TEST_STR
:
more_args
+=
",{}"
.
format
(
TPU_TP_TEST_STR
)
run_test
(
model
,
more_args
)
tests/v1/tpu/test_pallas.py
View file @
3a1d8940
...
@@ -95,4 +95,6 @@ def test_ragged_paged_attention():
...
@@ -95,4 +95,6 @@ def test_ragged_paged_attention():
sm_scale
=
scale
,
sm_scale
=
scale
,
sliding_window
=
sliding_window
,
sliding_window
=
sliding_window
,
soft_cap
=
logits_soft_cap
,
soft_cap
=
logits_soft_cap
,
k_scale
=
1.0
,
v_scale
=
1.0
,
)
)
vllm/engine/arg_utils.py
View file @
3a1d8940
...
@@ -1358,10 +1358,10 @@ class EngineArgs:
...
@@ -1358,10 +1358,10 @@ class EngineArgs:
and
not
envs
.
is_set
(
"VLLM_ATTENTION_BACKEND"
)
and
not
envs
.
is_set
(
"VLLM_ATTENTION_BACKEND"
)
)
or
envs
.
VLLM_ATTENTION_BACKEND
==
"FLASH_ATTN_VLLM_V1"
)
or
envs
.
VLLM_ATTENTION_BACKEND
==
"FLASH_ATTN_VLLM_V1"
supported
=
False
supported
=
False
if
current_platform
.
is_rocm
()
or
(
if
(
current_platform
.
is_rocm
()
current_platform
.
is_cuda
()
or
(
current_platform
.
is_cuda
()
and
current_platform
.
is_device_capability
(
100
)
and
current_platform
.
is_device_capability
(
100
)
)
):
# handle hpu also for OOT platform
or
current_platform
.
is_tpu
()):
supported
=
True
supported
=
True
elif
fp8_attention
and
will_use_fa
:
elif
fp8_attention
and
will_use_fa
:
from
vllm.attention.utils.fa_utils
import
(
from
vllm.attention.utils.fa_utils
import
(
...
...
vllm/platforms/tpu.py
View file @
3a1d8940
...
@@ -35,7 +35,9 @@ class TpuPlatform(Platform):
...
@@ -35,7 +35,9 @@ class TpuPlatform(Platform):
device_control_env_var
:
str
=
"TPU_VISIBLE_CHIPS"
device_control_env_var
:
str
=
"TPU_VISIBLE_CHIPS"
simple_compile_backend
:
str
=
"openxla"
simple_compile_backend
:
str
=
"openxla"
supported_quantization
:
list
[
str
]
=
[
"tpu_int8"
,
"compressed-tensors"
]
supported_quantization
:
list
[
str
]
=
[
"fp8"
,
"tpu_int8"
,
"compressed-tensors"
]
additional_env_vars
:
list
[
str
]
=
[
additional_env_vars
:
list
[
str
]
=
[
"TPU_CHIPS_PER_HOST_BOUNDS"
,
"TPU_HOST_BOUNDS"
"TPU_CHIPS_PER_HOST_BOUNDS"
,
"TPU_HOST_BOUNDS"
...
...
vllm/v1/attention/backends/pallas.py
View file @
3a1d8940
...
@@ -24,6 +24,19 @@ logger = init_logger(__name__)
...
@@ -24,6 +24,19 @@ logger = init_logger(__name__)
# TPU requires the head size to be a multiple of 128.
# TPU requires the head size to be a multiple of 128.
TPU_HEAD_SIZE_ALIGNMENT
=
128
TPU_HEAD_SIZE_ALIGNMENT
=
128
# Note: TPU can fp8 as storage dtype but doesn't support converting from uint8
# from to fp32 directly. That's why it has a dtype mapping different from GPU
TPU_STR_DTYPE_TO_TORCH_DTYPE
=
{
"half"
:
torch
.
half
,
"bfloat16"
:
torch
.
bfloat16
,
"float"
:
torch
.
float
,
"fp8"
:
torch
.
float8_e4m3fn
,
"fp8_e4m3"
:
torch
.
float8_e4m3fn
,
"fp8_e5m2"
:
torch
.
float8_e5m2
,
"int8"
:
torch
.
int8
,
"uint8"
:
torch
.
uint8
,
}
class
PallasAttentionBackend
(
AttentionBackend
):
class
PallasAttentionBackend
(
AttentionBackend
):
...
@@ -152,8 +165,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
...
@@ -152,8 +165,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
if
alibi_slopes
is
not
None
:
if
alibi_slopes
is
not
None
:
raise
NotImplementedError
(
"Alibi slopes is not supported."
)
raise
NotImplementedError
(
"Alibi slopes is not supported."
)
if
kv_cache_dtype
!=
"auto"
:
raise
NotImplementedError
(
"FP8 KV cache dtype is not supported."
)
if
attn_type
!=
AttentionType
.
DECODER
:
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
raise
NotImplementedError
(
"Encoder self-attention and "
...
@@ -161,6 +172,11 @@ class PallasAttentionBackendImpl(AttentionImpl):
...
@@ -161,6 +172,11 @@ class PallasAttentionBackendImpl(AttentionImpl):
"are not implemented for "
"are not implemented for "
"PallasAttentionBackendImpl"
)
"PallasAttentionBackendImpl"
)
self
.
kv_cache_quantized_dtype
=
None
if
kv_cache_dtype
!=
"auto"
:
self
.
kv_cache_quantized_dtype
=
TPU_STR_DTYPE_TO_TORCH_DTYPE
.
get
(
kv_cache_dtype
.
lower
().
strip
())
def
forward
(
def
forward
(
self
,
self
,
layer
:
AttentionLayer
,
layer
:
AttentionLayer
,
...
@@ -194,7 +210,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
...
@@ -194,7 +210,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
output
=
torch
.
ones_like
(
query
)
output
=
torch
.
ones_like
(
query
)
return
output
return
output
assert
layer
.
_k_scale_float
==
1.0
and
layer
.
_v_scale_float
==
1.0
num_tokens
,
hidden_size
=
query
.
shape
num_tokens
,
hidden_size
=
query
.
shape
query
=
query
.
view
(
num_tokens
,
self
.
num_heads
,
self
.
head_size
)
query
=
query
.
view
(
num_tokens
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
...
@@ -215,10 +230,21 @@ class PallasAttentionBackendImpl(AttentionImpl):
...
@@ -215,10 +230,21 @@ class PallasAttentionBackendImpl(AttentionImpl):
# Skip this if sharing KV cache with an earlier attention layer.
# Skip this if sharing KV cache with an earlier attention layer.
slot_mapping
=
attn_metadata
.
slot_mapping
slot_mapping
=
attn_metadata
.
slot_mapping
write_to_kv_cache
(
write_to_kv_cache
(
key
,
value
,
kv_cache
,
slot_mapping
,
key
,
value
,
kv_cache
,
slot_mapping
,
attn_metadata
.
num_slices_per_kv_cache_update_block
,
attn_metadata
.
num_slices_per_kv_cache_update_block
,
attn_metadata
.
num_kv_update_slices
)
attn_metadata
.
num_kv_update_slices
,
self
.
kv_cache_quantized_dtype
,
layer
.
_k_scale_float
,
layer
.
_v_scale_float
,
)
if
self
.
kv_cache_quantized_dtype
is
not
None
and
(
layer
.
_k_scale_float
==
0.0
or
layer
.
_v_scale_float
==
0.0
):
raise
ValueError
(
"k_scale_float and v_scale_float must be non-zero"
)
output
=
torch
.
ops
.
xla
.
ragged_paged_attention
(
output
=
torch
.
ops
.
xla
.
ragged_paged_attention
(
query
,
query
,
kv_cache
,
kv_cache
,
...
@@ -236,6 +262,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
...
@@ -236,6 +262,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
sm_scale
=
self
.
scale
,
sm_scale
=
self
.
scale
,
sliding_window
=
self
.
sliding_window
,
sliding_window
=
self
.
sliding_window
,
soft_cap
=
self
.
logits_soft_cap
,
soft_cap
=
self
.
logits_soft_cap
,
k_scale
=
layer
.
_k_scale_float
,
v_scale
=
layer
.
_v_scale_float
,
)
)
if
self
.
head_size
%
TPU_HEAD_SIZE_ALIGNMENT
!=
0
:
if
self
.
head_size
%
TPU_HEAD_SIZE_ALIGNMENT
!=
0
:
...
@@ -251,18 +279,32 @@ def write_to_kv_cache(
...
@@ -251,18 +279,32 @@ def write_to_kv_cache(
slot_mapping
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
num_slices_per_kv_cache_update_block
:
int
,
num_slices_per_kv_cache_update_block
:
int
,
num_kv_update_slices
:
torch
.
Tensor
,
num_kv_update_slices
:
torch
.
Tensor
,
kv_cache_quantized_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
)
->
None
:
)
->
None
:
""" Write the key and values to the KV cache.
""" Write the key and values to the KV cache.
Args:
Args:
key: shape = [num_tokens, num_kv_heads
*
head_size]
key: shape = [num_tokens, num_kv_heads
,
head_size]
value: shape = [num_tokens, num_kv_heads
*
head_size]
value: shape = [num_tokens, num_kv_heads
,
head_size]
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
num_slices_per_kv_cache_update_block: int
num_slices_per_kv_cache_update_block: int
"""
"""
_
,
page_size
,
num_combined_kv_heads
,
head_size
=
kv_cache
.
shape
_
,
page_size
,
num_combined_kv_heads
,
head_size
=
kv_cache
.
shape
head_size
=
cdiv
(
head_size
,
head_size
=
cdiv
(
head_size
,
TPU_HEAD_SIZE_ALIGNMENT
)
*
TPU_HEAD_SIZE_ALIGNMENT
TPU_HEAD_SIZE_ALIGNMENT
)
*
TPU_HEAD_SIZE_ALIGNMENT
if
kv_cache_quantized_dtype
is
not
None
:
dtype_info
=
torch
.
finfo
(
kv_cache_quantized_dtype
)
key
=
key
.
to
(
torch
.
float32
)
/
k_scale
# NOTE: clamp is added here to avoid out of range of quantized dtype
key
=
torch
.
clamp
(
key
,
dtype_info
.
min
,
dtype_info
.
max
)
key
=
key
.
to
(
kv_cache_quantized_dtype
)
value
=
value
.
to
(
torch
.
float32
)
/
v_scale
value
=
torch
.
clamp
(
value
,
dtype_info
.
min
,
dtype_info
.
max
)
value
=
value
.
to
(
kv_cache_quantized_dtype
)
kv
=
torch
.
cat
([
key
,
value
],
axis
=-
1
).
reshape
(
-
1
,
num_combined_kv_heads
,
kv
=
torch
.
cat
([
key
,
value
],
axis
=-
1
).
reshape
(
-
1
,
num_combined_kv_heads
,
head_size
)
head_size
)
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
3a1d8940
...
@@ -32,9 +32,10 @@ from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
...
@@ -32,9 +32,10 @@ from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
from
vllm.multimodal.utils
import
group_mm_inputs_by_modality
from
vllm.multimodal.utils
import
group_mm_inputs_by_modality
from
vllm.pooling_params
import
PoolingTask
from
vllm.pooling_params
import
PoolingTask
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
LayerBlockType
,
cdiv
,
from
vllm.utils
import
(
LayerBlockType
,
cdiv
,
is_pin_memory_available
,
is_pin_memory_available
,
prev_power_of_2
)
prev_power_of_2
)
from
vllm.v1.attention.backends.pallas
import
(
PallasAttentionBackend
,
from
vllm.v1.attention.backends.pallas
import
(
TPU_STR_DTYPE_TO_TORCH_DTYPE
,
PallasAttentionBackend
,
PallasMetadata
,
PallasMetadata
,
get_page_size_bytes
)
get_page_size_bytes
)
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
...
@@ -142,11 +143,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
...
@@ -142,11 +143,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
if
cache_config
.
cache_dtype
==
"auto"
:
if
cache_config
.
cache_dtype
==
"auto"
:
model_dtype
=
self
.
dtype
model_dtype
=
self
.
dtype
if
isinstance
(
model_dtype
,
str
):
if
isinstance
(
model_dtype
,
str
):
self
.
kv_cache_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
model_dtype
]
self
.
kv_cache_dtype
=
TPU_
STR_DTYPE_TO_TORCH_DTYPE
[
model_dtype
]
else
:
else
:
self
.
kv_cache_dtype
=
model_dtype
self
.
kv_cache_dtype
=
model_dtype
else
:
else
:
self
.
kv_cache_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
self
.
kv_cache_dtype
=
TPU_
STR_DTYPE_TO_TORCH_DTYPE
[
cache_config
.
cache_dtype
]
cache_config
.
cache_dtype
]
self
.
_hidden_states_dtype
=
self
.
dtype
self
.
_hidden_states_dtype
=
self
.
dtype
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment