Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
a643c630
Unverified
Commit
a643c630
authored
Dec 08, 2022
by
Patrick von Platen
Committed by
GitHub
Dec 08, 2022
Browse files
[K Diffusion] Add k diffusion sampler natively (#1603)
* uP * uP
parent
326de419
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
602 additions
and
2 deletions
+602
-2
examples/community/README.md
examples/community/README.md
+2
-2
examples/community/sd_text2img_k_diffusion.py
examples/community/sd_text2img_k_diffusion.py
+5
-0
hi
hi
+1
-0
setup.py
setup.py
+2
-0
src/diffusers/__init__.py
src/diffusers/__init__.py
+6
-0
src/diffusers/dependency_versions_table.py
src/diffusers/dependency_versions_table.py
+1
-0
src/diffusers/pipelines/__init__.py
src/diffusers/pipelines/__init__.py
+4
-0
src/diffusers/pipelines/stable_diffusion/__init__.py
src/diffusers/pipelines/stable_diffusion/__init__.py
+4
-0
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py
...stable_diffusion/pipeline_stable_diffusion_k_diffusion.py
+462
-0
src/diffusers/utils/__init__.py
src/diffusers/utils/__init__.py
+1
-0
src/diffusers/utils/dummy_torch_and_transformers_and_k_diffusion_objects.py
...s/dummy_torch_and_transformers_and_k_diffusion_objects.py
+19
-0
src/diffusers/utils/import_utils.py
src/diffusers/utils/import_utils.py
+18
-0
tests/pipelines/stable_diffusion/test_stable_diffusion_k_diffusion.py
...nes/stable_diffusion/test_stable_diffusion_k_diffusion.py
+77
-0
No files found.
examples/community/README.md
View file @
a643c630
...
@@ -686,7 +686,7 @@ pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom
...
@@ -686,7 +686,7 @@ pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom
pipe
=
pipe
.
to
(
"cuda"
)
pipe
=
pipe
.
to
(
"cuda"
)
prompt
=
"an astronaut riding a horse on mars"
prompt
=
"an astronaut riding a horse on mars"
pipe
.
set_s
amp
ler
(
"sample_heun"
)
pipe
.
set_s
chedu
ler
(
"sample_heun"
)
generator
=
torch
.
Generator
(
device
=
"cuda"
).
manual_seed
(
seed
)
generator
=
torch
.
Generator
(
device
=
"cuda"
).
manual_seed
(
seed
)
image
=
pipe
(
prompt
,
generator
=
generator
,
num_inference_steps
=
20
).
images
[
0
]
image
=
pipe
(
prompt
,
generator
=
generator
,
num_inference_steps
=
20
).
images
[
0
]
...
@@ -721,7 +721,7 @@ pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom
...
@@ -721,7 +721,7 @@ pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom
pipe
.
scheduler
=
EulerDiscreteScheduler
.
from_config
(
pipe
.
scheduler
.
config
)
pipe
.
scheduler
=
EulerDiscreteScheduler
.
from_config
(
pipe
.
scheduler
.
config
)
pipe
=
pipe
.
to
(
"cuda"
)
pipe
=
pipe
.
to
(
"cuda"
)
pipe
.
set_s
amp
ler
(
"sample_euler"
)
pipe
.
set_s
chedu
ler
(
"sample_euler"
)
generator
=
torch
.
Generator
(
device
=
"cuda"
).
manual_seed
(
seed
)
generator
=
torch
.
Generator
(
device
=
"cuda"
).
manual_seed
(
seed
)
image
=
pipe
(
prompt
,
generator
=
generator
,
num_inference_steps
=
50
).
images
[
0
]
image
=
pipe
(
prompt
,
generator
=
generator
,
num_inference_steps
=
50
).
images
[
0
]
```
```
...
...
examples/community/sd_text2img_k_diffusion.py
View file @
a643c630
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
import
importlib
import
importlib
import
warnings
from
typing
import
Callable
,
List
,
Optional
,
Union
from
typing
import
Callable
,
List
,
Optional
,
Union
import
torch
import
torch
...
@@ -111,6 +112,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
...
@@ -111,6 +112,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
self
.
k_diffusion_model
=
CompVisDenoiser
(
model
)
self
.
k_diffusion_model
=
CompVisDenoiser
(
model
)
def
set_sampler
(
self
,
scheduler_type
:
str
):
def
set_sampler
(
self
,
scheduler_type
:
str
):
warnings
.
warn
(
"The `set_sampler` method is deprecated, please use `set_scheduler` instead."
)
return
self
.
set_scheduler
(
scheduler_type
)
def
set_scheduler
(
self
,
scheduler_type
:
str
):
library
=
importlib
.
import_module
(
"k_diffusion"
)
library
=
importlib
.
import_module
(
"k_diffusion"
)
sampling
=
getattr
(
library
,
"sampling"
)
sampling
=
getattr
(
library
,
"sampling"
)
self
.
sampler
=
getattr
(
sampling
,
scheduler_type
)
self
.
sampler
=
getattr
(
sampling
,
scheduler_type
)
...
...
hi
0 → 100644
View file @
a643c630
setup.py
View file @
a643c630
...
@@ -91,6 +91,7 @@ _deps = [
...
@@ -91,6 +91,7 @@ _deps = [
"isort>=5.5.4"
,
"isort>=5.5.4"
,
"jax>=0.2.8,!=0.3.2"
,
"jax>=0.2.8,!=0.3.2"
,
"jaxlib>=0.1.65"
,
"jaxlib>=0.1.65"
,
"k-diffusion"
,
"librosa"
,
"librosa"
,
"modelcards>=0.1.4"
,
"modelcards>=0.1.4"
,
"numpy"
,
"numpy"
,
...
@@ -182,6 +183,7 @@ extras["docs"] = deps_list("hf-doc-builder")
...
@@ -182,6 +183,7 @@ extras["docs"] = deps_list("hf-doc-builder")
extras
[
"training"
]
=
deps_list
(
"accelerate"
,
"datasets"
,
"tensorboard"
,
"modelcards"
)
extras
[
"training"
]
=
deps_list
(
"accelerate"
,
"datasets"
,
"tensorboard"
,
"modelcards"
)
extras
[
"test"
]
=
deps_list
(
extras
[
"test"
]
=
deps_list
(
"datasets"
,
"datasets"
,
"k-diffusion"
,
"librosa"
,
"librosa"
,
"parameterized"
,
"parameterized"
,
"pytest"
,
"pytest"
,
...
...
src/diffusers/__init__.py
View file @
a643c630
...
@@ -5,6 +5,7 @@ from .onnx_utils import OnnxRuntimeModel
...
@@ -5,6 +5,7 @@ from .onnx_utils import OnnxRuntimeModel
from
.utils
import
(
from
.utils
import
(
is_flax_available
,
is_flax_available
,
is_inflect_available
,
is_inflect_available
,
is_k_diffusion_available
,
is_onnx_available
,
is_onnx_available
,
is_scipy_available
,
is_scipy_available
,
is_torch_available
,
is_torch_available
,
...
@@ -90,6 +91,11 @@ if is_torch_available() and is_transformers_available():
...
@@ -90,6 +91,11 @@ if is_torch_available() and is_transformers_available():
else
:
else
:
from
.utils.dummy_torch_and_transformers_objects
import
*
# noqa F403
from
.utils.dummy_torch_and_transformers_objects
import
*
# noqa F403
if
is_torch_available
()
and
is_transformers_available
()
and
is_k_diffusion_available
():
from
.pipelines
import
StableDiffusionKDiffusionPipeline
else
:
from
.utils.dummy_torch_and_transformers_and_k_diffusion_objects
import
*
# noqa F403
if
is_torch_available
()
and
is_transformers_available
()
and
is_onnx_available
():
if
is_torch_available
()
and
is_transformers_available
()
and
is_onnx_available
():
from
.pipelines
import
(
from
.pipelines
import
(
OnnxStableDiffusionImg2ImgPipeline
,
OnnxStableDiffusionImg2ImgPipeline
,
...
...
src/diffusers/dependency_versions_table.py
View file @
a643c630
...
@@ -15,6 +15,7 @@ deps = {
...
@@ -15,6 +15,7 @@ deps = {
"isort"
:
"isort>=5.5.4"
,
"isort"
:
"isort>=5.5.4"
,
"jax"
:
"jax>=0.2.8,!=0.3.2"
,
"jax"
:
"jax>=0.2.8,!=0.3.2"
,
"jaxlib"
:
"jaxlib>=0.1.65"
,
"jaxlib"
:
"jaxlib>=0.1.65"
,
"k-diffusion"
:
"k-diffusion"
,
"librosa"
:
"librosa"
,
"librosa"
:
"librosa"
,
"modelcards"
:
"modelcards>=0.1.4"
,
"modelcards"
:
"modelcards>=0.1.4"
,
"numpy"
:
"numpy"
,
"numpy"
:
"numpy"
,
...
...
src/diffusers/pipelines/__init__.py
View file @
a643c630
from
..utils
import
(
from
..utils
import
(
is_flax_available
,
is_flax_available
,
is_k_diffusion_available
,
is_librosa_available
,
is_librosa_available
,
is_onnx_available
,
is_onnx_available
,
is_torch_available
,
is_torch_available
,
...
@@ -56,5 +57,8 @@ if is_transformers_available() and is_onnx_available():
...
@@ -56,5 +57,8 @@ if is_transformers_available() and is_onnx_available():
StableDiffusionOnnxPipeline
,
StableDiffusionOnnxPipeline
,
)
)
if
is_torch_available
()
and
is_transformers_available
()
and
is_k_diffusion_available
():
from
.stable_diffusion
import
StableDiffusionKDiffusionPipeline
if
is_transformers_available
()
and
is_flax_available
():
if
is_transformers_available
()
and
is_flax_available
():
from
.stable_diffusion
import
FlaxStableDiffusionPipeline
from
.stable_diffusion
import
FlaxStableDiffusionPipeline
src/diffusers/pipelines/stable_diffusion/__init__.py
View file @
a643c630
...
@@ -9,6 +9,7 @@ from PIL import Image
...
@@ -9,6 +9,7 @@ from PIL import Image
from
...utils
import
(
from
...utils
import
(
BaseOutput
,
BaseOutput
,
is_flax_available
,
is_flax_available
,
is_k_diffusion_available
,
is_onnx_available
,
is_onnx_available
,
is_torch_available
,
is_torch_available
,
is_transformers_available
,
is_transformers_available
,
...
@@ -48,6 +49,9 @@ if is_transformers_available() and is_torch_available() and is_transformers_vers
...
@@ -48,6 +49,9 @@ if is_transformers_available() and is_torch_available() and is_transformers_vers
else
:
else
:
from
...utils.dummy_torch_and_transformers_objects
import
StableDiffusionImageVariationPipeline
from
...utils.dummy_torch_and_transformers_objects
import
StableDiffusionImageVariationPipeline
if
is_transformers_available
()
and
is_torch_available
()
and
is_k_diffusion_available
():
from
.pipeline_stable_diffusion_k_diffusion
import
StableDiffusionKDiffusionPipeline
if
is_transformers_available
()
and
is_onnx_available
():
if
is_transformers_available
()
and
is_onnx_available
():
from
.pipeline_onnx_stable_diffusion
import
OnnxStableDiffusionPipeline
,
StableDiffusionOnnxPipeline
from
.pipeline_onnx_stable_diffusion
import
OnnxStableDiffusionPipeline
,
StableDiffusionOnnxPipeline
from
.pipeline_onnx_stable_diffusion_img2img
import
OnnxStableDiffusionImg2ImgPipeline
from
.pipeline_onnx_stable_diffusion_img2img
import
OnnxStableDiffusionImg2ImgPipeline
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py
0 → 100755
View file @
a643c630
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
importlib
from
typing
import
Callable
,
List
,
Optional
,
Union
import
torch
from
k_diffusion.external
import
CompVisDenoiser
,
CompVisVDenoiser
from
...
import
DiffusionPipeline
from
...schedulers
import
LMSDiscreteScheduler
from
...utils
import
is_accelerate_available
,
logging
from
.
import
StableDiffusionPipelineOutput
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
class
ModelWrapper
:
def
__init__
(
self
,
model
,
alphas_cumprod
):
self
.
model
=
model
self
.
alphas_cumprod
=
alphas_cumprod
def
apply_model
(
self
,
*
args
,
**
kwargs
):
if
len
(
args
)
==
3
:
encoder_hidden_states
=
args
[
-
1
]
args
=
args
[:
2
]
if
kwargs
.
get
(
"cond"
,
None
)
is
not
None
:
encoder_hidden_states
=
kwargs
.
pop
(
"cond"
)
return
self
.
model
(
*
args
,
encoder_hidden_states
=
encoder_hidden_states
,
**
kwargs
).
sample
class
StableDiffusionKDiffusionPipeline
(
DiffusionPipeline
):
r
"""
Pipeline for text-to-image generation using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
<Tip warning={true}>
This is an experimental pipeline and is likely to change in the future.
</Tip>
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components
=
[
"safety_checker"
,
"feature_extractor"
]
def
__init__
(
self
,
vae
,
text_encoder
,
tokenizer
,
unet
,
scheduler
,
safety_checker
,
feature_extractor
,
requires_safety_checker
:
bool
=
True
,
):
super
().
__init__
()
logger
.
info
(
f
"
{
self
.
__class__
}
is an experimntal pipeline and is likely to change in the future. We recommend to use"
" this pipeline for fast experimentation / iteration if needed, but advice to rely on existing pipelines"
" as defined in https://huggingface.co/docs/diffusers/api/schedulers#implemented-schedulers for"
" production settings."
)
# get correct sigmas from LMS
scheduler
=
LMSDiscreteScheduler
.
from_config
(
scheduler
.
config
)
self
.
register_modules
(
vae
=
vae
,
text_encoder
=
text_encoder
,
tokenizer
=
tokenizer
,
unet
=
unet
,
scheduler
=
scheduler
,
safety_checker
=
safety_checker
,
feature_extractor
=
feature_extractor
,
)
self
.
register_to_config
(
requires_safety_checker
=
requires_safety_checker
)
self
.
vae_scale_factor
=
2
**
(
len
(
self
.
vae
.
config
.
block_out_channels
)
-
1
)
model
=
ModelWrapper
(
unet
,
scheduler
.
alphas_cumprod
)
if
scheduler
.
prediction_type
==
"v_prediction"
:
self
.
k_diffusion_model
=
CompVisVDenoiser
(
model
)
else
:
self
.
k_diffusion_model
=
CompVisDenoiser
(
model
)
def
set_scheduler
(
self
,
scheduler_type
:
str
):
library
=
importlib
.
import_module
(
"k_diffusion"
)
sampling
=
getattr
(
library
,
"sampling"
)
self
.
sampler
=
getattr
(
sampling
,
scheduler_type
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
def
enable_sequential_cpu_offload
(
self
,
gpu_id
=
0
):
r
"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
"""
if
is_accelerate_available
():
from
accelerate
import
cpu_offload
else
:
raise
ImportError
(
"Please install accelerate via `pip install accelerate`"
)
device
=
torch
.
device
(
f
"cuda:
{
gpu_id
}
"
)
for
cpu_offloaded_model
in
[
self
.
unet
,
self
.
text_encoder
,
self
.
vae
]:
if
cpu_offloaded_model
is
not
None
:
cpu_offload
(
cpu_offloaded_model
,
device
)
if
self
.
safety_checker
is
not
None
:
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
# fix by only offloading self.safety_checker for now
cpu_offload
(
self
.
safety_checker
.
vision_model
,
device
)
@
property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def
_execution_device
(
self
):
r
"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if
self
.
device
!=
torch
.
device
(
"meta"
)
or
not
hasattr
(
self
.
unet
,
"_hf_hook"
):
return
self
.
device
for
module
in
self
.
unet
.
modules
():
if
(
hasattr
(
module
,
"_hf_hook"
)
and
hasattr
(
module
.
_hf_hook
,
"execution_device"
)
and
module
.
_hf_hook
.
execution_device
is
not
None
):
return
torch
.
device
(
module
.
_hf_hook
.
execution_device
)
return
self
.
device
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def
_encode_prompt
(
self
,
prompt
,
device
,
num_images_per_prompt
,
do_classifier_free_guidance
,
negative_prompt
):
r
"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
"""
batch_size
=
len
(
prompt
)
if
isinstance
(
prompt
,
list
)
else
1
text_inputs
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
max_length
=
self
.
tokenizer
.
model_max_length
,
truncation
=
True
,
return_tensors
=
"pt"
,
)
text_input_ids
=
text_inputs
.
input_ids
untruncated_ids
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
return_tensors
=
"pt"
).
input_ids
if
not
torch
.
equal
(
text_input_ids
,
untruncated_ids
):
removed_text
=
self
.
tokenizer
.
batch_decode
(
untruncated_ids
[:,
self
.
tokenizer
.
model_max_length
-
1
:
-
1
])
logger
.
warning
(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f
"
{
self
.
tokenizer
.
model_max_length
}
tokens:
{
removed_text
}
"
)
if
hasattr
(
self
.
text_encoder
.
config
,
"use_attention_mask"
)
and
self
.
text_encoder
.
config
.
use_attention_mask
:
attention_mask
=
text_inputs
.
attention_mask
.
to
(
device
)
else
:
attention_mask
=
None
text_embeddings
=
self
.
text_encoder
(
text_input_ids
.
to
(
device
),
attention_mask
=
attention_mask
,
)
text_embeddings
=
text_embeddings
[
0
]
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed
,
seq_len
,
_
=
text_embeddings
.
shape
text_embeddings
=
text_embeddings
.
repeat
(
1
,
num_images_per_prompt
,
1
)
text_embeddings
=
text_embeddings
.
view
(
bs_embed
*
num_images_per_prompt
,
seq_len
,
-
1
)
# get unconditional embeddings for classifier free guidance
if
do_classifier_free_guidance
:
uncond_tokens
:
List
[
str
]
if
negative_prompt
is
None
:
uncond_tokens
=
[
""
]
*
batch_size
elif
type
(
prompt
)
is
not
type
(
negative_prompt
):
raise
TypeError
(
f
"`negative_prompt` should be the same type to `prompt`, but got
{
type
(
negative_prompt
)
}
!="
f
"
{
type
(
prompt
)
}
."
)
elif
isinstance
(
negative_prompt
,
str
):
uncond_tokens
=
[
negative_prompt
]
elif
batch_size
!=
len
(
negative_prompt
):
raise
ValueError
(
f
"`negative_prompt`:
{
negative_prompt
}
has batch size
{
len
(
negative_prompt
)
}
, but `prompt`:"
f
"
{
prompt
}
has batch size
{
batch_size
}
. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else
:
uncond_tokens
=
negative_prompt
max_length
=
text_input_ids
.
shape
[
-
1
]
uncond_input
=
self
.
tokenizer
(
uncond_tokens
,
padding
=
"max_length"
,
max_length
=
max_length
,
truncation
=
True
,
return_tensors
=
"pt"
,
)
if
hasattr
(
self
.
text_encoder
.
config
,
"use_attention_mask"
)
and
self
.
text_encoder
.
config
.
use_attention_mask
:
attention_mask
=
uncond_input
.
attention_mask
.
to
(
device
)
else
:
attention_mask
=
None
uncond_embeddings
=
self
.
text_encoder
(
uncond_input
.
input_ids
.
to
(
device
),
attention_mask
=
attention_mask
,
)
uncond_embeddings
=
uncond_embeddings
[
0
]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len
=
uncond_embeddings
.
shape
[
1
]
uncond_embeddings
=
uncond_embeddings
.
repeat
(
1
,
num_images_per_prompt
,
1
)
uncond_embeddings
=
uncond_embeddings
.
view
(
batch_size
*
num_images_per_prompt
,
seq_len
,
-
1
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings
=
torch
.
cat
([
uncond_embeddings
,
text_embeddings
])
return
text_embeddings
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def
run_safety_checker
(
self
,
image
,
device
,
dtype
):
if
self
.
safety_checker
is
not
None
:
safety_checker_input
=
self
.
feature_extractor
(
self
.
numpy_to_pil
(
image
),
return_tensors
=
"pt"
).
to
(
device
)
image
,
has_nsfw_concept
=
self
.
safety_checker
(
images
=
image
,
clip_input
=
safety_checker_input
.
pixel_values
.
to
(
dtype
)
)
else
:
has_nsfw_concept
=
None
return
image
,
has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def
decode_latents
(
self
,
latents
):
latents
=
1
/
0.18215
*
latents
image
=
self
.
vae
.
decode
(
latents
).
sample
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
float
().
numpy
()
return
image
def
check_inputs
(
self
,
prompt
,
height
,
width
,
callback_steps
):
if
not
isinstance
(
prompt
,
str
)
and
not
isinstance
(
prompt
,
list
):
raise
ValueError
(
f
"`prompt` has to be of type `str` or `list` but is
{
type
(
prompt
)
}
"
)
if
height
%
8
!=
0
or
width
%
8
!=
0
:
raise
ValueError
(
f
"`height` and `width` have to be divisible by 8 but are
{
height
}
and
{
width
}
."
)
if
(
callback_steps
is
None
)
or
(
callback_steps
is
not
None
and
(
not
isinstance
(
callback_steps
,
int
)
or
callback_steps
<=
0
)
):
raise
ValueError
(
f
"`callback_steps` has to be a positive integer but is
{
callback_steps
}
of type"
f
"
{
type
(
callback_steps
)
}
."
)
def
prepare_latents
(
self
,
batch_size
,
num_channels_latents
,
height
,
width
,
dtype
,
device
,
generator
,
latents
=
None
):
shape
=
(
batch_size
,
num_channels_latents
,
height
//
self
.
vae_scale_factor
,
width
//
self
.
vae_scale_factor
)
if
latents
is
None
:
if
device
.
type
==
"mps"
:
# randn does not work reproducibly on mps
latents
=
torch
.
randn
(
shape
,
generator
=
generator
,
device
=
"cpu"
,
dtype
=
dtype
).
to
(
device
)
else
:
latents
=
torch
.
randn
(
shape
,
generator
=
generator
,
device
=
device
,
dtype
=
dtype
)
else
:
if
latents
.
shape
!=
shape
:
raise
ValueError
(
f
"Unexpected latents shape, got
{
latents
.
shape
}
, expected
{
shape
}
"
)
latents
=
latents
.
to
(
device
)
# scale the initial noise by the standard deviation required by the scheduler
return
latents
@
torch
.
no_grad
()
def
__call__
(
self
,
prompt
:
Union
[
str
,
List
[
str
]],
height
:
int
=
512
,
width
:
int
=
512
,
num_inference_steps
:
int
=
50
,
guidance_scale
:
float
=
7.5
,
negative_prompt
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
num_images_per_prompt
:
Optional
[
int
]
=
1
,
eta
:
float
=
0.0
,
generator
:
Optional
[
torch
.
Generator
]
=
None
,
latents
:
Optional
[
torch
.
FloatTensor
]
=
None
,
output_type
:
Optional
[
str
]
=
"pil"
,
return_dict
:
bool
=
True
,
callback
:
Optional
[
Callable
[[
int
,
int
,
torch
.
FloatTensor
],
None
]]
=
None
,
callback_steps
:
Optional
[
int
]
=
1
,
):
r
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 1. Check inputs. Raise error if not correct
self
.
check_inputs
(
prompt
,
height
,
width
,
callback_steps
)
# 2. Define call parameters
batch_size
=
1
if
isinstance
(
prompt
,
str
)
else
len
(
prompt
)
device
=
self
.
_execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance
=
True
if
guidance_scale
<=
1.0
:
raise
ValueError
(
"has to use guidance_scale"
)
# 3. Encode input prompt
text_embeddings
=
self
.
_encode_prompt
(
prompt
,
device
,
num_images_per_prompt
,
do_classifier_free_guidance
,
negative_prompt
)
# 4. Prepare timesteps
self
.
scheduler
.
set_timesteps
(
num_inference_steps
,
device
=
text_embeddings
.
device
)
sigmas
=
self
.
scheduler
.
sigmas
sigmas
=
sigmas
.
to
(
text_embeddings
.
dtype
)
# 5. Prepare latent variables
num_channels_latents
=
self
.
unet
.
in_channels
latents
=
self
.
prepare_latents
(
batch_size
*
num_images_per_prompt
,
num_channels_latents
,
height
,
width
,
text_embeddings
.
dtype
,
device
,
generator
,
latents
,
)
latents
=
latents
*
sigmas
[
0
]
self
.
k_diffusion_model
.
sigmas
=
self
.
k_diffusion_model
.
sigmas
.
to
(
latents
.
device
)
self
.
k_diffusion_model
.
log_sigmas
=
self
.
k_diffusion_model
.
log_sigmas
.
to
(
latents
.
device
)
# 6. Define model function
def
model_fn
(
x
,
t
):
latent_model_input
=
torch
.
cat
([
x
]
*
2
)
noise_pred
=
self
.
k_diffusion_model
(
latent_model_input
,
t
,
cond
=
text_embeddings
)
noise_pred_uncond
,
noise_pred_text
=
noise_pred
.
chunk
(
2
)
noise_pred
=
noise_pred_uncond
+
guidance_scale
*
(
noise_pred_text
-
noise_pred_uncond
)
return
noise_pred
# 7. Run k-diffusion solver
latents
=
self
.
sampler
(
model_fn
,
latents
,
sigmas
)
# 8. Post-processing
image
=
self
.
decode_latents
(
latents
)
# 9. Run safety checker
image
,
has_nsfw_concept
=
self
.
run_safety_checker
(
image
,
device
,
text_embeddings
.
dtype
)
# 10. Convert to PIL
if
output_type
==
"pil"
:
image
=
self
.
numpy_to_pil
(
image
)
if
not
return_dict
:
return
(
image
,
has_nsfw_concept
)
return
StableDiffusionPipelineOutput
(
images
=
image
,
nsfw_content_detected
=
has_nsfw_concept
)
src/diffusers/utils/__init__.py
View file @
a643c630
...
@@ -29,6 +29,7 @@ from .import_utils import (
...
@@ -29,6 +29,7 @@ from .import_utils import (
is_accelerate_available
,
is_accelerate_available
,
is_flax_available
,
is_flax_available
,
is_inflect_available
,
is_inflect_available
,
is_k_diffusion_available
,
is_librosa_available
,
is_librosa_available
,
is_modelcards_available
,
is_modelcards_available
,
is_onnx_available
,
is_onnx_available
,
...
...
src/diffusers/utils/dummy_torch_and_transformers_and_k_diffusion_objects.py
0 → 100644
View file @
a643c630
# This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa
from
..utils
import
DummyObject
,
requires_backends
class
StableDiffusionKDiffusionPipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
,
"transformers"
,
"k_diffusion"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
,
"transformers"
,
"k_diffusion"
])
@
classmethod
def
from_config
(
cls
,
*
args
,
**
kwargs
):
requires_backends
(
cls
,
[
"torch"
,
"transformers"
,
"k_diffusion"
])
@
classmethod
def
from_pretrained
(
cls
,
*
args
,
**
kwargs
):
requires_backends
(
cls
,
[
"torch"
,
"transformers"
,
"k_diffusion"
])
src/diffusers/utils/import_utils.py
View file @
a643c630
...
@@ -210,6 +210,13 @@ try:
...
@@ -210,6 +210,13 @@ try:
except
importlib_metadata
.
PackageNotFoundError
:
except
importlib_metadata
.
PackageNotFoundError
:
_xformers_available
=
False
_xformers_available
=
False
_k_diffusion_available
=
importlib
.
util
.
find_spec
(
"k_diffusion"
)
is
not
None
try
:
_k_diffusion_version
=
importlib_metadata
.
version
(
"k_diffusion"
)
logger
.
debug
(
f
"Successfully imported k-diffusion version
{
_k_diffusion_version
}
"
)
except
importlib_metadata
.
PackageNotFoundError
:
_k_diffusion_available
=
False
def
is_torch_available
():
def
is_torch_available
():
return
_torch_available
return
_torch_available
...
@@ -263,6 +270,10 @@ def is_accelerate_available():
...
@@ -263,6 +270,10 @@ def is_accelerate_available():
return
_accelerate_available
return
_accelerate_available
def
is_k_diffusion_available
():
return
_k_diffusion_available
# docstyle-ignore
# docstyle-ignore
FLAX_IMPORT_ERROR
=
"""
FLAX_IMPORT_ERROR
=
"""
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
...
@@ -317,6 +328,12 @@ UNIDECODE_IMPORT_ERROR = """
...
@@ -317,6 +328,12 @@ UNIDECODE_IMPORT_ERROR = """
Unidecode`
Unidecode`
"""
"""
# docstyle-ignore
K_DIFFUSION_IMPORT_ERROR
=
"""
{0} requires the k-diffusion library but it was not found in your environment. You can install it with pip: `pip
install k-diffusion`
"""
BACKENDS_MAPPING
=
OrderedDict
(
BACKENDS_MAPPING
=
OrderedDict
(
[
[
...
@@ -329,6 +346,7 @@ BACKENDS_MAPPING = OrderedDict(
...
@@ -329,6 +346,7 @@ BACKENDS_MAPPING = OrderedDict(
(
"transformers"
,
(
is_transformers_available
,
TRANSFORMERS_IMPORT_ERROR
)),
(
"transformers"
,
(
is_transformers_available
,
TRANSFORMERS_IMPORT_ERROR
)),
(
"unidecode"
,
(
is_unidecode_available
,
UNIDECODE_IMPORT_ERROR
)),
(
"unidecode"
,
(
is_unidecode_available
,
UNIDECODE_IMPORT_ERROR
)),
(
"librosa"
,
(
is_librosa_available
,
LIBROSA_IMPORT_ERROR
)),
(
"librosa"
,
(
is_librosa_available
,
LIBROSA_IMPORT_ERROR
)),
(
"k_diffusion"
,
(
is_k_diffusion_available
,
K_DIFFUSION_IMPORT_ERROR
)),
]
]
)
)
...
...
tests/pipelines/stable_diffusion/test_stable_diffusion_k_diffusion.py
0 → 100644
View file @
a643c630
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
gc
import
unittest
import
numpy
as
np
import
torch
from
diffusers
import
StableDiffusionKDiffusionPipeline
from
diffusers.utils
import
slow
,
torch_device
from
diffusers.utils.testing_utils
import
require_torch_gpu
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
@
slow
@
require_torch_gpu
class
StableDiffusionPipelineIntegrationTests
(
unittest
.
TestCase
):
def
tearDown
(
self
):
# clean up the VRAM after each test
super
().
tearDown
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
def
test_stable_diffusion_1
(
self
):
sd_pipe
=
StableDiffusionKDiffusionPipeline
.
from_pretrained
(
"CompVis/stable-diffusion-v1-4"
)
sd_pipe
=
sd_pipe
.
to
(
torch_device
)
sd_pipe
.
set_progress_bar_config
(
disable
=
None
)
sd_pipe
.
set_scheduler
(
"sample_euler"
)
prompt
=
"A painting of a squirrel eating a burger"
generator
=
torch
.
Generator
(
device
=
torch_device
).
manual_seed
(
0
)
output
=
sd_pipe
([
prompt
],
generator
=
generator
,
guidance_scale
=
9.0
,
num_inference_steps
=
20
,
output_type
=
"np"
)
image
=
output
.
images
image_slice
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
assert
image
.
shape
==
(
1
,
512
,
512
,
3
)
expected_slice
=
np
.
array
([
0.8887
,
0.915
,
0.91
,
0.894
,
0.909
,
0.912
,
0.919
,
0.925
,
0.883
])
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
def
test_stable_diffusion_2
(
self
):
sd_pipe
=
StableDiffusionKDiffusionPipeline
.
from_pretrained
(
"stabilityai/stable-diffusion-2-1-base"
)
sd_pipe
=
sd_pipe
.
to
(
torch_device
)
sd_pipe
.
set_progress_bar_config
(
disable
=
None
)
sd_pipe
.
set_scheduler
(
"sample_euler"
)
prompt
=
"A painting of a squirrel eating a burger"
generator
=
torch
.
Generator
(
device
=
torch_device
).
manual_seed
(
0
)
output
=
sd_pipe
([
prompt
],
generator
=
generator
,
guidance_scale
=
9.0
,
num_inference_steps
=
20
,
output_type
=
"np"
)
image
=
output
.
images
image_slice
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
assert
image
.
shape
==
(
1
,
512
,
512
,
3
)
expected_slice
=
np
.
array
(
[
0.826810
,
0.81958747
,
0.8510199
,
0.8376758
,
0.83958465
,
0.8682068
,
0.84370345
,
0.85251087
,
0.85884345
]
)
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment