Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3b17ea26
Unverified
Commit
3b17ea26
authored
May 20, 2025
by
Michael Goin
Committed by
GitHub
May 20, 2025
Browse files
[TPU] Re-enable the Pallas MoE kernel (#18025)
Signed-off-by:
Michael Goin
<
mgoin64@gmail.com
>
parent
23baa218
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
24 additions
and
9 deletions
+24
-9
requirements/tpu.txt
requirements/tpu.txt
+5
-5
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+1
-2
vllm/model_executor/layers/fused_moe/moe_pallas.py
vllm/model_executor/layers/fused_moe/moe_pallas.py
+18
-2
No files found.
requirements/tpu.txt
View file @
3b17ea26
...
@@ -18,9 +18,9 @@ setuptools==78.1.0
...
@@ -18,9 +18,9 @@ setuptools==78.1.0
--find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
torch==2.8.0.dev20250
430
torch==2.8.0.dev20250
518
torchvision==0.22.0.dev20250
430
torchvision==0.22.0.dev20250
518
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250
430
-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250
518
-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250
430
-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250
518
-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250
430
-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250
518
-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
vllm/model_executor/layers/fused_moe/layer.py
View file @
3b17ea26
...
@@ -50,8 +50,7 @@ if is_rocm_aiter_moe_enabled():
...
@@ -50,8 +50,7 @@ if is_rocm_aiter_moe_enabled():
else
:
else
:
from
vllm.model_executor.layers.fused_moe.fused_moe
import
grouped_topk
from
vllm.model_executor.layers.fused_moe.fused_moe
import
grouped_topk
if
current_platform
.
is_tpu
():
if
current_platform
.
is_tpu
():
# the iterative moe implementation is used until the moe_pallas is fixed
from
.moe_pallas
import
fused_moe
as
fused_moe_pallas
from
.moe_torch_iterative
import
fused_moe
as
fused_moe_pallas
else
:
else
:
fused_moe_pallas
=
None
# type: ignore
fused_moe_pallas
=
None
# type: ignore
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
...
vllm/model_executor/layers/fused_moe/moe_pallas.py
View file @
3b17ea26
...
@@ -2,7 +2,23 @@
...
@@ -2,7 +2,23 @@
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch_xla.experimental.custom_kernel
import
_histogram
def
_histogram
(
input
:
torch
.
Tensor
,
min
:
int
,
max
:
int
)
->
torch
.
Tensor
:
"""
Compute the histogram of a int32 tensor. The bin edges are defined by the
min and max values, with step = 1.
"""
assert
input
.
dtype
==
torch
.
int32
,
"input must be of torch.int32 dtype."
assert
min
<=
max
,
"min must be less than or equal to max."
def
searchsorted
(
sorted_sequence
:
torch
.
Tensor
,
values_to_search
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
(
sorted_sequence
.
unsqueeze
(
1
)
==
values_to_search
).
sum
(
dim
=
1
)
bin_edges
=
torch
.
linspace
(
min
,
max
,
max
-
min
+
1
,
dtype
=
input
.
dtype
).
to
(
input
.
device
)
return
searchsorted
(
bin_edges
,
input
).
to
(
torch
.
int32
)
def
fused_moe
(
def
fused_moe
(
...
@@ -61,7 +77,7 @@ def fused_moe(
...
@@ -61,7 +77,7 @@ def fused_moe(
x
=
torch
.
ops
.
xla
.
gmm
(
x
,
w2
,
group_sizes
)
x
=
torch
.
ops
.
xla
.
gmm
(
x
,
w2
,
group_sizes
)
x
=
x
[
topk_argsort_revert_indices
].
reshape
(
-
1
,
topk
,
hidden_size
)
x
=
x
[
topk_argsort_revert_indices
].
reshape
(
-
1
,
topk
,
hidden_size
)
x
=
x
*
topk_weights
.
unsqueeze
_
(
dim
=-
1
)
x
=
x
*
topk_weights
.
unsqueeze
(
dim
=-
1
)
x
=
x
.
sum
(
dim
=-
2
)
x
=
x
.
sum
(
dim
=-
2
)
x
=
x
.
reshape
(
orig_shape
)
x
=
x
.
reshape
(
orig_shape
)
return
x
return
x
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment