"docs/git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "9a50cf800e511031602cb4b9523cef7b448b16af"
Commit 2ba58797 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

update

parent 560d3f1b
...@@ -80,6 +80,8 @@ def test(): ...@@ -80,6 +80,8 @@ def test():
in_feat = 2 in_feat = 2
out_feat = 3 out_feat = 3
linear = nn.Linear(in_feat, in_feat).cuda()
moe = MOELayer(num_expert, in_feat, out_feat).cuda() moe = MOELayer(num_expert, in_feat, out_feat).cuda()
moe_raw = MOELayer_raw(num_expert, in_feat, out_feat).cuda() moe_raw = MOELayer_raw(num_expert, in_feat, out_feat).cuda()
moe_raw.weight.data = moe.weight.data.clone() moe_raw.weight.data = moe.weight.data.clone()
...@@ -87,21 +89,28 @@ def test(): ...@@ -87,21 +89,28 @@ def test():
inp = torch.rand(batch_size, in_feat).cuda() inp = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda() gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
output = moe(inp, gate) linear.zero_grad()
output_raw= moe_raw(inp.clone(), gate.clone()) moe.zero_grad()
x = linear(inp)
print(output) output = moe(x, gate)
print(output_raw) print("moe output", output)
y = output.mean() y = output.mean()
y.backward() y.backward()
print("moe.weight.grad", moe.weight.grad)
print("linear.weight.grad", linear.weight.grad)
print("linear.bias.grad", linear.bias.grad)
linear.zero_grad()
moe.zero_grad()
x = linear(inp.clone())
output_raw= moe_raw(x, gate.clone())
print("moe_raw output", output_raw)
y_raw = output_raw.mean() y_raw = output_raw.mean()
y_raw.backward() y_raw.backward()
print("moe_raw.weight.grad", moe_raw.weight.grad)
print(moe.weight.grad) print("linear_raw.weight.grad", linear.weight.grad)
print(moe_raw.weight.grad) print("linear_raw.bias.grad", linear.bias.grad)
if __name__ == '__main__': if __name__ == '__main__':
test() test()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment