Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
6e1fcca1
Commit
6e1fcca1
authored
Mar 28, 2022
by
Rick Ho
Browse files
faster gate develop
parent
98e24e24
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
116 additions
and
0 deletions
+116
-0
fmoe/gates/faster_gate.py
fmoe/gates/faster_gate.py
+116
-0
No files found.
fmoe/gates/faster_gate.py
0 → 100644
View file @
6e1fcca1
r
"""
The example topology-aware gate for two-layer tree-like topology, proposed by
the PPoPP'22 paper, FasterMoE. Limited number of tokens are sent across the
upper-level slow connection, and other ones are re-directed to experts in the
local network.
The number of GPUs to form such a local network is defined by an environment
variable `FMOE_TOPO_GPUS_PER_NODE`, and it is by default `8`.
The fraction of tokens that are allowed to be sent across nodes is defined by
an environement variable `FMOE_TOPO_OUTGOING_FRACTION`, and it is by default
`0.14`. Users are supposed to set the proper value in their own environemnt,
guided by some performance model, to achieve maximum throughput.
"""
from
.naive_gate
import
NaiveGate
import
os
import
sys
import
torch
import
torch.nn.functional
as
F
from
.utils
import
limit_by_capacity
import
fmoe_cuda
from
fmoe.functions
import
count_by_gate
nw_per_node
=
8
try
:
nw_per_node
=
int
(
os
.
environ
[
'FMOE_TOPO_GPUS_PER_NODE'
])
except
Exception
:
pass
class
FasterGate
(
NaiveGate
):
def
__init__
(
self
,
d_model
,
n_expert
,
world_size
,
node_rank
):
super
().
__init__
(
d_model
,
n_expert
,
world_size
,
top_k
=
2
)
self
.
ne_per_node
=
nw_per_node
*
n_expert
self
.
ogn_ratio
=
.
14
try
:
self
.
ogn_ratio
=
float
(
os
.
environ
[
'FMOE_TOPO_OUTGOING_FRACTION'
])
except
Exception
:
pass
self
.
node_rank
=
node_rank
mask
=
[
1
]
*
world_size
*
n_expert
for
i
in
range
(
n_expert
*
world_size
):
if
i
//
self
.
ne_per_node
==
self
.
node_rank
:
mask
[
i
]
=
0
self
.
mask
=
torch
.
Tensor
(
mask
).
bool
()
self
.
policy_fn
=
None
print
(
'node rank {} mask {}'
.
format
(
node_rank
,
mask
))
def
forward
(
self
,
inp
):
if
self
.
mask
.
device
!=
inp
.
device
:
self
.
mask
=
self
.
mask
.
to
(
inp
.
device
)
gate_score
=
self
.
gate
(
inp
)
lim_mask
=
self
.
mask
top2_val
,
top2_idx
=
torch
.
topk
(
gate_score
,
k
=
2
,
dim
=-
1
)
S
=
gate_score
.
shape
[
0
]
top_k
=
2
with
torch
.
no_grad
():
top1_idx
=
top2_idx
.
view
((
-
1
,
top_k
))[:,
0
]
top1_val
=
top2_val
.
view
((
-
1
,
top_k
))[:,
0
]
c_e
=
torch
.
scatter_add
(
torch
.
zeros
(
self
.
tot_expert
,
device
=
top1_idx
.
device
),
0
,
top1_idx
,
torch
.
ones_like
(
top1_idx
,
dtype
=
torch
.
float
),
)
/
S
m_e
=
torch
.
mean
(
F
.
softmax
(
gate_score
,
dim
=
1
),
dim
=
0
)
loss
=
torch
.
mean
(
c_e
*
m_e
)
*
(
self
.
num_expert
**
2
)
self
.
set_loss
(
loss
)
with
torch
.
no_grad
():
if
self
.
policy_fn
is
None
:
stored_models
=
torch
.
zeros
(
self
.
num_expert
*
self
.
world_size
,
dtype
=
torch
.
bool
)
else
:
# TODO: Fix this after expert shadowing is ported
_
,
lec
,
aec
,
gec
,
agec
=
count_by_gate
(
top2_idx
,
self
.
num_expert
,
self
.
world_size
,
require_pos
=
False
)
stored_models
=
self
.
policy_fn
(
aec
,
agec
,
self
.
num_expert
,
self
.
world_size
,
inp
.
shape
[
-
1
],
True
)
lim_mask
=
lim_mask
&
~
stored_models
.
view
(
-
1
).
to
(
lim_mask
.
device
)
ogn_mask
=
lim_mask
[
top1_idx
]
ogn_thres
=
int
(
inp
.
shape
[
0
]
*
self
.
ogn_ratio
)
if
ogn_mask
.
sum
().
item
()
<
ogn_thres
:
topk_val
,
topk_idx
=
torch
.
topk
(
gate_score
,
k
=
self
.
top_k
)
topk_val
=
F
.
softmax
(
topk_val
,
dim
=-
1
)
return
topk_idx
,
topk_val
with
torch
.
no_grad
():
top1_val
[
~
ogn_mask
]
=
float
(
'-inf'
)
_
,
top_ogn
=
torch
.
topk
(
top1_val
.
view
(
-
1
),
k
=
ogn_thres
)
cand
=
gate_score
.
clone
()
cand
[:,
lim_mask
]
=
float
(
'-inf'
)
_
,
topk_idx
=
torch
.
topk
(
cand
,
k
=
self
.
top_k
)
topk_idx
[
top_ogn
,
1
]
=
top1_idx
.
view
(
-
1
)[
top_ogn
]
idx_x
=
torch
.
arange
(
inp
.
shape
[
0
],
device
=
inp
.
device
).
repeat_interleave
(
2
)
topk_val
=
gate_score
[
idx_x
,
topk_idx
.
view
(
-
1
)].
view
(
-
1
,
self
.
top_k
)
topk_val
=
F
.
softmax
(
topk_val
,
dim
=-
1
)
return
topk_idx
,
topk_val
def
gen_faster_gate
(
rank
):
def
_gen
(
d_model
,
n_expert
,
world_size
,
top_k
=
2
):
assert
top_k
==
2
return
FasterGate
(
d_model
,
n_expert
,
world_size
,
rank
//
nw_per_node
)
return
_gen
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment