Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
dc3db673
Commit
dc3db673
authored
Feb 02, 2021
by
Rick Ho
Browse files
fix replica condition and minor optimizations
parent
a8ecd3d7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
21 deletions
+18
-21
fmoe/layers.py
fmoe/layers.py
+17
-20
fmoe/megatron.py
fmoe/megatron.py
+1
-1
No files found.
fmoe/layers.py
View file @
dc3db673
...
...
@@ -99,12 +99,15 @@ class FMoETransformerMLP(nn.Module):
)
def
forward
(
self
,
inp
:
torch
.
Tensor
):
if
self
.
num_expert
!=
1
:
B
:
int
=
inp
.
shape
[
1
]
original_shape
=
inp
.
shape
inp
=
inp
.
reshape
(
-
1
,
self
.
d_model
)
if
self
.
model_parallel_size
>
1
:
B
:
int
=
inp
.
shape
[
0
]
local_batch_size
=
B
//
self
.
model_parallel_size
batch_start
=
local_batch_size
*
self
.
model_parallel_rank
batch_end
=
min
(
batch_start
+
local_batch_size
,
B
)
inp
=
inp
[
:,
batch_start
:
batch_end
,
:].
contiguous
()
inp
=
inp
[
batch_start
:
batch_end
]
residual
=
inp
if
self
.
pre_lnorm
:
...
...
@@ -112,9 +115,9 @@ class FMoETransformerMLP(nn.Module):
gate_top_k_idx
,
gate_score
=
self
.
gate
(
inp
)
inp
=
inp
.
view
(
-
1
,
self
.
d_model
).
repeat_interleave
(
repeats
=
self
.
top_k
,
dim
=
0
)
# (BxLxtop_k) x d_model
# to: (BxLxtop_k) x d_model
inp
=
inp
.
repeat_interleave
(
repeats
=
self
.
top_k
,
dim
=
0
)
x
=
_fmoe_full_forward
(
inp
,
gate_top_k_idx
,
...
...
@@ -124,26 +127,20 @@ class FMoETransformerMLP(nn.Module):
self
.
world_size
,
)
core_out
=
x
.
view
(
-
1
,
self
.
top_k
,
self
.
d_model
)
# (BxL) x top_k x d_model
core_out
=
torch
.
bmm
(
gate_score
,
core_out
)
# (BxL) x 1 x d_model
core_out
=
core_out
.
view
(
residual
.
size
(
0
),
residual
.
size
(
1
),
self
.
d_model
)
# to: (BxL) x top_k x d_model
core_out
=
x
.
view
(
-
1
,
self
.
top_k
,
self
.
d_model
)
# to: (BxL) x 1 x d_model
core_out
=
torch
.
bmm
(
gate_score
,
core_out
)
output
=
core_out
+
residual
if
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
if
self
.
num_expert
!=
1
:
if
self
.
model_parallel_size
>
1
:
world_size
=
self
.
model_parallel_size
if
world_size
==
1
:
return
output
,
self
.
bias
rank
=
self
.
model_parallel_rank
tensor_list
=
[
torch
.
empty_like
(
output
)
for
_
in
range
(
world_size
)]
tensor_list
[
rank
]
=
output
torch
.
distributed
.
all_gather
(
tensor_list
,
output
,
group
=
self
.
group
)
# Note: torch.cat already creates a contiguous tensor.
output
=
torch
.
cat
(
tensor_list
,
dim
=
1
)
.
contiguous
()
torch
.
distributed
.
all_gather
(
tensor_list
,
output
,
group
=
self
.
group
)
output
=
torch
.
cat
(
tensor_list
,
dim
=
1
)
return
output
,
self
.
bias
return
output
.
reshape
(
original_shape
)
,
self
.
bias
fmoe/megatron.py
View file @
dc3db673
...
...
@@ -3,7 +3,7 @@ from .layers import FMoETransformerMLP
def
create_moe_mlp
(
args
,
model_parallel_rank
,
group
):
assert
(
args
.
num_experts
%
args
.
model_parallel_size
==
0
args
.
seq_length
*
args
.
batch_size
%
args
.
model_parallel_size
==
0
),
"Num experts should be multiple of mp size"
num_experts
=
args
.
num_experts
//
args
.
model_parallel_size
fmoe
=
FMoETransformerMLP
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment