Commit 59913cca authored by Rick Ho's avatar Rick Ho
Browse files

resolve loss reduction with customized gates

parent 18a4395c
...@@ -10,13 +10,13 @@ import fmoe_cuda ...@@ -10,13 +10,13 @@ import fmoe_cuda
from .utils import get_torch_default_comm from .utils import get_torch_default_comm
def _ensure_nccl(t, comm): def ensure_comm(t, comm):
if comm is None: if comm is None:
comm = get_torch_default_comm() comm = get_torch_default_comm()
fmoe_cuda.ensure_nccl(comm, t) fmoe_cuda.ensure_nccl(comm, t)
def count_by_gate(gate, num_expert, world_size, comm, require_pos=True): def count_by_gate(gate, num_expert, world_size, comm=None, require_pos=True):
with torch.no_grad(): with torch.no_grad():
local_expert_count = torch.zeros( local_expert_count = torch.zeros(
num_expert * world_size, device=gate.device, dtype=torch.int32 num_expert * world_size, device=gate.device, dtype=torch.int32
...@@ -25,7 +25,6 @@ def count_by_gate(gate, num_expert, world_size, comm, require_pos=True): ...@@ -25,7 +25,6 @@ def count_by_gate(gate, num_expert, world_size, comm, require_pos=True):
local_expert_count = local_expert_count.long() local_expert_count = local_expert_count.long()
if world_size > 1: if world_size > 1:
_ensure_nccl(gate, comm)
global_expert_count = fmoe_cuda.expert_exchange( global_expert_count = fmoe_cuda.expert_exchange(
local_expert_count, num_expert, world_size local_expert_count, num_expert, world_size
) )
...@@ -52,9 +51,6 @@ def prepare_forward(gate, num_expert, world_size, comm): ...@@ -52,9 +51,6 @@ def prepare_forward(gate, num_expert, world_size, comm):
world_size: number of workers that hold different experts. world_size: number of workers that hold different experts.
comm: the communicator of all workers in the expert-parallel group. comm: the communicator of all workers in the expert-parallel group.
""" """
if world_size > 1:
_ensure_nccl(gate, comm=comm)
pos, local_expert_count, global_expert_count = count_by_gate(gate, pos, local_expert_count, global_expert_count = count_by_gate(gate,
num_expert, world_size, comm) num_expert, world_size, comm)
with torch.no_grad(): with torch.no_grad():
......
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import math import math
from .functions import prepare_forward from .functions import prepare_forward, ensure_comm
from .functions import MOEScatter, MOEGather, MOELinear from .functions import MOEScatter, MOEGather, MOELinear
from .functions import AllGather, Slice from .functions import AllGather, Slice
from .gates import NaiveGate from .gates import NaiveGate
...@@ -212,6 +212,8 @@ class FMoE(nn.Module): ...@@ -212,6 +212,8 @@ class FMoE(nn.Module):
according to the gate. The score of the selected gate given by the according to the gate. The score of the selected gate given by the
expert is multiplied to the experts' output tensors as a weight. expert is multiplied to the experts' output tensors as a weight.
""" """
if self.world_size > 1:
ensure_comm(inp, self.moe_group)
if self.mp_size > 1: if self.mp_size > 1:
inp = Slice.apply(inp, self.mp_rank, self.mp_size, self.mp_group) inp = Slice.apply(inp, self.mp_rank, self.mp_size, self.mp_group)
......
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
from fmoe.balance import reset_balance_profile from fmoe.balance import reset_balance_profile
from fmoe.balance import update_balance_profile from fmoe.balance import update_balance_profile
from fmoe.utils import get_torch_default_comm from fmoe.utils import get_torch_default_comm
from .distributed import get_moe_group
balance_dict = {} balance_dict = {}
...@@ -101,11 +102,11 @@ def patch_forward_step(forward_step_func): ...@@ -101,11 +102,11 @@ def patch_forward_step(forward_step_func):
torch.cat(loss_list).mean() * args.balance_loss_weight torch.cat(loss_list).mean() * args.balance_loss_weight
) )
# avarage across world group # avarage across moe group
world_group = get_torch_default_comm() moe_group = get_moe_group()
world_size = torch.distributed.get_world_size(group=world_group) world_size = torch.distributed.get_world_size(group=moe_group)
averaged_bal_loss = bal_loss.clone().detach() averaged_bal_loss = bal_loss.clone().detach()
torch.distributed.all_reduce(averaged_bal_loss, group=world_group) torch.distributed.all_reduce(averaged_bal_loss, group=moe_group)
averaged_bal_loss /= world_size averaged_bal_loss /= world_size
loss += bal_loss loss += bal_loss
......
...@@ -12,6 +12,10 @@ def set_moe_group(moe_group): ...@@ -12,6 +12,10 @@ def set_moe_group(moe_group):
_moe_group = moe_group _moe_group = moe_group
def get_moe_group():
return _moe_group
class DistributedDataParallel(DistributedGroupedDataParallel): class DistributedDataParallel(DistributedGroupedDataParallel):
r""" r"""
A wrapper that is used to replace the DDP module provided by Megatron, which A wrapper that is used to replace the DDP module provided by Megatron, which
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment