Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
5588725e
Unverified
Commit
5588725e
authored
Nov 07, 2024
by
Sayak Paul
Committed by
GitHub
Nov 06, 2024
Browse files
[Flux] reduce explicit device transfers and typecasting in flux. (#9817)
reduce explicit device transfers and typecasting in flux.
parent
ded3db16
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
17 additions
and
17 deletions
+17
-17
src/diffusers/pipelines/flux/pipeline_flux.py
src/diffusers/pipelines/flux/pipeline_flux.py
+3
-3
src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
+2
-2
src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py
...pipelines/flux/pipeline_flux_controlnet_image_to_image.py
+3
-3
src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
...ers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
+3
-3
src/diffusers/pipelines/flux/pipeline_flux_img2img.py
src/diffusers/pipelines/flux/pipeline_flux_img2img.py
+3
-3
src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
+3
-3
No files found.
src/diffusers/pipelines/flux/pipeline_flux.py
View file @
5588725e
...
@@ -371,7 +371,7 @@ class FluxPipeline(
...
@@ -371,7 +371,7 @@ class FluxPipeline(
unscale_lora_layers
(
self
.
text_encoder_2
,
lora_scale
)
unscale_lora_layers
(
self
.
text_encoder_2
,
lora_scale
)
dtype
=
self
.
text_encoder
.
dtype
if
self
.
text_encoder
is
not
None
else
self
.
transformer
.
dtype
dtype
=
self
.
text_encoder
.
dtype
if
self
.
text_encoder
is
not
None
else
self
.
transformer
.
dtype
text_ids
=
torch
.
zeros
(
prompt_embeds
.
shape
[
1
],
3
).
to
(
device
=
device
,
dtype
=
dtyp
e
)
text_ids
=
torch
.
zeros
(
prompt_embeds
.
shape
[
1
],
3
,
dtype
=
dtype
,
device
=
devic
e
)
return
prompt_embeds
,
pooled_prompt_embeds
,
text_ids
return
prompt_embeds
,
pooled_prompt_embeds
,
text_ids
...
@@ -427,7 +427,7 @@ class FluxPipeline(
...
@@ -427,7 +427,7 @@ class FluxPipeline(
@
staticmethod
@
staticmethod
def
_prepare_latent_image_ids
(
batch_size
,
height
,
width
,
device
,
dtype
):
def
_prepare_latent_image_ids
(
batch_size
,
height
,
width
,
device
,
dtype
):
latent_image_ids
=
torch
.
zeros
(
height
,
width
,
3
)
latent_image_ids
=
torch
.
zeros
(
height
,
width
,
3
,
device
=
device
,
dtype
=
dtype
)
latent_image_ids
[...,
1
]
=
latent_image_ids
[...,
1
]
+
torch
.
arange
(
height
)[:,
None
]
latent_image_ids
[...,
1
]
=
latent_image_ids
[...,
1
]
+
torch
.
arange
(
height
)[:,
None
]
latent_image_ids
[...,
2
]
=
latent_image_ids
[...,
2
]
+
torch
.
arange
(
width
)[
None
,
:]
latent_image_ids
[...,
2
]
=
latent_image_ids
[...,
2
]
+
torch
.
arange
(
width
)[
None
,
:]
...
@@ -437,7 +437,7 @@ class FluxPipeline(
...
@@ -437,7 +437,7 @@ class FluxPipeline(
latent_image_id_height
*
latent_image_id_width
,
latent_image_id_channels
latent_image_id_height
*
latent_image_id_width
,
latent_image_id_channels
)
)
return
latent_image_ids
.
to
(
device
=
device
,
dtype
=
dtype
)
return
latent_image_ids
@
staticmethod
@
staticmethod
def
_pack_latents
(
latents
,
batch_size
,
num_channels_latents
,
height
,
width
):
def
_pack_latents
(
latents
,
batch_size
,
num_channels_latents
,
height
,
width
):
...
...
src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
View file @
5588725e
...
@@ -452,7 +452,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
...
@@ -452,7 +452,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
@
staticmethod
@
staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
def
_prepare_latent_image_ids
(
batch_size
,
height
,
width
,
device
,
dtype
):
def
_prepare_latent_image_ids
(
batch_size
,
height
,
width
,
device
,
dtype
):
latent_image_ids
=
torch
.
zeros
(
height
,
width
,
3
)
latent_image_ids
=
torch
.
zeros
(
height
,
width
,
3
,
device
=
device
,
dtype
=
dtype
)
latent_image_ids
[...,
1
]
=
latent_image_ids
[...,
1
]
+
torch
.
arange
(
height
)[:,
None
]
latent_image_ids
[...,
1
]
=
latent_image_ids
[...,
1
]
+
torch
.
arange
(
height
)[:,
None
]
latent_image_ids
[...,
2
]
=
latent_image_ids
[...,
2
]
+
torch
.
arange
(
width
)[
None
,
:]
latent_image_ids
[...,
2
]
=
latent_image_ids
[...,
2
]
+
torch
.
arange
(
width
)[
None
,
:]
...
@@ -462,7 +462,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
...
@@ -462,7 +462,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
latent_image_id_height
*
latent_image_id_width
,
latent_image_id_channels
latent_image_id_height
*
latent_image_id_width
,
latent_image_id_channels
)
)
return
latent_image_ids
.
to
(
device
=
device
,
dtype
=
dtype
)
return
latent_image_ids
@
staticmethod
@
staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
...
...
src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py
View file @
5588725e
...
@@ -407,7 +407,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
...
@@ -407,7 +407,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
unscale_lora_layers
(
self
.
text_encoder_2
,
lora_scale
)
unscale_lora_layers
(
self
.
text_encoder_2
,
lora_scale
)
dtype
=
self
.
text_encoder
.
dtype
if
self
.
text_encoder
is
not
None
else
self
.
transformer
.
dtype
dtype
=
self
.
text_encoder
.
dtype
if
self
.
text_encoder
is
not
None
else
self
.
transformer
.
dtype
text_ids
=
torch
.
zeros
(
prompt_embeds
.
shape
[
1
],
3
).
to
(
device
=
device
,
dtype
=
dtyp
e
)
text_ids
=
torch
.
zeros
(
prompt_embeds
.
shape
[
1
],
3
,
dtype
=
dtype
,
device
=
devic
e
)
return
prompt_embeds
,
pooled_prompt_embeds
,
text_ids
return
prompt_embeds
,
pooled_prompt_embeds
,
text_ids
...
@@ -495,7 +495,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
...
@@ -495,7 +495,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
@
staticmethod
@
staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
def
_prepare_latent_image_ids
(
batch_size
,
height
,
width
,
device
,
dtype
):
def
_prepare_latent_image_ids
(
batch_size
,
height
,
width
,
device
,
dtype
):
latent_image_ids
=
torch
.
zeros
(
height
,
width
,
3
)
latent_image_ids
=
torch
.
zeros
(
height
,
width
,
3
,
device
=
device
,
dtype
=
dtype
)
latent_image_ids
[...,
1
]
=
latent_image_ids
[...,
1
]
+
torch
.
arange
(
height
)[:,
None
]
latent_image_ids
[...,
1
]
=
latent_image_ids
[...,
1
]
+
torch
.
arange
(
height
)[:,
None
]
latent_image_ids
[...,
2
]
=
latent_image_ids
[...,
2
]
+
torch
.
arange
(
width
)[
None
,
:]
latent_image_ids
[...,
2
]
=
latent_image_ids
[...,
2
]
+
torch
.
arange
(
width
)[
None
,
:]
...
@@ -505,7 +505,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
...
@@ -505,7 +505,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
latent_image_id_height
*
latent_image_id_width
,
latent_image_id_channels
latent_image_id_height
*
latent_image_id_width
,
latent_image_id_channels
)
)
return
latent_image_ids
.
to
(
device
=
device
,
dtype
=
dtype
)
return
latent_image_ids
@
staticmethod
@
staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
...
...
src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
View file @
5588725e
...
@@ -417,7 +417,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
...
@@ -417,7 +417,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
unscale_lora_layers
(
self
.
text_encoder_2
,
lora_scale
)
unscale_lora_layers
(
self
.
text_encoder_2
,
lora_scale
)
dtype
=
self
.
text_encoder
.
dtype
if
self
.
text_encoder
is
not
None
else
self
.
transformer
.
dtype
dtype
=
self
.
text_encoder
.
dtype
if
self
.
text_encoder
is
not
None
else
self
.
transformer
.
dtype
text_ids
=
torch
.
zeros
(
prompt_embeds
.
shape
[
1
],
3
).
to
(
device
=
device
,
dtype
=
dtyp
e
)
text_ids
=
torch
.
zeros
(
prompt_embeds
.
shape
[
1
],
3
,
dtype
=
dtype
,
device
=
devic
e
)
return
prompt_embeds
,
pooled_prompt_embeds
,
text_ids
return
prompt_embeds
,
pooled_prompt_embeds
,
text_ids
...
@@ -522,7 +522,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
...
@@ -522,7 +522,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
@
staticmethod
@
staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
def
_prepare_latent_image_ids
(
batch_size
,
height
,
width
,
device
,
dtype
):
def
_prepare_latent_image_ids
(
batch_size
,
height
,
width
,
device
,
dtype
):
latent_image_ids
=
torch
.
zeros
(
height
,
width
,
3
)
latent_image_ids
=
torch
.
zeros
(
height
,
width
,
3
,
device
=
device
,
dtype
=
dtype
)
latent_image_ids
[...,
1
]
=
latent_image_ids
[...,
1
]
+
torch
.
arange
(
height
)[:,
None
]
latent_image_ids
[...,
1
]
=
latent_image_ids
[...,
1
]
+
torch
.
arange
(
height
)[:,
None
]
latent_image_ids
[...,
2
]
=
latent_image_ids
[...,
2
]
+
torch
.
arange
(
width
)[
None
,
:]
latent_image_ids
[...,
2
]
=
latent_image_ids
[...,
2
]
+
torch
.
arange
(
width
)[
None
,
:]
...
@@ -532,7 +532,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
...
@@ -532,7 +532,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
latent_image_id_height
*
latent_image_id_width
,
latent_image_id_channels
latent_image_id_height
*
latent_image_id_width
,
latent_image_id_channels
)
)
return
latent_image_ids
.
to
(
device
=
device
,
dtype
=
dtype
)
return
latent_image_ids
@
staticmethod
@
staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
...
...
src/diffusers/pipelines/flux/pipeline_flux_img2img.py
View file @
5588725e
...
@@ -391,7 +391,7 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
...
@@ -391,7 +391,7 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
unscale_lora_layers
(
self
.
text_encoder_2
,
lora_scale
)
unscale_lora_layers
(
self
.
text_encoder_2
,
lora_scale
)
dtype
=
self
.
text_encoder
.
dtype
if
self
.
text_encoder
is
not
None
else
self
.
transformer
.
dtype
dtype
=
self
.
text_encoder
.
dtype
if
self
.
text_encoder
is
not
None
else
self
.
transformer
.
dtype
text_ids
=
torch
.
zeros
(
prompt_embeds
.
shape
[
1
],
3
).
to
(
device
=
device
,
dtype
=
dtyp
e
)
text_ids
=
torch
.
zeros
(
prompt_embeds
.
shape
[
1
],
3
,
dtype
=
dtype
,
device
=
devic
e
)
return
prompt_embeds
,
pooled_prompt_embeds
,
text_ids
return
prompt_embeds
,
pooled_prompt_embeds
,
text_ids
...
@@ -479,7 +479,7 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
...
@@ -479,7 +479,7 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
@
staticmethod
@
staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
def
_prepare_latent_image_ids
(
batch_size
,
height
,
width
,
device
,
dtype
):
def
_prepare_latent_image_ids
(
batch_size
,
height
,
width
,
device
,
dtype
):
latent_image_ids
=
torch
.
zeros
(
height
,
width
,
3
)
latent_image_ids
=
torch
.
zeros
(
height
,
width
,
3
,
device
=
device
,
dtype
=
dtype
)
latent_image_ids
[...,
1
]
=
latent_image_ids
[...,
1
]
+
torch
.
arange
(
height
)[:,
None
]
latent_image_ids
[...,
1
]
=
latent_image_ids
[...,
1
]
+
torch
.
arange
(
height
)[:,
None
]
latent_image_ids
[...,
2
]
=
latent_image_ids
[...,
2
]
+
torch
.
arange
(
width
)[
None
,
:]
latent_image_ids
[...,
2
]
=
latent_image_ids
[...,
2
]
+
torch
.
arange
(
width
)[
None
,
:]
...
@@ -489,7 +489,7 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
...
@@ -489,7 +489,7 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
latent_image_id_height
*
latent_image_id_width
,
latent_image_id_channels
latent_image_id_height
*
latent_image_id_width
,
latent_image_id_channels
)
)
return
latent_image_ids
.
to
(
device
=
device
,
dtype
=
dtype
)
return
latent_image_ids
@
staticmethod
@
staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
...
...
src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
View file @
5588725e
...
@@ -395,7 +395,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
...
@@ -395,7 +395,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
unscale_lora_layers
(
self
.
text_encoder_2
,
lora_scale
)
unscale_lora_layers
(
self
.
text_encoder_2
,
lora_scale
)
dtype
=
self
.
text_encoder
.
dtype
if
self
.
text_encoder
is
not
None
else
self
.
transformer
.
dtype
dtype
=
self
.
text_encoder
.
dtype
if
self
.
text_encoder
is
not
None
else
self
.
transformer
.
dtype
text_ids
=
torch
.
zeros
(
prompt_embeds
.
shape
[
1
],
3
).
to
(
device
=
device
,
dtype
=
dtyp
e
)
text_ids
=
torch
.
zeros
(
prompt_embeds
.
shape
[
1
],
3
,
dtype
=
dtype
,
device
=
devic
e
)
return
prompt_embeds
,
pooled_prompt_embeds
,
text_ids
return
prompt_embeds
,
pooled_prompt_embeds
,
text_ids
...
@@ -500,7 +500,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
...
@@ -500,7 +500,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
@
staticmethod
@
staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
def
_prepare_latent_image_ids
(
batch_size
,
height
,
width
,
device
,
dtype
):
def
_prepare_latent_image_ids
(
batch_size
,
height
,
width
,
device
,
dtype
):
latent_image_ids
=
torch
.
zeros
(
height
,
width
,
3
)
latent_image_ids
=
torch
.
zeros
(
height
,
width
,
3
,
device
=
device
,
dtype
=
dtype
)
latent_image_ids
[...,
1
]
=
latent_image_ids
[...,
1
]
+
torch
.
arange
(
height
)[:,
None
]
latent_image_ids
[...,
1
]
=
latent_image_ids
[...,
1
]
+
torch
.
arange
(
height
)[:,
None
]
latent_image_ids
[...,
2
]
=
latent_image_ids
[...,
2
]
+
torch
.
arange
(
width
)[
None
,
:]
latent_image_ids
[...,
2
]
=
latent_image_ids
[...,
2
]
+
torch
.
arange
(
width
)[
None
,
:]
...
@@ -510,7 +510,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
...
@@ -510,7 +510,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
latent_image_id_height
*
latent_image_id_width
,
latent_image_id_channels
latent_image_id_height
*
latent_image_id_width
,
latent_image_id_channels
)
)
return
latent_image_ids
.
to
(
device
=
device
,
dtype
=
dtype
)
return
latent_image_ids
@
staticmethod
@
staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment