Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
d37f08da
Unverified
Commit
d37f08da
authored
Oct 28, 2022
by
Patrick von Platen
Committed by
GitHub
Oct 28, 2022
Browse files
[Tests] no random latents anymore (#1045)
parent
c4ef1efe
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
31 additions
and
14 deletions
+31
-14
src/diffusers/utils/__init__.py
src/diffusers/utils/__init__.py
+1
-0
src/diffusers/utils/testing_utils.py
src/diffusers/utils/testing_utils.py
+17
-1
tests/models/test_models_unet_2d.py
tests/models/test_models_unet_2d.py
+7
-7
tests/models/test_models_vae.py
tests/models/test_models_vae.py
+6
-6
No files found.
src/diffusers/utils/__init__.py
View file @
d37f08da
...
...
@@ -43,6 +43,7 @@ if is_torch_available():
from
.testing_utils
import
(
floats_tensor
,
load_image
,
load_numpy
,
parse_flag_from_env
,
require_torch_gpu
,
slow
,
...
...
src/diffusers/utils/testing_utils.py
View file @
d37f08da
...
...
@@ -4,11 +4,14 @@ import os
import
random
import
re
import
unittest
import
urllib.parse
from
distutils.util
import
strtobool
from
io
import
StringIO
from
io
import
BytesIO
,
StringIO
from
pathlib
import
Path
from
typing
import
Union
import
numpy
as
np
import
PIL.Image
import
PIL.ImageOps
import
requests
...
...
@@ -165,6 +168,19 @@ def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
return
image
def
load_numpy
(
path
)
->
np
.
ndarray
:
if
not
path
.
startswith
(
"http://"
)
or
path
.
startswith
(
"https://"
):
path
=
os
.
path
.
join
(
"https://huggingface.co/datasets/fusing/diffusers-testing/resolve/main"
,
urllib
.
parse
.
quote
(
path
)
)
response
=
requests
.
get
(
path
)
response
.
raise_for_status
()
array
=
np
.
load
(
BytesIO
(
response
.
content
))
return
array
# --- pytest conf functions --- #
# to avoid multiple invocation from tests/conftest.py and examples/conftest.py - make sure it's called only once
...
...
tests/models/test_models_unet_2d.py
View file @
d37f08da
...
...
@@ -21,7 +21,7 @@ import unittest
import
torch
from
diffusers
import
UNet2DConditionModel
,
UNet2DModel
from
diffusers.utils
import
floats_tensor
,
require_torch_gpu
,
slow
,
torch_all_close
,
torch_device
from
diffusers.utils
import
floats_tensor
,
load_numpy
,
require_torch_gpu
,
slow
,
torch_all_close
,
torch_device
from
parameterized
import
parameterized
from
..test_modeling_common
import
ModelTesterMixin
...
...
@@ -411,6 +411,9 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
@
slow
class
UNet2DConditionModelIntegrationTests
(
unittest
.
TestCase
):
def
get_file_format
(
self
,
seed
,
shape
):
return
f
"gaussian_noise_s=
{
seed
}
_shape=
{
'_'
.
join
([
str
(
s
)
for
s
in
shape
])
}
.npy"
def
tearDown
(
self
):
# clean up the VRAM after each test
super
().
tearDown
()
...
...
@@ -418,11 +421,8 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
torch
.
cuda
.
empty_cache
()
def
get_latents
(
self
,
seed
=
0
,
shape
=
(
4
,
4
,
64
,
64
),
fp16
=
False
):
batch_size
,
channels
,
height
,
width
=
shape
generator
=
torch
.
Generator
(
device
=
torch_device
).
manual_seed
(
seed
)
dtype
=
torch
.
float16
if
fp16
else
torch
.
float32
image
=
torch
.
randn
(
batch_size
,
channels
,
height
,
width
,
device
=
torch_device
,
generator
=
generator
,
dtype
=
dtype
)
image
=
torch
.
from_numpy
(
load_numpy
(
self
.
get_file_format
(
seed
,
shape
))).
to
(
torch_device
).
to
(
dtype
)
return
image
def
get_unet_model
(
self
,
fp16
=
False
,
model_id
=
"CompVis/stable-diffusion-v1-4"
):
...
...
@@ -437,9 +437,9 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
return
model
def
get_encoder_hidden_states
(
self
,
seed
=
0
,
shape
=
(
4
,
77
,
768
),
fp16
=
False
):
generator
=
torch
.
Generator
(
device
=
torch_device
).
manual_seed
(
seed
)
dtype
=
torch
.
float16
if
fp16
else
torch
.
float32
return
torch
.
randn
(
shape
,
device
=
torch_device
,
generator
=
generator
,
dtype
=
dtype
)
hidden_states
=
torch
.
from_numpy
(
load_numpy
(
self
.
get_file_format
(
seed
,
shape
))).
to
(
torch_device
).
to
(
dtype
)
return
hidden_states
@
parameterized
.
expand
(
[
...
...
tests/models/test_models_vae.py
View file @
d37f08da
...
...
@@ -20,7 +20,7 @@ import torch
from
diffusers
import
AutoencoderKL
from
diffusers.modeling_utils
import
ModelMixin
from
diffusers.utils
import
floats_tensor
,
require_torch_gpu
,
slow
,
torch_all_close
,
torch_device
from
diffusers.utils
import
floats_tensor
,
load_numpy
,
require_torch_gpu
,
slow
,
torch_all_close
,
torch_device
from
parameterized
import
parameterized
from
..test_modeling_common
import
ModelTesterMixin
...
...
@@ -136,18 +136,18 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
@
slow
class
AutoencoderKLIntegrationTests
(
unittest
.
TestCase
):
def
get_file_format
(
self
,
seed
,
shape
):
return
f
"gaussian_noise_s=
{
seed
}
_shape=
{
'_'
.
join
([
str
(
s
)
for
s
in
shape
])
}
.npy"
def
tearDown
(
self
):
# clean up the VRAM after each test
super
().
tearDown
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
def
get_sd_image
(
self
,
seed
=
0
,
shape
=
(
4
,
3
,
512
,
512
),
fp16
=
False
):
batch_size
,
channels
,
height
,
width
=
shape
generator
=
torch
.
Generator
(
device
=
torch_device
).
manual_seed
(
seed
)
def
get_sd_image
(
self
,
seed
=
0
,
shape
=
(
4
,
4
,
64
,
64
),
fp16
=
False
):
dtype
=
torch
.
float16
if
fp16
else
torch
.
float32
image
=
torch
.
randn
(
batch_size
,
channels
,
height
,
width
,
device
=
torch_device
,
generator
=
generator
,
dtype
=
dtype
)
image
=
torch
.
from_numpy
(
load_numpy
(
self
.
get_file_format
(
seed
,
shape
))).
to
(
torch_device
).
to
(
dtype
)
return
image
def
get_sd_vae_model
(
self
,
model_id
=
"CompVis/stable-diffusion-v1-4"
,
fp16
=
False
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment