Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
6cb550fd
Commit
6cb550fd
authored
Dec 14, 2022
by
Rick Ho
Browse files
diverge gshard gate
parent
cdc140f1
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
56 additions
and
3 deletions
+56
-3
fmoe/gates/__init__.py
fmoe/gates/__init__.py
+1
-0
fmoe/gates/dc_gate.py
fmoe/gates/dc_gate.py
+48
-0
fmoe/gates/gshard_gate.py
fmoe/gates/gshard_gate.py
+6
-3
tests/test_ddp.py
tests/test_ddp.py
+1
-0
No files found.
fmoe/gates/__init__.py
View file @
6cb550fd
...
@@ -7,5 +7,6 @@ from .noisy_gate import NoisyGate
...
@@ -7,5 +7,6 @@ from .noisy_gate import NoisyGate
from
.gshard_gate
import
GShardGate
from
.gshard_gate
import
GShardGate
from
.switch_gate
import
SwitchGate
from
.switch_gate
import
SwitchGate
from
.dc_gate
import
DCGate
from
.swipe_gate
import
SwipeGate
from
.swipe_gate
import
SwipeGate
fmoe/gates/dc_gate.py
0 → 100644
View file @
6cb550fd
r
"""
Distributed Capacity gate, extended from GShard gate.
Instead of setting capacity based on local batch size and expert count,
the global load of each experts are calculated, and then the experts make
decisions of capacities on each worker.
"""
import
math
import
torch
import
torch.nn.functional
as
F
from
.naive_gate
import
NaiveGate
from
.utils
import
limit_by_capacity
class
DCGate
(
NaiveGate
):
def
__init__
(
self
,
d_model
,
num_expert
,
world_size
,
topk
=
2
,
capacity
=
(
1.2
,
2.4
),
random_routing
=
True
):
assert
topk
==
2
,
'topk should be 2 in gshard'
super
().
__init__
(
d_model
,
num_expert
,
world_size
,
top_k
=
2
)
self
.
capacity
=
capacity
self
.
random_routing
=
random_routing
def
forward
(
self
,
x
):
naive_outs
=
super
().
forward
(
x
,
return_all_scores
=
True
)
topk_idx
,
topk_val
,
gate_score
=
naive_outs
S
=
gate_score
.
shape
[
0
]
top1_idx
=
topk_idx
.
view
((
-
1
,
self
.
top_k
))[:,
0
]
c_e
=
torch
.
scatter_add
(
torch
.
zeros
(
self
.
tot_expert
,
device
=
top1_idx
.
device
),
0
,
top1_idx
,
torch
.
ones_like
(
top1_idx
,
dtype
=
torch
.
float
),
)
/
S
m_e
=
torch
.
mean
(
F
.
softmax
(
gate_score
,
dim
=
1
),
dim
=
0
)
loss
=
torch
.
mean
(
c_e
*
m_e
)
*
(
self
.
num_expert
**
2
)
self
.
set_loss
(
loss
)
cap_rate
=
self
.
capacity
[
0
if
self
.
training
else
1
]
capacity
=
math
.
ceil
(
cap_rate
*
x
.
shape
[
0
])
_new_lec
,
_new_gec
,
topk_idx
=
limit_by_capacity
(
topk_idx
,
self
.
num_expert
,
self
.
world_size
,
capacity
)
if
self
.
random_routing
:
rand_routing_prob
=
torch
.
rand
(
gate_score
.
size
(
0
),
device
=
x
.
device
)
mask
=
(
2
*
topk_val
[:,
1
]
<
rand_routing_prob
)
topk_idx
[:,
1
].
masked_fill_
(
mask
,
-
1
)
return
topk_idx
,
topk_val
fmoe/gates/gshard_gate.py
View file @
6cb550fd
...
@@ -6,6 +6,7 @@ import torch
...
@@ -6,6 +6,7 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
.naive_gate
import
NaiveGate
from
.naive_gate
import
NaiveGate
from
.utils
import
limit_by_capacity
from
.utils
import
limit_by_capacity
import
fmoe_cuda
as
fmoe_native
class
GShardGate
(
NaiveGate
):
class
GShardGate
(
NaiveGate
):
...
@@ -33,9 +34,11 @@ class GShardGate(NaiveGate):
...
@@ -33,9 +34,11 @@ class GShardGate(NaiveGate):
self
.
set_loss
(
loss
)
self
.
set_loss
(
loss
)
cap_rate
=
self
.
capacity
[
0
if
self
.
training
else
1
]
cap_rate
=
self
.
capacity
[
0
if
self
.
training
else
1
]
capacity
=
math
.
ceil
(
cap_rate
*
x
.
shape
[
0
])
capacity
=
math
.
ceil
(
cap_rate
*
x
.
shape
[
0
])
//
self
.
world_size
_new_lec
,
_new_gec
,
topk_idx
=
limit_by_capacity
(
capacity
=
torch
.
ones
(
self
.
num_expert
*
self
.
world_size
,
topk_idx
,
self
.
num_expert
,
self
.
world_size
,
capacity
)
dtype
=
torch
.
int32
,
device
=
topk_idx
.
device
)
*
capacity
topk_idx
=
fmoe_native
.
prune_gate_by_capacity
(
topk_idx
,
capacity
,
self
.
num_expert
,
self
.
world_size
)
if
self
.
random_routing
:
if
self
.
random_routing
:
rand_routing_prob
=
torch
.
rand
(
gate_score
.
size
(
0
),
device
=
x
.
device
)
rand_routing_prob
=
torch
.
rand
(
gate_score
.
size
(
0
),
device
=
x
.
device
)
...
...
tests/test_ddp.py
View file @
6cb550fd
...
@@ -35,6 +35,7 @@ def _run_distributed(func, world_size, args: Dict, script=__file__, env=dict()):
...
@@ -35,6 +35,7 @@ def _run_distributed(func, world_size, args: Dict, script=__file__, env=dict()):
env
[
"MASTER_ADDR"
]
=
"localhost"
env
[
"MASTER_ADDR"
]
=
"localhost"
env
[
"MASTER_PORT"
]
=
str
(
random
.
randint
(
50000
,
60000
))
env
[
"MASTER_PORT"
]
=
str
(
random
.
randint
(
50000
,
60000
))
env
[
"OMPI_COMM_WORLD_SIZE"
]
=
str
(
world_size
)
env
[
"OMPI_COMM_WORLD_SIZE"
]
=
str
(
world_size
)
env
[
"LD_LIBRARY_PATH"
]
=
os
.
environ
.
get
(
"LD_LIBRARY_PATH"
)
for
i
in
range
(
world_size
):
for
i
in
range
(
world_size
):
env
[
"OMPI_COMM_WORLD_RANK"
]
=
str
(
i
)
env
[
"OMPI_COMM_WORLD_RANK"
]
=
str
(
i
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment