Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
e8114bd0
Unverified
Commit
e8114bd0
authored
Jan 16, 2025
by
Daniel Regado
Committed by
GitHub
Jan 16, 2025
Browse files
IP-Adapter for `StableDiffusion3Img2ImgPipeline` (#10589)
Added support for IP-Adapter
parent
b0c89738
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
116 additions
and
7 deletions
+116
-7
src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py
...stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py
+114
-7
tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py
...e_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py
+2
-0
No files found.
src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py
View file @
e8114bd0
...
...
@@ -18,14 +18,16 @@ from typing import Any, Callable, Dict, List, Optional, Union
import
PIL.Image
import
torch
from
transformers
import
(
BaseImageProcessor
,
CLIPTextModelWithProjection
,
CLIPTokenizer
,
PreTrainedModel
,
T5EncoderModel
,
T5TokenizerFast
,
)
from
...image_processor
import
PipelineImageInput
,
VaeImageProcessor
from
...loaders
import
FromSingleFileMixin
,
SD3LoraLoaderMixin
from
...loaders
import
FromSingleFileMixin
,
SD3IPAdapterMixin
,
SD3LoraLoaderMixin
from
...models.autoencoders
import
AutoencoderKL
from
...models.transformers
import
SD3Transformer2DModel
from
...schedulers
import
FlowMatchEulerDiscreteScheduler
...
...
@@ -163,7 +165,7 @@ def retrieve_timesteps(
return
timesteps
,
num_inference_steps
class
StableDiffusion3Img2ImgPipeline
(
DiffusionPipeline
,
SD3LoraLoaderMixin
,
FromSingleFileMixin
):
class
StableDiffusion3Img2ImgPipeline
(
DiffusionPipeline
,
SD3LoraLoaderMixin
,
FromSingleFileMixin
,
SD3IPAdapterMixin
):
r
"""
Args:
transformer ([`SD3Transformer2DModel`]):
...
...
@@ -197,8 +199,8 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
"""
model_cpu_offload_seq
=
"text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
_optional_components
=
[]
model_cpu_offload_seq
=
"text_encoder->text_encoder_2->text_encoder_3->
image_encoder->
transformer->vae"
_optional_components
=
[
"image_encoder"
,
"feature_extractor"
]
_callback_tensor_inputs
=
[
"latents"
,
"prompt_embeds"
,
"negative_prompt_embeds"
,
"negative_pooled_prompt_embeds"
]
def
__init__
(
...
...
@@ -212,6 +214,8 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
tokenizer_2
:
CLIPTokenizer
,
text_encoder_3
:
T5EncoderModel
,
tokenizer_3
:
T5TokenizerFast
,
image_encoder
:
PreTrainedModel
=
None
,
feature_extractor
:
BaseImageProcessor
=
None
,
):
super
().
__init__
()
...
...
@@ -225,6 +229,8 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
tokenizer_3
=
tokenizer_3
,
transformer
=
transformer
,
scheduler
=
scheduler
,
image_encoder
=
image_encoder
,
feature_extractor
=
feature_extractor
,
)
self
.
vae_scale_factor
=
2
**
(
len
(
self
.
vae
.
config
.
block_out_channels
)
-
1
)
if
getattr
(
self
,
"vae"
,
None
)
else
8
latent_channels
=
self
.
vae
.
config
.
latent_channels
if
getattr
(
self
,
"vae"
,
None
)
else
16
...
...
@@ -738,6 +744,84 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
def
interrupt
(
self
):
return
self
.
_interrupt
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_image
def
encode_image
(
self
,
image
:
PipelineImageInput
,
device
:
torch
.
device
)
->
torch
.
Tensor
:
"""Encodes the given image into a feature representation using a pre-trained image encoder.
Args:
image (`PipelineImageInput`):
Input image to be encoded.
device: (`torch.device`):
Torch device.
Returns:
`torch.Tensor`: The encoded image feature representation.
"""
if
not
isinstance
(
image
,
torch
.
Tensor
):
image
=
self
.
feature_extractor
(
image
,
return_tensors
=
"pt"
).
pixel_values
image
=
image
.
to
(
device
=
device
,
dtype
=
self
.
dtype
)
return
self
.
image_encoder
(
image
,
output_hidden_states
=
True
).
hidden_states
[
-
2
]
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_ip_adapter_image_embeds
def
prepare_ip_adapter_image_embeds
(
self
,
ip_adapter_image
:
Optional
[
PipelineImageInput
]
=
None
,
ip_adapter_image_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
device
:
Optional
[
torch
.
device
]
=
None
,
num_images_per_prompt
:
int
=
1
,
do_classifier_free_guidance
:
bool
=
True
,
)
->
torch
.
Tensor
:
"""Prepares image embeddings for use in the IP-Adapter.
Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
Args:
ip_adapter_image (`PipelineImageInput`, *optional*):
The input image to extract features from for IP-Adapter.
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
Precomputed image embeddings.
device: (`torch.device`, *optional*):
Torch device.
num_images_per_prompt (`int`, defaults to 1):
Number of images that should be generated per prompt.
do_classifier_free_guidance (`bool`, defaults to True):
Whether to use classifier free guidance or not.
"""
device
=
device
or
self
.
_execution_device
if
ip_adapter_image_embeds
is
not
None
:
if
do_classifier_free_guidance
:
single_negative_image_embeds
,
single_image_embeds
=
ip_adapter_image_embeds
.
chunk
(
2
)
else
:
single_image_embeds
=
ip_adapter_image_embeds
elif
ip_adapter_image
is
not
None
:
single_image_embeds
=
self
.
encode_image
(
ip_adapter_image
,
device
)
if
do_classifier_free_guidance
:
single_negative_image_embeds
=
torch
.
zeros_like
(
single_image_embeds
)
else
:
raise
ValueError
(
"Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided."
)
image_embeds
=
torch
.
cat
([
single_image_embeds
]
*
num_images_per_prompt
,
dim
=
0
)
if
do_classifier_free_guidance
:
negative_image_embeds
=
torch
.
cat
([
single_negative_image_embeds
]
*
num_images_per_prompt
,
dim
=
0
)
image_embeds
=
torch
.
cat
([
negative_image_embeds
,
image_embeds
],
dim
=
0
)
return
image_embeds
.
to
(
device
=
device
)
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.enable_sequential_cpu_offload
def
enable_sequential_cpu_offload
(
self
,
*
args
,
**
kwargs
):
if
self
.
image_encoder
is
not
None
and
"image_encoder"
not
in
self
.
_exclude_from_cpu_offload
:
logger
.
warning
(
"`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
"`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
"`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
)
super
().
enable_sequential_cpu_offload
(
*
args
,
**
kwargs
)
@
torch
.
no_grad
()
@
replace_example_docstring
(
EXAMPLE_DOC_STRING
)
def
__call__
(
...
...
@@ -763,6 +847,8 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
pooled_prompt_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
negative_pooled_prompt_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
output_type
:
Optional
[
str
]
=
"pil"
,
ip_adapter_image
:
Optional
[
PipelineImageInput
]
=
None
,
ip_adapter_image_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
return_dict
:
bool
=
True
,
joint_attention_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
clip_skip
:
Optional
[
int
]
=
None
,
...
...
@@ -784,9 +870,9 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
prompt_3 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
will be used instead
height (`int`, *optional*, defaults to self.
unet
.config.sample_size * self.vae_scale_factor):
height (`int`, *optional*, defaults to self.
transformer
.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.
unet
.config.sample_size * self.vae_scale_factor):
width (`int`, *optional*, defaults to self.
transformer
.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
...
...
@@ -834,6 +920,12 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
ip_adapter_image (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
`True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
...
...
@@ -969,7 +1061,22 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
generator
,
)
# 6. Denoising loop
# 6. Prepare image embeddings
if
(
ip_adapter_image
is
not
None
and
self
.
is_ip_adapter_active
)
or
ip_adapter_image_embeds
is
not
None
:
ip_adapter_image_embeds
=
self
.
prepare_ip_adapter_image_embeds
(
ip_adapter_image
,
ip_adapter_image_embeds
,
device
,
batch_size
*
num_images_per_prompt
,
self
.
do_classifier_free_guidance
,
)
if
self
.
joint_attention_kwargs
is
None
:
self
.
_joint_attention_kwargs
=
{
"ip_adapter_image_embeds"
:
ip_adapter_image_embeds
}
else
:
self
.
_joint_attention_kwargs
.
update
(
ip_adapter_image_embeds
=
ip_adapter_image_embeds
)
# 7. Denoising loop
num_warmup_steps
=
max
(
len
(
timesteps
)
-
num_inference_steps
*
self
.
scheduler
.
order
,
0
)
self
.
_num_timesteps
=
len
(
timesteps
)
with
self
.
progress_bar
(
total
=
num_inference_steps
)
as
progress_bar
:
...
...
tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py
View file @
e8114bd0
...
...
@@ -105,6 +105,8 @@ class StableDiffusion3Img2ImgPipelineFastTests(PipelineLatentTesterMixin, unitte
"tokenizer_3"
:
tokenizer_3
,
"transformer"
:
transformer
,
"vae"
:
vae
,
"image_encoder"
:
None
,
"feature_extractor"
:
None
,
}
def
get_dummy_inputs
(
self
,
device
,
seed
=
0
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment