Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
12b10cbe
Commit
12b10cbe
authored
Jun 12, 2022
by
Patrick von Platen
Browse files
finish refactor
parent
2d97544d
Changes
23
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
52 additions
and
31 deletions
+52
-31
src/diffusers/utils/logging.py
src/diffusers/utils/logging.py
+1
-1
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+4
-4
tests/test_scheduler.py
tests/test_scheduler.py
+47
-26
No files found.
src/diffusers/utils/logging.py
View file @
12b10cbe
...
@@ -270,7 +270,7 @@ def reset_format() -> None:
...
@@ -270,7 +270,7 @@ def reset_format() -> None:
def
warning_advice
(
self
,
*
args
,
**
kwargs
):
def
warning_advice
(
self
,
*
args
,
**
kwargs
):
"""
"""
This method is identical to `logger.warning()`, but if env var TRANSFORMERS_NO_ADVISORY_WARNINGS=1 is set, this
This method is identical to `logger.warning
ing
()`, but if env var TRANSFORMERS_NO_ADVISORY_WARNINGS=1 is set, this
warning will not be printed
warning will not be printed
"""
"""
no_advisory_warnings
=
os
.
getenv
(
"TRANSFORMERS_NO_ADVISORY_WARNINGS"
,
False
)
no_advisory_warnings
=
os
.
getenv
(
"TRANSFORMERS_NO_ADVISORY_WARNINGS"
,
False
)
...
...
tests/test_modeling_utils.py
View file @
12b10cbe
...
@@ -19,11 +19,10 @@ import unittest
...
@@ -19,11 +19,10 @@ import unittest
import
torch
import
torch
from
diffusers
import
GaussianDDPMScheduler
,
UNetModel
,
DDIMScheduler
from
diffusers
import
DDIM
,
DDPM
,
DDIMScheduler
,
GaussianDDPMScheduler
,
LatentDiffusion
,
UNetModel
from
diffusers
import
DDIM
,
DDPM
,
LatentDiffusion
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.testing_utils
import
floats_tensor
,
torch_device
,
slow
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
...
@@ -149,6 +148,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -149,6 +148,7 @@ class PipelineTesterMixin(unittest.TestCase):
unet
=
UNetModel
.
from_pretrained
(
model_id
)
unet
=
UNetModel
.
from_pretrained
(
model_id
)
noise_scheduler
=
GaussianDDPMScheduler
.
from_config
(
model_id
)
noise_scheduler
=
GaussianDDPMScheduler
.
from_config
(
model_id
)
noise_scheduler
=
noise_scheduler
.
set_format
(
"pt"
)
ddpm
=
DDPM
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
ddpm
=
DDPM
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
image
=
ddpm
(
generator
=
generator
)
image
=
ddpm
(
generator
=
generator
)
...
@@ -165,7 +165,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -165,7 +165,7 @@ class PipelineTesterMixin(unittest.TestCase):
model_id
=
"fusing/ddpm-cifar10"
model_id
=
"fusing/ddpm-cifar10"
unet
=
UNetModel
.
from_pretrained
(
model_id
)
unet
=
UNetModel
.
from_pretrained
(
model_id
)
noise_scheduler
=
DDIMScheduler
()
noise_scheduler
=
DDIMScheduler
(
tensor_format
=
"pt"
)
ddim
=
DDIM
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
ddim
=
DDIM
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
image
=
ddim
(
generator
=
generator
,
eta
=
0.0
)
image
=
ddim
(
generator
=
generator
,
eta
=
0.0
)
...
...
tests/test_scheduler.py
View file @
12b10cbe
...
@@ -14,12 +14,13 @@
...
@@ -14,12 +14,13 @@
# limitations under the License.
# limitations under the License.
import
torch
import
numpy
as
np
import
unittest
import
tempfile
import
tempfile
import
unittest
from
diffusers
import
GaussianDDPMScheduler
,
DDIMScheduler
import
numpy
as
np
import
torch
from
diffusers
import
DDIMScheduler
,
GaussianDDPMScheduler
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
...
@@ -38,7 +39,7 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -38,7 +39,7 @@ class SchedulerCommonTest(unittest.TestCase):
image
=
np
.
random
.
rand
(
batch_size
,
num_channels
,
height
,
width
)
image
=
np
.
random
.
rand
(
batch_size
,
num_channels
,
height
,
width
)
return
torch
.
tensor
(
image
)
return
image
@
property
@
property
def
dummy_image_deter
(
self
):
def
dummy_image_deter
(
self
):
...
@@ -53,7 +54,7 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -53,7 +54,7 @@ class SchedulerCommonTest(unittest.TestCase):
image
=
image
/
num_elems
image
=
image
/
num_elems
image
=
image
.
transpose
(
3
,
0
,
1
,
2
)
image
=
image
.
transpose
(
3
,
0
,
1
,
2
)
return
torch
.
tensor
(
image
)
return
image
def
get_scheduler_config
(
self
):
def
get_scheduler_config
(
self
):
raise
NotImplementedError
raise
NotImplementedError
...
@@ -82,7 +83,7 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -82,7 +83,7 @@ class SchedulerCommonTest(unittest.TestCase):
output
=
scheduler
.
step
(
residual
,
image
,
time_step
,
**
kwargs
)
output
=
scheduler
.
step
(
residual
,
image
,
time_step
,
**
kwargs
)
new_output
=
new_scheduler
.
step
(
residual
,
image
,
time_step
,
**
kwargs
)
new_output
=
new_scheduler
.
step
(
residual
,
image
,
time_step
,
**
kwargs
)
assert
(
output
-
new_output
)
.
abs
().
sum
(
)
<
1e-5
,
"Scheduler outputs are not identical"
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
def
check_over_forward
(
self
,
time_step
=
0
,
**
forward_kwargs
):
def
check_over_forward
(
self
,
time_step
=
0
,
**
forward_kwargs
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
=
dict
(
self
.
forward_default_kwargs
)
...
@@ -103,7 +104,7 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -103,7 +104,7 @@ class SchedulerCommonTest(unittest.TestCase):
output
=
scheduler
.
step
(
residual
,
image
,
time_step
,
**
kwargs
)
output
=
scheduler
.
step
(
residual
,
image
,
time_step
,
**
kwargs
)
new_output
=
new_scheduler
.
step
(
residual
,
image
,
time_step
,
**
kwargs
)
new_output
=
new_scheduler
.
step
(
residual
,
image
,
time_step
,
**
kwargs
)
assert
(
output
-
new_output
)
.
abs
().
sum
(
)
<
1e-5
,
"Scheduler outputs are not identical"
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
def
test_from_pretrained_save_pretrained
(
self
):
def
test_from_pretrained_save_pretrained
(
self
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
=
dict
(
self
.
forward_default_kwargs
)
...
@@ -122,7 +123,7 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -122,7 +123,7 @@ class SchedulerCommonTest(unittest.TestCase):
output
=
scheduler
.
step
(
residual
,
image
,
1
,
**
kwargs
)
output
=
scheduler
.
step
(
residual
,
image
,
1
,
**
kwargs
)
new_output
=
new_scheduler
.
step
(
residual
,
image
,
1
,
**
kwargs
)
new_output
=
new_scheduler
.
step
(
residual
,
image
,
1
,
**
kwargs
)
assert
(
output
-
new_output
)
.
abs
().
sum
(
)
<
1e-5
,
"Scheduler outputs are not identical"
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
def
test_step_shape
(
self
):
def
test_step_shape
(
self
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
=
dict
(
self
.
forward_default_kwargs
)
...
@@ -140,6 +141,26 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -140,6 +141,26 @@ class SchedulerCommonTest(unittest.TestCase):
self
.
assertEqual
(
output_0
.
shape
,
image
.
shape
)
self
.
assertEqual
(
output_0
.
shape
,
image
.
shape
)
self
.
assertEqual
(
output_0
.
shape
,
output_1
.
shape
)
self
.
assertEqual
(
output_0
.
shape
,
output_1
.
shape
)
def
test_pytorch_equal_numpy
(
self
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
for
scheduler_class
in
self
.
scheduler_classes
:
image
=
self
.
dummy_image
residual
=
0.1
*
image
image_pt
=
torch
.
tensor
(
image
)
residual_pt
=
0.1
*
image_pt
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler_pt
=
scheduler_class
(
tensor_format
=
"pt"
,
**
scheduler_config
)
output
=
scheduler
.
step
(
residual
,
image
,
1
,
**
kwargs
)
output_pt
=
scheduler_pt
.
step
(
residual_pt
,
image_pt
,
1
,
**
kwargs
)
assert
np
.
sum
(
np
.
abs
(
output
-
output_pt
.
numpy
()))
<
1e-5
,
"Scheduler outputs are not identical"
class
DDPMSchedulerTest
(
SchedulerCommonTest
):
class
DDPMSchedulerTest
(
SchedulerCommonTest
):
scheduler_classes
=
(
GaussianDDPMScheduler
,)
scheduler_classes
=
(
GaussianDDPMScheduler
,)
...
@@ -151,7 +172,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
...
@@ -151,7 +172,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
"beta_end"
:
0.02
,
"beta_end"
:
0.02
,
"beta_schedule"
:
"linear"
,
"beta_schedule"
:
"linear"
,
"variance_type"
:
"fixed_small"
,
"variance_type"
:
"fixed_small"
,
"clip_predicted_image"
:
True
"clip_predicted_image"
:
True
,
}
}
config
.
update
(
**
kwargs
)
config
.
update
(
**
kwargs
)
...
@@ -186,9 +207,9 @@ class DDPMSchedulerTest(SchedulerCommonTest):
...
@@ -186,9 +207,9 @@ class DDPMSchedulerTest(SchedulerCommonTest):
scheduler_config
=
self
.
get_scheduler_config
()
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
=
scheduler_class
(
**
scheduler_config
)
assert
(
scheduler
.
get_variance
(
0
)
-
0.0
)
.
abs
().
sum
(
)
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
get_variance
(
0
)
-
0.0
))
<
1e-5
assert
(
scheduler
.
get_variance
(
487
)
-
0.00979
)
.
abs
().
sum
(
)
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
get_variance
(
487
)
-
0.00979
))
<
1e-5
assert
(
scheduler
.
get_variance
(
999
)
-
0.02
)
.
abs
().
sum
(
)
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
get_variance
(
999
)
-
0.02
))
<
1e-5
def
test_full_loop_no_noise
(
self
):
def
test_full_loop_no_noise
(
self
):
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_class
=
self
.
scheduler_classes
[
0
]
...
@@ -209,12 +230,12 @@ class DDPMSchedulerTest(SchedulerCommonTest):
...
@@ -209,12 +230,12 @@ class DDPMSchedulerTest(SchedulerCommonTest):
if
t
>
0
:
if
t
>
0
:
noise
=
self
.
dummy_image_deter
noise
=
self
.
dummy_image_deter
variance
=
scheduler
.
get_variance
(
t
)
.
sqrt
(
)
*
noise
variance
=
scheduler
.
get_variance
(
t
)
**
(
0.5
)
*
noise
image
=
pred_prev_image
+
variance
image
=
pred_prev_image
+
variance
result_sum
=
image
.
abs
().
sum
(
)
result_sum
=
np
.
sum
(
np
.
abs
(
image
)
)
result_mean
=
image
.
abs
().
mean
(
)
result_mean
=
np
.
mean
(
np
.
abs
(
image
)
)
assert
result_sum
.
item
()
-
732.9947
<
1e-3
assert
result_sum
.
item
()
-
732.9947
<
1e-3
assert
result_mean
.
item
()
-
0.9544
<
1e-3
assert
result_mean
.
item
()
-
0.9544
<
1e-3
...
@@ -230,7 +251,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
...
@@ -230,7 +251,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
"beta_start"
:
0.0001
,
"beta_start"
:
0.0001
,
"beta_end"
:
0.02
,
"beta_end"
:
0.02
,
"beta_schedule"
:
"linear"
,
"beta_schedule"
:
"linear"
,
"clip_predicted_image"
:
True
"clip_predicted_image"
:
True
,
}
}
config
.
update
(
**
kwargs
)
config
.
update
(
**
kwargs
)
...
@@ -269,12 +290,12 @@ class DDIMSchedulerTest(SchedulerCommonTest):
...
@@ -269,12 +290,12 @@ class DDIMSchedulerTest(SchedulerCommonTest):
scheduler_config
=
self
.
get_scheduler_config
()
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
=
scheduler_class
(
**
scheduler_config
)
assert
(
scheduler
.
get_variance
(
0
,
50
)
-
0.0
)
.
abs
().
sum
(
)
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
get_variance
(
0
,
50
)
-
0.0
))
<
1e-5
assert
(
scheduler
.
get_variance
(
21
,
50
)
-
0.14771
)
.
abs
().
sum
(
)
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
get_variance
(
21
,
50
)
-
0.14771
))
<
1e-5
assert
(
scheduler
.
get_variance
(
49
,
50
)
-
0.32460
)
.
abs
().
sum
(
)
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
get_variance
(
49
,
50
)
-
0.32460
))
<
1e-5
assert
(
scheduler
.
get_variance
(
0
,
1000
)
-
0.0
)
.
abs
().
sum
(
)
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
get_variance
(
0
,
1000
)
-
0.0
))
<
1e-5
assert
(
scheduler
.
get_variance
(
487
,
1000
)
-
0.00979
)
.
abs
().
sum
(
)
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
get_variance
(
487
,
1000
)
-
0.00979
))
<
1e-5
assert
(
scheduler
.
get_variance
(
999
,
1000
)
-
0.02
)
.
abs
().
sum
(
)
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
get_variance
(
999
,
1000
)
-
0.02
))
<
1e-5
def
test_full_loop_no_noise
(
self
):
def
test_full_loop_no_noise
(
self
):
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_class
=
self
.
scheduler_classes
[
0
]
...
@@ -297,12 +318,12 @@ class DDIMSchedulerTest(SchedulerCommonTest):
...
@@ -297,12 +318,12 @@ class DDIMSchedulerTest(SchedulerCommonTest):
variance
=
0
variance
=
0
if
eta
>
0
:
if
eta
>
0
:
noise
=
self
.
dummy_image_deter
noise
=
self
.
dummy_image_deter
variance
=
scheduler
.
get_variance
(
t
,
num_inference_steps
)
.
sqrt
(
)
*
eta
*
noise
variance
=
scheduler
.
get_variance
(
t
,
num_inference_steps
)
**
(
0.5
)
*
eta
*
noise
image
=
pred_prev_image
+
variance
image
=
pred_prev_image
+
variance
result_sum
=
image
.
abs
().
sum
(
)
result_sum
=
np
.
sum
(
np
.
abs
(
image
)
)
result_mean
=
image
.
abs
().
mean
(
)
result_mean
=
np
.
mean
(
np
.
abs
(
image
)
)
assert
result_sum
.
item
()
-
270.6214
<
1e-3
assert
result_sum
.
item
()
-
270.6214
<
1e-3
assert
result_mean
.
item
()
-
0.3524
<
1e-3
assert
result_mean
.
item
()
-
0.3524
<
1e-3
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment