Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
9d2fc6b5
Commit
9d2fc6b5
authored
Jul 15, 2022
by
Patrick von Platen
Browse files
some fixes
parent
3f1e9592
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
21 deletions
+12
-21
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+5
-14
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+7
-7
No files found.
src/diffusers/models/resnet.py
View file @
9d2fc6b5
...
@@ -711,9 +711,9 @@ class ResnetBlock(nn.Module):
...
@@ -711,9 +711,9 @@ class ResnetBlock(nn.Module):
def
forward
(
self
,
x
,
temb
):
def
forward
(
self
,
x
,
temb
):
h
=
x
h
=
x
if
self
.
pre_norm
:
h
=
self
.
norm1
(
h
)
h
=
self
.
norm1
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
nonlinearity
(
h
)
if
self
.
upsample
is
not
None
:
if
self
.
upsample
is
not
None
:
x
=
self
.
upsample
(
x
)
x
=
self
.
upsample
(
x
)
...
@@ -724,10 +724,6 @@ class ResnetBlock(nn.Module):
...
@@ -724,10 +724,6 @@ class ResnetBlock(nn.Module):
h
=
self
.
conv1
(
h
)
h
=
self
.
conv1
(
h
)
if
not
self
.
pre_norm
:
h
=
self
.
norm1
(
h
)
h
=
self
.
nonlinearity
(
h
)
if
temb
is
not
None
:
if
temb
is
not
None
:
temb
=
self
.
time_emb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
temb
=
self
.
time_emb_proj
(
self
.
nonlinearity
(
temb
))[:,
:,
None
,
None
]
else
:
else
:
...
@@ -741,17 +737,12 @@ class ResnetBlock(nn.Module):
...
@@ -741,17 +737,12 @@ class ResnetBlock(nn.Module):
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
nonlinearity
(
h
)
elif
self
.
time_embedding_norm
==
"default"
:
elif
self
.
time_embedding_norm
==
"default"
:
h
=
h
+
temb
h
=
h
+
temb
if
self
.
pre_norm
:
h
=
self
.
norm2
(
h
)
h
=
self
.
norm2
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
nonlinearity
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
h
=
self
.
conv2
(
h
)
if
not
self
.
pre_norm
:
h
=
self
.
norm2
(
h
)
h
=
self
.
nonlinearity
(
h
)
if
self
.
conv_shortcut
is
not
None
:
if
self
.
conv_shortcut
is
not
None
:
x
=
self
.
conv_shortcut
(
x
)
x
=
self
.
conv_shortcut
(
x
)
...
...
tests/test_modeling_utils.py
View file @
9d2fc6b5
...
@@ -281,7 +281,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -281,7 +281,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
"fusing/ddpm_dummy"
,
output_loading_info
=
True
,
ddpm
=
True
"fusing/ddpm_dummy"
,
output_loading_info
=
True
,
ddpm
=
True
)
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
#
self.assertEqual(len(loading_info["missing_keys"]), 0)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)[
"sample"
]
image
=
model
(
**
self
.
dummy_input
)[
"sample"
]
...
@@ -370,7 +370,7 @@ class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase):
...
@@ -370,7 +370,7 @@ class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase):
"fusing/glide-super-res-dummy"
,
output_loading_info
=
True
"fusing/glide-super-res-dummy"
,
output_loading_info
=
True
)
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
#
self.assertEqual(len(loading_info["missing_keys"]), 0)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
image
=
model
(
**
self
.
dummy_input
)
...
@@ -462,7 +462,7 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -462,7 +462,7 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
"fusing/unet-glide-text2im-dummy"
,
output_loading_info
=
True
"fusing/unet-glide-text2im-dummy"
,
output_loading_info
=
True
)
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
#
self.assertEqual(len(loading_info["missing_keys"]), 0)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
image
=
model
(
**
self
.
dummy_input
)
...
@@ -538,7 +538,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -538,7 +538,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
"fusing/unet-ldm-dummy"
,
output_loading_info
=
True
,
ldm
=
True
"fusing/unet-ldm-dummy"
,
output_loading_info
=
True
,
ldm
=
True
)
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
#
self.assertEqual(len(loading_info["missing_keys"]), 0)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)[
"sample"
]
image
=
model
(
**
self
.
dummy_input
)[
"sample"
]
...
@@ -630,7 +630,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -630,7 +630,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
def
test_from_pretrained_hub
(
self
):
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
NCSNpp
.
from_pretrained
(
"fusing/cifar10-ncsnpp-ve"
,
output_loading_info
=
True
)
model
,
loading_info
=
NCSNpp
.
from_pretrained
(
"fusing/cifar10-ncsnpp-ve"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
#
self.assertEqual(len(loading_info["missing_keys"]), 0)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
image
=
model
(
**
self
.
dummy_input
)
...
@@ -765,7 +765,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -765,7 +765,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
def
test_from_pretrained_hub
(
self
):
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
VQModel
.
from_pretrained
(
"fusing/vqgan-dummy"
,
output_loading_info
=
True
)
model
,
loading_info
=
VQModel
.
from_pretrained
(
"fusing/vqgan-dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
#
self.assertEqual(len(loading_info["missing_keys"]), 0)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
image
=
model
(
**
self
.
dummy_input
)
...
@@ -836,7 +836,7 @@ class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase):
...
@@ -836,7 +836,7 @@ class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase):
def
test_from_pretrained_hub
(
self
):
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
AutoencoderKL
.
from_pretrained
(
"fusing/autoencoder-kl-dummy"
,
output_loading_info
=
True
)
model
,
loading_info
=
AutoencoderKL
.
from_pretrained
(
"fusing/autoencoder-kl-dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
#
self.assertEqual(len(loading_info["missing_keys"]), 0)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
image
=
model
(
**
self
.
dummy_input
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment