Unverified Commit 85eff637 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

[{Up,Down}sample1d] explicit view kernel size as number elements in flattened indices (#3479)

explicit view kernel size as number elements in flattened indices
parent e589bdb9
...@@ -300,7 +300,8 @@ class Downsample1d(nn.Module): ...@@ -300,7 +300,8 @@ class Downsample1d(nn.Module):
hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode) hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode)
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
weight[indices, indices] = self.kernel.to(weight) kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
weight[indices, indices] = kernel
return F.conv1d(hidden_states, weight, stride=2) return F.conv1d(hidden_states, weight, stride=2)
...@@ -316,7 +317,8 @@ class Upsample1d(nn.Module): ...@@ -316,7 +317,8 @@ class Upsample1d(nn.Module):
hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode) hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
weight[indices, indices] = self.kernel.to(weight) kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
weight[indices, indices] = kernel
return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1) return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment