Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
1cfc5462
"tests/vscode:/vscode.git/clone" did not exist on "bce65cd13a60c6f4ac7d3cab1a74d061964b55cd"
Unverified
Commit
1cfc5462
authored
Feb 26, 2021
by
Rick Ho
Committed by
GitHub
Feb 26, 2021
Browse files
Merge pull request #8 from laekov/simple_mlp
Simple mlp
parents
b56c8043
1dc7e73e
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
54 additions
and
45 deletions
+54
-45
examples/transformer-xl/mem_transformer.py
examples/transformer-xl/mem_transformer.py
+22
-4
fmoe/layers.py
fmoe/layers.py
+28
-17
fmoe/megatron.py
fmoe/megatron.py
+3
-7
fmoe/transformer.py
fmoe/transformer.py
+1
-17
No files found.
examples/transformer-xl/mem_transformer.py
View file @
1cfc5462
...
@@ -384,11 +384,29 @@ class CustomizedMoEPositionwiseFF(FMoETransformerMLP):
...
@@ -384,11 +384,29 @@ class CustomizedMoEPositionwiseFF(FMoETransformerMLP):
nn
.
Dropout
(
dropout
)
nn
.
Dropout
(
dropout
)
)
)
super
().
__init__
(
num_expert
=
moe_num_expert
,
d_model
=
d_model
,
d_hidden
=
d_inner
,
top_k
=
moe_top_k
,
super
().
__init__
(
num_expert
=
moe_num_expert
,
d_model
=
d_model
,
d_hidden
=
d_inner
,
top_k
=
moe_top_k
,
do_lnorm
=
True
,
pre_lnorm
=
pre_lnorm
,
activation
=
activation
,
dropout
=
dropout
)
activation
=
activation
)
def
forward
(
self
,
x
):
self
.
pre_lnorm
=
pre_lnorm
x
=
super
().
forward
(
x
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
return
x
self
.
dropout
=
nn
.
Dropout
(
dropout
)
def
forward
(
self
,
inp
):
if
self
.
pre_lnorm
:
##### layer normalization + positionwise feed-forward
core_out
=
super
().
forward
(
self
.
layer_norm
(
inp
))
core_out
=
self
.
dropout
(
core_out
)
##### residual connection
output
=
core_out
+
inp
else
:
##### positionwise feed-forward
core_out
=
super
().
forward
(
inp
)
core_out
=
self
.
dropout
(
core_out
)
##### residual connection + layer normalization
output
=
self
.
layer_norm
(
inp
+
core_out
)
return
output
class
DecoderLayer
(
nn
.
Module
):
class
DecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
**
kwargs
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
**
kwargs
):
...
...
fmoe/layers.py
View file @
1cfc5462
...
@@ -61,22 +61,33 @@ class FMoELinear(nn.Module):
...
@@ -61,22 +61,33 @@ class FMoELinear(nn.Module):
'''
'''
x
=
MOELinear
.
apply
(
inp
,
self
.
weight
,
fwd_expert_count
)
x
=
MOELinear
.
apply
(
inp
,
self
.
weight
,
fwd_expert_count
)
if
self
.
bias
is
not
None
:
if
self
.
bias
is
not
None
:
# TODO: torch.repeat_interleave seems have wrong
# TODO: torch.repeat_interleave seems have numerical
# behaviors in backward, leading to incorrect
# instability in backward, leading to incorrect
# gradient computation for bias.
# gradient computation for solution 1 and 2.
# Thus we use a for-loop to manually expand the bias.
# Solution 3 uses a for-loop to expand the bias,
# This part should finally goes to MOELinear.apply.
# but is 50% slower.
# bias = torch.repeat_interleave(self.bias,
# This part should finally goes to MOELinear.apply,
# fwd_expert_count.to(self.bias.device), dim=0)
# like MOELinear.apply(x, weight, bias, count)
bias
=
[]
for
i
in
range
(
self
.
num_expert
):
# Solution 1
if
fwd_expert_count
[
i
]
>
0
:
bias
=
torch
.
repeat_interleave
(
self
.
bias
,
bias
.
append
(
fwd_expert_count
.
to
(
self
.
bias
.
device
),
dim
=
0
)
self
.
bias
[
i
].
unsqueeze
(
0
).
expand
(
fwd_expert_count
[
i
],
-
1
# Solution 2
)
# bias_idx = torch.arange(self.num_expert)\
)
# .repeat_interleave(fwd_expert_count)
bias
=
torch
.
cat
(
bias
,
dim
=
0
)
# bias = self.bias[bias_idx]
# Solution 3
# bias = []
# for i in range(self.num_expert):
# if fwd_expert_count[i] > 0:
# bias.append(
# self.bias[i].unsqueeze(0).expand(
# fwd_expert_count[i], -1
# )
# )
# bias = torch.cat(bias, dim=0)
x
=
x
+
bias
x
=
x
+
bias
return
x
return
x
...
...
fmoe/megatron.py
View file @
1cfc5462
...
@@ -3,8 +3,6 @@ The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
...
@@ -3,8 +3,6 @@ The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
lines of modification.
lines of modification.
See `examples/megatron` for usage instructions.
See `examples/megatron` for usage instructions.
'''
'''
import
torch
from
.transformer
import
FMoETransformerMLP
from
.transformer
import
FMoETransformerMLP
from
.distributed
import
DistributedGroupedDataParallel
from
.distributed
import
DistributedGroupedDataParallel
from
.utils
import
get_torch_default_comm
from
.utils
import
get_torch_default_comm
...
@@ -28,12 +26,10 @@ class MegatronMLP(FMoETransformerMLP):
...
@@ -28,12 +26,10 @@ class MegatronMLP(FMoETransformerMLP):
d_model
=
args
.
hidden_size
,
d_hidden
=
args
.
hidden_hidden_size
,
d_model
=
args
.
hidden_size
,
d_hidden
=
args
.
hidden_hidden_size
,
world_size
=
world_size
,
mp_group
=
group
,
world_size
=
world_size
,
mp_group
=
group
,
expert_dp_comm
=
'none'
if
args
.
distributed_experts
else
'dp'
)
expert_dp_comm
=
'none'
if
args
.
distributed_experts
else
'dp'
)
self
.
bias
=
torch
.
nn
.
parameter
.
Parameter
(
torch
.
zeros
(
args
.
hidden_size
,
dtype
=
torch
.
float32
)
)
def
forward
(
self
,
inp
):
def
forward
(
self
,
inp
):
return
super
().
forward
(
inp
),
self
.
bias
output
=
super
().
forward
(
inp
)
bias
=
output
.
new_zeros
(
output
.
size
(
-
1
),
requires_grad
=
False
)
return
output
,
bias
def
fmoefy
(
model
,
num_experts
=
None
,
distributed_experts
=
True
,
def
fmoefy
(
model
,
num_experts
=
None
,
distributed_experts
=
True
,
...
...
fmoe/transformer.py
View file @
1cfc5462
...
@@ -44,25 +44,15 @@ class FMoETransformerMLP(FMoE):
...
@@ -44,25 +44,15 @@ class FMoETransformerMLP(FMoE):
d_hidden
=
4096
,
d_hidden
=
4096
,
world_size
=
1
,
world_size
=
1
,
mp_group
=
None
,
mp_group
=
None
,
activation
=
torch
.
nn
.
functional
.
gelu
,
activation
=
torch
.
nn
.
GELU
()
,
gate
=
NaiveGate
,
gate
=
NaiveGate
,
top_k
=
2
,
top_k
=
2
,
do_lnorm
=
False
,
pre_lnorm
=
False
,
expert_dp_comm
=
'none'
,
expert_dp_comm
=
'none'
,
dropout
=
0.1
):
):
super
().
__init__
(
num_expert
=
num_expert
,
d_model
=
d_model
,
gate
=
gate
,
super
().
__init__
(
num_expert
=
num_expert
,
d_model
=
d_model
,
gate
=
gate
,
top_k
=
top_k
,
world_size
=
world_size
,
mp_group
=
mp_group
)
top_k
=
top_k
,
world_size
=
world_size
,
mp_group
=
mp_group
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
,
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
,
rank
=
self
.
mp_rank
)
rank
=
self
.
mp_rank
)
self
.
pre_lnorm
=
pre_lnorm
if
do_lnorm
:
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
pre_lnorm
=
pre_lnorm
else
:
self
.
pre_lnorm
=
None
self
.
mark_parallel_comm
(
expert_dp_comm
)
self
.
mark_parallel_comm
(
expert_dp_comm
)
def
forward
(
self
,
inp
:
torch
.
Tensor
):
def
forward
(
self
,
inp
:
torch
.
Tensor
):
...
@@ -72,11 +62,5 @@ class FMoETransformerMLP(FMoE):
...
@@ -72,11 +62,5 @@ class FMoETransformerMLP(FMoE):
'''
'''
original_shape
=
inp
.
shape
original_shape
=
inp
.
shape
inp
=
inp
.
reshape
(
-
1
,
self
.
d_model
)
inp
=
inp
.
reshape
(
-
1
,
self
.
d_model
)
if
self
.
pre_lnorm
is
not
None
and
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
output
=
super
().
forward
(
inp
)
output
=
super
().
forward
(
inp
)
output
=
self
.
dropout
(
output
)
output
+=
inp
if
self
.
pre_lnorm
is
not
None
and
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
return
output
.
reshape
(
original_shape
)
return
output
.
reshape
(
original_shape
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment