Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
4553c29d
Unverified
Commit
4553c29d
authored
Mar 15, 2023
by
Sayak Paul
Committed by
GitHub
Mar 15, 2023
Browse files
[Tests] fix: slow serialization test (#2678)
fix: slow serialization tests
parent
c9477bf8
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
5 deletions
+4
-5
tests/test_ema.py
tests/test_ema.py
+4
-5
No files found.
tests/test_ema.py
View file @
4553c29d
...
@@ -33,11 +33,9 @@ class EMAModelTests(unittest.TestCase):
...
@@ -33,11 +33,9 @@ class EMAModelTests(unittest.TestCase):
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
def
get_models
(
self
,
decay
=
0.9999
):
def
get_models
(
self
,
decay
=
0.9999
):
unet
=
UNet2DConditionModel
.
from_pretrained
(
self
.
model_id
,
subfolder
=
"unet"
,
device
=
torch_device
)
unet
=
UNet2DConditionModel
.
from_pretrained
(
self
.
model_id
,
subfolder
=
"unet"
)
ema_unet
=
UNet2DConditionModel
.
from_pretrained
(
self
.
model_id
,
subfolder
=
"unet"
)
unet
=
unet
.
to
(
torch_device
)
ema_unet
=
EMAModel
(
ema_unet
=
EMAModel
(
unet
.
parameters
(),
decay
=
decay
,
model_cls
=
UNet2DConditionModel
,
model_config
=
unet
.
config
)
ema_unet
.
parameters
(),
decay
=
decay
,
model_cls
=
UNet2DConditionModel
,
model_config
=
ema_unet
.
config
)
return
unet
,
ema_unet
return
unet
,
ema_unet
def
get_dummy_inputs
(
self
):
def
get_dummy_inputs
(
self
):
...
@@ -149,6 +147,7 @@ class EMAModelTests(unittest.TestCase):
...
@@ -149,6 +147,7 @@ class EMAModelTests(unittest.TestCase):
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
ema_unet
.
save_pretrained
(
tmpdir
)
ema_unet
.
save_pretrained
(
tmpdir
)
loaded_unet
=
UNet2DConditionModel
.
from_pretrained
(
tmpdir
,
model_cls
=
UNet2DConditionModel
)
loaded_unet
=
UNet2DConditionModel
.
from_pretrained
(
tmpdir
,
model_cls
=
UNet2DConditionModel
)
loaded_unet
=
loaded_unet
.
to
(
unet
.
device
)
# Since no EMA step has been performed the outputs should match.
# Since no EMA step has been performed the outputs should match.
output
=
unet
(
noisy_latents
,
timesteps
,
encoder_hidden_states
).
sample
output
=
unet
(
noisy_latents
,
timesteps
,
encoder_hidden_states
).
sample
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment