Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
ddfaaf49
Commit
ddfaaf49
authored
May 13, 2021
by
Rich Ho
Browse files
gshard gate test
parent
5a0ba835
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
43 additions
and
18 deletions
+43
-18
fmoe/gates/gshard_gate.py
fmoe/gates/gshard_gate.py
+5
-7
fmoe/gates/naive_gate.py
fmoe/gates/naive_gate.py
+4
-2
fmoe/layers.py
fmoe/layers.py
+1
-0
tests/moe.py
tests/moe.py
+2
-0
tests/test_gates.py
tests/test_gates.py
+31
-9
No files found.
fmoe/gates/gshard_gate.py
View file @
ddfaaf49
...
...
@@ -15,7 +15,8 @@ class GShardGate(NaiveGate):
self
.
capacity
=
capacity
def
forward
(
self
,
x
):
topk_idx
,
gate_score
=
super
().
forward
(
x
)
naive_outs
=
super
().
forward
(
x
,
return_all_scores
=
True
)
topk_idx
,
topk_val
,
gate_score
=
naive_outs
S
=
gate_score
.
shape
[
0
]
top_k
=
topk_idx
.
shape
[
0
]
//
gate_score
.
shape
[
0
]
...
...
@@ -31,22 +32,19 @@ class GShardGate(NaiveGate):
self
.
set_loss
(
loss
)
cap_rate
=
self
.
capacity
[
0
if
self
.
training
else
1
]
capacity
=
torch
.
ones
(
self
.
num_expert
,
dtype
=
torch
.
int32
)
capacity
=
torch
.
ones
(
self
.
num_expert
,
dtype
=
torch
.
int32
,
device
=
x
.
device
)
capacity
*=
math
.
ceil
(
cap_rate
*
x
.
shape
[
0
])
print
(
topk_idx
)
pos
,
lec
,
gec
=
count_by_gate
(
gate_score
,
self
.
num_expert
,
pos
,
lec
,
gec
=
count_by_gate
(
topk_idx
.
reshape
(
-
1
),
self
.
num_expert
,
self
.
world_size
)
print
(
topk_idx
)
new_gec
,
=
fmoe_native
.
limit_by_capacity
(
gec
,
capacity
,
self
.
num_expert
,
self
.
world_size
)
print
(
topk_idx
)
if
self
.
world_size
>
1
:
new_lec
=
fmoe_native
.
expert_exchange
(
new_gec
,
self
.
num_expert
,
self
.
world_size
)
else
:
new_lec
=
new_gec
print
(
topk_idx
)
fmoe_native
.
prune_gate_by_capacity
(
topk_idx
,
new_lec
.
to
(
torch
.
int32
),
self
.
num_expert
,
self
.
world_size
)
...
...
fmoe/gates/naive_gate.py
View file @
ddfaaf49
...
...
@@ -23,7 +23,7 @@ class NaiveGate(BaseGate):
self
.
gate
=
nn
.
Linear
(
d_model
,
self
.
tot_expert
)
self
.
top_k
=
top_k
def
forward
(
self
,
inp
):
def
forward
(
self
,
inp
,
return_all_scores
=
False
):
r
"""
The naive implementation simply calculates the top-k of a linear layer's
output.
...
...
@@ -38,4 +38,6 @@ class NaiveGate(BaseGate):
gate_score
=
F
.
softmax
(
gate_top_k_val
,
dim
=-
1
)
gate_top_k_idx
=
gate_top_k_idx
.
view
(
-
1
)
# (BxLxtop_k)
return
gate_top_k_idx
,
gate
if
return_all_scores
:
return
gate_top_k_idx
,
gate_top_k_val
,
gate
return
gate_top_k_idx
,
gate_top_k_val
fmoe/layers.py
View file @
ddfaaf49
...
...
@@ -225,6 +225,7 @@ class FMoE(nn.Module):
# to: (BxL) x top_k x d_model
x
=
x
.
view
(
-
1
,
self
.
top_k
,
self
.
d_model
)
# to: (BxL) x d_model
gate_score
=
gate_score
.
unsqueeze
(
1
)
x
=
torch
.
bmm
(
gate_score
,
x
).
reshape
(
-
1
,
self
.
d_model
)
if
self
.
mp_size
>
1
:
...
...
tests/moe.py
View file @
ddfaaf49
...
...
@@ -40,6 +40,7 @@ class BruteForceMoELinear(nn.Module):
x
=
x
@
self
.
weight_h4toh
[
i
].
t
()
x
=
x
+
self
.
bias_h4toh
[
i
]
o
[
idx
]
=
x
gate_score
=
gate_score
.
unsqueeze
(
1
)
x
=
torch
.
bmm
(
gate_score
,
o
.
view
(
-
1
,
self
.
top_k
,
self
.
d_model
)).
reshape
(
-
1
,
self
.
d_model
)
...
...
@@ -60,6 +61,7 @@ class BruteForceMoE(nn.Module):
x
=
inp
.
new_zeros
((
batch_size
,
self
.
d_model
))
for
i
in
range
(
batch_size
):
x
[
i
]
=
self
.
experts
[
gate_long
[
i
]](
inp
[
i
])
gate_score
=
gate_score
.
unsqueeze
(
1
)
x
=
torch
.
bmm
(
gate_score
,
x
.
view
(
-
1
,
self
.
top_k
,
self
.
d_model
)).
reshape
(
-
1
,
self
.
d_model
)
...
...
tests/test_gates.py
View file @
ddfaaf49
import
pytest
import
os
import
math
import
torch
import
torch.distributed
as
dist
from
fmoe.gates
import
GShardGate
def
test_gshard_gate
(
d_model
,
batch_size
,
n_expert
):
gate
=
GShardGate
(
d_model
,
n_expert
,
dist
.
get_world_size
()).
cuda
()
def
_ensure_initialized
():
if
not
dist
.
is_initialized
():
os
.
environ
[
"RANK"
]
=
os
.
environ
.
get
(
"OMPI_COMM_WORLD_RANK"
,
"0"
)
os
.
environ
[
"WORLD_SIZE"
]
=
os
.
environ
.
get
(
"OMPI_COMM_WORLD_SIZE"
,
"1"
)
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
os
.
environ
[
"RANK"
]
os
.
environ
[
"MASTER_ADDR"
]
=
os
.
environ
.
get
(
"MASTER_ADDR"
,
"localhost"
)
os
.
environ
[
"MASTER_PORT"
]
=
os
.
environ
.
get
(
"MASTER_PORT"
,
"12211"
)
dist
.
init_process_group
(
backend
=
"nccl"
)
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
8
,
1024
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
16
,
4096
])
@
pytest
.
mark
.
parametrize
(
"n_expert"
,
[
1
,
4
,
16
])
@
pytest
.
mark
.
parametrize
(
"cap"
,
[.
1
,
.
5
,
1.1
])
def
test_gshard_gate
(
d_model
,
batch_size
,
n_expert
,
cap
):
_ensure_initialized
()
if
dist
.
get_world_size
()
*
n_expert
<
2
:
pytest
.
skip
(
"No enough experts"
)
gate
=
GShardGate
(
d_model
,
n_expert
,
dist
.
get_world_size
(),
capacity
=
(
cap
,
cap
)).
cuda
()
x
=
torch
.
rand
(
batch_size
,
d_model
).
cuda
()
topk_idx
,
topk_val
=
gate
(
x
)
print
(
'rank {} idx {}'
.
format
(
dist
.
get_rank
(),
topk_idx
))
print
(
'rank {} val {}'
.
format
(
dist
.
get_rank
(),
topk_val
))
counts
=
[
0
for
_
in
range
(
n_expert
)]
for
v
in
topk_idx
.
cpu
().
view
(
-
1
).
numpy
():
if
v
!=
-
1
:
counts
[
v
]
+=
1
real_cap
=
math
.
ceil
(
cap
*
batch_size
)
for
i
in
counts
:
assert
(
i
<=
real_cap
)
if
__name__
==
'__main__'
:
os
.
environ
[
"RANK"
]
=
os
.
environ
.
get
(
"OMPI_COMM_WORLD_RANK"
,
"0"
)
os
.
environ
[
"WORLD_SIZE"
]
=
os
.
environ
.
get
(
"OMPI_COMM_WORLD_SIZE"
,
"1"
)
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
os
.
environ
[
"RANK"
]
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
)
test_gshard_gate
(
4096
,
1024
,
4
)
_ensure_initialized
()
test_gshard_gate
(
4096
,
1024
,
4
,
.
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment