Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
87dad9d5
Unverified
Commit
87dad9d5
authored
Feb 21, 2021
by
Rick Ho
Committed by
GitHub
Feb 21, 2021
Browse files
Merge pull request #5 from laekov/bias
merge Bias
parents
ed9277f9
01464726
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
121 additions
and
75 deletions
+121
-75
fmoe/layers.py
fmoe/layers.py
+43
-25
fmoe/transformer.py
fmoe/transformer.py
+5
-6
tests/moe.py
tests/moe.py
+8
-0
tests/test_dp.py
tests/test_dp.py
+18
-15
tests/test_numerical.py
tests/test_numerical.py
+47
-29
No files found.
fmoe/layers.py
View file @
87dad9d5
...
...
@@ -19,13 +19,18 @@ class FMoELinear(nn.Module):
performed in parallel to increase the performance.
The FMoELinear module provides such function.
'''
def
__init__
(
self
,
num_expert
=
32
,
in_feat
=
1024
,
out_feat
=
1024
,
rank
=
0
):
def
__init__
(
self
,
num_expert
:
int
,
in_feat
:
int
,
out_feat
:
int
,
bias
:
bool
=
True
,
rank
:
int
=
0
):
super
().
__init__
()
self
.
num_expert
=
num_expert
self
.
in_feat
=
in_feat
self
.
out_feat
=
out_feat
self
.
rank
=
rank
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
,
out_feat
,
in_feat
))
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
,
out_feat
))
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
reset_parameters
()
def
reset_parameters
(
self
):
...
...
@@ -41,17 +46,32 @@ class FMoELinear(nn.Module):
bound
=
math
.
sqrt
(
3.0
)
*
std
device
=
self
.
weight
.
device
dtype
=
self
.
weight
.
dtype
for
i
in
range
(
self
.
num_expert
):
weight
=
rng
.
uniform
(
-
bound
,
bound
,
size
=
tuple
(
self
.
weight
[
i
].
size
()))
self
.
weight
.
data
[
i
]
=
torch
.
tensor
(
weight
,
dtype
=
dtype
,
device
=
device
)
weight
=
rng
.
uniform
(
-
bound
,
bound
,
size
=
tuple
(
self
.
weight
.
size
()))
self
.
weight
.
data
=
torch
.
tensor
(
weight
,
dtype
=
dtype
,
device
=
device
)
if
self
.
bias
is
not
None
:
fan_in
,
_
=
nn
.
init
.
_calculate_fan_in_and_fan_out
(
self
.
weight
[
0
])
bound
=
1
/
math
.
sqrt
(
fan_in
)
bias
=
rng
.
uniform
(
-
bound
,
bound
,
size
=
tuple
(
self
.
bias
.
size
()))
self
.
bias
.
data
=
torch
.
tensor
(
bias
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
inp
,
fwd_expert_count
):
r
'''
Call MOE function
'''
return
MOELinear
.
apply
(
inp
,
self
.
weight
,
fwd_expert_count
)
x
=
MOELinear
.
apply
(
inp
,
self
.
weight
,
fwd_expert_count
)
if
self
.
bias
is
not
None
:
bias
=
torch
.
repeat_interleave
(
self
.
bias
,
fwd_expert_count
.
to
(
self
.
bias
.
device
),
dim
=
0
)
x
=
x
+
bias
return
x
def
extra_repr
(
self
)
->
str
:
return
'num_expert={}, in_features={},
\
out_features={}, bias={}, rank={}'
.
format
(
self
.
num_expert
,
self
.
in_feat
,
self
.
out_feat
,
self
.
bias
is
not
None
,
self
.
rank
)
def
mark_module_parallel_comm
(
module
,
comm
):
...
...
@@ -92,8 +112,8 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
class
FMoE
(
nn
.
Module
):
r
'''
A general moe implementation that supports an arbitrary module as the
expert
Either `expert` or `expert_fn` is required
.
A general moe implementation that supports an arbitrary module as the
expert
.
* `num_expert` stands for the number of experts on **each** worker.
* `world_size` stands for the total number of workers that contains
different experts.
...
...
@@ -106,12 +126,9 @@ class FMoE(nn.Module):
* `gate` is a gate class which can found in `fmoe.gates`.
* `expert` can be specified as a module class, it is used to generate
`num_expert` expert modules.
* `expert_fn` is specified as a callable object or a function, it will be
called during forward, giving the input tensor (contiguous) and the array of
the number of input feature to each expert as input.
'''
def
__init__
(
self
,
num_expert
=
32
,
d_model
=
1024
,
world_size
=
1
,
mp_group
=
None
,
top_k
=
2
,
gate
=
NaiveGate
,
expert
=
None
,
expert_fn
=
None
):
top_k
=
2
,
gate
=
NaiveGate
,
expert
=
None
):
super
().
__init__
()
self
.
num_expert
=
num_expert
self
.
d_model
=
d_model
...
...
@@ -125,19 +142,20 @@ class FMoE(nn.Module):
self
.
mp_rank
=
mp_group
.
rank
()
self
.
top_k
=
top_k
self
.
gate
=
gate
(
d_model
,
num_expert
,
world_size
,
top_k
)
if
expert_fn
is
None
:
assert
expert
is
not
None
,
'Either expert or expert_fn should be set'
if
expert
is
not
None
:
self
.
experts
=
[
expert
(
d_model
)
for
_
in
range
(
num_expert
)]
def
expert_fn
(
inp
,
fwd_expert_count
):
outputs
=
[]
base_idx
=
0
for
i
in
range
(
self
.
num_expert
):
batch_size
=
fwd_expert_count
[
i
].
item
()
inp_slice
=
inp
[
base_idx
:
base_idx
+
batch_size
]
outputs
.
append
(
self
.
experts
[
i
](
inp_slice
))
base_idx
+=
batch_size
return
torch
.
cat
(
outputs
,
dim
=
0
)
self
.
expert_fn
=
expert_fn
def
expert_fn
(
self
,
inp
,
fwd_expert_count
):
if
isinstance
(
self
.
experts
,
nn
.
Module
):
return
self
.
experts
(
inp
,
fwd_expert_count
)
outputs
=
[]
base_idx
=
0
for
i
in
range
(
self
.
num_expert
):
batch_size
=
fwd_expert_count
[
i
].
item
()
inp_slice
=
inp
[
base_idx
:
base_idx
+
batch_size
]
outputs
.
append
(
self
.
experts
[
i
](
inp_slice
))
base_idx
+=
batch_size
return
torch
.
cat
(
outputs
,
dim
=
0
)
def
mark_parallel_comm
(
self
):
r
'''
...
...
fmoe/transformer.py
View file @
87dad9d5
...
...
@@ -14,8 +14,10 @@ class _Expert(nn.Module):
'''
def
__init__
(
self
,
num_expert
,
d_model
,
d_hidden
,
activation
,
rank
=
0
):
super
().
__init__
()
self
.
htoh4
=
FMoELinear
(
num_expert
,
d_model
,
d_hidden
,
rank
)
self
.
h4toh
=
FMoELinear
(
num_expert
,
d_hidden
,
d_model
,
rank
)
self
.
htoh4
=
FMoELinear
(
num_expert
,
d_model
,
d_hidden
,
bias
=
True
,
rank
=
rank
)
self
.
h4toh
=
FMoELinear
(
num_expert
,
d_hidden
,
d_model
,
bias
=
True
,
rank
=
rank
)
self
.
activation
=
activation
def
forward
(
self
,
inp
,
fwd_expert_count
):
...
...
@@ -47,11 +49,8 @@ class FMoETransformerMLP(FMoE):
top_k
=
2
,
pre_lnorm
=
False
):
def
expert_fn
(
inp
,
gate
):
return
self
.
experts
(
inp
,
gate
)
super
().
__init__
(
num_expert
=
num_expert
,
d_model
=
d_model
,
gate
=
gate
,
top_k
=
top_k
,
world_size
=
world_size
,
mp_group
=
mp_group
,
expert_fn
=
expert_fn
)
top_k
=
top_k
,
world_size
=
world_size
,
mp_group
=
mp_group
)
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
,
rank
=
self
.
mp_rank
)
self
.
pre_lnorm
=
pre_lnorm
...
...
tests/moe.py
View file @
87dad9d5
...
...
@@ -20,9 +20,15 @@ class BruteForceMoELinear(nn.Module):
self
.
weight_htoh4
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
*
world_size
,
d_hidden
,
d_model
)
)
self
.
bias_htoh4
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
*
world_size
,
d_hidden
)
)
self
.
weight_h4toh
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
*
world_size
,
d_model
,
d_hidden
)
)
self
.
bias_h4toh
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
*
world_size
,
d_model
)
)
self
.
top_k
=
top_k
def
forward
(
self
,
inp
,
gate_idx
,
gate_score
):
...
...
@@ -34,8 +40,10 @@ class BruteForceMoELinear(nn.Module):
idx
=
(
gate_idx
==
i
)
x
=
inp
[
idx
]
x
=
x
@
self
.
weight_htoh4
[
i
].
t
()
x
=
x
+
self
.
bias_htoh4
[
i
]
x
=
self
.
activation
(
x
)
x
=
x
@
self
.
weight_h4toh
[
i
].
t
()
x
=
x
+
self
.
bias_h4toh
[
i
]
o
[
idx
]
=
x
x
=
torch
.
bmm
(
gate_score
,
o
.
view
(
-
1
,
self
.
top_k
,
self
.
d_model
)).
reshape
(
-
1
,
self
.
d_model
)
...
...
tests/test_dp.py
View file @
87dad9d5
...
...
@@ -10,6 +10,19 @@ from fmoe.transformer import _Expert
n_devices
=
int
(
os
.
environ
.
get
(
"N_GPUS"
,
"2"
))
class
MyMoE
(
FMoE
):
def
__init__
(
self
,
num_expert
,
d_model
,
d_hidden
,
top_k
,
activation
):
super
().
__init__
(
num_expert
=
num_expert
,
d_model
=
d_model
,
gate
=
NaiveGate
,
world_size
=
1
,
mp_group
=
None
,
top_k
=
top_k
)
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
)
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"top_k"
,
[
2
,
3
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
...
...
@@ -26,22 +39,12 @@ def test_fmoe_dp(
torch
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed
(
42
)
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
).
cuda
()
def
expert_fn
(
inp
,
gate
):
return
experts
(
inp
,
gate
)
moe
=
FMoE
(
num_expert
=
num_expert
,
d_model
=
d_model
,
gate
=
NaiveGate
,
world_size
=
1
,
mp_group
=
None
,
expert_fn
=
expert_fn
,
top_k
=
top_k
,
).
cuda
()
moe
=
MyMoE
(
num_expert
,
d_model
,
d_hidden
,
top_k
,
activation
).
cuda
()
moe_dp
=
torch
.
nn
.
DataParallel
(
moe
,
device_ids
=
list
(
range
(
n_devices
)))
for
i
in
range
(
5
):
output
=
moe_dp
(
torch
.
rand
(
batch_size
,
d_model
).
cuda
())
if
__name__
==
'__main__'
:
test_fmoe_dp
(
4
,
2
,
4
,
16
,
32
)
tests/test_numerical.py
View file @
87dad9d5
...
...
@@ -52,6 +52,20 @@ def _assert_numercial(names, moe_out_list, raw_out_list, rank):
assert
False
class
MyMoE
(
FMoE
):
def
__init__
(
self
,
num_expert
,
d_model
,
d_hidden
,
world_size
,
mp_group
,
top_k
,
activation
):
super
().
__init__
(
num_expert
=
num_expert
,
d_model
=
d_model
,
gate
=
NaiveGate
,
world_size
=
world_size
,
mp_group
=
mp_group
,
top_k
=
top_k
)
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
)
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"top_k"
,
[
2
,
3
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
...
...
@@ -74,20 +88,8 @@ def test_fmoe_linear(
torch
.
manual_seed
(
42
+
rank
)
torch
.
cuda
.
manual_seed
(
42
+
rank
)
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
).
cuda
()
def
expert_fn
(
inp
,
gate
):
return
experts
(
inp
,
gate
)
moe
=
FMoE
(
num_expert
=
num_expert
,
d_model
=
d_model
,
gate
=
NaiveGate
,
world_size
=
world_size
,
mp_group
=
mp_group
,
expert_fn
=
expert_fn
,
top_k
=
top_k
,
).
cuda
()
moe
=
MyMoE
(
num_expert
,
d_model
,
d_hidden
,
world_size
,
mp_group
,
top_k
,
activation
).
cuda
()
moe_raw
=
BruteForceMoELinear
(
activation
=
activation
,
...
...
@@ -99,38 +101,54 @@ def test_fmoe_linear(
).
cuda
()
if
world_size
==
1
:
moe_raw
.
weight_htoh4
.
data
=
experts
.
htoh4
.
weight
.
data
.
clone
()
moe_raw
.
weight_h4toh
.
data
=
experts
.
h4toh
.
weight
.
data
.
clone
()
moe_raw
.
weight_htoh4
.
data
=
moe
.
experts
.
htoh4
.
weight
.
data
.
clone
()
moe_raw
.
bias_htoh4
.
data
=
moe
.
experts
.
htoh4
.
bias
.
data
.
clone
()
moe_raw
.
weight_h4toh
.
data
=
moe
.
experts
.
h4toh
.
weight
.
data
.
clone
()
moe_raw
.
bias_h4toh
.
data
=
moe
.
experts
.
h4toh
.
bias
.
data
.
clone
()
else
:
weight_htoh4_array
=
[
torch
.
empty_like
(
experts
.
htoh4
.
weight
.
data
)
for
_
in
range
(
world_size
)
torch
.
empty_like
(
moe
.
experts
.
htoh4
.
weight
.
data
)
for
_
in
range
(
world_size
)
]
bias_htoh4_array
=
[
torch
.
empty_like
(
moe
.
experts
.
htoh4
.
bias
.
data
)
for
_
in
range
(
world_size
)
]
torch
.
distributed
.
all_gather
(
weight_htoh4_array
,
experts
.
htoh4
.
weight
.
data
)
torch
.
distributed
.
all_gather
(
weight_htoh4_array
,
moe
.
experts
.
htoh4
.
weight
.
data
)
torch
.
distributed
.
all_gather
(
bias_htoh4_array
,
moe
.
experts
.
htoh4
.
bias
.
data
)
moe_raw
.
weight_htoh4
.
data
=
torch
.
cat
(
weight_htoh4_array
,
dim
=
0
)
moe_raw
.
bias_htoh4
.
data
=
torch
.
cat
(
bias_htoh4_array
,
dim
=
0
)
weight_h4toh_array
=
[
torch
.
empty_like
(
experts
.
h4toh
.
weight
.
data
)
for
_
in
range
(
world_size
)
torch
.
empty_like
(
moe
.
experts
.
h4toh
.
weight
.
data
)
for
_
in
range
(
world_size
)
]
bias_h4toh_array
=
[
torch
.
empty_like
(
moe
.
experts
.
h4toh
.
bias
.
data
)
for
_
in
range
(
world_size
)
]
torch
.
distributed
.
all_gather
(
weight_h4toh_array
,
experts
.
h4toh
.
weight
.
data
)
torch
.
distributed
.
all_gather
(
weight_h4toh_array
,
moe
.
experts
.
h4toh
.
weight
.
data
)
torch
.
distributed
.
all_gather
(
bias_h4toh_array
,
moe
.
experts
.
h4toh
.
bias
.
data
)
moe_raw
.
weight_h4toh
.
data
=
torch
.
cat
(
weight_h4toh_array
,
dim
=
0
)
moe_raw
.
bias_h4toh
.
data
=
torch
.
cat
(
bias_h4toh_array
,
dim
=
0
)
moe_out
,
raw_out
=
_perform_forward
(
moe
,
moe_raw
,
batch_size
,
d_model
,
top_k
,
rank
,
mp_group
)
moe_out_list
=
moe_out
,
experts
.
htoh4
.
weight
.
grad
,
experts
.
h4toh
.
weight
.
grad
raw_out_list
=
raw_out
,
moe_raw
.
weight_htoh4
.
grad
,
moe_raw
.
weight_h4toh
.
grad
moe_out_list
=
moe_out
,
moe
.
experts
.
htoh4
.
weight
.
grad
,
moe
.
experts
.
h4toh
.
weight
.
grad
,
moe
.
experts
.
htoh4
.
bias
.
grad
,
moe
.
experts
.
h4toh
.
bias
.
grad
raw_out_list
=
raw_out
,
moe_raw
.
weight_htoh4
.
grad
,
moe_raw
.
weight_h4toh
.
grad
,
moe_raw
.
bias_htoh4
.
grad
,
moe_raw
.
bias_h4toh
.
grad
if
world_size
>
1
:
_
,
htoh4_grad
,
h4toh_grad
=
raw_out_list
torch
.
distributed
.
all_reduce
(
htoh4_grad
)
torch
.
distributed
.
all_reduce
(
h4toh_grad
)
_
,
htoh4_w_grad
,
h4toh_w_grad
,
htoh4_b_grad
,
h4toh_b_grad
=
raw_out_list
torch
.
distributed
.
all_reduce
(
htoh4_w_grad
)
torch
.
distributed
.
all_reduce
(
h4toh_w_grad
)
torch
.
distributed
.
all_reduce
(
htoh4_b_grad
)
torch
.
distributed
.
all_reduce
(
h4toh_b_grad
)
mp_size
=
mp_group
.
size
()
if
mp_group
else
1
htoh4_grad
=
htoh4_grad
[
rank
*
num_expert
:
(
rank
+
1
)
*
num_expert
]
/
mp_size
h4toh_grad
=
h4toh_grad
[
rank
*
num_expert
:
(
rank
+
1
)
*
num_expert
]
/
mp_size
raw_out_list
=
_
,
htoh4_grad
,
h4toh_grad
htoh4_w_grad
=
htoh4_w_grad
[
rank
*
num_expert
:
(
rank
+
1
)
*
num_expert
]
/
mp_size
h4toh_w_grad
=
h4toh_w_grad
[
rank
*
num_expert
:
(
rank
+
1
)
*
num_expert
]
/
mp_size
htoh4_b_grad
=
htoh4_b_grad
[
rank
*
num_expert
:
(
rank
+
1
)
*
num_expert
]
/
mp_size
h4toh_b_grad
=
h4toh_b_grad
[
rank
*
num_expert
:
(
rank
+
1
)
*
num_expert
]
/
mp_size
raw_out_list
=
_
,
htoh4_w_grad
,
h4toh_w_grad
,
htoh4_b_grad
,
h4toh_b_grad
names
=
[
"output"
,
"htoh4 weight grad"
,
"h4toh weight grad"
]
names
=
[
"output"
,
"htoh4 weight grad"
,
"h4toh weight
grad"
,
"htoh4 bias grad"
,
"h4toh bias
grad"
]
_assert_numercial
(
names
,
moe_out_list
,
raw_out_list
,
rank
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment