Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
a468db2b
"projects/git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "14bd5e28e8a0c93a82ae4e2152e85150dfcde6c7"
Commit
a468db2b
authored
May 18, 2021
by
Rick Ho
Browse files
fix bugs to run on multiple gpus
parent
38b334cc
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
12 additions
and
8 deletions
+12
-8
fmoe/functions.py
fmoe/functions.py
+8
-4
fmoe/gates/gshard_gate.py
fmoe/gates/gshard_gate.py
+1
-1
fmoe/gates/utils.py
fmoe/gates/utils.py
+1
-1
tests/test_gates.py
tests/test_gates.py
+2
-2
No files found.
fmoe/functions.py
View file @
a468db2b
...
@@ -10,6 +10,12 @@ import fmoe_cuda
...
@@ -10,6 +10,12 @@ import fmoe_cuda
from
.utils
import
get_torch_default_comm
from
.utils
import
get_torch_default_comm
def
_ensure_nccl
(
t
,
comm
=
None
):
if
comm
is
None
:
comm
=
get_torch_default_comm
()
fmoe_cuda
.
ensure_nccl
(
comm
,
t
)
def
count_by_gate
(
gate
,
num_expert
,
world_size
):
def
count_by_gate
(
gate
,
num_expert
,
world_size
):
# TODO: support -1 in gate, which means ignore this input
# TODO: support -1 in gate, which means ignore this input
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -21,6 +27,7 @@ def count_by_gate(gate, num_expert, world_size):
...
@@ -21,6 +27,7 @@ def count_by_gate(gate, num_expert, world_size):
local_expert_count
.
index_put_
((
gate_idx
.
long
(),),
gate_count
)
local_expert_count
.
index_put_
((
gate_idx
.
long
(),),
gate_count
)
if
world_size
>
1
:
if
world_size
>
1
:
_ensure_nccl
(
gate
)
(
global_expert_count
,)
=
fmoe_cuda
.
expert_exchange
(
(
global_expert_count
,)
=
fmoe_cuda
.
expert_exchange
(
local_expert_count
,
num_expert
,
world_size
local_expert_count
,
num_expert
,
world_size
)
)
...
@@ -29,7 +36,6 @@ def count_by_gate(gate, num_expert, world_size):
...
@@ -29,7 +36,6 @@ def count_by_gate(gate, num_expert, world_size):
return
pos
,
local_expert_count
,
global_expert_count
return
pos
,
local_expert_count
,
global_expert_count
def
prepare_forward
(
gate
,
num_expert
,
world_size
,
comm
=
None
):
def
prepare_forward
(
gate
,
num_expert
,
world_size
,
comm
=
None
):
r
"""
r
"""
Prepare necessary information from gate output for MoE computation.
Prepare necessary information from gate output for MoE computation.
...
@@ -42,9 +48,7 @@ def prepare_forward(gate, num_expert, world_size, comm=None):
...
@@ -42,9 +48,7 @@ def prepare_forward(gate, num_expert, world_size, comm=None):
comm: the communicator of all workers in the expert-parallel group.
comm: the communicator of all workers in the expert-parallel group.
"""
"""
if
world_size
>
1
:
if
world_size
>
1
:
if
comm
is
None
:
_ensure_nccl
(
gate
,
comm
=
comm
)
comm
=
get_torch_default_comm
()
fmoe_cuda
.
ensure_nccl
(
comm
,
gate
)
pos
,
local_expert_count
,
global_expert_count
=
count_by_gate
(
gate
,
pos
,
local_expert_count
,
global_expert_count
=
count_by_gate
(
gate
,
num_expert
,
world_size
)
num_expert
,
world_size
)
...
...
fmoe/gates/gshard_gate.py
View file @
a468db2b
...
@@ -21,7 +21,7 @@ class GShardGate(NaiveGate):
...
@@ -21,7 +21,7 @@ class GShardGate(NaiveGate):
top_k
=
topk_idx
.
shape
[
0
]
//
gate_score
.
shape
[
0
]
top_k
=
topk_idx
.
shape
[
0
]
//
gate_score
.
shape
[
0
]
top1_idx
=
topk_idx
.
view
((
-
1
,
top_k
))[:,
0
]
top1_idx
=
topk_idx
.
view
((
-
1
,
top_k
))[:,
0
]
c_e
=
torch
.
scatter_add
(
c_e
=
torch
.
scatter_add
(
torch
.
zeros
(
self
.
num
_expert
,
device
=
top1_idx
.
device
),
torch
.
zeros
(
self
.
tot
_expert
,
device
=
top1_idx
.
device
),
0
,
0
,
top1_idx
,
top1_idx
,
torch
.
ones_like
(
top1_idx
,
dtype
=
torch
.
float
),
torch
.
ones_like
(
top1_idx
,
dtype
=
torch
.
float
),
...
...
fmoe/gates/utils.py
View file @
a468db2b
...
@@ -14,7 +14,7 @@ def limit_by_capacity(topk_idx, num_expert, world_size, capacity):
...
@@ -14,7 +14,7 @@ def limit_by_capacity(topk_idx, num_expert, world_size, capacity):
new_gec
,
=
fmoe_native
.
limit_by_capacity
(
gec
,
capacity
,
new_gec
,
=
fmoe_native
.
limit_by_capacity
(
gec
,
capacity
,
num_expert
,
world_size
)
num_expert
,
world_size
)
if
world_size
>
1
:
if
world_size
>
1
:
new_lec
=
fmoe_native
.
expert_exchange
(
new_gec
,
num_expert
,
world_size
)
new_lec
,
=
fmoe_native
.
expert_exchange
(
new_gec
,
num_expert
,
world_size
)
else
:
else
:
new_lec
=
new_gec
new_lec
=
new_gec
...
...
tests/test_gates.py
View file @
a468db2b
...
@@ -28,7 +28,7 @@ def test_gshard_gate(d_model, batch_size, n_expert, cap):
...
@@ -28,7 +28,7 @@ def test_gshard_gate(d_model, batch_size, n_expert, cap):
capacity
=
(
cap
,
cap
)).
cuda
()
capacity
=
(
cap
,
cap
)).
cuda
()
x
=
torch
.
rand
(
batch_size
,
d_model
).
cuda
()
x
=
torch
.
rand
(
batch_size
,
d_model
).
cuda
()
topk_idx
,
topk_val
=
gate
(
x
)
topk_idx
,
topk_val
=
gate
(
x
)
counts
=
[
0
for
_
in
range
(
n_expert
)]
counts
=
[
0
for
_
in
range
(
n_expert
*
dist
.
get_world_size
()
)]
for
v
in
topk_idx
.
cpu
().
view
(
-
1
).
numpy
():
for
v
in
topk_idx
.
cpu
().
view
(
-
1
).
numpy
():
if
v
!=
-
1
:
if
v
!=
-
1
:
counts
[
v
]
+=
1
counts
[
v
]
+=
1
...
@@ -47,7 +47,7 @@ def test_switch_gate(d_model, batch_size, n_expert, cap):
...
@@ -47,7 +47,7 @@ def test_switch_gate(d_model, batch_size, n_expert, cap):
capacity
=
(
cap
,
cap
)).
cuda
()
capacity
=
(
cap
,
cap
)).
cuda
()
x
=
torch
.
rand
(
batch_size
,
d_model
).
cuda
()
x
=
torch
.
rand
(
batch_size
,
d_model
).
cuda
()
topk_idx
,
topk_val
=
gate
(
x
)
topk_idx
,
topk_val
=
gate
(
x
)
counts
=
[
0
for
_
in
range
(
n_expert
)]
counts
=
[
0
for
_
in
range
(
n_expert
*
dist
.
get_world_size
()
)]
for
v
in
topk_idx
.
cpu
().
view
(
-
1
).
numpy
():
for
v
in
topk_idx
.
cpu
().
view
(
-
1
).
numpy
():
if
v
!=
-
1
:
if
v
!=
-
1
:
counts
[
v
]
+=
1
counts
[
v
]
+=
1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment