Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
6b8d2f2e
"src/libtorchaudio/sox/io.cpp" did not exist on "d626703191eed866cfe2512d51ebd37f07057a14"
Commit
6b8d2f2e
authored
Feb 03, 2021
by
Rick Ho
Browse files
fmoefy
parent
4b650671
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
38 additions
and
32 deletions
+38
-32
README.md
README.md
+16
-25
fmoe/layers.py
fmoe/layers.py
+3
-3
fmoe/megatron.py
fmoe/megatron.py
+19
-4
No files found.
README.md
View file @
6b8d2f2e
...
...
@@ -24,6 +24,21 @@ using Fast MoE for training.
## Usage
### FMoEfy a transformer model
Transformer is currently the most popular model to be extended by MoE. Using
Fast MoE, a transformer-based model can be extended as MoE by an one-key plugin
shown as follow.
Assume that there is a PyTorch model
`model`
with MLP layers located at
`model.language_model.transformer.layers[<idx>].mlp`
, use the following two
lines to easily scale up the MLP layers to multiple experts.
```
python
from
fmoe.megatron
import
fmoefy
model
=
fmoefy
(
model
,
num_experts
=<
number
of
experts
per
worker
>
)
```
### Using Fast MoE as a PyTorch module
Examples can be seen in
[
examples
](
examples/
)
. The easist way is to replace the
...
...
@@ -38,28 +53,4 @@ NCCL backend is required to be built with PyTorch. Use environment variable
`USE_NCCL=1`
to
`setup.py`
to enable distributing experts across workers. Note
that the arguments of the MoE layers should then be excluded from the data
parallel parameter synchronization list.
## Feature Roadmap
### Better All-to-all communication efficiency and computation performance
The dispatching process from source worker to the expert is time-consuming and
topology-aware, as it is an all-to-all communication. Overlapping or other
communication reducition technologies can be applied to reduce the overhead of
this step. However, this demands much research and coding efforts.
### Dynamic expert distribution load balancing
Load imbalance is observed as there is no loss item about load balancing. Some
experts are significantly more frequently called. Therefore, a dynamic scheduler
to duplicate or recycle some experts on some workers may be effective.
### Model parallel the experts
To enable larger expert sizes.
### Use zero-optimizer to reduce memory consumption
### Intigrate top-k gate into local scatter gather
E
fmoe/layers.py
View file @
6b8d2f2e
...
...
@@ -71,7 +71,7 @@ class FMoETransformerMLP(nn.Module):
world_size
=
1
,
model_parallel_size
=
1
,
model_parallel_rank
=
1
,
group
=
None
,
mp_
group
=
None
,
activation
=
torch
.
nn
.
functional
.
gelu
,
top_k
=
2
,
pre_lnorm
=
False
,
...
...
@@ -83,7 +83,7 @@ class FMoETransformerMLP(nn.Module):
self
.
world_size
=
world_size
self
.
model_parallel_size
=
model_parallel_size
self
.
model_parallel_rank
=
model_parallel_rank
self
.
group
=
group
self
.
mp_
group
=
mp_
group
self
.
activation
=
activation
self
.
pre_lnorm
=
pre_lnorm
self
.
top_k
=
top_k
...
...
@@ -140,7 +140,7 @@ class FMoETransformerMLP(nn.Module):
world_size
=
self
.
model_parallel_size
tensor_list
=
[
torch
.
empty_like
(
output
)
for
_
in
range
(
world_size
)]
torch
.
distributed
.
all_gather
(
tensor_list
,
output
,
group
=
self
.
group
)
torch
.
distributed
.
all_gather
(
tensor_list
,
output
,
group
=
self
.
mp_
group
)
output
=
torch
.
cat
(
tensor_list
,
dim
=
1
)
return
output
.
reshape
(
original_shape
),
self
.
bias
fmoe/megatron.py
View file @
6b8d2f2e
...
...
@@ -4,15 +4,30 @@ from .layers import FMoETransformerMLP
def
create_moe_mlp
(
args
,
model_parallel_rank
,
group
):
assert
(
args
.
seq_length
*
args
.
batch_size
%
args
.
model_parallel_size
==
0
),
"Num experts should be multiple of mp size"
num_experts
=
args
.
num_experts
//
args
.
model_parallel_size
),
"Batch size x sequence length should be multiple of mp size"
fmoe
=
FMoETransformerMLP
(
num_experts
,
args
.
num_experts
,
d_model
=
args
.
hidden_size
,
d_hidden
=
args
.
hidden_size
*
4
,
world_size
=
args
.
world_size
,
model_parallel_size
=
args
.
model_parallel_size
,
model_parallel_rank
=
model_parallel_rank
,
group
=
group
,
mp_
group
=
group
,
)
return
fmoe
def
fmoefy
(
model
,
num_experts
=
None
):
from
megatron
import
get_args
from
megatron
import
mpu
args
=
get_args
()
if
num_experts
is
not
None
:
args
.
num_experts
=
num_experts
assert
(
'num_experts'
in
args
),
'num_experts should be specified in arguments or fmoefy function'
for
l
in
model
.
language_model
.
transformer
.
layers
:
l
.
mlp
=
create_moe_mlp
(
args
,
mpu
.
get_model_parallel_rank
(),
mpu
.
get_model_parallel_group
())
return
model
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment