Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
a88d1124
Commit
a88d1124
authored
Feb 21, 2021
by
Jiezhong Qiu
Browse files
test bias, dp still not passed
parent
3c24222c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
34 additions
and
4 deletions
+34
-4
fmoe/layers.py
fmoe/layers.py
+10
-2
fmoe/transformer.py
fmoe/transformer.py
+4
-2
tests/moe.py
tests/moe.py
+8
-0
tests/test_numerical.py
tests/test_numerical.py
+12
-0
No files found.
fmoe/layers.py
View file @
a88d1124
...
@@ -60,11 +60,19 @@ class FMoELinear(nn.Module):
...
@@ -60,11 +60,19 @@ class FMoELinear(nn.Module):
Call MOE function
Call MOE function
'''
'''
x
=
MOELinear
.
apply
(
inp
,
self
.
weight
,
fwd_expert_count
)
x
=
MOELinear
.
apply
(
inp
,
self
.
weight
,
fwd_expert_count
)
if
self
.
bias
:
if
self
.
bias
is
not
None
:
bias
=
torch
.
repeat_interleave
(
self
.
bias
,
fwd_expert_count
,
dim
=
0
)
bias
=
torch
.
repeat_interleave
(
self
.
bias
,
fwd_expert_count
.
to
(
self
.
bias
.
device
),
dim
=
0
)
x
=
x
+
bias
x
=
x
+
bias
return
x
return
x
def
extra_repr
(
self
)
->
str
:
return
'num_expert={}, in_features={},
\
out_features={}, bias={}, rank={}'
.
format
(
self
.
num_expert
,
self
.
in_feat
,
self
.
out_feat
,
self
.
bias
is
not
None
,
self
.
rank
)
def
mark_module_parallel_comm
(
module
,
comm
):
def
mark_module_parallel_comm
(
module
,
comm
):
r
'''
r
'''
...
...
fmoe/transformer.py
View file @
a88d1124
...
@@ -14,8 +14,10 @@ class _Expert(nn.Module):
...
@@ -14,8 +14,10 @@ class _Expert(nn.Module):
'''
'''
def
__init__
(
self
,
num_expert
,
d_model
,
d_hidden
,
activation
,
rank
=
0
):
def
__init__
(
self
,
num_expert
,
d_model
,
d_hidden
,
activation
,
rank
=
0
):
super
().
__init__
()
super
().
__init__
()
self
.
htoh4
=
FMoELinear
(
num_expert
,
d_model
,
d_hidden
,
rank
)
self
.
htoh4
=
FMoELinear
(
num_expert
,
d_model
,
d_hidden
,
self
.
h4toh
=
FMoELinear
(
num_expert
,
d_hidden
,
d_model
,
rank
)
bias
=
True
,
rank
=
rank
)
self
.
h4toh
=
FMoELinear
(
num_expert
,
d_hidden
,
d_model
,
bias
=
True
,
rank
=
rank
)
self
.
activation
=
activation
self
.
activation
=
activation
def
forward
(
self
,
inp
,
fwd_expert_count
):
def
forward
(
self
,
inp
,
fwd_expert_count
):
...
...
tests/moe.py
View file @
a88d1124
...
@@ -20,9 +20,15 @@ class BruteForceMoELinear(nn.Module):
...
@@ -20,9 +20,15 @@ class BruteForceMoELinear(nn.Module):
self
.
weight_htoh4
=
nn
.
Parameter
(
self
.
weight_htoh4
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
*
world_size
,
d_hidden
,
d_model
)
torch
.
Tensor
(
num_expert
*
world_size
,
d_hidden
,
d_model
)
)
)
self
.
bias_htoh4
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
*
world_size
,
d_hidden
)
)
self
.
weight_h4toh
=
nn
.
Parameter
(
self
.
weight_h4toh
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
*
world_size
,
d_model
,
d_hidden
)
torch
.
Tensor
(
num_expert
*
world_size
,
d_model
,
d_hidden
)
)
)
self
.
bias_h4toh
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
*
world_size
,
d_model
)
)
self
.
top_k
=
top_k
self
.
top_k
=
top_k
def
forward
(
self
,
inp
,
gate_idx
,
gate_score
):
def
forward
(
self
,
inp
,
gate_idx
,
gate_score
):
...
@@ -34,8 +40,10 @@ class BruteForceMoELinear(nn.Module):
...
@@ -34,8 +40,10 @@ class BruteForceMoELinear(nn.Module):
idx
=
(
gate_idx
==
i
)
idx
=
(
gate_idx
==
i
)
x
=
inp
[
idx
]
x
=
inp
[
idx
]
x
=
x
@
self
.
weight_htoh4
[
i
].
t
()
x
=
x
@
self
.
weight_htoh4
[
i
].
t
()
x
=
x
+
self
.
bias_htoh4
[
i
]
x
=
self
.
activation
(
x
)
x
=
self
.
activation
(
x
)
x
=
x
@
self
.
weight_h4toh
[
i
].
t
()
x
=
x
@
self
.
weight_h4toh
[
i
].
t
()
x
=
x
+
self
.
bias_h4toh
[
i
]
o
[
idx
]
=
x
o
[
idx
]
=
x
x
=
torch
.
bmm
(
gate_score
,
o
.
view
(
-
1
,
self
.
top_k
,
x
=
torch
.
bmm
(
gate_score
,
o
.
view
(
-
1
,
self
.
top_k
,
self
.
d_model
)).
reshape
(
-
1
,
self
.
d_model
)
self
.
d_model
)).
reshape
(
-
1
,
self
.
d_model
)
...
...
tests/test_numerical.py
View file @
a88d1124
...
@@ -100,19 +100,31 @@ def test_fmoe_linear(
...
@@ -100,19 +100,31 @@ def test_fmoe_linear(
if
world_size
==
1
:
if
world_size
==
1
:
moe_raw
.
weight_htoh4
.
data
=
experts
.
htoh4
.
weight
.
data
.
clone
()
moe_raw
.
weight_htoh4
.
data
=
experts
.
htoh4
.
weight
.
data
.
clone
()
moe_raw
.
bias_htoh4
.
data
=
experts
.
htoh4
.
bias
.
data
.
clone
()
moe_raw
.
weight_h4toh
.
data
=
experts
.
h4toh
.
weight
.
data
.
clone
()
moe_raw
.
weight_h4toh
.
data
=
experts
.
h4toh
.
weight
.
data
.
clone
()
moe_raw
.
bias_h4toh
.
data
=
experts
.
h4toh
.
bias
.
data
.
clone
()
else
:
else
:
weight_htoh4_array
=
[
weight_htoh4_array
=
[
torch
.
empty_like
(
experts
.
htoh4
.
weight
.
data
)
for
_
in
range
(
world_size
)
torch
.
empty_like
(
experts
.
htoh4
.
weight
.
data
)
for
_
in
range
(
world_size
)
]
]
bias_htoh4_array
=
[
torch
.
empty_like
(
experts
.
htoh4
.
bias
.
data
)
for
_
in
range
(
world_size
)
]
torch
.
distributed
.
all_gather
(
weight_htoh4_array
,
experts
.
htoh4
.
weight
.
data
)
torch
.
distributed
.
all_gather
(
weight_htoh4_array
,
experts
.
htoh4
.
weight
.
data
)
torch
.
distributed
.
all_gather
(
bias_htoh4_array
,
experts
.
htoh4
.
bias
.
data
)
moe_raw
.
weight_htoh4
.
data
=
torch
.
cat
(
weight_htoh4_array
,
dim
=
0
)
moe_raw
.
weight_htoh4
.
data
=
torch
.
cat
(
weight_htoh4_array
,
dim
=
0
)
moe_raw
.
bias_htoh4
.
data
=
torch
.
cat
(
bias_htoh4_array
,
dim
=
0
)
weight_h4toh_array
=
[
weight_h4toh_array
=
[
torch
.
empty_like
(
experts
.
h4toh
.
weight
.
data
)
for
_
in
range
(
world_size
)
torch
.
empty_like
(
experts
.
h4toh
.
weight
.
data
)
for
_
in
range
(
world_size
)
]
]
bias_h4toh_array
=
[
torch
.
empty_like
(
experts
.
h4toh
.
bias
.
data
)
for
_
in
range
(
world_size
)
]
torch
.
distributed
.
all_gather
(
weight_h4toh_array
,
experts
.
h4toh
.
weight
.
data
)
torch
.
distributed
.
all_gather
(
weight_h4toh_array
,
experts
.
h4toh
.
weight
.
data
)
torch
.
distributed
.
all_gather
(
bias_h4toh_array
,
experts
.
h4toh
.
bias
.
data
)
moe_raw
.
weight_h4toh
.
data
=
torch
.
cat
(
weight_h4toh_array
,
dim
=
0
)
moe_raw
.
weight_h4toh
.
data
=
torch
.
cat
(
weight_h4toh_array
,
dim
=
0
)
moe_raw
.
bias_h4toh
.
data
=
torch
.
cat
(
bias_h4toh_array
,
dim
=
0
)
moe_out
,
raw_out
=
_perform_forward
(
moe_out
,
raw_out
=
_perform_forward
(
moe
,
moe_raw
,
batch_size
,
d_model
,
top_k
,
rank
,
mp_group
moe
,
moe_raw
,
batch_size
,
d_model
,
top_k
,
rank
,
mp_group
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment