Commit e86dea53 authored by Rick Ho's avatar Rick Ho
Browse files

fix tensor initialization bug

parent 9e67148c
...@@ -47,7 +47,7 @@ class FMoELinear(nn.Module): ...@@ -47,7 +47,7 @@ class FMoELinear(nn.Module):
device = self.weight.device device = self.weight.device
dtype = self.weight.dtype dtype = self.weight.dtype
weight = rng.uniform(-bound, bound, size=tuple(self.weight.size())) weight = rng.uniform(-bound, bound, size=tuple(self.weight.size()))
self.weight.data = torch.Tensor(weight, dtype=dtype, device=device) self.weight.data = torch.tensor(weight, dtype=dtype, device=device)
if self.bias is not None: if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0]) fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment