Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
e29fc446
Commit
e29fc446
authored
Jun 22, 2022
by
Nathan Lambert
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
parents
7b4e049e
6e456b7a
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
117 additions
and
48 deletions
+117
-48
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+15
-7
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+1
-1
src/diffusers/utils/__init__.py
src/diffusers/utils/__init__.py
+38
-0
src/diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py
...s/dummy_transformers_and_inflect_and_unidecode_objects.py
+10
-0
src/diffusers/utils/dummy_transformers_objects.py
src/diffusers/utils/dummy_transformers_objects.py
+4
-7
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+47
-30
utils/check_dummies.py
utils/check_dummies.py
+2
-3
No files found.
src/diffusers/schedulers/scheduling_ddpm.py
View file @
e29fc446
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
import
math
import
math
import
numpy
as
np
import
numpy
as
np
import
torch
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_utils
import
SchedulerMixin
...
@@ -76,7 +77,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -76,7 +77,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
elif
beta_schedule
==
"linear"
:
elif
beta_schedule
==
"linear"
:
self
.
betas
=
np
.
linspace
(
beta_start
,
beta_end
,
timesteps
,
dtype
=
np
.
float32
)
self
.
betas
=
np
.
linspace
(
beta_start
,
beta_end
,
timesteps
,
dtype
=
np
.
float32
)
elif
beta_schedule
==
"squaredcos_cap_v2"
:
elif
beta_schedule
==
"squaredcos_cap_v2"
:
# G
LIDE
cosine schedule
# G
lide
cosine schedule
self
.
betas
=
betas_for_alpha_bar
(
timesteps
)
self
.
betas
=
betas_for_alpha_bar
(
timesteps
)
else
:
else
:
raise
NotImplementedError
(
f
"
{
beta_schedule
}
does is not implemented for
{
self
.
__class__
}
"
)
raise
NotImplementedError
(
f
"
{
beta_schedule
}
does is not implemented for
{
self
.
__class__
}
"
)
...
@@ -108,7 +109,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -108,7 +109,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
elif
variance_type
==
"fixed_large"
:
elif
variance_type
==
"fixed_large"
:
variance
=
self
.
betas
[
t
]
variance
=
self
.
betas
[
t
]
elif
variance_type
==
"fixed_large_log"
:
elif
variance_type
==
"fixed_large_log"
:
# G
LIDE
max_log
# G
lide
max_log
variance
=
self
.
log
(
self
.
betas
[
t
])
variance
=
self
.
log
(
self
.
betas
[
t
])
return
variance
return
variance
...
@@ -142,11 +143,18 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -142,11 +143,18 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return
pred_prev_sample
return
pred_prev_sample
def
forward_step
(
self
,
original_sample
,
noise
,
t
):
def
training_step
(
self
,
original_samples
:
torch
.
Tensor
,
noise
:
torch
.
Tensor
,
timesteps
:
torch
.
Tensor
):
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
t
]
**
0.5
if
timesteps
.
dim
()
!=
1
:
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
t
])
**
0.5
raise
ValueError
(
"`timesteps` must be a 1D tensor"
)
noisy_sample
=
sqrt_alpha_prod
*
original_sample
+
sqrt_one_minus_alpha_prod
*
noise
return
noisy_sample
device
=
original_samples
.
device
batch_size
=
original_samples
.
shape
[
0
]
timesteps
=
timesteps
.
reshape
(
batch_size
,
1
,
1
,
1
)
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
noisy_samples
=
sqrt_alpha_prod
.
to
(
device
)
*
original_samples
+
sqrt_one_minus_alpha_prod
.
to
(
device
)
*
noise
return
noisy_samples
def
__len__
(
self
):
def
__len__
(
self
):
return
self
.
config
.
timesteps
return
self
.
config
.
timesteps
src/diffusers/schedulers/scheduling_pndm.py
View file @
e29fc446
...
@@ -66,7 +66,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -66,7 +66,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
if
beta_schedule
==
"linear"
:
if
beta_schedule
==
"linear"
:
self
.
betas
=
np
.
linspace
(
beta_start
,
beta_end
,
timesteps
,
dtype
=
np
.
float32
)
self
.
betas
=
np
.
linspace
(
beta_start
,
beta_end
,
timesteps
,
dtype
=
np
.
float32
)
elif
beta_schedule
==
"squaredcos_cap_v2"
:
elif
beta_schedule
==
"squaredcos_cap_v2"
:
# G
LIDE
cosine schedule
# G
lide
cosine schedule
self
.
betas
=
betas_for_alpha_bar
(
timesteps
)
self
.
betas
=
betas_for_alpha_bar
(
timesteps
)
else
:
else
:
raise
NotImplementedError
(
f
"
{
beta_schedule
}
does is not implemented for
{
self
.
__class__
}
"
)
raise
NotImplementedError
(
f
"
{
beta_schedule
}
does is not implemented for
{
self
.
__class__
}
"
)
...
...
src/diffusers/utils/__init__.py
View file @
e29fc446
...
@@ -45,10 +45,34 @@ except importlib_metadata.PackageNotFoundError:
...
@@ -45,10 +45,34 @@ except importlib_metadata.PackageNotFoundError:
_transformers_available
=
False
_transformers_available
=
False
_inflect_available
=
importlib
.
util
.
find_spec
(
"inflect"
)
is
not
None
try
:
_inflect_version
=
importlib_metadata
.
version
(
"inflect"
)
logger
.
debug
(
f
"Successfully imported inflect version
{
_inflect_version
}
"
)
except
importlib_metadata
.
PackageNotFoundError
:
_inflect_available
=
False
_unidecode_available
=
importlib
.
util
.
find_spec
(
"unidecode"
)
is
not
None
try
:
_unidecode_version
=
importlib_metadata
.
version
(
"unidecode"
)
logger
.
debug
(
f
"Successfully imported unidecode version
{
_unidecode_version
}
"
)
except
importlib_metadata
.
PackageNotFoundError
:
_unidecode_available
=
False
def
is_transformers_available
():
def
is_transformers_available
():
return
_transformers_available
return
_transformers_available
def
is_inflect_available
():
return
_inflect_available
def
is_unidecode_available
():
return
_unidecode_available
class
RepositoryNotFoundError
(
HTTPError
):
class
RepositoryNotFoundError
(
HTTPError
):
"""
"""
Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
...
@@ -70,9 +94,23 @@ TRANSFORMERS_IMPORT_ERROR = """
...
@@ -70,9 +94,23 @@ TRANSFORMERS_IMPORT_ERROR = """
"""
"""
UNIDECODE_IMPORT_ERROR
=
"""
{0} requires the unidecode library but it was not found in your environment. You can install it with pip:
`pip install Unidecode`
"""
INFLECT_IMPORT_ERROR
=
"""
{0} requires the inflect library but it was not found in your environment. You can install it with pip:
`pip install inflect`
"""
BACKENDS_MAPPING
=
OrderedDict
(
BACKENDS_MAPPING
=
OrderedDict
(
[
[
(
"transformers"
,
(
is_transformers_available
,
TRANSFORMERS_IMPORT_ERROR
)),
(
"transformers"
,
(
is_transformers_available
,
TRANSFORMERS_IMPORT_ERROR
)),
(
"unidecode"
,
(
is_unidecode_available
,
UNIDECODE_IMPORT_ERROR
)),
(
"inflect"
,
(
is_inflect_available
,
INFLECT_IMPORT_ERROR
)),
]
]
)
)
...
...
src/diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py
0 → 100644
View file @
e29fc446
# This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa
from
..utils
import
DummyObject
,
requires_backends
class
GradTTS
(
metaclass
=
DummyObject
):
_backends
=
[
"transformers"
,
"inflect"
,
"unidecode"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"transformers"
,
"inflect"
,
"unidecode"
])
src/diffusers/utils/dummy_transformers_objects.py
View file @
e29fc446
...
@@ -3,21 +3,21 @@
...
@@ -3,21 +3,21 @@
from
..utils
import
DummyObject
,
requires_backends
from
..utils
import
DummyObject
,
requires_backends
class
G
LIDE
SuperResUNetModel
(
metaclass
=
DummyObject
):
class
G
lide
SuperResUNetModel
(
metaclass
=
DummyObject
):
_backends
=
[
"transformers"
]
_backends
=
[
"transformers"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"transformers"
])
requires_backends
(
self
,
[
"transformers"
])
class
G
LIDE
TextToImageUNetModel
(
metaclass
=
DummyObject
):
class
G
lide
TextToImageUNetModel
(
metaclass
=
DummyObject
):
_backends
=
[
"transformers"
]
_backends
=
[
"transformers"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"transformers"
])
requires_backends
(
self
,
[
"transformers"
])
class
G
LIDE
UNetModel
(
metaclass
=
DummyObject
):
class
G
lide
UNetModel
(
metaclass
=
DummyObject
):
_backends
=
[
"transformers"
]
_backends
=
[
"transformers"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
...
@@ -31,10 +31,7 @@ class UNetGradTTSModel(metaclass=DummyObject):
...
@@ -31,10 +31,7 @@ class UNetGradTTSModel(metaclass=DummyObject):
requires_backends
(
self
,
[
"transformers"
])
requires_backends
(
self
,
[
"transformers"
])
GLIDE
=
None
class
Glide
(
metaclass
=
DummyObject
):
class
GradTTS
(
metaclass
=
DummyObject
):
_backends
=
[
"transformers"
]
_backends
=
[
"transformers"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
...
...
tests/test_modeling_utils.py
View file @
e29fc446
...
@@ -21,18 +21,18 @@ import unittest
...
@@ -21,18 +21,18 @@ import unittest
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
pytest
from
diffusers
import
(
from
diffusers
import
(
BDDM
,
BDDMPipeline
,
DDIM
,
DDIMPipeline
,
DDPM
,
GLIDE
,
PNDM
,
DDIMScheduler
,
DDIMScheduler
,
DDPMPipeline
,
DDPMScheduler
,
DDPMScheduler
,
GLIDESuperResUNetModel
,
GlidePipeline
,
GLIDETextToImageUNetModel
,
GlideSuperResUNetModel
,
LatentDiffusion
,
GlideTextToImageUNetModel
,
GradTTSPipeline
,
LatentDiffusionPipeline
,
PNDMPipeline
,
PNDMScheduler
,
PNDMScheduler
,
UNetGradTTSModel
,
UNetGradTTSModel
,
UNetLDMModel
,
UNetLDMModel
,
...
@@ -247,13 +247,13 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -247,13 +247,13 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
# fmt: off
expected_output_slice
=
torch
.
tensor
([
0.2891
,
-
0.1899
,
0.2595
,
-
0.6214
,
0.0968
,
-
0.2622
,
0.4688
,
0.1311
,
0.0053
])
expected_output_slice
=
torch
.
tensor
([
0.2891
,
-
0.1899
,
0.2595
,
-
0.6214
,
0.0968
,
-
0.2622
,
0.4688
,
0.1311
,
0.0053
])
# fmt: on
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
G
LIDE
SuperResUNetTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
class
G
lide
SuperResUNetTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
G
LIDE
SuperResUNetModel
model_class
=
G
lide
SuperResUNetModel
@
property
@
property
def
dummy_input
(
self
):
def
dummy_input
(
self
):
...
@@ -309,7 +309,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
...
@@ -309,7 +309,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
def
test_from_pretrained_hub
(
self
):
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
G
LIDE
SuperResUNetModel
.
from_pretrained
(
model
,
loading_info
=
G
lide
SuperResUNetModel
.
from_pretrained
(
"fusing/glide-super-res-dummy"
,
output_loading_info
=
True
"fusing/glide-super-res-dummy"
,
output_loading_info
=
True
)
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
...
@@ -321,7 +321,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
...
@@ -321,7 +321,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
assert
image
is
not
None
,
"Make sure output is not None"
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
def
test_output_pretrained
(
self
):
model
=
G
LIDE
SuperResUNetModel
.
from_pretrained
(
"fusing/glide-super-res-dummy"
)
model
=
G
lide
SuperResUNetModel
.
from_pretrained
(
"fusing/glide-super-res-dummy"
)
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
...
@@ -342,8 +342,8 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
...
@@ -342,8 +342,8 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
G
LIDE
TextToImageUNetModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
class
G
lide
TextToImageUNetModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
G
LIDE
TextToImageUNetModel
model_class
=
G
lide
TextToImageUNetModel
@
property
@
property
def
dummy_input
(
self
):
def
dummy_input
(
self
):
...
@@ -401,7 +401,7 @@ class GLIDETextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -401,7 +401,7 @@ class GLIDETextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
def
test_from_pretrained_hub
(
self
):
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
G
LIDE
TextToImageUNetModel
.
from_pretrained
(
model
,
loading_info
=
G
lide
TextToImageUNetModel
.
from_pretrained
(
"fusing/unet-glide-text2im-dummy"
,
output_loading_info
=
True
"fusing/unet-glide-text2im-dummy"
,
output_loading_info
=
True
)
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
...
@@ -413,7 +413,7 @@ class GLIDETextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -413,7 +413,7 @@ class GLIDETextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
assert
image
is
not
None
,
"Make sure output is not None"
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
def
test_output_pretrained
(
self
):
model
=
G
LIDE
TextToImageUNetModel
.
from_pretrained
(
"fusing/unet-glide-text2im-dummy"
)
model
=
G
lide
TextToImageUNetModel
.
from_pretrained
(
"fusing/unet-glide-text2im-dummy"
)
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
...
@@ -431,7 +431,7 @@ class GLIDETextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -431,7 +431,7 @@ class GLIDETextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
output
,
_
=
torch
.
split
(
output
,
3
,
dim
=
1
)
output
,
_
=
torch
.
split
(
output
,
3
,
dim
=
1
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
# fmt: off
expected_output_slice
=
torch
.
tensor
([
2.7766
,
-
10.3558
,
-
14.9149
,
-
0.9376
,
-
14.9175
,
-
17.7679
,
-
5.5565
,
-
12.9521
,
-
12.9845
])
expected_output_slice
=
torch
.
tensor
([
2.7766
,
-
10.3558
,
-
14.9149
,
-
0.9376
,
-
14.9175
,
-
17.7679
,
-
5.5565
,
-
12.9521
,
-
12.9845
])
# fmt: on
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
...
@@ -571,7 +571,7 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -571,7 +571,7 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
output_slice
=
output
[
0
,
-
3
:,
-
3
:].
flatten
()
output_slice
=
output
[
0
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
0.0690
,
-
0.0531
,
0.0633
,
-
0.0660
,
-
0.0541
,
0.0650
,
-
0.0656
,
-
0.0555
,
0.0617
])
expected_output_slice
=
torch
.
tensor
([
-
0.0690
,
-
0.0531
,
0.0633
,
-
0.0660
,
-
0.0541
,
0.0650
,
-
0.0656
,
-
0.0555
,
0.0617
])
# fmt: on
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
...
@@ -583,11 +583,11 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -583,11 +583,11 @@ class PipelineTesterMixin(unittest.TestCase):
model
=
UNetModel
(
ch
=
32
,
ch_mult
=
(
1
,
2
),
num_res_blocks
=
2
,
attn_resolutions
=
(
16
,),
resolution
=
32
)
model
=
UNetModel
(
ch
=
32
,
ch_mult
=
(
1
,
2
),
num_res_blocks
=
2
,
attn_resolutions
=
(
16
,),
resolution
=
32
)
schedular
=
DDPMScheduler
(
timesteps
=
10
)
schedular
=
DDPMScheduler
(
timesteps
=
10
)
ddpm
=
DDPM
(
model
,
schedular
)
ddpm
=
DDPM
Pipeline
(
model
,
schedular
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
ddpm
.
save_pretrained
(
tmpdirname
)
ddpm
.
save_pretrained
(
tmpdirname
)
new_ddpm
=
DDPM
.
from_pretrained
(
tmpdirname
)
new_ddpm
=
DDPM
Pipeline
.
from_pretrained
(
tmpdirname
)
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
...
@@ -601,7 +601,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -601,7 +601,7 @@ class PipelineTesterMixin(unittest.TestCase):
def
test_from_pretrained_hub
(
self
):
def
test_from_pretrained_hub
(
self
):
model_path
=
"fusing/ddpm-cifar10"
model_path
=
"fusing/ddpm-cifar10"
ddpm
=
DDPM
.
from_pretrained
(
model_path
)
ddpm
=
DDPM
Pipeline
.
from_pretrained
(
model_path
)
ddpm_from_hub
=
DiffusionPipeline
.
from_pretrained
(
model_path
)
ddpm_from_hub
=
DiffusionPipeline
.
from_pretrained
(
model_path
)
ddpm
.
noise_scheduler
.
num_timesteps
=
10
ddpm
.
noise_scheduler
.
num_timesteps
=
10
...
@@ -624,7 +624,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -624,7 +624,7 @@ class PipelineTesterMixin(unittest.TestCase):
noise_scheduler
=
DDPMScheduler
.
from_config
(
model_id
)
noise_scheduler
=
DDPMScheduler
.
from_config
(
model_id
)
noise_scheduler
=
noise_scheduler
.
set_format
(
"pt"
)
noise_scheduler
=
noise_scheduler
.
set_format
(
"pt"
)
ddpm
=
DDPM
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
ddpm
=
DDPM
Pipeline
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
image
=
ddpm
(
generator
=
generator
)
image
=
ddpm
(
generator
=
generator
)
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
...
@@ -641,7 +641,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -641,7 +641,7 @@ class PipelineTesterMixin(unittest.TestCase):
unet
=
UNetModel
.
from_pretrained
(
model_id
)
unet
=
UNetModel
.
from_pretrained
(
model_id
)
noise_scheduler
=
DDIMScheduler
(
tensor_format
=
"pt"
)
noise_scheduler
=
DDIMScheduler
(
tensor_format
=
"pt"
)
ddim
=
DDIM
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
ddim
=
DDIM
Pipeline
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
image
=
ddim
(
generator
=
generator
,
eta
=
0.0
)
image
=
ddim
(
generator
=
generator
,
eta
=
0.0
)
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
...
@@ -660,7 +660,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -660,7 +660,7 @@ class PipelineTesterMixin(unittest.TestCase):
unet
=
UNetModel
.
from_pretrained
(
model_id
)
unet
=
UNetModel
.
from_pretrained
(
model_id
)
noise_scheduler
=
PNDMScheduler
(
tensor_format
=
"pt"
)
noise_scheduler
=
PNDMScheduler
(
tensor_format
=
"pt"
)
pndm
=
PNDM
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
pndm
=
PNDM
Pipeline
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
image
=
pndm
(
generator
=
generator
)
image
=
pndm
(
generator
=
generator
)
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
...
@@ -674,7 +674,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -674,7 +674,7 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
@
slow
def
test_ldm_text2img
(
self
):
def
test_ldm_text2img
(
self
):
model_id
=
"fusing/latent-diffusion-text2im-large"
model_id
=
"fusing/latent-diffusion-text2im-large"
ldm
=
LatentDiffusion
.
from_pretrained
(
model_id
)
ldm
=
LatentDiffusion
Pipeline
.
from_pretrained
(
model_id
)
prompt
=
"A painting of a squirrel eating a burger"
prompt
=
"A painting of a squirrel eating a burger"
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
...
@@ -689,7 +689,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -689,7 +689,7 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
@
slow
def
test_glide_text2img
(
self
):
def
test_glide_text2img
(
self
):
model_id
=
"fusing/glide-base"
model_id
=
"fusing/glide-base"
glide
=
G
LIDE
.
from_pretrained
(
model_id
)
glide
=
G
lidePipeline
.
from_pretrained
(
model_id
)
prompt
=
"a pencil sketch of a corgi"
prompt
=
"a pencil sketch of a corgi"
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
...
@@ -701,11 +701,28 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -701,11 +701,28 @@ class PipelineTesterMixin(unittest.TestCase):
expected_slice
=
torch
.
tensor
([
0.7119
,
0.7073
,
0.6460
,
0.7780
,
0.7423
,
0.6926
,
0.7378
,
0.7189
,
0.7784
])
expected_slice
=
torch
.
tensor
([
0.7119
,
0.7073
,
0.6460
,
0.7780
,
0.7423
,
0.6926
,
0.7378
,
0.7189
,
0.7784
])
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
def
test_grad_tts
(
self
):
model_id
=
"fusing/grad-tts-libri-tts"
grad_tts
=
GradTTSPipeline
.
from_pretrained
(
model_id
)
text
=
"Hello world, I missed you so much."
generator
=
torch
.
manual_seed
(
0
)
# generate mel spectograms using text
mel_spec
=
grad_tts
(
text
,
generator
=
generator
)
assert
mel_spec
.
shape
==
(
1
,
80
,
143
)
expected_slice
=
torch
.
tensor
(
[
-
6.6119
,
-
6.5963
,
-
6.2776
,
-
6.7496
,
-
6.7096
,
-
6.5131
,
-
6.4643
,
-
6.4817
,
-
6.7185
]
)
assert
(
mel_spec
[
0
,
:
3
,
:
3
].
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
def
test_module_from_pipeline
(
self
):
def
test_module_from_pipeline
(
self
):
model
=
DiffWave
(
num_res_layers
=
4
)
model
=
DiffWave
(
num_res_layers
=
4
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
12
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
12
)
bddm
=
BDDM
(
model
,
noise_scheduler
)
bddm
=
BDDM
Pipeline
(
model
,
noise_scheduler
)
# check if the library name for the diffwave moduel is set to pipeline module
# check if the library name for the diffwave moduel is set to pipeline module
self
.
assertTrue
(
bddm
.
config
[
"diffwave"
][
0
]
==
"pipeline_bddm"
)
self
.
assertTrue
(
bddm
.
config
[
"diffwave"
][
0
]
==
"pipeline_bddm"
)
...
@@ -713,6 +730,6 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -713,6 +730,6 @@ class PipelineTesterMixin(unittest.TestCase):
# check if we can save and load the pipeline
# check if we can save and load the pipeline
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
bddm
.
save_pretrained
(
tmpdirname
)
bddm
.
save_pretrained
(
tmpdirname
)
_
=
BDDM
.
from_pretrained
(
tmpdirname
)
_
=
BDDM
Pipeline
.
from_pretrained
(
tmpdirname
)
# check if the same works using the DifusionPipeline class
# check if the same works using the DifusionPipeline class
_
=
DiffusionPipeline
.
from_pretrained
(
tmpdirname
)
_
=
DiffusionPipeline
.
from_pretrained
(
tmpdirname
)
utils/check_dummies.py
View file @
e29fc446
...
@@ -23,10 +23,9 @@ import re
...
@@ -23,10 +23,9 @@ import re
PATH_TO_DIFFUSERS
=
"src/diffusers"
PATH_TO_DIFFUSERS
=
"src/diffusers"
# Matches is_xxx_available()
# Matches is_xxx_available()
_re_backend
=
re
.
compile
(
r
"
if
is\_([a-z_]*)_available\(\)"
)
_re_backend
=
re
.
compile
(
r
"is\_([a-z_]*)_available\(\)"
)
# Matches from xxx import bla
# Matches from xxx import bla
_re_single_line_import
=
re
.
compile
(
r
"\s+from\s+\S*\s+import\s+([^\(\s].*)\n"
)
_re_single_line_import
=
re
.
compile
(
r
"\s+from\s+\S*\s+import\s+([^\(\s].*)\n"
)
_re_test_backend
=
re
.
compile
(
r
"^\s+if\s+not\s+is\_[a-z]*\_available\(\)"
)
DUMMY_CONSTANT
=
"""
DUMMY_CONSTANT
=
"""
...
@@ -54,7 +53,7 @@ def find_backend(line):
...
@@ -54,7 +53,7 @@ def find_backend(line):
if
len
(
backends
)
==
0
:
if
len
(
backends
)
==
0
:
return
None
return
None
return
backends
[
0
]
return
"_and_"
.
join
(
backends
)
def
read_init
():
def
read_init
():
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment