Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
95c5ce4e
Unverified
Commit
95c5ce4e
authored
Jan 08, 2025
by
hlky
Committed by
GitHub
Jan 08, 2025
Browse files
PyTorch/XLA support (#10498)
Co-authored-by:
Sayak Paul
<
spsayakpaul@gmail.com
>
parent
c0964571
Changes
111
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
238 additions
and
11 deletions
+238
-11
src/diffusers/pipelines/allegro/pipeline_allegro.py
src/diffusers/pipelines/allegro/pipeline_allegro.py
+12
-0
src/diffusers/pipelines/amused/pipeline_amused.py
src/diffusers/pipelines/amused/pipeline_amused.py
+12
-1
src/diffusers/pipelines/amused/pipeline_amused_img2img.py
src/diffusers/pipelines/amused/pipeline_amused_img2img.py
+12
-1
src/diffusers/pipelines/amused/pipeline_amused_inpaint.py
src/diffusers/pipelines/amused/pipeline_amused_inpaint.py
+12
-1
src/diffusers/pipelines/animatediff/pipeline_animatediff.py
src/diffusers/pipelines/animatediff/pipeline_animatediff.py
+12
-0
src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py
.../pipelines/animatediff/pipeline_animatediff_controlnet.py
+12
-1
src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py
...fusers/pipelines/animatediff/pipeline_animatediff_sdxl.py
+12
-0
src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py
.../pipelines/animatediff/pipeline_animatediff_sparsectrl.py
+12
-0
src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py
...pipelines/animatediff/pipeline_animatediff_video2video.py
+12
-1
src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py
...nimatediff/pipeline_animatediff_video2video_controlnet.py
+12
-1
src/diffusers/pipelines/audioldm/pipeline_audioldm.py
src/diffusers/pipelines/audioldm/pipeline_audioldm.py
+12
-1
src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
+15
-0
src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py
...users/pipelines/blip_diffusion/pipeline_blip_diffusion.py
+12
-0
src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
+11
-1
src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
...sers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
+11
-1
src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
...sers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
+11
-0
src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
...sers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
+11
-1
src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
+11
-1
src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py
...pelines/consistency_models/pipeline_consistency_models.py
+11
-0
src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py
...ipelines/controlnet/pipeline_controlnet_blip_diffusion.py
+13
-0
No files found.
src/diffusers/pipelines/allegro/pipeline_allegro.py
View file @
95c5ce4e
...
@@ -33,6 +33,7 @@ from ...utils import (
...
@@ -33,6 +33,7 @@ from ...utils import (
deprecate
,
deprecate
,
is_bs4_available
,
is_bs4_available
,
is_ftfy_available
,
is_ftfy_available
,
is_torch_xla_available
,
logging
,
logging
,
replace_example_docstring
,
replace_example_docstring
,
)
)
...
@@ -41,6 +42,14 @@ from ...video_processor import VideoProcessor
...
@@ -41,6 +42,14 @@ from ...video_processor import VideoProcessor
from
.pipeline_output
import
AllegroPipelineOutput
from
.pipeline_output
import
AllegroPipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
if
is_bs4_available
():
if
is_bs4_available
():
...
@@ -921,6 +930,9 @@ class AllegroPipeline(DiffusionPipeline):
...
@@ -921,6 +930,9 @@ class AllegroPipeline(DiffusionPipeline):
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
progress_bar
.
update
()
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
not
output_type
==
"latent"
:
if
not
output_type
==
"latent"
:
latents
=
latents
.
to
(
self
.
vae
.
dtype
)
latents
=
latents
.
to
(
self
.
vae
.
dtype
)
video
=
self
.
decode_latents
(
latents
)
video
=
self
.
decode_latents
(
latents
)
...
...
src/diffusers/pipelines/amused/pipeline_amused.py
View file @
95c5ce4e
...
@@ -20,10 +20,18 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
...
@@ -20,10 +20,18 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from
...image_processor
import
VaeImageProcessor
from
...image_processor
import
VaeImageProcessor
from
...models
import
UVit2DModel
,
VQModel
from
...models
import
UVit2DModel
,
VQModel
from
...schedulers
import
AmusedScheduler
from
...schedulers
import
AmusedScheduler
from
...utils
import
replace_example_docstring
from
...utils
import
is_torch_xla_available
,
replace_example_docstring
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
EXAMPLE_DOC_STRING
=
"""
EXAMPLE_DOC_STRING
=
"""
Examples:
Examples:
```py
```py
...
@@ -299,6 +307,9 @@ class AmusedPipeline(DiffusionPipeline):
...
@@ -299,6 +307,9 @@ class AmusedPipeline(DiffusionPipeline):
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
timestep
,
latents
)
callback
(
step_idx
,
timestep
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
output_type
==
"latent"
:
if
output_type
==
"latent"
:
output
=
latents
output
=
latents
else
:
else
:
...
...
src/diffusers/pipelines/amused/pipeline_amused_img2img.py
View file @
95c5ce4e
...
@@ -20,10 +20,18 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
...
@@ -20,10 +20,18 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from
...image_processor
import
PipelineImageInput
,
VaeImageProcessor
from
...image_processor
import
PipelineImageInput
,
VaeImageProcessor
from
...models
import
UVit2DModel
,
VQModel
from
...models
import
UVit2DModel
,
VQModel
from
...schedulers
import
AmusedScheduler
from
...schedulers
import
AmusedScheduler
from
...utils
import
replace_example_docstring
from
...utils
import
is_torch_xla_available
,
replace_example_docstring
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
EXAMPLE_DOC_STRING
=
"""
EXAMPLE_DOC_STRING
=
"""
Examples:
Examples:
```py
```py
...
@@ -325,6 +333,9 @@ class AmusedImg2ImgPipeline(DiffusionPipeline):
...
@@ -325,6 +333,9 @@ class AmusedImg2ImgPipeline(DiffusionPipeline):
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
timestep
,
latents
)
callback
(
step_idx
,
timestep
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
output_type
==
"latent"
:
if
output_type
==
"latent"
:
output
=
latents
output
=
latents
else
:
else
:
...
...
src/diffusers/pipelines/amused/pipeline_amused_inpaint.py
View file @
95c5ce4e
...
@@ -21,10 +21,18 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
...
@@ -21,10 +21,18 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from
...image_processor
import
PipelineImageInput
,
VaeImageProcessor
from
...image_processor
import
PipelineImageInput
,
VaeImageProcessor
from
...models
import
UVit2DModel
,
VQModel
from
...models
import
UVit2DModel
,
VQModel
from
...schedulers
import
AmusedScheduler
from
...schedulers
import
AmusedScheduler
from
...utils
import
replace_example_docstring
from
...utils
import
is_torch_xla_available
,
replace_example_docstring
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
EXAMPLE_DOC_STRING
=
"""
EXAMPLE_DOC_STRING
=
"""
Examples:
Examples:
```py
```py
...
@@ -356,6 +364,9 @@ class AmusedInpaintPipeline(DiffusionPipeline):
...
@@ -356,6 +364,9 @@ class AmusedInpaintPipeline(DiffusionPipeline):
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
timestep
,
latents
)
callback
(
step_idx
,
timestep
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
output_type
==
"latent"
:
if
output_type
==
"latent"
:
output
=
latents
output
=
latents
else
:
else
:
...
...
src/diffusers/pipelines/animatediff/pipeline_animatediff.py
View file @
95c5ce4e
...
@@ -34,6 +34,7 @@ from ...schedulers import (
...
@@ -34,6 +34,7 @@ from ...schedulers import (
from
...utils
import
(
from
...utils
import
(
USE_PEFT_BACKEND
,
USE_PEFT_BACKEND
,
deprecate
,
deprecate
,
is_torch_xla_available
,
logging
,
logging
,
replace_example_docstring
,
replace_example_docstring
,
scale_lora_layers
,
scale_lora_layers
,
...
@@ -47,8 +48,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
...
@@ -47,8 +48,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from
.pipeline_output
import
AnimateDiffPipelineOutput
from
.pipeline_output
import
AnimateDiffPipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
EXAMPLE_DOC_STRING
=
"""
Examples:
Examples:
```py
```py
...
@@ -844,6 +853,9 @@ class AnimateDiffPipeline(
...
@@ -844,6 +853,9 @@ class AnimateDiffPipeline(
if
callback
is
not
None
and
i
%
callback_steps
==
0
:
if
callback
is
not
None
and
i
%
callback_steps
==
0
:
callback
(
i
,
t
,
latents
)
callback
(
i
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# 9. Post processing
# 9. Post processing
if
output_type
==
"latent"
:
if
output_type
==
"latent"
:
video
=
latents
video
=
latents
...
...
src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py
View file @
95c5ce4e
...
@@ -32,7 +32,7 @@ from ...models import (
...
@@ -32,7 +32,7 @@ from ...models import (
from
...models.lora
import
adjust_lora_scale_text_encoder
from
...models.lora
import
adjust_lora_scale_text_encoder
from
...models.unets.unet_motion_model
import
MotionAdapter
from
...models.unets.unet_motion_model
import
MotionAdapter
from
...schedulers
import
KarrasDiffusionSchedulers
from
...schedulers
import
KarrasDiffusionSchedulers
from
...utils
import
USE_PEFT_BACKEND
,
logging
,
scale_lora_layers
,
unscale_lora_layers
from
...utils
import
USE_PEFT_BACKEND
,
is_torch_xla_available
,
logging
,
scale_lora_layers
,
unscale_lora_layers
from
...utils.torch_utils
import
is_compiled_module
,
randn_tensor
from
...utils.torch_utils
import
is_compiled_module
,
randn_tensor
from
...video_processor
import
VideoProcessor
from
...video_processor
import
VideoProcessor
from
..free_init_utils
import
FreeInitMixin
from
..free_init_utils
import
FreeInitMixin
...
@@ -41,8 +41,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
...
@@ -41,8 +41,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from
.pipeline_output
import
AnimateDiffPipelineOutput
from
.pipeline_output
import
AnimateDiffPipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
EXAMPLE_DOC_STRING
=
"""
Examples:
Examples:
```py
```py
...
@@ -1090,6 +1098,9 @@ class AnimateDiffControlNetPipeline(
...
@@ -1090,6 +1098,9 @@ class AnimateDiffControlNetPipeline(
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
progress_bar
.
update
()
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# 9. Post processing
# 9. Post processing
if
output_type
==
"latent"
:
if
output_type
==
"latent"
:
video
=
latents
video
=
latents
...
...
src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py
View file @
95c5ce4e
...
@@ -48,6 +48,7 @@ from ...schedulers import (
...
@@ -48,6 +48,7 @@ from ...schedulers import (
)
)
from
...utils
import
(
from
...utils
import
(
USE_PEFT_BACKEND
,
USE_PEFT_BACKEND
,
is_torch_xla_available
,
logging
,
logging
,
replace_example_docstring
,
replace_example_docstring
,
scale_lora_layers
,
scale_lora_layers
,
...
@@ -60,8 +61,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
...
@@ -60,8 +61,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from
.pipeline_output
import
AnimateDiffPipelineOutput
from
.pipeline_output
import
AnimateDiffPipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
EXAMPLE_DOC_STRING
=
"""
Examples:
Examples:
```py
```py
...
@@ -1265,6 +1274,9 @@ class AnimateDiffSDXLPipeline(
...
@@ -1265,6 +1274,9 @@ class AnimateDiffSDXLPipeline(
progress_bar
.
update
()
progress_bar
.
update
()
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# make sure the VAE is in float32 mode, as it overflows in float16
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting
=
self
.
vae
.
dtype
==
torch
.
float16
and
self
.
vae
.
config
.
force_upcast
needs_upcasting
=
self
.
vae
.
dtype
==
torch
.
float16
and
self
.
vae
.
config
.
force_upcast
...
...
src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py
View file @
95c5ce4e
...
@@ -30,6 +30,7 @@ from ...models.unets.unet_motion_model import MotionAdapter
...
@@ -30,6 +30,7 @@ from ...models.unets.unet_motion_model import MotionAdapter
from
...schedulers
import
KarrasDiffusionSchedulers
from
...schedulers
import
KarrasDiffusionSchedulers
from
...utils
import
(
from
...utils
import
(
USE_PEFT_BACKEND
,
USE_PEFT_BACKEND
,
is_torch_xla_available
,
logging
,
logging
,
replace_example_docstring
,
replace_example_docstring
,
scale_lora_layers
,
scale_lora_layers
,
...
@@ -42,8 +43,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
...
@@ -42,8 +43,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from
.pipeline_output
import
AnimateDiffPipelineOutput
from
.pipeline_output
import
AnimateDiffPipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
EXAMPLE_DOC_STRING
=
"""
Examples:
Examples:
```python
```python
...
@@ -994,6 +1003,9 @@ class AnimateDiffSparseControlNetPipeline(
...
@@ -994,6 +1003,9 @@ class AnimateDiffSparseControlNetPipeline(
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
progress_bar
.
update
()
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# 11. Post processing
# 11. Post processing
if
output_type
==
"latent"
:
if
output_type
==
"latent"
:
video
=
latents
video
=
latents
...
...
src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py
View file @
95c5ce4e
...
@@ -31,7 +31,7 @@ from ...schedulers import (
...
@@ -31,7 +31,7 @@ from ...schedulers import (
LMSDiscreteScheduler
,
LMSDiscreteScheduler
,
PNDMScheduler
,
PNDMScheduler
,
)
)
from
...utils
import
USE_PEFT_BACKEND
,
logging
,
scale_lora_layers
,
unscale_lora_layers
from
...utils
import
USE_PEFT_BACKEND
,
is_torch_xla_available
,
logging
,
scale_lora_layers
,
unscale_lora_layers
from
...utils.torch_utils
import
randn_tensor
from
...utils.torch_utils
import
randn_tensor
from
...video_processor
import
VideoProcessor
from
...video_processor
import
VideoProcessor
from
..free_init_utils
import
FreeInitMixin
from
..free_init_utils
import
FreeInitMixin
...
@@ -40,8 +40,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
...
@@ -40,8 +40,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from
.pipeline_output
import
AnimateDiffPipelineOutput
from
.pipeline_output
import
AnimateDiffPipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
EXAMPLE_DOC_STRING
=
"""
Examples:
Examples:
```py
```py
...
@@ -1037,6 +1045,9 @@ class AnimateDiffVideoToVideoPipeline(
...
@@ -1037,6 +1045,9 @@ class AnimateDiffVideoToVideoPipeline(
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
progress_bar
.
update
()
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# 10. Post-processing
# 10. Post-processing
if
output_type
==
"latent"
:
if
output_type
==
"latent"
:
video
=
latents
video
=
latents
...
...
src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py
View file @
95c5ce4e
...
@@ -39,7 +39,7 @@ from ...schedulers import (
...
@@ -39,7 +39,7 @@ from ...schedulers import (
LMSDiscreteScheduler
,
LMSDiscreteScheduler
,
PNDMScheduler
,
PNDMScheduler
,
)
)
from
...utils
import
USE_PEFT_BACKEND
,
logging
,
scale_lora_layers
,
unscale_lora_layers
from
...utils
import
USE_PEFT_BACKEND
,
is_torch_xla_available
,
logging
,
scale_lora_layers
,
unscale_lora_layers
from
...utils.torch_utils
import
is_compiled_module
,
randn_tensor
from
...utils.torch_utils
import
is_compiled_module
,
randn_tensor
from
...video_processor
import
VideoProcessor
from
...video_processor
import
VideoProcessor
from
..free_init_utils
import
FreeInitMixin
from
..free_init_utils
import
FreeInitMixin
...
@@ -48,8 +48,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
...
@@ -48,8 +48,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from
.pipeline_output
import
AnimateDiffPipelineOutput
from
.pipeline_output
import
AnimateDiffPipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
EXAMPLE_DOC_STRING
=
"""
Examples:
Examples:
```py
```py
...
@@ -1325,6 +1333,9 @@ class AnimateDiffVideoToVideoControlNetPipeline(
...
@@ -1325,6 +1333,9 @@ class AnimateDiffVideoToVideoControlNetPipeline(
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
progress_bar
.
update
()
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# 11. Post-processing
# 11. Post-processing
if
output_type
==
"latent"
:
if
output_type
==
"latent"
:
video
=
latents
video
=
latents
...
...
src/diffusers/pipelines/audioldm/pipeline_audioldm.py
View file @
95c5ce4e
...
@@ -22,13 +22,21 @@ from transformers import ClapTextModelWithProjection, RobertaTokenizer, RobertaT
...
@@ -22,13 +22,21 @@ from transformers import ClapTextModelWithProjection, RobertaTokenizer, RobertaT
from
...models
import
AutoencoderKL
,
UNet2DConditionModel
from
...models
import
AutoencoderKL
,
UNet2DConditionModel
from
...schedulers
import
KarrasDiffusionSchedulers
from
...schedulers
import
KarrasDiffusionSchedulers
from
...utils
import
logging
,
replace_example_docstring
from
...utils
import
is_torch_xla_available
,
logging
,
replace_example_docstring
from
...utils.torch_utils
import
randn_tensor
from
...utils.torch_utils
import
randn_tensor
from
..pipeline_utils
import
AudioPipelineOutput
,
DiffusionPipeline
,
StableDiffusionMixin
from
..pipeline_utils
import
AudioPipelineOutput
,
DiffusionPipeline
,
StableDiffusionMixin
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
EXAMPLE_DOC_STRING
=
"""
Examples:
Examples:
```py
```py
...
@@ -530,6 +538,9 @@ class AudioLDMPipeline(DiffusionPipeline, StableDiffusionMixin):
...
@@ -530,6 +538,9 @@ class AudioLDMPipeline(DiffusionPipeline, StableDiffusionMixin):
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# 8. Post-processing
# 8. Post-processing
mel_spectrogram
=
self
.
decode_latents
(
latents
)
mel_spectrogram
=
self
.
decode_latents
(
latents
)
...
...
src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
View file @
95c5ce4e
...
@@ -48,8 +48,20 @@ from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditi
...
@@ -48,8 +48,20 @@ from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditi
if
is_librosa_available
():
if
is_librosa_available
():
import
librosa
import
librosa
from
...utils
import
is_torch_xla_available
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
EXAMPLE_DOC_STRING
=
"""
Examples:
Examples:
```py
```py
...
@@ -1033,6 +1045,9 @@ class AudioLDM2Pipeline(DiffusionPipeline):
...
@@ -1033,6 +1045,9 @@ class AudioLDM2Pipeline(DiffusionPipeline):
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
self
.
maybe_free_model_hooks
()
self
.
maybe_free_model_hooks
()
# 8. Post-processing
# 8. Post-processing
...
...
src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py
View file @
95c5ce4e
...
@@ -20,6 +20,7 @@ from transformers import CLIPTokenizer
...
@@ -20,6 +20,7 @@ from transformers import CLIPTokenizer
from
...models
import
AutoencoderKL
,
UNet2DConditionModel
from
...models
import
AutoencoderKL
,
UNet2DConditionModel
from
...schedulers
import
PNDMScheduler
from
...schedulers
import
PNDMScheduler
from
...utils
import
(
from
...utils
import
(
is_torch_xla_available
,
logging
,
logging
,
replace_example_docstring
,
replace_example_docstring
,
)
)
...
@@ -30,8 +31,16 @@ from .modeling_blip2 import Blip2QFormerModel
...
@@ -30,8 +31,16 @@ from .modeling_blip2 import Blip2QFormerModel
from
.modeling_ctx_clip
import
ContextCLIPTextModel
from
.modeling_ctx_clip
import
ContextCLIPTextModel
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
EXAMPLE_DOC_STRING
=
"""
Examples:
Examples:
```py
```py
...
@@ -336,6 +345,9 @@ class BlipDiffusionPipeline(DiffusionPipeline):
...
@@ -336,6 +345,9 @@ class BlipDiffusionPipeline(DiffusionPipeline):
latents
,
latents
,
)[
"prev_sample"
]
)[
"prev_sample"
]
if
XLA_AVAILABLE
:
xm
.
mark_step
()
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
image
=
self
.
image_processor
.
postprocess
(
image
,
output_type
=
output_type
)
image
=
self
.
image_processor
.
postprocess
(
image
,
output_type
=
output_type
)
...
...
src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
View file @
95c5ce4e
...
@@ -26,12 +26,19 @@ from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
...
@@ -26,12 +26,19 @@ from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
from
...models.embeddings
import
get_3d_rotary_pos_embed
from
...models.embeddings
import
get_3d_rotary_pos_embed
from
...pipelines.pipeline_utils
import
DiffusionPipeline
from
...pipelines.pipeline_utils
import
DiffusionPipeline
from
...schedulers
import
CogVideoXDDIMScheduler
,
CogVideoXDPMScheduler
from
...schedulers
import
CogVideoXDDIMScheduler
,
CogVideoXDPMScheduler
from
...utils
import
logging
,
replace_example_docstring
from
...utils
import
is_torch_xla_available
,
logging
,
replace_example_docstring
from
...utils.torch_utils
import
randn_tensor
from
...utils.torch_utils
import
randn_tensor
from
...video_processor
import
VideoProcessor
from
...video_processor
import
VideoProcessor
from
.pipeline_output
import
CogVideoXPipelineOutput
from
.pipeline_output
import
CogVideoXPipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
@@ -753,6 +760,9 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
...
@@ -753,6 +760,9 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
progress_bar
.
update
()
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
not
output_type
==
"latent"
:
if
not
output_type
==
"latent"
:
# Discard any padding frames that were added for CogVideoX 1.5
# Discard any padding frames that were added for CogVideoX 1.5
latents
=
latents
[:,
additional_frames
:]
latents
=
latents
[:,
additional_frames
:]
...
...
src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
View file @
95c5ce4e
...
@@ -27,12 +27,19 @@ from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
...
@@ -27,12 +27,19 @@ from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
from
...models.embeddings
import
get_3d_rotary_pos_embed
from
...models.embeddings
import
get_3d_rotary_pos_embed
from
...pipelines.pipeline_utils
import
DiffusionPipeline
from
...pipelines.pipeline_utils
import
DiffusionPipeline
from
...schedulers
import
KarrasDiffusionSchedulers
from
...schedulers
import
KarrasDiffusionSchedulers
from
...utils
import
logging
,
replace_example_docstring
from
...utils
import
is_torch_xla_available
,
logging
,
replace_example_docstring
from
...utils.torch_utils
import
randn_tensor
from
...utils.torch_utils
import
randn_tensor
from
...video_processor
import
VideoProcessor
from
...video_processor
import
VideoProcessor
from
.pipeline_output
import
CogVideoXPipelineOutput
from
.pipeline_output
import
CogVideoXPipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
@@ -808,6 +815,9 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
...
@@ -808,6 +815,9 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
progress_bar
.
update
()
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
not
output_type
==
"latent"
:
if
not
output_type
==
"latent"
:
video
=
self
.
decode_latents
(
latents
)
video
=
self
.
decode_latents
(
latents
)
video
=
self
.
video_processor
.
postprocess_video
(
video
=
video
,
output_type
=
output_type
)
video
=
self
.
video_processor
.
postprocess_video
(
video
=
video
,
output_type
=
output_type
)
...
...
src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
View file @
95c5ce4e
...
@@ -29,6 +29,7 @@ from ...models.embeddings import get_3d_rotary_pos_embed
...
@@ -29,6 +29,7 @@ from ...models.embeddings import get_3d_rotary_pos_embed
from
...pipelines.pipeline_utils
import
DiffusionPipeline
from
...pipelines.pipeline_utils
import
DiffusionPipeline
from
...schedulers
import
CogVideoXDDIMScheduler
,
CogVideoXDPMScheduler
from
...schedulers
import
CogVideoXDDIMScheduler
,
CogVideoXDPMScheduler
from
...utils
import
(
from
...utils
import
(
is_torch_xla_available
,
logging
,
logging
,
replace_example_docstring
,
replace_example_docstring
,
)
)
...
@@ -37,6 +38,13 @@ from ...video_processor import VideoProcessor
...
@@ -37,6 +38,13 @@ from ...video_processor import VideoProcessor
from
.pipeline_output
import
CogVideoXPipelineOutput
from
.pipeline_output
import
CogVideoXPipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
@@ -866,6 +874,9 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
...
@@ -866,6 +874,9 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
progress_bar
.
update
()
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
not
output_type
==
"latent"
:
if
not
output_type
==
"latent"
:
# Discard any padding frames that were added for CogVideoX 1.5
# Discard any padding frames that were added for CogVideoX 1.5
latents
=
latents
[:,
additional_frames
:]
latents
=
latents
[:,
additional_frames
:]
...
...
src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
View file @
95c5ce4e
...
@@ -27,12 +27,19 @@ from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
...
@@ -27,12 +27,19 @@ from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
from
...models.embeddings
import
get_3d_rotary_pos_embed
from
...models.embeddings
import
get_3d_rotary_pos_embed
from
...pipelines.pipeline_utils
import
DiffusionPipeline
from
...pipelines.pipeline_utils
import
DiffusionPipeline
from
...schedulers
import
CogVideoXDDIMScheduler
,
CogVideoXDPMScheduler
from
...schedulers
import
CogVideoXDDIMScheduler
,
CogVideoXDPMScheduler
from
...utils
import
logging
,
replace_example_docstring
from
...utils
import
is_torch_xla_available
,
logging
,
replace_example_docstring
from
...utils.torch_utils
import
randn_tensor
from
...utils.torch_utils
import
randn_tensor
from
...video_processor
import
VideoProcessor
from
...video_processor
import
VideoProcessor
from
.pipeline_output
import
CogVideoXPipelineOutput
from
.pipeline_output
import
CogVideoXPipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
@@ -834,6 +841,9 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
...
@@ -834,6 +841,9 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
progress_bar
.
update
()
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
not
output_type
==
"latent"
:
if
not
output_type
==
"latent"
:
video
=
self
.
decode_latents
(
latents
)
video
=
self
.
decode_latents
(
latents
)
video
=
self
.
video_processor
.
postprocess_video
(
video
=
video
,
output_type
=
output_type
)
video
=
self
.
video_processor
.
postprocess_video
(
video
=
video
,
output_type
=
output_type
)
...
...
src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
View file @
95c5ce4e
...
@@ -24,11 +24,18 @@ from ...image_processor import VaeImageProcessor
...
@@ -24,11 +24,18 @@ from ...image_processor import VaeImageProcessor
from
...models
import
AutoencoderKL
,
CogView3PlusTransformer2DModel
from
...models
import
AutoencoderKL
,
CogView3PlusTransformer2DModel
from
...pipelines.pipeline_utils
import
DiffusionPipeline
from
...pipelines.pipeline_utils
import
DiffusionPipeline
from
...schedulers
import
CogVideoXDDIMScheduler
,
CogVideoXDPMScheduler
from
...schedulers
import
CogVideoXDDIMScheduler
,
CogVideoXDPMScheduler
from
...utils
import
logging
,
replace_example_docstring
from
...utils
import
is_torch_xla_available
,
logging
,
replace_example_docstring
from
...utils.torch_utils
import
randn_tensor
from
...utils.torch_utils
import
randn_tensor
from
.pipeline_output
import
CogView3PipelineOutput
from
.pipeline_output
import
CogView3PipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
@@ -654,6 +661,9 @@ class CogView3PlusPipeline(DiffusionPipeline):
...
@@ -654,6 +661,9 @@ class CogView3PlusPipeline(DiffusionPipeline):
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
progress_bar
.
update
()
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
not
output_type
==
"latent"
:
if
not
output_type
==
"latent"
:
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
,
generator
=
generator
)[
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
,
generator
=
generator
)[
0
0
...
...
src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py
View file @
95c5ce4e
...
@@ -19,6 +19,7 @@ import torch
...
@@ -19,6 +19,7 @@ import torch
from
...models
import
UNet2DModel
from
...models
import
UNet2DModel
from
...schedulers
import
CMStochasticIterativeScheduler
from
...schedulers
import
CMStochasticIterativeScheduler
from
...utils
import
(
from
...utils
import
(
is_torch_xla_available
,
logging
,
logging
,
replace_example_docstring
,
replace_example_docstring
,
)
)
...
@@ -26,6 +27,13 @@ from ...utils.torch_utils import randn_tensor
...
@@ -26,6 +27,13 @@ from ...utils.torch_utils import randn_tensor
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
@@ -263,6 +271,9 @@ class ConsistencyModelPipeline(DiffusionPipeline):
...
@@ -263,6 +271,9 @@ class ConsistencyModelPipeline(DiffusionPipeline):
if
callback
is
not
None
and
i
%
callback_steps
==
0
:
if
callback
is
not
None
and
i
%
callback_steps
==
0
:
callback
(
i
,
t
,
sample
)
callback
(
i
,
t
,
sample
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# 6. Post-process image sample
# 6. Post-process image sample
image
=
self
.
postprocess_image
(
sample
,
output_type
=
output_type
)
image
=
self
.
postprocess_image
(
sample
,
output_type
=
output_type
)
...
...
src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py
View file @
95c5ce4e
...
@@ -21,6 +21,7 @@ from transformers import CLIPTokenizer
...
@@ -21,6 +21,7 @@ from transformers import CLIPTokenizer
from
...models
import
AutoencoderKL
,
ControlNetModel
,
UNet2DConditionModel
from
...models
import
AutoencoderKL
,
ControlNetModel
,
UNet2DConditionModel
from
...schedulers
import
PNDMScheduler
from
...schedulers
import
PNDMScheduler
from
...utils
import
(
from
...utils
import
(
is_torch_xla_available
,
logging
,
logging
,
replace_example_docstring
,
replace_example_docstring
,
)
)
...
@@ -31,8 +32,16 @@ from ..blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
...
@@ -31,8 +32,16 @@ from ..blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
EXAMPLE_DOC_STRING
=
"""
Examples:
Examples:
```py
```py
...
@@ -401,6 +410,10 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
...
@@ -401,6 +410,10 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
t
,
t
,
latents
,
latents
,
)[
"prev_sample"
]
)[
"prev_sample"
]
if
XLA_AVAILABLE
:
xm
.
mark_step
()
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
image
=
self
.
image_processor
.
postprocess
(
image
,
output_type
=
output_type
)
image
=
self
.
image_processor
.
postprocess
(
image
,
output_type
=
output_type
)
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment