"...text-generation-inference.git" did not exist on "7e2a7433d3584a5a68dbf3e71def4323079f2c26"
Unverified Commit 0bf6aeb8 authored by Saurav Maheshkar's avatar Saurav Maheshkar Committed by GitHub
Browse files

feat: rename single-letter vars in `resnet.py` (#3868)

feat: rename single-letter vars
parent 9a45d7fb
...@@ -95,9 +95,9 @@ class Downsample1D(nn.Module): ...@@ -95,9 +95,9 @@ class Downsample1D(nn.Module):
assert self.channels == self.out_channels assert self.channels == self.out_channels
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride) self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
def forward(self, x): def forward(self, inputs):
assert x.shape[1] == self.channels assert inputs.shape[1] == self.channels
return self.conv(x) return self.conv(inputs)
class Upsample2D(nn.Module): class Upsample2D(nn.Module):
...@@ -431,13 +431,13 @@ class KDownsample2D(nn.Module): ...@@ -431,13 +431,13 @@ class KDownsample2D(nn.Module):
self.pad = kernel_1d.shape[1] // 2 - 1 self.pad = kernel_1d.shape[1] // 2 - 1
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
def forward(self, x): def forward(self, inputs):
x = F.pad(x, (self.pad,) * 4, self.pad_mode) inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
indices = torch.arange(x.shape[1], device=x.device) indices = torch.arange(inputs.shape[1], device=inputs.device)
kernel = self.kernel.to(weight)[None, :].expand(x.shape[1], -1, -1) kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
weight[indices, indices] = kernel weight[indices, indices] = kernel
return F.conv2d(x, weight, stride=2) return F.conv2d(inputs, weight, stride=2)
class KUpsample2D(nn.Module): class KUpsample2D(nn.Module):
...@@ -448,13 +448,13 @@ class KUpsample2D(nn.Module): ...@@ -448,13 +448,13 @@ class KUpsample2D(nn.Module):
self.pad = kernel_1d.shape[1] // 2 - 1 self.pad = kernel_1d.shape[1] // 2 - 1
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
def forward(self, x): def forward(self, inputs):
x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode) inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
indices = torch.arange(x.shape[1], device=x.device) indices = torch.arange(inputs.shape[1], device=inputs.device)
kernel = self.kernel.to(weight)[None, :].expand(x.shape[1], -1, -1) kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
weight[indices, indices] = kernel weight[indices, indices] = kernel
return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1) return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
class ResnetBlock2D(nn.Module): class ResnetBlock2D(nn.Module):
...@@ -664,13 +664,13 @@ class Conv1dBlock(nn.Module): ...@@ -664,13 +664,13 @@ class Conv1dBlock(nn.Module):
self.group_norm = nn.GroupNorm(n_groups, out_channels) self.group_norm = nn.GroupNorm(n_groups, out_channels)
self.mish = nn.Mish() self.mish = nn.Mish()
def forward(self, x): def forward(self, inputs):
x = self.conv1d(x) intermediate_repr = self.conv1d(inputs)
x = rearrange_dims(x) intermediate_repr = rearrange_dims(intermediate_repr)
x = self.group_norm(x) intermediate_repr = self.group_norm(intermediate_repr)
x = rearrange_dims(x) intermediate_repr = rearrange_dims(intermediate_repr)
x = self.mish(x) output = self.mish(intermediate_repr)
return x return output
# unet_rl.py # unet_rl.py
...@@ -687,10 +687,10 @@ class ResidualTemporalBlock1D(nn.Module): ...@@ -687,10 +687,10 @@ class ResidualTemporalBlock1D(nn.Module):
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity() nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
) )
def forward(self, x, t): def forward(self, inputs, t):
""" """
Args: Args:
x : [ batch_size x inp_channels x horizon ] inputs : [ batch_size x inp_channels x horizon ]
t : [ batch_size x embed_dim ] t : [ batch_size x embed_dim ]
returns: returns:
...@@ -698,9 +698,9 @@ class ResidualTemporalBlock1D(nn.Module): ...@@ -698,9 +698,9 @@ class ResidualTemporalBlock1D(nn.Module):
""" """
t = self.time_emb_act(t) t = self.time_emb_act(t)
t = self.time_emb(t) t = self.time_emb(t)
out = self.conv_in(x) + rearrange_dims(t) out = self.conv_in(inputs) + rearrange_dims(t)
out = self.conv_out(out) out = self.conv_out(out)
return out + self.residual_conv(x) return out + self.residual_conv(inputs)
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1): def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment