Unverified Commit 87dad9d5 authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #5 from laekov/bias

merge Bias
parents ed9277f9 01464726
......@@ -19,13 +19,18 @@ class FMoELinear(nn.Module):
performed in parallel to increase the performance.
The FMoELinear module provides such function.
'''
def __init__(self, num_expert=32, in_feat=1024, out_feat=1024, rank=0):
def __init__(self, num_expert: int, in_feat: int, out_feat: int,
bias: bool = True, rank: int = 0):
super().__init__()
self.num_expert = num_expert
self.in_feat = in_feat
self.out_feat = out_feat
self.rank = rank
self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat))
if bias:
self.bias = nn.Parameter(torch.Tensor(num_expert, out_feat))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
......@@ -41,17 +46,32 @@ class FMoELinear(nn.Module):
bound = math.sqrt(3.0) * std
device = self.weight.device
dtype = self.weight.dtype
for i in range(self.num_expert):
weight = rng.uniform(-bound, bound,
size=tuple(self.weight[i].size()))
self.weight.data[i] = torch.tensor(weight,
dtype=dtype, device=device)
weight = rng.uniform(-bound, bound, size=tuple(self.weight.size()))
self.weight.data = torch.tensor(weight, dtype=dtype, device=device)
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
bound = 1 / math.sqrt(fan_in)
bias = rng.uniform(-bound, bound, size=tuple(self.bias.size()))
self.bias.data = torch.tensor(bias, dtype=dtype, device=device)
def forward(self, inp, fwd_expert_count):
r'''
Call MOE function
'''
return MOELinear.apply(inp, self.weight, fwd_expert_count)
x = MOELinear.apply(inp, self.weight, fwd_expert_count)
if self.bias is not None:
bias = torch.repeat_interleave(self.bias,
fwd_expert_count.to(self.bias.device), dim=0)
x = x + bias
return x
def extra_repr(self) -> str:
return 'num_expert={}, in_features={}, \
out_features={}, bias={}, rank={}'.format(
self.num_expert, self.in_feat,
self.out_feat, self.bias is not None, self.rank
)
def mark_module_parallel_comm(module, comm):
......@@ -92,8 +112,8 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
class FMoE(nn.Module):
r'''
A general moe implementation that supports an arbitrary module as the expert
Either `expert` or `expert_fn` is required.
A general moe implementation that supports an arbitrary module as the
expert.
* `num_expert` stands for the number of experts on **each** worker.
* `world_size` stands for the total number of workers that contains
different experts.
......@@ -106,12 +126,9 @@ class FMoE(nn.Module):
* `gate` is a gate class which can found in `fmoe.gates`.
* `expert` can be specified as a module class, it is used to generate
`num_expert` expert modules.
* `expert_fn` is specified as a callable object or a function, it will be
called during forward, giving the input tensor (contiguous) and the array of
the number of input feature to each expert as input.
'''
def __init__(self, num_expert=32, d_model=1024, world_size=1, mp_group=None,
top_k=2, gate=NaiveGate, expert=None, expert_fn=None):
top_k=2, gate=NaiveGate, expert=None):
super().__init__()
self.num_expert = num_expert
self.d_model = d_model
......@@ -125,19 +142,20 @@ class FMoE(nn.Module):
self.mp_rank = mp_group.rank()
self.top_k = top_k
self.gate = gate(d_model, num_expert, world_size, top_k)
if expert_fn is None:
assert expert is not None, 'Either expert or expert_fn should be set'
if expert is not None:
self.experts = [expert(d_model) for _ in range(num_expert)]
def expert_fn(inp, fwd_expert_count):
outputs = []
base_idx = 0
for i in range(self.num_expert):
batch_size = fwd_expert_count[i].item()
inp_slice = inp[base_idx:base_idx + batch_size]
outputs.append(self.experts[i](inp_slice))
base_idx += batch_size
return torch.cat(outputs, dim=0)
self.expert_fn = expert_fn
def expert_fn(self, inp, fwd_expert_count):
if isinstance(self.experts, nn.Module):
return self.experts(inp, fwd_expert_count)
outputs = []
base_idx = 0
for i in range(self.num_expert):
batch_size = fwd_expert_count[i].item()
inp_slice = inp[base_idx:base_idx + batch_size]
outputs.append(self.experts[i](inp_slice))
base_idx += batch_size
return torch.cat(outputs, dim=0)
def mark_parallel_comm(self):
r'''
......
......@@ -14,8 +14,10 @@ class _Expert(nn.Module):
'''
def __init__(self, num_expert, d_model, d_hidden, activation, rank=0):
super().__init__()
self.htoh4 = FMoELinear(num_expert, d_model, d_hidden, rank)
self.h4toh = FMoELinear(num_expert, d_hidden, d_model, rank)
self.htoh4 = FMoELinear(num_expert, d_model, d_hidden,
bias=True, rank=rank)
self.h4toh = FMoELinear(num_expert, d_hidden, d_model,
bias=True, rank=rank)
self.activation = activation
def forward(self, inp, fwd_expert_count):
......@@ -47,11 +49,8 @@ class FMoETransformerMLP(FMoE):
top_k=2,
pre_lnorm=False
):
def expert_fn(inp, gate):
return self.experts(inp, gate)
super().__init__(num_expert=num_expert, d_model=d_model, gate=gate,
top_k=top_k, world_size=world_size, mp_group=mp_group,
expert_fn=expert_fn)
top_k=top_k, world_size=world_size, mp_group=mp_group)
self.experts = _Expert(num_expert, d_model, d_hidden, activation,
rank=self.mp_rank)
self.pre_lnorm = pre_lnorm
......
......@@ -20,9 +20,15 @@ class BruteForceMoELinear(nn.Module):
self.weight_htoh4 = nn.Parameter(
torch.Tensor(num_expert * world_size, d_hidden, d_model)
)
self.bias_htoh4 = nn.Parameter(
torch.Tensor(num_expert * world_size, d_hidden)
)
self.weight_h4toh = nn.Parameter(
torch.Tensor(num_expert * world_size, d_model, d_hidden)
)
self.bias_h4toh = nn.Parameter(
torch.Tensor(num_expert * world_size, d_model)
)
self.top_k = top_k
def forward(self, inp, gate_idx, gate_score):
......@@ -34,8 +40,10 @@ class BruteForceMoELinear(nn.Module):
idx = (gate_idx == i)
x = inp[idx]
x = x @ self.weight_htoh4[i].t()
x = x + self.bias_htoh4[i]
x = self.activation(x)
x = x @ self.weight_h4toh[i].t()
x = x + self.bias_h4toh[i]
o[idx] = x
x = torch.bmm(gate_score, o.view(-1, self.top_k,
self.d_model)).reshape(-1, self.d_model)
......
......@@ -10,6 +10,19 @@ from fmoe.transformer import _Expert
n_devices = int(os.environ.get("N_GPUS", "2"))
class MyMoE(FMoE):
def __init__(self, num_expert, d_model, d_hidden, top_k, activation):
super().__init__(
num_expert=num_expert,
d_model=d_model,
gate=NaiveGate,
world_size=1,
mp_group=None,
top_k=top_k
)
self.experts = _Expert(num_expert, d_model, d_hidden, activation)
@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("top_k", [2, 3])
@pytest.mark.parametrize("batch_size", [4])
......@@ -26,22 +39,12 @@ def test_fmoe_dp(
torch.manual_seed(42)
torch.cuda.manual_seed(42)
experts = _Expert(num_expert, d_model, d_hidden, activation).cuda()
def expert_fn(inp, gate):
return experts(inp, gate)
moe = FMoE(
num_expert=num_expert,
d_model=d_model,
gate=NaiveGate,
world_size=1,
mp_group=None,
expert_fn=expert_fn,
top_k=top_k,
).cuda()
moe = MyMoE(num_expert, d_model, d_hidden, top_k, activation).cuda()
moe_dp = torch.nn.DataParallel(moe, device_ids=list(range(n_devices)))
for i in range(5):
output = moe_dp(torch.rand(batch_size, d_model).cuda())
if __name__ == '__main__':
test_fmoe_dp(4, 2, 4, 16, 32)
......@@ -52,6 +52,20 @@ def _assert_numercial(names, moe_out_list, raw_out_list, rank):
assert False
class MyMoE(FMoE):
def __init__(self, num_expert, d_model, d_hidden, world_size, mp_group,
top_k, activation):
super().__init__(
num_expert=num_expert,
d_model=d_model,
gate=NaiveGate,
world_size=world_size,
mp_group=mp_group,
top_k=top_k
)
self.experts = _Expert(num_expert, d_model, d_hidden, activation)
@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("top_k", [2, 3])
@pytest.mark.parametrize("batch_size", [4])
......@@ -74,20 +88,8 @@ def test_fmoe_linear(
torch.manual_seed(42 + rank)
torch.cuda.manual_seed(42 + rank)
experts = _Expert(num_expert, d_model, d_hidden, activation).cuda()
def expert_fn(inp, gate):
return experts(inp, gate)
moe = FMoE(
num_expert=num_expert,
d_model=d_model,
gate=NaiveGate,
world_size=world_size,
mp_group=mp_group,
expert_fn=expert_fn,
top_k=top_k,
).cuda()
moe = MyMoE(num_expert, d_model, d_hidden, world_size, mp_group, top_k,
activation).cuda()
moe_raw = BruteForceMoELinear(
activation=activation,
......@@ -99,38 +101,54 @@ def test_fmoe_linear(
).cuda()
if world_size == 1:
moe_raw.weight_htoh4.data = experts.htoh4.weight.data.clone()
moe_raw.weight_h4toh.data = experts.h4toh.weight.data.clone()
moe_raw.weight_htoh4.data = moe.experts.htoh4.weight.data.clone()
moe_raw.bias_htoh4.data = moe.experts.htoh4.bias.data.clone()
moe_raw.weight_h4toh.data = moe.experts.h4toh.weight.data.clone()
moe_raw.bias_h4toh.data = moe.experts.h4toh.bias.data.clone()
else:
weight_htoh4_array = [
torch.empty_like(experts.htoh4.weight.data) for _ in range(world_size)
torch.empty_like(moe.experts.htoh4.weight.data) for _ in range(world_size)
]
bias_htoh4_array = [
torch.empty_like(moe.experts.htoh4.bias.data) for _ in range(world_size)
]
torch.distributed.all_gather(weight_htoh4_array, experts.htoh4.weight.data)
torch.distributed.all_gather(weight_htoh4_array, moe.experts.htoh4.weight.data)
torch.distributed.all_gather(bias_htoh4_array, moe.experts.htoh4.bias.data)
moe_raw.weight_htoh4.data = torch.cat(weight_htoh4_array, dim=0)
moe_raw.bias_htoh4.data = torch.cat(bias_htoh4_array, dim=0)
weight_h4toh_array = [
torch.empty_like(experts.h4toh.weight.data) for _ in range(world_size)
torch.empty_like(moe.experts.h4toh.weight.data) for _ in range(world_size)
]
bias_h4toh_array = [
torch.empty_like(moe.experts.h4toh.bias.data) for _ in range(world_size)
]
torch.distributed.all_gather(weight_h4toh_array, experts.h4toh.weight.data)
torch.distributed.all_gather(weight_h4toh_array, moe.experts.h4toh.weight.data)
torch.distributed.all_gather(bias_h4toh_array, moe.experts.h4toh.bias.data)
moe_raw.weight_h4toh.data = torch.cat(weight_h4toh_array, dim=0)
moe_raw.bias_h4toh.data = torch.cat(bias_h4toh_array, dim=0)
moe_out, raw_out = _perform_forward(
moe, moe_raw, batch_size, d_model, top_k, rank, mp_group
)
moe_out_list = moe_out, experts.htoh4.weight.grad, experts.h4toh.weight.grad
raw_out_list = raw_out, moe_raw.weight_htoh4.grad, moe_raw.weight_h4toh.grad
moe_out_list = moe_out, moe.experts.htoh4.weight.grad, moe.experts.h4toh.weight.grad, moe.experts.htoh4.bias.grad, moe.experts.h4toh.bias.grad
raw_out_list = raw_out, moe_raw.weight_htoh4.grad, moe_raw.weight_h4toh.grad, moe_raw.bias_htoh4.grad, moe_raw.bias_h4toh.grad
if world_size > 1:
_, htoh4_grad, h4toh_grad = raw_out_list
torch.distributed.all_reduce(htoh4_grad)
torch.distributed.all_reduce(h4toh_grad)
_, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad = raw_out_list
torch.distributed.all_reduce(htoh4_w_grad)
torch.distributed.all_reduce(h4toh_w_grad)
torch.distributed.all_reduce(htoh4_b_grad)
torch.distributed.all_reduce(h4toh_b_grad)
mp_size = mp_group.size() if mp_group else 1
htoh4_grad = htoh4_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
h4toh_grad = h4toh_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
raw_out_list = _, htoh4_grad, h4toh_grad
htoh4_w_grad = htoh4_w_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
h4toh_w_grad = h4toh_w_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
htoh4_b_grad = htoh4_b_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
h4toh_b_grad = h4toh_b_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
raw_out_list = _, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad
names = ["output", "htoh4 weight grad", "h4toh weight grad"]
names = ["output", "htoh4 weight grad", "h4toh weight grad", "htoh4 bias grad", "h4toh bias grad"]
_assert_numercial(names, moe_out_list, raw_out_list, rank)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment