Unverified Commit 3c2a5979 authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #19 from laekov/balance

Add balance strategy
parents c1e67585 e028f2ec
......@@ -167,10 +167,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("expert_exchange", &moe_expert_exchange, "MoE expert exchange (CUDA)");
m.def("global_scatter", &moe_global_scatter, "MoE global scatter (CUDA)");
m.def("global_gather", &moe_global_gather, "MoE global gather (CUDA)");
m.def("global_fused_forward", &moe_global_fused_forward,
m.def("global_fused_forward", &moe_global_fused_forward,
"MoE global gather (CUDA)");
m.def("ensure_nccl", &moe_ensure_nccl, "MoE ensure torch nccl comm");
#endif
m.def("forward", &moe_forward, "MoE forward (CUDA)");
m.def("backward", &moe_backward, "MoE backward (CUDA)");
}
}
\ No newline at end of file
import torch
import torch.nn.functional as F
metrics = {
"coefficient-variation": lambda c_e: torch.std(c_e) / torch.mean(c_e),
"Lmax-over-Lmin": lambda c_e: (torch.max(c_e) + 1) / (torch.min(c_e) + 1),
"Lmax-over-Lmean": lambda c_e: torch.max(c_e) / torch.mean(c_e),
}
def reset_balance_profile(balance_dict, num_layers, balance_strategy):
for key in metrics:
balance_dict[key] = [None for _ in range(num_layers)]
if balance_strategy:
balance_dict[f"{balance_strategy}_loss"] = [None for _ in range(num_layers)]
def update_balance_profile(
balance_dict,
gate_top_k_idx,
_gate_score_top_k,
gate_context,
layer_idx,
num_expert,
balance_strategy,
):
c_e = torch.scatter_add(
torch.zeros(num_expert, device=gate_top_k_idx.device),
0,
gate_top_k_idx,
torch.ones_like(gate_top_k_idx, dtype=torch.float),
)
for key in metrics:
balance_dict[key][layer_idx] = metrics[key](c_e)
S = gate_top_k_idx.shape[0]
if balance_strategy == "gshard":
gate_score_all = gate_context
m_e = torch.sum(F.softmax(gate_score_all, dim=1), dim=0) / S
balance_dict["gshard_loss"][layer_idx] = torch.sum(c_e * m_e) / num_expert / S
elif balance_strategy == "noisy":
balance_dict["noisy_loss"][layer_idx] = gate_context
......@@ -5,6 +5,7 @@ The `NaiveGate` is the reference to implement any other gate.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal
class ZeroGate(nn.Module):
......@@ -12,8 +13,9 @@ class ZeroGate(nn.Module):
Guide all input samples to gate 0.
"""
def __init__(self, _1, _2, _3, top_k=2):
def __init__(self, _1, num_expert, _3, top_k=2):
super().__init__()
self.num_expert = num_expert
self.top_k = top_k
def forward(self, inp):
......@@ -23,9 +25,12 @@ class ZeroGate(nn.Module):
idx = torch.zeros(
inp.shape[0] * self.top_k, dtype=torch.int64, device=inp.device
)
score = torch.ones(inp.shape[0] * self.top_k,
device=inp.device) / self.top_k
return idx, score.reshape(-1, 1, self.top_k)
gate_score = (
torch.ones(inp.shape[0] * self.top_k, device=inp.device) / self.top_k
)
gate_score_all = torch.zeros(inp.shape[0], self.num_expert, device=inp.device)
gate_score_all[:, 0] = 1
return idx, gate_score.reshape(-1, 1, self.top_k), gate_score_all
class NaiveGate(nn.Module):
......@@ -58,4 +63,129 @@ class NaiveGate(nn.Module):
gate_score = F.softmax(gate_top_k_val, dim=-1).unsqueeze(1)
gate_top_k_idx = gate_top_k_idx.view(-1) # (BxLxtop_k)
return gate_top_k_idx, gate_score
return gate_top_k_idx, gate_score, gate
class NoisyGate(nn.Module):
def __init__(self, d_model, num_expert, world_size, top_k=2):
super().__init__()
self.num_expert = num_expert * world_size
self.w_gate = nn.Parameter(
torch.zeros(d_model, num_expert * world_size), requires_grad=True
)
self.w_noise = nn.Parameter(
torch.zeros(d_model, num_expert * world_size), requires_grad=True
)
self.top_k = top_k
self.softplus = nn.Softplus()
self.softmax = nn.Softmax(1)
self.noise_epsilon = 1e-2
def _gates_to_load(self, gates):
"""Compute the true load per expert, given the gates.
The load is the number of examples for which the corresponding gate is >0.
Args:
gates: a `Tensor` of shape [batch_size, n]
Returns:
a float32 `Tensor` of shape [n]
"""
return (gates > 0).sum(0)
def _prob_in_top_k(
self, clean_values, noisy_values, noise_stddev, noisy_top_values
):
"""Helper function to NoisyTopKGating.
Computes the probability that value is in top k, given different random noise.
This gives us a way of backpropagating from a loss that balances the number
of times each expert is in the top k experts per example.
In the case of no noise, pass in None for noise_stddev, and the result will
not be differentiable.
Args:
clean_values: a `Tensor` of shape [batch, n].
noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus
normally distributed noise with standard deviation noise_stddev.
noise_stddev: a `Tensor` of shape [batch, n], or None
noisy_top_values: a `Tensor` of shape [batch, m].
"values" Output of tf.top_k(noisy_top_values, m). m >= k+1
Returns:
a `Tensor` of shape [batch, n].
"""
batch = clean_values.size(0)
m = noisy_top_values.size(1)
top_values_flat = noisy_top_values.flatten()
threshold_positions_if_in = (
torch.arange(batch, device=clean_values.device) * m + self.top_k
)
threshold_if_in = torch.unsqueeze(
torch.gather(top_values_flat, 0, threshold_positions_if_in), 1
)
is_in = torch.gt(noisy_values, threshold_if_in)
threshold_positions_if_out = threshold_positions_if_in - 1
threshold_if_out = torch.unsqueeze(
torch.gather(top_values_flat, 0, threshold_positions_if_out), 1
)
# is each value currently in the top k.
normal = Normal(
torch.tensor([0.0], device=clean_values.device),
torch.tensor([1.0], device=clean_values.device),
)
prob_if_in = normal.cdf((clean_values - threshold_if_in) / noise_stddev)
prob_if_out = normal.cdf((clean_values - threshold_if_out) / noise_stddev)
prob = torch.where(is_in, prob_if_in, prob_if_out)
return prob
def cv_squared(self, x):
"""The squared coefficient of variation of a sample.
Useful as a loss to encourage a positive distribution to be more uniform.
Epsilons added for numerical stability.
Returns 0 for an empty Tensor.
Args:
x: a `Tensor`.
Returns:
a `Scalar`.
"""
eps = 1e-10
# if only num_expert = 1
if x.shape[0] == 1:
return torch.Tensor([0])
return x.float().var() / (x.float().mean() ** 2 + eps)
def forward(self, inp):
clean_logits = inp @ self.w_gate
raw_noise_stddev = inp @ self.w_noise
noise_stddev = (
self.softplus(raw_noise_stddev) + self.noise_epsilon
) * self.training
noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev)
logits = noisy_logits
# calculate topk + 1 that will be needed for the noisy gates
top_logits, top_indices = logits.topk(
min(self.top_k + 1, self.num_expert), dim=1
)
top_k_logits = top_logits[:, : self.top_k]
top_k_indices = top_indices[:, : self.top_k]
top_k_gates = self.softmax(top_k_logits)
zeros = torch.zeros_like(logits, requires_grad=True)
gates = zeros.scatter(1, top_k_indices, top_k_gates)
if self.top_k < self.num_expert:
load = (
self._prob_in_top_k(
clean_logits, noisy_logits, noise_stddev, top_logits
)
).sum(0)
else:
load = self._gates_to_load(gates)
importance = gates.sum(0)
loss = self.cv_squared(importance) + self.cv_squared(load)
return (
top_k_indices.contiguous().view(-1),
top_k_gates.contiguous().unsqueeze(1),
loss,
)
......@@ -151,6 +151,7 @@ class FMoE(nn.Module):
top_k=2,
gate=NaiveGate,
expert=None,
gate_hook=None,
):
super().__init__()
self.num_expert = num_expert
......@@ -171,6 +172,7 @@ class FMoE(nn.Module):
self.experts_fused = False
else:
self.experts_fused = True
self.gate_hook = gate_hook
def expert_fn(self, inp, fwd_expert_count):
r"""
......@@ -212,7 +214,9 @@ class FMoE(nn.Module):
if self.mp_size > 1:
inp = Slice.apply(inp, self.mp_rank, self.mp_size, self.mp_group)
gate_top_k_idx, gate_score = self.gate(inp)
gate_top_k_idx, gate_score, gate_state_dict = self.gate(inp)
if self.gate_hook:
self.gate_hook(gate_top_k_idx, gate_score, gate_state_dict)
# to: (BxLxtop_k) x d_model
inp = inp.repeat_interleave(repeats=self.top_k, dim=0)
x = _fmoe_general_global_forward(
......
......@@ -15,6 +15,8 @@ import torch.nn.functional as F
from .transformer import FMoETransformerMLP
from .distributed import DistributedGroupedDataParallel
from .balance import update_balance_profile, reset_balance_profile
from .utils import get_torch_default_comm
class _FakeMegatronMLP(nn.Module):
......@@ -73,22 +75,167 @@ def _random_init_weight(self, rng):
self.bias.data = torch.from_numpy(bias).to(dtype=dtype, device=device)
balance_dict = {}
num_layers = 0
def reset_gate_hook():
from megatron import get_args
global balance_dict, num_layers
reset_balance_profile(balance_dict, num_layers, get_args().balance_strategy)
def get_balance_profile():
global balance_dict
return balance_dict
def generate_megatron_gate_hook(layer_idx, num_expert_global):
from megatron import get_args
balance_strategy = get_args().balance_strategy
def megatron_gate_hook(gate_top_k_idx, gate_score_top_k, gate_context):
global balance_dict
update_balance_profile(
balance_dict,
gate_top_k_idx,
gate_score_top_k,
gate_context,
layer_idx,
num_expert_global,
balance_strategy,
)
return megatron_gate_hook
def add_fmoe_args(parser):
group = parser.add_argument_group(title="fastmoe")
group.add_argument("--fmoefy", action="store_true")
group.add_argument("--num-experts", type=int, default=None)
group.add_argument("--top-k", type=int, default=2)
group.add_argument("--balance-loss-weight", type=float, default=1)
group.add_argument("--balance-strategy", type=str, default=None)
return parser
def add_balance_log(writer, iteration):
from megatron import is_last_rank
balance_dict_tensor = torch.vstack(
[torch.tensor(item, device=item[0].device) for item in balance_dict.values()]
).detach()
world_group = get_torch_default_comm()
world_size = torch.distributed.get_world_size(group=world_group)
torch.distributed.all_reduce(balance_dict_tensor, group=world_group)
balance_dict_tensor /= world_size
if writer and is_last_rank():
for idx, metric_name in enumerate(balance_dict):
for layer_id, val in enumerate(balance_dict_tensor[idx]):
writer.add_scalar(
f"balance-{metric_name}/layer-{layer_id}", val.item(), iteration
)
writer.add_scalar(
f"balance-{metric_name}/all",
balance_dict_tensor[idx].mean().item(),
iteration,
)
reset_gate_hook()
def patch_forward_step(forward_step_func):
r"""
Patch model's forward_step_func to support balance loss
"""
from megatron.mpu import is_pipeline_last_stage
from megatron import get_args
if not get_args().balance_strategy:
return forward_step_func
def forward_step_with_balance_loss(data_iterator, model, input_tensor):
args = get_args()
output = forward_step_func(data_iterator, model, input_tensor)
if is_pipeline_last_stage():
loss_name = args.balance_strategy + "_loss"
(loss, state_dict), bal_loss = (
output,
(
torch.tensor(
balance_dict[loss_name],
device=balance_dict[loss_name][0].device,
).mean()
* args.balance_loss_weight
).float(),
)
# avarage across world group
world_group = get_torch_default_comm()
world_size = torch.distributed.get_world_size(group=world_group)
averaged_bal_loss = bal_loss.clone().detach()
torch.distributed.all_reduce(averaged_bal_loss, group=world_group)
averaged_bal_loss /= world_size
loss += bal_loss
state_dict[loss_name] = averaged_bal_loss
return loss, state_dict
else:
return output
return forward_step_with_balance_loss
def patch_model_provider(model_provider):
from megatron import get_args
def fmoefied_model_provider():
args = get_args()
return fmoefy(
model_provider(),
num_experts=args.num_experts,
hidden_hidden_size=4 * args.hidden_size // args.top_k,
top_k=args.top_k,
)
return fmoefied_model_provider
class MegatronMLP(FMoETransformerMLP):
r"""
Make the FMoETransformerMLP layer that distributes experts across
communication group `group` to replace the original MLP layer in Megatron.
"""
def __init__(self, args, group):
def __init__(self, args, group, layer_idx):
assert (
args.seq_length * args.micro_batch_size
% args.tensor_model_parallel_size
args.seq_length * args.micro_batch_size % args.tensor_model_parallel_size
== 0
), "Batch size x sequence length should be multiple of mp size"
if not args.distributed_experts:
world_size = 1
else:
world_size = args.world_size
gate = None
if not args.balance_strategy or args.balance_strategy == "gshard":
from .gates import NaiveGate
gate = NaiveGate
elif args.balance_strategy == "noisy":
from .gates import NoisyGate
gate = NoisyGate
else:
assert False, "Undefined balance strategy {}" % (args.balance_strategy)
super().__init__(
args.num_experts,
top_k=args.top_k,
......@@ -97,6 +244,10 @@ class MegatronMLP(FMoETransformerMLP):
world_size=world_size,
mp_group=group,
expert_dp_comm="none" if args.distributed_experts else "dp",
gate_hook=generate_megatron_gate_hook(
layer_idx, args.num_experts * world_size
),
gate=gate,
)
self.hidden_size = args.hidden_size
if args.distributed_experts:
......@@ -170,8 +321,14 @@ def fmoefy(
if distributed_experts is not None:
args.distributed_experts = distributed_experts
for l in model.language_model.transformer.layers:
l.mlp = MegatronMLP(args, mpu.get_model_parallel_group())
for idx, l in enumerate(model.language_model.transformer.layers):
l.mlp = MegatronMLP(args, mpu.get_model_parallel_group(), idx)
# initialize gate hook
global num_layers, balance_dict
num_layers = len(model.language_model.transformer.layers)
reset_gate_hook()
return model
......
......@@ -48,6 +48,7 @@ class FMoETransformerMLP(FMoE):
gate=NaiveGate,
top_k=2,
expert_dp_comm="none",
gate_hook=None,
):
super().__init__(
num_expert=num_expert,
......@@ -56,6 +57,7 @@ class FMoETransformerMLP(FMoE):
top_k=top_k,
world_size=world_size,
mp_group=mp_group,
gate_hook=gate_hook,
)
self.experts = _Expert(
num_expert, d_model, d_hidden, activation, rank=self.mp_rank
......
......@@ -40,7 +40,7 @@ class BruteForceMoE(nn.Module):
def forward(self, inp):
if self.pre_lnorm:
inp = self.layer_norm(inp)
gate_top_k_idx, gate_score = self.gate(inp)
gate_top_k_idx, gate_score, _ = self.gate(inp)
inp = inp.repeat_interleave(repeats=self.top_k, dim=0)
x = self.mlp(inp, gate_top_k_idx, gate_score)
if not self.pre_lnorm:
......
......@@ -38,7 +38,7 @@ def _perform_forward(
inp.requires_grad = True
inp_raw.requires_grad = True
gate_idx, gate_score = moe.gate(inp_raw)
gate_idx, gate_score, _ = moe.gate(inp_raw)
inp_repeated = inp_raw.repeat_interleave(repeats=top_k, dim=0)
moe_out = moe(inp)
raw_out = moe_raw(inp_repeated, gate_idx, gate_score)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment