Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
38b334cc
Commit
38b334cc
authored
May 13, 2021
by
Rich Ho
Browse files
test switch gate
parent
ddfaaf49
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
74 additions
and
44 deletions
+74
-44
fmoe/gates/gshard_gate.py
fmoe/gates/gshard_gate.py
+3
-18
fmoe/gates/switch_gate.py
fmoe/gates/switch_gate.py
+25
-24
fmoe/gates/utils.py
fmoe/gates/utils.py
+24
-0
tests/test_gates.py
tests/test_gates.py
+22
-2
No files found.
fmoe/gates/gshard_gate.py
View file @
38b334cc
...
...
@@ -5,8 +5,7 @@ import math
import
torch
import
torch.nn.functional
as
F
from
.naive_gate
import
NaiveGate
from
fmoe.functions
import
count_by_gate
import
fmoe_cuda
as
fmoe_native
from
.utils
import
limit_by_capacity
class
GShardGate
(
NaiveGate
):
...
...
@@ -32,21 +31,7 @@ class GShardGate(NaiveGate):
self
.
set_loss
(
loss
)
cap_rate
=
self
.
capacity
[
0
if
self
.
training
else
1
]
capacity
=
torch
.
ones
(
self
.
num_expert
,
dtype
=
torch
.
int32
,
device
=
x
.
device
)
capacity
*=
math
.
ceil
(
cap_rate
*
x
.
shape
[
0
])
pos
,
lec
,
gec
=
count_by_gate
(
topk_idx
.
reshape
(
-
1
),
self
.
num_expert
,
self
.
world_size
)
new_gec
,
=
fmoe_native
.
limit_by_capacity
(
gec
,
capacity
,
self
.
num_expert
,
self
.
world_size
)
if
self
.
world_size
>
1
:
new_lec
=
fmoe_native
.
expert_exchange
(
new_gec
,
self
.
num_expert
,
self
.
world_size
)
else
:
new_lec
=
new_gec
fmoe_native
.
prune_gate_by_capacity
(
topk_idx
,
new_lec
.
to
(
torch
.
int32
),
self
.
num_expert
,
self
.
world_size
)
capacity
=
math
.
ceil
(
cap_rate
*
x
.
shape
[
0
])
limit_by_capacity
(
topk_idx
,
self
.
num_expert
,
self
.
world_size
,
capacity
)
return
topk_idx
,
topk_val
fmoe/gates/switch_gate.py
View file @
38b334cc
r
"""
Balanced gate with Switch Transformer's policy (Google, 2021)
"""
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
.naive_gate
import
NaiveGate
from
.utils
import
limit_by_capacity
class
SwitchGate
(
NaiveGate
):
r
"""
...
...
@@ -13,7 +17,6 @@ class SwitchGate(NaiveGate):
def
__init__
(
self
,
d_model
,
num_expert
,
world_size
,
switch_eps
=
.
1
,
capacity
=
(
1.2
,
2.4
)):
super
().
__init__
(
d_model
,
num_expert
,
world_size
,
top_k
=
1
)
self
.
gate
=
nn
.
Linear
(
d_model
,
num_expert
*
world_size
)
self
.
switch_eps
=
switch_eps
self
.
capacity
=
capacity
...
...
@@ -21,37 +24,35 @@ class SwitchGate(NaiveGate):
r
"""
The switch firstly conduct softmax and then calculates the top-1
"""
gate
=
super
().
forward
(
inp
)
score
=
self
.
gate
(
inp
)
if
self
.
training
:
# random uniform number from [1-eps, 1+eps]
noise
=
torch
.
rand_like
(
gat
e
)
noise
=
torch
.
rand_like
(
scor
e
)
noise
=
noise
*
2
*
self
.
switch_eps
+
1.0
-
self
.
switch_eps
gat
e
+=
noise
scor
e
+=
noise
# fp32 softmax for numerical stability
gate_
score
=
F
.
softmax
(
gat
e
.
float
(),
dim
=-
1
)
score
=
F
.
softmax
(
scor
e
.
float
(),
dim
=-
1
)
gate
_score
_
top1
,
gate_idx_top1
=
torch
.
topk
(
gate_score_clip
,
k
=
1
,
dim
=-
1
,
largest
=
True
top1
_score
,
top1
_idx
=
torch
.
topk
(
score
,
k
=
1
,
dim
=-
1
,
largest
=
True
)
# [.. x top_k]
gate_score
=
gate_score
.
to
(
dtype
=
inp
.
dtype
)
gate_score_top1
=
gate_score_top1
.
to
(
dtype
=
inp
.
dtype
)
gate_score_top1
=
gate_score_top1
.
unsqueeze
(
1
)
gate_idx_top1
=
gate_idx_top1
.
view
(
-
1
)
# (BxLxtop_k)
top1_score
=
top1_score
.
to
(
dtype
=
inp
.
dtype
)
top1_score
=
top1_score
.
to
(
dtype
=
inp
.
dtype
)
# TODO: capacity limit
cap_rate
=
self
.
capacity
[
0
if
self
.
training
else
1
]
capacity
=
math
.
ceil
(
cap_rate
*
inp
.
shape
[
0
])
limit_by_capacity
(
top1_idx
,
self
.
num_expert
,
self
.
world_size
,
capacity
)
# TODO: not testd, the following code is super dangerous!!!!!!
gate_updated
=
gate_idx_top1
gate_updated
=
gate_updated
[
gate_updated
>
-
1
]
valid_idx
=
top1_idx
[
top1_idx
>
-
1
]
fraction_expert
=
torch
.
scatter_add
(
torch
.
zeros
(
self
.
tot_expert
,
device
=
gate_updated
.
device
),
torch
.
zeros
(
self
.
tot_expert
,
device
=
valid_idx
.
device
),
0
,
gate_updated
,
torch
.
ones_like
(
gate_updated
,
dtype
=
torch
.
float
),
)
/
gate_updated
.
view
(
-
1
).
size
(
0
)
prob_expert
=
gate_
score
.
sum
(
dim
=
0
)
/
gate_updated
.
view
(
-
1
).
size
(
0
)
switch_aux_
loss
=
(
fraction_expert
*
prob_expert
).
sum
()
*
self
.
tot_expert
self
.
set_loss
(
switch_aux_
loss
)
return
gate
_idx
_
top1
,
gate
_score
_top1
valid_idx
,
torch
.
ones_like
(
valid_idx
,
dtype
=
torch
.
float
),
)
/
valid_idx
.
numel
(
)
prob_expert
=
score
.
sum
(
dim
=
0
)
/
valid_idx
.
numel
(
)
loss
=
(
fraction_expert
*
prob_expert
).
sum
()
*
self
.
tot_expert
self
.
set_loss
(
loss
)
return
top1
_idx
,
top1_score
fmoe/gates/utils.py
0 → 100644
View file @
38b334cc
r
"""
Utilities that may be used in the gates
"""
import
torch
from
fmoe.functions
import
count_by_gate
import
fmoe_cuda
as
fmoe_native
def
limit_by_capacity
(
topk_idx
,
num_expert
,
world_size
,
capacity
):
capacity
=
torch
.
ones
(
num_expert
,
dtype
=
torch
.
int32
,
device
=
topk_idx
.
device
)
*
capacity
pos
,
lec
,
gec
=
count_by_gate
(
topk_idx
.
reshape
(
-
1
),
num_expert
,
world_size
)
new_gec
,
=
fmoe_native
.
limit_by_capacity
(
gec
,
capacity
,
num_expert
,
world_size
)
if
world_size
>
1
:
new_lec
=
fmoe_native
.
expert_exchange
(
new_gec
,
num_expert
,
world_size
)
else
:
new_lec
=
new_gec
fmoe_native
.
prune_gate_by_capacity
(
topk_idx
,
new_lec
.
to
(
torch
.
int32
),
num_expert
,
world_size
)
return
new_lec
,
new_gec
tests/test_gates.py
View file @
38b334cc
...
...
@@ -3,7 +3,7 @@ import os
import
math
import
torch
import
torch.distributed
as
dist
from
fmoe.gates
import
GShardGate
from
fmoe.gates
import
GShardGate
,
SwitchGate
def
_ensure_initialized
():
...
...
@@ -37,6 +37,26 @@ def test_gshard_gate(d_model, batch_size, n_expert, cap):
assert
(
i
<=
real_cap
)
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
8
,
1024
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
16
,
4096
])
@
pytest
.
mark
.
parametrize
(
"n_expert"
,
[
1
,
4
,
16
])
@
pytest
.
mark
.
parametrize
(
"cap"
,
[.
1
,
.
5
,
1.1
])
def
test_switch_gate
(
d_model
,
batch_size
,
n_expert
,
cap
):
_ensure_initialized
()
gate
=
SwitchGate
(
d_model
,
n_expert
,
dist
.
get_world_size
(),
capacity
=
(
cap
,
cap
)).
cuda
()
x
=
torch
.
rand
(
batch_size
,
d_model
).
cuda
()
topk_idx
,
topk_val
=
gate
(
x
)
counts
=
[
0
for
_
in
range
(
n_expert
)]
for
v
in
topk_idx
.
cpu
().
view
(
-
1
).
numpy
():
if
v
!=
-
1
:
counts
[
v
]
+=
1
real_cap
=
math
.
ceil
(
cap
*
batch_size
)
for
i
in
counts
:
assert
(
i
<=
real_cap
)
if
__name__
==
'__main__'
:
_ensure_initialized
()
test_gshard_gate
(
4096
,
1024
,
4
,
.
2
)
# test_gshard_gate(4096, 1024, 4, .2)
test_switch_gate
(
4096
,
1024
,
4
,
.
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment