Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
3c2a5979
Unverified
Commit
3c2a5979
authored
Mar 22, 2021
by
Rick Ho
Committed by
GitHub
Mar 22, 2021
Browse files
Merge pull request #19 from laekov/balance
Add balance strategy
parents
c1e67585
e028f2ec
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
349 additions
and
15 deletions
+349
-15
cuda/moe.cpp
cuda/moe.cpp
+2
-2
fmoe/balance.py
fmoe/balance.py
+41
-0
fmoe/gates.py
fmoe/gates.py
+135
-5
fmoe/layers.py
fmoe/layers.py
+5
-1
fmoe/megatron.py
fmoe/megatron.py
+162
-5
fmoe/transformer.py
fmoe/transformer.py
+2
-0
tests/benchmark_mlp.py
tests/benchmark_mlp.py
+1
-1
tests/test_numerical.py
tests/test_numerical.py
+1
-1
No files found.
cuda/moe.cpp
View file @
3c2a5979
...
...
@@ -167,10 +167,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"expert_exchange"
,
&
moe_expert_exchange
,
"MoE expert exchange (CUDA)"
);
m
.
def
(
"global_scatter"
,
&
moe_global_scatter
,
"MoE global scatter (CUDA)"
);
m
.
def
(
"global_gather"
,
&
moe_global_gather
,
"MoE global gather (CUDA)"
);
m
.
def
(
"global_fused_forward"
,
&
moe_global_fused_forward
,
m
.
def
(
"global_fused_forward"
,
&
moe_global_fused_forward
,
"MoE global gather (CUDA)"
);
m
.
def
(
"ensure_nccl"
,
&
moe_ensure_nccl
,
"MoE ensure torch nccl comm"
);
#endif
m
.
def
(
"forward"
,
&
moe_forward
,
"MoE forward (CUDA)"
);
m
.
def
(
"backward"
,
&
moe_backward
,
"MoE backward (CUDA)"
);
}
}
\ No newline at end of file
fmoe/balance.py
0 → 100644
View file @
3c2a5979
import
torch
import
torch.nn.functional
as
F
metrics
=
{
"coefficient-variation"
:
lambda
c_e
:
torch
.
std
(
c_e
)
/
torch
.
mean
(
c_e
),
"Lmax-over-Lmin"
:
lambda
c_e
:
(
torch
.
max
(
c_e
)
+
1
)
/
(
torch
.
min
(
c_e
)
+
1
),
"Lmax-over-Lmean"
:
lambda
c_e
:
torch
.
max
(
c_e
)
/
torch
.
mean
(
c_e
),
}
def
reset_balance_profile
(
balance_dict
,
num_layers
,
balance_strategy
):
for
key
in
metrics
:
balance_dict
[
key
]
=
[
None
for
_
in
range
(
num_layers
)]
if
balance_strategy
:
balance_dict
[
f
"
{
balance_strategy
}
_loss"
]
=
[
None
for
_
in
range
(
num_layers
)]
def
update_balance_profile
(
balance_dict
,
gate_top_k_idx
,
_gate_score_top_k
,
gate_context
,
layer_idx
,
num_expert
,
balance_strategy
,
):
c_e
=
torch
.
scatter_add
(
torch
.
zeros
(
num_expert
,
device
=
gate_top_k_idx
.
device
),
0
,
gate_top_k_idx
,
torch
.
ones_like
(
gate_top_k_idx
,
dtype
=
torch
.
float
),
)
for
key
in
metrics
:
balance_dict
[
key
][
layer_idx
]
=
metrics
[
key
](
c_e
)
S
=
gate_top_k_idx
.
shape
[
0
]
if
balance_strategy
==
"gshard"
:
gate_score_all
=
gate_context
m_e
=
torch
.
sum
(
F
.
softmax
(
gate_score_all
,
dim
=
1
),
dim
=
0
)
/
S
balance_dict
[
"gshard_loss"
][
layer_idx
]
=
torch
.
sum
(
c_e
*
m_e
)
/
num_expert
/
S
elif
balance_strategy
==
"noisy"
:
balance_dict
[
"noisy_loss"
][
layer_idx
]
=
gate_context
fmoe/gates.py
View file @
3c2a5979
...
...
@@ -5,6 +5,7 @@ The `NaiveGate` is the reference to implement any other gate.
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.distributions.normal
import
Normal
class
ZeroGate
(
nn
.
Module
):
...
...
@@ -12,8 +13,9 @@ class ZeroGate(nn.Module):
Guide all input samples to gate 0.
"""
def
__init__
(
self
,
_1
,
_2
,
_3
,
top_k
=
2
):
def
__init__
(
self
,
_1
,
num_expert
,
_3
,
top_k
=
2
):
super
().
__init__
()
self
.
num_expert
=
num_expert
self
.
top_k
=
top_k
def
forward
(
self
,
inp
):
...
...
@@ -23,9 +25,12 @@ class ZeroGate(nn.Module):
idx
=
torch
.
zeros
(
inp
.
shape
[
0
]
*
self
.
top_k
,
dtype
=
torch
.
int64
,
device
=
inp
.
device
)
score
=
torch
.
ones
(
inp
.
shape
[
0
]
*
self
.
top_k
,
device
=
inp
.
device
)
/
self
.
top_k
return
idx
,
score
.
reshape
(
-
1
,
1
,
self
.
top_k
)
gate_score
=
(
torch
.
ones
(
inp
.
shape
[
0
]
*
self
.
top_k
,
device
=
inp
.
device
)
/
self
.
top_k
)
gate_score_all
=
torch
.
zeros
(
inp
.
shape
[
0
],
self
.
num_expert
,
device
=
inp
.
device
)
gate_score_all
[:,
0
]
=
1
return
idx
,
gate_score
.
reshape
(
-
1
,
1
,
self
.
top_k
),
gate_score_all
class
NaiveGate
(
nn
.
Module
):
...
...
@@ -58,4 +63,129 @@ class NaiveGate(nn.Module):
gate_score
=
F
.
softmax
(
gate_top_k_val
,
dim
=-
1
).
unsqueeze
(
1
)
gate_top_k_idx
=
gate_top_k_idx
.
view
(
-
1
)
# (BxLxtop_k)
return
gate_top_k_idx
,
gate_score
return
gate_top_k_idx
,
gate_score
,
gate
class
NoisyGate
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
num_expert
,
world_size
,
top_k
=
2
):
super
().
__init__
()
self
.
num_expert
=
num_expert
*
world_size
self
.
w_gate
=
nn
.
Parameter
(
torch
.
zeros
(
d_model
,
num_expert
*
world_size
),
requires_grad
=
True
)
self
.
w_noise
=
nn
.
Parameter
(
torch
.
zeros
(
d_model
,
num_expert
*
world_size
),
requires_grad
=
True
)
self
.
top_k
=
top_k
self
.
softplus
=
nn
.
Softplus
()
self
.
softmax
=
nn
.
Softmax
(
1
)
self
.
noise_epsilon
=
1e-2
def
_gates_to_load
(
self
,
gates
):
"""Compute the true load per expert, given the gates.
The load is the number of examples for which the corresponding gate is >0.
Args:
gates: a `Tensor` of shape [batch_size, n]
Returns:
a float32 `Tensor` of shape [n]
"""
return
(
gates
>
0
).
sum
(
0
)
def
_prob_in_top_k
(
self
,
clean_values
,
noisy_values
,
noise_stddev
,
noisy_top_values
):
"""Helper function to NoisyTopKGating.
Computes the probability that value is in top k, given different random noise.
This gives us a way of backpropagating from a loss that balances the number
of times each expert is in the top k experts per example.
In the case of no noise, pass in None for noise_stddev, and the result will
not be differentiable.
Args:
clean_values: a `Tensor` of shape [batch, n].
noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus
normally distributed noise with standard deviation noise_stddev.
noise_stddev: a `Tensor` of shape [batch, n], or None
noisy_top_values: a `Tensor` of shape [batch, m].
"values" Output of tf.top_k(noisy_top_values, m). m >= k+1
Returns:
a `Tensor` of shape [batch, n].
"""
batch
=
clean_values
.
size
(
0
)
m
=
noisy_top_values
.
size
(
1
)
top_values_flat
=
noisy_top_values
.
flatten
()
threshold_positions_if_in
=
(
torch
.
arange
(
batch
,
device
=
clean_values
.
device
)
*
m
+
self
.
top_k
)
threshold_if_in
=
torch
.
unsqueeze
(
torch
.
gather
(
top_values_flat
,
0
,
threshold_positions_if_in
),
1
)
is_in
=
torch
.
gt
(
noisy_values
,
threshold_if_in
)
threshold_positions_if_out
=
threshold_positions_if_in
-
1
threshold_if_out
=
torch
.
unsqueeze
(
torch
.
gather
(
top_values_flat
,
0
,
threshold_positions_if_out
),
1
)
# is each value currently in the top k.
normal
=
Normal
(
torch
.
tensor
([
0.0
],
device
=
clean_values
.
device
),
torch
.
tensor
([
1.0
],
device
=
clean_values
.
device
),
)
prob_if_in
=
normal
.
cdf
((
clean_values
-
threshold_if_in
)
/
noise_stddev
)
prob_if_out
=
normal
.
cdf
((
clean_values
-
threshold_if_out
)
/
noise_stddev
)
prob
=
torch
.
where
(
is_in
,
prob_if_in
,
prob_if_out
)
return
prob
def
cv_squared
(
self
,
x
):
"""The squared coefficient of variation of a sample.
Useful as a loss to encourage a positive distribution to be more uniform.
Epsilons added for numerical stability.
Returns 0 for an empty Tensor.
Args:
x: a `Tensor`.
Returns:
a `Scalar`.
"""
eps
=
1e-10
# if only num_expert = 1
if
x
.
shape
[
0
]
==
1
:
return
torch
.
Tensor
([
0
])
return
x
.
float
().
var
()
/
(
x
.
float
().
mean
()
**
2
+
eps
)
def
forward
(
self
,
inp
):
clean_logits
=
inp
@
self
.
w_gate
raw_noise_stddev
=
inp
@
self
.
w_noise
noise_stddev
=
(
self
.
softplus
(
raw_noise_stddev
)
+
self
.
noise_epsilon
)
*
self
.
training
noisy_logits
=
clean_logits
+
(
torch
.
randn_like
(
clean_logits
)
*
noise_stddev
)
logits
=
noisy_logits
# calculate topk + 1 that will be needed for the noisy gates
top_logits
,
top_indices
=
logits
.
topk
(
min
(
self
.
top_k
+
1
,
self
.
num_expert
),
dim
=
1
)
top_k_logits
=
top_logits
[:,
:
self
.
top_k
]
top_k_indices
=
top_indices
[:,
:
self
.
top_k
]
top_k_gates
=
self
.
softmax
(
top_k_logits
)
zeros
=
torch
.
zeros_like
(
logits
,
requires_grad
=
True
)
gates
=
zeros
.
scatter
(
1
,
top_k_indices
,
top_k_gates
)
if
self
.
top_k
<
self
.
num_expert
:
load
=
(
self
.
_prob_in_top_k
(
clean_logits
,
noisy_logits
,
noise_stddev
,
top_logits
)
).
sum
(
0
)
else
:
load
=
self
.
_gates_to_load
(
gates
)
importance
=
gates
.
sum
(
0
)
loss
=
self
.
cv_squared
(
importance
)
+
self
.
cv_squared
(
load
)
return
(
top_k_indices
.
contiguous
().
view
(
-
1
),
top_k_gates
.
contiguous
().
unsqueeze
(
1
),
loss
,
)
fmoe/layers.py
View file @
3c2a5979
...
...
@@ -151,6 +151,7 @@ class FMoE(nn.Module):
top_k
=
2
,
gate
=
NaiveGate
,
expert
=
None
,
gate_hook
=
None
,
):
super
().
__init__
()
self
.
num_expert
=
num_expert
...
...
@@ -171,6 +172,7 @@ class FMoE(nn.Module):
self
.
experts_fused
=
False
else
:
self
.
experts_fused
=
True
self
.
gate_hook
=
gate_hook
def
expert_fn
(
self
,
inp
,
fwd_expert_count
):
r
"""
...
...
@@ -212,7 +214,9 @@ class FMoE(nn.Module):
if
self
.
mp_size
>
1
:
inp
=
Slice
.
apply
(
inp
,
self
.
mp_rank
,
self
.
mp_size
,
self
.
mp_group
)
gate_top_k_idx
,
gate_score
=
self
.
gate
(
inp
)
gate_top_k_idx
,
gate_score
,
gate_state_dict
=
self
.
gate
(
inp
)
if
self
.
gate_hook
:
self
.
gate_hook
(
gate_top_k_idx
,
gate_score
,
gate_state_dict
)
# to: (BxLxtop_k) x d_model
inp
=
inp
.
repeat_interleave
(
repeats
=
self
.
top_k
,
dim
=
0
)
x
=
_fmoe_general_global_forward
(
...
...
fmoe/megatron.py
View file @
3c2a5979
...
...
@@ -15,6 +15,8 @@ import torch.nn.functional as F
from
.transformer
import
FMoETransformerMLP
from
.distributed
import
DistributedGroupedDataParallel
from
.balance
import
update_balance_profile
,
reset_balance_profile
from
.utils
import
get_torch_default_comm
class
_FakeMegatronMLP
(
nn
.
Module
):
...
...
@@ -73,22 +75,167 @@ def _random_init_weight(self, rng):
self
.
bias
.
data
=
torch
.
from_numpy
(
bias
).
to
(
dtype
=
dtype
,
device
=
device
)
balance_dict
=
{}
num_layers
=
0
def
reset_gate_hook
():
from
megatron
import
get_args
global
balance_dict
,
num_layers
reset_balance_profile
(
balance_dict
,
num_layers
,
get_args
().
balance_strategy
)
def
get_balance_profile
():
global
balance_dict
return
balance_dict
def
generate_megatron_gate_hook
(
layer_idx
,
num_expert_global
):
from
megatron
import
get_args
balance_strategy
=
get_args
().
balance_strategy
def
megatron_gate_hook
(
gate_top_k_idx
,
gate_score_top_k
,
gate_context
):
global
balance_dict
update_balance_profile
(
balance_dict
,
gate_top_k_idx
,
gate_score_top_k
,
gate_context
,
layer_idx
,
num_expert_global
,
balance_strategy
,
)
return
megatron_gate_hook
def
add_fmoe_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
"fastmoe"
)
group
.
add_argument
(
"--fmoefy"
,
action
=
"store_true"
)
group
.
add_argument
(
"--num-experts"
,
type
=
int
,
default
=
None
)
group
.
add_argument
(
"--top-k"
,
type
=
int
,
default
=
2
)
group
.
add_argument
(
"--balance-loss-weight"
,
type
=
float
,
default
=
1
)
group
.
add_argument
(
"--balance-strategy"
,
type
=
str
,
default
=
None
)
return
parser
def
add_balance_log
(
writer
,
iteration
):
from
megatron
import
is_last_rank
balance_dict_tensor
=
torch
.
vstack
(
[
torch
.
tensor
(
item
,
device
=
item
[
0
].
device
)
for
item
in
balance_dict
.
values
()]
).
detach
()
world_group
=
get_torch_default_comm
()
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
world_group
)
torch
.
distributed
.
all_reduce
(
balance_dict_tensor
,
group
=
world_group
)
balance_dict_tensor
/=
world_size
if
writer
and
is_last_rank
():
for
idx
,
metric_name
in
enumerate
(
balance_dict
):
for
layer_id
,
val
in
enumerate
(
balance_dict_tensor
[
idx
]):
writer
.
add_scalar
(
f
"balance-
{
metric_name
}
/layer-
{
layer_id
}
"
,
val
.
item
(),
iteration
)
writer
.
add_scalar
(
f
"balance-
{
metric_name
}
/all"
,
balance_dict_tensor
[
idx
].
mean
().
item
(),
iteration
,
)
reset_gate_hook
()
def
patch_forward_step
(
forward_step_func
):
r
"""
Patch model's forward_step_func to support balance loss
"""
from
megatron.mpu
import
is_pipeline_last_stage
from
megatron
import
get_args
if
not
get_args
().
balance_strategy
:
return
forward_step_func
def
forward_step_with_balance_loss
(
data_iterator
,
model
,
input_tensor
):
args
=
get_args
()
output
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
if
is_pipeline_last_stage
():
loss_name
=
args
.
balance_strategy
+
"_loss"
(
loss
,
state_dict
),
bal_loss
=
(
output
,
(
torch
.
tensor
(
balance_dict
[
loss_name
],
device
=
balance_dict
[
loss_name
][
0
].
device
,
).
mean
()
*
args
.
balance_loss_weight
).
float
(),
)
# avarage across world group
world_group
=
get_torch_default_comm
()
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
world_group
)
averaged_bal_loss
=
bal_loss
.
clone
().
detach
()
torch
.
distributed
.
all_reduce
(
averaged_bal_loss
,
group
=
world_group
)
averaged_bal_loss
/=
world_size
loss
+=
bal_loss
state_dict
[
loss_name
]
=
averaged_bal_loss
return
loss
,
state_dict
else
:
return
output
return
forward_step_with_balance_loss
def
patch_model_provider
(
model_provider
):
from
megatron
import
get_args
def
fmoefied_model_provider
():
args
=
get_args
()
return
fmoefy
(
model_provider
(),
num_experts
=
args
.
num_experts
,
hidden_hidden_size
=
4
*
args
.
hidden_size
//
args
.
top_k
,
top_k
=
args
.
top_k
,
)
return
fmoefied_model_provider
class
MegatronMLP
(
FMoETransformerMLP
):
r
"""
Make the FMoETransformerMLP layer that distributes experts across
communication group `group` to replace the original MLP layer in Megatron.
"""
def
__init__
(
self
,
args
,
group
):
def
__init__
(
self
,
args
,
group
,
layer_idx
):
assert
(
args
.
seq_length
*
args
.
micro_batch_size
%
args
.
tensor_model_parallel_size
args
.
seq_length
*
args
.
micro_batch_size
%
args
.
tensor_model_parallel_size
==
0
),
"Batch size x sequence length should be multiple of mp size"
if
not
args
.
distributed_experts
:
world_size
=
1
else
:
world_size
=
args
.
world_size
gate
=
None
if
not
args
.
balance_strategy
or
args
.
balance_strategy
==
"gshard"
:
from
.gates
import
NaiveGate
gate
=
NaiveGate
elif
args
.
balance_strategy
==
"noisy"
:
from
.gates
import
NoisyGate
gate
=
NoisyGate
else
:
assert
False
,
"Undefined balance strategy {}"
%
(
args
.
balance_strategy
)
super
().
__init__
(
args
.
num_experts
,
top_k
=
args
.
top_k
,
...
...
@@ -97,6 +244,10 @@ class MegatronMLP(FMoETransformerMLP):
world_size
=
world_size
,
mp_group
=
group
,
expert_dp_comm
=
"none"
if
args
.
distributed_experts
else
"dp"
,
gate_hook
=
generate_megatron_gate_hook
(
layer_idx
,
args
.
num_experts
*
world_size
),
gate
=
gate
,
)
self
.
hidden_size
=
args
.
hidden_size
if
args
.
distributed_experts
:
...
...
@@ -170,8 +321,14 @@ def fmoefy(
if
distributed_experts
is
not
None
:
args
.
distributed_experts
=
distributed_experts
for
l
in
model
.
language_model
.
transformer
.
layers
:
l
.
mlp
=
MegatronMLP
(
args
,
mpu
.
get_model_parallel_group
())
for
idx
,
l
in
enumerate
(
model
.
language_model
.
transformer
.
layers
):
l
.
mlp
=
MegatronMLP
(
args
,
mpu
.
get_model_parallel_group
(),
idx
)
# initialize gate hook
global
num_layers
,
balance_dict
num_layers
=
len
(
model
.
language_model
.
transformer
.
layers
)
reset_gate_hook
()
return
model
...
...
fmoe/transformer.py
View file @
3c2a5979
...
...
@@ -48,6 +48,7 @@ class FMoETransformerMLP(FMoE):
gate
=
NaiveGate
,
top_k
=
2
,
expert_dp_comm
=
"none"
,
gate_hook
=
None
,
):
super
().
__init__
(
num_expert
=
num_expert
,
...
...
@@ -56,6 +57,7 @@ class FMoETransformerMLP(FMoE):
top_k
=
top_k
,
world_size
=
world_size
,
mp_group
=
mp_group
,
gate_hook
=
gate_hook
,
)
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
,
rank
=
self
.
mp_rank
...
...
tests/benchmark_mlp.py
View file @
3c2a5979
...
...
@@ -40,7 +40,7 @@ class BruteForceMoE(nn.Module):
def
forward
(
self
,
inp
):
if
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
gate_top_k_idx
,
gate_score
=
self
.
gate
(
inp
)
gate_top_k_idx
,
gate_score
,
_
=
self
.
gate
(
inp
)
inp
=
inp
.
repeat_interleave
(
repeats
=
self
.
top_k
,
dim
=
0
)
x
=
self
.
mlp
(
inp
,
gate_top_k_idx
,
gate_score
)
if
not
self
.
pre_lnorm
:
...
...
tests/test_numerical.py
View file @
3c2a5979
...
...
@@ -38,7 +38,7 @@ def _perform_forward(
inp
.
requires_grad
=
True
inp_raw
.
requires_grad
=
True
gate_idx
,
gate_score
=
moe
.
gate
(
inp_raw
)
gate_idx
,
gate_score
,
_
=
moe
.
gate
(
inp_raw
)
inp_repeated
=
inp_raw
.
repeat_interleave
(
repeats
=
top_k
,
dim
=
0
)
moe_out
=
moe
(
inp
)
raw_out
=
moe_raw
(
inp_repeated
,
gate_idx
,
gate_score
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment