Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
a643c630
Unverified
Commit
a643c630
authored
Dec 08, 2022
by
Patrick von Platen
Committed by
GitHub
Dec 08, 2022
Browse files
[K Diffusion] Add k diffusion sampler natively (#1603)
* uP * uP
parent
326de419
Changes
13
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
602 additions
and
2 deletions
+602
-2
examples/community/README.md
examples/community/README.md
+2
-2
examples/community/sd_text2img_k_diffusion.py
examples/community/sd_text2img_k_diffusion.py
+5
-0
hi
hi
+1
-0
setup.py
setup.py
+2
-0
src/diffusers/__init__.py
src/diffusers/__init__.py
+6
-0
src/diffusers/dependency_versions_table.py
src/diffusers/dependency_versions_table.py
+1
-0
src/diffusers/pipelines/__init__.py
src/diffusers/pipelines/__init__.py
+4
-0
src/diffusers/pipelines/stable_diffusion/__init__.py
src/diffusers/pipelines/stable_diffusion/__init__.py
+4
-0
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py
...stable_diffusion/pipeline_stable_diffusion_k_diffusion.py
+462
-0
src/diffusers/utils/__init__.py
src/diffusers/utils/__init__.py
+1
-0
src/diffusers/utils/dummy_torch_and_transformers_and_k_diffusion_objects.py
...s/dummy_torch_and_transformers_and_k_diffusion_objects.py
+19
-0
src/diffusers/utils/import_utils.py
src/diffusers/utils/import_utils.py
+18
-0
tests/pipelines/stable_diffusion/test_stable_diffusion_k_diffusion.py
...nes/stable_diffusion/test_stable_diffusion_k_diffusion.py
+77
-0
No files found.
examples/community/README.md
View file @
a643c630
...
...
@@ -686,7 +686,7 @@ pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom
pipe
=
pipe
.
to
(
"cuda"
)
prompt
=
"an astronaut riding a horse on mars"
pipe
.
set_s
amp
ler
(
"sample_heun"
)
pipe
.
set_s
chedu
ler
(
"sample_heun"
)
generator
=
torch
.
Generator
(
device
=
"cuda"
).
manual_seed
(
seed
)
image
=
pipe
(
prompt
,
generator
=
generator
,
num_inference_steps
=
20
).
images
[
0
]
...
...
@@ -721,7 +721,7 @@ pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom
pipe
.
scheduler
=
EulerDiscreteScheduler
.
from_config
(
pipe
.
scheduler
.
config
)
pipe
=
pipe
.
to
(
"cuda"
)
pipe
.
set_s
amp
ler
(
"sample_euler"
)
pipe
.
set_s
chedu
ler
(
"sample_euler"
)
generator
=
torch
.
Generator
(
device
=
"cuda"
).
manual_seed
(
seed
)
image
=
pipe
(
prompt
,
generator
=
generator
,
num_inference_steps
=
50
).
images
[
0
]
```
...
...
examples/community/sd_text2img_k_diffusion.py
View file @
a643c630
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
import
importlib
import
warnings
from
typing
import
Callable
,
List
,
Optional
,
Union
import
torch
...
...
@@ -111,6 +112,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
self
.
k_diffusion_model
=
CompVisDenoiser
(
model
)
def
set_sampler
(
self
,
scheduler_type
:
str
):
warnings
.
warn
(
"The `set_sampler` method is deprecated, please use `set_scheduler` instead."
)
return
self
.
set_scheduler
(
scheduler_type
)
def
set_scheduler
(
self
,
scheduler_type
:
str
):
library
=
importlib
.
import_module
(
"k_diffusion"
)
sampling
=
getattr
(
library
,
"sampling"
)
self
.
sampler
=
getattr
(
sampling
,
scheduler_type
)
...
...
hi
0 → 100644
View file @
a643c630
setup.py
View file @
a643c630
...
...
@@ -91,6 +91,7 @@ _deps = [
"isort>=5.5.4"
,
"jax>=0.2.8,!=0.3.2"
,
"jaxlib>=0.1.65"
,
"k-diffusion"
,
"librosa"
,
"modelcards>=0.1.4"
,
"numpy"
,
...
...
@@ -182,6 +183,7 @@ extras["docs"] = deps_list("hf-doc-builder")
extras
[
"training"
]
=
deps_list
(
"accelerate"
,
"datasets"
,
"tensorboard"
,
"modelcards"
)
extras
[
"test"
]
=
deps_list
(
"datasets"
,
"k-diffusion"
,
"librosa"
,
"parameterized"
,
"pytest"
,
...
...
src/diffusers/__init__.py
View file @
a643c630
...
...
@@ -5,6 +5,7 @@ from .onnx_utils import OnnxRuntimeModel
from
.utils
import
(
is_flax_available
,
is_inflect_available
,
is_k_diffusion_available
,
is_onnx_available
,
is_scipy_available
,
is_torch_available
,
...
...
@@ -90,6 +91,11 @@ if is_torch_available() and is_transformers_available():
else
:
from
.utils.dummy_torch_and_transformers_objects
import
*
# noqa F403
if
is_torch_available
()
and
is_transformers_available
()
and
is_k_diffusion_available
():
from
.pipelines
import
StableDiffusionKDiffusionPipeline
else
:
from
.utils.dummy_torch_and_transformers_and_k_diffusion_objects
import
*
# noqa F403
if
is_torch_available
()
and
is_transformers_available
()
and
is_onnx_available
():
from
.pipelines
import
(
OnnxStableDiffusionImg2ImgPipeline
,
...
...
src/diffusers/dependency_versions_table.py
View file @
a643c630
...
...
@@ -15,6 +15,7 @@ deps = {
"isort"
:
"isort>=5.5.4"
,
"jax"
:
"jax>=0.2.8,!=0.3.2"
,
"jaxlib"
:
"jaxlib>=0.1.65"
,
"k-diffusion"
:
"k-diffusion"
,
"librosa"
:
"librosa"
,
"modelcards"
:
"modelcards>=0.1.4"
,
"numpy"
:
"numpy"
,
...
...
src/diffusers/pipelines/__init__.py
View file @
a643c630
from
..utils
import
(
is_flax_available
,
is_k_diffusion_available
,
is_librosa_available
,
is_onnx_available
,
is_torch_available
,
...
...
@@ -56,5 +57,8 @@ if is_transformers_available() and is_onnx_available():
StableDiffusionOnnxPipeline
,
)
if
is_torch_available
()
and
is_transformers_available
()
and
is_k_diffusion_available
():
from
.stable_diffusion
import
StableDiffusionKDiffusionPipeline
if
is_transformers_available
()
and
is_flax_available
():
from
.stable_diffusion
import
FlaxStableDiffusionPipeline
src/diffusers/pipelines/stable_diffusion/__init__.py
View file @
a643c630
...
...
@@ -9,6 +9,7 @@ from PIL import Image
from
...utils
import
(
BaseOutput
,
is_flax_available
,
is_k_diffusion_available
,
is_onnx_available
,
is_torch_available
,
is_transformers_available
,
...
...
@@ -48,6 +49,9 @@ if is_transformers_available() and is_torch_available() and is_transformers_vers
else
:
from
...utils.dummy_torch_and_transformers_objects
import
StableDiffusionImageVariationPipeline
if
is_transformers_available
()
and
is_torch_available
()
and
is_k_diffusion_available
():
from
.pipeline_stable_diffusion_k_diffusion
import
StableDiffusionKDiffusionPipeline
if
is_transformers_available
()
and
is_onnx_available
():
from
.pipeline_onnx_stable_diffusion
import
OnnxStableDiffusionPipeline
,
StableDiffusionOnnxPipeline
from
.pipeline_onnx_stable_diffusion_img2img
import
OnnxStableDiffusionImg2ImgPipeline
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py
0 → 100755
View file @
a643c630
This diff is collapsed.
Click to expand it.
src/diffusers/utils/__init__.py
View file @
a643c630
...
...
@@ -29,6 +29,7 @@ from .import_utils import (
is_accelerate_available
,
is_flax_available
,
is_inflect_available
,
is_k_diffusion_available
,
is_librosa_available
,
is_modelcards_available
,
is_onnx_available
,
...
...
src/diffusers/utils/dummy_torch_and_transformers_and_k_diffusion_objects.py
0 → 100644
View file @
a643c630
# This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa
from
..utils
import
DummyObject
,
requires_backends
class
StableDiffusionKDiffusionPipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
,
"transformers"
,
"k_diffusion"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
,
"transformers"
,
"k_diffusion"
])
@
classmethod
def
from_config
(
cls
,
*
args
,
**
kwargs
):
requires_backends
(
cls
,
[
"torch"
,
"transformers"
,
"k_diffusion"
])
@
classmethod
def
from_pretrained
(
cls
,
*
args
,
**
kwargs
):
requires_backends
(
cls
,
[
"torch"
,
"transformers"
,
"k_diffusion"
])
src/diffusers/utils/import_utils.py
View file @
a643c630
...
...
@@ -210,6 +210,13 @@ try:
except
importlib_metadata
.
PackageNotFoundError
:
_xformers_available
=
False
_k_diffusion_available
=
importlib
.
util
.
find_spec
(
"k_diffusion"
)
is
not
None
try
:
_k_diffusion_version
=
importlib_metadata
.
version
(
"k_diffusion"
)
logger
.
debug
(
f
"Successfully imported k-diffusion version
{
_k_diffusion_version
}
"
)
except
importlib_metadata
.
PackageNotFoundError
:
_k_diffusion_available
=
False
def
is_torch_available
():
return
_torch_available
...
...
@@ -263,6 +270,10 @@ def is_accelerate_available():
return
_accelerate_available
def
is_k_diffusion_available
():
return
_k_diffusion_available
# docstyle-ignore
FLAX_IMPORT_ERROR
=
"""
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
...
...
@@ -317,6 +328,12 @@ UNIDECODE_IMPORT_ERROR = """
Unidecode`
"""
# docstyle-ignore
K_DIFFUSION_IMPORT_ERROR
=
"""
{0} requires the k-diffusion library but it was not found in your environment. You can install it with pip: `pip
install k-diffusion`
"""
BACKENDS_MAPPING
=
OrderedDict
(
[
...
...
@@ -329,6 +346,7 @@ BACKENDS_MAPPING = OrderedDict(
(
"transformers"
,
(
is_transformers_available
,
TRANSFORMERS_IMPORT_ERROR
)),
(
"unidecode"
,
(
is_unidecode_available
,
UNIDECODE_IMPORT_ERROR
)),
(
"librosa"
,
(
is_librosa_available
,
LIBROSA_IMPORT_ERROR
)),
(
"k_diffusion"
,
(
is_k_diffusion_available
,
K_DIFFUSION_IMPORT_ERROR
)),
]
)
...
...
tests/pipelines/stable_diffusion/test_stable_diffusion_k_diffusion.py
0 → 100644
View file @
a643c630
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
gc
import
unittest
import
numpy
as
np
import
torch
from
diffusers
import
StableDiffusionKDiffusionPipeline
from
diffusers.utils
import
slow
,
torch_device
from
diffusers.utils.testing_utils
import
require_torch_gpu
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
@
slow
@
require_torch_gpu
class
StableDiffusionPipelineIntegrationTests
(
unittest
.
TestCase
):
def
tearDown
(
self
):
# clean up the VRAM after each test
super
().
tearDown
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
def
test_stable_diffusion_1
(
self
):
sd_pipe
=
StableDiffusionKDiffusionPipeline
.
from_pretrained
(
"CompVis/stable-diffusion-v1-4"
)
sd_pipe
=
sd_pipe
.
to
(
torch_device
)
sd_pipe
.
set_progress_bar_config
(
disable
=
None
)
sd_pipe
.
set_scheduler
(
"sample_euler"
)
prompt
=
"A painting of a squirrel eating a burger"
generator
=
torch
.
Generator
(
device
=
torch_device
).
manual_seed
(
0
)
output
=
sd_pipe
([
prompt
],
generator
=
generator
,
guidance_scale
=
9.0
,
num_inference_steps
=
20
,
output_type
=
"np"
)
image
=
output
.
images
image_slice
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
assert
image
.
shape
==
(
1
,
512
,
512
,
3
)
expected_slice
=
np
.
array
([
0.8887
,
0.915
,
0.91
,
0.894
,
0.909
,
0.912
,
0.919
,
0.925
,
0.883
])
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
def
test_stable_diffusion_2
(
self
):
sd_pipe
=
StableDiffusionKDiffusionPipeline
.
from_pretrained
(
"stabilityai/stable-diffusion-2-1-base"
)
sd_pipe
=
sd_pipe
.
to
(
torch_device
)
sd_pipe
.
set_progress_bar_config
(
disable
=
None
)
sd_pipe
.
set_scheduler
(
"sample_euler"
)
prompt
=
"A painting of a squirrel eating a burger"
generator
=
torch
.
Generator
(
device
=
torch_device
).
manual_seed
(
0
)
output
=
sd_pipe
([
prompt
],
generator
=
generator
,
guidance_scale
=
9.0
,
num_inference_steps
=
20
,
output_type
=
"np"
)
image
=
output
.
images
image_slice
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
assert
image
.
shape
==
(
1
,
512
,
512
,
3
)
expected_slice
=
np
.
array
(
[
0.826810
,
0.81958747
,
0.8510199
,
0.8376758
,
0.83958465
,
0.8682068
,
0.84370345
,
0.85251087
,
0.85884345
]
)
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment