Unverified Commit 28ba2d28 authored by Colin's avatar Colin Committed by GitHub
Browse files

mask and experts list

parent baae8fb9
......@@ -11,3 +11,5 @@ build
*swp
logs
dist
**/.DS_Store
.idea
......@@ -132,6 +132,8 @@ class FMoE(nn.Module):
gate=NaiveGate,
expert=None,
gate_hook=None,
mask=None,
mask_dict=None,
):
super().__init__()
self.num_expert = num_expert
......@@ -145,14 +147,20 @@ class FMoE(nn.Module):
self.mp_size = mp_group.size()
self.mp_rank = mp_group.rank()
self.top_k = top_k
self.gate = gate(d_model, num_expert, world_size, top_k)
if expert is not None:
if type(expert) is list:
self.experts = nn.ModuleList([e(d_model) for e in expert])
self.experts_fused = False
self.num_expert = num_expert = len(expert)
elif expert is not None:
self.experts = nn.ModuleList([expert(d_model)
for _ in range(num_expert)])
self.experts_fused = False
else:
self.experts_fused = True
self.gate = gate(d_model, num_expert, world_size, top_k)
self.gate_hook = gate_hook
self.mask = mask
self.mask_dict = mask_dict
def expert_fn(self, inp, fwd_expert_count):
r"""
......@@ -196,14 +204,33 @@ class FMoE(nn.Module):
gate_top_k_idx, gate_score = self.gate(inp)
x = _fmoe_general_global_forward(
inp,
# delete masked tensors
if self.mask != None and self.mask_dict != None:
mask = self.mask.view(-1)
# to: (BxL') x d_model
inp = inp[mask == 0, :]
gate_top_k_idx = gate_top_k_idx[mask == 0, :]
fwd = _fmoe_general_global_forward(
inp,
gate_top_k_idx,
self.expert_fn, self.num_expert, self.world_size
)
x = x.view(inp.shape[0], self.top_k, self.d_model)
gate_score = gate_score.view(inp.shape[0], 1, self.top_k)
# recover deleted tensors
if self.mask != None and self.mask_dict != None:
# to: (BxL') x top_k x d_model
fwd = fwd.view(-1, self.top_k, self.d_model)
# to: (BxL) x top_k x d_model
x = torch.zeros(mask.shape[0], self.top_k, self.d_model)
# recover
x[mask == 0] = fwd
for k, v in self.mask_dict.items():
x[mask == k] = v
else:
x = fwd.view(-1, self.top_k, self.d_model)
gate_score = gate_score.view(x.shape[0], 1, self.top_k)
x = torch.bmm(gate_score, x).reshape(-1, self.d_model)
if self.mp_size > 1:
......
......@@ -49,6 +49,8 @@ class FMoETransformerMLP(FMoE):
top_k=2,
expert_dp_comm="none",
gate_hook=None,
mask=None,
mask_dict=None,
):
super().__init__(
num_expert=num_expert,
......@@ -58,6 +60,8 @@ class FMoETransformerMLP(FMoE):
world_size=world_size,
mp_group=mp_group,
gate_hook=gate_hook,
mask=mask,
mask_dict=mask_dict
)
self.experts = _Expert(
num_expert, d_model, d_hidden, activation, rank=self.mp_rank
......
......@@ -55,7 +55,11 @@ class BruteForceMoE(nn.Module):
self.num_expert = num_expert
self.d_model = d_model
self.top_k = top_k
self.experts = [expert(d_model) for _ in range(num_expert * world_size)]
if type(expert) is list:
self.experts = [e(d_model) for e in expert]
self.num_expert = num_expert = len(expert)
else:
self.experts = [expert(d_model) for _ in range(num_expert * world_size)]
def forward(self, inp, gate_idx, gate_score):
inp = inp.repeat_interleave(repeats=self.top_k, dim=0)
......
......@@ -384,6 +384,107 @@ def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group):
_assert_numerical(names, ddp_out_list, raw_out_list, rank)
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("num_expert", [None])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("top_k", [2, 3])
@pytest.mark.parametrize("expert", [ [NaiveExpert for _ in range(4)], [LinearExpert, NaiveExpert, LinearExpert, NaiveExpert, LinearExpert, NaiveExpert, LinearExpert, NaiveExpert] ])
@pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1])
@pytest.mark.parametrize("mp_group", [None])
@pytest.mark.parametrize("dp_group", [None])
@pytest.mark.parametrize("world_group", [None])
def test_fmoe_experts(
batch_size,
num_expert,
d_model,
top_k,
expert: Union[Type[nn.Module], str],
rank,
world_size,
mp_group,
dp_group,
world_group,
):
torch.manual_seed(42 + rank)
torch.cuda.manual_seed(42 + rank)
if isinstance(expert, str):
expert = globals()[expert]
moe = FMoE(
num_expert=num_expert,
d_model=d_model,
gate=NaiveGate,
world_size=world_size,
mp_group=mp_group,
expert=expert,
top_k=top_k,
).cuda()
moe_raw = BruteForceMoE(
expert=expert,
num_expert=num_expert,
d_model=d_model,
world_size=world_size,
top_k=top_k,
).cuda()
if world_size == 1:
for expert_moe, expert_raw in zip(moe.experts, moe_raw.experts):
for para_moe, para_raw in zip(
expert_moe.parameters(), expert_raw.parameters()
):
para_raw.data = para_moe.data.clone()
else:
assert len(moe.experts) >= 1
for idx, para in enumerate(moe.experts[0].parameters()):
para_tensor = torch.cat(
[list(expert.parameters())[idx].unsqueeze(0) for expert in moe.experts]
)
para_array = [torch.empty_like(para_tensor) for _ in range(world_size)]
torch.distributed.all_gather(para_array, para_tensor)
para_tensor_gathered = torch.cat(para_array, dim=0)
assert para_tensor_gathered.shape[0] == len(moe_raw.experts)
for expertID in range(para_tensor_gathered.shape[0]):
list(moe_raw.experts[expertID].parameters())[
idx
].data = para_tensor_gathered[expertID]
moe_out, raw_out, moe_grad_in, raw_grad_in = _perform_forward(
moe, moe_raw, batch_size, d_model, top_k, rank, mp_group
)
def get_experts_grad(experts: List[nn.Module]):
return torch.stack(
[
torch.stack(
[
p.grad.sum() if p.grad is not None else torch.zeros(1).cuda()
for p in item.parameters()
]
).sum()
for item in experts
]
)
moe_grad, raw_grad = (
get_experts_grad(moe.experts),
get_experts_grad(moe_raw.experts),
)
if world_size > 1:
torch.distributed.all_reduce(raw_grad)
mp_size = mp_group.size() if mp_group else 1
raw_grad = raw_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
moe_out_list = [moe_out, moe_grad, moe_grad_in]
raw_out_list = [raw_out, raw_grad, raw_grad_in]
names = ["forward", "backward", "grad_in"]
_assert_numerical(names, moe_out_list, raw_out_list, rank)
if __name__ == "__main__":
test_fmoe_linear(
batch_size=2,
......@@ -396,4 +497,5 @@ if __name__ == "__main__":
mp_group=None,
dp_group=None,
world_group=None,
data_type=torch.float32,
)
......@@ -51,8 +51,13 @@ def test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1):
def _test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1):
inp = torch.rand(batch_size, d_hidden).cuda()
mask = torch.zeros(inp.shape[0], dtype=torch.long)
mask[1] = 1
mask_dict = {
1: torch.zeros(d_hidden).cuda()
}
model = FMoETransformerMLP(num_expert, d_hidden, d_hidden * 4, world_size,
gate=ConstantGate).cuda()
gate=ConstantGate, mask=mask, mask_dict=mask_dict).cuda()
oup = model(inp)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment