"...text-generation-inference.git" did not exist on "6db3bcb700e62134b35d87169e88907543583a16"
Unverified Commit 87dad9d5 authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #5 from laekov/bias

merge Bias
parents ed9277f9 01464726
...@@ -19,13 +19,18 @@ class FMoELinear(nn.Module): ...@@ -19,13 +19,18 @@ class FMoELinear(nn.Module):
performed in parallel to increase the performance. performed in parallel to increase the performance.
The FMoELinear module provides such function. The FMoELinear module provides such function.
''' '''
def __init__(self, num_expert=32, in_feat=1024, out_feat=1024, rank=0): def __init__(self, num_expert: int, in_feat: int, out_feat: int,
bias: bool = True, rank: int = 0):
super().__init__() super().__init__()
self.num_expert = num_expert self.num_expert = num_expert
self.in_feat = in_feat self.in_feat = in_feat
self.out_feat = out_feat self.out_feat = out_feat
self.rank = rank self.rank = rank
self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat)) self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat))
if bias:
self.bias = nn.Parameter(torch.Tensor(num_expert, out_feat))
else:
self.register_parameter('bias', None)
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
...@@ -41,17 +46,32 @@ class FMoELinear(nn.Module): ...@@ -41,17 +46,32 @@ class FMoELinear(nn.Module):
bound = math.sqrt(3.0) * std bound = math.sqrt(3.0) * std
device = self.weight.device device = self.weight.device
dtype = self.weight.dtype dtype = self.weight.dtype
for i in range(self.num_expert): weight = rng.uniform(-bound, bound, size=tuple(self.weight.size()))
weight = rng.uniform(-bound, bound, self.weight.data = torch.tensor(weight, dtype=dtype, device=device)
size=tuple(self.weight[i].size()))
self.weight.data[i] = torch.tensor(weight, if self.bias is not None:
dtype=dtype, device=device) fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
bound = 1 / math.sqrt(fan_in)
bias = rng.uniform(-bound, bound, size=tuple(self.bias.size()))
self.bias.data = torch.tensor(bias, dtype=dtype, device=device)
def forward(self, inp, fwd_expert_count): def forward(self, inp, fwd_expert_count):
r''' r'''
Call MOE function Call MOE function
''' '''
return MOELinear.apply(inp, self.weight, fwd_expert_count) x = MOELinear.apply(inp, self.weight, fwd_expert_count)
if self.bias is not None:
bias = torch.repeat_interleave(self.bias,
fwd_expert_count.to(self.bias.device), dim=0)
x = x + bias
return x
def extra_repr(self) -> str:
return 'num_expert={}, in_features={}, \
out_features={}, bias={}, rank={}'.format(
self.num_expert, self.in_feat,
self.out_feat, self.bias is not None, self.rank
)
def mark_module_parallel_comm(module, comm): def mark_module_parallel_comm(module, comm):
...@@ -92,8 +112,8 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size): ...@@ -92,8 +112,8 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
class FMoE(nn.Module): class FMoE(nn.Module):
r''' r'''
A general moe implementation that supports an arbitrary module as the expert A general moe implementation that supports an arbitrary module as the
Either `expert` or `expert_fn` is required. expert.
* `num_expert` stands for the number of experts on **each** worker. * `num_expert` stands for the number of experts on **each** worker.
* `world_size` stands for the total number of workers that contains * `world_size` stands for the total number of workers that contains
different experts. different experts.
...@@ -106,12 +126,9 @@ class FMoE(nn.Module): ...@@ -106,12 +126,9 @@ class FMoE(nn.Module):
* `gate` is a gate class which can found in `fmoe.gates`. * `gate` is a gate class which can found in `fmoe.gates`.
* `expert` can be specified as a module class, it is used to generate * `expert` can be specified as a module class, it is used to generate
`num_expert` expert modules. `num_expert` expert modules.
* `expert_fn` is specified as a callable object or a function, it will be
called during forward, giving the input tensor (contiguous) and the array of
the number of input feature to each expert as input.
''' '''
def __init__(self, num_expert=32, d_model=1024, world_size=1, mp_group=None, def __init__(self, num_expert=32, d_model=1024, world_size=1, mp_group=None,
top_k=2, gate=NaiveGate, expert=None, expert_fn=None): top_k=2, gate=NaiveGate, expert=None):
super().__init__() super().__init__()
self.num_expert = num_expert self.num_expert = num_expert
self.d_model = d_model self.d_model = d_model
...@@ -125,19 +142,20 @@ class FMoE(nn.Module): ...@@ -125,19 +142,20 @@ class FMoE(nn.Module):
self.mp_rank = mp_group.rank() self.mp_rank = mp_group.rank()
self.top_k = top_k self.top_k = top_k
self.gate = gate(d_model, num_expert, world_size, top_k) self.gate = gate(d_model, num_expert, world_size, top_k)
if expert_fn is None: if expert is not None:
assert expert is not None, 'Either expert or expert_fn should be set'
self.experts = [expert(d_model) for _ in range(num_expert)] self.experts = [expert(d_model) for _ in range(num_expert)]
def expert_fn(inp, fwd_expert_count):
outputs = [] def expert_fn(self, inp, fwd_expert_count):
base_idx = 0 if isinstance(self.experts, nn.Module):
for i in range(self.num_expert): return self.experts(inp, fwd_expert_count)
batch_size = fwd_expert_count[i].item() outputs = []
inp_slice = inp[base_idx:base_idx + batch_size] base_idx = 0
outputs.append(self.experts[i](inp_slice)) for i in range(self.num_expert):
base_idx += batch_size batch_size = fwd_expert_count[i].item()
return torch.cat(outputs, dim=0) inp_slice = inp[base_idx:base_idx + batch_size]
self.expert_fn = expert_fn outputs.append(self.experts[i](inp_slice))
base_idx += batch_size
return torch.cat(outputs, dim=0)
def mark_parallel_comm(self): def mark_parallel_comm(self):
r''' r'''
......
...@@ -14,8 +14,10 @@ class _Expert(nn.Module): ...@@ -14,8 +14,10 @@ class _Expert(nn.Module):
''' '''
def __init__(self, num_expert, d_model, d_hidden, activation, rank=0): def __init__(self, num_expert, d_model, d_hidden, activation, rank=0):
super().__init__() super().__init__()
self.htoh4 = FMoELinear(num_expert, d_model, d_hidden, rank) self.htoh4 = FMoELinear(num_expert, d_model, d_hidden,
self.h4toh = FMoELinear(num_expert, d_hidden, d_model, rank) bias=True, rank=rank)
self.h4toh = FMoELinear(num_expert, d_hidden, d_model,
bias=True, rank=rank)
self.activation = activation self.activation = activation
def forward(self, inp, fwd_expert_count): def forward(self, inp, fwd_expert_count):
...@@ -47,11 +49,8 @@ class FMoETransformerMLP(FMoE): ...@@ -47,11 +49,8 @@ class FMoETransformerMLP(FMoE):
top_k=2, top_k=2,
pre_lnorm=False pre_lnorm=False
): ):
def expert_fn(inp, gate):
return self.experts(inp, gate)
super().__init__(num_expert=num_expert, d_model=d_model, gate=gate, super().__init__(num_expert=num_expert, d_model=d_model, gate=gate,
top_k=top_k, world_size=world_size, mp_group=mp_group, top_k=top_k, world_size=world_size, mp_group=mp_group)
expert_fn=expert_fn)
self.experts = _Expert(num_expert, d_model, d_hidden, activation, self.experts = _Expert(num_expert, d_model, d_hidden, activation,
rank=self.mp_rank) rank=self.mp_rank)
self.pre_lnorm = pre_lnorm self.pre_lnorm = pre_lnorm
......
...@@ -20,9 +20,15 @@ class BruteForceMoELinear(nn.Module): ...@@ -20,9 +20,15 @@ class BruteForceMoELinear(nn.Module):
self.weight_htoh4 = nn.Parameter( self.weight_htoh4 = nn.Parameter(
torch.Tensor(num_expert * world_size, d_hidden, d_model) torch.Tensor(num_expert * world_size, d_hidden, d_model)
) )
self.bias_htoh4 = nn.Parameter(
torch.Tensor(num_expert * world_size, d_hidden)
)
self.weight_h4toh = nn.Parameter( self.weight_h4toh = nn.Parameter(
torch.Tensor(num_expert * world_size, d_model, d_hidden) torch.Tensor(num_expert * world_size, d_model, d_hidden)
) )
self.bias_h4toh = nn.Parameter(
torch.Tensor(num_expert * world_size, d_model)
)
self.top_k = top_k self.top_k = top_k
def forward(self, inp, gate_idx, gate_score): def forward(self, inp, gate_idx, gate_score):
...@@ -34,8 +40,10 @@ class BruteForceMoELinear(nn.Module): ...@@ -34,8 +40,10 @@ class BruteForceMoELinear(nn.Module):
idx = (gate_idx == i) idx = (gate_idx == i)
x = inp[idx] x = inp[idx]
x = x @ self.weight_htoh4[i].t() x = x @ self.weight_htoh4[i].t()
x = x + self.bias_htoh4[i]
x = self.activation(x) x = self.activation(x)
x = x @ self.weight_h4toh[i].t() x = x @ self.weight_h4toh[i].t()
x = x + self.bias_h4toh[i]
o[idx] = x o[idx] = x
x = torch.bmm(gate_score, o.view(-1, self.top_k, x = torch.bmm(gate_score, o.view(-1, self.top_k,
self.d_model)).reshape(-1, self.d_model) self.d_model)).reshape(-1, self.d_model)
......
...@@ -10,6 +10,19 @@ from fmoe.transformer import _Expert ...@@ -10,6 +10,19 @@ from fmoe.transformer import _Expert
n_devices = int(os.environ.get("N_GPUS", "2")) n_devices = int(os.environ.get("N_GPUS", "2"))
class MyMoE(FMoE):
def __init__(self, num_expert, d_model, d_hidden, top_k, activation):
super().__init__(
num_expert=num_expert,
d_model=d_model,
gate=NaiveGate,
world_size=1,
mp_group=None,
top_k=top_k
)
self.experts = _Expert(num_expert, d_model, d_hidden, activation)
@pytest.mark.parametrize("num_expert", [4, 8]) @pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("top_k", [2, 3]) @pytest.mark.parametrize("top_k", [2, 3])
@pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("batch_size", [4])
...@@ -26,22 +39,12 @@ def test_fmoe_dp( ...@@ -26,22 +39,12 @@ def test_fmoe_dp(
torch.manual_seed(42) torch.manual_seed(42)
torch.cuda.manual_seed(42) torch.cuda.manual_seed(42)
experts = _Expert(num_expert, d_model, d_hidden, activation).cuda() moe = MyMoE(num_expert, d_model, d_hidden, top_k, activation).cuda()
def expert_fn(inp, gate):
return experts(inp, gate)
moe = FMoE(
num_expert=num_expert,
d_model=d_model,
gate=NaiveGate,
world_size=1,
mp_group=None,
expert_fn=expert_fn,
top_k=top_k,
).cuda()
moe_dp = torch.nn.DataParallel(moe, device_ids=list(range(n_devices))) moe_dp = torch.nn.DataParallel(moe, device_ids=list(range(n_devices)))
for i in range(5): for i in range(5):
output = moe_dp(torch.rand(batch_size, d_model).cuda()) output = moe_dp(torch.rand(batch_size, d_model).cuda())
if __name__ == '__main__':
test_fmoe_dp(4, 2, 4, 16, 32)
...@@ -52,6 +52,20 @@ def _assert_numercial(names, moe_out_list, raw_out_list, rank): ...@@ -52,6 +52,20 @@ def _assert_numercial(names, moe_out_list, raw_out_list, rank):
assert False assert False
class MyMoE(FMoE):
def __init__(self, num_expert, d_model, d_hidden, world_size, mp_group,
top_k, activation):
super().__init__(
num_expert=num_expert,
d_model=d_model,
gate=NaiveGate,
world_size=world_size,
mp_group=mp_group,
top_k=top_k
)
self.experts = _Expert(num_expert, d_model, d_hidden, activation)
@pytest.mark.parametrize("num_expert", [4, 8]) @pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("top_k", [2, 3]) @pytest.mark.parametrize("top_k", [2, 3])
@pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("batch_size", [4])
...@@ -74,20 +88,8 @@ def test_fmoe_linear( ...@@ -74,20 +88,8 @@ def test_fmoe_linear(
torch.manual_seed(42 + rank) torch.manual_seed(42 + rank)
torch.cuda.manual_seed(42 + rank) torch.cuda.manual_seed(42 + rank)
experts = _Expert(num_expert, d_model, d_hidden, activation).cuda() moe = MyMoE(num_expert, d_model, d_hidden, world_size, mp_group, top_k,
activation).cuda()
def expert_fn(inp, gate):
return experts(inp, gate)
moe = FMoE(
num_expert=num_expert,
d_model=d_model,
gate=NaiveGate,
world_size=world_size,
mp_group=mp_group,
expert_fn=expert_fn,
top_k=top_k,
).cuda()
moe_raw = BruteForceMoELinear( moe_raw = BruteForceMoELinear(
activation=activation, activation=activation,
...@@ -99,38 +101,54 @@ def test_fmoe_linear( ...@@ -99,38 +101,54 @@ def test_fmoe_linear(
).cuda() ).cuda()
if world_size == 1: if world_size == 1:
moe_raw.weight_htoh4.data = experts.htoh4.weight.data.clone() moe_raw.weight_htoh4.data = moe.experts.htoh4.weight.data.clone()
moe_raw.weight_h4toh.data = experts.h4toh.weight.data.clone() moe_raw.bias_htoh4.data = moe.experts.htoh4.bias.data.clone()
moe_raw.weight_h4toh.data = moe.experts.h4toh.weight.data.clone()
moe_raw.bias_h4toh.data = moe.experts.h4toh.bias.data.clone()
else: else:
weight_htoh4_array = [ weight_htoh4_array = [
torch.empty_like(experts.htoh4.weight.data) for _ in range(world_size) torch.empty_like(moe.experts.htoh4.weight.data) for _ in range(world_size)
]
bias_htoh4_array = [
torch.empty_like(moe.experts.htoh4.bias.data) for _ in range(world_size)
] ]
torch.distributed.all_gather(weight_htoh4_array, experts.htoh4.weight.data) torch.distributed.all_gather(weight_htoh4_array, moe.experts.htoh4.weight.data)
torch.distributed.all_gather(bias_htoh4_array, moe.experts.htoh4.bias.data)
moe_raw.weight_htoh4.data = torch.cat(weight_htoh4_array, dim=0) moe_raw.weight_htoh4.data = torch.cat(weight_htoh4_array, dim=0)
moe_raw.bias_htoh4.data = torch.cat(bias_htoh4_array, dim=0)
weight_h4toh_array = [ weight_h4toh_array = [
torch.empty_like(experts.h4toh.weight.data) for _ in range(world_size) torch.empty_like(moe.experts.h4toh.weight.data) for _ in range(world_size)
]
bias_h4toh_array = [
torch.empty_like(moe.experts.h4toh.bias.data) for _ in range(world_size)
] ]
torch.distributed.all_gather(weight_h4toh_array, experts.h4toh.weight.data) torch.distributed.all_gather(weight_h4toh_array, moe.experts.h4toh.weight.data)
torch.distributed.all_gather(bias_h4toh_array, moe.experts.h4toh.bias.data)
moe_raw.weight_h4toh.data = torch.cat(weight_h4toh_array, dim=0) moe_raw.weight_h4toh.data = torch.cat(weight_h4toh_array, dim=0)
moe_raw.bias_h4toh.data = torch.cat(bias_h4toh_array, dim=0)
moe_out, raw_out = _perform_forward( moe_out, raw_out = _perform_forward(
moe, moe_raw, batch_size, d_model, top_k, rank, mp_group moe, moe_raw, batch_size, d_model, top_k, rank, mp_group
) )
moe_out_list = moe_out, experts.htoh4.weight.grad, experts.h4toh.weight.grad moe_out_list = moe_out, moe.experts.htoh4.weight.grad, moe.experts.h4toh.weight.grad, moe.experts.htoh4.bias.grad, moe.experts.h4toh.bias.grad
raw_out_list = raw_out, moe_raw.weight_htoh4.grad, moe_raw.weight_h4toh.grad raw_out_list = raw_out, moe_raw.weight_htoh4.grad, moe_raw.weight_h4toh.grad, moe_raw.bias_htoh4.grad, moe_raw.bias_h4toh.grad
if world_size > 1: if world_size > 1:
_, htoh4_grad, h4toh_grad = raw_out_list _, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad = raw_out_list
torch.distributed.all_reduce(htoh4_grad) torch.distributed.all_reduce(htoh4_w_grad)
torch.distributed.all_reduce(h4toh_grad) torch.distributed.all_reduce(h4toh_w_grad)
torch.distributed.all_reduce(htoh4_b_grad)
torch.distributed.all_reduce(h4toh_b_grad)
mp_size = mp_group.size() if mp_group else 1 mp_size = mp_group.size() if mp_group else 1
htoh4_grad = htoh4_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size htoh4_w_grad = htoh4_w_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
h4toh_grad = h4toh_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size h4toh_w_grad = h4toh_w_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
raw_out_list = _, htoh4_grad, h4toh_grad htoh4_b_grad = htoh4_b_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
h4toh_b_grad = h4toh_b_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
raw_out_list = _, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad
names = ["output", "htoh4 weight grad", "h4toh weight grad"] names = ["output", "htoh4 weight grad", "h4toh weight grad", "htoh4 bias grad", "h4toh bias grad"]
_assert_numercial(names, moe_out_list, raw_out_list, rank) _assert_numercial(names, moe_out_list, raw_out_list, rank)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment