"vscode:/vscode.git/clone" did not exist on "50369291d2d491fd496ad7fd9feb3cbbe14ea021"
Commit 0f091a1d authored by Sugon_ldc's avatar Sugon_ldc
Browse files

add fastmoe project

parents
Pipeline #263 failed with stages
in 0 seconds
r"""
Utility in Megatron
"""
def add_fmoe_args(parser):
group = parser.add_argument_group(title="fastmoe")
group.add_argument("--fmoefy", action="store_true")
group.add_argument("--num-experts", type=int, default=None)
group.add_argument("--top-k", type=int, default=2)
group.add_argument("--balance-loss-weight", type=float, default=1)
group.add_argument("--balance-strategy", type=str, default=None)
group.add_argument("--hidden-hidden-size", type=int, default=None)
return parser
r"""
Adaption to act as the MLP layer using an MoE MLP layer in transformer.
"""
import torch
import torch.nn as nn
from .layers import FMoE
from .linear import FMoELinear
class _Expert(nn.Module):
r"""
An expert using 2 FMoELinear modules to speed up the computation of experts
within one worker.
"""
def __init__(self, num_expert, d_model, d_hidden, activation, rank=0):
super().__init__()
self.htoh4 = FMoELinear(num_expert, d_model, d_hidden, bias=True, rank=rank)
self.h4toh = FMoELinear(num_expert, d_hidden, d_model, bias=True, rank=rank)
self.activation = activation
def forward(self, inp, fwd_expert_count):
r"""
First expand input to 4h (the hidden size is variable, but is called h4
for convenience). Then perform activation. Finally shirink back to h.
"""
x = self.htoh4(inp, fwd_expert_count)
x = self.activation(x)
x = self.h4toh(x, fwd_expert_count)
return x
class FMoETransformerMLP(FMoE):
r"""
A complete MoE MLP module in a Transformer block.
* `activation` is the activation function to be used in MLP in each expert.
* `d_hidden` is the dimension of the MLP layer.
"""
def __init__(
self,
num_expert=32,
d_model=1024,
d_hidden=4096,
activation=torch.nn.GELU(),
expert_dp_comm="none",
expert_rank=0,
**kwargs
):
super().__init__(num_expert=num_expert, d_model=d_model, **kwargs)
self.experts = _Expert(
num_expert, d_model, d_hidden, activation, rank=expert_rank
)
self.mark_parallel_comm(expert_dp_comm)
def forward(self, inp: torch.Tensor):
r"""
This module wraps up the FMoE module with reshape, residual and layer
normalization.
"""
original_shape = inp.shape
inp = inp.reshape(-1, self.d_model)
output = super().forward(inp)
return output.reshape(original_shape)
r"""
Utils to play with PyTorch.
"""
import torch.distributed as dist
# pylint: disable=broad-except
# pylint: disable=protected-access
def get_torch_default_comm():
r"""
The NCCL communicator is needed so that Fast MoE can perform customized
communication operators in the C code. However, it is not a publicly
available variable. Therefore, a hacking class of the `ProcessGroupNCCL`
in Fast MoE's C code takes the `_default_pg` and tries to dig the
communicator out from the object. As PyTorch's private interface varies from
time to time, different hacking techniques are tried one-by-one to be
compatible with various versions of PyTorch.
"""
try:
comm = dist.distributed_c10d._get_default_group()
return comm
except Exception as _:
pass
try:
comm = dist.distributed_c10d._default_pg
if comm is not None:
return comm
except Exception as _:
pass
raise RuntimeError("Unsupported PyTorch version")
import setuptools
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import os
import torch
cxx_flags = []
ext_libs = []
authors = [
'Jiaao He',
'Jiezhong Qiu',
'Aohan Zeng',
'Tiago Antunes',
'Jinjun Peng',
'Qin Li',
]
is_rocm_pytorch = False
if torch.__version__ >= '1.5':
from torch.utils.cpp_extension import ROCM_HOME
is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False
if os.environ.get('USE_NCCL', '1') == '1':
cxx_flags.append('-DFMOE_USE_NCCL')
cxx_flags.append('-DUSE_C10D_NCCL')
if is_rocm_pytorch:
ext_libs.append('rccl')
else:
ext_libs.append('nccl')
if is_rocm_pytorch:
define_macros=[('FMOE_USE_HIP', None)]
else:
define_macros=[]
if __name__ == '__main__':
setuptools.setup(
name='fastmoe',
version='0.3.0',
description='An efficient Mixture-of-Experts system for PyTorch',
author=', '.join(authors),
author_email='hja20@mails.tsinghua.edu.cn',
license='Apache-2',
url='https://github.com/laekov/fastmoe',
packages=['fmoe', 'fmoe.megatron', 'fmoe.gates'],
ext_modules=[
CUDAExtension(
name='fmoe_cuda',
sources=[
'cuda/stream_manager.cpp',
'cuda/local_exchange.cu',
'cuda/balancing.cu',
'cuda/global_exchange.cpp',
'cuda/parallel_linear.cu',
'cuda/fmoe_cuda.cpp',
],
define_macros=define_macros,
extra_compile_args={
'cxx': cxx_flags,
'nvcc': cxx_flags
},
libraries=ext_libs
)
],
cmdclass={
'build_ext': BuildExtension
})
import torch
import torch.nn as nn
from fmoe import FMoETransformerMLP
from fmoe.gates import NaiveGate
from moe import BruteForceMoELinear
import time
import sys
import os
rank = None
world_size = None
dev_name_default = "cuda:0"
class BruteForceMoE(nn.Module):
def __init__(
self,
num_expert=32,
d_model=1024,
d_hidden=4096,
world_size=1,
mp_group=None,
activation=torch.nn.functional.gelu,
gate=NaiveGate,
top_k=1,
pre_lnorm=False,
):
assert world_size == 1, "Distributed brute force is not supported"
super().__init__()
self.mlp = BruteForceMoELinear(
activation, num_expert, d_model, d_hidden, 1, top_k
)
self.top_k = top_k
self.gate = gate(d_model, num_expert, world_size, top_k)
self.pre_lnorm = pre_lnorm
self.layer_norm = nn.LayerNorm(d_model)
self.d_model = d_model
def forward(self, inp):
if self.pre_lnorm:
inp = self.layer_norm(inp)
gate_top_k_idx, gate_score = self.gate(inp)
inp = inp.repeat_interleave(repeats=self.top_k, dim=0)
x = self.mlp(inp, gate_top_k_idx, gate_score)
if not self.pre_lnorm:
x = self.layer_norm(x)
return x
def benchmark_mlp(MOELayer, batch_size, in_feat, hidden_feat, num_expert, top_k):
torch.manual_seed(42 + rank)
torch.cuda.manual_seed(42 + rank)
if rank == 0:
print(
"Performance test of {} mm size {} {}x{} experts {}x{} topk {}".format(
MOELayer.__name__,
batch_size,
in_feat,
hidden_feat,
world_size,
num_expert,
top_k,
)
)
if world_size > 1:
dev_name = "cuda"
else:
dev_name = dev_name_default
inp = torch.rand(batch_size, in_feat).cuda(dev_name)
inp.requires_grad = True
moe = MOELayer(
num_expert=num_expert,
d_model=in_feat,
d_hidden=hidden_feat,
world_size=world_size,
top_k=top_k,
).cuda(dev_name)
moe.train()
# warm up
for _ in range(4):
_ = moe(inp)
n_runs = 16
tott = 0.0
backt = 0.0
maxt = 0.0
sqtot = 0.0
for i in range(n_runs):
ts = time.time()
o = moe(inp)
te = time.time()
loss = o.sum()
bts = time.time()
loss.backward()
bte = time.time()
tott += te - ts
sqtot += (te - ts) ** 2
maxt = max(maxt, te - ts)
backt += bte - bts
gflops = (
2e-9
* n_runs
* (
in_feat * hidden_feat * batch_size * top_k * 2
+ batch_size * in_feat * num_expert
)
/ tott
)
print(
"Time mean/max/stdev/back {:.3f} {:.3f} {:.3f} {:.3f} ms, {:.3f} GFLOPs".format(
tott * 1e3 / n_runs,
maxt * 1e3,
(sqtot / n_runs - (tott / n_runs) ** 2) * 1e3 * top_k / n_runs,
backt * 1e3 / n_runs,
gflops,
)
)
if __name__ == "__main__":
if int(os.environ["WORLD_SIZE"]) > 1:
torch.distributed.init_process_group(backend="nccl")
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
else:
rank = 0
world_size = 1
batch_size = int(os.environ.get("BATCH_SIZE", "4096"))
d_model = int(os.environ.get("D_MODEL", "1024"))
d_hidden = int(os.environ.get("D_HIDDEN", "4096"))
num_expert = int(os.environ.get("NUM_EXPERT", "64"))
top_k = int(os.environ.get("TOP_K", "2"))
benchmark_mlp(FMoETransformerMLP, batch_size, d_model, d_hidden, num_expert, top_k)
if world_size == 1:
benchmark_mlp(BruteForceMoE, batch_size, d_model, d_hidden, num_expert, top_k)
import math
from torch import nn
import torch
class BruteForceMoELinear(nn.Module):
def __init__(
self,
activation,
num_expert=32,
d_model=1024,
d_hidden=2048,
world_size=1,
top_k=2,
):
super(BruteForceMoELinear, self).__init__()
self.num_expert = num_expert
self.d_model = d_model
self.activation = activation
self.weight_htoh4 = nn.Parameter(
torch.Tensor(num_expert * world_size, d_hidden, d_model)
)
self.bias_htoh4 = nn.Parameter(torch.Tensor(num_expert * world_size, d_hidden))
self.weight_h4toh = nn.Parameter(
torch.Tensor(num_expert * world_size, d_model, d_hidden)
)
self.bias_h4toh = nn.Parameter(torch.Tensor(num_expert * world_size, d_model))
self.top_k = top_k
def forward(self, inp, gate_idx, gate_score):
inp = inp.repeat_interleave(repeats=self.top_k, dim=0)
gate_long = gate_idx.long().view(-1)
batch_size = inp.size(0)
o = torch.empty(batch_size, self.d_model, dtype=inp.dtype, device=inp.device)
for i in range(self.weight_htoh4.shape[0]):
idx = gate_long == i
x = inp[idx]
x = x @ self.weight_htoh4[i].t()
x = x + self.bias_htoh4[i]
x = self.activation(x)
x = x @ self.weight_h4toh[i].t()
x = x + self.bias_h4toh[i]
o[idx] = x
gate_score = gate_score.unsqueeze(1)
x = torch.bmm(gate_score, o.view(-1, self.top_k, self.d_model)).reshape(
-1, self.d_model
)
return x
class BruteForceMoE(nn.Module):
def __init__(self, expert, num_expert=32, d_model=1024, world_size=1, top_k=2):
super(BruteForceMoE, self).__init__()
self.num_expert = num_expert
self.d_model = d_model
self.top_k = top_k
if type(expert) is list:
self.experts = [e(d_model) for e in expert]
self.num_expert = num_expert = len(expert)
else:
self.experts = [expert(d_model) for _ in range(num_expert * world_size)]
def forward(self, inp, gate_idx, gate_score):
inp = inp.repeat_interleave(repeats=self.top_k, dim=0)
gate_long = gate_idx.long().view(-1)
batch_size = inp.size(0)
x = inp.new_zeros((batch_size, self.d_model))
for i in range(batch_size):
x[i] = self.experts[gate_long[i]](inp[i])
gate_score = gate_score.unsqueeze(1)
x = torch.bmm(gate_score, x.view(-1, self.top_k, self.d_model)).reshape(
-1, self.d_model
)
return x
class NaiveExpert(nn.Module):
def __init__(self, d_model):
super(NaiveExpert, self).__init__()
self.linear = nn.Linear(d_model, d_model).cuda()
def forward(self, x):
return self.linear(x)
class LinearExpert(nn.Module):
def __init__(self, d_model):
super(LinearExpert, self).__init__()
self.model = nn.Sequential(
nn.Linear(d_model, d_model * 2), nn.ReLU(), nn.Linear(d_model * 2, d_model),
).cuda()
def forward(self, x):
return self.model(x)
#!/bin/bash
if [ -z $MASTER_ADDR ]
then
if [ -z $SLURM_JOB_ID ]
then
export MASTER_ADDR=localhost
else
export MASTER_ADDR=$(scontrol show JobId=$SLURM_JOB_ID | grep BatchHost | tr '=' ' ' | awk '{print $2}')
fi
fi
if [ -z $MASTER_PORT ]
then
export MASTER_PORT=12215
fi
if [ ! -z $OMPI_COMM_WORLD_RANK ]
then
RANK=$OMPI_COMM_WORLD_RANK
localrank=$OMPI_COMM_WORLD_LOCAL_RANK
elif [ ! -z $SLURM_PROCID ]
then
export RANK=$SLURM_PROCID
export WORLD_SIZE=$SLURM_NPROCS
localrank=$SLURM_LOCALID
else
RANK=0
localrank=0
WORLD_SIZE=1
fi
export CUDA_VISIBLE_DEVICES=$localrank
exec $@
import json
import os
import sys
from typing import Dict
import pytest
import torch
import torch.distributed as dist
from test_numerical import test_fmoe as _test_fmoe
from test_numerical import test_fmoe_linear as _test_fmoe_linear
from test_numerical import _test_fmoe_local_ddp
def _ensure_initialized():
if not dist.is_initialized():
os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0")
os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"]
os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "localhost")
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "12211")
dist.init_process_group(backend="nccl")
def _run_distributed(func, world_size, args: Dict, script=__file__):
if torch.cuda.device_count() < world_size:
pytest.skip("No enough GPU")
import subprocess
import os
ps = []
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "36666"
os.environ["OMPI_COMM_WORLD_SIZE"] = str(world_size)
for i in range(world_size):
os.environ["OMPI_COMM_WORLD_RANK"] = str(i)
os.environ["HIP_VISIBLE_DEVICES"] = str(i)
p = subprocess.Popen(
[sys.executable, script, func, json.dumps(args)], stdout=subprocess.PIPE
)
ps.append(p)
for p in ps:
p.wait()
retc = p.poll()
assert retc == 0
@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("top_k", [2])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("d_hidden", [32])
@pytest.mark.parametrize("mp_size", [1, 2])
@pytest.mark.parametrize("data_type", ['torch.FloatTensor', 'torch.DoubleTensor', 'torch.HalfTensor'])
def test_fmoe_linear_distributed(
num_expert, top_k, batch_size, d_model, d_hidden, mp_size, data_type
):
_run_distributed(
"_test_fmoe_linear",
mp_size * 2,
{
"num_expert": num_expert,
"top_k": top_k,
"batch_size": batch_size,
"d_model": d_model,
"d_hidden": d_hidden,
"mp_size": mp_size,
"data_type": data_type
},
)
@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("top_k", [2])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("expert", ["NaiveExpert", "LinearExpert"])
@pytest.mark.parametrize("mp_size", [1, 2])
def test_fmoe_distributed(num_expert, top_k, batch_size, d_model, expert, mp_size):
_run_distributed(
"_test_fmoe",
mp_size * 2,
{
"num_expert": num_expert,
"top_k": top_k,
"batch_size": batch_size,
"d_model": d_model,
"expert": expert,
"mp_size": mp_size,
},
)
@pytest.mark.parametrize("mp_size", [1, 2])
def test_fmoe_local_ddp(mp_size):
_run_distributed(
_test_fmoe_local_ddp.__name__, mp_size * 2, {"mp_size": mp_size},
)
if __name__ == "__main__":
if len(sys.argv) >= 3:
args = json.loads(sys.argv[2])
os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0")
os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"]
torch.distributed.init_process_group(backend="nccl")
args["rank"] = torch.distributed.get_rank()
args["world_size"] = torch.distributed.get_world_size()
args["mp_group"] = [
torch.distributed.new_group(
ranks=[j * args["mp_size"] + i for i in range(args["mp_size"])],
backend="nccl",
)
for j in range(args["world_size"] // args["mp_size"])
][args["rank"] // args["mp_size"]]
args["dp_group"] = [
torch.distributed.new_group(
ranks=[
i * args["mp_size"] + j
for i in range(args["world_size"] // args["mp_size"])
],
backend="nccl",
)
for j in range(args["mp_size"])
][args["rank"] % args["mp_size"]]
args["world_group"] = torch.distributed.new_group(
ranks=list(range(args["world_size"])), backend="nccl",
)
del args["mp_size"]
locals()[sys.argv[1]](**args)
else:
test_fmoe_local_ddp(mp_size=2)
test_fmoe_linear_distributed(
num_expert=4, top_k=2, batch_size=4, d_model=8, d_hidden=8, mp_size=2,
data_type="torch.HalfTensor"
)
import os
import pytest
import torch
from fmoe.gates import NaiveGate
from fmoe.layers import FMoE
from fmoe.transformer import _Expert
n_devices = int(os.environ.get("N_GPUS", "2"))
class MyMoE(FMoE):
def __init__(self, num_expert, d_model, d_hidden, top_k, activation):
super().__init__(
num_expert=num_expert,
d_model=d_model,
gate=NaiveGate,
world_size=1,
mp_group=None,
top_k=top_k,
)
self.experts = _Expert(num_expert, d_model, d_hidden, activation)
@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("top_k", [2, 3])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("d_hidden", [32])
def test_fmoe_dp(
num_expert,
top_k,
batch_size,
d_model,
d_hidden,
activation=torch.nn.functional.gelu,
):
torch.manual_seed(42)
torch.cuda.manual_seed(42)
moe = MyMoE(num_expert, d_model, d_hidden, top_k, activation).cuda()
moe_dp = torch.nn.DataParallel(moe, device_ids=list(range(n_devices)))
for i in range(5):
output = moe_dp(torch.rand(batch_size, d_model).cuda())
if __name__ == "__main__":
test_fmoe_dp(4, 2, 4, 16, 32)
import pytest
import os
import sys
import json
import math
import torch
import torch.distributed as dist
import torch.nn.functional as F
from fmoe.gates import GShardGate, SwitchGate
from test_ddp import _run_distributed
import pdb
def _ensure_initialized():
if not dist.is_initialized():
os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0")
os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"]
os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "localhost")
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "12211")
dist.init_process_group(backend="nccl")
@pytest.mark.parametrize("d_model", [1024])
@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("n_expert", [1, 4])
@pytest.mark.parametrize("cap", [.1, 1.1])
def test_gshard_gate(d_model, batch_size, n_expert, cap):
#pdb.set_trace()
if 1 * n_expert < 2:
pytest.skip("No enough experts")
_run_distributed('_test_gshard_gate',
1,
{
'd_model': d_model,
'batch_size': batch_size,
'n_expert': n_expert,
'cap': cap
},
script=__file__
)
def _test_gshard_gate(d_model, batch_size, n_expert, cap):
_ensure_initialized()
#pdb.set_trace()
gate = GShardGate(d_model, n_expert, dist.get_world_size(),
capacity=(cap, cap)).cuda()
x = torch.rand(batch_size, d_model).cuda()
topk_idx, topk_val = gate(x)
counts = [0 for _ in range(n_expert * dist.get_world_size())]
for v in topk_idx.cpu().view(-1).numpy():
if v != -1:
counts[v] += 1
real_cap = math.ceil(cap * batch_size)
for i in counts:
assert(i <= real_cap)
gate_score = gate.gate(x)
for i in range(batch_size):
for j in range(gate.top_k):
v = topk_idx[i, j]
if v != -1:
assert topk_val[i, j] == gate_score[i, v]
@pytest.mark.parametrize("d_model", [1024])
@pytest.mark.parametrize("batch_size", [4096])
@pytest.mark.parametrize("n_expert", [1, 16])
@pytest.mark.parametrize("cap", [.1, .8])
def test_switch_gate(d_model, batch_size, n_expert, cap):
#pdb.set_trace()
_run_distributed('_test_switch_gate',
1,
{
'd_model': d_model,
'batch_size': batch_size,
'n_expert': n_expert,
'cap': cap
},
script=__file__
)
#_test_switch_gate(d_model, batch_size, n_expert, cap)
def _test_switch_gate(d_model, batch_size, n_expert, cap):
_ensure_initialized()
#pdb.set_trace()
#random.seed(1)
#np.random.seed(1)
torch.manual_seed(1)
gate = SwitchGate(d_model, n_expert, dist.get_world_size(),
capacity=(cap, cap)).cuda()
x = torch.rand(batch_size, d_model).cuda()
rng = torch.cuda.get_rng_state() # save rng state
topk_idx, topk_val = gate(x)
counts = [0 for _ in range(n_expert * dist.get_world_size())]
for v in topk_idx.cpu().view(-1).numpy():
if v != -1:
counts[v] += 1
real_cap = math.ceil(cap * batch_size)
for i in counts:
assert(i <= real_cap)
#pdb.set_trace()
score = gate.gate(x)
if gate.training:
# reset rng state to make sure noise is the same as in gate.forward()
torch.cuda.set_rng_state(rng)
# random uniform number from [1-eps, 1+eps]
noise = torch.rand_like(score)
noise = noise * 2 * gate.switch_eps + 1.0 - gate.switch_eps
#pdb.set_trace()
score += noise
# fp32 softmax for numerical stability
score = F.softmax(score.float(), dim=-1)
for i in range(batch_size):
v = topk_idx[i]
if v != -1:
assert round(topk_val[i].item(),4)== round(score[i,topk_idx[i]].item(),4)
#assert topk_val[i] == score[i, topk_idx[i]]
#if topk_val[i] == score[i,topk_idx[i]]:
# print(topk_val[i],score[i,topk_idx[i]])
#print(topk_val,score)
if __name__ == '__main__':
if len(sys.argv) >= 3:
args = json.loads(sys.argv[2])
locals()[sys.argv[1]](**args)
else:
_ensure_initialized()
# test_gshard_gate(4096, 1024, 4, .2)
test_switch_gate(8, 16, 4, .1)
# test_switch_gate(4096, 1024, 4, .2)
import sys
from collections import OrderedDict
from typing import List, Type, Union
import pytest
import torch
import torch.nn as nn
import numpy as np
from copy import deepcopy
from fmoe.functions import MOEGather, MOEScatter, count_by_gate
from test_numerical import _assert_numerical
@pytest.mark.parametrize("n_expert", [1, 4, 8])
@pytest.mark.parametrize("topk", [1, 2])
@pytest.mark.parametrize("batch_size", [12])
@pytest.mark.parametrize("d_model", [6])
@pytest.mark.parametrize("world_size", [1])
def test_scatter(n_expert, topk, batch_size, d_model, world_size):
gate_idx = torch.randint(n_expert + 1, (batch_size, topk)) - 1
gate_idx = gate_idx.long().cuda()
pos, lec, gec = count_by_gate(gate_idx, n_expert, world_size)
fbs = int(gec.sum().item())
inp = torch.rand(batch_size, d_model).cuda()
inp.requires_grad = True
out = MOEScatter.apply(inp, pos % batch_size, lec, gec, fbs, world_size)
out.sum().backward()
inp_raw = inp.data.clone()
out_raw = torch.empty(pos.shape[0], d_model,
device=inp.device, dtype=inp.dtype)
# out_raw.sum().backward()
for i, f in enumerate(pos.cpu()):
out_raw[i] = inp[f % batch_size]
_assert_numerical(['out'], [out], [out_raw], 0)
# TODO: check grad
if __name__ == '__main__':
test_scatter(4, 2, 8, 6, 1)
import sys
from collections import OrderedDict
from typing import List, Type, Union
import pytest
import torch
import torch.nn as nn
import numpy as np
from copy import deepcopy
from fmoe.gates import NaiveGate
from fmoe.layers import FMoE
from fmoe.transformer import _Expert
from fmoe.distributed import DistributedGroupedDataParallel as LocalDDP
from fmoe.megatron.layers import _megatron_init_method
from moe import BruteForceMoELinear, BruteForceMoE, NaiveExpert, LinearExpert
def _perform_forward(
moe: nn.Module, moe_raw: nn.Module, batch_size, d_model, top_k, rank, mp_group, data_type='torch.FloatTensor'
):
moe.zero_grad()
moe_raw.zero_grad()
inp = torch.rand(batch_size, d_model).type(data_type).cuda()
if mp_group is not None:
group_sender = rank // mp_group.size() * mp_group.size()
torch.distributed.broadcast(inp, group_sender, group=mp_group)
torch.distributed.broadcast(
moe.gate.gate.weight.data, group_sender, group=mp_group
)
torch.distributed.broadcast(
moe.gate.gate.bias.data, group_sender, group=mp_group
)
inp_raw = inp.clone()
inp.requires_grad = True
inp_raw.requires_grad = True
gate_idx, gate_score = moe.gate(inp_raw)
moe_out = moe(inp)
raw_out = moe_raw(inp_raw, gate_idx, gate_score)
raw_out.mean().backward()
moe_out.mean().backward()
return moe_out, raw_out, inp.grad, inp_raw.grad
def _assert_numerical(names, moe_out_list, raw_out_list, rank, precision=1e-3):
for name, mo, ro in zip(names, moe_out_list, raw_out_list):
err = (mo - ro).abs().max()
print("Rank {} {} abs err {}".format(rank, name, err))
if err > precision:
sys.stderr.write(f"=========== {name} moe out ==============\n")
sys.stderr.write("{}\n".format(mo))
sys.stderr.write(f"=========== {name} raw out ==============\n")
sys.stderr.write("{}\n".format(ro))
sys.stderr.write(f"=========== {name} diff ==============\n")
sys.stderr.write("{}\n{}\n".format((mo - ro).abs(), err))
assert False
class MyMoE(FMoE):
def __init__(
self, num_expert, d_model, d_hidden, world_size, mp_group, top_k, activation
):
super().__init__(
num_expert=num_expert,
d_model=d_model,
gate=NaiveGate,
world_size=world_size,
mp_group=mp_group,
top_k=top_k,
)
self.experts = _Expert(num_expert, d_model, d_hidden, activation)
rng = np.random.default_rng(1234)
_megatron_init_method(self.experts.htoh4, rng, 1.0)
_megatron_init_method(self.experts.h4toh, rng, 1.0)
@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("top_k", [2, 3])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("d_hidden", [32])
@pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1])
@pytest.mark.parametrize("mp_group", [None])
@pytest.mark.parametrize("dp_group", [None])
@pytest.mark.parametrize("world_group", [None])
@pytest.mark.parametrize("data_type", ['torch.FloatTensor', 'torch.DoubleTensor', 'torch.HalfTensor'])
def test_fmoe_linear(
num_expert,
top_k,
batch_size,
d_model,
d_hidden,
rank,
world_size,
mp_group,
dp_group,
world_group,
data_type,
activation=torch.nn.functional.gelu,
):
torch.manual_seed(42 + rank)
torch.cuda.manual_seed(42 + rank)
moe = MyMoE(
num_expert, d_model, d_hidden, world_size, mp_group, top_k, activation
).type(data_type).cuda()
moe_raw = BruteForceMoELinear(
activation=activation,
num_expert=num_expert,
d_model=d_model,
d_hidden=d_hidden,
world_size=world_size,
top_k=top_k,
).type(data_type).cuda()
if world_size == 1:
moe_raw.weight_htoh4.data = moe.experts.htoh4.weight.data.clone()
moe_raw.bias_htoh4.data = moe.experts.htoh4.bias.data.clone()
moe_raw.weight_h4toh.data = moe.experts.h4toh.weight.data.clone()
moe_raw.bias_h4toh.data = moe.experts.h4toh.bias.data.clone()
else:
weight_htoh4_array = [
torch.empty_like(moe.experts.htoh4.weight.data) for _ in range(world_size)
]
bias_htoh4_array = [
torch.empty_like(moe.experts.htoh4.bias.data) for _ in range(world_size)
]
torch.distributed.all_gather(weight_htoh4_array, moe.experts.htoh4.weight.data)
torch.distributed.all_gather(bias_htoh4_array, moe.experts.htoh4.bias.data)
moe_raw.weight_htoh4.data = torch.cat(weight_htoh4_array, dim=0)
moe_raw.bias_htoh4.data = torch.cat(bias_htoh4_array, dim=0)
weight_h4toh_array = [
torch.empty_like(moe.experts.h4toh.weight.data) for _ in range(world_size)
]
bias_h4toh_array = [
torch.empty_like(moe.experts.h4toh.bias.data) for _ in range(world_size)
]
torch.distributed.all_gather(weight_h4toh_array, moe.experts.h4toh.weight.data)
torch.distributed.all_gather(bias_h4toh_array, moe.experts.h4toh.bias.data)
moe_raw.weight_h4toh.data = torch.cat(weight_h4toh_array, dim=0)
moe_raw.bias_h4toh.data = torch.cat(bias_h4toh_array, dim=0)
moe_out, raw_out, moe_grad_in, raw_grad_in = _perform_forward(
moe, moe_raw, batch_size, d_model, top_k, rank, mp_group, data_type=data_type
)
moe_out_list = (
moe_out,
moe_grad_in,
moe.experts.htoh4.weight.grad,
moe.experts.h4toh.weight.grad,
moe.experts.htoh4.bias.grad,
moe.experts.h4toh.bias.grad,
)
raw_out_list = (
raw_out,
raw_grad_in,
moe_raw.weight_htoh4.grad,
moe_raw.weight_h4toh.grad,
moe_raw.bias_htoh4.grad,
moe_raw.bias_h4toh.grad,
)
if world_size > 1:
_, __, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad = raw_out_list
torch.distributed.all_reduce(htoh4_w_grad)
torch.distributed.all_reduce(h4toh_w_grad)
torch.distributed.all_reduce(htoh4_b_grad)
torch.distributed.all_reduce(h4toh_b_grad)
mp_size = mp_group.size() if mp_group else 1
htoh4_w_grad = (
htoh4_w_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
)
h4toh_w_grad = (
h4toh_w_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
)
htoh4_b_grad = (
htoh4_b_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
)
h4toh_b_grad = (
h4toh_b_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
)
raw_out_list = _, __, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad
names = [
"output",
"input grad",
"htoh4 weight grad",
"h4toh weight grad",
"htoh4 bias grad",
"h4toh bias grad",
]
precision = 5e-1 if data_type == 'torch.HalfTensor' else 1e-3
_assert_numerical(names, moe_out_list, raw_out_list, rank, precision=precision)
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("top_k", [2, 3])
@pytest.mark.parametrize("expert", [NaiveExpert, LinearExpert])
@pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1])
@pytest.mark.parametrize("mp_group", [None])
@pytest.mark.parametrize("dp_group", [None])
@pytest.mark.parametrize("world_group", [None])
def test_fmoe(
batch_size,
num_expert,
d_model,
top_k,
expert: Union[Type[nn.Module], str],
rank,
world_size,
mp_group,
dp_group,
world_group,
):
torch.manual_seed(42 + rank)
torch.cuda.manual_seed(42 + rank)
if isinstance(expert, str):
expert = globals()[expert]
moe = FMoE(
num_expert=num_expert,
d_model=d_model,
gate=NaiveGate,
world_size=world_size,
mp_group=mp_group,
expert=expert,
top_k=top_k,
).cuda()
moe_raw = BruteForceMoE(
expert=expert,
num_expert=num_expert,
d_model=d_model,
world_size=world_size,
top_k=top_k,
).cuda()
if world_size == 1:
for expert_moe, expert_raw in zip(moe.experts, moe_raw.experts):
for para_moe, para_raw in zip(
expert_moe.parameters(), expert_raw.parameters()
):
para_raw.data = para_moe.data.clone()
else:
assert len(moe.experts) >= 1
for idx, para in enumerate(moe.experts[0].parameters()):
para_tensor = torch.cat(
[list(expert.parameters())[idx].unsqueeze(0) for expert in moe.experts]
)
para_array = [torch.empty_like(para_tensor) for _ in range(world_size)]
torch.distributed.all_gather(para_array, para_tensor)
para_tensor_gathered = torch.cat(para_array, dim=0)
assert para_tensor_gathered.shape[0] == len(moe_raw.experts)
for expertID in range(para_tensor_gathered.shape[0]):
list(moe_raw.experts[expertID].parameters())[
idx
].data = para_tensor_gathered[expertID]
moe_out, raw_out, moe_grad_in, raw_grad_in = _perform_forward(
moe, moe_raw, batch_size, d_model, top_k, rank, mp_group
)
def get_experts_grad(experts: List[nn.Module]):
return torch.stack(
[
torch.stack(
[
p.grad.sum() if p.grad is not None else torch.zeros(1).cuda()
for p in item.parameters()
]
).sum()
for item in experts
]
)
moe_grad, raw_grad = (
get_experts_grad(moe.experts),
get_experts_grad(moe_raw.experts),
)
if world_size > 1:
torch.distributed.all_reduce(raw_grad)
mp_size = mp_group.size() if mp_group else 1
raw_grad = raw_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
moe_out_list = [moe_out, moe_grad, moe_grad_in]
raw_out_list = [raw_out, raw_grad, raw_grad_in]
names = ["forward", "backward", "grad_in"]
_assert_numerical(names, moe_out_list, raw_out_list, rank)
class MyModule(nn.Module):
def __init__(self, dim=8):
super(MyModule, self).__init__()
self.model = nn.Sequential(
OrderedDict(
[
("linear1", nn.Linear(dim, dim)),
("relu1", nn.ReLU()),
("linear2", nn.Linear(dim, dim)),
("relu2", nn.ReLU()),
("linear3", nn.Linear(dim, dim)),
]
)
)
def set_comm(self):
for p in self.model._modules["linear1"].parameters():
setattr(p, "dp_comm", "mp")
for p in self.model._modules["linear2"].parameters():
setattr(p, "dp_comm", "dp")
for p in self.model._modules["linear3"].parameters():
setattr(p, "dp_comm", "world")
def forward(self, inp):
return self.model(inp)
def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group):
batch_size, dim = 4, 8
torch.manual_seed(42 + rank)
torch.cuda.manual_seed(42 + rank)
model = MyModule().cuda()
model_ddp = LocalDDP(deepcopy(model),
mp_group=mp_group, dp_group=dp_group, world_group=world_group)
model.set_comm()
model_ddp.module.set_comm()
inp = torch.randn(batch_size, dim).cuda()
raw_out = model(inp).mean()
ddp_out = model_ddp(inp).mean()
raw_out.backward()
ddp_out.backward()
torch.distributed.all_reduce(
model.model._modules["linear1"].weight.grad.data, group=mp_group
)
model.model._modules["linear1"].weight.grad /= mp_group.size()
torch.distributed.all_reduce(
model.model._modules["linear2"].weight.grad.data, group=dp_group
)
model.model._modules["linear2"].weight.grad /= dp_group.size()
torch.distributed.all_reduce(
model.model._modules["linear3"].weight.grad.data, group=world_group
)
model.model._modules["linear3"].weight.grad /= world_group.size()
model_ddp.allreduce_params(reduce_after=False, fp32_allreduce=False)
raw_out_list = [
model.model._modules["linear1"].weight.grad,
model.model._modules["linear2"].weight.grad,
model.model._modules["linear3"].weight.grad,
]
ddp_out_list = [
model_ddp.module.model._modules["linear1"].weight.grad,
model_ddp.module.model._modules["linear2"].weight.grad,
model_ddp.module.model._modules["linear3"].weight.grad,
]
names = ["mp grad", "dp grad", "wp grad"]
_assert_numerical(names, ddp_out_list, raw_out_list, rank)
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("num_expert", [None])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("top_k", [2, 3])
@pytest.mark.parametrize("expert", [ [NaiveExpert for _ in range(4)], [LinearExpert, NaiveExpert, LinearExpert, NaiveExpert, LinearExpert, NaiveExpert, LinearExpert, NaiveExpert] ])
@pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1])
@pytest.mark.parametrize("mp_group", [None])
@pytest.mark.parametrize("dp_group", [None])
@pytest.mark.parametrize("world_group", [None])
def test_fmoe_experts(
batch_size,
num_expert,
d_model,
top_k,
expert: Union[Type[nn.Module], str],
rank,
world_size,
mp_group,
dp_group,
world_group,
):
torch.manual_seed(42 + rank)
torch.cuda.manual_seed(42 + rank)
if isinstance(expert, str):
expert = globals()[expert]
moe = FMoE(
num_expert=num_expert,
d_model=d_model,
gate=NaiveGate,
world_size=world_size,
mp_group=mp_group,
expert=expert,
top_k=top_k,
).cuda()
moe_raw = BruteForceMoE(
expert=expert,
num_expert=num_expert,
d_model=d_model,
world_size=world_size,
top_k=top_k,
).cuda()
if world_size == 1:
for expert_moe, expert_raw in zip(moe.experts, moe_raw.experts):
for para_moe, para_raw in zip(
expert_moe.parameters(), expert_raw.parameters()
):
para_raw.data = para_moe.data.clone()
else:
assert len(moe.experts) >= 1
for idx, para in enumerate(moe.experts[0].parameters()):
para_tensor = torch.cat(
[list(expert.parameters())[idx].unsqueeze(0) for expert in moe.experts]
)
para_array = [torch.empty_like(para_tensor) for _ in range(world_size)]
torch.distributed.all_gather(para_array, para_tensor)
para_tensor_gathered = torch.cat(para_array, dim=0)
assert para_tensor_gathered.shape[0] == len(moe_raw.experts)
for expertID in range(para_tensor_gathered.shape[0]):
list(moe_raw.experts[expertID].parameters())[
idx
].data = para_tensor_gathered[expertID]
moe_out, raw_out, moe_grad_in, raw_grad_in = _perform_forward(
moe, moe_raw, batch_size, d_model, top_k, rank, mp_group
)
def get_experts_grad(experts: List[nn.Module]):
return torch.stack(
[
torch.stack(
[
p.grad.sum() if p.grad is not None else torch.zeros(1).cuda()
for p in item.parameters()
]
).sum()
for item in experts
]
)
moe_grad, raw_grad = (
get_experts_grad(moe.experts),
get_experts_grad(moe_raw.experts),
)
if world_size > 1:
torch.distributed.all_reduce(raw_grad)
mp_size = mp_group.size() if mp_group else 1
raw_grad = raw_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
moe_out_list = [moe_out, moe_grad, moe_grad_in]
raw_out_list = [raw_out, raw_grad, raw_grad_in]
names = ["forward", "backward", "grad_in"]
_assert_numerical(names, moe_out_list, raw_out_list, rank)
if __name__ == "__main__":
test_fmoe_linear(
batch_size=2,
num_expert=2,
d_model=2,
top_k=2,
d_hidden=16,
rank=0,
world_size=1,
mp_group=None,
dp_group=None,
world_group=None,
data_type=torch.float32,
)
import pytest
import os
import sys
import json
import math
import torch
import torch.distributed as dist
import torch.nn.functional as F
from fmoe.functions import ensure_comm
from fmoe.gates.swipe_gate import SwipeGate
from test_ddp import _ensure_initialized, _run_distributed
@pytest.mark.parametrize("d_model", [1024])
@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("n_expert", [1, 4])
@pytest.mark.parametrize("top_k", [2, 4])
@pytest.mark.parametrize("world_size", [2, 4, 8])
def test_swipe_gate(world_size, d_model, batch_size, n_expert, top_k):
if world_size * n_expert < 2:
pytest.skip("No enough experts")
_run_distributed('_test_swipe_gate',
world_size,
{
'd_model': d_model,
'batch_size': batch_size,
'n_expert': n_expert,
'top_k': top_k
},
script=__file__
)
def _test_swipe_gate(d_model, batch_size, n_expert, top_k):
_ensure_initialized()
gate = SwipeGate(d_model, n_expert, dist.get_world_size()).cuda()
x = torch.rand(batch_size, d_model).cuda()
ensure_comm(x, None)
topk_idx, topk_val = gate(x)
@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("n_expert", [1, 4])
@pytest.mark.parametrize("world_size", [2, 4, 8])
def test_swipe_once(world_size, batch_size, n_expert):
if world_size * n_expert < 2:
pytest.skip("No enough experts")
_run_distributed('_test_swipe_once',
world_size,
{
'batch_size': batch_size,
'n_expert': n_expert
},
script=__file__
)
def _test_swipe_once(batch_size, n_expert):
_ensure_initialized()
rank = dist.get_rank()
world_size = dist.get_world_size()
gate = SwipeGate(4, n_expert, dist.get_world_size()).cuda()
idx = torch.randint(0, n_expert * world_size, (batch_size,)).cuda()
capacity = torch.scalar_tensor(batch_size * 2, dtype=torch.long)
ensure_comm(idx, None)
new_idx, new_cap = gate.swipe_once(idx, capacity, 0)
idx = torch.randint(0, n_expert * world_size, (batch_size,)).cuda()
new_idx, new_cap = gate.swipe_once(idx, new_cap, 0)
if __name__ == '__main__':
if len(sys.argv) >= 3:
args = json.loads(sys.argv[2])
locals()[sys.argv[1]](**args)
else:
test_swipe_gate(8, 4, 8, 4, 2)
# test_swipe_once(8, 800, 4)
import os
import sys
import json
import torch
from fmoe.layers import _fmoe_general_global_forward
from fmoe import FMoETransformerMLP
from test_ddp import _run_distributed
class ConstantGate(torch.nn.Module):
def __init__(self, d_model, num_expert, world_size, top_k=1):
super().__init__()
self.top_k = top_k
def forward(self, inp):
idx = torch.zeros((inp.shape[0], self.top_k), dtype=torch.int64,
device=inp.device)
score = torch.ones((inp.shape[0], 1, self.top_k), device=inp.device) / 2
return idx, score
def test_zero_fwd(num_expert=2, batch_size=4, d_hidden=8, world_size=1):
_run_distributed('_test_zero_fwd',
1,
{
'num_expert': num_expert,
'batch_size': batch_size,
'd_hidden': d_hidden
},
script=__file__
)
def _test_zero_fwd(num_expert=2, batch_size=4, d_hidden=8, world_size=1):
inp = torch.rand(batch_size, d_hidden).cuda()
gate = torch.zeros(batch_size, dtype=torch.int64).cuda()
x = _fmoe_general_global_forward(inp, gate, lambda x, y: x, num_expert,
world_size)
def test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1):
_run_distributed('_test_zero_transformer',
1,
{
'num_expert': num_expert,
'batch_size': batch_size,
'd_hidden': d_hidden
},
script=__file__
)
def _test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1):
inp = torch.rand(batch_size, d_hidden).cuda()
mask = torch.zeros(inp.shape[0], dtype=torch.long)
mask[1] = 1
mask_dict = {
1: torch.zeros(d_hidden).cuda()
}
model = FMoETransformerMLP(num_expert, d_hidden, d_hidden * 4,
world_size=world_size, gate=ConstantGate, mask=mask,
mask_dict=mask_dict).cuda()
oup = model(inp)
if __name__ == '__main__':
if len(sys.argv) >= 3:
args = json.loads(sys.argv[2])
os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0")
os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"]
torch.distributed.init_process_group(backend="nccl")
args['world_size'] = torch.distributed.get_world_size()
locals()[sys.argv[1]](**args)
else:
# test_zero_fwd(world_size=torch.distributed.get_world_size())
test_zero_transformer(num_expert=16, batch_size=4096, d_hidden=1024,
world_size=1)
print('done')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment