Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
89d6c794
Commit
89d6c794
authored
Mar 22, 2021
by
Sengxian
Browse files
Add gshard and noisy gate balance strategy
parent
a3b2eb62
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
349 additions
and
15 deletions
+349
-15
cuda/moe.cpp
cuda/moe.cpp
+2
-2
fmoe/balance.py
fmoe/balance.py
+41
-0
fmoe/gates.py
fmoe/gates.py
+135
-5
fmoe/layers.py
fmoe/layers.py
+5
-1
fmoe/megatron.py
fmoe/megatron.py
+162
-5
fmoe/transformer.py
fmoe/transformer.py
+2
-0
tests/benchmark_mlp.py
tests/benchmark_mlp.py
+1
-1
tests/test_numerical.py
tests/test_numerical.py
+1
-1
No files found.
cuda/moe.cpp
View file @
89d6c794
...
...
@@ -167,10 +167,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"expert_exchange"
,
&
moe_expert_exchange
,
"MoE expert exchange (CUDA)"
);
m
.
def
(
"global_scatter"
,
&
moe_global_scatter
,
"MoE global scatter (CUDA)"
);
m
.
def
(
"global_gather"
,
&
moe_global_gather
,
"MoE global gather (CUDA)"
);
m
.
def
(
"global_fused_forward"
,
&
moe_global_fused_forward
,
m
.
def
(
"global_fused_forward"
,
&
moe_global_fused_forward
,
"MoE global gather (CUDA)"
);
m
.
def
(
"ensure_nccl"
,
&
moe_ensure_nccl
,
"MoE ensure torch nccl comm"
);
#endif
m
.
def
(
"forward"
,
&
moe_forward
,
"MoE forward (CUDA)"
);
m
.
def
(
"backward"
,
&
moe_backward
,
"MoE backward (CUDA)"
);
}
}
\ No newline at end of file
fmoe/balance.py
0 → 100644
View file @
89d6c794
import
torch
import
torch.nn.functional
as
F
metrics
=
{
"coefficient-variation"
:
lambda
c_e
:
torch
.
std
(
c_e
)
/
torch
.
mean
(
c_e
),
"Lmax_div_Lmin"
:
lambda
c_e
:
(
torch
.
max
(
c_e
)
+
1
)
/
(
torch
.
min
(
c_e
)
+
1
),
"Lmax_div_Lmean"
:
lambda
c_e
:
torch
.
max
(
c_e
)
/
torch
.
mean
(
c_e
),
}
def
reset_balance_profile
(
balance_dict
,
num_layers
,
balance_strategy
):
for
key
in
metrics
:
balance_dict
[
key
]
=
[
None
for
_
in
range
(
num_layers
)]
if
balance_strategy
:
balance_dict
[
f
"
{
balance_strategy
}
_loss"
]
=
[
None
for
_
in
range
(
num_layers
)]
def
update_balance_profile
(
balance_dict
,
gate_top_k_idx
,
_gate_score_top_k
,
gate_state_dict
,
layer_idx
,
num_expert
,
balance_strategy
,
):
c_e
=
torch
.
scatter_add
(
torch
.
zeros
(
num_expert
,
device
=
gate_top_k_idx
.
device
),
0
,
gate_top_k_idx
,
torch
.
ones_like
(
gate_top_k_idx
,
dtype
=
torch
.
float
),
)
for
key
in
metrics
:
balance_dict
[
key
][
layer_idx
]
=
metrics
[
key
](
c_e
)
S
=
gate_top_k_idx
.
shape
[
0
]
if
balance_strategy
==
"gshard"
:
gate_score_all
=
gate_state_dict
m_e
=
torch
.
sum
(
F
.
softmax
(
gate_score_all
,
dim
=
1
),
dim
=
0
)
/
S
balance_dict
[
"gshard_loss"
][
layer_idx
]
=
torch
.
sum
(
c_e
*
m_e
)
/
num_expert
/
S
elif
balance_strategy
==
"noisy"
:
balance_dict
[
"noisy_loss"
][
layer_idx
]
=
gate_state_dict
fmoe/gates.py
View file @
89d6c794
...
...
@@ -5,6 +5,7 @@ The `NaiveGate` is the reference to implement any other gate.
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.distributions.normal
import
Normal
class
ZeroGate
(
nn
.
Module
):
...
...
@@ -12,8 +13,9 @@ class ZeroGate(nn.Module):
Guide all input samples to gate 0.
"""
def
__init__
(
self
,
_1
,
_2
,
_3
,
top_k
=
2
):
def
__init__
(
self
,
_1
,
num_expert
,
_3
,
top_k
=
2
):
super
().
__init__
()
self
.
num_expert
=
num_expert
self
.
top_k
=
top_k
def
forward
(
self
,
inp
):
...
...
@@ -23,9 +25,12 @@ class ZeroGate(nn.Module):
idx
=
torch
.
zeros
(
inp
.
shape
[
0
]
*
self
.
top_k
,
dtype
=
torch
.
int64
,
device
=
inp
.
device
)
score
=
torch
.
ones
(
inp
.
shape
[
0
]
*
self
.
top_k
,
device
=
inp
.
device
)
/
self
.
top_k
return
idx
,
score
.
reshape
(
-
1
,
1
,
self
.
top_k
)
gate_score
=
(
torch
.
ones
(
inp
.
shape
[
0
]
*
self
.
top_k
,
device
=
inp
.
device
)
/
self
.
top_k
)
gate_score_all
=
torch
.
zeros
(
inp
.
shape
[
0
],
self
.
num_expert
,
device
=
inp
.
device
)
gate_score_all
[:,
0
]
=
1
return
idx
,
gate_score
.
reshape
(
-
1
,
1
,
self
.
top_k
),
gate_score_all
class
NaiveGate
(
nn
.
Module
):
...
...
@@ -58,4 +63,129 @@ class NaiveGate(nn.Module):
gate_score
=
F
.
softmax
(
gate_top_k_val
,
dim
=-
1
).
unsqueeze
(
1
)
gate_top_k_idx
=
gate_top_k_idx
.
view
(
-
1
)
# (BxLxtop_k)
return
gate_top_k_idx
,
gate_score
return
gate_top_k_idx
,
gate_score
,
gate
class
NoisyGate
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
num_expert
,
world_size
,
top_k
=
2
):
super
().
__init__
()
self
.
num_expert
=
num_expert
*
world_size
self
.
w_gate
=
nn
.
Parameter
(
torch
.
zeros
(
d_model
,
num_expert
*
world_size
),
requires_grad
=
True
)
self
.
w_noise
=
nn
.
Parameter
(
torch
.
zeros
(
d_model
,
num_expert
*
world_size
),
requires_grad
=
True
)
self
.
top_k
=
top_k
self
.
softplus
=
nn
.
Softplus
()
self
.
softmax
=
nn
.
Softmax
(
1
)
self
.
noise_epsilon
=
1e-2
def
_gates_to_load
(
self
,
gates
):
"""Compute the true load per expert, given the gates.
The load is the number of examples for which the corresponding gate is >0.
Args:
gates: a `Tensor` of shape [batch_size, n]
Returns:
a float32 `Tensor` of shape [n]
"""
return
(
gates
>
0
).
sum
(
0
)
def
_prob_in_top_k
(
self
,
clean_values
,
noisy_values
,
noise_stddev
,
noisy_top_values
):
"""Helper function to NoisyTopKGating.
Computes the probability that value is in top k, given different random noise.
This gives us a way of backpropagating from a loss that balances the number
of times each expert is in the top k experts per example.
In the case of no noise, pass in None for noise_stddev, and the result will
not be differentiable.
Args:
clean_values: a `Tensor` of shape [batch, n].
noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus
normally distributed noise with standard deviation noise_stddev.
noise_stddev: a `Tensor` of shape [batch, n], or None
noisy_top_values: a `Tensor` of shape [batch, m].
"values" Output of tf.top_k(noisy_top_values, m). m >= k+1
Returns:
a `Tensor` of shape [batch, n].
"""
batch
=
clean_values
.
size
(
0
)
m
=
noisy_top_values
.
size
(
1
)
top_values_flat
=
noisy_top_values
.
flatten
()
threshold_positions_if_in
=
(
torch
.
arange
(
batch
,
device
=
clean_values
.
device
)
*
m
+
self
.
top_k
)
threshold_if_in
=
torch
.
unsqueeze
(
torch
.
gather
(
top_values_flat
,
0
,
threshold_positions_if_in
),
1
)
is_in
=
torch
.
gt
(
noisy_values
,
threshold_if_in
)
threshold_positions_if_out
=
threshold_positions_if_in
-
1
threshold_if_out
=
torch
.
unsqueeze
(
torch
.
gather
(
top_values_flat
,
0
,
threshold_positions_if_out
),
1
)
# is each value currently in the top k.
normal
=
Normal
(
torch
.
tensor
([
0.0
],
device
=
clean_values
.
device
),
torch
.
tensor
([
1.0
],
device
=
clean_values
.
device
),
)
prob_if_in
=
normal
.
cdf
((
clean_values
-
threshold_if_in
)
/
noise_stddev
)
prob_if_out
=
normal
.
cdf
((
clean_values
-
threshold_if_out
)
/
noise_stddev
)
prob
=
torch
.
where
(
is_in
,
prob_if_in
,
prob_if_out
)
return
prob
def
cv_squared
(
self
,
x
):
"""The squared coefficient of variation of a sample.
Useful as a loss to encourage a positive distribution to be more uniform.
Epsilons added for numerical stability.
Returns 0 for an empty Tensor.
Args:
x: a `Tensor`.
Returns:
a `Scalar`.
"""
eps
=
1e-10
# if only num_expert = 1
if
x
.
shape
[
0
]
==
1
:
return
torch
.
Tensor
([
0
])
return
x
.
float
().
var
()
/
(
x
.
float
().
mean
()
**
2
+
eps
)
def
forward
(
self
,
inp
):
clean_logits
=
inp
@
self
.
w_gate
raw_noise_stddev
=
inp
@
self
.
w_noise
noise_stddev
=
(
self
.
softplus
(
raw_noise_stddev
)
+
self
.
noise_epsilon
)
*
self
.
training
noisy_logits
=
clean_logits
+
(
torch
.
randn_like
(
clean_logits
)
*
noise_stddev
)
logits
=
noisy_logits
# calculate topk + 1 that will be needed for the noisy gates
top_logits
,
top_indices
=
logits
.
topk
(
min
(
self
.
top_k
+
1
,
self
.
num_expert
),
dim
=
1
)
top_k_logits
=
top_logits
[:,
:
self
.
top_k
]
top_k_indices
=
top_indices
[:,
:
self
.
top_k
]
top_k_gates
=
self
.
softmax
(
top_k_logits
)
zeros
=
torch
.
zeros_like
(
logits
,
requires_grad
=
True
)
gates
=
zeros
.
scatter
(
1
,
top_k_indices
,
top_k_gates
)
if
self
.
top_k
<
self
.
num_expert
:
load
=
(
self
.
_prob_in_top_k
(
clean_logits
,
noisy_logits
,
noise_stddev
,
top_logits
)
).
sum
(
0
)
else
:
load
=
self
.
_gates_to_load
(
gates
)
importance
=
gates
.
sum
(
0
)
loss
=
self
.
cv_squared
(
importance
)
+
self
.
cv_squared
(
load
)
return
(
top_k_indices
.
contiguous
().
view
(
-
1
),
top_k_gates
.
contiguous
().
unsqueeze
(
1
),
loss
,
)
fmoe/layers.py
View file @
89d6c794
...
...
@@ -151,6 +151,7 @@ class FMoE(nn.Module):
top_k
=
2
,
gate
=
NaiveGate
,
expert
=
None
,
gate_hook
=
None
,
):
super
().
__init__
()
self
.
num_expert
=
num_expert
...
...
@@ -171,6 +172,7 @@ class FMoE(nn.Module):
self
.
experts_fused
=
False
else
:
self
.
experts_fused
=
True
self
.
gate_hook
=
gate_hook
def
expert_fn
(
self
,
inp
,
fwd_expert_count
):
r
"""
...
...
@@ -212,7 +214,9 @@ class FMoE(nn.Module):
if
self
.
mp_size
>
1
:
inp
=
Slice
.
apply
(
inp
,
self
.
mp_rank
,
self
.
mp_size
,
self
.
mp_group
)
gate_top_k_idx
,
gate_score
=
self
.
gate
(
inp
)
gate_top_k_idx
,
gate_score
,
gate_state_dict
=
self
.
gate
(
inp
)
if
self
.
gate_hook
:
self
.
gate_hook
(
gate_top_k_idx
,
gate_score
,
gate_state_dict
)
# to: (BxLxtop_k) x d_model
inp
=
inp
.
repeat_interleave
(
repeats
=
self
.
top_k
,
dim
=
0
)
x
=
_fmoe_general_global_forward
(
...
...
fmoe/megatron.py
View file @
89d6c794
...
...
@@ -11,6 +11,8 @@ import torch.nn.functional as F
from
.transformer
import
FMoETransformerMLP
from
.distributed
import
DistributedGroupedDataParallel
from
.balance
import
update_balance_profile
,
reset_balance_profile
from
.utils
import
get_torch_default_comm
class
_FakeMegatronMLP
(
nn
.
Module
):
...
...
@@ -69,22 +71,167 @@ def _random_init_weight(self, rng):
self
.
bias
.
data
=
torch
.
from_numpy
(
bias
).
to
(
dtype
=
dtype
,
device
=
device
)
balance_dict
=
{}
num_layers
=
0
def
reset_gate_hook
():
from
megatron
import
get_args
global
balance_dict
,
num_layers
reset_balance_profile
(
balance_dict
,
num_layers
,
get_args
().
balance_strategy
)
def
get_balance_profile
():
global
balance_dict
return
balance_dict
def
generate_megatron_gate_hook
(
layer_idx
,
num_expert_global
):
from
megatron
import
get_args
balance_strategy
=
get_args
().
balance_strategy
def
megatron_gate_hook
(
gate_top_k_idx
,
gate_score_top_k
,
gate_state_dict
):
global
balance_dict
update_balance_profile
(
balance_dict
,
gate_top_k_idx
,
gate_score_top_k
,
gate_state_dict
,
layer_idx
,
num_expert_global
,
balance_strategy
,
)
return
megatron_gate_hook
def
add_fmoe_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
"fastmoe"
)
group
.
add_argument
(
"--fmoefy"
,
action
=
"store_true"
)
group
.
add_argument
(
"--num-experts"
,
type
=
int
,
default
=
None
)
group
.
add_argument
(
"--top-k"
,
type
=
int
,
default
=
2
)
group
.
add_argument
(
"--balance-loss-weight"
,
type
=
float
,
default
=
1
)
group
.
add_argument
(
"--balance-strategy"
,
type
=
str
,
default
=
None
)
return
parser
def
add_balance_log
(
writer
,
iteration
):
from
megatron
import
is_last_rank
balance_dict_tensor
=
torch
.
vstack
(
[
torch
.
tensor
(
item
,
device
=
item
[
0
].
device
)
for
item
in
balance_dict
.
values
()]
).
detach
()
world_group
=
get_torch_default_comm
()
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
world_group
)
torch
.
distributed
.
all_reduce
(
balance_dict_tensor
,
group
=
world_group
)
balance_dict_tensor
/=
world_size
if
writer
and
is_last_rank
():
for
idx
,
metric_name
in
enumerate
(
balance_dict
):
for
layer_id
,
val
in
enumerate
(
balance_dict_tensor
[
idx
]):
writer
.
add_scalar
(
f
"balance-
{
metric_name
}
/layer-
{
layer_id
}
"
,
val
.
item
(),
iteration
)
writer
.
add_scalar
(
f
"balance-
{
metric_name
}
/all"
,
balance_dict_tensor
[
idx
].
mean
().
item
(),
iteration
,
)
reset_gate_hook
()
def
patch_forward_step
(
forward_step_func
):
r
"""
Patch model's forward_step_func to support balance loss
"""
from
megatron.mpu
import
is_pipeline_last_stage
from
megatron
import
get_args
if
not
get_args
().
balance_strategy
:
return
forward_step_func
def
forward_step_with_balance_loss
(
data_iterator
,
model
,
input_tensor
):
args
=
get_args
()
output
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
if
is_pipeline_last_stage
():
loss_name
=
args
.
balance_strategy
+
"_loss"
(
loss
,
state_dict
),
bal_loss
=
(
output
,
(
torch
.
tensor
(
balance_dict
[
loss_name
],
device
=
balance_dict
[
loss_name
][
0
].
device
,
).
mean
()
*
args
.
balance_loss_weight
).
float
(),
)
# avarage across world group
world_group
=
get_torch_default_comm
()
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
world_group
)
averaged_bal_loss
=
bal_loss
.
clone
().
detach
()
torch
.
distributed
.
all_reduce
(
averaged_bal_loss
,
group
=
world_group
)
averaged_bal_loss
/=
world_size
loss
+=
bal_loss
state_dict
[
loss_name
]
=
averaged_bal_loss
return
loss
,
state_dict
else
:
return
output
return
forward_step_with_balance_loss
def
patch_model_provider
(
model_provider
):
from
megatron
import
get_args
def
fmoefied_model_provider
():
args
=
get_args
()
return
fmoefy
(
model_provider
(),
num_experts
=
args
.
num_experts
,
hidden_hidden_size
=
4
*
args
.
hidden_size
//
args
.
top_k
,
top_k
=
args
.
top_k
,
)
return
fmoefied_model_provider
class
MegatronMLP
(
FMoETransformerMLP
):
r
"""
Make the FMoETransformerMLP layer that distributes experts across
communication group `group` to replace the original MLP layer in Megatron.
"""
def
__init__
(
self
,
args
,
group
):
def
__init__
(
self
,
args
,
group
,
layer_idx
):
assert
(
args
.
seq_length
*
args
.
micro_batch_size
%
args
.
tensor_model_parallel_size
args
.
seq_length
*
args
.
micro_batch_size
%
args
.
tensor_model_parallel_size
==
0
),
"Batch size x sequence length should be multiple of mp size"
if
not
args
.
distributed_experts
:
world_size
=
1
else
:
world_size
=
args
.
world_size
gate
=
None
if
not
args
.
balance_strategy
or
args
.
balance_strategy
==
"gshard"
:
from
.gates
import
NaiveGate
gate
=
NaiveGate
elif
args
.
balance_strategy
==
"noisy"
:
from
.gates
import
NoisyGate
gate
=
NoisyGate
else
:
assert
False
,
"Undefined balance strategy {}"
%
(
args
.
balance_strategy
)
super
().
__init__
(
args
.
num_experts
,
top_k
=
args
.
top_k
,
...
...
@@ -93,6 +240,10 @@ class MegatronMLP(FMoETransformerMLP):
world_size
=
world_size
,
mp_group
=
group
,
expert_dp_comm
=
"none"
if
args
.
distributed_experts
else
"dp"
,
gate_hook
=
generate_megatron_gate_hook
(
layer_idx
,
args
.
num_experts
*
world_size
),
gate
=
gate
,
)
self
.
hidden_size
=
args
.
hidden_size
if
args
.
distributed_experts
:
...
...
@@ -166,8 +317,14 @@ def fmoefy(
if
distributed_experts
is
not
None
:
args
.
distributed_experts
=
distributed_experts
for
l
in
model
.
language_model
.
transformer
.
layers
:
l
.
mlp
=
MegatronMLP
(
args
,
mpu
.
get_model_parallel_group
())
for
idx
,
l
in
enumerate
(
model
.
language_model
.
transformer
.
layers
):
l
.
mlp
=
MegatronMLP
(
args
,
mpu
.
get_model_parallel_group
(),
idx
)
# initialize gate hook
global
num_layers
,
balance_dict
num_layers
=
len
(
model
.
language_model
.
transformer
.
layers
)
reset_gate_hook
()
return
model
...
...
fmoe/transformer.py
View file @
89d6c794
...
...
@@ -50,6 +50,7 @@ class FMoETransformerMLP(FMoE):
gate
=
NaiveGate
,
top_k
=
2
,
expert_dp_comm
=
"none"
,
gate_hook
=
None
,
):
super
().
__init__
(
num_expert
=
num_expert
,
...
...
@@ -58,6 +59,7 @@ class FMoETransformerMLP(FMoE):
top_k
=
top_k
,
world_size
=
world_size
,
mp_group
=
mp_group
,
gate_hook
=
gate_hook
,
)
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
,
rank
=
self
.
mp_rank
...
...
tests/benchmark_mlp.py
View file @
89d6c794
...
...
@@ -40,7 +40,7 @@ class BruteForceMoE(nn.Module):
def
forward
(
self
,
inp
):
if
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
gate_top_k_idx
,
gate_score
=
self
.
gate
(
inp
)
gate_top_k_idx
,
gate_score
,
_
=
self
.
gate
(
inp
)
inp
=
inp
.
repeat_interleave
(
repeats
=
self
.
top_k
,
dim
=
0
)
x
=
self
.
mlp
(
inp
,
gate_top_k_idx
,
gate_score
)
if
not
self
.
pre_lnorm
:
...
...
tests/test_numerical.py
View file @
89d6c794
...
...
@@ -38,7 +38,7 @@ def _perform_forward(
inp
.
requires_grad
=
True
inp_raw
.
requires_grad
=
True
gate_idx
,
gate_score
=
moe
.
gate
(
inp_raw
)
gate_idx
,
gate_score
,
_
=
moe
.
gate
(
inp_raw
)
inp_repeated
=
inp_raw
.
repeat_interleave
(
repeats
=
top_k
,
dim
=
0
)
moe_out
=
moe
(
inp
)
raw_out
=
moe_raw
(
inp_repeated
,
gate_idx
,
gate_score
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment