"git@developer.sourcefind.cn:OpenDAS/torch-harmonics.git" did not exist on "42067ef2628320aa28cc79eb7d8bca97088f934e"
Commit f5cc759c authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

test data parallel (FAILED!)

parent 7945bce6
...@@ -112,5 +112,30 @@ def test(): ...@@ -112,5 +112,30 @@ def test():
print("linear_raw.weight.grad", linear.weight.grad) print("linear_raw.weight.grad", linear.weight.grad)
print("linear_raw.bias.grad", linear.bias.grad) print("linear_raw.bias.grad", linear.bias.grad)
def test_dp():
torch.manual_seed(42)
torch.cuda.manual_seed(42)
batch_size = 4
num_expert = 4
in_feat = 2
out_feat = 3
inp = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
print("data parallel of a nn.Linear model")
linear = nn.Linear(in_feat, in_feat).cuda()
moe_linear = torch.nn.DataParallel(linear, device_ids=[0, 1])
output = moe_linear(inp)
print("successful!")
print("data parallel of our MoE model")
moe = MOELayer(num_expert, in_feat, out_feat).cuda()
moe_dp = torch.nn.DataParallel(moe, device_ids=[0, 1])
output = moe_dp(inp, gate)
if __name__ == '__main__': if __name__ == '__main__':
test() # test()
test_dp()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment