Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
59913cca
Commit
59913cca
authored
Jul 08, 2021
by
Rick Ho
Browse files
resolve loss reduction with customized gates
parent
18a4395c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
14 additions
and
11 deletions
+14
-11
fmoe/functions.py
fmoe/functions.py
+2
-6
fmoe/layers.py
fmoe/layers.py
+3
-1
fmoe/megatron/balance.py
fmoe/megatron/balance.py
+5
-4
fmoe/megatron/distributed.py
fmoe/megatron/distributed.py
+4
-0
No files found.
fmoe/functions.py
View file @
59913cca
...
@@ -10,13 +10,13 @@ import fmoe_cuda
...
@@ -10,13 +10,13 @@ import fmoe_cuda
from
.utils
import
get_torch_default_comm
from
.utils
import
get_torch_default_comm
def
_
ensure_
nccl
(
t
,
comm
):
def
ensure_
comm
(
t
,
comm
):
if
comm
is
None
:
if
comm
is
None
:
comm
=
get_torch_default_comm
()
comm
=
get_torch_default_comm
()
fmoe_cuda
.
ensure_nccl
(
comm
,
t
)
fmoe_cuda
.
ensure_nccl
(
comm
,
t
)
def
count_by_gate
(
gate
,
num_expert
,
world_size
,
comm
,
require_pos
=
True
):
def
count_by_gate
(
gate
,
num_expert
,
world_size
,
comm
=
None
,
require_pos
=
True
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
local_expert_count
=
torch
.
zeros
(
local_expert_count
=
torch
.
zeros
(
num_expert
*
world_size
,
device
=
gate
.
device
,
dtype
=
torch
.
int32
num_expert
*
world_size
,
device
=
gate
.
device
,
dtype
=
torch
.
int32
...
@@ -25,7 +25,6 @@ def count_by_gate(gate, num_expert, world_size, comm, require_pos=True):
...
@@ -25,7 +25,6 @@ def count_by_gate(gate, num_expert, world_size, comm, require_pos=True):
local_expert_count
=
local_expert_count
.
long
()
local_expert_count
=
local_expert_count
.
long
()
if
world_size
>
1
:
if
world_size
>
1
:
_ensure_nccl
(
gate
,
comm
)
global_expert_count
=
fmoe_cuda
.
expert_exchange
(
global_expert_count
=
fmoe_cuda
.
expert_exchange
(
local_expert_count
,
num_expert
,
world_size
local_expert_count
,
num_expert
,
world_size
)
)
...
@@ -52,9 +51,6 @@ def prepare_forward(gate, num_expert, world_size, comm):
...
@@ -52,9 +51,6 @@ def prepare_forward(gate, num_expert, world_size, comm):
world_size: number of workers that hold different experts.
world_size: number of workers that hold different experts.
comm: the communicator of all workers in the expert-parallel group.
comm: the communicator of all workers in the expert-parallel group.
"""
"""
if
world_size
>
1
:
_ensure_nccl
(
gate
,
comm
=
comm
)
pos
,
local_expert_count
,
global_expert_count
=
count_by_gate
(
gate
,
pos
,
local_expert_count
,
global_expert_count
=
count_by_gate
(
gate
,
num_expert
,
world_size
,
comm
)
num_expert
,
world_size
,
comm
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
...
fmoe/layers.py
View file @
59913cca
...
@@ -5,7 +5,7 @@ import torch
...
@@ -5,7 +5,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
math
import
math
from
.functions
import
prepare_forward
from
.functions
import
prepare_forward
,
ensure_comm
from
.functions
import
MOEScatter
,
MOEGather
,
MOELinear
from
.functions
import
MOEScatter
,
MOEGather
,
MOELinear
from
.functions
import
AllGather
,
Slice
from
.functions
import
AllGather
,
Slice
from
.gates
import
NaiveGate
from
.gates
import
NaiveGate
...
@@ -212,6 +212,8 @@ class FMoE(nn.Module):
...
@@ -212,6 +212,8 @@ class FMoE(nn.Module):
according to the gate. The score of the selected gate given by the
according to the gate. The score of the selected gate given by the
expert is multiplied to the experts' output tensors as a weight.
expert is multiplied to the experts' output tensors as a weight.
"""
"""
if
self
.
world_size
>
1
:
ensure_comm
(
inp
,
self
.
moe_group
)
if
self
.
mp_size
>
1
:
if
self
.
mp_size
>
1
:
inp
=
Slice
.
apply
(
inp
,
self
.
mp_rank
,
self
.
mp_size
,
self
.
mp_group
)
inp
=
Slice
.
apply
(
inp
,
self
.
mp_rank
,
self
.
mp_size
,
self
.
mp_group
)
...
...
fmoe/megatron/balance.py
View file @
59913cca
...
@@ -5,6 +5,7 @@ import torch
...
@@ -5,6 +5,7 @@ import torch
from
fmoe.balance
import
reset_balance_profile
from
fmoe.balance
import
reset_balance_profile
from
fmoe.balance
import
update_balance_profile
from
fmoe.balance
import
update_balance_profile
from
fmoe.utils
import
get_torch_default_comm
from
fmoe.utils
import
get_torch_default_comm
from
.distributed
import
get_moe_group
balance_dict
=
{}
balance_dict
=
{}
...
@@ -101,11 +102,11 @@ def patch_forward_step(forward_step_func):
...
@@ -101,11 +102,11 @@ def patch_forward_step(forward_step_func):
torch
.
cat
(
loss_list
).
mean
()
*
args
.
balance_loss_weight
torch
.
cat
(
loss_list
).
mean
()
*
args
.
balance_loss_weight
)
)
# avarage across
world
group
# avarage across
moe
group
world
_group
=
get_
torch_default_comm
()
moe
_group
=
get_
moe_group
()
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
world
_group
)
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
moe
_group
)
averaged_bal_loss
=
bal_loss
.
clone
().
detach
()
averaged_bal_loss
=
bal_loss
.
clone
().
detach
()
torch
.
distributed
.
all_reduce
(
averaged_bal_loss
,
group
=
world
_group
)
torch
.
distributed
.
all_reduce
(
averaged_bal_loss
,
group
=
moe
_group
)
averaged_bal_loss
/=
world_size
averaged_bal_loss
/=
world_size
loss
+=
bal_loss
loss
+=
bal_loss
...
...
fmoe/megatron/distributed.py
View file @
59913cca
...
@@ -12,6 +12,10 @@ def set_moe_group(moe_group):
...
@@ -12,6 +12,10 @@ def set_moe_group(moe_group):
_moe_group
=
moe_group
_moe_group
=
moe_group
def
get_moe_group
():
return
_moe_group
class
DistributedDataParallel
(
DistributedGroupedDataParallel
):
class
DistributedDataParallel
(
DistributedGroupedDataParallel
):
r
"""
r
"""
A wrapper that is used to replace the DDP module provided by Megatron, which
A wrapper that is used to replace the DDP module provided by Megatron, which
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment