"git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "89e8d27e4f749a6fe620798d45709bb050ed7abc"
Commit a7b0047e authored by Patrick von Platen's avatar Patrick von Platen
Browse files

some clean up

parent dcb9070b
......@@ -175,6 +175,7 @@ class Downsample(nn.Module):
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py, unet_score_vde.py
# => All 2D-Resnets are included here now!
class ResnetBlock(nn.Module):
def __init__(
self,
......@@ -317,9 +318,6 @@ class ResnetBlock(nn.Module):
num_groups = min(in_ch // 4, 32)
num_groups_out = min(out_ch // 4, 32)
temb_dim = temb_channels
# output_scale_factor = np.sqrt(2.0)
# non_linearity = "silu"
# use_nin_shortcut = in_channels != out_channels or use_nin_shortcut = True
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=in_ch, eps=eps)
self.up = up
......@@ -337,13 +335,9 @@ class ResnetBlock(nn.Module):
# 1x1 convolution with DDPM initialization.
self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0)
# self.skip_rescale = skip_rescale
self.in_ch = in_ch
self.out_ch = out_ch
# TODO(Patrick) - move to main init
self.is_overwritten = False
def set_weights_grad_tts(self):
self.conv1.weight.data = self.block1.block[0].weight.data
self.conv1.bias.data = self.block1.block[0].bias.data
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment