Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
263a870e
Unverified
Commit
263a870e
authored
Jan 12, 2025
by
Avshalom Manevich
Committed by
GitHub
Jan 12, 2025
Browse files
[Hardware][TPU] workaround fix for MoE on TPU (#11764)
parent
8bddb735
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
60 additions
and
1 deletion
+60
-1
tests/kernels/test_moe.py
tests/kernels/test_moe.py
+7
-0
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+2
-1
vllm/model_executor/layers/fused_moe/moe_torch_iterative.py
vllm/model_executor/layers/fused_moe/moe_torch_iterative.py
+51
-0
No files found.
tests/kernels/test_moe.py
View file @
263a870e
...
@@ -14,6 +14,8 @@ from vllm import _custom_ops as ops
...
@@ -14,6 +14,8 @@ from vllm import _custom_ops as ops
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_topk
,
moe_align_block_size
)
fused_topk
,
moe_align_block_size
)
from
vllm.model_executor.layers.fused_moe.moe_torch_iterative
import
(
fused_moe
as
iterative_moe
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
marlin_quantize
)
marlin_quantize
)
from
vllm.model_executor.models.mixtral
import
MixtralMoE
from
vllm.model_executor.models.mixtral
import
MixtralMoE
...
@@ -46,6 +48,11 @@ def test_fused_moe(
...
@@ -46,6 +48,11 @@ def test_fused_moe(
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
)
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
)
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
)
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
iterative_output
=
iterative_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
)
torch
.
testing
.
assert_close
(
iterative_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
@
pytest
.
mark
.
parametrize
(
"dtype"
,
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
263a870e
...
@@ -20,7 +20,8 @@ if current_platform.is_cuda_alike():
...
@@ -20,7 +20,8 @@ if current_platform.is_cuda_alike():
else
:
else
:
fused_experts
=
None
# type: ignore
fused_experts
=
None
# type: ignore
if
current_platform
.
is_tpu
():
if
current_platform
.
is_tpu
():
from
.moe_pallas
import
fused_moe
as
fused_moe_pallas
# the iterative moe implementation is used until the moe_pallas is fixed
from
.moe_torch_iterative
import
fused_moe
as
fused_moe_pallas
else
:
else
:
fused_moe_pallas
=
None
# type: ignore
fused_moe_pallas
=
None
# type: ignore
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
...
vllm/model_executor/layers/fused_moe/moe_torch_iterative.py
0 → 100644
View file @
263a870e
import
torch
import
torch.nn.functional
as
F
def
fused_moe
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
)
->
torch
.
Tensor
:
"""
Args:
hidden_states: [*, hidden_size]
w1: [num_experts, intermediate_size * 2, hidden_size]
w2: [num_experts, hidden_size, intermediate_size]
gating_output: [*, num_experts]
"""
orig_shape
=
hidden_states
.
shape
hidden_size
=
hidden_states
.
shape
[
-
1
]
num_tokens
=
hidden_states
.
shape
[:
-
1
].
numel
()
num_experts
=
w1
.
shape
[
0
]
intermediate_size
=
w2
.
shape
[
-
1
]
dtype
=
hidden_states
.
dtype
hidden_states
=
hidden_states
.
view
(
num_tokens
,
hidden_size
)
gating_output
=
gating_output
.
view
(
num_tokens
,
num_experts
)
topk_weights
=
gating_output
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float
)
topk_weights
,
selected_experts
=
topk_weights
.
topk
(
topk
,
dim
=-
1
)
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
topk_weights
=
topk_weights
.
to
(
dtype
)
final_hidden_states
=
None
for
expert_idx
in
range
(
num_experts
):
expert_w1
=
w1
[
expert_idx
]
expert_w2
=
w2
[
expert_idx
]
expert_mask
=
(
selected_experts
==
expert_idx
)
expert_weights
=
(
topk_weights
*
expert_mask
).
sum
(
dim
=-
1
,
keepdim
=
True
)
x
=
F
.
linear
(
hidden_states
,
expert_w1
)
gate
=
F
.
silu
(
x
[:,
:
intermediate_size
])
x
=
x
[:,
intermediate_size
:]
*
gate
x
=
F
.
linear
(
x
,
expert_w2
)
current_hidden_states
=
x
*
expert_weights
if
final_hidden_states
is
None
:
final_hidden_states
=
current_hidden_states
else
:
final_hidden_states
=
final_hidden_states
+
current_hidden_states
return
final_hidden_states
.
view
(
orig_shape
)
# type: ignore
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment