Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
4d48209d
"website/vscode:/vscode.git/clone" did not exist on "16d0aa82c1bbf3788571c651f5149f3c4e91a47a"
Commit
4d48209d
authored
Mar 09, 2021
by
Sengxian
Browse files
Format using black
parent
527a8cc9
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
323 additions
and
214 deletions
+323
-214
.pylintrc
.pylintrc
+1
-0
fmoe/distributed.py
fmoe/distributed.py
+26
-20
fmoe/functions.py
fmoe/functions.py
+11
-7
fmoe/gates.py
fmoe/gates.py
+16
-14
fmoe/layers.py
fmoe/layers.py
+65
-45
fmoe/megatron.py
fmoe/megatron.py
+62
-42
fmoe/transformer.py
fmoe/transformer.py
+26
-19
fmoe/utils.py
fmoe/utils.py
+5
-5
tests/benchmark_mlp.py
tests/benchmark_mlp.py
+76
-43
tests/moe.py
tests/moe.py
+7
-11
tests/test_dp.py
tests/test_dp.py
+2
-2
tests/test_numerical.py
tests/test_numerical.py
+26
-6
No files found.
.pylintrc
View file @
4d48209d
...
@@ -142,6 +142,7 @@ disable=print-statement,
...
@@ -142,6 +142,7 @@ disable=print-statement,
arguments-differ,
arguments-differ,
import-outside-toplevel,
import-outside-toplevel,
signature-differs,
signature-differs,
bad-continuation,
# Enable the message, report, category or checker with the given id(s). You can
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# either give multiple identifier separated by comma (,) or put this option
...
...
fmoe/distributed.py
View file @
4d48209d
r
'''
r
"""
Supportive modules to conduct distributed training
Supportive modules to conduct distributed training
'''
"""
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
...
@@ -8,7 +8,7 @@ from .utils import get_torch_default_comm
...
@@ -8,7 +8,7 @@ from .utils import get_torch_default_comm
class
DistributedGroupedDataParallel
(
nn
.
Module
):
class
DistributedGroupedDataParallel
(
nn
.
Module
):
r
'''
r
"""
A customized DDP module to support different all-reduce regions in the
A customized DDP module to support different all-reduce regions in the
model. The all-reduce region is defined as an attribution `dp_comm` in the
model. The all-reduce region is defined as an attribution `dp_comm` in the
weight object.
weight object.
...
@@ -20,36 +20,42 @@ class DistributedGroupedDataParallel(nn.Module):
...
@@ -20,36 +20,42 @@ class DistributedGroupedDataParallel(nn.Module):
If it is set to `world`, the gradients is synchronized across all workers,
If it is set to `world`, the gradients is synchronized across all workers,
regardless their model or data parallel group. This is extremely useful for
regardless their model or data parallel group. This is extremely useful for
shared layers like the gate.
shared layers like the gate.
'''
"""
def
__init__
(
self
,
module
,
mp_group
=
None
,
dp_group
=
None
,
world_group
=
None
,
auto_allreduce
=
False
):
def
__init__
(
assert
not
auto_allreduce
,
'Automatic all-reduce is not implemented yet'
self
,
module
,
mp_group
=
None
,
dp_group
=
None
,
world_group
=
None
,
auto_allreduce
=
False
,
):
assert
not
auto_allreduce
,
"Automatic all-reduce is not implemented yet"
super
().
__init__
()
super
().
__init__
()
self
.
module
=
module
self
.
module
=
module
self
.
comms
=
dict
()
self
.
comms
=
dict
()
if
mp_group
is
not
None
:
if
mp_group
is
not
None
:
self
.
comms
[
'
mp
'
]
=
mp_group
self
.
comms
[
"
mp
"
]
=
mp_group
if
dp_group
is
not
None
:
if
dp_group
is
not
None
:
self
.
comms
[
'
dp
'
]
=
dp_group
self
.
comms
[
"
dp
"
]
=
dp_group
else
:
else
:
self
.
comms
[
'
dp
'
]
=
get_torch_default_comm
()
self
.
comms
[
"
dp
"
]
=
get_torch_default_comm
()
if
world_group
is
None
:
if
world_group
is
None
:
self
.
comms
[
'
world
'
]
=
get_torch_default_comm
()
self
.
comms
[
"
world
"
]
=
get_torch_default_comm
()
else
:
else
:
self
.
comms
[
'
world
'
]
=
world_group
self
.
comms
[
"
world
"
]
=
world_group
def
allreduce_params
(
no_scale
=
False
,
reduce_after
=
False
,
def
allreduce_params
(
no_scale
=
False
,
reduce_after
=
False
,
fp32_allreduce
=
False
):
fp32_allreduce
=
False
):
groups
=
dict
()
groups
=
dict
()
for
p
in
self
.
module
.
parameters
():
for
p
in
self
.
module
.
parameters
():
if
not
p
.
requires_grad
or
p
.
grad
is
None
:
if
not
p
.
requires_grad
or
p
.
grad
is
None
:
continue
continue
if
hasattr
(
p
,
'
dp_comm
'
):
if
hasattr
(
p
,
"
dp_comm
"
):
dp_comm
=
p
.
dp_comm
dp_comm
=
p
.
dp_comm
else
:
else
:
dp_comm
=
'
dp
'
dp_comm
=
"
dp
"
group_key
=
(
dp_comm
,
p
.
dtype
)
group_key
=
(
dp_comm
,
p
.
dtype
)
if
group_key
not
in
groups
:
if
group_key
not
in
groups
:
groups
[
group_key
]
=
[
p
]
groups
[
group_key
]
=
[
p
]
...
@@ -81,10 +87,10 @@ class DistributedGroupedDataParallel(nn.Module):
...
@@ -81,10 +87,10 @@ class DistributedGroupedDataParallel(nn.Module):
for
p
in
self
.
module
.
parameters
():
for
p
in
self
.
module
.
parameters
():
if
not
p
.
requires_grad
or
p
.
grad
is
None
:
if
not
p
.
requires_grad
or
p
.
grad
is
None
:
continue
continue
if
hasattr
(
p
,
'
dp_comm
'
):
if
hasattr
(
p
,
"
dp_comm
"
):
dp_comm
=
p
.
dp_comm
dp_comm
=
p
.
dp_comm
else
:
else
:
dp_comm
=
'
dp
'
dp_comm
=
"
dp
"
group_key
=
(
dp_comm
,
p
.
dtype
)
group_key
=
(
dp_comm
,
p
.
dtype
)
if
group_key
not
in
groups
:
if
group_key
not
in
groups
:
groups
[
group_key
]
=
[
p
]
groups
[
group_key
]
=
[
p
]
...
@@ -103,7 +109,7 @@ class DistributedGroupedDataParallel(nn.Module):
...
@@ -103,7 +109,7 @@ class DistributedGroupedDataParallel(nn.Module):
d
.
copy_
(
s
)
d
.
copy_
(
s
)
def
forward
(
self
,
*
args
,
**
kwargs
):
def
forward
(
self
,
*
args
,
**
kwargs
):
r
'''
r
"""
Directly call the module's forward function.
Directly call the module's forward function.
'''
"""
return
self
.
module
(
*
args
,
**
kwargs
)
return
self
.
module
(
*
args
,
**
kwargs
)
fmoe/functions.py
View file @
4d48209d
...
@@ -40,8 +40,7 @@ def moe_prepare_forward(gate, num_expert, world_size, comm=None):
...
@@ -40,8 +40,7 @@ def moe_prepare_forward(gate, num_expert, world_size, comm=None):
)
)
else
:
else
:
global_expert_count
=
local_expert_count
global_expert_count
=
local_expert_count
fwd_expert_count
=
global_expert_count
.
view
(
world_size
,
fwd_expert_count
=
global_expert_count
.
view
(
world_size
,
num_expert
).
sum
(
dim
=
0
)
num_expert
).
sum
(
dim
=
0
)
fwd_batch_size
=
int
(
fwd_expert_count
.
sum
().
item
())
fwd_batch_size
=
int
(
fwd_expert_count
.
sum
().
item
())
return
(
return
(
pos
,
pos
,
...
@@ -58,6 +57,7 @@ class MOEScatter(Function):
...
@@ -58,6 +57,7 @@ class MOEScatter(Function):
If `world_size` is greater than 1, the samples will first be locally
If `world_size` is greater than 1, the samples will first be locally
scattered, and then exchanged across workers.
scattered, and then exchanged across workers.
"""
"""
@
staticmethod
@
staticmethod
def
forward
(
def
forward
(
ctx
,
ctx
,
...
@@ -107,6 +107,7 @@ class MOELinear(Function):
...
@@ -107,6 +107,7 @@ class MOELinear(Function):
r
"""
r
"""
Computes linear operators within one GPU on different experts simutaneously.
Computes linear operators within one GPU on different experts simutaneously.
"""
"""
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
global_input_buf
,
weight
,
fwd_expert_count
):
def
forward
(
ctx
,
global_input_buf
,
weight
,
fwd_expert_count
):
(
global_output_buf
,)
=
fmoe_cuda
.
forward
(
(
global_output_buf
,)
=
fmoe_cuda
.
forward
(
...
@@ -130,6 +131,7 @@ class MOEGather(Function):
...
@@ -130,6 +131,7 @@ class MOEGather(Function):
Gather output samples from contiguous alone experts back to [batch x
Gather output samples from contiguous alone experts back to [batch x
sequences]. Works symmetrically with MOEScatter.
sequences]. Works symmetrically with MOEScatter.
"""
"""
@
staticmethod
@
staticmethod
def
forward
(
def
forward
(
ctx
,
ctx
,
...
@@ -176,9 +178,10 @@ class MOEGather(Function):
...
@@ -176,9 +178,10 @@ class MOEGather(Function):
class
AllGather
(
Function
):
class
AllGather
(
Function
):
r
'''
r
"""
A wrapper for the All-Gather function to support auto-differentiation.
A wrapper for the All-Gather function to support auto-differentiation.
'''
"""
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
inp
,
rank
,
world_size
,
group
):
def
forward
(
ctx
,
inp
,
rank
,
world_size
,
group
):
tensor_list
=
[
torch
.
empty_like
(
inp
)
for
_
in
range
(
world_size
)]
tensor_list
=
[
torch
.
empty_like
(
inp
)
for
_
in
range
(
world_size
)]
...
@@ -191,13 +194,14 @@ class AllGather(Function):
...
@@ -191,13 +194,14 @@ class AllGather(Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_out
):
def
backward
(
ctx
,
grad_out
):
rank
,
dim0
=
ctx
.
args
rank
,
dim0
=
ctx
.
args
return
grad_out
[
rank
*
dim0
:
(
rank
+
1
)
*
dim0
],
None
,
None
,
None
return
grad_out
[
rank
*
dim0
:
(
rank
+
1
)
*
dim0
],
None
,
None
,
None
class
Slice
(
Function
):
class
Slice
(
Function
):
r
'''
r
"""
A wrapper for the Slice function to support auto-differentiation.
A wrapper for the Slice function to support auto-differentiation.
'''
"""
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
inp
,
rank
,
world_size
,
group
):
def
forward
(
ctx
,
inp
,
rank
,
world_size
,
group
):
B
:
int
=
inp
.
shape
[
0
]
B
:
int
=
inp
.
shape
[
0
]
...
...
fmoe/gates.py
View file @
4d48209d
r
'''
r
"""
Different implementations of the Gate are located here.
Different implementations of the Gate are located here.
The `NaiveGate` is the reference to implement any other gate.
The `NaiveGate` is the reference to implement any other gate.
'''
"""
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
class
ZeroGate
(
nn
.
Module
):
class
ZeroGate
(
nn
.
Module
):
r
'''
r
"""
Guide all input samples to gate 0.
Guide all input samples to gate 0.
'''
"""
def
__init__
(
self
,
_1
,
_2
,
_3
,
top_k
=
2
):
def
__init__
(
self
,
_1
,
_2
,
_3
,
top_k
=
2
):
super
().
__init__
()
super
().
__init__
()
self
.
top_k
=
top_k
self
.
top_k
=
top_k
def
forward
(
self
,
inp
):
def
forward
(
self
,
inp
):
r
'''
r
"""
All output to expert 1
All output to expert 1
'''
"""
idx
=
torch
.
zeros
(
inp
.
shape
[
0
]
*
self
.
top_k
,
idx
=
torch
.
zeros
(
dtype
=
torch
.
int64
,
device
=
inp
.
device
)
inp
.
shape
[
0
]
*
self
.
top_k
,
dtype
=
torch
.
int64
,
device
=
inp
.
device
score
=
torch
.
ones
(
inp
.
shape
[
0
]
*
self
.
top_k
,
)
device
=
inp
.
device
)
/
self
.
top_k
score
=
torch
.
ones
(
inp
.
shape
[
0
]
*
self
.
top_k
,
device
=
inp
.
device
)
/
self
.
top_k
return
idx
,
score
.
reshape
(
-
1
,
1
,
self
.
top_k
)
return
idx
,
score
.
reshape
(
-
1
,
1
,
self
.
top_k
)
class
NaiveGate
(
nn
.
Module
):
class
NaiveGate
(
nn
.
Module
):
r
'''
r
"""
A naive gate implementation that defines the standard behavior of the gate
A naive gate implementation that defines the standard behavior of the gate
which determines which experts the tokens are going to.
which determines which experts the tokens are going to.
Both the indecies and the score, or confidence, are output to the parent
Both the indecies and the score, or confidence, are output to the parent
module.
module.
The load-balance strategies are also designed to be implemented within the
The load-balance strategies are also designed to be implemented within the
`Gate` module.
`Gate` module.
'''
"""
def
__init__
(
self
,
d_model
,
num_expert
,
world_size
,
top_k
=
2
):
def
__init__
(
self
,
d_model
,
num_expert
,
world_size
,
top_k
=
2
):
super
().
__init__
()
super
().
__init__
()
self
.
gate
=
nn
.
Linear
(
d_model
,
num_expert
*
world_size
)
self
.
gate
=
nn
.
Linear
(
d_model
,
num_expert
*
world_size
)
self
.
top_k
=
top_k
self
.
top_k
=
top_k
def
forward
(
self
,
inp
):
def
forward
(
self
,
inp
):
r
'''
r
"""
The naive implementation simply calculates the top-k of a linear layer's
The naive implementation simply calculates the top-k of a linear layer's
output.
output.
'''
"""
gate
=
self
.
gate
(
inp
)
gate
=
self
.
gate
(
inp
)
gate_top_k_val
,
gate_top_k_idx
=
torch
.
topk
(
gate_top_k_val
,
gate_top_k_idx
=
torch
.
topk
(
gate
,
k
=
self
.
top_k
,
dim
=-
1
,
largest
=
True
,
sorted
=
False
gate
,
k
=
self
.
top_k
,
dim
=-
1
,
largest
=
True
,
sorted
=
False
...
...
fmoe/layers.py
View file @
4d48209d
r
'''
r
"""
Layers that FMoE provides to users
Layers that FMoE provides to users
'''
"""
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -11,14 +11,21 @@ from .gates import NaiveGate
...
@@ -11,14 +11,21 @@ from .gates import NaiveGate
class
FMoELinear
(
nn
.
Module
):
class
FMoELinear
(
nn
.
Module
):
r
'''
r
"""
A linear layer that contains multiple experts.
A linear layer that contains multiple experts.
As multiple experts can be placed on the same worker, the computation can be
As multiple experts can be placed on the same worker, the computation can be
performed in parallel to increase the performance.
performed in parallel to increase the performance.
The FMoELinear module provides such function.
The FMoELinear module provides such function.
'''
"""
def
__init__
(
self
,
num_expert
:
int
,
in_feat
:
int
,
out_feat
:
int
,
bias
:
bool
=
True
,
rank
:
int
=
0
):
def
__init__
(
self
,
num_expert
:
int
,
in_feat
:
int
,
out_feat
:
int
,
bias
:
bool
=
True
,
rank
:
int
=
0
,
):
super
().
__init__
()
super
().
__init__
()
self
.
num_expert
=
num_expert
self
.
num_expert
=
num_expert
self
.
in_feat
=
in_feat
self
.
in_feat
=
in_feat
...
@@ -28,12 +35,12 @@ class FMoELinear(nn.Module):
...
@@ -28,12 +35,12 @@ class FMoELinear(nn.Module):
if
bias
:
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
,
out_feat
))
self
.
bias
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
,
out_feat
))
else
:
else
:
self
.
register_parameter
(
'
bias
'
,
None
)
self
.
register_parameter
(
"
bias
"
,
None
)
def
forward
(
self
,
inp
,
fwd_expert_count
):
def
forward
(
self
,
inp
,
fwd_expert_count
):
r
'''
r
"""
Call MOE function
Call MOE function
'''
"""
x
=
MOELinear
.
apply
(
inp
,
self
.
weight
,
fwd_expert_count
)
x
=
MOELinear
.
apply
(
inp
,
self
.
weight
,
fwd_expert_count
)
if
self
.
bias
is
not
None
:
if
self
.
bias
is
not
None
:
# TODO: torch.repeat_interleave seems have numerical
# TODO: torch.repeat_interleave seems have numerical
...
@@ -45,8 +52,9 @@ class FMoELinear(nn.Module):
...
@@ -45,8 +52,9 @@ class FMoELinear(nn.Module):
# like MOELinear.apply(x, weight, bias, count)
# like MOELinear.apply(x, weight, bias, count)
# Solution 1
# Solution 1
bias
=
torch
.
repeat_interleave
(
self
.
bias
,
bias
=
torch
.
repeat_interleave
(
fwd_expert_count
.
to
(
self
.
bias
.
device
),
dim
=
0
)
self
.
bias
,
fwd_expert_count
.
to
(
self
.
bias
.
device
),
dim
=
0
)
# Solution 2
# Solution 2
# bias_idx = torch.arange(self.num_expert)\
# bias_idx = torch.arange(self.num_expert)\
...
@@ -67,24 +75,27 @@ class FMoELinear(nn.Module):
...
@@ -67,24 +75,27 @@ class FMoELinear(nn.Module):
return
x
return
x
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
return
'num_expert={}, in_features={},
\
return
"num_expert={}, in_features={},
\
out_features={}, bias={}, rank={}'
.
format
(
out_features={}, bias={}, rank={}"
.
format
(
self
.
num_expert
,
self
.
in_feat
,
self
.
num_expert
,
self
.
out_feat
,
self
.
bias
is
not
None
,
self
.
rank
self
.
in_feat
,
self
.
out_feat
,
self
.
bias
is
not
None
,
self
.
rank
,
)
)
def
mark_module_parallel_comm
(
module
,
comm
):
def
mark_module_parallel_comm
(
module
,
comm
):
r
'''
r
"""
Mark all parameters in `module` as doing data parallel in `comm`, where
Mark all parameters in `module` as doing data parallel in `comm`, where
`comm` may be one of `'world', 'dp', 'none'`.
`comm` may be one of `'world', 'dp', 'none'`.
'''
"""
for
p
in
module
.
parameters
():
for
p
in
module
.
parameters
():
setattr
(
p
,
'
dp_comm
'
,
comm
)
setattr
(
p
,
"
dp_comm
"
,
comm
)
def
_fmoe_general_global_forward
(
inp
,
gate
,
expert_fn
,
num_expert
,
world_size
):
def
_fmoe_general_global_forward
(
inp
,
gate
,
expert_fn
,
num_expert
,
world_size
):
r
'''
r
"""
A private function that performs the following steps to complete the MoE
A private function that performs the following steps to complete the MoE
computation.
computation.
* Count the number of tokens from each worker to each expert.
* Count the number of tokens from each worker to each expert.
...
@@ -94,14 +105,16 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
...
@@ -94,14 +105,16 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
* Gather the output features of experts back, and reorder them as sentences.
* Gather the output features of experts back, and reorder them as sentences.
Intermediate results like expert counts are hidden from users by this
Intermediate results like expert counts are hidden from users by this
function.
function.
'''
"""
(
(
pos
,
local_expert_count
,
global_expert_count
,
fwd_expert_count
,
pos
,
fwd_batch_size
local_expert_count
,
global_expert_count
,
fwd_expert_count
,
fwd_batch_size
,
)
=
moe_prepare_forward
(
gate
,
num_expert
,
world_size
)
)
=
moe_prepare_forward
(
gate
,
num_expert
,
world_size
)
x
=
MOEScatter
.
apply
(
x
=
MOEScatter
.
apply
(
inp
,
pos
,
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
inp
,
pos
,
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
world_size
world_size
)
)
x
=
expert_fn
(
x
,
fwd_expert_count
)
x
=
expert_fn
(
x
,
fwd_expert_count
)
x
=
MOEGather
.
apply
(
x
=
MOEGather
.
apply
(
...
@@ -111,7 +124,7 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
...
@@ -111,7 +124,7 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
class
FMoE
(
nn
.
Module
):
class
FMoE
(
nn
.
Module
):
r
'''
r
"""
A general moe implementation that supports an arbitrary module as the
A general moe implementation that supports an arbitrary module as the
expert.
expert.
* `num_expert` stands for the number of experts on **each** worker.
* `num_expert` stands for the number of experts on **each** worker.
...
@@ -126,9 +139,18 @@ class FMoE(nn.Module):
...
@@ -126,9 +139,18 @@ class FMoE(nn.Module):
* `gate` is a gate class which can found in `fmoe.gates`.
* `gate` is a gate class which can found in `fmoe.gates`.
* `expert` can be specified as a module class, it is used to generate
* `expert` can be specified as a module class, it is used to generate
`num_expert` expert modules.
`num_expert` expert modules.
'''
"""
def
__init__
(
self
,
num_expert
=
32
,
d_model
=
1024
,
world_size
=
1
,
mp_group
=
None
,
top_k
=
2
,
gate
=
NaiveGate
,
expert
=
None
):
def
__init__
(
self
,
num_expert
=
32
,
d_model
=
1024
,
world_size
=
1
,
mp_group
=
None
,
top_k
=
2
,
gate
=
NaiveGate
,
expert
=
None
,
):
super
().
__init__
()
super
().
__init__
()
self
.
num_expert
=
num_expert
self
.
num_expert
=
num_expert
self
.
d_model
=
d_model
self
.
d_model
=
d_model
...
@@ -143,34 +165,33 @@ class FMoE(nn.Module):
...
@@ -143,34 +165,33 @@ class FMoE(nn.Module):
self
.
top_k
=
top_k
self
.
top_k
=
top_k
self
.
gate
=
gate
(
d_model
,
num_expert
,
world_size
,
top_k
)
self
.
gate
=
gate
(
d_model
,
num_expert
,
world_size
,
top_k
)
if
expert
is
not
None
:
if
expert
is
not
None
:
self
.
experts
=
nn
.
ModuleList
([
expert
(
d_model
)
self
.
experts
=
nn
.
ModuleList
([
expert
(
d_model
)
for
_
in
range
(
num_expert
)])
for
_
in
range
(
num_expert
)])
self
.
experts_fused
=
False
self
.
experts_fused
=
False
else
:
else
:
self
.
experts_fused
=
True
self
.
experts_fused
=
True
def
expert_fn
(
self
,
inp
,
fwd_expert_count
):
def
expert_fn
(
self
,
inp
,
fwd_expert_count
):
r
'''
r
"""
The default expert function which either calls the experts as a whole
The default expert function which either calls the experts as a whole
or as separate experts.
or as separate experts.
'''
"""
if
self
.
experts_fused
:
if
self
.
experts_fused
:
return
self
.
experts
(
inp
,
fwd_expert_count
)
return
self
.
experts
(
inp
,
fwd_expert_count
)
outputs
=
[]
outputs
=
[]
base_idx
=
0
base_idx
=
0
for
i
in
range
(
self
.
num_expert
):
for
i
in
range
(
self
.
num_expert
):
batch_size
=
fwd_expert_count
[
i
].
item
()
batch_size
=
fwd_expert_count
[
i
].
item
()
inp_slice
=
inp
[
base_idx
:
base_idx
+
batch_size
]
inp_slice
=
inp
[
base_idx
:
base_idx
+
batch_size
]
outputs
.
append
(
self
.
experts
[
i
](
inp_slice
))
outputs
.
append
(
self
.
experts
[
i
](
inp_slice
))
base_idx
+=
batch_size
base_idx
+=
batch_size
return
torch
.
cat
(
outputs
,
dim
=
0
)
return
torch
.
cat
(
outputs
,
dim
=
0
)
def
mark_parallel_comm
(
self
,
expert_dp_comm
=
'
none
'
):
def
mark_parallel_comm
(
self
,
expert_dp_comm
=
"
none
"
):
r
'''
r
"""
Automatically mark the data parallel comms of the parameters within the
Automatically mark the data parallel comms of the parameters within the
module. This can be typically called at the end of the __init__ function
module. This can be typically called at the end of the __init__ function
in child classes.
in child classes.
'''
"""
if
self
.
experts
is
not
None
:
if
self
.
experts
is
not
None
:
comm
=
expert_dp_comm
comm
=
expert_dp_comm
if
isinstance
(
self
.
experts
,
list
):
if
isinstance
(
self
.
experts
,
list
):
...
@@ -178,29 +199,28 @@ class FMoE(nn.Module):
...
@@ -178,29 +199,28 @@ class FMoE(nn.Module):
mark_module_parallel_comm
(
e
,
comm
)
mark_module_parallel_comm
(
e
,
comm
)
else
:
else
:
mark_module_parallel_comm
(
self
.
experts
,
comm
)
mark_module_parallel_comm
(
self
.
experts
,
comm
)
mark_module_parallel_comm
(
self
.
gate
,
'
world
'
)
mark_module_parallel_comm
(
self
.
gate
,
"
world
"
)
def
forward
(
self
,
inp
):
def
forward
(
self
,
inp
):
r
'''
r
"""
The FMoE module first computes gate output, and then conduct MoE forward
The FMoE module first computes gate output, and then conduct MoE forward
according to the gate. The score of the selected gate given by the
according to the gate. The score of the selected gate given by the
expert is multiplied to the experts' output tensors as a weight.
expert is multiplied to the experts' output tensors as a weight.
'''
"""
if
self
.
mp_size
>
1
:
if
self
.
mp_size
>
1
:
inp
=
Slice
.
apply
(
inp
,
inp
=
Slice
.
apply
(
inp
,
self
.
mp_rank
,
self
.
mp_size
,
self
.
mp_group
)
self
.
mp_rank
,
self
.
mp_size
,
self
.
mp_group
)
gate_top_k_idx
,
gate_score
=
self
.
gate
(
inp
)
gate_top_k_idx
,
gate_score
=
self
.
gate
(
inp
)
# to: (BxLxtop_k) x d_model
# to: (BxLxtop_k) x d_model
inp
=
inp
.
repeat_interleave
(
repeats
=
self
.
top_k
,
dim
=
0
)
inp
=
inp
.
repeat_interleave
(
repeats
=
self
.
top_k
,
dim
=
0
)
x
=
_fmoe_general_global_forward
(
inp
,
gate_top_k_idx
,
self
.
expert_fn
,
x
=
_fmoe_general_global_forward
(
self
.
num_expert
,
self
.
world_size
)
inp
,
gate_top_k_idx
,
self
.
expert_fn
,
self
.
num_expert
,
self
.
world_size
)
# to: (BxL) x top_k x d_model
# to: (BxL) x top_k x d_model
x
=
x
.
view
(
-
1
,
self
.
top_k
,
self
.
d_model
)
x
=
x
.
view
(
-
1
,
self
.
top_k
,
self
.
d_model
)
# to: (BxL) x d_model
# to: (BxL) x d_model
x
=
torch
.
bmm
(
gate_score
,
x
).
reshape
(
-
1
,
self
.
d_model
)
x
=
torch
.
bmm
(
gate_score
,
x
).
reshape
(
-
1
,
self
.
d_model
)
if
self
.
mp_size
>
1
:
if
self
.
mp_size
>
1
:
x
=
AllGather
.
apply
(
x
,
x
=
AllGather
.
apply
(
x
,
self
.
mp_rank
,
self
.
mp_size
,
self
.
mp_group
)
self
.
mp_rank
,
self
.
mp_size
,
self
.
mp_group
)
return
x
return
x
fmoe/megatron.py
View file @
4d48209d
r
'''
r
"""
The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
lines of modification.
lines of modification.
See `examples/megatron` for usage instructions.
See `examples/megatron` for usage instructions.
'''
"""
import
math
import
math
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -14,27 +14,30 @@ from .distributed import DistributedGroupedDataParallel
...
@@ -14,27 +14,30 @@ from .distributed import DistributedGroupedDataParallel
class
_FakeMegatronMLP
(
nn
.
Module
):
class
_FakeMegatronMLP
(
nn
.
Module
):
r
'''
r
"""
A fake mlp without model parallelism for correctness testing
A fake mlp without model parallelism for correctness testing
'''
"""
def
__init__
(
self
,
args
,
_
):
def
__init__
(
self
,
args
,
_
):
super
().
__init__
()
super
().
__init__
()
self
.
fc1
=
nn
.
Linear
(
args
.
hidden_size
,
args
.
hidden_hidden_size
)
self
.
fc1
=
nn
.
Linear
(
args
.
hidden_size
,
args
.
hidden_hidden_size
)
self
.
fc2
=
nn
.
Linear
(
args
.
hidden_hidden_size
,
args
.
hidden_size
)
self
.
fc2
=
nn
.
Linear
(
args
.
hidden_hidden_size
,
args
.
hidden_size
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
r
'''
r
"""
Directly use GeLU
Directly use GeLU
'''
"""
x
=
self
.
fc1
(
x
)
x
=
self
.
fc1
(
x
)
x
=
F
.
gelu
(
x
)
x
=
F
.
gelu
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
fc2
(
x
)
return
x
,
torch
.
zeros_like
(
x
)
return
x
,
torch
.
zeros_like
(
x
)
def
_megatron_init_method
(
self
,
rng
,
sigma
):
def
_megatron_init_method
(
self
,
rng
,
sigma
):
r
'''
r
"""
Init method based on N(0, sigma).
Init method based on N(0, sigma).
Copied from Megatron-LM
Copied from Megatron-LM
'''
"""
device
=
self
.
weight
.
device
device
=
self
.
weight
.
device
dtype
=
self
.
weight
.
dtype
dtype
=
self
.
weight
.
dtype
weight
=
rng
.
normal
(
loc
=
0.0
,
scale
=
sigma
,
size
=
tuple
(
self
.
weight
.
size
()))
weight
=
rng
.
normal
(
loc
=
0.0
,
scale
=
sigma
,
size
=
tuple
(
self
.
weight
.
size
()))
...
@@ -45,12 +48,13 @@ def _megatron_init_method(self, rng, sigma):
...
@@ -45,12 +48,13 @@ def _megatron_init_method(self, rng, sigma):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
self
.
bias
.
zero_
()
self
.
bias
.
zero_
()
def
_random_init_weight
(
self
,
rng
):
def
_random_init_weight
(
self
,
rng
):
r
'''
r
"""
Copied from torch.nn.init.kaiming_uniform_
Copied from torch.nn.init.kaiming_uniform_
'''
"""
fan
=
nn
.
init
.
_calculate_correct_fan
(
self
.
weight
[
0
],
'
fan_in
'
)
fan
=
nn
.
init
.
_calculate_correct_fan
(
self
.
weight
[
0
],
"
fan_in
"
)
gain
=
nn
.
init
.
calculate_gain
(
'
leaky_relu
'
,
math
.
sqrt
(
5
))
gain
=
nn
.
init
.
calculate_gain
(
"
leaky_relu
"
,
math
.
sqrt
(
5
))
std
=
gain
/
math
.
sqrt
(
fan
)
std
=
gain
/
math
.
sqrt
(
fan
)
bound
=
math
.
sqrt
(
3.0
)
*
std
bound
=
math
.
sqrt
(
3.0
)
*
std
device
=
self
.
weight
.
device
device
=
self
.
weight
.
device
...
@@ -66,23 +70,29 @@ def _random_init_weight(self, rng):
...
@@ -66,23 +70,29 @@ def _random_init_weight(self, rng):
class
MegatronMLP
(
FMoETransformerMLP
):
class
MegatronMLP
(
FMoETransformerMLP
):
r
'''
r
"""
Make the FMoETransformerMLP layer that distributes experts across
Make the FMoETransformerMLP layer that distributes experts across
communication group `group` to replace the original MLP layer in Megatron.
communication group `group` to replace the original MLP layer in Megatron.
'''
"""
def
__init__
(
self
,
args
,
group
):
def
__init__
(
self
,
args
,
group
):
assert
(
args
.
seq_length
*
args
.
micro_batch_size
assert
(
%
args
.
tensor_model_parallel_size
==
0
args
.
seq_length
*
args
.
micro_batch_size
%
args
.
tensor_model_parallel_size
==
0
),
"Batch size x sequence length should be multiple of mp size"
),
"Batch size x sequence length should be multiple of mp size"
if
not
args
.
distributed_experts
:
if
not
args
.
distributed_experts
:
world_size
=
1
world_size
=
1
else
:
else
:
world_size
=
args
.
world_size
world_size
=
args
.
world_size
super
().
__init__
(
args
.
num_experts
,
super
().
__init__
(
args
.
num_experts
,
top_k
=
args
.
top_k
,
top_k
=
args
.
top_k
,
d_model
=
args
.
hidden_size
,
d_hidden
=
args
.
hidden_hidden_size
,
d_model
=
args
.
hidden_size
,
world_size
=
world_size
,
mp_group
=
group
,
d_hidden
=
args
.
hidden_hidden_size
,
expert_dp_comm
=
'none'
if
args
.
distributed_experts
else
'dp'
)
world_size
=
world_size
,
mp_group
=
group
,
expert_dp_comm
=
"none"
if
args
.
distributed_experts
else
"dp"
,
)
self
.
hidden_size
=
args
.
hidden_size
self
.
hidden_size
=
args
.
hidden_size
if
args
.
distributed_experts
:
if
args
.
distributed_experts
:
self
.
rank
=
args
.
rank
self
.
rank
=
args
.
rank
...
@@ -93,24 +103,31 @@ class MegatronMLP(FMoETransformerMLP):
...
@@ -93,24 +103,31 @@ class MegatronMLP(FMoETransformerMLP):
self
.
reset_parameters
()
self
.
reset_parameters
()
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
r
'''
r
"""
Initialize the weight as linear layers.
Initialize the weight as linear layers.
As megatron is using fixed random seed for some nasty stuff, an
As megatron is using fixed random seed for some nasty stuff, an
additional numpy rng is used.
additional numpy rng is used.
'''
"""
rng
=
np
.
random
.
default_rng
(
np
.
random
.
randint
(
2048
)
+
self
.
rank
)
rng
=
np
.
random
.
default_rng
(
np
.
random
.
randint
(
2048
)
+
self
.
rank
)
_megatron_init_method
(
self
.
experts
.
htoh4
,
rng
,
self
.
sigma
)
_megatron_init_method
(
self
.
experts
.
htoh4
,
rng
,
self
.
sigma
)
std
=
self
.
sigma
/
math
.
sqrt
(
2.0
*
self
.
num_layers
)
std
=
self
.
sigma
/
math
.
sqrt
(
2.0
*
self
.
num_layers
)
_megatron_init_method
(
self
.
experts
.
h4toh
,
rng
,
std
)
_megatron_init_method
(
self
.
experts
.
h4toh
,
rng
,
std
)
def
forward
(
self
,
inp
):
def
forward
(
self
,
inp
):
return
super
().
forward
(
inp
),
torch
.
zeros
(
self
.
hidden_size
,
return
(
dtype
=
inp
.
dtype
,
device
=
inp
.
device
)
super
().
forward
(
inp
),
torch
.
zeros
(
self
.
hidden_size
,
dtype
=
inp
.
dtype
,
device
=
inp
.
device
),
)
def
fmoefy
(
model
,
num_experts
=
None
,
distributed_experts
=
True
,
def
fmoefy
(
hidden_hidden_size
=
None
,
top_k
=
None
):
model
,
r
'''
num_experts
=
None
,
distributed_experts
=
True
,
hidden_hidden_size
=
None
,
top_k
=
None
,
):
r
"""
Replace MLP layers in a transformer-based model in Megatron by MoE.
Replace MLP layers in a transformer-based model in Megatron by MoE.
* `model` should be a standard Megatron model that has
* `model` should be a standard Megatron model that has
`model.language_model.transformer.layers` as transformer layers, which is an
`model.language_model.transformer.layers` as transformer layers, which is an
...
@@ -123,24 +140,25 @@ def fmoefy(model, num_experts=None, distributed_experts=True,
...
@@ -123,24 +140,25 @@ def fmoefy(model, num_experts=None, distributed_experts=True,
Note that pipeline parallel is not supported yet. When distributed experts
Note that pipeline parallel is not supported yet. When distributed experts
are enabled, their communicator should be Megatron's
are enabled, their communicator should be Megatron's
tensor_model_parall_comm x data_parallel_comm, which is not created.
tensor_model_parall_comm x data_parallel_comm, which is not created.
'''
"""
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
import
mpu
args
=
get_args
()
args
=
get_args
()
if
num_experts
is
not
None
:
if
num_experts
is
not
None
:
args
.
num_experts
=
num_experts
args
.
num_experts
=
num_experts
assert
(
assert
(
'
num_experts
'
in
args
"
num_experts
"
in
args
),
'
num_experts should be specified in arguments or fmoefy function
'
),
"
num_experts should be specified in arguments or fmoefy function
"
if
hidden_hidden_size
is
not
None
:
if
hidden_hidden_size
is
not
None
:
args
.
hidden_hidden_size
=
hidden_hidden_size
args
.
hidden_hidden_size
=
hidden_hidden_size
elif
not
hasattr
(
args
,
'
hidden_hidden_size
'
):
elif
not
hasattr
(
args
,
"
hidden_hidden_size
"
):
args
.
hidden_hidden_size
=
args
.
hidden_size
*
4
args
.
hidden_hidden_size
=
args
.
hidden_size
*
4
if
top_k
is
not
None
:
if
top_k
is
not
None
:
args
.
top_k
=
top_k
args
.
top_k
=
top_k
elif
not
hasattr
(
args
,
'
top_k
'
):
elif
not
hasattr
(
args
,
"
top_k
"
):
args
.
top_k
=
2
args
.
top_k
=
2
# Set distributed_experts to None to use default setting in args
# Set distributed_experts to None to use default setting in args
...
@@ -153,33 +171,35 @@ def fmoefy(model, num_experts=None, distributed_experts=True,
...
@@ -153,33 +171,35 @@ def fmoefy(model, num_experts=None, distributed_experts=True,
class
DistributedDataParallel
(
DistributedGroupedDataParallel
):
class
DistributedDataParallel
(
DistributedGroupedDataParallel
):
r
'''
r
"""
A wrapper that is used to replace the DDP module provided by Megatron, which
A wrapper that is used to replace the DDP module provided by Megatron, which
is adapted to enable the sophiscated parallel and reduction strategies in
is adapted to enable the sophiscated parallel and reduction strategies in
Fast MoE.
Fast MoE.
'''
"""
def
__init__
(
self
,
module
):
def
__init__
(
self
,
module
):
from
megatron
import
mpu
from
megatron
import
mpu
super
().
__init__
(
super
().
__init__
(
module
,
module
,
mp_group
=
mpu
.
get_model_parallel_group
(),
mp_group
=
mpu
.
get_model_parallel_group
(),
dp_group
=
mpu
.
get_data_parallel_group
()
dp_group
=
mpu
.
get_data_parallel_group
()
,
)
)
def
state_dict
(
self
,
*
args
,
**
kwargs
):
def
state_dict
(
self
,
*
args
,
**
kwargs
):
r
'''
r
"""
Keep consitency with Megatron
Keep consitency with Megatron
'''
"""
return
self
.
module
.
state_dict
(
*
args
,
**
kwargs
)
return
self
.
module
.
state_dict
(
*
args
,
**
kwargs
)
def
state_dict_for_save_checkpoint
(
self
,
*
args
,
**
kwargs
):
def
state_dict_for_save_checkpoint
(
self
,
*
args
,
**
kwargs
):
r
'''
r
"""
Keep consitency with Megatron
Keep consitency with Megatron
'''
"""
return
self
.
module
.
state_dict_for_save_checkpoint
(
*
args
,
**
kwargs
)
return
self
.
module
.
state_dict_for_save_checkpoint
(
*
args
,
**
kwargs
)
def
load_state_dict
(
self
,
*
args
,
**
kwargs
):
def
load_state_dict
(
self
,
*
args
,
**
kwargs
):
r
'''
r
"""
Keep consitency with Megatron
Keep consitency with Megatron
'''
"""
return
self
.
module
.
load_state_dict
(
*
args
,
**
kwargs
)
return
self
.
module
.
load_state_dict
(
*
args
,
**
kwargs
)
fmoe/transformer.py
View file @
4d48209d
r
'''
r
"""
Adaption to act as the MLP layer using an MoE MLP layer in transformer.
Adaption to act as the MLP layer using an MoE MLP layer in transformer.
'''
"""
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
.gates
import
NaiveGate
from
.gates
import
NaiveGate
...
@@ -8,23 +8,22 @@ from .layers import FMoE, FMoELinear
...
@@ -8,23 +8,22 @@ from .layers import FMoE, FMoELinear
class
_Expert
(
nn
.
Module
):
class
_Expert
(
nn
.
Module
):
r
'''
r
"""
An expert using 2 FMoELinear modules to speed up the computation of experts
An expert using 2 FMoELinear modules to speed up the computation of experts
within one worker.
within one worker.
'''
"""
def
__init__
(
self
,
num_expert
,
d_model
,
d_hidden
,
activation
,
rank
=
0
):
def
__init__
(
self
,
num_expert
,
d_model
,
d_hidden
,
activation
,
rank
=
0
):
super
().
__init__
()
super
().
__init__
()
self
.
htoh4
=
FMoELinear
(
num_expert
,
d_model
,
d_hidden
,
self
.
htoh4
=
FMoELinear
(
num_expert
,
d_model
,
d_hidden
,
bias
=
True
,
rank
=
rank
)
bias
=
True
,
rank
=
rank
)
self
.
h4toh
=
FMoELinear
(
num_expert
,
d_hidden
,
d_model
,
bias
=
True
,
rank
=
rank
)
self
.
h4toh
=
FMoELinear
(
num_expert
,
d_hidden
,
d_model
,
bias
=
True
,
rank
=
rank
)
self
.
activation
=
activation
self
.
activation
=
activation
def
forward
(
self
,
inp
,
fwd_expert_count
):
def
forward
(
self
,
inp
,
fwd_expert_count
):
r
'''
r
"""
First expand input to 4h (the hidden size is variable, but is called h4
First expand input to 4h (the hidden size is variable, but is called h4
for convenience). Then perform activation. Finally shirink back to h.
for convenience). Then perform activation. Finally shirink back to h.
'''
"""
x
=
self
.
htoh4
(
inp
,
fwd_expert_count
)
x
=
self
.
htoh4
(
inp
,
fwd_expert_count
)
x
=
self
.
activation
(
x
)
x
=
self
.
activation
(
x
)
x
=
self
.
h4toh
(
x
,
fwd_expert_count
)
x
=
self
.
h4toh
(
x
,
fwd_expert_count
)
...
@@ -32,11 +31,12 @@ class _Expert(nn.Module):
...
@@ -32,11 +31,12 @@ class _Expert(nn.Module):
class
FMoETransformerMLP
(
FMoE
):
class
FMoETransformerMLP
(
FMoE
):
r
'''
r
"""
A complete MoE MLP module in a Transformer block.
A complete MoE MLP module in a Transformer block.
* `activation` is the activation function to be used in MLP in each expert.
* `activation` is the activation function to be used in MLP in each expert.
* `d_hidden` is the dimension of the MLP layer.
* `d_hidden` is the dimension of the MLP layer.
'''
"""
def
__init__
(
def
__init__
(
self
,
self
,
num_expert
=
32
,
num_expert
=
32
,
...
@@ -47,19 +47,26 @@ class FMoETransformerMLP(FMoE):
...
@@ -47,19 +47,26 @@ class FMoETransformerMLP(FMoE):
activation
=
torch
.
nn
.
GELU
(),
activation
=
torch
.
nn
.
GELU
(),
gate
=
NaiveGate
,
gate
=
NaiveGate
,
top_k
=
2
,
top_k
=
2
,
expert_dp_comm
=
'
none
'
expert_dp_comm
=
"
none
"
,
):
):
super
().
__init__
(
num_expert
=
num_expert
,
d_model
=
d_model
,
gate
=
gate
,
super
().
__init__
(
top_k
=
top_k
,
world_size
=
world_size
,
mp_group
=
mp_group
)
num_expert
=
num_expert
,
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
,
d_model
=
d_model
,
rank
=
self
.
mp_rank
)
gate
=
gate
,
top_k
=
top_k
,
world_size
=
world_size
,
mp_group
=
mp_group
,
)
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
,
rank
=
self
.
mp_rank
)
self
.
mark_parallel_comm
(
expert_dp_comm
)
self
.
mark_parallel_comm
(
expert_dp_comm
)
def
forward
(
self
,
inp
:
torch
.
Tensor
):
def
forward
(
self
,
inp
:
torch
.
Tensor
):
r
'''
r
"""
This module wraps up the FMoE module with reshape, residual and layer
This module wraps up the FMoE module with reshape, residual and layer
normalization.
normalization.
'''
"""
original_shape
=
inp
.
shape
original_shape
=
inp
.
shape
inp
=
inp
.
reshape
(
-
1
,
self
.
d_model
)
inp
=
inp
.
reshape
(
-
1
,
self
.
d_model
)
output
=
super
().
forward
(
inp
)
output
=
super
().
forward
(
inp
)
...
...
fmoe/utils.py
View file @
4d48209d
r
'''
r
"""
Utils to play with PyTorch.
Utils to play with PyTorch.
'''
"""
import
torch.distributed
as
dist
import
torch.distributed
as
dist
# pylint: disable=broad-except
# pylint: disable=broad-except
# pylint: disable=protected-access
# pylint: disable=protected-access
def
get_torch_default_comm
():
def
get_torch_default_comm
():
r
'''
r
"""
The NCCL communicator is needed so that Fast MoE can perform customized
The NCCL communicator is needed so that Fast MoE can perform customized
communication operators in the C code. However, it is not a publicly
communication operators in the C code. However, it is not a publicly
available variable. Therefore, a hacking class of the `ProcessGroupNCCL`
available variable. Therefore, a hacking class of the `ProcessGroupNCCL`
...
@@ -15,7 +15,7 @@ def get_torch_default_comm():
...
@@ -15,7 +15,7 @@ def get_torch_default_comm():
communicator out from the object. As PyTorch's private interface varies from
communicator out from the object. As PyTorch's private interface varies from
time to time, different hacking techniques are tried one-by-one to be
time to time, different hacking techniques are tried one-by-one to be
compatible with various versions of PyTorch.
compatible with various versions of PyTorch.
'''
"""
try
:
try
:
comm
=
dist
.
distributed_c10d
.
_get_default_group
()
comm
=
dist
.
distributed_c10d
.
_get_default_group
()
return
comm
return
comm
...
@@ -27,4 +27,4 @@ def get_torch_default_comm():
...
@@ -27,4 +27,4 @@ def get_torch_default_comm():
return
comm
return
comm
except
Exception
as
_
:
except
Exception
as
_
:
pass
pass
raise
RuntimeError
(
'
Unsupported PyTorch version
'
)
raise
RuntimeError
(
"
Unsupported PyTorch version
"
)
tests/benchmark_mlp.py
View file @
4d48209d
...
@@ -10,18 +10,27 @@ import os
...
@@ -10,18 +10,27 @@ import os
rank
=
None
rank
=
None
world_size
=
None
world_size
=
None
dev_name_default
=
'
cuda:0
'
dev_name_default
=
"
cuda:0
"
class
BruteForceMoE
(
nn
.
Module
):
class
BruteForceMoE
(
nn
.
Module
):
def
__init__
(
self
,
num_expert
=
32
,
d_model
=
1024
,
d_hidden
=
4096
,
def
__init__
(
world_size
=
1
,
mp_group
=
None
,
self
,
num_expert
=
32
,
d_model
=
1024
,
d_hidden
=
4096
,
world_size
=
1
,
mp_group
=
None
,
activation
=
torch
.
nn
.
functional
.
gelu
,
activation
=
torch
.
nn
.
functional
.
gelu
,
gate
=
NaiveGate
,
top_k
=
1
,
pre_lnorm
=
False
):
gate
=
NaiveGate
,
assert
world_size
==
1
,
'Distributed brute force is not supported'
top_k
=
1
,
pre_lnorm
=
False
,
):
assert
world_size
==
1
,
"Distributed brute force is not supported"
super
().
__init__
()
super
().
__init__
()
self
.
mlp
=
BruteForceMoELinear
(
activation
,
num_expert
,
d_model
,
self
.
mlp
=
BruteForceMoELinear
(
d_hidden
,
1
,
top_k
)
activation
,
num_expert
,
d_model
,
d_hidden
,
1
,
top_k
)
self
.
top_k
=
top_k
self
.
top_k
=
top_k
self
.
gate
=
gate
(
d_model
,
num_expert
,
world_size
,
top_k
)
self
.
gate
=
gate
(
d_model
,
num_expert
,
world_size
,
top_k
)
self
.
pre_lnorm
=
pre_lnorm
self
.
pre_lnorm
=
pre_lnorm
...
@@ -43,20 +52,32 @@ def benchmark_mlp(MOELayer, batch_size, in_feat, hidden_feat, num_expert, top_k)
...
@@ -43,20 +52,32 @@ def benchmark_mlp(MOELayer, batch_size, in_feat, hidden_feat, num_expert, top_k)
torch
.
manual_seed
(
42
+
rank
)
torch
.
manual_seed
(
42
+
rank
)
torch
.
cuda
.
manual_seed
(
42
+
rank
)
torch
.
cuda
.
manual_seed
(
42
+
rank
)
if
rank
==
0
:
if
rank
==
0
:
print
(
'Performance test of {} mm size {} {}x{} experts {}x{} topk {}'
print
(
.
format
(
MOELayer
.
__name__
,
batch_size
,
in_feat
,
hidden_feat
,
"Performance test of {} mm size {} {}x{} experts {}x{} topk {}"
.
format
(
world_size
,
num_expert
,
top_k
))
MOELayer
.
__name__
,
batch_size
,
in_feat
,
hidden_feat
,
world_size
,
num_expert
,
top_k
,
)
)
if
world_size
>
1
:
if
world_size
>
1
:
dev_name
=
'
cuda
'
dev_name
=
"
cuda
"
else
:
else
:
dev_name
=
dev_name_default
dev_name
=
dev_name_default
inp
=
torch
.
rand
(
batch_size
,
in_feat
).
cuda
(
dev_name
)
inp
=
torch
.
rand
(
batch_size
,
in_feat
).
cuda
(
dev_name
)
inp
.
requires_grad
=
True
inp
.
requires_grad
=
True
moe
=
MOELayer
(
num_expert
=
num_expert
,
moe
=
MOELayer
(
d_model
=
in_feat
,
d_hidden
=
hidden_feat
,
num_expert
=
num_expert
,
world_size
=
world_size
,
top_k
=
top_k
).
cuda
(
dev_name
)
d_model
=
in_feat
,
d_hidden
=
hidden_feat
,
world_size
=
world_size
,
top_k
=
top_k
,
).
cuda
(
dev_name
)
moe
.
train
()
moe
.
train
()
# warm up
# warm up
...
@@ -64,10 +85,10 @@ def benchmark_mlp(MOELayer, batch_size, in_feat, hidden_feat, num_expert, top_k)
...
@@ -64,10 +85,10 @@ def benchmark_mlp(MOELayer, batch_size, in_feat, hidden_feat, num_expert, top_k)
_
=
moe
(
inp
)
_
=
moe
(
inp
)
n_runs
=
16
n_runs
=
16
tott
=
0.
tott
=
0.
0
backt
=
0.
backt
=
0.
0
maxt
=
0.
maxt
=
0.
0
sqtot
=
0.
sqtot
=
0.
0
for
i
in
range
(
n_runs
):
for
i
in
range
(
n_runs
):
ts
=
time
.
time
()
ts
=
time
.
time
()
o
=
moe
(
inp
)
o
=
moe
(
inp
)
...
@@ -80,36 +101,48 @@ def benchmark_mlp(MOELayer, batch_size, in_feat, hidden_feat, num_expert, top_k)
...
@@ -80,36 +101,48 @@ def benchmark_mlp(MOELayer, batch_size, in_feat, hidden_feat, num_expert, top_k)
bte
=
time
.
time
()
bte
=
time
.
time
()
tott
+=
te
-
ts
tott
+=
te
-
ts
sqtot
+=
(
te
-
ts
)
**
2
sqtot
+=
(
te
-
ts
)
**
2
maxt
=
max
(
maxt
,
te
-
ts
)
maxt
=
max
(
maxt
,
te
-
ts
)
backt
+=
bte
-
bts
backt
+=
bte
-
bts
gflops
=
2e-9
*
n_runs
*
(
in_feat
*
hidden_feat
*
batch_size
*
top_k
*
2
+
gflops
=
(
batch_size
*
in_feat
*
num_expert
)
/
tott
2e-9
print
(
'Time mean/max/stdev/back {:.3f} {:.3f} {:.3f} {:.3f} ms, {:.3f} GFLOPs'
.
format
(
*
n_runs
tott
*
1e3
/
n_runs
,
maxt
*
1e3
,
*
(
(
sqtot
/
n_runs
-
(
tott
/
n_runs
)
**
2
)
*
1e3
*
top_k
/
n_runs
,
in_feat
*
hidden_feat
*
batch_size
*
top_k
*
2
backt
*
1e3
/
n_runs
,
gflops
))
+
batch_size
*
in_feat
*
num_expert
)
/
tott
if
__name__
==
'__main__'
:
)
os
.
environ
[
'RANK'
]
=
os
.
environ
.
get
(
'OMPI_COMM_WORLD_RANK'
,
'0'
)
print
(
os
.
environ
[
'WORLD_SIZE'
]
=
os
.
environ
.
get
(
'OMPI_COMM_WORLD_SIZE'
,
'1'
)
"Time mean/max/stdev/back {:.3f} {:.3f} {:.3f} {:.3f} ms, {:.3f} GFLOPs"
.
format
(
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
os
.
environ
.
get
(
'OMPI_COMM_WORLD_LOCAL_RANK'
,
'0'
)
tott
*
1e3
/
n_runs
,
if
int
(
os
.
environ
[
'WORLD_SIZE'
])
>
1
:
maxt
*
1e3
,
torch
.
distributed
.
init_process_group
(
backend
=
'nccl'
)
(
sqtot
/
n_runs
-
(
tott
/
n_runs
)
**
2
)
*
1e3
*
top_k
/
n_runs
,
backt
*
1e3
/
n_runs
,
gflops
,
)
)
if
__name__
==
"__main__"
:
os
.
environ
[
"RANK"
]
=
os
.
environ
.
get
(
"OMPI_COMM_WORLD_RANK"
,
"0"
)
os
.
environ
[
"WORLD_SIZE"
]
=
os
.
environ
.
get
(
"OMPI_COMM_WORLD_SIZE"
,
"1"
)
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
os
.
environ
.
get
(
"OMPI_COMM_WORLD_LOCAL_RANK"
,
"0"
)
if
int
(
os
.
environ
[
"WORLD_SIZE"
])
>
1
:
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
)
rank
=
torch
.
distributed
.
get_rank
()
rank
=
torch
.
distributed
.
get_rank
()
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
else
:
else
:
rank
=
0
rank
=
0
world_size
=
1
world_size
=
1
batch_size
=
int
(
os
.
environ
.
get
(
'BATCH_SIZE'
,
'4096'
))
batch_size
=
int
(
os
.
environ
.
get
(
"BATCH_SIZE"
,
"4096"
))
d_model
=
int
(
os
.
environ
.
get
(
'D_MODEL'
,
'1024'
))
d_model
=
int
(
os
.
environ
.
get
(
"D_MODEL"
,
"1024"
))
d_hidden
=
int
(
os
.
environ
.
get
(
'D_HIDDEN'
,
'4096'
))
d_hidden
=
int
(
os
.
environ
.
get
(
"D_HIDDEN"
,
"4096"
))
num_expert
=
int
(
os
.
environ
.
get
(
'NUM_EXPERT'
,
'64'
))
num_expert
=
int
(
os
.
environ
.
get
(
"NUM_EXPERT"
,
"64"
))
top_k
=
int
(
os
.
environ
.
get
(
'TOP_K'
,
'2'
))
top_k
=
int
(
os
.
environ
.
get
(
"TOP_K"
,
"2"
))
benchmark_mlp
(
FMoETransformerMLP
,
batch_size
,
d_model
,
benchmark_mlp
(
FMoETransformerMLP
,
batch_size
,
d_model
,
d_hidden
,
num_expert
,
top_k
)
d_hidden
,
num_expert
,
top_k
)
if
world_size
==
1
:
if
world_size
==
1
:
benchmark_mlp
(
BruteForceMoE
,
batch_size
,
d_model
,
d_hidden
,
num_expert
,
benchmark_mlp
(
BruteForceMoE
,
batch_size
,
d_model
,
d_hidden
,
num_expert
,
top_k
)
top_k
)
tests/moe.py
View file @
4d48209d
...
@@ -20,24 +20,19 @@ class BruteForceMoELinear(nn.Module):
...
@@ -20,24 +20,19 @@ class BruteForceMoELinear(nn.Module):
self
.
weight_htoh4
=
nn
.
Parameter
(
self
.
weight_htoh4
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
*
world_size
,
d_hidden
,
d_model
)
torch
.
Tensor
(
num_expert
*
world_size
,
d_hidden
,
d_model
)
)
)
self
.
bias_htoh4
=
nn
.
Parameter
(
self
.
bias_htoh4
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
*
world_size
,
d_hidden
))
torch
.
Tensor
(
num_expert
*
world_size
,
d_hidden
)
)
self
.
weight_h4toh
=
nn
.
Parameter
(
self
.
weight_h4toh
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
*
world_size
,
d_model
,
d_hidden
)
torch
.
Tensor
(
num_expert
*
world_size
,
d_model
,
d_hidden
)
)
)
self
.
bias_h4toh
=
nn
.
Parameter
(
self
.
bias_h4toh
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
*
world_size
,
d_model
))
torch
.
Tensor
(
num_expert
*
world_size
,
d_model
)
)
self
.
top_k
=
top_k
self
.
top_k
=
top_k
def
forward
(
self
,
inp
,
gate_idx
,
gate_score
):
def
forward
(
self
,
inp
,
gate_idx
,
gate_score
):
gate_long
=
gate_idx
.
long
()
gate_long
=
gate_idx
.
long
()
batch_size
=
inp
.
size
(
0
)
batch_size
=
inp
.
size
(
0
)
o
=
torch
.
empty
(
batch_size
,
self
.
d_model
,
dtype
=
inp
.
dtype
,
o
=
torch
.
empty
(
batch_size
,
self
.
d_model
,
dtype
=
inp
.
dtype
,
device
=
inp
.
device
)
device
=
inp
.
device
)
for
i
in
range
(
self
.
weight_htoh4
.
shape
[
0
]):
for
i
in
range
(
self
.
weight_htoh4
.
shape
[
0
]):
idx
=
(
gate_idx
==
i
)
idx
=
gate_idx
==
i
x
=
inp
[
idx
]
x
=
inp
[
idx
]
x
=
x
@
self
.
weight_htoh4
[
i
].
t
()
x
=
x
@
self
.
weight_htoh4
[
i
].
t
()
x
=
x
+
self
.
bias_htoh4
[
i
]
x
=
x
+
self
.
bias_htoh4
[
i
]
...
@@ -45,8 +40,9 @@ class BruteForceMoELinear(nn.Module):
...
@@ -45,8 +40,9 @@ class BruteForceMoELinear(nn.Module):
x
=
x
@
self
.
weight_h4toh
[
i
].
t
()
x
=
x
@
self
.
weight_h4toh
[
i
].
t
()
x
=
x
+
self
.
bias_h4toh
[
i
]
x
=
x
+
self
.
bias_h4toh
[
i
]
o
[
idx
]
=
x
o
[
idx
]
=
x
x
=
torch
.
bmm
(
gate_score
,
o
.
view
(
-
1
,
self
.
top_k
,
x
=
torch
.
bmm
(
gate_score
,
o
.
view
(
-
1
,
self
.
top_k
,
self
.
d_model
)).
reshape
(
self
.
d_model
)).
reshape
(
-
1
,
self
.
d_model
)
-
1
,
self
.
d_model
)
return
x
return
x
...
...
tests/test_dp.py
View file @
4d48209d
...
@@ -18,7 +18,7 @@ class MyMoE(FMoE):
...
@@ -18,7 +18,7 @@ class MyMoE(FMoE):
gate
=
NaiveGate
,
gate
=
NaiveGate
,
world_size
=
1
,
world_size
=
1
,
mp_group
=
None
,
mp_group
=
None
,
top_k
=
top_k
top_k
=
top_k
,
)
)
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
)
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
)
...
@@ -46,5 +46,5 @@ def test_fmoe_dp(
...
@@ -46,5 +46,5 @@ def test_fmoe_dp(
output
=
moe_dp
(
torch
.
rand
(
batch_size
,
d_model
).
cuda
())
output
=
moe_dp
(
torch
.
rand
(
batch_size
,
d_model
).
cuda
())
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
test_fmoe_dp
(
4
,
2
,
4
,
16
,
32
)
test_fmoe_dp
(
4
,
2
,
4
,
16
,
32
)
tests/test_numerical.py
View file @
4d48209d
...
@@ -68,7 +68,6 @@ class MyMoE(FMoE):
...
@@ -68,7 +68,6 @@ class MyMoE(FMoE):
super
().
__init__
(
super
().
__init__
(
num_expert
=
num_expert
,
num_expert
=
num_expert
,
d_model
=
d_model
,
d_model
=
d_model
,
gate
=
NaiveGate
,
gate
=
NaiveGate
,
world_size
=
world_size
,
world_size
=
world_size
,
mp_group
=
mp_group
,
mp_group
=
mp_group
,
...
@@ -77,8 +76,8 @@ class MyMoE(FMoE):
...
@@ -77,8 +76,8 @@ class MyMoE(FMoE):
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
)
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
)
rng
=
np
.
random
.
default_rng
(
1234
)
rng
=
np
.
random
.
default_rng
(
1234
)
_megatron_init_method
(
self
.
experts
.
htoh4
,
rng
,
1.
)
_megatron_init_method
(
self
.
experts
.
htoh4
,
rng
,
1.
0
)
_megatron_init_method
(
self
.
experts
.
h4toh
,
rng
,
1.
)
_megatron_init_method
(
self
.
experts
.
h4toh
,
rng
,
1.
0
)
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
4
,
8
])
...
@@ -152,8 +151,22 @@ def test_fmoe_linear(
...
@@ -152,8 +151,22 @@ def test_fmoe_linear(
moe
,
moe_raw
,
batch_size
,
d_model
,
top_k
,
rank
,
mp_group
moe
,
moe_raw
,
batch_size
,
d_model
,
top_k
,
rank
,
mp_group
)
)
moe_out_list
=
moe_out
,
moe_grad_in
,
moe
.
experts
.
htoh4
.
weight
.
grad
,
moe
.
experts
.
h4toh
.
weight
.
grad
,
moe
.
experts
.
htoh4
.
bias
.
grad
,
moe
.
experts
.
h4toh
.
bias
.
grad
moe_out_list
=
(
raw_out_list
=
raw_out
,
raw_grad_in
,
moe_raw
.
weight_htoh4
.
grad
,
moe_raw
.
weight_h4toh
.
grad
,
moe_raw
.
bias_htoh4
.
grad
,
moe_raw
.
bias_h4toh
.
grad
moe_out
,
moe_grad_in
,
moe
.
experts
.
htoh4
.
weight
.
grad
,
moe
.
experts
.
h4toh
.
weight
.
grad
,
moe
.
experts
.
htoh4
.
bias
.
grad
,
moe
.
experts
.
h4toh
.
bias
.
grad
,
)
raw_out_list
=
(
raw_out
,
raw_grad_in
,
moe_raw
.
weight_htoh4
.
grad
,
moe_raw
.
weight_h4toh
.
grad
,
moe_raw
.
bias_htoh4
.
grad
,
moe_raw
.
bias_h4toh
.
grad
,
)
if
world_size
>
1
:
if
world_size
>
1
:
_
,
__
,
htoh4_w_grad
,
h4toh_w_grad
,
htoh4_b_grad
,
h4toh_b_grad
=
raw_out_list
_
,
__
,
htoh4_w_grad
,
h4toh_w_grad
,
htoh4_b_grad
,
h4toh_b_grad
=
raw_out_list
...
@@ -176,7 +189,14 @@ def test_fmoe_linear(
...
@@ -176,7 +189,14 @@ def test_fmoe_linear(
)
)
raw_out_list
=
_
,
__
,
htoh4_w_grad
,
h4toh_w_grad
,
htoh4_b_grad
,
h4toh_b_grad
raw_out_list
=
_
,
__
,
htoh4_w_grad
,
h4toh_w_grad
,
htoh4_b_grad
,
h4toh_b_grad
names
=
[
"output"
,
"input grad"
,
"htoh4 weight grad"
,
"h4toh weight grad"
,
"htoh4 bias grad"
,
"h4toh bias grad"
]
names
=
[
"output"
,
"input grad"
,
"htoh4 weight grad"
,
"h4toh weight grad"
,
"htoh4 bias grad"
,
"h4toh bias grad"
,
]
_assert_numercial
(
names
,
moe_out_list
,
raw_out_list
,
rank
)
_assert_numercial
(
names
,
moe_out_list
,
raw_out_list
,
rank
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment