Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
95c5ce4e
Unverified
Commit
95c5ce4e
authored
Jan 08, 2025
by
hlky
Committed by
GitHub
Jan 08, 2025
Browse files
PyTorch/XLA support (#10498)
Co-authored-by:
Sayak Paul
<
spsayakpaul@gmail.com
>
parent
c0964571
Changes
111
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
248 additions
and
1 deletion
+248
-1
src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
...users/pipelines/controlnet/pipeline_controlnet_img2img.py
+11
-0
src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
...users/pipelines/controlnet/pipeline_controlnet_inpaint.py
+11
-0
src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
...pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
+13
-0
src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
...ffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
+13
-0
src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py
...pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py
+13
-0
src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py
...nes/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py
+13
-0
src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py
...s/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py
+14
-0
src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py
...nes/controlnet/pipeline_controlnet_union_sd_xl_img2img.py
+14
-0
src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py
...ffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py
+11
-0
src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
...s/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
+13
-0
src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py
...ers/pipelines/dance_diffusion/pipeline_dance_diffusion.py
+11
-1
src/diffusers/pipelines/ddim/pipeline_ddim.py
src/diffusers/pipelines/ddim/pipeline_ddim.py
+12
-0
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
+12
-0
src/diffusers/pipelines/deepfloyd_if/pipeline_if.py
src/diffusers/pipelines/deepfloyd_if/pipeline_if.py
+12
-0
src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py
src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py
+12
-0
src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py
...lines/deepfloyd_if/pipeline_if_img2img_superresolution.py
+13
-0
src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py
...iffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py
+12
-0
src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py
...es/deepfloyd_if/pipeline_if_inpainting_superresolution.py
+13
-0
src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py
...ers/pipelines/deepfloyd_if/pipeline_if_superresolution.py
+13
-0
src/diffusers/pipelines/dit/pipeline_dit.py
src/diffusers/pipelines/dit/pipeline_dit.py
+12
-0
No files found.
src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
View file @
95c5ce4e
...
...
@@ -30,6 +30,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from
...utils
import
(
USE_PEFT_BACKEND
,
deprecate
,
is_torch_xla_available
,
logging
,
replace_example_docstring
,
scale_lora_layers
,
...
...
@@ -41,6 +42,13 @@ from ..stable_diffusion import StableDiffusionPipelineOutput
from
..stable_diffusion.safety_checker
import
StableDiffusionSafetyChecker
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
...
@@ -1294,6 +1302,9 @@ class StableDiffusionControlNetImg2ImgPipeline(
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
if
hasattr
(
self
,
"final_offload_hook"
)
and
self
.
final_offload_hook
is
not
None
:
...
...
src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
View file @
95c5ce4e
...
...
@@ -32,6 +32,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from
...utils
import
(
USE_PEFT_BACKEND
,
deprecate
,
is_torch_xla_available
,
logging
,
replace_example_docstring
,
scale_lora_layers
,
...
...
@@ -43,6 +44,13 @@ from ..stable_diffusion import StableDiffusionPipelineOutput
from
..stable_diffusion.safety_checker
import
StableDiffusionSafetyChecker
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
...
@@ -1476,6 +1484,9 @@ class StableDiffusionControlNetInpaintPipeline(
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
if
hasattr
(
self
,
"final_offload_hook"
)
and
self
.
final_offload_hook
is
not
None
:
...
...
src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
View file @
95c5ce4e
...
...
@@ -60,6 +60,16 @@ if is_invisible_watermark_available():
from
diffusers.pipelines.stable_diffusion_xl.watermark
import
StableDiffusionXLWatermarker
from
...utils
import
is_torch_xla_available
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
...
@@ -1833,6 +1843,9 @@ class StableDiffusionXLControlNetInpaintPipeline(
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# make sure the VAE is in float32 mode, as it overflows in float16
if
self
.
vae
.
dtype
==
torch
.
float16
and
self
.
vae
.
config
.
force_upcast
:
self
.
upcast_vae
()
...
...
src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
View file @
95c5ce4e
...
...
@@ -62,6 +62,16 @@ if is_invisible_watermark_available():
from
..stable_diffusion_xl.watermark
import
StableDiffusionXLWatermarker
from
...utils
import
is_torch_xla_available
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
...
@@ -1552,6 +1562,9 @@ class StableDiffusionXLControlNetPipeline(
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
not
output_type
==
"latent"
:
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting
=
self
.
vae
.
dtype
==
torch
.
float16
and
self
.
vae
.
config
.
force_upcast
...
...
src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py
View file @
95c5ce4e
...
...
@@ -62,6 +62,16 @@ if is_invisible_watermark_available():
from
..stable_diffusion_xl.watermark
import
StableDiffusionXLWatermarker
from
...utils
import
is_torch_xla_available
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
...
@@ -1612,6 +1622,9 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
if
hasattr
(
self
,
"final_offload_hook"
)
and
self
.
final_offload_hook
is
not
None
:
...
...
src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py
View file @
95c5ce4e
...
...
@@ -60,6 +60,16 @@ if is_invisible_watermark_available():
from
diffusers.pipelines.stable_diffusion_xl.watermark
import
StableDiffusionXLWatermarker
from
...utils
import
is_torch_xla_available
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
...
@@ -1759,6 +1769,9 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# make sure the VAE is in float32 mode, as it overflows in float16
if
self
.
vae
.
dtype
==
torch
.
float16
and
self
.
vae
.
config
.
force_upcast
:
self
.
upcast_vae
()
...
...
src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py
View file @
95c5ce4e
...
...
@@ -60,6 +60,17 @@ from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutpu
if
is_invisible_watermark_available
():
from
..stable_diffusion_xl.watermark
import
StableDiffusionXLWatermarker
from
...utils
import
is_torch_xla_available
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
...
@@ -1458,6 +1469,9 @@ class StableDiffusionXLControlNetUnionPipeline(
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
not
output_type
==
"latent"
:
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting
=
self
.
vae
.
dtype
==
torch
.
float16
and
self
.
vae
.
config
.
force_upcast
...
...
src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py
View file @
95c5ce4e
...
...
@@ -61,6 +61,17 @@ from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutpu
if
is_invisible_watermark_available
():
from
..stable_diffusion_xl.watermark
import
StableDiffusionXLWatermarker
from
...utils
import
is_torch_xla_available
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
...
@@ -1577,6 +1588,9 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
if
hasattr
(
self
,
"final_offload_hook"
)
and
self
.
final_offload_hook
is
not
None
:
...
...
src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py
View file @
95c5ce4e
...
...
@@ -30,6 +30,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from
...utils
import
(
USE_PEFT_BACKEND
,
deprecate
,
is_torch_xla_available
,
logging
,
replace_example_docstring
,
scale_lora_layers
,
...
...
@@ -41,6 +42,13 @@ from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from
..stable_diffusion.safety_checker
import
StableDiffusionSafetyChecker
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
...
@@ -884,6 +892,9 @@ class StableDiffusionControlNetXSPipeline(
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
if
hasattr
(
self
,
"final_offload_hook"
)
and
self
.
final_offload_hook
is
not
None
:
...
...
src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
View file @
95c5ce4e
...
...
@@ -54,6 +54,16 @@ if is_invisible_watermark_available():
from
..stable_diffusion_xl.watermark
import
StableDiffusionXLWatermarker
from
...utils
import
is_torch_xla_available
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
...
@@ -1078,6 +1088,9 @@ class StableDiffusionXLControlNetXSPipeline(
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# manually for max memory savings
if
self
.
vae
.
dtype
==
torch
.
float16
and
self
.
vae
.
config
.
force_upcast
:
self
.
upcast_vae
()
...
...
src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py
View file @
95c5ce4e
...
...
@@ -17,11 +17,18 @@ from typing import List, Optional, Tuple, Union
import
torch
from
...utils
import
logging
from
...utils
import
is_torch_xla_available
,
logging
from
...utils.torch_utils
import
randn_tensor
from
..pipeline_utils
import
AudioPipelineOutput
,
DiffusionPipeline
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
...
@@ -146,6 +153,9 @@ class DanceDiffusionPipeline(DiffusionPipeline):
# 2. compute previous audio sample: x_t -> t_t-1
audio
=
self
.
scheduler
.
step
(
model_output
,
t
,
audio
).
prev_sample
if
XLA_AVAILABLE
:
xm
.
mark_step
()
audio
=
audio
.
clamp
(
-
1
,
1
).
float
().
cpu
().
numpy
()
audio
=
audio
[:,
:,
:
original_sample_size
]
...
...
src/diffusers/pipelines/ddim/pipeline_ddim.py
View file @
95c5ce4e
...
...
@@ -17,10 +17,19 @@ from typing import List, Optional, Tuple, Union
import
torch
from
...schedulers
import
DDIMScheduler
from
...utils
import
is_torch_xla_available
from
...utils.torch_utils
import
randn_tensor
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
class
DDIMPipeline
(
DiffusionPipeline
):
r
"""
Pipeline for image generation.
...
...
@@ -143,6 +152,9 @@ class DDIMPipeline(DiffusionPipeline):
model_output
,
t
,
image
,
eta
=
eta
,
use_clipped_model_output
=
use_clipped_model_output
,
generator
=
generator
).
prev_sample
if
XLA_AVAILABLE
:
xm
.
mark_step
()
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
if
output_type
==
"pil"
:
...
...
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
View file @
95c5ce4e
...
...
@@ -17,10 +17,19 @@ from typing import List, Optional, Tuple, Union
import
torch
from
...utils
import
is_torch_xla_available
from
...utils.torch_utils
import
randn_tensor
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
class
DDPMPipeline
(
DiffusionPipeline
):
r
"""
Pipeline for image generation.
...
...
@@ -116,6 +125,9 @@ class DDPMPipeline(DiffusionPipeline):
# 2. compute previous image: x_t -> x_t-1
image
=
self
.
scheduler
.
step
(
model_output
,
t
,
image
,
generator
=
generator
).
prev_sample
if
XLA_AVAILABLE
:
xm
.
mark_step
()
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
if
output_type
==
"pil"
:
...
...
src/diffusers/pipelines/deepfloyd_if/pipeline_if.py
View file @
95c5ce4e
...
...
@@ -14,6 +14,7 @@ from ...utils import (
BACKENDS_MAPPING
,
is_bs4_available
,
is_ftfy_available
,
is_torch_xla_available
,
logging
,
replace_example_docstring
,
)
...
...
@@ -24,8 +25,16 @@ from .safety_checker import IFSafetyChecker
from
.watermark
import
IFWatermarker
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
if
is_bs4_available
():
from
bs4
import
BeautifulSoup
...
...
@@ -735,6 +744,9 @@ class IFPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin):
if
callback
is
not
None
and
i
%
callback_steps
==
0
:
callback
(
i
,
t
,
intermediate_images
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
image
=
intermediate_images
if
output_type
==
"pil"
:
...
...
src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py
View file @
95c5ce4e
...
...
@@ -17,6 +17,7 @@ from ...utils import (
PIL_INTERPOLATION
,
is_bs4_available
,
is_ftfy_available
,
is_torch_xla_available
,
logging
,
replace_example_docstring
,
)
...
...
@@ -27,8 +28,16 @@ from .safety_checker import IFSafetyChecker
from
.watermark
import
IFWatermarker
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
if
is_bs4_available
():
from
bs4
import
BeautifulSoup
...
...
@@ -856,6 +865,9 @@ class IFImg2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin):
if
callback
is
not
None
and
i
%
callback_steps
==
0
:
callback
(
i
,
t
,
intermediate_images
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
image
=
intermediate_images
if
output_type
==
"pil"
:
...
...
src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py
View file @
95c5ce4e
...
...
@@ -35,6 +35,16 @@ if is_ftfy_available():
import
ftfy
from
...utils
import
is_torch_xla_available
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
...
@@ -974,6 +984,9 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, StableDiffusionLoraLoa
if
callback
is
not
None
and
i
%
callback_steps
==
0
:
callback
(
i
,
t
,
intermediate_images
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
image
=
intermediate_images
if
output_type
==
"pil"
:
...
...
src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py
View file @
95c5ce4e
...
...
@@ -17,6 +17,7 @@ from ...utils import (
PIL_INTERPOLATION
,
is_bs4_available
,
is_ftfy_available
,
is_torch_xla_available
,
logging
,
replace_example_docstring
,
)
...
...
@@ -27,8 +28,16 @@ from .safety_checker import IFSafetyChecker
from
.watermark
import
IFWatermarker
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
if
is_bs4_available
():
from
bs4
import
BeautifulSoup
...
...
@@ -975,6 +984,9 @@ class IFInpaintingPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin):
if
callback
is
not
None
and
i
%
callback_steps
==
0
:
callback
(
i
,
t
,
intermediate_images
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
image
=
intermediate_images
if
output_type
==
"pil"
:
...
...
src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py
View file @
95c5ce4e
...
...
@@ -35,6 +35,16 @@ if is_ftfy_available():
import
ftfy
from
...utils
import
is_torch_xla_available
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
...
@@ -1085,6 +1095,9 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, StableDiffusionLora
if
callback
is
not
None
and
i
%
callback_steps
==
0
:
callback
(
i
,
t
,
intermediate_images
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
image
=
intermediate_images
if
output_type
==
"pil"
:
...
...
src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py
View file @
95c5ce4e
...
...
@@ -34,6 +34,16 @@ if is_ftfy_available():
import
ftfy
from
...utils
import
is_torch_xla_available
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
...
@@ -831,6 +841,9 @@ class IFSuperResolutionPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixi
if
callback
is
not
None
and
i
%
callback_steps
==
0
:
callback
(
i
,
t
,
intermediate_images
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
image
=
intermediate_images
if
output_type
==
"pil"
:
...
...
src/diffusers/pipelines/dit/pipeline_dit.py
View file @
95c5ce4e
...
...
@@ -24,10 +24,19 @@ import torch
from
...models
import
AutoencoderKL
,
DiTTransformer2DModel
from
...schedulers
import
KarrasDiffusionSchedulers
from
...utils
import
is_torch_xla_available
from
...utils.torch_utils
import
randn_tensor
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
class
DiTPipeline
(
DiffusionPipeline
):
r
"""
Pipeline for image generation based on a Transformer backbone instead of a UNet.
...
...
@@ -211,6 +220,9 @@ class DiTPipeline(DiffusionPipeline):
# compute previous image: x_t -> x_t-1
latent_model_input
=
self
.
scheduler
.
step
(
model_output
,
t
,
latent_model_input
).
prev_sample
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
guidance_scale
>
1
:
latents
,
_
=
latent_model_input
.
chunk
(
2
,
dim
=
0
)
else
:
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment