Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
471fe656
Unverified
Commit
471fe656
authored
Apr 21, 2025
by
Chengji Yao
Committed by
GitHub
Apr 21, 2025
Browse files
[TPU][V1] Implicitly adjust page size when there's SMEM OOM (#16871)
Signed-off-by:
Chengji Yao
<
chengjiyao@google.com
>
parent
3a0fba5c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
34 additions
and
2 deletions
+34
-2
tests/v1/tpu/test_basic.py
tests/v1/tpu/test_basic.py
+5
-2
vllm/platforms/tpu.py
vllm/platforms/tpu.py
+14
-0
vllm/v1/attention/backends/pallas.py
vllm/v1/attention/backends/pallas.py
+15
-0
No files found.
tests/v1/tpu/test_basic.py
View file @
471fe656
...
...
@@ -22,6 +22,7 @@ MODELS = [
]
TENSOR_PARALLEL_SIZES
=
[
1
]
MAX_NUM_REQS
=
[
16
,
1024
]
# TODO: Enable when CI/CD will have a multi-tpu instance
# TENSOR_PARALLEL_SIZES = [1, 4]
...
...
@@ -32,12 +33,14 @@ TENSOR_PARALLEL_SIZES = [1]
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"tensor_parallel_size"
,
TENSOR_PARALLEL_SIZES
)
@
pytest
.
mark
.
parametrize
(
"max_num_seqs"
,
MAX_NUM_REQS
)
def
test_basic
(
vllm_runner
:
type
[
VllmRunner
],
monkeypatch
:
pytest
.
MonkeyPatch
,
model
:
str
,
max_tokens
:
int
,
tensor_parallel_size
:
int
,
max_num_seqs
:
int
,
)
->
None
:
prompt
=
"The next numbers of the sequence "
+
", "
.
join
(
str
(
i
)
for
i
in
range
(
1024
))
+
" are:"
...
...
@@ -51,9 +54,9 @@ def test_basic(
# Note: max_num_batched_tokens == 1024 is needed here to
# actually test chunked prompt
max_num_batched_tokens
=
1024
,
max_model_len
=
819
6
,
max_model_len
=
819
2
,
gpu_memory_utilization
=
0.7
,
max_num_seqs
=
16
,
max_num_seqs
=
max_num_seqs
,
tensor_parallel_size
=
tensor_parallel_size
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
...
...
vllm/platforms/tpu.py
View file @
471fe656
...
...
@@ -97,6 +97,20 @@ class TpuPlatform(Platform):
"Using bfloat16 instead."
,
vllm_config
.
model_config
.
dtype
)
vllm_config
.
model_config
.
dtype
=
torch
.
bfloat16
if
envs
.
VLLM_USE_V1
:
from
vllm.v1.attention.backends.pallas
import
(
PallasAttentionBackend
)
min_page_size
=
PallasAttentionBackend
.
get_min_page_size
(
vllm_config
)
if
min_page_size
>
vllm_config
.
cache_config
.
block_size
:
logger
.
warning
(
"Increase the page size from %s to %s to make sure there's"
"no SMEM OOM"
,
vllm_config
.
cache_config
.
block_size
,
min_page_size
,
)
vllm_config
.
cache_config
.
block_size
=
min_page_size
parallel_config
=
vllm_config
.
parallel_config
scheduler_config
=
vllm_config
.
scheduler_config
if
parallel_config
.
worker_cls
==
"auto"
:
...
...
vllm/v1/attention/backends/pallas.py
View file @
471fe656
...
...
@@ -10,7 +10,9 @@ import torch_xla.experimental.custom_kernel # noqa: F401
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
AttentionType
)
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.utils
import
cdiv
logger
=
init_logger
(
__name__
)
...
...
@@ -50,6 +52,19 @@ class PallasAttentionBackend(AttentionBackend):
)
->
None
:
raise
RuntimeError
(
"swap_blocks is not used for the TPU backend."
)
# In recent TPU generations, up to v6e, the SMEM size is 1MB. The
# block_tables within the PallasMetadata constitute almost the entire SMEM
# requirement. Its size is max_num_seqs * num_page_per_seq * 4 (Int). Here
# we simply make sure that the size is smaller than half of SMEM capacity.
@
staticmethod
def
get_min_page_size
(
vllm_config
:
VllmConfig
)
->
int
:
max_num_page_per_req
=
(
1024
*
1024
//
2
//
vllm_config
.
scheduler_config
.
max_num_seqs
//
4
)
min_page_size
=
cdiv
(
vllm_config
.
model_config
.
max_model_len
,
max_num_page_per_req
)
min_page_size
=
1
<<
(
min_page_size
-
1
).
bit_length
()
return
min_page_size
@
dataclass
class
PallasMetadata
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment