Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
97659408
Unverified
Commit
97659408
authored
May 05, 2025
by
XiongfeiWei
Committed by
GitHub
May 05, 2025
Browse files
[TPU] Enable gemma3-27b with TP>1 on multi-chips. (#17335)
Signed-off-by:
Xiongfei Wei
<
isaacwxf23@gmail.com
>
parent
5ea5c514
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
0 deletions
+44
-0
tests/v1/tpu/test_basic.py
tests/v1/tpu/test_basic.py
+43
-0
vllm/platforms/tpu.py
vllm/platforms/tpu.py
+1
-0
No files found.
tests/v1/tpu/test_basic.py
View file @
97659408
...
...
@@ -8,6 +8,7 @@ from __future__ import annotations
from
typing
import
TYPE_CHECKING
import
pytest
from
torch_xla._internal
import
tpu
from
vllm.platforms
import
current_platform
...
...
@@ -63,3 +64,45 @@ def test_basic(
output
=
vllm_outputs
[
0
][
1
]
assert
"1024"
in
output
or
"0, 1"
in
output
TP_SIZE_8
=
8
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_tpu
(),
reason
=
"This is a test for TPU only"
)
@
pytest
.
mark
.
skipif
(
tpu
.
num_available_chips
()
<
TP_SIZE_8
,
reason
=
f
"This test requires
{
TP_SIZE_8
}
TPU chips."
)
def
test_gemma3_27b_with_text_input_and_tp
(
vllm_runner
:
type
[
VllmRunner
],
monkeypatch
:
pytest
.
MonkeyPatch
,
)
->
None
:
model
=
"google/gemma-3-27b-it"
max_tokens
=
16
tensor_parallel_size
=
TP_SIZE_8
max_num_seqs
=
4
prompts
=
[
"A robot may not injure a human being"
,
"It is only with the heart that one can see rightly;"
,
"The greatest glory in living lies not in never falling,"
,
]
answers
=
[
" or, through inaction, allow a human being to come to harm."
,
" what is essential is invisible to the eye."
,
" but in rising every time we fall."
,
]
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
with
vllm_runner
(
model
,
max_num_batched_tokens
=
256
,
max_num_seqs
=
max_num_seqs
,
tensor_parallel_size
=
tensor_parallel_size
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy
(
prompts
,
max_tokens
)
# vllm_outputs is a list of tuples whose first element is the token id
# and the second element is the output (including the prompt).
for
output
,
answer
in
zip
(
vllm_outputs
,
answers
):
generated_text
=
output
[
1
]
assert
answer
in
generated_text
vllm/platforms/tpu.py
View file @
97659408
...
...
@@ -30,6 +30,7 @@ class TpuPlatform(Platform):
dispatch_key
:
str
=
"XLA"
ray_device_key
:
str
=
"TPU"
device_control_env_var
:
str
=
"TPU_VISIBLE_CHIPS"
simple_compile_backend
:
str
=
"openxla"
supported_quantization
:
list
[
str
]
=
[
"tpu_int8"
,
"compressed-tensors"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment