Commit 59913cca authored by Rick Ho's avatar Rick Ho
Browse files

resolve loss reduction with customized gates

parent 18a4395c
......@@ -10,13 +10,13 @@ import fmoe_cuda
from .utils import get_torch_default_comm
def _ensure_nccl(t, comm):
def ensure_comm(t, comm):
if comm is None:
comm = get_torch_default_comm()
fmoe_cuda.ensure_nccl(comm, t)
def count_by_gate(gate, num_expert, world_size, comm, require_pos=True):
def count_by_gate(gate, num_expert, world_size, comm=None, require_pos=True):
with torch.no_grad():
local_expert_count = torch.zeros(
num_expert * world_size, device=gate.device, dtype=torch.int32
......@@ -25,7 +25,6 @@ def count_by_gate(gate, num_expert, world_size, comm, require_pos=True):
local_expert_count = local_expert_count.long()
if world_size > 1:
_ensure_nccl(gate, comm)
global_expert_count = fmoe_cuda.expert_exchange(
local_expert_count, num_expert, world_size
)
......@@ -52,9 +51,6 @@ def prepare_forward(gate, num_expert, world_size, comm):
world_size: number of workers that hold different experts.
comm: the communicator of all workers in the expert-parallel group.
"""
if world_size > 1:
_ensure_nccl(gate, comm=comm)
pos, local_expert_count, global_expert_count = count_by_gate(gate,
num_expert, world_size, comm)
with torch.no_grad():
......
......@@ -5,7 +5,7 @@ import torch
import torch.nn as nn
import math
from .functions import prepare_forward
from .functions import prepare_forward, ensure_comm
from .functions import MOEScatter, MOEGather, MOELinear
from .functions import AllGather, Slice
from .gates import NaiveGate
......@@ -212,6 +212,8 @@ class FMoE(nn.Module):
according to the gate. The score of the selected gate given by the
expert is multiplied to the experts' output tensors as a weight.
"""
if self.world_size > 1:
ensure_comm(inp, self.moe_group)
if self.mp_size > 1:
inp = Slice.apply(inp, self.mp_rank, self.mp_size, self.mp_group)
......
......@@ -5,6 +5,7 @@ import torch
from fmoe.balance import reset_balance_profile
from fmoe.balance import update_balance_profile
from fmoe.utils import get_torch_default_comm
from .distributed import get_moe_group
balance_dict = {}
......@@ -101,11 +102,11 @@ def patch_forward_step(forward_step_func):
torch.cat(loss_list).mean() * args.balance_loss_weight
)
# avarage across world group
world_group = get_torch_default_comm()
world_size = torch.distributed.get_world_size(group=world_group)
# avarage across moe group
moe_group = get_moe_group()
world_size = torch.distributed.get_world_size(group=moe_group)
averaged_bal_loss = bal_loss.clone().detach()
torch.distributed.all_reduce(averaged_bal_loss, group=world_group)
torch.distributed.all_reduce(averaged_bal_loss, group=moe_group)
averaged_bal_loss /= world_size
loss += bal_loss
......
......@@ -12,6 +12,10 @@ def set_moe_group(moe_group):
_moe_group = moe_group
def get_moe_group():
return _moe_group
class DistributedDataParallel(DistributedGroupedDataParallel):
r"""
A wrapper that is used to replace the DDP module provided by Megatron, which
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment