Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
1c555604
Commit
1c555604
authored
Nov 23, 2021
by
Jiezhong Qiu
Browse files
check moe input/output batch sizes are the same
parent
27c8c2f3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
6 deletions
+21
-6
fmoe/layers.py
fmoe/layers.py
+21
-6
No files found.
fmoe/layers.py
View file @
1c555604
...
@@ -179,6 +179,14 @@ class FMoE(nn.Module):
...
@@ -179,6 +179,14 @@ class FMoE(nn.Module):
according to the gate. The score of the selected gate given by the
according to the gate. The score of the selected gate given by the
expert is multiplied to the experts' output tensors as a weight.
expert is multiplied to the experts' output tensors as a weight.
"""
"""
moe_inp_batch_size
=
tree
.
flatten
(
tree
.
map_structure
(
lambda
tensor
:
tensor
.
shape
[
0
],
moe_inp
)
)
assert
all
(
[
batch_size
==
moe_inp_batch_size
[
0
]
for
batch_size
in
moe_inp_batch_size
]
),
"MoE inputs must have the same batch size"
if
self
.
world_size
>
1
:
if
self
.
world_size
>
1
:
def
ensure_comm_func
(
tensor
):
def
ensure_comm_func
(
tensor
):
...
@@ -199,14 +207,14 @@ class FMoE(nn.Module):
...
@@ -199,14 +207,14 @@ class FMoE(nn.Module):
if
self
.
gate_hook
is
not
None
:
if
self
.
gate_hook
is
not
None
:
self
.
gate_hook
(
gate_top_k_idx
,
gate_score
,
None
)
self
.
gate_hook
(
gate_top_k_idx
,
gate_score
,
None
)
# TODO: to fix
def
delete_mask_func
(
tensor
):
# to: (BxL') x d_model
tensor
=
tensor
[
mask
==
0
,
:]
return
tensor
# delete masked tensors
# delete masked tensors
if
self
.
mask
is
not
None
and
self
.
mask_dict
is
not
None
:
if
self
.
mask
is
not
None
and
self
.
mask_dict
is
not
None
:
# TODO: to fix
def
delete_mask_func
(
tensor
):
# to: (BxL') x d_model
tensor
=
tensor
[
mask
==
0
,
:]
return
tensor
mask
=
self
.
mask
.
view
(
-
1
)
mask
=
self
.
mask
.
view
(
-
1
)
moe_inp
=
tree
.
map_structure
(
delete_mask_func
,
moe_inp
)
moe_inp
=
tree
.
map_structure
(
delete_mask_func
,
moe_inp
)
gate_top_k_idx
=
gate_top_k_idx
[
mask
==
0
,
:]
gate_top_k_idx
=
gate_top_k_idx
[
mask
==
0
,
:]
...
@@ -263,4 +271,11 @@ class FMoE(nn.Module):
...
@@ -263,4 +271,11 @@ class FMoE(nn.Module):
)
)
moe_outp
=
tree
.
map_structure
(
all_gather_func
,
moe_outp
)
moe_outp
=
tree
.
map_structure
(
all_gather_func
,
moe_outp
)
moe_outp_batch_size
=
tree
.
flatten
(
tree
.
map_structure
(
lambda
tensor
:
tensor
.
shape
[
0
],
moe_outp
)
)
assert
all
(
[
batch_size
==
moe_outp_batch_size
[
0
]
for
batch_size
in
moe_outp_batch_size
]
),
"MoE outputs must have the same batch size"
return
moe_outp
return
moe_outp
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment