Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
1cfc5462
Unverified
Commit
1cfc5462
authored
Feb 26, 2021
by
Rick Ho
Committed by
GitHub
Feb 26, 2021
Browse files
Merge pull request #8 from laekov/simple_mlp
Simple mlp
parents
b56c8043
1dc7e73e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
54 additions
and
45 deletions
+54
-45
examples/transformer-xl/mem_transformer.py
examples/transformer-xl/mem_transformer.py
+22
-4
fmoe/layers.py
fmoe/layers.py
+28
-17
fmoe/megatron.py
fmoe/megatron.py
+3
-7
fmoe/transformer.py
fmoe/transformer.py
+1
-17
No files found.
examples/transformer-xl/mem_transformer.py
View file @
1cfc5462
...
...
@@ -384,11 +384,29 @@ class CustomizedMoEPositionwiseFF(FMoETransformerMLP):
nn
.
Dropout
(
dropout
)
)
super
().
__init__
(
num_expert
=
moe_num_expert
,
d_model
=
d_model
,
d_hidden
=
d_inner
,
top_k
=
moe_top_k
,
do_lnorm
=
True
,
pre_lnorm
=
pre_lnorm
,
activation
=
activation
,
dropout
=
dropout
)
activation
=
activation
)
def
forward
(
self
,
x
):
x
=
super
().
forward
(
x
)
return
x
self
.
pre_lnorm
=
pre_lnorm
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
def
forward
(
self
,
inp
):
if
self
.
pre_lnorm
:
##### layer normalization + positionwise feed-forward
core_out
=
super
().
forward
(
self
.
layer_norm
(
inp
))
core_out
=
self
.
dropout
(
core_out
)
##### residual connection
output
=
core_out
+
inp
else
:
##### positionwise feed-forward
core_out
=
super
().
forward
(
inp
)
core_out
=
self
.
dropout
(
core_out
)
##### residual connection + layer normalization
output
=
self
.
layer_norm
(
inp
+
core_out
)
return
output
class
DecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
**
kwargs
):
...
...
fmoe/layers.py
View file @
1cfc5462
...
...
@@ -61,28 +61,39 @@ class FMoELinear(nn.Module):
'''
x
=
MOELinear
.
apply
(
inp
,
self
.
weight
,
fwd_expert_count
)
if
self
.
bias
is
not
None
:
# TODO: torch.repeat_interleave seems have wrong
# behaviors in backward, leading to incorrect
# gradient computation for bias.
# Thus we use a for-loop to manually expand the bias.
# This part should finally goes to MOELinear.apply.
# bias = torch.repeat_interleave(self.bias,
# fwd_expert_count.to(self.bias.device), dim=0)
bias
=
[]
for
i
in
range
(
self
.
num_expert
):
if
fwd_expert_count
[
i
]
>
0
:
bias
.
append
(
self
.
bias
[
i
].
unsqueeze
(
0
).
expand
(
fwd_expert_count
[
i
],
-
1
)
)
bias
=
torch
.
cat
(
bias
,
dim
=
0
)
# TODO: torch.repeat_interleave seems have numerical
# instability in backward, leading to incorrect
# gradient computation for solution 1 and 2.
# Solution 3 uses a for-loop to expand the bias,
# but is 50% slower.
# This part should finally goes to MOELinear.apply,
# like MOELinear.apply(x, weight, bias, count)
# Solution 1
bias
=
torch
.
repeat_interleave
(
self
.
bias
,
fwd_expert_count
.
to
(
self
.
bias
.
device
),
dim
=
0
)
# Solution 2
# bias_idx = torch.arange(self.num_expert)\
# .repeat_interleave(fwd_expert_count)
# bias = self.bias[bias_idx]
# Solution 3
# bias = []
# for i in range(self.num_expert):
# if fwd_expert_count[i] > 0:
# bias.append(
# self.bias[i].unsqueeze(0).expand(
# fwd_expert_count[i], -1
# )
# )
# bias = torch.cat(bias, dim=0)
x
=
x
+
bias
return
x
def
extra_repr
(
self
)
->
str
:
return
'num_expert={}, in_features={},
\
out_features={}, bias={}, rank={}'
.
format
(
out_features={}, bias={}, rank={}'
.
format
(
self
.
num_expert
,
self
.
in_feat
,
self
.
out_feat
,
self
.
bias
is
not
None
,
self
.
rank
)
...
...
fmoe/megatron.py
View file @
1cfc5462
...
...
@@ -3,8 +3,6 @@ The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
lines of modification.
See `examples/megatron` for usage instructions.
'''
import
torch
from
.transformer
import
FMoETransformerMLP
from
.distributed
import
DistributedGroupedDataParallel
from
.utils
import
get_torch_default_comm
...
...
@@ -28,12 +26,10 @@ class MegatronMLP(FMoETransformerMLP):
d_model
=
args
.
hidden_size
,
d_hidden
=
args
.
hidden_hidden_size
,
world_size
=
world_size
,
mp_group
=
group
,
expert_dp_comm
=
'none'
if
args
.
distributed_experts
else
'dp'
)
self
.
bias
=
torch
.
nn
.
parameter
.
Parameter
(
torch
.
zeros
(
args
.
hidden_size
,
dtype
=
torch
.
float32
)
)
def
forward
(
self
,
inp
):
return
super
().
forward
(
inp
),
self
.
bias
output
=
super
().
forward
(
inp
)
bias
=
output
.
new_zeros
(
output
.
size
(
-
1
),
requires_grad
=
False
)
return
output
,
bias
def
fmoefy
(
model
,
num_experts
=
None
,
distributed_experts
=
True
,
...
...
fmoe/transformer.py
View file @
1cfc5462
...
...
@@ -44,25 +44,15 @@ class FMoETransformerMLP(FMoE):
d_hidden
=
4096
,
world_size
=
1
,
mp_group
=
None
,
activation
=
torch
.
nn
.
functional
.
gelu
,
activation
=
torch
.
nn
.
GELU
()
,
gate
=
NaiveGate
,
top_k
=
2
,
do_lnorm
=
False
,
pre_lnorm
=
False
,
expert_dp_comm
=
'none'
,
dropout
=
0.1
):
super
().
__init__
(
num_expert
=
num_expert
,
d_model
=
d_model
,
gate
=
gate
,
top_k
=
top_k
,
world_size
=
world_size
,
mp_group
=
mp_group
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
,
rank
=
self
.
mp_rank
)
self
.
pre_lnorm
=
pre_lnorm
if
do_lnorm
:
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
pre_lnorm
=
pre_lnorm
else
:
self
.
pre_lnorm
=
None
self
.
mark_parallel_comm
(
expert_dp_comm
)
def
forward
(
self
,
inp
:
torch
.
Tensor
):
...
...
@@ -72,11 +62,5 @@ class FMoETransformerMLP(FMoE):
'''
original_shape
=
inp
.
shape
inp
=
inp
.
reshape
(
-
1
,
self
.
d_model
)
if
self
.
pre_lnorm
is
not
None
and
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
output
=
super
().
forward
(
inp
)
output
=
self
.
dropout
(
output
)
output
+=
inp
if
self
.
pre_lnorm
is
not
None
and
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
return
output
.
reshape
(
original_shape
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment