Unverified Commit db56f8a4 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

explicit broadcasts for assignments (#3535)

parent c13dbd5c
...@@ -433,7 +433,8 @@ class KDownsample2D(nn.Module): ...@@ -433,7 +433,8 @@ class KDownsample2D(nn.Module):
x = F.pad(x, (self.pad,) * 4, self.pad_mode) x = F.pad(x, (self.pad,) * 4, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
indices = torch.arange(x.shape[1], device=x.device) indices = torch.arange(x.shape[1], device=x.device)
weight[indices, indices] = self.kernel.to(weight) kernel = self.kernel.to(weight)[None, :].expand(x.shape[1], -1, -1)
weight[indices, indices] = kernel
return F.conv2d(x, weight, stride=2) return F.conv2d(x, weight, stride=2)
...@@ -449,7 +450,8 @@ class KUpsample2D(nn.Module): ...@@ -449,7 +450,8 @@ class KUpsample2D(nn.Module):
x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode) x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
indices = torch.arange(x.shape[1], device=x.device) indices = torch.arange(x.shape[1], device=x.device)
weight[indices, indices] = self.kernel.to(weight) kernel = self.kernel.to(weight)[None, :].expand(x.shape[1], -1, -1)
weight[indices, indices] = kernel
return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1) return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment