Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
b76eea04
Commit
b76eea04
authored
Jun 07, 2022
by
Patrick von Platen
Browse files
check with other device
parent
5da71f8f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
23 deletions
+5
-23
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+5
-23
No files found.
tests/test_modeling_utils.py
View file @
b76eea04
...
...
@@ -14,7 +14,6 @@
# limitations under the License.
import
random
import
tempfile
import
unittest
...
...
@@ -22,7 +21,6 @@ import os
from
distutils.util
import
strtobool
import
torch
import
numpy
as
np
from
diffusers
import
GaussianDDPMScheduler
,
UNetModel
from
diffusers.pipeline_utils
import
DiffusionPipeline
...
...
@@ -31,22 +29,7 @@ from models.vision.ddpm.modeling_ddpm import DDPM
global_rng
=
random
.
Random
()
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
def
get_random_generator
(
seed
):
seed
=
1234
random
.
seed
(
seed
)
os
.
environ
[
"PYTHONHASHSEED"
]
=
str
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
benchmark
=
False
torch
.
backends
.
cudnn
.
enabled
=
False
generator
=
torch
.
Generator
()
return
generator
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
def
parse_flag_from_env
(
key
,
default
=
False
):
...
...
@@ -132,7 +115,7 @@ class SamplerTesterMixin(unittest.TestCase):
@
slow
def
test_sample
(
self
):
generator
=
get_random_generator
(
0
)
generator
=
torch
.
manual_seed
(
0
)
# 1. Load models
scheduler
=
GaussianDDPMScheduler
.
from_config
(
"fusing/ddpm-lsun-church"
)
...
...
@@ -182,13 +165,12 @@ class SamplerTesterMixin(unittest.TestCase):
def
test_sample_fast
(
self
):
# 1. Load models
generator
=
get_random_generator
(
0
)
generator
=
torch
.
manual_seed
(
0
)
scheduler
=
GaussianDDPMScheduler
.
from_config
(
"fusing/ddpm-lsun-church"
,
timesteps
=
10
)
model
=
UNetModel
.
from_pretrained
(
"fusing/ddpm-lsun-church"
).
to
(
torch_device
)
# 2. Sample gaussian noise
torch
.
manual_seed
(
0
)
image
=
scheduler
.
sample_noise
((
1
,
model
.
in_channels
,
model
.
resolution
,
model
.
resolution
),
device
=
torch_device
,
generator
=
generator
)
# 3. Denoise
...
...
@@ -218,8 +200,8 @@ class SamplerTesterMixin(unittest.TestCase):
assert
image
.
shape
==
(
1
,
3
,
256
,
256
)
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
import
ipdb
;
ipdb
.
set_trace
(
)
assert
(
image_slice
-
torch
.
tensor
([[
0.1746
,
0.5125
,
-
0.7920
],
[
-
0.5734
,
-
0.2910
,
-
0.1984
],
[
0.4090
,
-
0.7740
,
-
0.3941
]])
).
abs
().
sum
()
<
1e-3
expected_slice
=
torch
.
tensor
([
-
0.0304
,
-
0.1895
,
-
0.2436
,
-
0.9837
,
-
0.5422
,
0.1931
,
-
0.8175
,
0.0862
,
-
0.7783
]
)
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
sum
()
<
1e-3
class
PipelineTesterMixin
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment