Commit 76327544 authored by Rick Ho's avatar Rick Ho
Browse files

distributed test weight

parent 864a4522
...@@ -67,15 +67,15 @@ def test_module(moe, linear, inp, gate): ...@@ -67,15 +67,15 @@ def test_module(moe, linear, inp, gate):
moe.zero_grad() moe.zero_grad()
x = (linear(inp)) x = (linear(inp))
output = moe(x, gate) output = moe(x, gate)
print('ooutput', torch.distributed.get_rank(), output) # print('ooutput', torch.distributed.get_rank(), output)
y = output.mean() y = output.mean()
y.backward() y.backward()
return output, moe.weight.grad, linear.weight.grad, linear.bias.grad return output, moe.weight.grad, linear.weight.grad, linear.bias.grad
def test(): def test():
torch.manual_seed(42) torch.manual_seed(42 + torch.distributed.get_rank())
torch.cuda.manual_seed(42) torch.cuda.manual_seed(42 + torch.distributed.get_rank())
batch_size = 4 batch_size = 4
num_expert = 2 num_expert = 2
in_feat = 6 in_feat = 6
...@@ -106,10 +106,14 @@ def test(): ...@@ -106,10 +106,14 @@ def test():
moe_out = test_module(moe, linear, inp.clone(), gate.clone()) moe_out = test_module(moe, linear, inp.clone(), gate.clone())
raw_out = test_module(moe_raw, linear, inp.clone(), gate.clone()) raw_out = test_module(moe_raw, linear, inp.clone(), gate.clone())
if world_size == 1: names = ['Out', 'Moe wei', 'Linear wei', 'Linear bias']
names = ['Out', 'Moe wei', 'Linear wei', 'Linear bias'] if world_size > 1:
else: rank = torch.distributed.get_rank()
names = ['Out'] ou, wg, lwg, lbg = raw_out
wg = wg.cpu()
torch.distributed.all_reduce(wg)
wg = wg[rank * num_expert:(rank + 1)* num_expert]
raw_out = ou, wg.cuda(), lwg, lbg
for name, mo, ro in zip(names, moe_out, raw_out): for name, mo, ro in zip(names, moe_out, raw_out):
err = (mo - ro).abs().sum() err = (mo - ro).abs().sum()
print('{} abs err {}'.format(name, err)) print('{} abs err {}'.format(name, err))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment