Commit 57191b14 authored by Rick Ho's avatar Rick Ho
Browse files

fix python global variable

parent 42b825fe
...@@ -40,7 +40,7 @@ class MOELayer(nn.Module): ...@@ -40,7 +40,7 @@ class MOELayer(nn.Module):
def reset_parameters(self): def reset_parameters(self):
for i in range(self.num_expert): for i in range(self.num_expert):
linear = nn.Linear(in_features=self.in_feat, out_features=out_feat) linear = nn.Linear(in_features=self.in_feat, out_features=self.out_feat)
self.weight.data[i] = linear.weight.data self.weight.data[i] = linear.weight.data
def forward(self, inp, gate): def forward(self, inp, gate):
...@@ -59,7 +59,7 @@ class MOELayer_einsum(nn.Module): ...@@ -59,7 +59,7 @@ class MOELayer_einsum(nn.Module):
def reset_parameters(self): def reset_parameters(self):
for i in range(self.num_expert): for i in range(self.num_expert):
linear = nn.Linear(in_features=self.in_feat, out_features=out_feat) linear = nn.Linear(in_features=self.in_feat, out_features=self.out_feat)
self.weight.data[i] = linear.weight.data self.weight.data[i] = linear.weight.data
def forward(self, inp, gate): def forward(self, inp, gate):
...@@ -70,24 +70,30 @@ class MOELayer_einsum(nn.Module): ...@@ -70,24 +70,30 @@ class MOELayer_einsum(nn.Module):
x[i] = self.weight[gate_long[i]] @ inp[i] x[i] = self.weight[gate_long[i]] @ inp[i]
return x return x
batch_size = 4
num_expert = 4
in_feat = 2
out_feat = 3
moe = MOELayer(num_expert, in_feat, out_feat).cuda() def test():
moe_einsum = MOELayer_einsum(num_expert, in_feat, out_feat).cuda() batch_size = 4
moe_einsum.weight.data = moe.weight.data.clone() num_expert = 4
in_feat = 2
out_feat = 3
moe = MOELayer(num_expert, in_feat, out_feat).cuda()
moe_einsum = MOELayer_einsum(num_expert, in_feat, out_feat).cuda()
moe_einsum.weight.data = moe.weight.data.clone()
inp = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
output = moe(inp, gate) inp = torch.rand(batch_size, in_feat).cuda()
output_einsum = moe_einsum(inp.clone(), gate.clone()) gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
print(output) output = moe(inp, gate)
print(output_einsum) output_einsum = moe_einsum(inp.clone(), gate.clone())
#y = output.mean() print(output)
#y.backward() print(output_einsum)
\ No newline at end of file
#y = output.mean()
#y.backward()
if __name__ == '__main__':
test()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment