Commit 07a5d8ac authored by Rick Ho's avatar Rick Ho
Browse files

fix zero test

parent 5f8ba136
......@@ -56,8 +56,9 @@ def _test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1)
mask_dict = {
1: torch.zeros(d_hidden).cuda()
}
model = FMoETransformerMLP(num_expert, d_hidden, d_hidden * 4, world_size,
gate=ConstantGate, mask=mask, mask_dict=mask_dict).cuda()
model = FMoETransformerMLP(num_expert, d_hidden, d_hidden * 4,
world_size=world_size, gate=ConstantGate, mask=mask,
mask_dict=mask_dict).cuda()
oup = model(inp)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment