Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
4c90e6e8
Commit
4c90e6e8
authored
Apr 26, 2021
by
Rick Ho
Browse files
add gshard and switch gate with loss
parent
82103edb
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
95 additions
and
3 deletions
+95
-3
fmoe/gates/__init__.py
fmoe/gates/__init__.py
+2
-0
fmoe/gates/gshard_gate.py
fmoe/gates/gshard_gate.py
+32
-0
fmoe/gates/naive_gate.py
fmoe/gates/naive_gate.py
+3
-2
fmoe/gates/switch_gate.py
fmoe/gates/switch_gate.py
+57
-0
tests/test_numerical.py
tests/test_numerical.py
+1
-1
No files found.
fmoe/gates/__init__.py
View file @
4c90e6e8
...
@@ -5,3 +5,5 @@ from .zero_gate import ZeroGate
...
@@ -5,3 +5,5 @@ from .zero_gate import ZeroGate
from
.naive_gate
import
NaiveGate
from
.naive_gate
import
NaiveGate
from
.noisy_gate
import
NoisyGate
from
.noisy_gate
import
NoisyGate
from
.gshard_gate
import
GShardGate
from
.switch_gate
import
SwitchGate
fmoe/gates/gshard_gate.py
0 → 100644
View file @
4c90e6e8
r
"""
Balanced gate with GShard's policy (Google, 2020)
"""
import
torch
import
torch.nn.functional
as
F
from
.naive_gate
import
NaiveGate
class
GShardGate
(
NaiveGate
):
def
__init__
(
self
,
d_model
,
num_expert
,
world_size
,
capacity
=
(
1.2
,
2.4
)):
super
().
__init__
(
d_model
,
num_expert
,
world_size
,
top_k
=
2
)
self
.
capacity
=
capacity
def
forward
(
self
,
x
):
topk_idx
,
topk_val
,
gate_score
=
super
().
forward
(
x
)
S
=
gate_score
.
shape
[
0
]
top_k
=
topk_idx
.
shape
[
0
]
//
gate_score
.
shape
[
0
]
top1_idx
=
topk_idx
.
view
((
-
1
,
top_k
))[:,
0
]
c_e
=
torch
.
scatter_add
(
torch
.
zeros
(
self
.
num_expert
,
device
=
gate_top_1_idx
.
device
),
0
,
top1_idx
,
torch
.
ones_like
(
top1_idx
,
dtype
=
torch
.
float
),
)
/
S
m_e
=
torch
.
mean
(
F
.
softmax
(
gate_score
,
dim
=
1
),
dim
=
0
)
loss
=
torch
.
mean
(
c_e
*
m_e
)
*
(
self
.
num_expert
**
2
)
self
.
set_loss
(
loss
)
# TODO: capacity limit
return
topk_idx
,
topk_val
fmoe/gates/naive_gate.py
View file @
4c90e6e8
...
@@ -19,7 +19,7 @@ class NaiveGate(BaseGate):
...
@@ -19,7 +19,7 @@ class NaiveGate(BaseGate):
"""
"""
def
__init__
(
self
,
d_model
,
num_expert
,
world_size
,
top_k
=
2
):
def
__init__
(
self
,
d_model
,
num_expert
,
world_size
,
top_k
=
2
):
super
().
__init__
()
super
().
__init__
(
num_expert
,
world_size
)
self
.
gate
=
nn
.
Linear
(
d_model
,
self
.
tot_expert
)
self
.
gate
=
nn
.
Linear
(
d_model
,
self
.
tot_expert
)
self
.
top_k
=
top_k
self
.
top_k
=
top_k
...
@@ -38,5 +38,6 @@ class NaiveGate(BaseGate):
...
@@ -38,5 +38,6 @@ class NaiveGate(BaseGate):
gate_score
=
F
.
softmax
(
gate_top_k_val
,
dim
=-
1
).
unsqueeze
(
1
)
gate_score
=
F
.
softmax
(
gate_top_k_val
,
dim
=-
1
).
unsqueeze
(
1
)
gate_top_k_idx
=
gate_top_k_idx
.
view
(
-
1
)
# (BxLxtop_k)
gate_top_k_idx
=
gate_top_k_idx
.
view
(
-
1
)
# (BxLxtop_k)
return
gate_top_k_idx
,
gate_score
# TODO: capacity
return
gate_top_k_idx
,
gate_score
fmoe/gates/switch_gate.py
0 → 100644
View file @
4c90e6e8
r
"""
Balanced gate with Switch Transformer's policy (Google, 2021)
"""
import
torch
import
torch.nn.functional
as
F
from
.naive_gate
import
NaiveGate
class
SwitchGate
(
NaiveGate
):
r
"""
A switch gate implementation
"""
def
__init__
(
self
,
d_model
,
num_expert
,
world_size
,
switch_eps
=
.
1
,
capacity
=
(
1.2
,
2.4
)):
super
().
__init__
(
d_model
,
num_expert
,
world_size
,
top_k
=
1
)
self
.
gate
=
nn
.
Linear
(
d_model
,
num_expert
*
world_size
)
self
.
switch_eps
=
switch_eps
self
.
capacity
=
capacity
def
forward
(
self
,
inp
):
r
"""
The switch firstly conduct softmax and then calculates the top-1
"""
gate
=
super
().
forward
(
inp
)
if
self
.
training
:
# random uniform number from [1-eps, 1+eps]
noise
=
torch
.
rand_like
(
gate
)
noise
=
noise
*
2
*
self
.
switch_eps
+
1.0
-
self
.
switch_eps
gate
+=
noise
# fp32 softmax for numerical stability
gate_score
=
F
.
softmax
(
gate
.
float
(),
dim
=-
1
)
gate_score_top1
,
gate_idx_top1
=
torch
.
topk
(
gate_score_clip
,
k
=
1
,
dim
=-
1
,
largest
=
True
)
# [.. x top_k]
gate_score
=
gate_score
.
to
(
dtype
=
inp
.
dtype
)
gate_score_top1
=
gate_score_top1
.
to
(
dtype
=
inp
.
dtype
)
gate_score_top1
=
gate_score_top1
.
unsqueeze
(
1
)
gate_idx_top1
=
gate_idx_top1
.
view
(
-
1
)
# (BxLxtop_k)
# TODO: capacity limit
# TODO: not testd, the following code is super dangerous!!!!!!
gate_updated
=
gate_idx_top1
gate_updated
=
gate_updated
[
gate_updated
>
-
1
]
fraction_expert
=
torch
.
scatter_add
(
torch
.
zeros
(
self
.
tot_expert
,
device
=
gate_updated
.
device
),
0
,
gate_updated
,
torch
.
ones_like
(
gate_updated
,
dtype
=
torch
.
float
),
)
/
gate_updated
.
view
(
-
1
).
size
(
0
)
prob_expert
=
gate_score
.
sum
(
dim
=
0
)
/
gate_updated
.
view
(
-
1
).
size
(
0
)
switch_aux_loss
=
(
fraction_expert
*
prob_expert
).
sum
()
*
self
.
tot_expert
self
.
set_loss
(
switch_aux_loss
)
return
gate_idx_top1
,
gate_score_top1
tests/test_numerical.py
View file @
4c90e6e8
...
@@ -38,7 +38,7 @@ def _perform_forward(
...
@@ -38,7 +38,7 @@ def _perform_forward(
inp
.
requires_grad
=
True
inp
.
requires_grad
=
True
inp_raw
.
requires_grad
=
True
inp_raw
.
requires_grad
=
True
gate_idx
,
gate_score
,
_
=
moe
.
gate
(
inp_raw
)
gate_idx
,
gate_score
=
moe
.
gate
(
inp_raw
)
inp_repeated
=
inp_raw
.
repeat_interleave
(
repeats
=
top_k
,
dim
=
0
)
inp_repeated
=
inp_raw
.
repeat_interleave
(
repeats
=
top_k
,
dim
=
0
)
moe_out
=
moe
(
inp
)
moe_out
=
moe
(
inp
)
raw_out
=
moe_raw
(
inp_repeated
,
gate_idx
,
gate_score
)
raw_out
=
moe_raw
(
inp_repeated
,
gate_idx
,
gate_score
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment