Unverified Commit baae8fb9 authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #31 from laekov/gate

Reconstruct gate and add gshard / switch
parents 3c42c892 8d14dd29
import sys
from collections import OrderedDict
from typing import List, Type, Union
import pytest
import torch
import torch.nn as nn
import numpy as np
from copy import deepcopy
from fmoe.functions import MOEGather, MOEScatter, count_by_gate
from test_numerical import _assert_numerical
@pytest.mark.parametrize("n_expert", [1, 4, 8])
@pytest.mark.parametrize("topk", [1, 2])
@pytest.mark.parametrize("batch_size", [12])
@pytest.mark.parametrize("d_model", [6])
@pytest.mark.parametrize("world_size", [1])
def test_scatter(n_expert, topk, batch_size, d_model, world_size):
gate_idx = torch.randint(n_expert + 1, (batch_size, topk)) - 1
gate_idx = gate_idx.long().cuda()
pos, lec, gec = count_by_gate(gate_idx, n_expert, world_size)
fbs = int(gec.sum().item())
inp = torch.rand(batch_size, d_model).cuda()
inp.requires_grad = True
out = MOEScatter.apply(inp, pos % batch_size, lec, gec, fbs, world_size)
out.sum().backward()
inp_raw = inp.data.clone()
out_raw = torch.empty(pos.shape[0], d_model,
device=inp.device, dtype=inp.dtype)
# out_raw.sum().backward()
for i, f in enumerate(pos.cpu()):
out_raw[i] = inp[f % batch_size]
_assert_numerical(['out'], [out], [out_raw], 0)
# TODO: check grad
if __name__ == '__main__':
test_scatter(4, 2, 8, 6, 1)
...@@ -24,7 +24,7 @@ def _perform_forward( ...@@ -24,7 +24,7 @@ def _perform_forward(
inp = torch.rand(batch_size, d_model).type(data_type).cuda() inp = torch.rand(batch_size, d_model).type(data_type).cuda()
if mp_group: if mp_group is not None:
group_sender = rank // mp_group.size() * mp_group.size() group_sender = rank // mp_group.size() * mp_group.size()
torch.distributed.broadcast(inp, group_sender, group=mp_group) torch.distributed.broadcast(inp, group_sender, group=mp_group)
torch.distributed.broadcast( torch.distributed.broadcast(
...@@ -38,10 +38,9 @@ def _perform_forward( ...@@ -38,10 +38,9 @@ def _perform_forward(
inp.requires_grad = True inp.requires_grad = True
inp_raw.requires_grad = True inp_raw.requires_grad = True
gate_idx, gate_score, _ = moe.gate(inp_raw) gate_idx, gate_score = moe.gate(inp_raw)
inp_repeated = inp_raw.repeat_interleave(repeats=top_k, dim=0)
moe_out = moe(inp) moe_out = moe(inp)
raw_out = moe_raw(inp_repeated, gate_idx, gate_score) raw_out = moe_raw(inp_raw, gate_idx, gate_score)
raw_out.mean().backward() raw_out.mean().backward()
moe_out.mean().backward() moe_out.mean().backward()
...@@ -51,7 +50,7 @@ def _perform_forward( ...@@ -51,7 +50,7 @@ def _perform_forward(
def _assert_numerical(names, moe_out_list, raw_out_list, rank, precision=1e-3): def _assert_numerical(names, moe_out_list, raw_out_list, rank, precision=1e-3):
for name, mo, ro in zip(names, moe_out_list, raw_out_list): for name, mo, ro in zip(names, moe_out_list, raw_out_list):
err = (mo - ro).abs().sum() err = (mo - ro).abs().max()
print("Rank {} {} abs err {}".format(rank, name, err)) print("Rank {} {} abs err {}".format(rank, name, err))
if err > precision: if err > precision:
sys.stderr.write(f"=========== {name} moe out ==============\n") sys.stderr.write(f"=========== {name} moe out ==============\n")
......
import os
import sys
import json
import torch import torch
from fmoe.layers import _fmoe_general_global_forward from fmoe.layers import _fmoe_general_global_forward
from fmoe import FMoETransformerMLP from fmoe import FMoETransformerMLP
from test_ddp import _run_distributed
class ConstantGate(torch.nn.Module): class ConstantGate(torch.nn.Module):
def __init__(self, d_model, num_expert, world_size, top_k=1): def __init__(self, d_model, num_expert, world_size, top_k=1):
...@@ -9,13 +14,24 @@ class ConstantGate(torch.nn.Module): ...@@ -9,13 +14,24 @@ class ConstantGate(torch.nn.Module):
self.top_k = top_k self.top_k = top_k
def forward(self, inp): def forward(self, inp):
idx = torch.zeros((inp.shape[0] * self.top_k,), dtype=torch.int64, idx = torch.zeros((inp.shape[0], self.top_k), dtype=torch.int64,
device=inp.device) device=inp.device)
score = torch.ones((inp.shape[0], 1, self.top_k), device=inp.device) / 2 score = torch.ones((inp.shape[0], 1, self.top_k), device=inp.device) / 2
return idx, score, None return idx, score
def test_zero_fwd(num_expert=2, batch_size=4, d_hidden=8, world_size=1): def test_zero_fwd(num_expert=2, batch_size=4, d_hidden=8, world_size=1):
_run_distributed('_test_zero_fwd',
1,
{
'num_expert': num_expert,
'batch_size': batch_size,
'd_hidden': d_hidden
},
script=__file__
)
def _test_zero_fwd(num_expert=2, batch_size=4, d_hidden=8, world_size=1):
inp = torch.rand(batch_size, d_hidden).cuda() inp = torch.rand(batch_size, d_hidden).cuda()
gate = torch.zeros(batch_size, dtype=torch.int64).cuda() gate = torch.zeros(batch_size, dtype=torch.int64).cuda()
x = _fmoe_general_global_forward(inp, gate, lambda x, y: x, num_expert, x = _fmoe_general_global_forward(inp, gate, lambda x, y: x, num_expert,
...@@ -23,6 +39,17 @@ def test_zero_fwd(num_expert=2, batch_size=4, d_hidden=8, world_size=1): ...@@ -23,6 +39,17 @@ def test_zero_fwd(num_expert=2, batch_size=4, d_hidden=8, world_size=1):
def test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1): def test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1):
_run_distributed('_test_zero_transformer',
1,
{
'num_expert': num_expert,
'batch_size': batch_size,
'd_hidden': d_hidden
},
script=__file__
)
def _test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1):
inp = torch.rand(batch_size, d_hidden).cuda() inp = torch.rand(batch_size, d_hidden).cuda()
model = FMoETransformerMLP(num_expert, d_hidden, d_hidden * 4, world_size, model = FMoETransformerMLP(num_expert, d_hidden, d_hidden * 4, world_size,
gate=ConstantGate).cuda() gate=ConstantGate).cuda()
...@@ -30,9 +57,16 @@ def test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1): ...@@ -30,9 +57,16 @@ def test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1):
if __name__ == '__main__': if __name__ == '__main__':
torch.distributed.init_process_group(backend="nccl") if len(sys.argv) >= 3:
torch.cuda.set_device(torch.distributed.get_rank()) args = json.loads(sys.argv[2])
# test_zero_fwd(world_size=torch.distributed.get_world_size()) os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0")
test_zero_transformer(num_expert=16, batch_size=4096, d_hidden=1024, os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
world_size=torch.distributed.get_world_size()) os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"]
print('done') torch.distributed.init_process_group(backend="nccl")
args['world_size'] = torch.distributed.get_world_size()
locals()[sys.argv[1]](**args)
else:
# test_zero_fwd(world_size=torch.distributed.get_world_size())
test_zero_transformer(num_expert=16, batch_size=4096, d_hidden=1024,
world_size=1)
print('done')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment