Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
a643c630
Unverified
Commit
a643c630
authored
Dec 08, 2022
by
Patrick von Platen
Committed by
GitHub
Dec 08, 2022
Browse files
[K Diffusion] Add k diffusion sampler natively (#1603)
* uP * uP
parent
326de419
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
602 additions
and
2 deletions
+602
-2
examples/community/README.md
examples/community/README.md
+2
-2
examples/community/sd_text2img_k_diffusion.py
examples/community/sd_text2img_k_diffusion.py
+5
-0
hi
hi
+1
-0
setup.py
setup.py
+2
-0
src/diffusers/__init__.py
src/diffusers/__init__.py
+6
-0
src/diffusers/dependency_versions_table.py
src/diffusers/dependency_versions_table.py
+1
-0
src/diffusers/pipelines/__init__.py
src/diffusers/pipelines/__init__.py
+4
-0
src/diffusers/pipelines/stable_diffusion/__init__.py
src/diffusers/pipelines/stable_diffusion/__init__.py
+4
-0
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py
...stable_diffusion/pipeline_stable_diffusion_k_diffusion.py
+462
-0
src/diffusers/utils/__init__.py
src/diffusers/utils/__init__.py
+1
-0
src/diffusers/utils/dummy_torch_and_transformers_and_k_diffusion_objects.py
...s/dummy_torch_and_transformers_and_k_diffusion_objects.py
+19
-0
src/diffusers/utils/import_utils.py
src/diffusers/utils/import_utils.py
+18
-0
tests/pipelines/stable_diffusion/test_stable_diffusion_k_diffusion.py
...nes/stable_diffusion/test_stable_diffusion_k_diffusion.py
+77
-0
No files found.
examples/community/README.md
View file @
a643c630
...
@@ -686,7 +686,7 @@ pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom
...
@@ -686,7 +686,7 @@ pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom
pipe
=
pipe
.
to
(
"cuda"
)
pipe
=
pipe
.
to
(
"cuda"
)
prompt
=
"an astronaut riding a horse on mars"
prompt
=
"an astronaut riding a horse on mars"
pipe
.
set_s
amp
ler
(
"sample_heun"
)
pipe
.
set_s
chedu
ler
(
"sample_heun"
)
generator
=
torch
.
Generator
(
device
=
"cuda"
).
manual_seed
(
seed
)
generator
=
torch
.
Generator
(
device
=
"cuda"
).
manual_seed
(
seed
)
image
=
pipe
(
prompt
,
generator
=
generator
,
num_inference_steps
=
20
).
images
[
0
]
image
=
pipe
(
prompt
,
generator
=
generator
,
num_inference_steps
=
20
).
images
[
0
]
...
@@ -721,7 +721,7 @@ pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom
...
@@ -721,7 +721,7 @@ pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom
pipe
.
scheduler
=
EulerDiscreteScheduler
.
from_config
(
pipe
.
scheduler
.
config
)
pipe
.
scheduler
=
EulerDiscreteScheduler
.
from_config
(
pipe
.
scheduler
.
config
)
pipe
=
pipe
.
to
(
"cuda"
)
pipe
=
pipe
.
to
(
"cuda"
)
pipe
.
set_s
amp
ler
(
"sample_euler"
)
pipe
.
set_s
chedu
ler
(
"sample_euler"
)
generator
=
torch
.
Generator
(
device
=
"cuda"
).
manual_seed
(
seed
)
generator
=
torch
.
Generator
(
device
=
"cuda"
).
manual_seed
(
seed
)
image
=
pipe
(
prompt
,
generator
=
generator
,
num_inference_steps
=
50
).
images
[
0
]
image
=
pipe
(
prompt
,
generator
=
generator
,
num_inference_steps
=
50
).
images
[
0
]
```
```
...
...
examples/community/sd_text2img_k_diffusion.py
View file @
a643c630
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
import
importlib
import
importlib
import
warnings
from
typing
import
Callable
,
List
,
Optional
,
Union
from
typing
import
Callable
,
List
,
Optional
,
Union
import
torch
import
torch
...
@@ -111,6 +112,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
...
@@ -111,6 +112,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
self
.
k_diffusion_model
=
CompVisDenoiser
(
model
)
self
.
k_diffusion_model
=
CompVisDenoiser
(
model
)
def
set_sampler
(
self
,
scheduler_type
:
str
):
def
set_sampler
(
self
,
scheduler_type
:
str
):
warnings
.
warn
(
"The `set_sampler` method is deprecated, please use `set_scheduler` instead."
)
return
self
.
set_scheduler
(
scheduler_type
)
def
set_scheduler
(
self
,
scheduler_type
:
str
):
library
=
importlib
.
import_module
(
"k_diffusion"
)
library
=
importlib
.
import_module
(
"k_diffusion"
)
sampling
=
getattr
(
library
,
"sampling"
)
sampling
=
getattr
(
library
,
"sampling"
)
self
.
sampler
=
getattr
(
sampling
,
scheduler_type
)
self
.
sampler
=
getattr
(
sampling
,
scheduler_type
)
...
...
hi
0 → 100644
View file @
a643c630
setup.py
View file @
a643c630
...
@@ -91,6 +91,7 @@ _deps = [
...
@@ -91,6 +91,7 @@ _deps = [
"isort>=5.5.4"
,
"isort>=5.5.4"
,
"jax>=0.2.8,!=0.3.2"
,
"jax>=0.2.8,!=0.3.2"
,
"jaxlib>=0.1.65"
,
"jaxlib>=0.1.65"
,
"k-diffusion"
,
"librosa"
,
"librosa"
,
"modelcards>=0.1.4"
,
"modelcards>=0.1.4"
,
"numpy"
,
"numpy"
,
...
@@ -182,6 +183,7 @@ extras["docs"] = deps_list("hf-doc-builder")
...
@@ -182,6 +183,7 @@ extras["docs"] = deps_list("hf-doc-builder")
extras
[
"training"
]
=
deps_list
(
"accelerate"
,
"datasets"
,
"tensorboard"
,
"modelcards"
)
extras
[
"training"
]
=
deps_list
(
"accelerate"
,
"datasets"
,
"tensorboard"
,
"modelcards"
)
extras
[
"test"
]
=
deps_list
(
extras
[
"test"
]
=
deps_list
(
"datasets"
,
"datasets"
,
"k-diffusion"
,
"librosa"
,
"librosa"
,
"parameterized"
,
"parameterized"
,
"pytest"
,
"pytest"
,
...
...
src/diffusers/__init__.py
View file @
a643c630
...
@@ -5,6 +5,7 @@ from .onnx_utils import OnnxRuntimeModel
...
@@ -5,6 +5,7 @@ from .onnx_utils import OnnxRuntimeModel
from
.utils
import
(
from
.utils
import
(
is_flax_available
,
is_flax_available
,
is_inflect_available
,
is_inflect_available
,
is_k_diffusion_available
,
is_onnx_available
,
is_onnx_available
,
is_scipy_available
,
is_scipy_available
,
is_torch_available
,
is_torch_available
,
...
@@ -90,6 +91,11 @@ if is_torch_available() and is_transformers_available():
...
@@ -90,6 +91,11 @@ if is_torch_available() and is_transformers_available():
else
:
else
:
from
.utils.dummy_torch_and_transformers_objects
import
*
# noqa F403
from
.utils.dummy_torch_and_transformers_objects
import
*
# noqa F403
if
is_torch_available
()
and
is_transformers_available
()
and
is_k_diffusion_available
():
from
.pipelines
import
StableDiffusionKDiffusionPipeline
else
:
from
.utils.dummy_torch_and_transformers_and_k_diffusion_objects
import
*
# noqa F403
if
is_torch_available
()
and
is_transformers_available
()
and
is_onnx_available
():
if
is_torch_available
()
and
is_transformers_available
()
and
is_onnx_available
():
from
.pipelines
import
(
from
.pipelines
import
(
OnnxStableDiffusionImg2ImgPipeline
,
OnnxStableDiffusionImg2ImgPipeline
,
...
...
src/diffusers/dependency_versions_table.py
View file @
a643c630
...
@@ -15,6 +15,7 @@ deps = {
...
@@ -15,6 +15,7 @@ deps = {
"isort"
:
"isort>=5.5.4"
,
"isort"
:
"isort>=5.5.4"
,
"jax"
:
"jax>=0.2.8,!=0.3.2"
,
"jax"
:
"jax>=0.2.8,!=0.3.2"
,
"jaxlib"
:
"jaxlib>=0.1.65"
,
"jaxlib"
:
"jaxlib>=0.1.65"
,
"k-diffusion"
:
"k-diffusion"
,
"librosa"
:
"librosa"
,
"librosa"
:
"librosa"
,
"modelcards"
:
"modelcards>=0.1.4"
,
"modelcards"
:
"modelcards>=0.1.4"
,
"numpy"
:
"numpy"
,
"numpy"
:
"numpy"
,
...
...
src/diffusers/pipelines/__init__.py
View file @
a643c630
from
..utils
import
(
from
..utils
import
(
is_flax_available
,
is_flax_available
,
is_k_diffusion_available
,
is_librosa_available
,
is_librosa_available
,
is_onnx_available
,
is_onnx_available
,
is_torch_available
,
is_torch_available
,
...
@@ -56,5 +57,8 @@ if is_transformers_available() and is_onnx_available():
...
@@ -56,5 +57,8 @@ if is_transformers_available() and is_onnx_available():
StableDiffusionOnnxPipeline
,
StableDiffusionOnnxPipeline
,
)
)
if
is_torch_available
()
and
is_transformers_available
()
and
is_k_diffusion_available
():
from
.stable_diffusion
import
StableDiffusionKDiffusionPipeline
if
is_transformers_available
()
and
is_flax_available
():
if
is_transformers_available
()
and
is_flax_available
():
from
.stable_diffusion
import
FlaxStableDiffusionPipeline
from
.stable_diffusion
import
FlaxStableDiffusionPipeline
src/diffusers/pipelines/stable_diffusion/__init__.py
View file @
a643c630
...
@@ -9,6 +9,7 @@ from PIL import Image
...
@@ -9,6 +9,7 @@ from PIL import Image
from
...utils
import
(
from
...utils
import
(
BaseOutput
,
BaseOutput
,
is_flax_available
,
is_flax_available
,
is_k_diffusion_available
,
is_onnx_available
,
is_onnx_available
,
is_torch_available
,
is_torch_available
,
is_transformers_available
,
is_transformers_available
,
...
@@ -48,6 +49,9 @@ if is_transformers_available() and is_torch_available() and is_transformers_vers
...
@@ -48,6 +49,9 @@ if is_transformers_available() and is_torch_available() and is_transformers_vers
else
:
else
:
from
...utils.dummy_torch_and_transformers_objects
import
StableDiffusionImageVariationPipeline
from
...utils.dummy_torch_and_transformers_objects
import
StableDiffusionImageVariationPipeline
if
is_transformers_available
()
and
is_torch_available
()
and
is_k_diffusion_available
():
from
.pipeline_stable_diffusion_k_diffusion
import
StableDiffusionKDiffusionPipeline
if
is_transformers_available
()
and
is_onnx_available
():
if
is_transformers_available
()
and
is_onnx_available
():
from
.pipeline_onnx_stable_diffusion
import
OnnxStableDiffusionPipeline
,
StableDiffusionOnnxPipeline
from
.pipeline_onnx_stable_diffusion
import
OnnxStableDiffusionPipeline
,
StableDiffusionOnnxPipeline
from
.pipeline_onnx_stable_diffusion_img2img
import
OnnxStableDiffusionImg2ImgPipeline
from
.pipeline_onnx_stable_diffusion_img2img
import
OnnxStableDiffusionImg2ImgPipeline
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py
0 → 100755
View file @
a643c630
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
importlib
from
typing
import
Callable
,
List
,
Optional
,
Union
import
torch
from
k_diffusion.external
import
CompVisDenoiser
,
CompVisVDenoiser
from
...
import
DiffusionPipeline
from
...schedulers
import
LMSDiscreteScheduler
from
...utils
import
is_accelerate_available
,
logging
from
.
import
StableDiffusionPipelineOutput
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
class
ModelWrapper
:
def
__init__
(
self
,
model
,
alphas_cumprod
):
self
.
model
=
model
self
.
alphas_cumprod
=
alphas_cumprod
def
apply_model
(
self
,
*
args
,
**
kwargs
):
if
len
(
args
)
==
3
:
encoder_hidden_states
=
args
[
-
1
]
args
=
args
[:
2
]
if
kwargs
.
get
(
"cond"
,
None
)
is
not
None
:
encoder_hidden_states
=
kwargs
.
pop
(
"cond"
)
return
self
.
model
(
*
args
,
encoder_hidden_states
=
encoder_hidden_states
,
**
kwargs
).
sample
class
StableDiffusionKDiffusionPipeline
(
DiffusionPipeline
):
r
"""
Pipeline for text-to-image generation using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
<Tip warning={true}>
This is an experimental pipeline and is likely to change in the future.
</Tip>
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components
=
[
"safety_checker"
,
"feature_extractor"
]
def
__init__
(
self
,
vae
,
text_encoder
,
tokenizer
,
unet
,
scheduler
,
safety_checker
,
feature_extractor
,
requires_safety_checker
:
bool
=
True
,
):
super
().
__init__
()
logger
.
info
(
f
"
{
self
.
__class__
}
is an experimntal pipeline and is likely to change in the future. We recommend to use"
" this pipeline for fast experimentation / iteration if needed, but advice to rely on existing pipelines"
" as defined in https://huggingface.co/docs/diffusers/api/schedulers#implemented-schedulers for"
" production settings."
)
# get correct sigmas from LMS
scheduler
=
LMSDiscreteScheduler
.
from_config
(
scheduler
.
config
)
self
.
register_modules
(
vae
=
vae
,
text_encoder
=
text_encoder
,
tokenizer
=
tokenizer
,
unet
=
unet
,
scheduler
=
scheduler
,
safety_checker
=
safety_checker
,
feature_extractor
=
feature_extractor
,
)
self
.
register_to_config
(
requires_safety_checker
=
requires_safety_checker
)
self
.
vae_scale_factor
=
2
**
(
len
(
self
.
vae
.
config
.
block_out_channels
)
-
1
)
model
=
ModelWrapper
(
unet
,
scheduler
.
alphas_cumprod
)
if
scheduler
.
prediction_type
==
"v_prediction"
:
self
.
k_diffusion_model
=
CompVisVDenoiser
(
model
)
else
:
self
.
k_diffusion_model
=
CompVisDenoiser
(
model
)
def
set_scheduler
(
self
,
scheduler_type
:
str
):
library
=
importlib
.
import_module
(
"k_diffusion"
)
sampling
=
getattr
(
library
,
"sampling"
)
self
.
sampler
=
getattr
(
sampling
,
scheduler_type
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
def
enable_sequential_cpu_offload
(
self
,
gpu_id
=
0
):
r
"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
"""
if
is_accelerate_available
():
from
accelerate
import
cpu_offload
else
:
raise
ImportError
(
"Please install accelerate via `pip install accelerate`"
)
device
=
torch
.
device
(
f
"cuda:
{
gpu_id
}
"
)
for
cpu_offloaded_model
in
[
self
.
unet
,
self
.
text_encoder
,
self
.
vae
]:
if
cpu_offloaded_model
is
not
None
:
cpu_offload
(
cpu_offloaded_model
,
device
)
if
self
.
safety_checker
is
not
None
:
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
# fix by only offloading self.safety_checker for now
cpu_offload
(
self
.
safety_checker
.
vision_model
,
device
)
@
property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def
_execution_device
(
self
):
r
"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if
self
.
device
!=
torch
.
device
(
"meta"
)
or
not
hasattr
(
self
.
unet
,
"_hf_hook"
):
return
self
.
device
for
module
in
self
.
unet
.
modules
():
if
(
hasattr
(
module
,
"_hf_hook"
)
and
hasattr
(
module
.
_hf_hook
,
"execution_device"
)
and
module
.
_hf_hook
.
execution_device
is
not
None
):
return
torch
.
device
(
module
.
_hf_hook
.
execution_device
)
return
self
.
device
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def
_encode_prompt
(
self
,
prompt
,
device
,
num_images_per_prompt
,
do_classifier_free_guidance
,
negative_prompt
):
r
"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
"""
batch_size
=
len
(
prompt
)
if
isinstance
(
prompt
,
list
)
else
1
text_inputs
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
max_length
=
self
.
tokenizer
.
model_max_length
,
truncation
=
True
,
return_tensors
=
"pt"
,
)
text_input_ids
=
text_inputs
.
input_ids
untruncated_ids
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
return_tensors
=
"pt"
).
input_ids
if
not
torch
.
equal
(
text_input_ids
,
untruncated_ids
):
removed_text
=
self
.
tokenizer
.
batch_decode
(
untruncated_ids
[:,
self
.
tokenizer
.
model_max_length
-
1
:
-
1
])
logger
.
warning
(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f
"
{
self
.
tokenizer
.
model_max_length
}
tokens:
{
removed_text
}
"
)
if
hasattr
(
self
.
text_encoder
.
config
,
"use_attention_mask"
)
and
self
.
text_encoder
.
config
.
use_attention_mask
:
attention_mask
=
text_inputs
.
attention_mask
.
to
(
device
)
else
:
attention_mask
=
None
text_embeddings
=
self
.
text_encoder
(
text_input_ids
.
to
(
device
),
attention_mask
=
attention_mask
,
)
text_embeddings
=
text_embeddings
[
0
]
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed
,
seq_len
,
_
=
text_embeddings
.
shape
text_embeddings
=
text_embeddings
.
repeat
(
1
,
num_images_per_prompt
,
1
)
text_embeddings
=
text_embeddings
.
view
(
bs_embed
*
num_images_per_prompt
,
seq_len
,
-
1
)
# get unconditional embeddings for classifier free guidance
if
do_classifier_free_guidance
:
uncond_tokens
:
List
[
str
]
if
negative_prompt
is
None
:
uncond_tokens
=
[
""
]
*
batch_size
elif
type
(
prompt
)
is
not
type
(
negative_prompt
):
raise
TypeError
(
f
"`negative_prompt` should be the same type to `prompt`, but got
{
type
(
negative_prompt
)
}
!="
f
"
{
type
(
prompt
)
}
."
)
elif
isinstance
(
negative_prompt
,
str
):
uncond_tokens
=
[
negative_prompt
]
elif
batch_size
!=
len
(
negative_prompt
):
raise
ValueError
(
f
"`negative_prompt`:
{
negative_prompt
}
has batch size
{
len
(
negative_prompt
)
}
, but `prompt`:"
f
"
{
prompt
}
has batch size
{
batch_size
}
. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else
:
uncond_tokens
=
negative_prompt
max_length
=
text_input_ids
.
shape
[
-
1
]
uncond_input
=
self
.
tokenizer
(
uncond_tokens
,
padding
=
"max_length"
,
max_length
=
max_length
,
truncation
=
True
,
return_tensors
=
"pt"
,
)
if
hasattr
(
self
.
text_encoder
.
config
,
"use_attention_mask"
)
and
self
.
text_encoder
.
config
.
use_attention_mask
:
attention_mask
=
uncond_input
.
attention_mask
.
to
(
device
)
else
:
attention_mask
=
None
uncond_embeddings
=
self
.
text_encoder
(
uncond_input
.
input_ids
.
to
(
device
),
attention_mask
=
attention_mask
,
)
uncond_embeddings
=
uncond_embeddings
[
0
]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len
=
uncond_embeddings
.
shape
[
1
]
uncond_embeddings
=
uncond_embeddings
.
repeat
(
1
,
num_images_per_prompt
,
1
)
uncond_embeddings
=
uncond_embeddings
.
view
(
batch_size
*
num_images_per_prompt
,
seq_len
,
-
1
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings
=
torch
.
cat
([
uncond_embeddings
,
text_embeddings
])
return
text_embeddings
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def
run_safety_checker
(
self
,
image
,
device
,
dtype
):
if
self
.
safety_checker
is
not
None
:
safety_checker_input
=
self
.
feature_extractor
(
self
.
numpy_to_pil
(
image
),
return_tensors
=
"pt"
).
to
(
device
)
image
,
has_nsfw_concept
=
self
.
safety_checker
(
images
=
image
,
clip_input
=
safety_checker_input
.
pixel_values
.
to
(
dtype
)
)
else
:
has_nsfw_concept
=
None
return
image
,
has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def
decode_latents
(
self
,
latents
):
latents
=
1
/
0.18215
*
latents
image
=
self
.
vae
.
decode
(
latents
).
sample
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
float
().
numpy
()
return
image
def
check_inputs
(
self
,
prompt
,
height
,
width
,
callback_steps
):
if
not
isinstance
(
prompt
,
str
)
and
not
isinstance
(
prompt
,
list
):
raise
ValueError
(
f
"`prompt` has to be of type `str` or `list` but is
{
type
(
prompt
)
}
"
)
if
height
%
8
!=
0
or
width
%
8
!=
0
:
raise
ValueError
(
f
"`height` and `width` have to be divisible by 8 but are
{
height
}
and
{
width
}
."
)
if
(
callback_steps
is
None
)
or
(
callback_steps
is
not
None
and
(
not
isinstance
(
callback_steps
,
int
)
or
callback_steps
<=
0
)
):
raise
ValueError
(
f
"`callback_steps` has to be a positive integer but is
{
callback_steps
}
of type"
f
"
{
type
(
callback_steps
)
}
."
)
def
prepare_latents
(
self
,
batch_size
,
num_channels_latents
,
height
,
width
,
dtype
,
device
,
generator
,
latents
=
None
):
shape
=
(
batch_size
,
num_channels_latents
,
height
//
self
.
vae_scale_factor
,
width
//
self
.
vae_scale_factor
)
if
latents
is
None
:
if
device
.
type
==
"mps"
:
# randn does not work reproducibly on mps
latents
=
torch
.
randn
(
shape
,
generator
=
generator
,
device
=
"cpu"
,
dtype
=
dtype
).
to
(
device
)
else
:
latents
=
torch
.
randn
(
shape
,
generator
=
generator
,
device
=
device
,
dtype
=
dtype
)
else
:
if
latents
.
shape
!=
shape
:
raise
ValueError
(
f
"Unexpected latents shape, got
{
latents
.
shape
}
, expected
{
shape
}
"
)
latents
=
latents
.
to
(
device
)
# scale the initial noise by the standard deviation required by the scheduler
return
latents
@
torch
.
no_grad
()
def
__call__
(
self
,
prompt
:
Union
[
str
,
List
[
str
]],
height
:
int
=
512
,
width
:
int
=
512
,
num_inference_steps
:
int
=
50
,
guidance_scale
:
float
=
7.5
,
negative_prompt
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
num_images_per_prompt
:
Optional
[
int
]
=
1
,
eta
:
float
=
0.0
,
generator
:
Optional
[
torch
.
Generator
]
=
None
,
latents
:
Optional
[
torch
.
FloatTensor
]
=
None
,
output_type
:
Optional
[
str
]
=
"pil"
,
return_dict
:
bool
=
True
,
callback
:
Optional
[
Callable
[[
int
,
int
,
torch
.
FloatTensor
],
None
]]
=
None
,
callback_steps
:
Optional
[
int
]
=
1
,
):
r
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 1. Check inputs. Raise error if not correct
self
.
check_inputs
(
prompt
,
height
,
width
,
callback_steps
)
# 2. Define call parameters
batch_size
=
1
if
isinstance
(
prompt
,
str
)
else
len
(
prompt
)
device
=
self
.
_execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance
=
True
if
guidance_scale
<=
1.0
:
raise
ValueError
(
"has to use guidance_scale"
)
# 3. Encode input prompt
text_embeddings
=
self
.
_encode_prompt
(
prompt
,
device
,
num_images_per_prompt
,
do_classifier_free_guidance
,
negative_prompt
)
# 4. Prepare timesteps
self
.
scheduler
.
set_timesteps
(
num_inference_steps
,
device
=
text_embeddings
.
device
)
sigmas
=
self
.
scheduler
.
sigmas
sigmas
=
sigmas
.
to
(
text_embeddings
.
dtype
)
# 5. Prepare latent variables
num_channels_latents
=
self
.
unet
.
in_channels
latents
=
self
.
prepare_latents
(
batch_size
*
num_images_per_prompt
,
num_channels_latents
,
height
,
width
,
text_embeddings
.
dtype
,
device
,
generator
,
latents
,
)
latents
=
latents
*
sigmas
[
0
]
self
.
k_diffusion_model
.
sigmas
=
self
.
k_diffusion_model
.
sigmas
.
to
(
latents
.
device
)
self
.
k_diffusion_model
.
log_sigmas
=
self
.
k_diffusion_model
.
log_sigmas
.
to
(
latents
.
device
)
# 6. Define model function
def
model_fn
(
x
,
t
):
latent_model_input
=
torch
.
cat
([
x
]
*
2
)
noise_pred
=
self
.
k_diffusion_model
(
latent_model_input
,
t
,
cond
=
text_embeddings
)
noise_pred_uncond
,
noise_pred_text
=
noise_pred
.
chunk
(
2
)
noise_pred
=
noise_pred_uncond
+
guidance_scale
*
(
noise_pred_text
-
noise_pred_uncond
)
return
noise_pred
# 7. Run k-diffusion solver
latents
=
self
.
sampler
(
model_fn
,
latents
,
sigmas
)
# 8. Post-processing
image
=
self
.
decode_latents
(
latents
)
# 9. Run safety checker
image
,
has_nsfw_concept
=
self
.
run_safety_checker
(
image
,
device
,
text_embeddings
.
dtype
)
# 10. Convert to PIL
if
output_type
==
"pil"
:
image
=
self
.
numpy_to_pil
(
image
)
if
not
return_dict
:
return
(
image
,
has_nsfw_concept
)
return
StableDiffusionPipelineOutput
(
images
=
image
,
nsfw_content_detected
=
has_nsfw_concept
)
src/diffusers/utils/__init__.py
View file @
a643c630
...
@@ -29,6 +29,7 @@ from .import_utils import (
...
@@ -29,6 +29,7 @@ from .import_utils import (
is_accelerate_available
,
is_accelerate_available
,
is_flax_available
,
is_flax_available
,
is_inflect_available
,
is_inflect_available
,
is_k_diffusion_available
,
is_librosa_available
,
is_librosa_available
,
is_modelcards_available
,
is_modelcards_available
,
is_onnx_available
,
is_onnx_available
,
...
...
src/diffusers/utils/dummy_torch_and_transformers_and_k_diffusion_objects.py
0 → 100644
View file @
a643c630
# This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa
from
..utils
import
DummyObject
,
requires_backends
class
StableDiffusionKDiffusionPipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
,
"transformers"
,
"k_diffusion"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
,
"transformers"
,
"k_diffusion"
])
@
classmethod
def
from_config
(
cls
,
*
args
,
**
kwargs
):
requires_backends
(
cls
,
[
"torch"
,
"transformers"
,
"k_diffusion"
])
@
classmethod
def
from_pretrained
(
cls
,
*
args
,
**
kwargs
):
requires_backends
(
cls
,
[
"torch"
,
"transformers"
,
"k_diffusion"
])
src/diffusers/utils/import_utils.py
View file @
a643c630
...
@@ -210,6 +210,13 @@ try:
...
@@ -210,6 +210,13 @@ try:
except
importlib_metadata
.
PackageNotFoundError
:
except
importlib_metadata
.
PackageNotFoundError
:
_xformers_available
=
False
_xformers_available
=
False
_k_diffusion_available
=
importlib
.
util
.
find_spec
(
"k_diffusion"
)
is
not
None
try
:
_k_diffusion_version
=
importlib_metadata
.
version
(
"k_diffusion"
)
logger
.
debug
(
f
"Successfully imported k-diffusion version
{
_k_diffusion_version
}
"
)
except
importlib_metadata
.
PackageNotFoundError
:
_k_diffusion_available
=
False
def
is_torch_available
():
def
is_torch_available
():
return
_torch_available
return
_torch_available
...
@@ -263,6 +270,10 @@ def is_accelerate_available():
...
@@ -263,6 +270,10 @@ def is_accelerate_available():
return
_accelerate_available
return
_accelerate_available
def
is_k_diffusion_available
():
return
_k_diffusion_available
# docstyle-ignore
# docstyle-ignore
FLAX_IMPORT_ERROR
=
"""
FLAX_IMPORT_ERROR
=
"""
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
...
@@ -317,6 +328,12 @@ UNIDECODE_IMPORT_ERROR = """
...
@@ -317,6 +328,12 @@ UNIDECODE_IMPORT_ERROR = """
Unidecode`
Unidecode`
"""
"""
# docstyle-ignore
K_DIFFUSION_IMPORT_ERROR
=
"""
{0} requires the k-diffusion library but it was not found in your environment. You can install it with pip: `pip
install k-diffusion`
"""
BACKENDS_MAPPING
=
OrderedDict
(
BACKENDS_MAPPING
=
OrderedDict
(
[
[
...
@@ -329,6 +346,7 @@ BACKENDS_MAPPING = OrderedDict(
...
@@ -329,6 +346,7 @@ BACKENDS_MAPPING = OrderedDict(
(
"transformers"
,
(
is_transformers_available
,
TRANSFORMERS_IMPORT_ERROR
)),
(
"transformers"
,
(
is_transformers_available
,
TRANSFORMERS_IMPORT_ERROR
)),
(
"unidecode"
,
(
is_unidecode_available
,
UNIDECODE_IMPORT_ERROR
)),
(
"unidecode"
,
(
is_unidecode_available
,
UNIDECODE_IMPORT_ERROR
)),
(
"librosa"
,
(
is_librosa_available
,
LIBROSA_IMPORT_ERROR
)),
(
"librosa"
,
(
is_librosa_available
,
LIBROSA_IMPORT_ERROR
)),
(
"k_diffusion"
,
(
is_k_diffusion_available
,
K_DIFFUSION_IMPORT_ERROR
)),
]
]
)
)
...
...
tests/pipelines/stable_diffusion/test_stable_diffusion_k_diffusion.py
0 → 100644
View file @
a643c630
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
gc
import
unittest
import
numpy
as
np
import
torch
from
diffusers
import
StableDiffusionKDiffusionPipeline
from
diffusers.utils
import
slow
,
torch_device
from
diffusers.utils.testing_utils
import
require_torch_gpu
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
@
slow
@
require_torch_gpu
class
StableDiffusionPipelineIntegrationTests
(
unittest
.
TestCase
):
def
tearDown
(
self
):
# clean up the VRAM after each test
super
().
tearDown
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
def
test_stable_diffusion_1
(
self
):
sd_pipe
=
StableDiffusionKDiffusionPipeline
.
from_pretrained
(
"CompVis/stable-diffusion-v1-4"
)
sd_pipe
=
sd_pipe
.
to
(
torch_device
)
sd_pipe
.
set_progress_bar_config
(
disable
=
None
)
sd_pipe
.
set_scheduler
(
"sample_euler"
)
prompt
=
"A painting of a squirrel eating a burger"
generator
=
torch
.
Generator
(
device
=
torch_device
).
manual_seed
(
0
)
output
=
sd_pipe
([
prompt
],
generator
=
generator
,
guidance_scale
=
9.0
,
num_inference_steps
=
20
,
output_type
=
"np"
)
image
=
output
.
images
image_slice
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
assert
image
.
shape
==
(
1
,
512
,
512
,
3
)
expected_slice
=
np
.
array
([
0.8887
,
0.915
,
0.91
,
0.894
,
0.909
,
0.912
,
0.919
,
0.925
,
0.883
])
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
def
test_stable_diffusion_2
(
self
):
sd_pipe
=
StableDiffusionKDiffusionPipeline
.
from_pretrained
(
"stabilityai/stable-diffusion-2-1-base"
)
sd_pipe
=
sd_pipe
.
to
(
torch_device
)
sd_pipe
.
set_progress_bar_config
(
disable
=
None
)
sd_pipe
.
set_scheduler
(
"sample_euler"
)
prompt
=
"A painting of a squirrel eating a burger"
generator
=
torch
.
Generator
(
device
=
torch_device
).
manual_seed
(
0
)
output
=
sd_pipe
([
prompt
],
generator
=
generator
,
guidance_scale
=
9.0
,
num_inference_steps
=
20
,
output_type
=
"np"
)
image
=
output
.
images
image_slice
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
assert
image
.
shape
==
(
1
,
512
,
512
,
3
)
expected_slice
=
np
.
array
(
[
0.826810
,
0.81958747
,
0.8510199
,
0.8376758
,
0.83958465
,
0.8682068
,
0.84370345
,
0.85251087
,
0.85884345
]
)
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment