Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
537679a8
Unverified
Commit
537679a8
authored
Jul 08, 2021
by
Rick Ho
Committed by
GitHub
Jul 08, 2021
Browse files
Merge pull request #60 from laekov/remove-comm
Remove unnecessary dependencies on comm
parents
50a9aa94
c8483d42
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
24 deletions
+12
-24
fmoe/balance.py
fmoe/balance.py
+2
-15
fmoe/functions.py
fmoe/functions.py
+3
-3
fmoe/layers.py
fmoe/layers.py
+7
-6
No files found.
fmoe/balance.py
View file @
537679a8
...
@@ -24,18 +24,5 @@ def update_balance_profile(
...
@@ -24,18 +24,5 @@ def update_balance_profile(
num_expert
,
num_expert
,
balance_strategy
,
balance_strategy
,
):
):
c_e
=
torch
.
scatter_add
(
# Fill in this function to conduct balance related jobs
torch
.
zeros
(
num_expert
,
device
=
gate_top_k_idx
.
device
),
pass
0
,
gate_top_k_idx
,
torch
.
ones_like
(
gate_top_k_idx
,
dtype
=
torch
.
float
),
)
for
key
in
metrics
:
balance_dict
[
key
][
layer_idx
]
=
metrics
[
key
](
c_e
)
S
=
gate_top_k_idx
.
shape
[
0
]
if
balance_strategy
==
"gshard"
:
gate_score_all
=
gate_context
m_e
=
torch
.
sum
(
F
.
softmax
(
gate_score_all
,
dim
=
1
),
dim
=
0
)
/
S
balance_dict
[
"gshard_loss"
][
layer_idx
]
=
torch
.
sum
(
c_e
*
m_e
)
/
num_expert
/
S
elif
balance_strategy
==
"noisy"
:
balance_dict
[
"noisy_loss"
][
layer_idx
]
=
gate_context
fmoe/functions.py
View file @
537679a8
...
@@ -16,7 +16,7 @@ def ensure_comm(t, comm):
...
@@ -16,7 +16,7 @@ def ensure_comm(t, comm):
fmoe_cuda
.
ensure_nccl
(
comm
,
t
)
fmoe_cuda
.
ensure_nccl
(
comm
,
t
)
def
count_by_gate
(
gate
,
num_expert
,
world_size
,
comm
=
None
,
require_pos
=
True
):
def
count_by_gate
(
gate
,
num_expert
,
world_size
,
require_pos
=
True
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
local_expert_count
=
torch
.
zeros
(
local_expert_count
=
torch
.
zeros
(
num_expert
*
world_size
,
device
=
gate
.
device
,
dtype
=
torch
.
int32
num_expert
*
world_size
,
device
=
gate
.
device
,
dtype
=
torch
.
int32
...
@@ -40,7 +40,7 @@ def count_by_gate(gate, num_expert, world_size, comm=None, require_pos=True):
...
@@ -40,7 +40,7 @@ def count_by_gate(gate, num_expert, world_size, comm=None, require_pos=True):
return
pos
,
local_expert_count
,
global_expert_count
return
pos
,
local_expert_count
,
global_expert_count
def
prepare_forward
(
gate
,
num_expert
,
world_size
,
comm
):
def
prepare_forward
(
gate
,
num_expert
,
world_size
):
r
"""
r
"""
Prepare necessary information from gate output for MoE computation.
Prepare necessary information from gate output for MoE computation.
...
@@ -52,7 +52,7 @@ def prepare_forward(gate, num_expert, world_size, comm):
...
@@ -52,7 +52,7 @@ def prepare_forward(gate, num_expert, world_size, comm):
comm: the communicator of all workers in the expert-parallel group.
comm: the communicator of all workers in the expert-parallel group.
"""
"""
pos
,
local_expert_count
,
global_expert_count
=
count_by_gate
(
gate
,
pos
,
local_expert_count
,
global_expert_count
=
count_by_gate
(
gate
,
num_expert
,
world_size
,
comm
)
num_expert
,
world_size
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
fwd_expert_count
=
global_expert_count
.
view
(
world_size
,
fwd_expert_count
=
global_expert_count
.
view
(
world_size
,
num_expert
).
sum
(
dim
=
0
)
num_expert
).
sum
(
dim
=
0
)
...
...
fmoe/layers.py
View file @
537679a8
...
@@ -74,8 +74,7 @@ def mark_module_parallel_comm(module, comm):
...
@@ -74,8 +74,7 @@ def mark_module_parallel_comm(module, comm):
setattr
(
p
,
"dp_comm"
,
comm
)
setattr
(
p
,
"dp_comm"
,
comm
)
def
_fmoe_general_global_forward
(
inp
,
gate
,
expert_fn
,
num_expert
,
world_size
,
def
_fmoe_general_global_forward
(
inp
,
gate
,
expert_fn
,
num_expert
,
world_size
):
comm
=
None
):
r
"""
r
"""
A private function that performs the following steps to complete the MoE
A private function that performs the following steps to complete the MoE
computation.
computation.
...
@@ -93,7 +92,7 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size,
...
@@ -93,7 +92,7 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size,
global_expert_count
,
global_expert_count
,
fwd_expert_count
,
fwd_expert_count
,
fwd_batch_size
,
fwd_batch_size
,
)
=
prepare_forward
(
gate
,
num_expert
,
world_size
,
comm
)
)
=
prepare_forward
(
gate
,
num_expert
,
world_size
)
topk
=
1
topk
=
1
if
len
(
gate
.
shape
)
==
2
:
if
len
(
gate
.
shape
)
==
2
:
topk
=
gate
.
shape
[
1
]
topk
=
gate
.
shape
[
1
]
...
@@ -219,6 +218,9 @@ class FMoE(nn.Module):
...
@@ -219,6 +218,9 @@ class FMoE(nn.Module):
gate_top_k_idx
,
gate_score
=
self
.
gate
(
inp
)
gate_top_k_idx
,
gate_score
=
self
.
gate
(
inp
)
if
self
.
gate_hook
is
not
None
:
self
.
gate_hook
(
gate_top_k_idx
,
gate_score
,
None
)
# delete masked tensors
# delete masked tensors
if
self
.
mask
is
not
None
and
self
.
mask_dict
is
not
None
:
if
self
.
mask
is
not
None
and
self
.
mask_dict
is
not
None
:
mask
=
self
.
mask
.
view
(
-
1
)
mask
=
self
.
mask
.
view
(
-
1
)
...
@@ -227,9 +229,8 @@ class FMoE(nn.Module):
...
@@ -227,9 +229,8 @@ class FMoE(nn.Module):
gate_top_k_idx
=
gate_top_k_idx
[
mask
==
0
,
:]
gate_top_k_idx
=
gate_top_k_idx
[
mask
==
0
,
:]
fwd
=
_fmoe_general_global_forward
(
fwd
=
_fmoe_general_global_forward
(
inp
,
inp
,
gate_top_k_idx
,
gate_top_k_idx
,
self
.
expert_fn
,
self
.
num_expert
,
self
.
world_size
self
.
expert_fn
,
self
.
num_expert
,
self
.
world_size
,
self
.
moe_group
)
)
# recover deleted tensors
# recover deleted tensors
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment