Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a1cc9f33
Unverified
Commit
a1cc9f33
authored
May 29, 2025
by
Chengji Yao
Committed by
GitHub
May 29, 2025
Browse files
[TPU] remove transpose ops in moe kernel (#18923)
Signed-off-by:
Chengji Yao
<
chengjiyao@google.com
>
parent
a521ef06
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
8 additions
and
13 deletions
+8
-13
requirements/tpu.txt
requirements/tpu.txt
+5
-5
tests/tpu/test_moe_pallas.py
tests/tpu/test_moe_pallas.py
+1
-1
vllm/model_executor/layers/fused_moe/moe_pallas.py
vllm/model_executor/layers/fused_moe/moe_pallas.py
+2
-7
No files found.
requirements/tpu.txt
View file @
a1cc9f33
...
...
@@ -18,9 +18,9 @@ setuptools==78.1.0
--find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
torch==2.8.0.dev202505
18
torchvision==0.22.0.dev202505
18
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev202505
18
-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev202505
18
-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev202505
18
-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch==2.8.0.dev202505
29
torchvision==0.22.0.dev202505
29
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev202505
29
-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev202505
29
-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev202505
29
-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
tests/tpu/test_moe_pallas.py
View file @
a1cc9f33
...
...
@@ -26,7 +26,7 @@ TOP_KS = [2, 6]
# The Pallas GMM kernel requires num_tokens * topk to be a multiple of 16
@
pytest
.
mark
.
parametrize
(
"m"
,
[
8
,
16
,
64
,
2048
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
1024
,
2048
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
51
1
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
51
2
,
1024
])
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"ep_size"
,
EP_SIZE
)
...
...
vllm/model_executor/layers/fused_moe/moe_pallas.py
View file @
a1cc9f33
...
...
@@ -67,15 +67,10 @@ def fused_moe(
token_indices
=
token_indices
[
topk_argsort_indices
]
group_sizes
=
_histogram
(
topk_indices
.
to
(
torch
.
int32
),
0
,
num_experts
-
1
)
# NOTE(woosuk): The GMM Pallas kernel requires a different weight layout
# from HF Transformers.
w1
=
w1
.
transpose
(
1
,
2
)
w2
=
w2
.
transpose
(
1
,
2
)
x
=
hidden_states
[
token_indices
]
x
=
torch
.
ops
.
xla
.
gmm
(
x
,
w1
,
group_sizes
)
x
=
torch
.
ops
.
xla
.
gmm
(
x
,
w1
,
group_sizes
,
transpose_rhs
=
True
)
x
=
F
.
silu
(
x
[...,
:
intermediate_size
])
*
x
[...,
intermediate_size
:]
x
=
torch
.
ops
.
xla
.
gmm
(
x
,
w2
,
group_sizes
)
x
=
torch
.
ops
.
xla
.
gmm
(
x
,
w2
,
group_sizes
,
transpose_rhs
=
True
)
x
=
x
[
topk_argsort_revert_indices
].
reshape
(
-
1
,
topk
,
hidden_size
)
x
=
x
*
topk_weights
.
unsqueeze
(
dim
=-
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment