Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
f6afdbee
Commit
f6afdbee
authored
Feb 24, 2021
by
Rick Ho
Browse files
add megatron back
parent
72e40c74
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
139 additions
and
0 deletions
+139
-0
fmoe/gates.py
fmoe/gates.py
+14
-0
fmoe/megatron.py
fmoe/megatron.py
+125
-0
No files found.
fmoe/gates.py
View file @
f6afdbee
...
@@ -7,6 +7,20 @@ import torch.nn as nn
...
@@ -7,6 +7,20 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
class
ZeroGate
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
num_expert
,
world_size
,
top_k
=
2
):
super
().
__init__
()
def
forward
(
self
,
inp
):
r
'''
The naive implementation simply calculates the top-k of a linear layer's
output.
'''
idx
=
torch
.
zeros
(
inp
.
shape
[
0
],
dtype
=
torch
.
int64
,
device
=
inp
.
device
)
score
=
torch
.
ones
(
inp
.
shape
[
0
],
device
=
inp
.
device
)
return
idx
,
score
.
reshape
(
-
1
,
1
,
1
)
class
NaiveGate
(
nn
.
Module
):
class
NaiveGate
(
nn
.
Module
):
r
'''
r
'''
A naive gate implementation that defines the standard behavior of the gate
A naive gate implementation that defines the standard behavior of the gate
...
...
fmoe/megatron.py
0 → 100644
View file @
f6afdbee
r
'''
The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
lines of modification.
See `examples/megatron` for usage instructions.
'''
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
.transformer
import
FMoETransformerMLP
from
.distributed
import
DistributedGroupedDataParallel
from
.utils
import
get_torch_default_comm
class
_MegatronMLP
(
nn
.
Module
):
def
__init__
(
self
,
args
,
group
):
super
().
__init__
()
self
.
fc1
=
nn
.
Linear
(
args
.
hidden_size
,
args
.
hidden_hidden_size
)
self
.
fc2
=
nn
.
Linear
(
args
.
hidden_hidden_size
,
args
.
hidden_size
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
F
.
gelu
(
x
)
x
=
self
.
fc2
(
x
)
return
x
,
torch
.
zeros_like
(
x
)
class
MegatronMLP
(
FMoETransformerMLP
):
r
'''
Make the FMoETransformerMLP layer that distributes experts across
communication group `group` to replace the original MLP layer in Megatron.
'''
def
__init__
(
self
,
args
,
group
):
assert
(
args
.
seq_length
*
args
.
micro_batch_size
%
args
.
tensor_model_parallel_size
==
0
),
"Batch size x sequence length should be multiple of mp size"
if
not
args
.
distributed_experts
:
world_size
=
1
else
:
world_size
=
args
.
world_size
super
().
__init__
(
args
.
num_experts
,
top_k
=
args
.
top_k
,
d_model
=
args
.
hidden_size
,
d_hidden
=
args
.
hidden_hidden_size
,
world_size
=
world_size
,
mp_group
=
group
,
expert_dp_comm
=
'none'
if
args
.
distributed_experts
else
'dp'
)
self
.
hidden_size
=
args
.
hidden_size
def
forward
(
self
,
inp
):
return
super
().
forward
(
inp
),
torch
.
zeros
(
self
.
hidden_size
,
dtype
=
inp
.
dtype
,
device
=
inp
.
device
)
def
fmoefy
(
model
,
num_experts
=
None
,
distributed_experts
=
True
,
hidden_hidden_size
=
None
,
top_k
=
None
):
r
'''
Replace MLP layers in a transformer-based model in Megatron by MoE.
* `model` should be a standard Megatron model that has
`model.language_model.transformer.layers` as transformer layers, which is an
array of transformer blocks that contain an `mlp` member.
* `distributed_expert` is set to True if different experts are located in
different workers. Otherwise, the experts on the workers are identical, and
they are trained in data-parallel mode. This can be useful when testing on
small models that do not require high training throughput or large parameter
capacity.
Note that pipeline parallel is not supported yet. When distributed experts
are enabled, their communicator should be Megatron's
tensor_model_parall_comm x data_parallel_comm, which is not created.
'''
from
megatron
import
get_args
args
=
get_args
()
if
num_experts
is
not
None
:
args
.
num_experts
=
num_experts
assert
(
'num_experts'
in
args
),
'num_experts should be specified in arguments or fmoefy function'
if
hidden_hidden_size
is
not
None
:
args
.
hidden_hidden_size
=
hidden_hidden_size
elif
not
hasattr
(
args
,
'hidden_hidden_size'
):
args
.
hidden_hidden_size
=
args
.
hidden_size
*
4
if
top_k
is
not
None
:
args
.
top_k
=
top_k
elif
not
hasattr
(
args
,
'top_k'
):
args
.
top_k
=
2
# Set distributed_experts to None to use default setting in args
if
distributed_experts
is
not
None
:
args
.
distributed_experts
=
distributed_experts
for
l
in
model
.
language_model
.
transformer
.
layers
:
l
.
mlp
=
MegatronMLP
(
args
,
None
)
return
model
class
DistributedDataParallel
(
DistributedGroupedDataParallel
):
r
'''
A wrapper that is used to replace the DDP module provided by Megatron, which
is adapted to enable the sophiscated parallel and reduction strategies in
Fast MoE.
'''
def
__init__
(
self
,
module
):
from
megatron
import
mpu
super
().
__init__
(
module
,
mp_group
=
mpu
.
get_model_parallel_group
(),
dp_group
=
mpu
.
get_data_parallel_group
()
)
def
state_dict
(
self
,
*
args
,
**
kwargs
):
r
'''
Keep consitency with Megatron
'''
return
self
.
module
.
state_dict
(
*
args
,
**
kwargs
)
def
state_dict_for_save_checkpoint
(
self
,
*
args
,
**
kwargs
):
r
'''
Keep consitency with Megatron
'''
return
self
.
module
.
state_dict_for_save_checkpoint
(
*
args
,
**
kwargs
)
def
load_state_dict
(
self
,
*
args
,
**
kwargs
):
r
'''
Keep consitency with Megatron
'''
return
self
.
module
.
load_state_dict
(
*
args
,
**
kwargs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment