Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
988369a0
Unverified
Commit
988369a0
authored
Jun 16, 2022
by
Suraj Patil
Committed by
GitHub
Jun 16, 2022
Browse files
Merge branch 'main' into grad-tts
parents
5a3467e6
bed32182
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
50 additions
and
33 deletions
+50
-33
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+1
-1
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+7
-2
src/diffusers/utils/logging.py
src/diffusers/utils/logging.py
+15
-15
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+27
-15
No files found.
src/diffusers/schedulers/scheduling_ddpm.py
View file @
988369a0
...
@@ -44,7 +44,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -44,7 +44,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
clip_predicted_image
=
clip_predicted_image
,
clip_predicted_image
=
clip_predicted_image
,
)
)
self
.
timesteps
=
int
(
timesteps
)
self
.
timesteps
=
int
(
timesteps
)
self
.
timestep_values
=
timestep_values
# save the fixed timestep values for BDDM
self
.
timestep_values
=
timestep_values
# save the fixed timestep values for BDDM
self
.
clip_image
=
clip_predicted_image
self
.
clip_image
=
clip_predicted_image
self
.
variance_type
=
variance_type
self
.
variance_type
=
variance_type
...
...
src/diffusers/schedulers/scheduling_pndm.py
View file @
988369a0
...
@@ -84,7 +84,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -84,7 +84,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
inference_step_times
=
list
(
range
(
0
,
self
.
timesteps
,
self
.
timesteps
//
num_inference_steps
))
inference_step_times
=
list
(
range
(
0
,
self
.
timesteps
,
self
.
timesteps
//
num_inference_steps
))
warmup_time_steps
=
np
.
array
(
inference_step_times
[
-
self
.
pndm_order
:]).
repeat
(
2
)
+
np
.
tile
(
np
.
array
([
0
,
self
.
timesteps
//
num_inference_steps
//
2
]),
self
.
pndm_order
)
warmup_time_steps
=
np
.
array
(
inference_step_times
[
-
self
.
pndm_order
:]).
repeat
(
2
)
+
np
.
tile
(
np
.
array
([
0
,
self
.
timesteps
//
num_inference_steps
//
2
]),
self
.
pndm_order
)
self
.
warmup_time_steps
[
num_inference_steps
]
=
list
(
reversed
(
warmup_time_steps
[:
-
1
].
repeat
(
2
)[
1
:
-
1
]))
self
.
warmup_time_steps
[
num_inference_steps
]
=
list
(
reversed
(
warmup_time_steps
[:
-
1
].
repeat
(
2
)[
1
:
-
1
]))
return
self
.
warmup_time_steps
[
num_inference_steps
]
return
self
.
warmup_time_steps
[
num_inference_steps
]
...
@@ -137,7 +139,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -137,7 +139,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
at
=
alphas_cump
[
t
+
1
].
view
(
-
1
,
1
,
1
,
1
)
at
=
alphas_cump
[
t
+
1
].
view
(
-
1
,
1
,
1
,
1
)
at_next
=
alphas_cump
[
t_next
+
1
].
view
(
-
1
,
1
,
1
,
1
)
at_next
=
alphas_cump
[
t_next
+
1
].
view
(
-
1
,
1
,
1
,
1
)
x_delta
=
(
at_next
-
at
)
*
((
1
/
(
at
.
sqrt
()
*
(
at
.
sqrt
()
+
at_next
.
sqrt
())))
*
x
-
1
/
(
at
.
sqrt
()
*
(((
1
-
at_next
)
*
at
).
sqrt
()
+
((
1
-
at
)
*
at_next
).
sqrt
()))
*
et
)
x_delta
=
(
at_next
-
at
)
*
(
(
1
/
(
at
.
sqrt
()
*
(
at
.
sqrt
()
+
at_next
.
sqrt
())))
*
x
-
1
/
(
at
.
sqrt
()
*
(((
1
-
at_next
)
*
at
).
sqrt
()
+
((
1
-
at
)
*
at_next
).
sqrt
()))
*
et
)
x_next
=
x
+
x_delta
x_next
=
x
+
x_delta
return
x_next
return
x_next
...
...
src/diffusers/utils/logging.py
View file @
988369a0
...
@@ -49,16 +49,16 @@ _tqdm_active = True
...
@@ -49,16 +49,16 @@ _tqdm_active = True
def
_get_default_logging_level
():
def
_get_default_logging_level
():
"""
"""
If
TRANSFORM
ERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
If
DIFFUS
ERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
not - fall back to `_default_log_level`
not - fall back to `_default_log_level`
"""
"""
env_level_str
=
os
.
getenv
(
"
TRANSFORM
ERS_VERBOSITY"
,
None
)
env_level_str
=
os
.
getenv
(
"
DIFFUS
ERS_VERBOSITY"
,
None
)
if
env_level_str
:
if
env_level_str
:
if
env_level_str
in
log_levels
:
if
env_level_str
in
log_levels
:
return
log_levels
[
env_level_str
]
return
log_levels
[
env_level_str
]
else
:
else
:
logging
.
getLogger
().
warning
(
logging
.
getLogger
().
warning
(
f
"Unknown option
TRANSFORM
ERS_VERBOSITY=
{
env_level_str
}
, "
f
"Unknown option
DIFFUS
ERS_VERBOSITY=
{
env_level_str
}
, "
f
"has to be one of:
{
', '
.
join
(
log_levels
.
keys
())
}
"
f
"has to be one of:
{
', '
.
join
(
log_levels
.
keys
())
}
"
)
)
return
_default_log_level
return
_default_log_level
...
@@ -126,14 +126,14 @@ def get_logger(name: Optional[str] = None) -> logging.Logger:
...
@@ -126,14 +126,14 @@ def get_logger(name: Optional[str] = None) -> logging.Logger:
def
get_verbosity
()
->
int
:
def
get_verbosity
()
->
int
:
"""
"""
Return the current level for the 🤗
Transform
ers'
s
root logger as an int.
Return the current level for the 🤗
Diffus
ers' root logger as an int.
Returns:
Returns:
`int`: The logging level.
`int`: The logging level.
<Tip>
<Tip>
🤗
Transform
ers has following logging levels:
🤗
Diffus
ers has following logging levels:
- 50: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
- 50: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
- 40: `diffusers.logging.ERROR`
- 40: `diffusers.logging.ERROR`
...
@@ -149,7 +149,7 @@ def get_verbosity() -> int:
...
@@ -149,7 +149,7 @@ def get_verbosity() -> int:
def
set_verbosity
(
verbosity
:
int
)
->
None
:
def
set_verbosity
(
verbosity
:
int
)
->
None
:
"""
"""
Set the verbosity level for the 🤗
Transform
ers'
s
root logger.
Set the verbosity level for the 🤗
Diffus
ers' root logger.
Args:
Args:
verbosity (`int`):
verbosity (`int`):
...
@@ -187,7 +187,7 @@ def set_verbosity_error():
...
@@ -187,7 +187,7 @@ def set_verbosity_error():
def
disable_default_handler
()
->
None
:
def
disable_default_handler
()
->
None
:
"""Disable the default handler of the HuggingFace
Transform
ers'
s
root logger."""
"""Disable the default handler of the HuggingFace
Diffus
ers' root logger."""
_configure_library_root_logger
()
_configure_library_root_logger
()
...
@@ -196,7 +196,7 @@ def disable_default_handler() -> None:
...
@@ -196,7 +196,7 @@ def disable_default_handler() -> None:
def
enable_default_handler
()
->
None
:
def
enable_default_handler
()
->
None
:
"""Enable the default handler of the HuggingFace
Transform
ers'
s
root logger."""
"""Enable the default handler of the HuggingFace
Diffus
ers' root logger."""
_configure_library_root_logger
()
_configure_library_root_logger
()
...
@@ -205,7 +205,7 @@ def enable_default_handler() -> None:
...
@@ -205,7 +205,7 @@ def enable_default_handler() -> None:
def
add_handler
(
handler
:
logging
.
Handler
)
->
None
:
def
add_handler
(
handler
:
logging
.
Handler
)
->
None
:
"""adds a handler to the HuggingFace
Transform
ers'
s
root logger."""
"""adds a handler to the HuggingFace
Diffus
ers' root logger."""
_configure_library_root_logger
()
_configure_library_root_logger
()
...
@@ -214,7 +214,7 @@ def add_handler(handler: logging.Handler) -> None:
...
@@ -214,7 +214,7 @@ def add_handler(handler: logging.Handler) -> None:
def
remove_handler
(
handler
:
logging
.
Handler
)
->
None
:
def
remove_handler
(
handler
:
logging
.
Handler
)
->
None
:
"""removes given handler from the HuggingFace
Transform
ers'
s
root logger."""
"""removes given handler from the HuggingFace
Diffus
ers' root logger."""
_configure_library_root_logger
()
_configure_library_root_logger
()
...
@@ -233,7 +233,7 @@ def disable_propagation() -> None:
...
@@ -233,7 +233,7 @@ def disable_propagation() -> None:
def
enable_propagation
()
->
None
:
def
enable_propagation
()
->
None
:
"""
"""
Enable propagation of the library log outputs. Please disable the HuggingFace
Transform
ers'
s
default handler to
Enable propagation of the library log outputs. Please disable the HuggingFace
Diffus
ers' default handler to
prevent double logging if the root logger has been configured.
prevent double logging if the root logger has been configured.
"""
"""
...
@@ -243,7 +243,7 @@ def enable_propagation() -> None:
...
@@ -243,7 +243,7 @@ def enable_propagation() -> None:
def
enable_explicit_format
()
->
None
:
def
enable_explicit_format
()
->
None
:
"""
"""
Enable explicit formatting for every HuggingFace
Transform
ers'
s
logger. The explicit formatter is as follows:
Enable explicit formatting for every HuggingFace
Diffus
ers' logger. The explicit formatter is as follows:
```
```
[LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
[LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
```
```
...
@@ -258,7 +258,7 @@ def enable_explicit_format() -> None:
...
@@ -258,7 +258,7 @@ def enable_explicit_format() -> None:
def
reset_format
()
->
None
:
def
reset_format
()
->
None
:
"""
"""
Resets the formatting for HuggingFace
Transform
ers'
s
loggers.
Resets the formatting for HuggingFace
Diffus
ers' loggers.
All handlers currently bound to the root logger are affected by this method.
All handlers currently bound to the root logger are affected by this method.
"""
"""
...
@@ -270,10 +270,10 @@ def reset_format() -> None:
...
@@ -270,10 +270,10 @@ def reset_format() -> None:
def
warning_advice
(
self
,
*
args
,
**
kwargs
):
def
warning_advice
(
self
,
*
args
,
**
kwargs
):
"""
"""
This method is identical to `logger.warninging()`, but if env var
TRANSFORM
ERS_NO_ADVISORY_WARNINGS=1 is set, this
This method is identical to `logger.warninging()`, but if env var
DIFFUS
ERS_NO_ADVISORY_WARNINGS=1 is set, this
warning will not be printed
warning will not be printed
"""
"""
no_advisory_warnings
=
os
.
getenv
(
"
TRANSFORM
ERS_NO_ADVISORY_WARNINGS"
,
False
)
no_advisory_warnings
=
os
.
getenv
(
"
DIFFUS
ERS_NO_ADVISORY_WARNINGS"
,
False
)
if
no_advisory_warnings
:
if
no_advisory_warnings
:
return
return
self
.
warning
(
*
args
,
**
kwargs
)
self
.
warning
(
*
args
,
**
kwargs
)
...
...
tests/test_modeling_utils.py
View file @
988369a0
...
@@ -19,7 +19,18 @@ import unittest
...
@@ -19,7 +19,18 @@ import unittest
import
torch
import
torch
from
diffusers
import
DDIM
,
DDPM
,
PNDM
,
GLIDE
,
BDDM
,
DDIMScheduler
,
DDPMScheduler
,
LatentDiffusion
,
PNDMScheduler
,
UNetModel
from
diffusers
import
(
BDDM
,
DDIM
,
DDPM
,
GLIDE
,
PNDM
,
DDIMScheduler
,
DDPMScheduler
,
LatentDiffusion
,
PNDMScheduler
,
UNetModel
,
)
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.pipelines.pipeline_bddm
import
DiffWave
from
diffusers.pipelines.pipeline_bddm
import
DiffWave
...
@@ -214,6 +225,21 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -214,6 +225,21 @@ class PipelineTesterMixin(unittest.TestCase):
expected_slice
=
torch
.
tensor
([
0.7295
,
0.7358
,
0.7256
,
0.7435
,
0.7095
,
0.6884
,
0.7325
,
0.6921
,
0.6458
])
expected_slice
=
torch
.
tensor
([
0.7295
,
0.7358
,
0.7256
,
0.7435
,
0.7095
,
0.6884
,
0.7325
,
0.6921
,
0.6458
])
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
def
test_glide_text2img
(
self
):
model_id
=
"fusing/glide-base"
glide
=
GLIDE
.
from_pretrained
(
model_id
)
prompt
=
"a pencil sketch of a corgi"
generator
=
torch
.
manual_seed
(
0
)
image
=
glide
(
prompt
,
generator
=
generator
,
num_inference_steps_upscale
=
20
)
image_slice
=
image
[
0
,
:
3
,
:
3
,
-
1
].
cpu
()
assert
image
.
shape
==
(
1
,
256
,
256
,
3
)
expected_slice
=
torch
.
tensor
([
0.7119
,
0.7073
,
0.6460
,
0.7780
,
0.7423
,
0.6926
,
0.7378
,
0.7189
,
0.7784
])
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
def
test_module_from_pipeline
(
self
):
def
test_module_from_pipeline
(
self
):
model
=
DiffWave
(
num_res_layers
=
4
)
model
=
DiffWave
(
num_res_layers
=
4
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
12
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
12
)
...
@@ -229,17 +255,3 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -229,17 +255,3 @@ class PipelineTesterMixin(unittest.TestCase):
_
=
BDDM
.
from_pretrained
(
tmpdirname
)
_
=
BDDM
.
from_pretrained
(
tmpdirname
)
# check if the same works using the DifusionPipeline class
# check if the same works using the DifusionPipeline class
_
=
DiffusionPipeline
.
from_pretrained
(
tmpdirname
)
_
=
DiffusionPipeline
.
from_pretrained
(
tmpdirname
)
@
slow
def
test_glide_text2img
(
self
):
model_id
=
"fusing/glide-base"
glide
=
GLIDE
.
from_pretrained
(
model_id
)
prompt
=
"a pencil sketch of a corgi"
generator
=
torch
.
manual_seed
(
0
)
image
=
glide
(
prompt
,
generator
=
generator
,
num_inference_steps_upscale
=
20
)
image_slice
=
image
[
0
,
:
3
,
:
3
,
-
1
].
cpu
()
assert
image
.
shape
==
(
1
,
256
,
256
,
3
)
expected_slice
=
torch
.
tensor
([
0.7119
,
0.7073
,
0.6460
,
0.7780
,
0.7423
,
0.6926
,
0.7378
,
0.7189
,
0.7784
])
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment