Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
5a0ba835
"vscode:/vscode.git/clone" did not exist on "be6f6c2927dc03b6103af8d48a961562dd5d68d5"
Commit
5a0ba835
authored
May 13, 2021
by
Rick Ho
Browse files
add test but cannot pass
parent
56cb8c15
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
34 additions
and
8 deletions
+34
-8
cuda/balancing.cu
cuda/balancing.cu
+2
-0
fmoe/functions.py
fmoe/functions.py
+1
-1
fmoe/gates/gshard_gate.py
fmoe/gates/gshard_gate.py
+9
-3
fmoe/gates/naive_gate.py
fmoe/gates/naive_gate.py
+2
-4
tests/test_gates.py
tests/test_gates.py
+20
-0
No files found.
cuda/balancing.cu
View file @
5a0ba835
...
...
@@ -7,6 +7,8 @@
std
::
vector
<
torch
::
Tensor
>
_limit_by_capacity
(
torch
::
Tensor
expert_count
,
torch
::
Tensor
capacity
,
long
n_expert
,
long
n_worker
)
{
CHECK_INPUT
(
expert_count
);
CHECK_INPUT
(
capacity
);
auto
expert_count_ack
=
torch
::
empty_like
(
expert_count
);
auto
smgr
=
getCudaStreamManager
(
expert_count
.
device
().
index
());
fmoe_cuda_limit_by_capacity_impl
(
...
...
fmoe/functions.py
View file @
5a0ba835
...
...
@@ -10,7 +10,7 @@ import fmoe_cuda
from
.utils
import
get_torch_default_comm
def
count_by_gate
(
gate
,
num_expert
,
world_size
,
comm
):
def
count_by_gate
(
gate
,
num_expert
,
world_size
):
# TODO: support -1 in gate, which means ignore this input
with
torch
.
no_grad
():
_
,
pos
=
torch
.
sort
(
gate
)
...
...
fmoe/gates/gshard_gate.py
View file @
5a0ba835
r
"""
Balanced gate with GShard's policy (Google, 2020)
"""
import
math
import
torch
import
torch.nn.functional
as
F
from
.naive_gate
import
NaiveGate
...
...
@@ -14,13 +15,13 @@ class GShardGate(NaiveGate):
self
.
capacity
=
capacity
def
forward
(
self
,
x
):
topk_idx
,
topk_val
,
gate_score
=
super
().
forward
(
x
)
topk_idx
,
gate_score
=
super
().
forward
(
x
)
S
=
gate_score
.
shape
[
0
]
top_k
=
topk_idx
.
shape
[
0
]
//
gate_score
.
shape
[
0
]
top1_idx
=
topk_idx
.
view
((
-
1
,
top_k
))[:,
0
]
c_e
=
torch
.
scatter_add
(
torch
.
zeros
(
self
.
num_expert
,
device
=
gate_
top
_
1_idx
.
device
),
torch
.
zeros
(
self
.
num_expert
,
device
=
top1_idx
.
device
),
0
,
top1_idx
,
torch
.
ones_like
(
top1_idx
,
dtype
=
torch
.
float
),
...
...
@@ -33,14 +34,19 @@ class GShardGate(NaiveGate):
capacity
=
torch
.
ones
(
self
.
num_expert
,
dtype
=
torch
.
int32
)
capacity
*=
math
.
ceil
(
cap_rate
*
x
.
shape
[
0
])
pos
,
lec
,
gec
=
count_by_gate
(
gate
,
self
.
num_expert
,
self
.
world_size
)
print
(
topk_idx
)
pos
,
lec
,
gec
=
count_by_gate
(
gate_score
,
self
.
num_expert
,
self
.
world_size
)
print
(
topk_idx
)
new_gec
,
=
fmoe_native
.
limit_by_capacity
(
gec
,
capacity
,
self
.
num_expert
,
self
.
world_size
)
print
(
topk_idx
)
if
self
.
world_size
>
1
:
new_lec
=
fmoe_native
.
expert_exchange
(
new_gec
,
self
.
num_expert
,
self
.
world_size
)
else
:
new_lec
=
new_gec
print
(
topk_idx
)
fmoe_native
.
prune_gate_by_capacity
(
topk_idx
,
new_lec
.
to
(
torch
.
int32
),
self
.
num_expert
,
self
.
world_size
)
...
...
fmoe/gates/naive_gate.py
View file @
5a0ba835
...
...
@@ -35,9 +35,7 @@ class NaiveGate(BaseGate):
gate_top_k_val
=
gate_top_k_val
.
view
(
-
1
,
self
.
top_k
)
# (BxL) x 1 x top_k
gate_score
=
F
.
softmax
(
gate_top_k_val
,
dim
=-
1
)
.
unsqueeze
(
1
)
gate_score
=
F
.
softmax
(
gate_top_k_val
,
dim
=-
1
)
gate_top_k_idx
=
gate_top_k_idx
.
view
(
-
1
)
# (BxLxtop_k)
# TODO: capacity
return
gate_top_k_idx
,
gate_score
return
gate_top_k_idx
,
gate
tests/test_gates.py
0 → 100644
View file @
5a0ba835
import
os
import
torch
import
torch.distributed
as
dist
from
fmoe.gates
import
GShardGate
def
test_gshard_gate
(
d_model
,
batch_size
,
n_expert
):
gate
=
GShardGate
(
d_model
,
n_expert
,
dist
.
get_world_size
()).
cuda
()
x
=
torch
.
rand
(
batch_size
,
d_model
).
cuda
()
topk_idx
,
topk_val
=
gate
(
x
)
print
(
'rank {} idx {}'
.
format
(
dist
.
get_rank
(),
topk_idx
))
print
(
'rank {} val {}'
.
format
(
dist
.
get_rank
(),
topk_val
))
if
__name__
==
'__main__'
:
os
.
environ
[
"RANK"
]
=
os
.
environ
.
get
(
"OMPI_COMM_WORLD_RANK"
,
"0"
)
os
.
environ
[
"WORLD_SIZE"
]
=
os
.
environ
.
get
(
"OMPI_COMM_WORLD_SIZE"
,
"1"
)
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
os
.
environ
[
"RANK"
]
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
)
test_gshard_gate
(
4096
,
1024
,
4
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment