Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
38b334cc
Commit
38b334cc
authored
May 13, 2021
by
Rich Ho
Browse files
test switch gate
parent
ddfaaf49
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
74 additions
and
44 deletions
+74
-44
fmoe/gates/gshard_gate.py
fmoe/gates/gshard_gate.py
+3
-18
fmoe/gates/switch_gate.py
fmoe/gates/switch_gate.py
+25
-24
fmoe/gates/utils.py
fmoe/gates/utils.py
+24
-0
tests/test_gates.py
tests/test_gates.py
+22
-2
No files found.
fmoe/gates/gshard_gate.py
View file @
38b334cc
...
@@ -5,8 +5,7 @@ import math
...
@@ -5,8 +5,7 @@ import math
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
.naive_gate
import
NaiveGate
from
.naive_gate
import
NaiveGate
from
fmoe.functions
import
count_by_gate
from
.utils
import
limit_by_capacity
import
fmoe_cuda
as
fmoe_native
class
GShardGate
(
NaiveGate
):
class
GShardGate
(
NaiveGate
):
...
@@ -32,21 +31,7 @@ class GShardGate(NaiveGate):
...
@@ -32,21 +31,7 @@ class GShardGate(NaiveGate):
self
.
set_loss
(
loss
)
self
.
set_loss
(
loss
)
cap_rate
=
self
.
capacity
[
0
if
self
.
training
else
1
]
cap_rate
=
self
.
capacity
[
0
if
self
.
training
else
1
]
capacity
=
torch
.
ones
(
self
.
num_expert
,
dtype
=
torch
.
int32
,
capacity
=
math
.
ceil
(
cap_rate
*
x
.
shape
[
0
])
device
=
x
.
device
)
limit_by_capacity
(
topk_idx
,
self
.
num_expert
,
self
.
world_size
,
capacity
)
capacity
*=
math
.
ceil
(
cap_rate
*
x
.
shape
[
0
])
pos
,
lec
,
gec
=
count_by_gate
(
topk_idx
.
reshape
(
-
1
),
self
.
num_expert
,
self
.
world_size
)
new_gec
,
=
fmoe_native
.
limit_by_capacity
(
gec
,
capacity
,
self
.
num_expert
,
self
.
world_size
)
if
self
.
world_size
>
1
:
new_lec
=
fmoe_native
.
expert_exchange
(
new_gec
,
self
.
num_expert
,
self
.
world_size
)
else
:
new_lec
=
new_gec
fmoe_native
.
prune_gate_by_capacity
(
topk_idx
,
new_lec
.
to
(
torch
.
int32
),
self
.
num_expert
,
self
.
world_size
)
return
topk_idx
,
topk_val
return
topk_idx
,
topk_val
fmoe/gates/switch_gate.py
View file @
38b334cc
r
"""
r
"""
Balanced gate with Switch Transformer's policy (Google, 2021)
Balanced gate with Switch Transformer's policy (Google, 2021)
"""
"""
import
math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
.naive_gate
import
NaiveGate
from
.naive_gate
import
NaiveGate
from
.utils
import
limit_by_capacity
class
SwitchGate
(
NaiveGate
):
class
SwitchGate
(
NaiveGate
):
r
"""
r
"""
...
@@ -13,7 +17,6 @@ class SwitchGate(NaiveGate):
...
@@ -13,7 +17,6 @@ class SwitchGate(NaiveGate):
def
__init__
(
self
,
d_model
,
num_expert
,
world_size
,
def
__init__
(
self
,
d_model
,
num_expert
,
world_size
,
switch_eps
=
.
1
,
capacity
=
(
1.2
,
2.4
)):
switch_eps
=
.
1
,
capacity
=
(
1.2
,
2.4
)):
super
().
__init__
(
d_model
,
num_expert
,
world_size
,
top_k
=
1
)
super
().
__init__
(
d_model
,
num_expert
,
world_size
,
top_k
=
1
)
self
.
gate
=
nn
.
Linear
(
d_model
,
num_expert
*
world_size
)
self
.
switch_eps
=
switch_eps
self
.
switch_eps
=
switch_eps
self
.
capacity
=
capacity
self
.
capacity
=
capacity
...
@@ -21,37 +24,35 @@ class SwitchGate(NaiveGate):
...
@@ -21,37 +24,35 @@ class SwitchGate(NaiveGate):
r
"""
r
"""
The switch firstly conduct softmax and then calculates the top-1
The switch firstly conduct softmax and then calculates the top-1
"""
"""
gate
=
super
().
forward
(
inp
)
score
=
self
.
gate
(
inp
)
if
self
.
training
:
if
self
.
training
:
# random uniform number from [1-eps, 1+eps]
# random uniform number from [1-eps, 1+eps]
noise
=
torch
.
rand_like
(
gat
e
)
noise
=
torch
.
rand_like
(
scor
e
)
noise
=
noise
*
2
*
self
.
switch_eps
+
1.0
-
self
.
switch_eps
noise
=
noise
*
2
*
self
.
switch_eps
+
1.0
-
self
.
switch_eps
gat
e
+=
noise
scor
e
+=
noise
# fp32 softmax for numerical stability
# fp32 softmax for numerical stability
gate_
score
=
F
.
softmax
(
gat
e
.
float
(),
dim
=-
1
)
score
=
F
.
softmax
(
scor
e
.
float
(),
dim
=-
1
)
gate
_score
_
top1
,
gate_idx_top1
=
torch
.
topk
(
top1
_score
,
top1
_idx
=
torch
.
topk
(
gate_score_clip
,
k
=
1
,
dim
=-
1
,
largest
=
True
score
,
k
=
1
,
dim
=-
1
,
largest
=
True
)
# [.. x top_k]
)
# [.. x top_k]
gate_score
=
gate_score
.
to
(
dtype
=
inp
.
dtype
)
top1_score
=
top1_score
.
to
(
dtype
=
inp
.
dtype
)
gate_score_top1
=
gate_score_top1
.
to
(
dtype
=
inp
.
dtype
)
top1_score
=
top1_score
.
to
(
dtype
=
inp
.
dtype
)
gate_score_top1
=
gate_score_top1
.
unsqueeze
(
1
)
gate_idx_top1
=
gate_idx_top1
.
view
(
-
1
)
# (BxLxtop_k)
# TODO: capacity limit
cap_rate
=
self
.
capacity
[
0
if
self
.
training
else
1
]
capacity
=
math
.
ceil
(
cap_rate
*
inp
.
shape
[
0
])
limit_by_capacity
(
top1_idx
,
self
.
num_expert
,
self
.
world_size
,
capacity
)
# TODO: not testd, the following code is super dangerous!!!!!!
valid_idx
=
top1_idx
[
top1_idx
>
-
1
]
gate_updated
=
gate_idx_top1
gate_updated
=
gate_updated
[
gate_updated
>
-
1
]
fraction_expert
=
torch
.
scatter_add
(
fraction_expert
=
torch
.
scatter_add
(
torch
.
zeros
(
self
.
tot_expert
,
device
=
gate_updated
.
device
),
torch
.
zeros
(
self
.
tot_expert
,
device
=
valid_idx
.
device
),
0
,
0
,
gate_updated
,
valid_idx
,
torch
.
ones_like
(
gate_updated
,
dtype
=
torch
.
float
),
torch
.
ones_like
(
valid_idx
,
dtype
=
torch
.
float
),
)
/
gate_updated
.
view
(
-
1
).
size
(
0
)
)
/
valid_idx
.
numel
(
)
prob_expert
=
gate_
score
.
sum
(
dim
=
0
)
/
gate_updated
.
view
(
-
1
).
size
(
0
)
prob_expert
=
score
.
sum
(
dim
=
0
)
/
valid_idx
.
numel
(
)
switch_aux_
loss
=
(
fraction_expert
*
prob_expert
).
sum
()
*
self
.
tot_expert
loss
=
(
fraction_expert
*
prob_expert
).
sum
()
*
self
.
tot_expert
self
.
set_loss
(
switch_aux_
loss
)
self
.
set_loss
(
loss
)
return
gate
_idx
_
top1
,
gate
_score
_top1
return
top1
_idx
,
top1_score
fmoe/gates/utils.py
0 → 100644
View file @
38b334cc
r
"""
Utilities that may be used in the gates
"""
import
torch
from
fmoe.functions
import
count_by_gate
import
fmoe_cuda
as
fmoe_native
def
limit_by_capacity
(
topk_idx
,
num_expert
,
world_size
,
capacity
):
capacity
=
torch
.
ones
(
num_expert
,
dtype
=
torch
.
int32
,
device
=
topk_idx
.
device
)
*
capacity
pos
,
lec
,
gec
=
count_by_gate
(
topk_idx
.
reshape
(
-
1
),
num_expert
,
world_size
)
new_gec
,
=
fmoe_native
.
limit_by_capacity
(
gec
,
capacity
,
num_expert
,
world_size
)
if
world_size
>
1
:
new_lec
=
fmoe_native
.
expert_exchange
(
new_gec
,
num_expert
,
world_size
)
else
:
new_lec
=
new_gec
fmoe_native
.
prune_gate_by_capacity
(
topk_idx
,
new_lec
.
to
(
torch
.
int32
),
num_expert
,
world_size
)
return
new_lec
,
new_gec
tests/test_gates.py
View file @
38b334cc
...
@@ -3,7 +3,7 @@ import os
...
@@ -3,7 +3,7 @@ import os
import
math
import
math
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
fmoe.gates
import
GShardGate
from
fmoe.gates
import
GShardGate
,
SwitchGate
def
_ensure_initialized
():
def
_ensure_initialized
():
...
@@ -37,6 +37,26 @@ def test_gshard_gate(d_model, batch_size, n_expert, cap):
...
@@ -37,6 +37,26 @@ def test_gshard_gate(d_model, batch_size, n_expert, cap):
assert
(
i
<=
real_cap
)
assert
(
i
<=
real_cap
)
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
8
,
1024
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
16
,
4096
])
@
pytest
.
mark
.
parametrize
(
"n_expert"
,
[
1
,
4
,
16
])
@
pytest
.
mark
.
parametrize
(
"cap"
,
[.
1
,
.
5
,
1.1
])
def
test_switch_gate
(
d_model
,
batch_size
,
n_expert
,
cap
):
_ensure_initialized
()
gate
=
SwitchGate
(
d_model
,
n_expert
,
dist
.
get_world_size
(),
capacity
=
(
cap
,
cap
)).
cuda
()
x
=
torch
.
rand
(
batch_size
,
d_model
).
cuda
()
topk_idx
,
topk_val
=
gate
(
x
)
counts
=
[
0
for
_
in
range
(
n_expert
)]
for
v
in
topk_idx
.
cpu
().
view
(
-
1
).
numpy
():
if
v
!=
-
1
:
counts
[
v
]
+=
1
real_cap
=
math
.
ceil
(
cap
*
batch_size
)
for
i
in
counts
:
assert
(
i
<=
real_cap
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
_ensure_initialized
()
_ensure_initialized
()
test_gshard_gate
(
4096
,
1024
,
4
,
.
2
)
# test_gshard_gate(4096, 1024, 4, .2)
test_switch_gate
(
4096
,
1024
,
4
,
.
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment