Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
92247c52
Unverified
Commit
92247c52
authored
May 21, 2025
by
bnellnm
Committed by
GitHub
May 20, 2025
Browse files
[Bug] Fix moe_sum signature (#18440)
Signed-off-by:
Bill Nell
<
bnell@redhat.com
>
parent
0c15c2e4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
1 deletion
+19
-1
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+1
-1
tests/kernels/moe/test_moe.py
tests/kernels/moe/test_moe.py
+18
-0
No files found.
csrc/moe/torch_bindings.cpp
View file @
92247c52
...
@@ -10,7 +10,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
...
@@ -10,7 +10,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// Calculate the result of moe by summing up the partial results
// Calculate the result of moe by summing up the partial results
// from all selected experts.
// from all selected experts.
m
.
def
(
"moe_sum(Tensor
!
input, Tensor output) -> ()"
);
m
.
def
(
"moe_sum(Tensor input, Tensor
!
output) -> ()"
);
m
.
impl
(
"moe_sum"
,
torch
::
kCUDA
,
&
moe_sum
);
m
.
impl
(
"moe_sum"
,
torch
::
kCUDA
,
&
moe_sum
);
// Aligning the number of tokens to be processed by each expert such
// Aligning the number of tokens to be processed by each expert such
...
...
tests/kernels/moe/test_moe.py
View file @
92247c52
...
@@ -575,3 +575,21 @@ def test_moe_align_block_size_opcheck():
...
@@ -575,3 +575,21 @@ def test_moe_align_block_size_opcheck():
opcheck
(
torch
.
ops
.
_moe_C
.
moe_align_block_size
,
opcheck
(
torch
.
ops
.
_moe_C
.
moe_align_block_size
,
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
))
num_tokens_post_pad
))
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
33
,
222
,
1024
*
128
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
511
,
1024
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
"Skip for rocm"
)
def
test_moe_sum
(
m
:
int
,
topk
:
int
,
k
:
int
,
dtype
:
torch
.
dtype
):
input
=
torch
.
randn
((
m
,
topk
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
actual
=
torch
.
empty
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
expected
=
input
.
sum
(
dim
=
1
)
torch
.
ops
.
_moe_C
.
moe_sum
(
input
,
actual
)
torch
.
testing
.
assert_close
(
actual
,
expected
,
atol
=
2e-2
,
rtol
=
0
)
opcheck
(
torch
.
ops
.
_moe_C
.
moe_sum
,
(
input
,
actual
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment