Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b61dc5f9
Unverified
Commit
b61dc5f9
authored
Jun 05, 2025
by
Chengji Yao
Committed by
GitHub
Jun 06, 2025
Browse files
[TPU] update torch_xla pin (#19231)
Signed-off-by:
Chengji Yao
<
chengjiyao@google.com
>
parent
f8a1a2d1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
8 additions
and
7 deletions
+8
-7
requirements/tpu.txt
requirements/tpu.txt
+5
-5
tests/tpu/test_moe_pallas.py
tests/tpu/test_moe_pallas.py
+1
-1
vllm/v1/worker/tpu_worker.py
vllm/v1/worker/tpu_worker.py
+2
-1
No files found.
requirements/tpu.txt
View file @
b61dc5f9
...
...
@@ -18,9 +18,9 @@ setuptools==78.1.0
--find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
torch==2.8.0.dev20250
529
torchvision==0.2
2
.0.dev20250
529
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250
529
-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250
529
-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250
529
-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch==2.8.0.dev20250
605
torchvision==0.2
3
.0.dev20250
605
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250
605
-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250
605
-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250
605
-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
tests/tpu/test_moe_pallas.py
View file @
b61dc5f9
...
...
@@ -27,7 +27,7 @@ TOP_KS = [2, 6]
# The Pallas GMM kernel requires num_tokens * topk to be a multiple of 16
@
pytest
.
mark
.
parametrize
(
"m"
,
[
8
,
16
,
64
,
2048
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
1024
,
2048
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
51
2
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
51
1
,
1024
])
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"ep_size"
,
EP_SIZE
)
...
...
vllm/v1/worker/tpu_worker.py
View file @
b61dc5f9
...
...
@@ -100,7 +100,8 @@ class TPUWorker:
# `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to
# fix this. It will be removed after the bug in XLA compiler is fixed.
os
.
environ
[
"LIBTPU_INIT_ARGS"
]
=
(
"--xla_tpu_force_1d_allreduce_at_chunk_count=1"
)
os
.
environ
.
get
(
"LIBTPU_INIT_ARGS"
,
""
)
+
" --xla_tpu_force_1d_allreduce_at_chunk_count=1"
)
torch
.
set_grad_enabled
(
False
)
torch
.
set_default_dtype
(
self
.
model_config
.
dtype
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment