Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
621ca2c0
Unverified
Commit
621ca2c0
authored
May 06, 2025
by
Jevin Jiang
Committed by
GitHub
May 06, 2025
Browse files
[TPU] Increase block size and reset block shapes (#16458)
parent
6115b115
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
35 additions
and
11 deletions
+35
-11
examples/offline_inference/tpu.py
examples/offline_inference/tpu.py
+2
-1
requirements/tpu.txt
requirements/tpu.txt
+5
-5
vllm/platforms/tpu.py
vllm/platforms/tpu.py
+6
-4
vllm/utils.py
vllm/utils.py
+7
-0
vllm/v1/attention/backends/pallas.py
vllm/v1/attention/backends/pallas.py
+15
-1
No files found.
examples/offline_inference/tpu.py
View file @
621ca2c0
...
...
@@ -22,7 +22,8 @@ def main():
# In real workloads, `enforace_eager` should be `False`.
llm
=
LLM
(
model
=
"Qwen/Qwen2-1.5B-Instruct"
,
max_num_batched_tokens
=
64
,
max_num_seqs
=
4
)
max_num_seqs
=
4
,
max_model_len
=
128
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
print
(
"-"
*
50
)
for
output
,
answer
in
zip
(
outputs
,
answers
):
...
...
requirements/tpu.txt
View file @
621ca2c0
...
...
@@ -18,9 +18,9 @@ setuptools==78.1.0
--find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
torch==2.8.0.dev2025040
8
torchvision==0.22.0.dev2025040
8
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev2025040
8
-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev2025040
8
-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev2025040
8
-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch==2.8.0.dev202504
3
0
torchvision==0.22.0.dev202504
3
0
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev202504
3
0-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev202504
3
0-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev202504
3
0-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
vllm/platforms/tpu.py
View file @
621ca2c0
...
...
@@ -76,9 +76,9 @@ class TpuPlatform(Platform):
from
vllm.config
import
CompilationLevel
cache_config
=
vllm_config
.
cache_config
# For v0, the default block size is 16.
if
cache_config
and
cache_config
.
block_size
is
None
:
cache_config
.
block_size
=
16
compilation_config
=
vllm_config
.
compilation_config
# TPU only supports DYNAMO_ONCE compilation level
...
...
@@ -101,16 +101,18 @@ class TpuPlatform(Platform):
if
envs
.
VLLM_USE_V1
:
from
vllm.v1.attention.backends.pallas
import
(
PallasAttentionBackend
)
cache_config
.
block_size
=
PallasAttentionBackend
.
get_page_size
(
vllm_config
)
min_page_size
=
PallasAttentionBackend
.
get_min_page_size
(
vllm_config
)
if
min_page_size
>
vllm_config
.
cache_config
.
block_size
:
if
min_page_size
>
cache_config
.
block_size
:
logger
.
warning
(
"Increase the page size from %s to %s to make sure there's"
"no SMEM OOM"
,
vllm_config
.
cache_config
.
block_size
,
cache_config
.
block_size
,
min_page_size
,
)
vllm_config
.
cache_config
.
block_size
=
min_page_size
cache_config
.
block_size
=
min_page_size
parallel_config
=
vllm_config
.
parallel_config
scheduler_config
=
vllm_config
.
scheduler_config
...
...
vllm/utils.py
View file @
621ca2c0
...
...
@@ -707,6 +707,13 @@ def cdiv(a: int, b: int) -> int:
return
-
(
a
//
-
b
)
def
next_power_of_2
(
n
)
->
int
:
"""The next power of 2 (inclusive)"""
if
n
<
1
:
return
1
return
1
<<
(
n
-
1
).
bit_length
()
def
round_up
(
x
:
int
,
y
:
int
)
->
int
:
return
((
x
+
y
-
1
)
//
y
)
*
y
...
...
vllm/v1/attention/backends/pallas.py
View file @
621ca2c0
...
...
@@ -12,7 +12,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.utils
import
cdiv
from
vllm.utils
import
cdiv
,
next_power_of_2
logger
=
init_logger
(
__name__
)
...
...
@@ -65,6 +65,20 @@ class PallasAttentionBackend(AttentionBackend):
min_page_size
=
1
<<
(
min_page_size
-
1
).
bit_length
()
return
min_page_size
# TPU has limited SREGs (scalar registers), if page_size is too small, we
# can spill SREGs easily which leads to bad performance. The strategy we
# apply here is trying to split max-model-len to 16 pages which make the
# spill less likely. Meanwhile we make sure the page size is in [16, 256].
@
staticmethod
def
get_page_size
(
vllm_config
:
VllmConfig
)
->
int
:
page_size
=
next_power_of_2
(
vllm_config
.
model_config
.
max_model_len
)
//
16
if
page_size
<=
16
:
return
16
if
page_size
>=
256
:
return
256
return
page_size
@
dataclass
class
PallasMetadata
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment