Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
cd6ca9df
Unverified
Commit
cd6ca9df
authored
Nov 21, 2024
by
Aryan
Committed by
GitHub
Nov 21, 2024
Browse files
Fix prepare latent image ids and vae sample generators for flux (#9981)
* fix * update expected slice
parent
e564abe2
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
23 additions
and
9 deletions
+23
-9
src/diffusers/pipelines/flux/pipeline_flux.py
src/diffusers/pipelines/flux/pipeline_flux.py
+1
-1
src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
+17
-3
src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py
...pipelines/flux/pipeline_flux_controlnet_image_to_image.py
+2
-2
src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
...ers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
+2
-2
tests/pipelines/controlnet_flux/test_controlnet_flux.py
tests/pipelines/controlnet_flux/test_controlnet_flux.py
+1
-1
No files found.
src/diffusers/pipelines/flux/pipeline_flux.py
View file @
cd6ca9df
...
...
@@ -513,7 +513,7 @@ class FluxPipeline(
shape
=
(
batch_size
,
num_channels_latents
,
height
,
width
)
if
latents
is
not
None
:
latent_image_ids
=
self
.
_prepare_latent_image_ids
(
batch_size
,
height
,
width
,
device
,
dtype
)
latent_image_ids
=
self
.
_prepare_latent_image_ids
(
batch_size
,
height
//
2
,
width
//
2
,
device
,
dtype
)
return
latents
.
to
(
device
=
device
,
dtype
=
dtype
),
latent_image_ids
if
isinstance
(
generator
,
list
)
and
len
(
generator
)
!=
batch_size
:
...
...
src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
View file @
cd6ca9df
...
...
@@ -97,6 +97,20 @@ def calculate_shift(
return
mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def
retrieve_latents
(
encoder_output
:
torch
.
Tensor
,
generator
:
Optional
[
torch
.
Generator
]
=
None
,
sample_mode
:
str
=
"sample"
):
if
hasattr
(
encoder_output
,
"latent_dist"
)
and
sample_mode
==
"sample"
:
return
encoder_output
.
latent_dist
.
sample
(
generator
)
elif
hasattr
(
encoder_output
,
"latent_dist"
)
and
sample_mode
==
"argmax"
:
return
encoder_output
.
latent_dist
.
mode
()
elif
hasattr
(
encoder_output
,
"latents"
):
return
encoder_output
.
latents
else
:
raise
AttributeError
(
"Could not access latents of provided encoder_output"
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def
retrieve_timesteps
(
scheduler
,
...
...
@@ -512,7 +526,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
shape
=
(
batch_size
,
num_channels_latents
,
height
,
width
)
if
latents
is
not
None
:
latent_image_ids
=
self
.
_prepare_latent_image_ids
(
batch_size
,
height
,
width
,
device
,
dtype
)
latent_image_ids
=
self
.
_prepare_latent_image_ids
(
batch_size
,
height
//
2
,
width
//
2
,
device
,
dtype
)
return
latents
.
to
(
device
=
device
,
dtype
=
dtype
),
latent_image_ids
if
isinstance
(
generator
,
list
)
and
len
(
generator
)
!=
batch_size
:
...
...
@@ -772,7 +786,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
controlnet_blocks_repeat
=
False
if
self
.
controlnet
.
input_hint_block
is
None
else
True
if
self
.
controlnet
.
input_hint_block
is
None
:
# vae encode
control_image
=
self
.
vae
.
encode
(
control_image
)
.
latent_dist
.
sample
(
)
control_image
=
retrieve_latents
(
self
.
vae
.
encode
(
control_image
)
,
generator
=
generator
)
control_image
=
(
control_image
-
self
.
vae
.
config
.
shift_factor
)
*
self
.
vae
.
config
.
scaling_factor
# pack
...
...
@@ -810,7 +824,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
if
self
.
controlnet
.
nets
[
0
].
input_hint_block
is
None
:
# vae encode
control_image_
=
self
.
vae
.
encode
(
control_image_
)
.
latent_dist
.
sample
(
)
control_image_
=
retrieve_latents
(
self
.
vae
.
encode
(
control_image_
)
,
generator
=
generator
)
control_image_
=
(
control_image_
-
self
.
vae
.
config
.
shift_factor
)
*
self
.
vae
.
config
.
scaling_factor
# pack
...
...
src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py
View file @
cd6ca9df
...
...
@@ -801,7 +801,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
)
height
,
width
=
control_image
.
shape
[
-
2
:]
control_image
=
self
.
vae
.
encode
(
control_image
)
.
latent_dist
.
sample
(
)
control_image
=
retrieve_latents
(
self
.
vae
.
encode
(
control_image
)
,
generator
=
generator
)
control_image
=
(
control_image
-
self
.
vae
.
config
.
shift_factor
)
*
self
.
vae
.
config
.
scaling_factor
height_control_image
,
width_control_image
=
control_image
.
shape
[
2
:]
...
...
@@ -832,7 +832,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
)
height
,
width
=
control_image_
.
shape
[
-
2
:]
control_image_
=
self
.
vae
.
encode
(
control_image_
)
.
latent_dist
.
sample
(
)
control_image_
=
retrieve_latents
(
self
.
vae
.
encode
(
control_image_
)
,
generator
=
generator
)
control_image_
=
(
control_image_
-
self
.
vae
.
config
.
shift_factor
)
*
self
.
vae
.
config
.
scaling_factor
height_control_image
,
width_control_image
=
control_image_
.
shape
[
2
:]
...
...
src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
View file @
cd6ca9df
...
...
@@ -942,7 +942,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
controlnet_blocks_repeat
=
False
if
self
.
controlnet
.
input_hint_block
is
None
else
True
if
self
.
controlnet
.
input_hint_block
is
None
:
# vae encode
control_image
=
self
.
vae
.
encode
(
control_image
)
.
latent_dist
.
sample
(
)
control_image
=
retrieve_latents
(
self
.
vae
.
encode
(
control_image
)
,
generator
=
generator
)
control_image
=
(
control_image
-
self
.
vae
.
config
.
shift_factor
)
*
self
.
vae
.
config
.
scaling_factor
# pack
...
...
@@ -979,7 +979,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
if
self
.
controlnet
.
nets
[
0
].
input_hint_block
is
None
:
# vae encode
control_image_
=
self
.
vae
.
encode
(
control_image_
)
.
latent_dist
.
sample
(
)
control_image_
=
retrieve_latents
(
self
.
vae
.
encode
(
control_image_
)
,
generator
=
generator
)
control_image_
=
(
control_image_
-
self
.
vae
.
config
.
shift_factor
)
*
self
.
vae
.
config
.
scaling_factor
# pack
...
...
tests/pipelines/controlnet_flux/test_controlnet_flux.py
View file @
cd6ca9df
...
...
@@ -170,7 +170,7 @@ class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
assert
image
.
shape
==
(
1
,
32
,
32
,
3
)
expected_slice
=
np
.
array
(
[
0.73
48633
,
0.41333008
,
0.6621094
,
0.
5444336
,
0.47607422
,
0.5859375
,
0.44677734
,
0.4506836
,
0.40454102
]
[
0.
4
73
87695
,
0.63134766
,
0.5605469
,
0.6
1
621094
,
0.
7207031
,
0.7089844
,
0.70410156
,
0.6113281
,
0.64160156
]
)
assert
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment