Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
61dc6574
Commit
61dc6574
authored
Jul 01, 2022
by
Patrick von Platen
Browse files
more fixes
parent
f1aade05
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
98 additions
and
13 deletions
+98
-13
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+98
-13
No files found.
src/diffusers/models/resnet.py
View file @
61dc6574
from
abc
import
abstractmethod
from
abc
import
abstractmethod
import
functools
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -374,15 +375,20 @@ class ResnetBlock(nn.Module):
...
@@ -374,15 +375,20 @@ class ResnetBlock(nn.Module):
dropout
=
0.0
,
dropout
=
0.0
,
temb_channels
=
512
,
temb_channels
=
512
,
groups
=
32
,
groups
=
32
,
groups_out
=
None
,
pre_norm
=
True
,
pre_norm
=
True
,
eps
=
1e-6
,
eps
=
1e-6
,
non_linearity
=
"swish"
,
non_linearity
=
"swish"
,
time_embedding_norm
=
"default"
,
time_embedding_norm
=
"default"
,
fir_kernel
=
(
1
,
3
,
3
,
1
),
output_scale_factor
=
1.0
,
use_nin_shortcut
=
None
,
up
=
False
,
up
=
False
,
down
=
False
,
down
=
False
,
overwrite_for_grad_tts
=
False
,
overwrite_for_grad_tts
=
False
,
overwrite_for_ldm
=
False
,
overwrite_for_ldm
=
False
,
overwrite_for_glide
=
False
,
overwrite_for_glide
=
False
,
overwrite_for_score_vde
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
pre_norm
=
pre_norm
self
.
pre_norm
=
pre_norm
...
@@ -393,6 +399,13 @@ class ResnetBlock(nn.Module):
...
@@ -393,6 +399,13 @@ class ResnetBlock(nn.Module):
self
.
time_embedding_norm
=
time_embedding_norm
self
.
time_embedding_norm
=
time_embedding_norm
self
.
up
=
up
self
.
up
=
up
self
.
down
=
down
self
.
down
=
down
self
.
output_scale_factor
=
output_scale_factor
if
groups_out
is
None
:
groups_out
=
groups
if
use_nin_shortcut
is
None
:
use_nin_shortcut
=
self
.
in_channels
!=
self
.
out_channels
if
self
.
pre_norm
:
if
self
.
pre_norm
:
self
.
norm1
=
Normalize
(
in_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
norm1
=
Normalize
(
in_channels
,
num_groups
=
groups
,
eps
=
eps
)
...
@@ -406,7 +419,7 @@ class ResnetBlock(nn.Module):
...
@@ -406,7 +419,7 @@ class ResnetBlock(nn.Module):
elif
time_embedding_norm
==
"scale_shift"
:
elif
time_embedding_norm
==
"scale_shift"
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
2
*
out_channels
)
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
2
*
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
norm2
=
Normalize
(
out_channels
,
num_groups
=
groups
_out
,
eps
=
eps
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
...
@@ -417,14 +430,17 @@ class ResnetBlock(nn.Module):
...
@@ -417,14 +430,17 @@ class ResnetBlock(nn.Module):
elif
non_linearity
==
"silu"
:
elif
non_linearity
==
"silu"
:
self
.
nonlinearity
=
nn
.
SiLU
()
self
.
nonlinearity
=
nn
.
SiLU
()
if
up
:
# if up:
self
.
h_upd
=
Upsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
)
# self.h_upd = Upsample(in_channels, use_conv=False, dims=2)
self
.
x_upd
=
Upsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
)
# self.x_upd = Upsample(in_channels, use_conv=False, dims=2)
elif
down
:
# elif down:
self
.
h_upd
=
Downsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
,
padding
=
1
,
name
=
"op"
)
# self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
self
.
x_upd
=
Downsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
,
padding
=
1
,
name
=
"op"
)
# self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
self
.
upsample
=
Upsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
)
if
self
.
up
else
None
if
self
.
in_channels
!=
self
.
out_channels
:
self
.
downsample
=
Downsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
,
padding
=
1
,
name
=
"op"
)
if
self
.
down
else
None
self
.
nin_shortcut
=
None
if
use_nin_shortcut
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
# TODO(SURAJ, PATRICK): ALL OF THE FOLLOWING OF THE INIT METHOD CAN BE DELETED ONCE WEIGHTS ARE CONVERTED
# TODO(SURAJ, PATRICK): ALL OF THE FOLLOWING OF THE INIT METHOD CAN BE DELETED ONCE WEIGHTS ARE CONVERTED
...
@@ -432,6 +448,7 @@ class ResnetBlock(nn.Module):
...
@@ -432,6 +448,7 @@ class ResnetBlock(nn.Module):
self
.
overwrite_for_glide
=
overwrite_for_glide
self
.
overwrite_for_glide
=
overwrite_for_glide
self
.
overwrite_for_grad_tts
=
overwrite_for_grad_tts
self
.
overwrite_for_grad_tts
=
overwrite_for_grad_tts
self
.
overwrite_for_ldm
=
overwrite_for_ldm
or
overwrite_for_glide
self
.
overwrite_for_ldm
=
overwrite_for_ldm
or
overwrite_for_glide
self
.
overwrite_for_score_vde
=
overwrite_for_score_vde
if
self
.
overwrite_for_grad_tts
:
if
self
.
overwrite_for_grad_tts
:
dim
=
in_channels
dim
=
in_channels
dim_out
=
out_channels
dim_out
=
out_channels
...
@@ -450,6 +467,7 @@ class ResnetBlock(nn.Module):
...
@@ -450,6 +467,7 @@ class ResnetBlock(nn.Module):
channels
=
in_channels
channels
=
in_channels
emb_channels
=
temb_channels
emb_channels
=
temb_channels
use_scale_shift_norm
=
False
use_scale_shift_norm
=
False
non_linearity
=
"silu"
self
.
in_layers
=
nn
.
Sequential
(
self
.
in_layers
=
nn
.
Sequential
(
normalization
(
channels
,
swish
=
1.0
),
normalization
(
channels
,
swish
=
1.0
),
...
@@ -473,6 +491,45 @@ class ResnetBlock(nn.Module):
...
@@ -473,6 +491,45 @@ class ResnetBlock(nn.Module):
self
.
skip_connection
=
nn
.
Identity
()
self
.
skip_connection
=
nn
.
Identity
()
else
:
else
:
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
self
.
skip_connection
=
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
1
)
elif
self
.
overwrite_for_score_vde
:
in_ch
=
in_channels
out_ch
=
out_channels
eps
=
1e-6
num_groups
=
min
(
in_ch
//
4
,
32
)
num_groups_out
=
min
(
out_ch
//
4
,
32
)
temb_dim
=
temb_channels
# output_scale_factor = np.sqrt(2.0)
# non_linearity = "silu"
# use_nin_shortcut = in_channels != out_channels or use_nin_shortcut = True
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
num_groups
,
num_channels
=
in_ch
,
eps
=
eps
)
self
.
up
=
up
self
.
down
=
down
self
.
fir_kernel
=
fir_kernel
self
.
Conv_0
=
conv2d
(
in_ch
,
out_ch
,
kernel_size
=
3
,
padding
=
1
)
if
temb_dim
is
not
None
:
self
.
Dense_0
=
nn
.
Linear
(
temb_dim
,
out_ch
)
self
.
Dense_0
.
weight
.
data
=
variance_scaling
()(
self
.
Dense_0
.
weight
.
shape
)
nn
.
init
.
zeros_
(
self
.
Dense_0
.
bias
)
self
.
GroupNorm_1
=
nn
.
GroupNorm
(
num_groups
=
num_groups_out
,
num_channels
=
out_ch
,
eps
=
eps
)
self
.
Dropout_0
=
nn
.
Dropout
(
dropout
)
self
.
Conv_1
=
conv2d
(
out_ch
,
out_ch
,
init_scale
=
0.0
,
kernel_size
=
3
,
padding
=
1
)
if
in_ch
!=
out_ch
or
up
or
down
:
# 1x1 convolution with DDPM initialization.
self
.
Conv_2
=
conv2d
(
in_ch
,
out_ch
,
kernel_size
=
1
,
padding
=
0
)
# self.skip_rescale = skip_rescale
self
.
in_ch
=
in_ch
self
.
out_ch
=
out_ch
# TODO(Patrick) - move to main init
self
.
upsample
=
functools
.
partial
(
upsample_2d
,
k
=
self
.
fir_kernel
)
self
.
downsample
=
functools
.
partial
(
downsample_2d
,
k
=
self
.
fir_kernel
)
self
.
is_overwritten
=
False
def
set_weights_grad_tts
(
self
):
def
set_weights_grad_tts
(
self
):
self
.
conv1
.
weight
.
data
=
self
.
block1
.
block
[
0
].
weight
.
data
self
.
conv1
.
weight
.
data
=
self
.
block1
.
block
[
0
].
weight
.
data
...
@@ -512,6 +569,24 @@ class ResnetBlock(nn.Module):
...
@@ -512,6 +569,24 @@ class ResnetBlock(nn.Module):
self
.
nin_shortcut
.
weight
.
data
=
self
.
skip_connection
.
weight
.
data
self
.
nin_shortcut
.
weight
.
data
=
self
.
skip_connection
.
weight
.
data
self
.
nin_shortcut
.
bias
.
data
=
self
.
skip_connection
.
bias
.
data
self
.
nin_shortcut
.
bias
.
data
=
self
.
skip_connection
.
bias
.
data
def
set_weights_score_vde
(
self
):
self
.
conv1
.
weight
.
data
=
self
.
Conv_0
.
weight
.
data
self
.
conv1
.
bias
.
data
=
self
.
Conv_0
.
bias
.
data
self
.
norm1
.
weight
.
data
=
self
.
GroupNorm_0
.
weight
.
data
self
.
norm1
.
bias
.
data
=
self
.
GroupNorm_0
.
bias
.
data
self
.
conv2
.
weight
.
data
=
self
.
Conv_1
.
weight
.
data
self
.
conv2
.
bias
.
data
=
self
.
Conv_1
.
bias
.
data
self
.
norm2
.
weight
.
data
=
self
.
GroupNorm_1
.
weight
.
data
self
.
norm2
.
bias
.
data
=
self
.
GroupNorm_1
.
bias
.
data
self
.
temb_proj
.
weight
.
data
=
self
.
Dense_0
.
weight
.
data
self
.
temb_proj
.
bias
.
data
=
self
.
Dense_0
.
bias
.
data
if
self
.
in_channels
!=
self
.
out_channels
or
self
.
up
or
self
.
down
:
self
.
nin_shortcut
.
weight
.
data
=
self
.
Conv_2
.
weight
.
data
self
.
nin_shortcut
.
bias
.
data
=
self
.
Conv_2
.
bias
.
data
def
forward
(
self
,
x
,
temb
,
mask
=
1.0
):
def
forward
(
self
,
x
,
temb
,
mask
=
1.0
):
# TODO(Patrick) eventually this class should be split into multiple classes
# TODO(Patrick) eventually this class should be split into multiple classes
# too many if else statements
# too many if else statements
...
@@ -521,6 +596,9 @@ class ResnetBlock(nn.Module):
...
@@ -521,6 +596,9 @@ class ResnetBlock(nn.Module):
elif
self
.
overwrite_for_ldm
and
not
self
.
is_overwritten
:
elif
self
.
overwrite_for_ldm
and
not
self
.
is_overwritten
:
self
.
set_weights_ldm
()
self
.
set_weights_ldm
()
self
.
is_overwritten
=
True
self
.
is_overwritten
=
True
elif
self
.
overwrite_for_score_vde
and
not
self
.
is_overwritten
:
self
.
set_weights_score_vde
()
self
.
is_overwritten
=
True
h
=
x
h
=
x
h
=
h
*
mask
h
=
h
*
mask
...
@@ -528,10 +606,17 @@ class ResnetBlock(nn.Module):
...
@@ -528,10 +606,17 @@ class ResnetBlock(nn.Module):
h
=
self
.
norm1
(
h
)
h
=
self
.
norm1
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
nonlinearity
(
h
)
if
self
.
up
or
self
.
down
:
if
self
.
upsample
is
not
None
:
x
=
self
.
x_upd
(
x
)
x
=
self
.
upsample
(
x
)
h
=
self
.
h_upd
(
h
)
h
=
self
.
upsample
(
h
)
elif
self
.
downsample
is
not
None
:
x
=
self
.
downsample
(
x
)
h
=
self
.
downsample
(
h
)
# if self.up: or self.down:
# x = self.x_upd(x)
# h = self.h_upd(h)
#
h
=
self
.
conv1
(
h
)
h
=
self
.
conv1
(
h
)
if
not
self
.
pre_norm
:
if
not
self
.
pre_norm
:
...
@@ -563,7 +648,7 @@ class ResnetBlock(nn.Module):
...
@@ -563,7 +648,7 @@ class ResnetBlock(nn.Module):
h
=
h
*
mask
h
=
h
*
mask
x
=
x
*
mask
x
=
x
*
mask
if
self
.
in_
channels
!=
self
.
out_channels
:
if
self
.
n
in_
shortcut
is
not
None
:
x
=
self
.
nin_shortcut
(
x
)
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
return
x
+
h
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment