Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
a7b0047e
Commit
a7b0047e
authored
Jul 01, 2022
by
Patrick von Platen
Browse files
some clean up
parent
dcb9070b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
1 addition
and
7 deletions
+1
-7
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+1
-7
No files found.
src/diffusers/models/resnet.py
View file @
a7b0047e
...
@@ -175,6 +175,7 @@ class Downsample(nn.Module):
...
@@ -175,6 +175,7 @@ class Downsample(nn.Module):
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py, unet_score_vde.py
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py, unet_score_vde.py
# => All 2D-Resnets are included here now!
class
ResnetBlock
(
nn
.
Module
):
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -317,9 +318,6 @@ class ResnetBlock(nn.Module):
...
@@ -317,9 +318,6 @@ class ResnetBlock(nn.Module):
num_groups
=
min
(
in_ch
//
4
,
32
)
num_groups
=
min
(
in_ch
//
4
,
32
)
num_groups_out
=
min
(
out_ch
//
4
,
32
)
num_groups_out
=
min
(
out_ch
//
4
,
32
)
temb_dim
=
temb_channels
temb_dim
=
temb_channels
# output_scale_factor = np.sqrt(2.0)
# non_linearity = "silu"
# use_nin_shortcut = in_channels != out_channels or use_nin_shortcut = True
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
num_groups
,
num_channels
=
in_ch
,
eps
=
eps
)
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
num_groups
,
num_channels
=
in_ch
,
eps
=
eps
)
self
.
up
=
up
self
.
up
=
up
...
@@ -337,13 +335,9 @@ class ResnetBlock(nn.Module):
...
@@ -337,13 +335,9 @@ class ResnetBlock(nn.Module):
# 1x1 convolution with DDPM initialization.
# 1x1 convolution with DDPM initialization.
self
.
Conv_2
=
conv2d
(
in_ch
,
out_ch
,
kernel_size
=
1
,
padding
=
0
)
self
.
Conv_2
=
conv2d
(
in_ch
,
out_ch
,
kernel_size
=
1
,
padding
=
0
)
# self.skip_rescale = skip_rescale
self
.
in_ch
=
in_ch
self
.
in_ch
=
in_ch
self
.
out_ch
=
out_ch
self
.
out_ch
=
out_ch
# TODO(Patrick) - move to main init
self
.
is_overwritten
=
False
def
set_weights_grad_tts
(
self
):
def
set_weights_grad_tts
(
self
):
self
.
conv1
.
weight
.
data
=
self
.
block1
.
block
[
0
].
weight
.
data
self
.
conv1
.
weight
.
data
=
self
.
block1
.
block
[
0
].
weight
.
data
self
.
conv1
.
bias
.
data
=
self
.
block1
.
block
[
0
].
bias
.
data
self
.
conv1
.
bias
.
data
=
self
.
block1
.
block
[
0
].
bias
.
data
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment