Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
46d20d2d
Commit
46d20d2d
authored
Jun 07, 2022
by
Patrick von Platen
Browse files
fix random seed
parent
9c4cd06d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
13 deletions
+25
-13
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+25
-13
No files found.
tests/test_modeling_utils.py
View file @
46d20d2d
...
...
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
random
import
tempfile
import
unittest
...
...
@@ -30,6 +32,22 @@ global_rng = random.Random()
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
def
get_random_generator
(
seed
):
seed
=
1234
random
.
seed
(
seed
)
os
.
environ
[
‘
PYTHONHASHSEED
’
]
=
str
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
benchmark
=
False
torch
.
backends
.
cudnn
.
enabled
=
False
generator
=
torch
.
Generator
()
return
generator
def
parse_flag_from_env
(
key
,
default
=
False
):
try
:
value
=
os
.
environ
[
key
]
...
...
@@ -113,8 +131,7 @@ class SamplerTesterMixin(unittest.TestCase):
@
slow
def
test_sample
(
self
):
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
6694729458485568
)
generator
=
get_random_generator
(
0
)
# 1. Load models
scheduler
=
GaussianDDPMScheduler
.
from_config
(
"fusing/ddpm-lsun-church"
)
...
...
@@ -163,8 +180,7 @@ class SamplerTesterMixin(unittest.TestCase):
def
test_sample_fast
(
self
):
# 1. Load models
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
6694729458485568
)
generator
=
get_random_generator
(
0
)
scheduler
=
GaussianDDPMScheduler
.
from_config
(
"fusing/ddpm-lsun-church"
,
timesteps
=
10
)
model
=
UNetModel
.
from_pretrained
(
"fusing/ddpm-lsun-church"
).
to
(
torch_device
)
...
...
@@ -215,16 +231,14 @@ class PipelineTesterMixin(unittest.TestCase):
ddpm
.
save_pretrained
(
tmpdirname
)
new_ddpm
=
DDPM
.
from_pretrained
(
tmpdirname
)
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
669472945848556
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ddpm
(
generator
=
generator
)
generator
=
generator
.
manual_seed
(
669472945848556
)
generator
=
generator
.
manual_seed
(
0
)
new_image
=
new_ddpm
(
generator
=
generator
)
assert
(
image
-
new_image
).
abs
().
sum
()
<
1e-5
,
"Models don't give the same forward pass"
@
slow
def
test_from_pretrained_hub
(
self
):
model_path
=
"fusing/ddpm-cifar10"
...
...
@@ -235,12 +249,10 @@ class PipelineTesterMixin(unittest.TestCase):
ddpm
.
noise_scheduler
.
num_timesteps
=
10
ddpm_from_hub
.
noise_scheduler
.
num_timesteps
=
10
generator
=
torch
.
Generator
(
device
=
torch_device
)
generator
=
generator
.
manual_seed
(
669472945848556
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ddpm
(
generator
=
generator
)
generator
=
generator
.
manual_seed
(
669472945848556
)
generator
=
generator
.
manual_seed
(
0
)
new_image
=
ddpm_from_hub
(
generator
=
generator
)
assert
(
image
-
new_image
).
abs
().
sum
()
<
1e-5
,
"Models don't give the same forward pass"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment