Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
db934c67
Commit
db934c67
authored
Jun 30, 2022
by
Patrick von Platen
Browse files
fix more tests
parent
185347e4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
69 additions
and
21 deletions
+69
-21
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+68
-20
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+1
-1
No files found.
src/diffusers/models/resnet.py
View file @
db934c67
import
string
from
abc
import
abstractmethod
from
abc
import
abstractmethod
import
numpy
as
np
import
numpy
as
np
...
@@ -188,7 +187,7 @@ class ResBlock(TimestepBlock):
...
@@ -188,7 +187,7 @@ class ResBlock(TimestepBlock):
use_checkpoint
=
False
,
use_checkpoint
=
False
,
up
=
False
,
up
=
False
,
down
=
False
,
down
=
False
,
overwrite
=
Fals
e
,
# TODO(Patrick) - use for glide at later stage
overwrite
=
Tru
e
,
# TODO(Patrick) - use for glide at later stage
):
):
super
().
__init__
()
super
().
__init__
()
self
.
channels
=
channels
self
.
channels
=
channels
...
@@ -220,12 +219,10 @@ class ResBlock(TimestepBlock):
...
@@ -220,12 +219,10 @@ class ResBlock(TimestepBlock):
nn
.
SiLU
(),
nn
.
SiLU
(),
linear
(
linear
(
emb_channels
,
emb_channels
,
2
*
self
.
out_channels
if
use_scale_shift_norm
else
self
.
out_channels
,
2
*
self
.
out_channels
,
),
),
)
)
self
.
out_layers
=
nn
.
Sequential
(
self
.
out_layers
=
nn
.
Sequential
(
# normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
# nn.SiLU() if use_scale_shift_norm else nn.Identity(),
normalization
(
self
.
out_channels
,
swish
=
0.0
),
normalization
(
self
.
out_channels
,
swish
=
0.0
),
nn
.
SiLU
(),
nn
.
SiLU
(),
nn
.
Dropout
(
p
=
dropout
),
nn
.
Dropout
(
p
=
dropout
),
...
@@ -257,13 +254,16 @@ class ResBlock(TimestepBlock):
...
@@ -257,13 +254,16 @@ class ResBlock(TimestepBlock):
self
.
out_channels
=
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
self
.
use_conv_shortcut
=
conv_shortcut
# Add to init
self
.
time_embedding_norm
=
"scale_shift"
if
self
.
pre_norm
:
if
self
.
pre_norm
:
self
.
norm1
=
Normalize
(
in_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
norm1
=
Normalize
(
in_channels
,
num_groups
=
groups
,
eps
=
eps
)
else
:
else
:
self
.
norm1
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
norm1
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
2
*
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
norm2
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
...
@@ -277,6 +277,14 @@ class ResBlock(TimestepBlock):
...
@@ -277,6 +277,14 @@ class ResBlock(TimestepBlock):
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
in_channels
!=
self
.
out_channels
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
up
,
self
.
down
=
up
,
down
# if self.up:
# self.h_upd = Upsample(in_channels, use_conv=False, dims=dims)
# self.x_upd = Upsample(in_channels, use_conv=False, dims=dims)
# elif self.down:
# self.h_upd = Downsample(in_channels, use_conv=False, dims=dims, padding=1, name="op")
# self.x_upd = Downsample(in_channels, use_conv=False, dims=dims, padding=1, name="op")
def
set_weights
(
self
):
def
set_weights
(
self
):
# TODO(Patrick): use for glide at later stage
# TODO(Patrick): use for glide at later stage
self
.
norm1
.
weight
.
data
=
self
.
in_layers
[
0
].
weight
.
data
self
.
norm1
.
weight
.
data
=
self
.
in_layers
[
0
].
weight
.
data
...
@@ -309,6 +317,7 @@ class ResBlock(TimestepBlock):
...
@@ -309,6 +317,7 @@ class ResBlock(TimestepBlock):
# TODO(Patrick): use for glide at later stage
# TODO(Patrick): use for glide at later stage
self
.
set_weights
()
self
.
set_weights
()
orig_x
=
x
if
self
.
updown
:
if
self
.
updown
:
in_rest
,
in_conv
=
self
.
in_layers
[:
-
1
],
self
.
in_layers
[
-
1
]
in_rest
,
in_conv
=
self
.
in_layers
[:
-
1
],
self
.
in_layers
[
-
1
]
h
=
in_rest
(
x
)
h
=
in_rest
(
x
)
...
@@ -334,8 +343,7 @@ class ResBlock(TimestepBlock):
...
@@ -334,8 +343,7 @@ class ResBlock(TimestepBlock):
result
=
self
.
skip_connection
(
x
)
+
h
result
=
self
.
skip_connection
(
x
)
+
h
# TODO(Patrick) Use for glide at later stage
# TODO(Patrick) Use for glide at later stage
# result = self.forward_2(x, emb)
result
=
self
.
forward_2
(
orig_x
,
emb
)
return
result
return
result
def
forward_2
(
self
,
x
,
temb
):
def
forward_2
(
self
,
x
,
temb
):
...
@@ -347,18 +355,24 @@ class ResBlock(TimestepBlock):
...
@@ -347,18 +355,24 @@ class ResBlock(TimestepBlock):
h
=
self
.
norm1
(
h
)
h
=
self
.
norm1
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
nonlinearity
(
h
)
if
self
.
up
or
self
.
down
:
x
=
self
.
x_upd
(
x
)
h
=
self
.
h_upd
(
h
)
h
=
self
.
conv1
(
h
)
h
=
self
.
conv1
(
h
)
temb
=
self
.
temb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
temb
=
self
.
temb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
scale
,
shift
=
torch
.
chunk
(
temb
,
2
,
dim
=
1
)
if
self
.
time_embedding_norm
==
"scale_shift"
:
scale
,
shift
=
torch
.
chunk
(
temb
,
2
,
dim
=
1
)
h
=
self
.
norm2
(
h
)
h
=
h
*
scale
+
shift
h
=
self
.
norm2
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
norm2
(
h
)
h
=
h
+
h
*
scale
+
shift
h
=
self
.
nonlinearity
(
h
)
else
:
h
=
h
+
temb
h
=
self
.
norm2
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
h
=
self
.
conv2
(
h
)
...
@@ -386,8 +400,12 @@ class ResnetBlock(nn.Module):
...
@@ -386,8 +400,12 @@ class ResnetBlock(nn.Module):
pre_norm
=
True
,
pre_norm
=
True
,
eps
=
1e-6
,
eps
=
1e-6
,
non_linearity
=
"swish"
,
non_linearity
=
"swish"
,
time_embedding_norm
=
"default"
,
up
=
False
,
down
=
False
,
overwrite_for_grad_tts
=
False
,
overwrite_for_grad_tts
=
False
,
overwrite_for_ldm
=
False
,
overwrite_for_ldm
=
False
,
overwrite_for_glide
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
pre_norm
=
pre_norm
self
.
pre_norm
=
pre_norm
...
@@ -395,6 +413,9 @@ class ResnetBlock(nn.Module):
...
@@ -395,6 +413,9 @@ class ResnetBlock(nn.Module):
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
self
.
use_conv_shortcut
=
conv_shortcut
self
.
time_embedding_norm
=
time_embedding_norm
self
.
up
=
up
self
.
down
=
down
if
self
.
pre_norm
:
if
self
.
pre_norm
:
self
.
norm1
=
Normalize
(
in_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
norm1
=
Normalize
(
in_channels
,
num_groups
=
groups
,
eps
=
eps
)
...
@@ -402,7 +423,12 @@ class ResnetBlock(nn.Module):
...
@@ -402,7 +423,12 @@ class ResnetBlock(nn.Module):
self
.
norm1
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
norm1
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
if
time_embedding_norm
==
"default"
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
if
time_embedding_norm
==
"scale_shift"
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
2
*
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
norm2
=
Normalize
(
out_channels
,
num_groups
=
groups
,
eps
=
eps
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
...
@@ -414,6 +440,13 @@ class ResnetBlock(nn.Module):
...
@@ -414,6 +440,13 @@ class ResnetBlock(nn.Module):
elif
non_linearity
==
"silu"
:
elif
non_linearity
==
"silu"
:
self
.
nonlinearity
=
nn
.
SiLU
()
self
.
nonlinearity
=
nn
.
SiLU
()
if
up
:
self
.
h_upd
=
Upsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
)
self
.
x_upd
=
Upsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
)
elif
down
:
self
.
h_upd
=
Downsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
,
padding
=
1
,
name
=
"op"
)
self
.
x_upd
=
Downsample
(
in_channels
,
use_conv
=
False
,
dims
=
2
,
padding
=
1
,
name
=
"op"
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
if
self
.
use_conv_shortcut
:
# TODO(Patrick) - this branch is never used I think => can be deleted!
# TODO(Patrick) - this branch is never used I think => can be deleted!
...
@@ -422,8 +455,9 @@ class ResnetBlock(nn.Module):
...
@@ -422,8 +455,9 @@ class ResnetBlock(nn.Module):
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
is_overwritten
=
False
self
.
is_overwritten
=
False
self
.
overwrite_for_glide
=
overwrite_for_glide
self
.
overwrite_for_grad_tts
=
overwrite_for_grad_tts
self
.
overwrite_for_grad_tts
=
overwrite_for_grad_tts
self
.
overwrite_for_ldm
=
overwrite_for_ldm
self
.
overwrite_for_ldm
=
overwrite_for_ldm
or
overwrite_for_glide
if
self
.
overwrite_for_grad_tts
:
if
self
.
overwrite_for_grad_tts
:
dim
=
in_channels
dim
=
in_channels
dim_out
=
out_channels
dim_out
=
out_channels
...
@@ -517,12 +551,18 @@ class ResnetBlock(nn.Module):
...
@@ -517,12 +551,18 @@ class ResnetBlock(nn.Module):
self
.
set_weights_ldm
()
self
.
set_weights_ldm
()
self
.
is_overwritten
=
True
self
.
is_overwritten
=
True
if
self
.
up
or
self
.
down
:
x
=
self
.
x_upd
(
x
)
h
=
x
h
=
x
h
=
h
*
mask
h
=
h
*
mask
if
self
.
pre_norm
:
if
self
.
pre_norm
:
h
=
self
.
norm1
(
h
)
h
=
self
.
norm1
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
nonlinearity
(
h
)
if
self
.
up
or
self
.
down
:
h
=
self
.
h_upd
(
h
)
h
=
self
.
conv1
(
h
)
h
=
self
.
conv1
(
h
)
if
not
self
.
pre_norm
:
if
not
self
.
pre_norm
:
...
@@ -530,12 +570,20 @@ class ResnetBlock(nn.Module):
...
@@ -530,12 +570,20 @@ class ResnetBlock(nn.Module):
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
h
*
mask
h
=
h
*
mask
h
=
h
+
self
.
temb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
temb
=
self
.
temb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
if
self
.
time_embedding_norm
==
"scale_shift"
:
scale
,
shift
=
torch
.
chunk
(
temb
,
2
,
dim
=
1
)
h
=
h
*
mask
if
self
.
pre_norm
:
h
=
self
.
norm2
(
h
)
h
=
self
.
norm2
(
h
)
h
=
h
+
h
*
scale
+
shift
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
nonlinearity
(
h
)
elif
self
.
time_embedding_norm
==
"default"
:
h
=
h
+
temb
h
=
h
*
mask
if
self
.
pre_norm
:
h
=
self
.
norm2
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
h
=
self
.
conv2
(
h
)
...
...
tests/test_modeling_utils.py
View file @
db934c67
...
@@ -259,7 +259,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -259,7 +259,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
# fmt: off
# fmt: off
expected_output_slice
=
torch
.
tensor
([
0.2891
,
-
0.1899
,
0.2595
,
-
0.6214
,
0.0968
,
-
0.2622
,
0.4688
,
0.1311
,
0.0053
])
expected_output_slice
=
torch
.
tensor
([
0.2891
,
-
0.1899
,
0.2595
,
-
0.6214
,
0.0968
,
-
0.2622
,
0.4688
,
0.1311
,
0.0053
])
# fmt: on
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-
3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-
2
))
class
GlideSuperResUNetTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
class
GlideSuperResUNetTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment