"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "843355f89fd043e82b3344d9259e6faa640da6f9"
Commit 13ac40ed authored by patil-suraj's avatar patil-suraj
Browse files

style

parent ebe68343
...@@ -603,7 +603,7 @@ class ResnetBlockBigGANpp(nn.Module): ...@@ -603,7 +603,7 @@ class ResnetBlockBigGANpp(nn.Module):
self.Dropout_0 = nn.Dropout(dropout) self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv2d(out_ch, out_ch, init_scale=init_scale, kernel_size=3, padding=1) self.Conv_1 = conv2d(out_ch, out_ch, init_scale=init_scale, kernel_size=3, padding=1)
if in_ch != out_ch or up or down: if in_ch != out_ch or up or down:
#1x1 convolution with DDPM initialization. # 1x1 convolution with DDPM initialization.
self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0) self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0)
self.skip_rescale = skip_rescale self.skip_rescale = skip_rescale
...@@ -757,9 +757,7 @@ class RearrangeDim(nn.Module): ...@@ -757,9 +757,7 @@ class RearrangeDim(nn.Module):
def conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1): def conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1):
"""nXn convolution with DDPM initialization.""" """nXn convolution with DDPM initialization."""
conv = nn.Conv2d( conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
)
conv.weight.data = variance_scaling(init_scale)(conv.weight.data.shape) conv.weight.data = variance_scaling(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias) nn.init.zeros_(conv.bias)
return conv return conv
......
...@@ -289,9 +289,7 @@ def downsample_2d(x, k=None, factor=2, gain=1): ...@@ -289,9 +289,7 @@ def downsample_2d(x, k=None, factor=2, gain=1):
def conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1): def conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1):
"""nXn convolution with DDPM initialization.""" """nXn convolution with DDPM initialization."""
conv = nn.Conv2d( conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
)
conv.weight.data = variance_scaling(init_scale)(conv.weight.data.shape) conv.weight.data = variance_scaling(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias) nn.init.zeros_(conv.bias)
return conv return conv
...@@ -336,7 +334,7 @@ class Combine(nn.Module): ...@@ -336,7 +334,7 @@ class Combine(nn.Module):
def __init__(self, dim1, dim2, method="cat"): def __init__(self, dim1, dim2, method="cat"):
super().__init__() super().__init__()
#1x1 convolution with DDPM initialization. # 1x1 convolution with DDPM initialization.
self.Conv_0 = conv2d(dim1, dim2, kernel_size=1, padding=0) self.Conv_0 = conv2d(dim1, dim2, kernel_size=1, padding=0)
self.method = method self.method = method
...@@ -602,7 +600,9 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -602,7 +600,9 @@ class NCSNpp(ModelMixin, ConfigMixin):
else: else:
if progressive == "output_skip": if progressive == "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(conv2d(in_ch, channels, bias=True, init_scale=init_scale, kernel_size=3, padding=1)) modules.append(
conv2d(in_ch, channels, bias=True, init_scale=init_scale, kernel_size=3, padding=1)
)
pyramid_ch = channels pyramid_ch = channels
elif progressive == "residual": elif progressive == "residual":
modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch)) modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment