Unverified Commit 537679a8 authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #60 from laekov/remove-comm

Remove unnecessary dependencies on comm
parents 50a9aa94 c8483d42
......@@ -24,18 +24,5 @@ def update_balance_profile(
num_expert,
balance_strategy,
):
c_e = torch.scatter_add(
torch.zeros(num_expert, device=gate_top_k_idx.device),
0,
gate_top_k_idx,
torch.ones_like(gate_top_k_idx, dtype=torch.float),
)
for key in metrics:
balance_dict[key][layer_idx] = metrics[key](c_e)
S = gate_top_k_idx.shape[0]
if balance_strategy == "gshard":
gate_score_all = gate_context
m_e = torch.sum(F.softmax(gate_score_all, dim=1), dim=0) / S
balance_dict["gshard_loss"][layer_idx] = torch.sum(c_e * m_e) / num_expert / S
elif balance_strategy == "noisy":
balance_dict["noisy_loss"][layer_idx] = gate_context
# Fill in this function to conduct balance related jobs
pass
......@@ -16,7 +16,7 @@ def ensure_comm(t, comm):
fmoe_cuda.ensure_nccl(comm, t)
def count_by_gate(gate, num_expert, world_size, comm=None, require_pos=True):
def count_by_gate(gate, num_expert, world_size, require_pos=True):
with torch.no_grad():
local_expert_count = torch.zeros(
num_expert * world_size, device=gate.device, dtype=torch.int32
......@@ -40,7 +40,7 @@ def count_by_gate(gate, num_expert, world_size, comm=None, require_pos=True):
return pos, local_expert_count, global_expert_count
def prepare_forward(gate, num_expert, world_size, comm):
def prepare_forward(gate, num_expert, world_size):
r"""
Prepare necessary information from gate output for MoE computation.
......@@ -52,7 +52,7 @@ def prepare_forward(gate, num_expert, world_size, comm):
comm: the communicator of all workers in the expert-parallel group.
"""
pos, local_expert_count, global_expert_count = count_by_gate(gate,
num_expert, world_size, comm)
num_expert, world_size)
with torch.no_grad():
fwd_expert_count = global_expert_count.view(world_size,
num_expert).sum(dim=0)
......
......@@ -74,8 +74,7 @@ def mark_module_parallel_comm(module, comm):
setattr(p, "dp_comm", comm)
def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size,
comm=None):
def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
r"""
A private function that performs the following steps to complete the MoE
computation.
......@@ -93,7 +92,7 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size,
global_expert_count,
fwd_expert_count,
fwd_batch_size,
) = prepare_forward(gate, num_expert, world_size, comm)
) = prepare_forward(gate, num_expert, world_size)
topk = 1
if len(gate.shape) == 2:
topk = gate.shape[1]
......@@ -219,6 +218,9 @@ class FMoE(nn.Module):
gate_top_k_idx, gate_score = self.gate(inp)
if self.gate_hook is not None:
self.gate_hook(gate_top_k_idx, gate_score, None)
# delete masked tensors
if self.mask is not None and self.mask_dict is not None:
mask = self.mask.view(-1)
......@@ -227,9 +229,8 @@ class FMoE(nn.Module):
gate_top_k_idx = gate_top_k_idx[mask == 0, :]
fwd = _fmoe_general_global_forward(
inp,
gate_top_k_idx,
self.expert_fn, self.num_expert, self.world_size, self.moe_group
inp, gate_top_k_idx,
self.expert_fn, self.num_expert, self.world_size
)
# recover deleted tensors
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment