Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
1c555604
Commit
1c555604
authored
Nov 23, 2021
by
Jiezhong Qiu
Browse files
check moe input/output batch sizes are the same
parent
27c8c2f3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
6 deletions
+21
-6
fmoe/layers.py
fmoe/layers.py
+21
-6
No files found.
fmoe/layers.py
View file @
1c555604
...
...
@@ -179,6 +179,14 @@ class FMoE(nn.Module):
according to the gate. The score of the selected gate given by the
expert is multiplied to the experts' output tensors as a weight.
"""
moe_inp_batch_size
=
tree
.
flatten
(
tree
.
map_structure
(
lambda
tensor
:
tensor
.
shape
[
0
],
moe_inp
)
)
assert
all
(
[
batch_size
==
moe_inp_batch_size
[
0
]
for
batch_size
in
moe_inp_batch_size
]
),
"MoE inputs must have the same batch size"
if
self
.
world_size
>
1
:
def
ensure_comm_func
(
tensor
):
...
...
@@ -199,14 +207,14 @@ class FMoE(nn.Module):
if
self
.
gate_hook
is
not
None
:
self
.
gate_hook
(
gate_top_k_idx
,
gate_score
,
None
)
# TODO: to fix
def
delete_mask_func
(
tensor
):
# to: (BxL') x d_model
tensor
=
tensor
[
mask
==
0
,
:]
return
tensor
# delete masked tensors
if
self
.
mask
is
not
None
and
self
.
mask_dict
is
not
None
:
# TODO: to fix
def
delete_mask_func
(
tensor
):
# to: (BxL') x d_model
tensor
=
tensor
[
mask
==
0
,
:]
return
tensor
mask
=
self
.
mask
.
view
(
-
1
)
moe_inp
=
tree
.
map_structure
(
delete_mask_func
,
moe_inp
)
gate_top_k_idx
=
gate_top_k_idx
[
mask
==
0
,
:]
...
...
@@ -263,4 +271,11 @@ class FMoE(nn.Module):
)
moe_outp
=
tree
.
map_structure
(
all_gather_func
,
moe_outp
)
moe_outp_batch_size
=
tree
.
flatten
(
tree
.
map_structure
(
lambda
tensor
:
tensor
.
shape
[
0
],
moe_outp
)
)
assert
all
(
[
batch_size
==
moe_outp_batch_size
[
0
]
for
batch_size
in
moe_outp_batch_size
]
),
"MoE outputs must have the same batch size"
return
moe_outp
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment