Unverified Commit 537679a8 authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #60 from laekov/remove-comm

Remove unnecessary dependencies on comm
parents 50a9aa94 c8483d42
...@@ -24,18 +24,5 @@ def update_balance_profile( ...@@ -24,18 +24,5 @@ def update_balance_profile(
num_expert, num_expert,
balance_strategy, balance_strategy,
): ):
c_e = torch.scatter_add( # Fill in this function to conduct balance related jobs
torch.zeros(num_expert, device=gate_top_k_idx.device), pass
0,
gate_top_k_idx,
torch.ones_like(gate_top_k_idx, dtype=torch.float),
)
for key in metrics:
balance_dict[key][layer_idx] = metrics[key](c_e)
S = gate_top_k_idx.shape[0]
if balance_strategy == "gshard":
gate_score_all = gate_context
m_e = torch.sum(F.softmax(gate_score_all, dim=1), dim=0) / S
balance_dict["gshard_loss"][layer_idx] = torch.sum(c_e * m_e) / num_expert / S
elif balance_strategy == "noisy":
balance_dict["noisy_loss"][layer_idx] = gate_context
...@@ -16,7 +16,7 @@ def ensure_comm(t, comm): ...@@ -16,7 +16,7 @@ def ensure_comm(t, comm):
fmoe_cuda.ensure_nccl(comm, t) fmoe_cuda.ensure_nccl(comm, t)
def count_by_gate(gate, num_expert, world_size, comm=None, require_pos=True): def count_by_gate(gate, num_expert, world_size, require_pos=True):
with torch.no_grad(): with torch.no_grad():
local_expert_count = torch.zeros( local_expert_count = torch.zeros(
num_expert * world_size, device=gate.device, dtype=torch.int32 num_expert * world_size, device=gate.device, dtype=torch.int32
...@@ -40,7 +40,7 @@ def count_by_gate(gate, num_expert, world_size, comm=None, require_pos=True): ...@@ -40,7 +40,7 @@ def count_by_gate(gate, num_expert, world_size, comm=None, require_pos=True):
return pos, local_expert_count, global_expert_count return pos, local_expert_count, global_expert_count
def prepare_forward(gate, num_expert, world_size, comm): def prepare_forward(gate, num_expert, world_size):
r""" r"""
Prepare necessary information from gate output for MoE computation. Prepare necessary information from gate output for MoE computation.
...@@ -52,7 +52,7 @@ def prepare_forward(gate, num_expert, world_size, comm): ...@@ -52,7 +52,7 @@ def prepare_forward(gate, num_expert, world_size, comm):
comm: the communicator of all workers in the expert-parallel group. comm: the communicator of all workers in the expert-parallel group.
""" """
pos, local_expert_count, global_expert_count = count_by_gate(gate, pos, local_expert_count, global_expert_count = count_by_gate(gate,
num_expert, world_size, comm) num_expert, world_size)
with torch.no_grad(): with torch.no_grad():
fwd_expert_count = global_expert_count.view(world_size, fwd_expert_count = global_expert_count.view(world_size,
num_expert).sum(dim=0) num_expert).sum(dim=0)
......
...@@ -74,8 +74,7 @@ def mark_module_parallel_comm(module, comm): ...@@ -74,8 +74,7 @@ def mark_module_parallel_comm(module, comm):
setattr(p, "dp_comm", comm) setattr(p, "dp_comm", comm)
def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size, def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
comm=None):
r""" r"""
A private function that performs the following steps to complete the MoE A private function that performs the following steps to complete the MoE
computation. computation.
...@@ -93,7 +92,7 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size, ...@@ -93,7 +92,7 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size,
global_expert_count, global_expert_count,
fwd_expert_count, fwd_expert_count,
fwd_batch_size, fwd_batch_size,
) = prepare_forward(gate, num_expert, world_size, comm) ) = prepare_forward(gate, num_expert, world_size)
topk = 1 topk = 1
if len(gate.shape) == 2: if len(gate.shape) == 2:
topk = gate.shape[1] topk = gate.shape[1]
...@@ -219,6 +218,9 @@ class FMoE(nn.Module): ...@@ -219,6 +218,9 @@ class FMoE(nn.Module):
gate_top_k_idx, gate_score = self.gate(inp) gate_top_k_idx, gate_score = self.gate(inp)
if self.gate_hook is not None:
self.gate_hook(gate_top_k_idx, gate_score, None)
# delete masked tensors # delete masked tensors
if self.mask is not None and self.mask_dict is not None: if self.mask is not None and self.mask_dict is not None:
mask = self.mask.view(-1) mask = self.mask.view(-1)
...@@ -227,9 +229,8 @@ class FMoE(nn.Module): ...@@ -227,9 +229,8 @@ class FMoE(nn.Module):
gate_top_k_idx = gate_top_k_idx[mask == 0, :] gate_top_k_idx = gate_top_k_idx[mask == 0, :]
fwd = _fmoe_general_global_forward( fwd = _fmoe_general_global_forward(
inp, inp, gate_top_k_idx,
gate_top_k_idx, self.expert_fn, self.num_expert, self.world_size
self.expert_fn, self.num_expert, self.world_size, self.moe_group
) )
# recover deleted tensors # recover deleted tensors
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment