Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
e028f2ec
Commit
e028f2ec
authored
Mar 22, 2021
by
Sengxian
Browse files
Revise variable name
parent
69121432
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
7 deletions
+7
-7
fmoe/balance.py
fmoe/balance.py
+5
-5
fmoe/megatron.py
fmoe/megatron.py
+2
-2
No files found.
fmoe/balance.py
View file @
e028f2ec
...
@@ -3,8 +3,8 @@ import torch.nn.functional as F
...
@@ -3,8 +3,8 @@ import torch.nn.functional as F
metrics
=
{
metrics
=
{
"coefficient-variation"
:
lambda
c_e
:
torch
.
std
(
c_e
)
/
torch
.
mean
(
c_e
),
"coefficient-variation"
:
lambda
c_e
:
torch
.
std
(
c_e
)
/
torch
.
mean
(
c_e
),
"Lmax
_div_
Lmin"
:
lambda
c_e
:
(
torch
.
max
(
c_e
)
+
1
)
/
(
torch
.
min
(
c_e
)
+
1
),
"Lmax
-over-
Lmin"
:
lambda
c_e
:
(
torch
.
max
(
c_e
)
+
1
)
/
(
torch
.
min
(
c_e
)
+
1
),
"Lmax
_div_
Lmean"
:
lambda
c_e
:
torch
.
max
(
c_e
)
/
torch
.
mean
(
c_e
),
"Lmax
-over-
Lmean"
:
lambda
c_e
:
torch
.
max
(
c_e
)
/
torch
.
mean
(
c_e
),
}
}
...
@@ -19,7 +19,7 @@ def update_balance_profile(
...
@@ -19,7 +19,7 @@ def update_balance_profile(
balance_dict
,
balance_dict
,
gate_top_k_idx
,
gate_top_k_idx
,
_gate_score_top_k
,
_gate_score_top_k
,
gate_
state_dic
t
,
gate_
contex
t
,
layer_idx
,
layer_idx
,
num_expert
,
num_expert
,
balance_strategy
,
balance_strategy
,
...
@@ -34,8 +34,8 @@ def update_balance_profile(
...
@@ -34,8 +34,8 @@ def update_balance_profile(
balance_dict
[
key
][
layer_idx
]
=
metrics
[
key
](
c_e
)
balance_dict
[
key
][
layer_idx
]
=
metrics
[
key
](
c_e
)
S
=
gate_top_k_idx
.
shape
[
0
]
S
=
gate_top_k_idx
.
shape
[
0
]
if
balance_strategy
==
"gshard"
:
if
balance_strategy
==
"gshard"
:
gate_score_all
=
gate_
state_dic
t
gate_score_all
=
gate_
contex
t
m_e
=
torch
.
sum
(
F
.
softmax
(
gate_score_all
,
dim
=
1
),
dim
=
0
)
/
S
m_e
=
torch
.
sum
(
F
.
softmax
(
gate_score_all
,
dim
=
1
),
dim
=
0
)
/
S
balance_dict
[
"gshard_loss"
][
layer_idx
]
=
torch
.
sum
(
c_e
*
m_e
)
/
num_expert
/
S
balance_dict
[
"gshard_loss"
][
layer_idx
]
=
torch
.
sum
(
c_e
*
m_e
)
/
num_expert
/
S
elif
balance_strategy
==
"noisy"
:
elif
balance_strategy
==
"noisy"
:
balance_dict
[
"noisy_loss"
][
layer_idx
]
=
gate_
state_dic
t
balance_dict
[
"noisy_loss"
][
layer_idx
]
=
gate_
contex
t
fmoe/megatron.py
View file @
e028f2ec
...
@@ -96,13 +96,13 @@ def generate_megatron_gate_hook(layer_idx, num_expert_global):
...
@@ -96,13 +96,13 @@ def generate_megatron_gate_hook(layer_idx, num_expert_global):
balance_strategy
=
get_args
().
balance_strategy
balance_strategy
=
get_args
().
balance_strategy
def
megatron_gate_hook
(
gate_top_k_idx
,
gate_score_top_k
,
gate_
state_dic
t
):
def
megatron_gate_hook
(
gate_top_k_idx
,
gate_score_top_k
,
gate_
contex
t
):
global
balance_dict
global
balance_dict
update_balance_profile
(
update_balance_profile
(
balance_dict
,
balance_dict
,
gate_top_k_idx
,
gate_top_k_idx
,
gate_score_top_k
,
gate_score_top_k
,
gate_
state_dic
t
,
gate_
contex
t
,
layer_idx
,
layer_idx
,
num_expert_global
,
num_expert_global
,
balance_strategy
,
balance_strategy
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment