"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "fedba24a63b047239d3d9616052b912661894afb"
Unverified Commit 7265dd8c authored by NIKHIL A V's avatar NIKHIL A V Committed by GitHub
Browse files

renamed x to meaningful variable in resnet.py (#677)



* renamed single letter variables

* renamed x to meaningful variable in resnet.py

Hello @patil-suraj can you verify it
Thanks

* Reformatted using black

* renamed x to meaningful variable in resnet.py

Hello @patil-suraj can you verify it
Thanks

* reformatted the files

* modified unboundlocalerror in line 374

* removed referenced before error

* renamed single variable x -> hidden_state, p-> pad_value
Co-authored-by: default avatarNikhil A V <nikhilav@Nikhils-MacBook-Pro.local>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent 14b97549
...@@ -112,7 +112,7 @@ class FirUpsample2D(nn.Module): ...@@ -112,7 +112,7 @@ class FirUpsample2D(nn.Module):
self.fir_kernel = fir_kernel self.fir_kernel = fir_kernel
self.out_channels = out_channels self.out_channels = out_channels
def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1): def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
"""Fused `upsample_2d()` followed by `Conv2d()`. """Fused `upsample_2d()` followed by `Conv2d()`.
Args: Args:
...@@ -151,34 +151,46 @@ class FirUpsample2D(nn.Module): ...@@ -151,34 +151,46 @@ class FirUpsample2D(nn.Module):
convW = weight.shape[3] convW = weight.shape[3]
inC = weight.shape[1] inC = weight.shape[1]
p = (kernel.shape[0] - factor) - (convW - 1) pad_value = (kernel.shape[0] - factor) - (convW - 1)
stride = (factor, factor) stride = (factor, factor)
# Determine data dimensions. # Determine data dimensions.
output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW) output_shape = (
(hidden_states.shape[2] - 1) * factor + convH,
(hidden_states.shape[3] - 1) * factor + convW,
)
output_padding = ( output_padding = (
output_shape[0] - (x.shape[2] - 1) * stride[0] - convH, output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
output_shape[1] - (x.shape[3] - 1) * stride[1] - convW, output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
) )
assert output_padding[0] >= 0 and output_padding[1] >= 0 assert output_padding[0] >= 0 and output_padding[1] >= 0
inC = weight.shape[1] inC = weight.shape[1]
num_groups = x.shape[1] // inC num_groups = hidden_states.shape[1] // inC
# Transpose weights. # Transpose weights.
weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW)) weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4) weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW)) weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
x = F.conv_transpose2d(x, weight, stride=stride, output_padding=output_padding, padding=0) inverse_conv = F.conv_transpose2d(
hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
)
x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) output = upfirdn2d_native(
inverse_conv,
torch.tensor(kernel, device=inverse_conv.device),
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
)
else: else:
p = kernel.shape[0] - factor pad_value = kernel.shape[0] - factor
x = upfirdn2d_native( output = upfirdn2d_native(
x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2) hidden_states,
torch.tensor(kernel, device=hidden_states.device),
up=factor,
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
) )
return x return output
def forward(self, hidden_states): def forward(self, hidden_states):
if self.use_conv: if self.use_conv:
...@@ -200,7 +212,7 @@ class FirDownsample2D(nn.Module): ...@@ -200,7 +212,7 @@ class FirDownsample2D(nn.Module):
self.use_conv = use_conv self.use_conv = use_conv
self.out_channels = out_channels self.out_channels = out_channels
def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1): def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
"""Fused `Conv2d()` followed by `downsample_2d()`. """Fused `Conv2d()` followed by `downsample_2d()`.
Args: Args:
...@@ -232,20 +244,29 @@ class FirDownsample2D(nn.Module): ...@@ -232,20 +244,29 @@ class FirDownsample2D(nn.Module):
if self.use_conv: if self.use_conv:
_, _, convH, convW = weight.shape _, _, convH, convW = weight.shape
p = (kernel.shape[0] - factor) + (convW - 1) pad_value = (kernel.shape[0] - factor) + (convW - 1)
s = [factor, factor] stride_value = [factor, factor]
x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2, p // 2)) upfirdn_input = upfirdn2d_native(
x = F.conv2d(x, weight, stride=s, padding=0) hidden_states,
torch.tensor(kernel, device=hidden_states.device),
pad=((pad_value + 1) // 2, pad_value // 2),
)
hidden_states = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
else: else:
p = kernel.shape[0] - factor pad_value = kernel.shape[0] - factor
x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) hidden_states = upfirdn2d_native(
hidden_states,
torch.tensor(kernel, device=hidden_states.device),
down=factor,
pad=((pad_value + 1) // 2, pad_value // 2),
)
return x return hidden_states
def forward(self, hidden_states): def forward(self, hidden_states):
if self.use_conv: if self.use_conv:
hidden_states = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
hidden_states = hidden_states + self.Conv2d_0.bias.reshape(1, -1, 1, 1) hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else: else:
hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
...@@ -332,17 +353,17 @@ class ResnetBlock2D(nn.Module): ...@@ -332,17 +353,17 @@ class ResnetBlock2D(nn.Module):
if self.use_in_shortcut: if self.use_in_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, temb): def forward(self, input_tensor, temb):
hidden_states = x hidden_states = input_tensor
hidden_states = self.norm1(hidden_states) hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states) hidden_states = self.nonlinearity(hidden_states)
if self.upsample is not None: if self.upsample is not None:
x = self.upsample(x) input_tensor = self.upsample(input_tensor)
hidden_states = self.upsample(hidden_states) hidden_states = self.upsample(hidden_states)
elif self.downsample is not None: elif self.downsample is not None:
x = self.downsample(x) input_tensor = self.downsample(input_tensor)
hidden_states = self.downsample(hidden_states) hidden_states = self.downsample(hidden_states)
hidden_states = self.conv1(hidden_states) hidden_states = self.conv1(hidden_states)
...@@ -358,19 +379,19 @@ class ResnetBlock2D(nn.Module): ...@@ -358,19 +379,19 @@ class ResnetBlock2D(nn.Module):
hidden_states = self.conv2(hidden_states) hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None: if self.conv_shortcut is not None:
x = self.conv_shortcut(x) input_tensor = self.conv_shortcut(input_tensor)
out = (x + hidden_states) / self.output_scale_factor output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return out return output_tensor
class Mish(torch.nn.Module): class Mish(torch.nn.Module):
def forward(self, x): def forward(self, hidden_states):
return x * torch.tanh(torch.nn.functional.softplus(x)) return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
def upsample_2d(x, kernel=None, factor=2, gain=1): def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
r"""Upsample2D a batch of 2D images with the given filter. r"""Upsample2D a batch of 2D images with the given filter.
Args: Args:
...@@ -397,11 +418,16 @@ def upsample_2d(x, kernel=None, factor=2, gain=1): ...@@ -397,11 +418,16 @@ def upsample_2d(x, kernel=None, factor=2, gain=1):
kernel /= torch.sum(kernel) kernel /= torch.sum(kernel)
kernel = kernel * (gain * (factor**2)) kernel = kernel * (gain * (factor**2))
p = kernel.shape[0] - factor pad_value = kernel.shape[0] - factor
return upfirdn2d_native(x, kernel.to(device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)) return upfirdn2d_native(
hidden_states,
kernel.to(device=hidden_states.device),
up=factor,
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
)
def downsample_2d(x, kernel=None, factor=2, gain=1): def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
r"""Downsample2D a batch of 2D images with the given filter. r"""Downsample2D a batch of 2D images with the given filter.
Args: Args:
...@@ -429,8 +455,10 @@ def downsample_2d(x, kernel=None, factor=2, gain=1): ...@@ -429,8 +455,10 @@ def downsample_2d(x, kernel=None, factor=2, gain=1):
kernel /= torch.sum(kernel) kernel /= torch.sum(kernel)
kernel = kernel * gain kernel = kernel * gain
p = kernel.shape[0] - factor pad_value = kernel.shape[0] - factor
return upfirdn2d_native(x, kernel.to(device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) return upfirdn2d_native(
hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
)
def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)): def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
...@@ -441,6 +469,7 @@ def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)): ...@@ -441,6 +469,7 @@ def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
_, channel, in_h, in_w = input.shape _, channel, in_h, in_w = input.shape
input = input.reshape(-1, in_h, in_w, 1) input = input.reshape(-1, in_h, in_w, 1)
# Rename this variable (input); it shadows a builtin.sonarlint(python:S5806)
_, in_h, in_w, minor = input.shape _, in_h, in_w, minor = input.shape
kernel_h, kernel_w = kernel.shape kernel_h, kernel_w = kernel.shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment