Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
95c5ce4e
Unverified
Commit
95c5ce4e
authored
Jan 08, 2025
by
hlky
Committed by
GitHub
Jan 08, 2025
Browse files
PyTorch/XLA support (#10498)
Co-authored-by:
Sayak Paul
<
spsayakpaul@gmail.com
>
parent
c0964571
Changes
111
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
242 additions
and
2 deletions
+242
-2
src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py
...pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py
+15
-0
src/diffusers/pipelines/lumina/pipeline_lumina.py
src/diffusers/pipelines/lumina/pipeline_lumina.py
+12
-0
src/diffusers/pipelines/marigold/pipeline_marigold_depth.py
src/diffusers/pipelines/marigold/pipeline_marigold_depth.py
+11
-0
src/diffusers/pipelines/marigold/pipeline_marigold_normals.py
...diffusers/pipelines/marigold/pipeline_marigold_normals.py
+11
-0
src/diffusers/pipelines/musicldm/pipeline_musicldm.py
src/diffusers/pipelines/musicldm/pipeline_musicldm.py
+15
-0
src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py
src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py
+11
-0
src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py
...users/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py
+11
-0
src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py
src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py
+13
-0
src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py
...rs/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py
+13
-0
src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py
src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py
+12
-0
src/diffusers/pipelines/pag/pipeline_pag_sana.py
src/diffusers/pipelines/pag/pipeline_pag_sana.py
+12
-0
src/diffusers/pipelines/pag/pipeline_pag_sd.py
src/diffusers/pipelines/pag/pipeline_pag_sd.py
+12
-0
src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py
src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py
+12
-0
src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py
src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py
+12
-0
src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py
src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py
+12
-0
src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
...s/pipelines/paint_by_example/pipeline_paint_by_example.py
+11
-1
src/diffusers/pipelines/pia/pipeline_pia.py
src/diffusers/pipelines/pia/pipeline_pia.py
+12
-0
src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
...diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
+12
-0
src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py
...diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py
+12
-0
src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py
...ic_stable_diffusion/pipeline_semantic_stable_diffusion.py
+11
-1
No files found.
src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py
View file @
95c5ce4e
...
...
@@ -19,6 +19,7 @@ from ...schedulers import DDIMScheduler, DPMSolverMultistepScheduler
from
...utils
import
(
USE_PEFT_BACKEND
,
deprecate
,
is_torch_xla_available
,
logging
,
replace_example_docstring
,
scale_lora_layers
,
...
...
@@ -29,8 +30,16 @@ from ..pipeline_utils import DiffusionPipeline
from
.pipeline_output
import
LEditsPPDiffusionPipelineOutput
,
LEditsPPInversionPipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
Examples:
```py
...
...
@@ -1209,6 +1218,9 @@ class LEditsPPPipelineStableDiffusion(
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# 8. Post-processing
if
not
output_type
==
"latent"
:
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
,
generator
=
generator
)[
...
...
@@ -1378,6 +1390,9 @@ class LEditsPPPipelineStableDiffusion(
progress_bar
.
update
()
if
XLA_AVAILABLE
:
xm
.
mark_step
()
self
.
init_latents
=
xts
[
-
1
].
expand
(
self
.
batch_size
,
-
1
,
-
1
,
-
1
)
zs
=
zs
.
flip
(
0
)
self
.
zs
=
zs
...
...
src/diffusers/pipelines/lumina/pipeline_lumina.py
View file @
95c5ce4e
...
...
@@ -31,6 +31,7 @@ from ...utils import (
BACKENDS_MAPPING
,
is_bs4_available
,
is_ftfy_available
,
is_torch_xla_available
,
logging
,
replace_example_docstring
,
)
...
...
@@ -38,8 +39,16 @@ from ...utils.torch_utils import randn_tensor
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
if
is_bs4_available
():
from
bs4
import
BeautifulSoup
...
...
@@ -874,6 +883,9 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
progress_bar
.
update
()
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
not
output_type
==
"latent"
:
latents
=
latents
/
self
.
vae
.
config
.
scaling_factor
image
=
self
.
vae
.
decode
(
latents
,
return_dict
=
False
)[
0
]
...
...
src/diffusers/pipelines/marigold/pipeline_marigold_depth.py
View file @
95c5ce4e
...
...
@@ -37,6 +37,7 @@ from ...schedulers import (
)
from
...utils
import
(
BaseOutput
,
is_torch_xla_available
,
logging
,
replace_example_docstring
,
)
...
...
@@ -46,6 +47,13 @@ from ..pipeline_utils import DiffusionPipeline
from
.marigold_image_processing
import
MarigoldImageProcessor
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
...
@@ -517,6 +525,9 @@ class MarigoldDepthPipeline(DiffusionPipeline):
noise
,
t
,
batch_pred_latent
,
generator
=
generator
).
prev_sample
# [B,4,h,w]
if
XLA_AVAILABLE
:
xm
.
mark_step
()
pred_latents
.
append
(
batch_pred_latent
)
pred_latent
=
torch
.
cat
(
pred_latents
,
dim
=
0
)
# [N*E,4,h,w]
...
...
src/diffusers/pipelines/marigold/pipeline_marigold_normals.py
View file @
95c5ce4e
...
...
@@ -36,6 +36,7 @@ from ...schedulers import (
)
from
...utils
import
(
BaseOutput
,
is_torch_xla_available
,
logging
,
replace_example_docstring
,
)
...
...
@@ -44,6 +45,13 @@ from ..pipeline_utils import DiffusionPipeline
from
.marigold_image_processing
import
MarigoldImageProcessor
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
...
@@ -493,6 +501,9 @@ class MarigoldNormalsPipeline(DiffusionPipeline):
noise
,
t
,
batch_pred_latent
,
generator
=
generator
).
prev_sample
# [B,4,h,w]
if
XLA_AVAILABLE
:
xm
.
mark_step
()
pred_latents
.
append
(
batch_pred_latent
)
pred_latent
=
torch
.
cat
(
pred_latents
,
dim
=
0
)
# [N*E,4,h,w]
...
...
src/diffusers/pipelines/musicldm/pipeline_musicldm.py
View file @
95c5ce4e
...
...
@@ -42,8 +42,20 @@ from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline, StableDiffu
if
is_librosa_available
():
import
librosa
from
...utils
import
is_torch_xla_available
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
Examples:
```py
...
...
@@ -603,6 +615,9 @@ class MusicLDMPipeline(DiffusionPipeline, StableDiffusionMixin):
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
self
.
maybe_free_model_hooks
()
# 8. Post-processing
...
...
src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py
View file @
95c5ce4e
...
...
@@ -30,6 +30,7 @@ from ...models.lora import adjust_lora_scale_text_encoder
from
...schedulers
import
KarrasDiffusionSchedulers
from
...utils
import
(
USE_PEFT_BACKEND
,
is_torch_xla_available
,
logging
,
replace_example_docstring
,
scale_lora_layers
,
...
...
@@ -42,6 +43,13 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from
.pag_utils
import
PAGMixin
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
...
@@ -1293,6 +1301,9 @@ class StableDiffusionControlNetPAGPipeline(
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
if
hasattr
(
self
,
"final_offload_hook"
)
and
self
.
final_offload_hook
is
not
None
:
...
...
src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py
View file @
95c5ce4e
...
...
@@ -31,6 +31,7 @@ from ...models.lora import adjust_lora_scale_text_encoder
from
...schedulers
import
KarrasDiffusionSchedulers
from
...utils
import
(
USE_PEFT_BACKEND
,
is_torch_xla_available
,
logging
,
replace_example_docstring
,
scale_lora_layers
,
...
...
@@ -43,6 +44,13 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from
.pag_utils
import
PAGMixin
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
...
@@ -1505,6 +1513,9 @@ class StableDiffusionControlNetPAGInpaintPipeline(
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
if
hasattr
(
self
,
"final_offload_hook"
)
and
self
.
final_offload_hook
is
not
None
:
...
...
src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py
View file @
95c5ce4e
...
...
@@ -62,6 +62,16 @@ if is_invisible_watermark_available():
from
..stable_diffusion_xl.watermark
import
StableDiffusionXLWatermarker
from
...utils
import
is_torch_xla_available
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
...
@@ -1564,6 +1574,9 @@ class StableDiffusionXLControlNetPAGPipeline(
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
not
output_type
==
"latent"
:
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting
=
self
.
vae
.
dtype
==
torch
.
float16
and
self
.
vae
.
config
.
force_upcast
...
...
src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py
View file @
95c5ce4e
...
...
@@ -62,6 +62,16 @@ if is_invisible_watermark_available():
from
..stable_diffusion_xl.watermark
import
StableDiffusionXLWatermarker
from
...utils
import
is_torch_xla_available
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
...
@@ -1630,6 +1640,9 @@ class StableDiffusionXLControlNetPAGImg2ImgPipeline(
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
if
hasattr
(
self
,
"final_offload_hook"
)
and
self
.
final_offload_hook
is
not
None
:
...
...
src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py
View file @
95c5ce4e
...
...
@@ -29,6 +29,7 @@ from ...utils import (
deprecate
,
is_bs4_available
,
is_ftfy_available
,
is_torch_xla_available
,
logging
,
replace_example_docstring
,
)
...
...
@@ -43,8 +44,16 @@ from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN
from
.pag_utils
import
PAGMixin
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
if
is_bs4_available
():
from
bs4
import
BeautifulSoup
...
...
@@ -843,6 +852,9 @@ class PixArtSigmaPAGPipeline(DiffusionPipeline, PAGMixin):
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
not
output_type
==
"latent"
:
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
if
use_resolution_binning
:
...
...
src/diffusers/pipelines/pag/pipeline_pag_sana.py
View file @
95c5ce4e
...
...
@@ -30,6 +30,7 @@ from ...utils import (
BACKENDS_MAPPING
,
is_bs4_available
,
is_ftfy_available
,
is_torch_xla_available
,
logging
,
replace_example_docstring
,
)
...
...
@@ -43,8 +44,16 @@ from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN
from
.pag_utils
import
PAGMixin
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
if
is_bs4_available
():
from
bs4
import
BeautifulSoup
...
...
@@ -867,6 +876,9 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
output_type
==
"latent"
:
image
=
latents
else
:
...
...
src/diffusers/pipelines/pag/pipeline_pag_sd.py
View file @
95c5ce4e
...
...
@@ -27,6 +27,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from
...utils
import
(
USE_PEFT_BACKEND
,
deprecate
,
is_torch_xla_available
,
logging
,
replace_example_docstring
,
scale_lora_layers
,
...
...
@@ -39,8 +40,16 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from
.pag_utils
import
PAGMixin
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
Examples:
```py
...
...
@@ -1034,6 +1043,9 @@ class StableDiffusionPAGPipeline(
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
not
output_type
==
"latent"
:
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
,
generator
=
generator
)[
0
...
...
src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py
View file @
95c5ce4e
...
...
@@ -26,6 +26,7 @@ from ...models.unets.unet_motion_model import MotionAdapter
from
...schedulers
import
KarrasDiffusionSchedulers
from
...utils
import
(
USE_PEFT_BACKEND
,
is_torch_xla_available
,
logging
,
replace_example_docstring
,
scale_lora_layers
,
...
...
@@ -40,8 +41,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from
.pag_utils
import
PAGMixin
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
Examples:
```py
...
...
@@ -847,6 +856,9 @@ class AnimateDiffPAGPipeline(
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# 9. Post processing
if
output_type
==
"latent"
:
video
=
latents
...
...
src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py
View file @
95c5ce4e
...
...
@@ -30,6 +30,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from
...utils
import
(
USE_PEFT_BACKEND
,
deprecate
,
is_torch_xla_available
,
logging
,
replace_example_docstring
,
scale_lora_layers
,
...
...
@@ -42,8 +43,16 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from
.pag_utils
import
PAGMixin
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
Examples:
```py
...
...
@@ -1066,6 +1075,9 @@ class StableDiffusionPAGImg2ImgPipeline(
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
not
output_type
==
"latent"
:
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
,
generator
=
generator
)[
0
...
...
src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py
View file @
95c5ce4e
...
...
@@ -28,6 +28,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from
...utils
import
(
USE_PEFT_BACKEND
,
deprecate
,
is_torch_xla_available
,
logging
,
replace_example_docstring
,
scale_lora_layers
,
...
...
@@ -40,8 +41,16 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from
.pag_utils
import
PAGMixin
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
Examples:
```py
...
...
@@ -1318,6 +1327,9 @@ class StableDiffusionPAGInpaintPipeline(
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
not
output_type
==
"latent"
:
condition_kwargs
=
{}
if
isinstance
(
self
.
vae
,
AsymmetricAutoencoderKL
):
...
...
src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
View file @
95c5ce4e
...
...
@@ -23,7 +23,7 @@ from transformers import CLIPImageProcessor
from
...image_processor
import
VaeImageProcessor
from
...models
import
AutoencoderKL
,
UNet2DConditionModel
from
...schedulers
import
DDIMScheduler
,
LMSDiscreteScheduler
,
PNDMScheduler
from
...utils
import
deprecate
,
logging
from
...utils
import
deprecate
,
is_torch_xla_available
,
logging
from
...utils.torch_utils
import
randn_tensor
from
..pipeline_utils
import
DiffusionPipeline
,
StableDiffusionMixin
from
..stable_diffusion
import
StableDiffusionPipelineOutput
...
...
@@ -31,6 +31,13 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from
.image_encoder
import
PaintByExampleImageEncoder
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
...
@@ -604,6 +611,9 @@ class PaintByExamplePipeline(DiffusionPipeline, StableDiffusionMixin):
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
self
.
maybe_free_model_hooks
()
if
not
output_type
==
"latent"
:
...
...
src/diffusers/pipelines/pia/pipeline_pia.py
View file @
95c5ce4e
...
...
@@ -37,6 +37,7 @@ from ...schedulers import (
from
...utils
import
(
USE_PEFT_BACKEND
,
BaseOutput
,
is_torch_xla_available
,
logging
,
replace_example_docstring
,
scale_lora_layers
,
...
...
@@ -48,8 +49,16 @@ from ..free_init_utils import FreeInitMixin
from
..pipeline_utils
import
DiffusionPipeline
,
StableDiffusionMixin
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
Examples:
```py
...
...
@@ -928,6 +937,9 @@ class PIAPipeline(
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# 9. Post processing
if
output_type
==
"latent"
:
video
=
latents
...
...
src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
View file @
95c5ce4e
...
...
@@ -29,6 +29,7 @@ from ...utils import (
deprecate
,
is_bs4_available
,
is_ftfy_available
,
is_torch_xla_available
,
logging
,
replace_example_docstring
,
)
...
...
@@ -36,8 +37,16 @@ from ...utils.torch_utils import randn_tensor
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
if
is_bs4_available
():
from
bs4
import
BeautifulSoup
...
...
@@ -943,6 +952,9 @@ class PixArtAlphaPipeline(DiffusionPipeline):
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
not
output_type
==
"latent"
:
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
if
use_resolution_binning
:
...
...
src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py
View file @
95c5ce4e
...
...
@@ -29,6 +29,7 @@ from ...utils import (
deprecate
,
is_bs4_available
,
is_ftfy_available
,
is_torch_xla_available
,
logging
,
replace_example_docstring
,
)
...
...
@@ -41,8 +42,16 @@ from .pipeline_pixart_alpha import (
)
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
if
is_bs4_available
():
from
bs4
import
BeautifulSoup
...
...
@@ -854,6 +863,9 @@ class PixArtSigmaPipeline(DiffusionPipeline):
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
not
output_type
==
"latent"
:
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
if
use_resolution_binning
:
...
...
src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py
View file @
95c5ce4e
...
...
@@ -9,12 +9,19 @@ from ...image_processor import VaeImageProcessor
from
...models
import
AutoencoderKL
,
UNet2DConditionModel
from
...pipelines.stable_diffusion.safety_checker
import
StableDiffusionSafetyChecker
from
...schedulers
import
KarrasDiffusionSchedulers
from
...utils
import
deprecate
,
logging
from
...utils
import
deprecate
,
is_torch_xla_available
,
logging
from
...utils.torch_utils
import
randn_tensor
from
..pipeline_utils
import
DiffusionPipeline
,
StableDiffusionMixin
from
.pipeline_output
import
SemanticStableDiffusionPipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
...
@@ -701,6 +708,9 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# 8. Post-processing
if
not
output_type
==
"latent"
:
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment