Commit 76327544 authored by Rick Ho's avatar Rick Ho
Browse files

distributed test weight

parent 864a4522
......@@ -67,15 +67,15 @@ def test_module(moe, linear, inp, gate):
moe.zero_grad()
x = (linear(inp))
output = moe(x, gate)
print('ooutput', torch.distributed.get_rank(), output)
# print('ooutput', torch.distributed.get_rank(), output)
y = output.mean()
y.backward()
return output, moe.weight.grad, linear.weight.grad, linear.bias.grad
def test():
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.manual_seed(42 + torch.distributed.get_rank())
torch.cuda.manual_seed(42 + torch.distributed.get_rank())
batch_size = 4
num_expert = 2
in_feat = 6
......@@ -106,10 +106,14 @@ def test():
moe_out = test_module(moe, linear, inp.clone(), gate.clone())
raw_out = test_module(moe_raw, linear, inp.clone(), gate.clone())
if world_size == 1:
names = ['Out', 'Moe wei', 'Linear wei', 'Linear bias']
else:
names = ['Out']
names = ['Out', 'Moe wei', 'Linear wei', 'Linear bias']
if world_size > 1:
rank = torch.distributed.get_rank()
ou, wg, lwg, lbg = raw_out
wg = wg.cpu()
torch.distributed.all_reduce(wg)
wg = wg[rank * num_expert:(rank + 1)* num_expert]
raw_out = ou, wg.cuda(), lwg, lbg
for name, mo, ro in zip(names, moe_out, raw_out):
err = (mo - ro).abs().sum()
print('{} abs err {}'.format(name, err))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment