Commit a468db2b authored by Rick Ho's avatar Rick Ho
Browse files

fix bugs to run on multiple gpus

parent 38b334cc
...@@ -10,6 +10,12 @@ import fmoe_cuda ...@@ -10,6 +10,12 @@ import fmoe_cuda
from .utils import get_torch_default_comm from .utils import get_torch_default_comm
def _ensure_nccl(t, comm=None):
if comm is None:
comm = get_torch_default_comm()
fmoe_cuda.ensure_nccl(comm, t)
def count_by_gate(gate, num_expert, world_size): def count_by_gate(gate, num_expert, world_size):
# TODO: support -1 in gate, which means ignore this input # TODO: support -1 in gate, which means ignore this input
with torch.no_grad(): with torch.no_grad():
...@@ -21,6 +27,7 @@ def count_by_gate(gate, num_expert, world_size): ...@@ -21,6 +27,7 @@ def count_by_gate(gate, num_expert, world_size):
local_expert_count.index_put_((gate_idx.long(),), gate_count) local_expert_count.index_put_((gate_idx.long(),), gate_count)
if world_size > 1: if world_size > 1:
_ensure_nccl(gate)
(global_expert_count,) = fmoe_cuda.expert_exchange( (global_expert_count,) = fmoe_cuda.expert_exchange(
local_expert_count, num_expert, world_size local_expert_count, num_expert, world_size
) )
...@@ -29,7 +36,6 @@ def count_by_gate(gate, num_expert, world_size): ...@@ -29,7 +36,6 @@ def count_by_gate(gate, num_expert, world_size):
return pos, local_expert_count, global_expert_count return pos, local_expert_count, global_expert_count
def prepare_forward(gate, num_expert, world_size, comm=None): def prepare_forward(gate, num_expert, world_size, comm=None):
r""" r"""
Prepare necessary information from gate output for MoE computation. Prepare necessary information from gate output for MoE computation.
...@@ -42,9 +48,7 @@ def prepare_forward(gate, num_expert, world_size, comm=None): ...@@ -42,9 +48,7 @@ def prepare_forward(gate, num_expert, world_size, comm=None):
comm: the communicator of all workers in the expert-parallel group. comm: the communicator of all workers in the expert-parallel group.
""" """
if world_size > 1: if world_size > 1:
if comm is None: _ensure_nccl(gate, comm=comm)
comm = get_torch_default_comm()
fmoe_cuda.ensure_nccl(comm, gate)
pos, local_expert_count, global_expert_count = count_by_gate(gate, pos, local_expert_count, global_expert_count = count_by_gate(gate,
num_expert, world_size) num_expert, world_size)
......
...@@ -21,7 +21,7 @@ class GShardGate(NaiveGate): ...@@ -21,7 +21,7 @@ class GShardGate(NaiveGate):
top_k = topk_idx.shape[0] // gate_score.shape[0] top_k = topk_idx.shape[0] // gate_score.shape[0]
top1_idx = topk_idx.view((-1, top_k))[:, 0] top1_idx = topk_idx.view((-1, top_k))[:, 0]
c_e = torch.scatter_add( c_e = torch.scatter_add(
torch.zeros(self.num_expert, device=top1_idx.device), torch.zeros(self.tot_expert, device=top1_idx.device),
0, 0,
top1_idx, top1_idx,
torch.ones_like(top1_idx, dtype=torch.float), torch.ones_like(top1_idx, dtype=torch.float),
......
...@@ -14,7 +14,7 @@ def limit_by_capacity(topk_idx, num_expert, world_size, capacity): ...@@ -14,7 +14,7 @@ def limit_by_capacity(topk_idx, num_expert, world_size, capacity):
new_gec, = fmoe_native.limit_by_capacity(gec, capacity, new_gec, = fmoe_native.limit_by_capacity(gec, capacity,
num_expert, world_size) num_expert, world_size)
if world_size > 1: if world_size > 1:
new_lec = fmoe_native.expert_exchange(new_gec, num_expert, world_size) new_lec, = fmoe_native.expert_exchange(new_gec, num_expert, world_size)
else: else:
new_lec = new_gec new_lec = new_gec
......
...@@ -28,7 +28,7 @@ def test_gshard_gate(d_model, batch_size, n_expert, cap): ...@@ -28,7 +28,7 @@ def test_gshard_gate(d_model, batch_size, n_expert, cap):
capacity=(cap, cap)).cuda() capacity=(cap, cap)).cuda()
x = torch.rand(batch_size, d_model).cuda() x = torch.rand(batch_size, d_model).cuda()
topk_idx, topk_val = gate(x) topk_idx, topk_val = gate(x)
counts = [0 for _ in range(n_expert)] counts = [0 for _ in range(n_expert * dist.get_world_size())]
for v in topk_idx.cpu().view(-1).numpy(): for v in topk_idx.cpu().view(-1).numpy():
if v != -1: if v != -1:
counts[v] += 1 counts[v] += 1
...@@ -47,7 +47,7 @@ def test_switch_gate(d_model, batch_size, n_expert, cap): ...@@ -47,7 +47,7 @@ def test_switch_gate(d_model, batch_size, n_expert, cap):
capacity=(cap, cap)).cuda() capacity=(cap, cap)).cuda()
x = torch.rand(batch_size, d_model).cuda() x = torch.rand(batch_size, d_model).cuda()
topk_idx, topk_val = gate(x) topk_idx, topk_val = gate(x)
counts = [0 for _ in range(n_expert)] counts = [0 for _ in range(n_expert * dist.get_world_size())]
for v in topk_idx.cpu().view(-1).numpy(): for v in topk_idx.cpu().view(-1).numpy():
if v != -1: if v != -1:
counts[v] += 1 counts[v] += 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment