Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
cd6ca9df
Unverified
Commit
cd6ca9df
authored
Nov 21, 2024
by
Aryan
Committed by
GitHub
Nov 21, 2024
Browse files
Fix prepare latent image ids and vae sample generators for flux (#9981)
* fix * update expected slice
parent
e564abe2
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
23 additions
and
9 deletions
+23
-9
src/diffusers/pipelines/flux/pipeline_flux.py
src/diffusers/pipelines/flux/pipeline_flux.py
+1
-1
src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
+17
-3
src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py
...pipelines/flux/pipeline_flux_controlnet_image_to_image.py
+2
-2
src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
...ers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
+2
-2
tests/pipelines/controlnet_flux/test_controlnet_flux.py
tests/pipelines/controlnet_flux/test_controlnet_flux.py
+1
-1
No files found.
src/diffusers/pipelines/flux/pipeline_flux.py
View file @
cd6ca9df
...
@@ -513,7 +513,7 @@ class FluxPipeline(
...
@@ -513,7 +513,7 @@ class FluxPipeline(
shape
=
(
batch_size
,
num_channels_latents
,
height
,
width
)
shape
=
(
batch_size
,
num_channels_latents
,
height
,
width
)
if
latents
is
not
None
:
if
latents
is
not
None
:
latent_image_ids
=
self
.
_prepare_latent_image_ids
(
batch_size
,
height
,
width
,
device
,
dtype
)
latent_image_ids
=
self
.
_prepare_latent_image_ids
(
batch_size
,
height
//
2
,
width
//
2
,
device
,
dtype
)
return
latents
.
to
(
device
=
device
,
dtype
=
dtype
),
latent_image_ids
return
latents
.
to
(
device
=
device
,
dtype
=
dtype
),
latent_image_ids
if
isinstance
(
generator
,
list
)
and
len
(
generator
)
!=
batch_size
:
if
isinstance
(
generator
,
list
)
and
len
(
generator
)
!=
batch_size
:
...
...
src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
View file @
cd6ca9df
...
@@ -97,6 +97,20 @@ def calculate_shift(
...
@@ -97,6 +97,20 @@ def calculate_shift(
return
mu
return
mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def
retrieve_latents
(
encoder_output
:
torch
.
Tensor
,
generator
:
Optional
[
torch
.
Generator
]
=
None
,
sample_mode
:
str
=
"sample"
):
if
hasattr
(
encoder_output
,
"latent_dist"
)
and
sample_mode
==
"sample"
:
return
encoder_output
.
latent_dist
.
sample
(
generator
)
elif
hasattr
(
encoder_output
,
"latent_dist"
)
and
sample_mode
==
"argmax"
:
return
encoder_output
.
latent_dist
.
mode
()
elif
hasattr
(
encoder_output
,
"latents"
):
return
encoder_output
.
latents
else
:
raise
AttributeError
(
"Could not access latents of provided encoder_output"
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def
retrieve_timesteps
(
def
retrieve_timesteps
(
scheduler
,
scheduler
,
...
@@ -512,7 +526,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
...
@@ -512,7 +526,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
shape
=
(
batch_size
,
num_channels_latents
,
height
,
width
)
shape
=
(
batch_size
,
num_channels_latents
,
height
,
width
)
if
latents
is
not
None
:
if
latents
is
not
None
:
latent_image_ids
=
self
.
_prepare_latent_image_ids
(
batch_size
,
height
,
width
,
device
,
dtype
)
latent_image_ids
=
self
.
_prepare_latent_image_ids
(
batch_size
,
height
//
2
,
width
//
2
,
device
,
dtype
)
return
latents
.
to
(
device
=
device
,
dtype
=
dtype
),
latent_image_ids
return
latents
.
to
(
device
=
device
,
dtype
=
dtype
),
latent_image_ids
if
isinstance
(
generator
,
list
)
and
len
(
generator
)
!=
batch_size
:
if
isinstance
(
generator
,
list
)
and
len
(
generator
)
!=
batch_size
:
...
@@ -772,7 +786,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
...
@@ -772,7 +786,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
controlnet_blocks_repeat
=
False
if
self
.
controlnet
.
input_hint_block
is
None
else
True
controlnet_blocks_repeat
=
False
if
self
.
controlnet
.
input_hint_block
is
None
else
True
if
self
.
controlnet
.
input_hint_block
is
None
:
if
self
.
controlnet
.
input_hint_block
is
None
:
# vae encode
# vae encode
control_image
=
self
.
vae
.
encode
(
control_image
)
.
latent_dist
.
sample
(
)
control_image
=
retrieve_latents
(
self
.
vae
.
encode
(
control_image
)
,
generator
=
generator
)
control_image
=
(
control_image
-
self
.
vae
.
config
.
shift_factor
)
*
self
.
vae
.
config
.
scaling_factor
control_image
=
(
control_image
-
self
.
vae
.
config
.
shift_factor
)
*
self
.
vae
.
config
.
scaling_factor
# pack
# pack
...
@@ -810,7 +824,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
...
@@ -810,7 +824,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
if
self
.
controlnet
.
nets
[
0
].
input_hint_block
is
None
:
if
self
.
controlnet
.
nets
[
0
].
input_hint_block
is
None
:
# vae encode
# vae encode
control_image_
=
self
.
vae
.
encode
(
control_image_
)
.
latent_dist
.
sample
(
)
control_image_
=
retrieve_latents
(
self
.
vae
.
encode
(
control_image_
)
,
generator
=
generator
)
control_image_
=
(
control_image_
-
self
.
vae
.
config
.
shift_factor
)
*
self
.
vae
.
config
.
scaling_factor
control_image_
=
(
control_image_
-
self
.
vae
.
config
.
shift_factor
)
*
self
.
vae
.
config
.
scaling_factor
# pack
# pack
...
...
src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py
View file @
cd6ca9df
...
@@ -801,7 +801,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
...
@@ -801,7 +801,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
)
)
height
,
width
=
control_image
.
shape
[
-
2
:]
height
,
width
=
control_image
.
shape
[
-
2
:]
control_image
=
self
.
vae
.
encode
(
control_image
)
.
latent_dist
.
sample
(
)
control_image
=
retrieve_latents
(
self
.
vae
.
encode
(
control_image
)
,
generator
=
generator
)
control_image
=
(
control_image
-
self
.
vae
.
config
.
shift_factor
)
*
self
.
vae
.
config
.
scaling_factor
control_image
=
(
control_image
-
self
.
vae
.
config
.
shift_factor
)
*
self
.
vae
.
config
.
scaling_factor
height_control_image
,
width_control_image
=
control_image
.
shape
[
2
:]
height_control_image
,
width_control_image
=
control_image
.
shape
[
2
:]
...
@@ -832,7 +832,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
...
@@ -832,7 +832,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
)
)
height
,
width
=
control_image_
.
shape
[
-
2
:]
height
,
width
=
control_image_
.
shape
[
-
2
:]
control_image_
=
self
.
vae
.
encode
(
control_image_
)
.
latent_dist
.
sample
(
)
control_image_
=
retrieve_latents
(
self
.
vae
.
encode
(
control_image_
)
,
generator
=
generator
)
control_image_
=
(
control_image_
-
self
.
vae
.
config
.
shift_factor
)
*
self
.
vae
.
config
.
scaling_factor
control_image_
=
(
control_image_
-
self
.
vae
.
config
.
shift_factor
)
*
self
.
vae
.
config
.
scaling_factor
height_control_image
,
width_control_image
=
control_image_
.
shape
[
2
:]
height_control_image
,
width_control_image
=
control_image_
.
shape
[
2
:]
...
...
src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
View file @
cd6ca9df
...
@@ -942,7 +942,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
...
@@ -942,7 +942,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
controlnet_blocks_repeat
=
False
if
self
.
controlnet
.
input_hint_block
is
None
else
True
controlnet_blocks_repeat
=
False
if
self
.
controlnet
.
input_hint_block
is
None
else
True
if
self
.
controlnet
.
input_hint_block
is
None
:
if
self
.
controlnet
.
input_hint_block
is
None
:
# vae encode
# vae encode
control_image
=
self
.
vae
.
encode
(
control_image
)
.
latent_dist
.
sample
(
)
control_image
=
retrieve_latents
(
self
.
vae
.
encode
(
control_image
)
,
generator
=
generator
)
control_image
=
(
control_image
-
self
.
vae
.
config
.
shift_factor
)
*
self
.
vae
.
config
.
scaling_factor
control_image
=
(
control_image
-
self
.
vae
.
config
.
shift_factor
)
*
self
.
vae
.
config
.
scaling_factor
# pack
# pack
...
@@ -979,7 +979,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
...
@@ -979,7 +979,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
if
self
.
controlnet
.
nets
[
0
].
input_hint_block
is
None
:
if
self
.
controlnet
.
nets
[
0
].
input_hint_block
is
None
:
# vae encode
# vae encode
control_image_
=
self
.
vae
.
encode
(
control_image_
)
.
latent_dist
.
sample
(
)
control_image_
=
retrieve_latents
(
self
.
vae
.
encode
(
control_image_
)
,
generator
=
generator
)
control_image_
=
(
control_image_
-
self
.
vae
.
config
.
shift_factor
)
*
self
.
vae
.
config
.
scaling_factor
control_image_
=
(
control_image_
-
self
.
vae
.
config
.
shift_factor
)
*
self
.
vae
.
config
.
scaling_factor
# pack
# pack
...
...
tests/pipelines/controlnet_flux/test_controlnet_flux.py
View file @
cd6ca9df
...
@@ -170,7 +170,7 @@ class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
...
@@ -170,7 +170,7 @@ class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
assert
image
.
shape
==
(
1
,
32
,
32
,
3
)
assert
image
.
shape
==
(
1
,
32
,
32
,
3
)
expected_slice
=
np
.
array
(
expected_slice
=
np
.
array
(
[
0.73
48633
,
0.41333008
,
0.6621094
,
0.
5444336
,
0.47607422
,
0.5859375
,
0.44677734
,
0.4506836
,
0.40454102
]
[
0.
4
73
87695
,
0.63134766
,
0.5605469
,
0.6
1
621094
,
0.
7207031
,
0.7089844
,
0.70410156
,
0.6113281
,
0.64160156
]
)
)
assert
(
assert
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment