Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
59913cca
Commit
59913cca
authored
Jul 08, 2021
by
Rick Ho
Browse files
resolve loss reduction with customized gates
parent
18a4395c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
14 additions
and
11 deletions
+14
-11
fmoe/functions.py
fmoe/functions.py
+2
-6
fmoe/layers.py
fmoe/layers.py
+3
-1
fmoe/megatron/balance.py
fmoe/megatron/balance.py
+5
-4
fmoe/megatron/distributed.py
fmoe/megatron/distributed.py
+4
-0
No files found.
fmoe/functions.py
View file @
59913cca
...
...
@@ -10,13 +10,13 @@ import fmoe_cuda
from
.utils
import
get_torch_default_comm
def
_
ensure_
nccl
(
t
,
comm
):
def
ensure_
comm
(
t
,
comm
):
if
comm
is
None
:
comm
=
get_torch_default_comm
()
fmoe_cuda
.
ensure_nccl
(
comm
,
t
)
def
count_by_gate
(
gate
,
num_expert
,
world_size
,
comm
,
require_pos
=
True
):
def
count_by_gate
(
gate
,
num_expert
,
world_size
,
comm
=
None
,
require_pos
=
True
):
with
torch
.
no_grad
():
local_expert_count
=
torch
.
zeros
(
num_expert
*
world_size
,
device
=
gate
.
device
,
dtype
=
torch
.
int32
...
...
@@ -25,7 +25,6 @@ def count_by_gate(gate, num_expert, world_size, comm, require_pos=True):
local_expert_count
=
local_expert_count
.
long
()
if
world_size
>
1
:
_ensure_nccl
(
gate
,
comm
)
global_expert_count
=
fmoe_cuda
.
expert_exchange
(
local_expert_count
,
num_expert
,
world_size
)
...
...
@@ -52,9 +51,6 @@ def prepare_forward(gate, num_expert, world_size, comm):
world_size: number of workers that hold different experts.
comm: the communicator of all workers in the expert-parallel group.
"""
if
world_size
>
1
:
_ensure_nccl
(
gate
,
comm
=
comm
)
pos
,
local_expert_count
,
global_expert_count
=
count_by_gate
(
gate
,
num_expert
,
world_size
,
comm
)
with
torch
.
no_grad
():
...
...
fmoe/layers.py
View file @
59913cca
...
...
@@ -5,7 +5,7 @@ import torch
import
torch.nn
as
nn
import
math
from
.functions
import
prepare_forward
from
.functions
import
prepare_forward
,
ensure_comm
from
.functions
import
MOEScatter
,
MOEGather
,
MOELinear
from
.functions
import
AllGather
,
Slice
from
.gates
import
NaiveGate
...
...
@@ -212,6 +212,8 @@ class FMoE(nn.Module):
according to the gate. The score of the selected gate given by the
expert is multiplied to the experts' output tensors as a weight.
"""
if
self
.
world_size
>
1
:
ensure_comm
(
inp
,
self
.
moe_group
)
if
self
.
mp_size
>
1
:
inp
=
Slice
.
apply
(
inp
,
self
.
mp_rank
,
self
.
mp_size
,
self
.
mp_group
)
...
...
fmoe/megatron/balance.py
View file @
59913cca
...
...
@@ -5,6 +5,7 @@ import torch
from
fmoe.balance
import
reset_balance_profile
from
fmoe.balance
import
update_balance_profile
from
fmoe.utils
import
get_torch_default_comm
from
.distributed
import
get_moe_group
balance_dict
=
{}
...
...
@@ -101,11 +102,11 @@ def patch_forward_step(forward_step_func):
torch
.
cat
(
loss_list
).
mean
()
*
args
.
balance_loss_weight
)
# avarage across
world
group
world
_group
=
get_
torch_default_comm
()
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
world
_group
)
# avarage across
moe
group
moe
_group
=
get_
moe_group
()
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
moe
_group
)
averaged_bal_loss
=
bal_loss
.
clone
().
detach
()
torch
.
distributed
.
all_reduce
(
averaged_bal_loss
,
group
=
world
_group
)
torch
.
distributed
.
all_reduce
(
averaged_bal_loss
,
group
=
moe
_group
)
averaged_bal_loss
/=
world_size
loss
+=
bal_loss
...
...
fmoe/megatron/distributed.py
View file @
59913cca
...
...
@@ -12,6 +12,10 @@ def set_moe_group(moe_group):
_moe_group
=
moe_group
def
get_moe_group
():
return
_moe_group
class
DistributedDataParallel
(
DistributedGroupedDataParallel
):
r
"""
A wrapper that is used to replace the DDP module provided by Megatron, which
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment