Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
95c5ce4e
Unverified
Commit
95c5ce4e
authored
Jan 08, 2025
by
hlky
Committed by
GitHub
Jan 08, 2025
Browse files
PyTorch/XLA support (#10498)
Co-authored-by:
Sayak Paul
<
spsayakpaul@gmail.com
>
parent
c0964571
Changes
111
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
240 additions
and
5 deletions
+240
-5
src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py
...ffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py
+12
-1
src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py
src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py
+12
-0
src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py
src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py
+12
-0
src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py
...ffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py
+12
-0
src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
...ffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
+12
-0
src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py
...diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py
+12
-0
src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py
...diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py
+12
-1
src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py
...ipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py
+13
-0
src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py
.../kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py
+12
-0
src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py
...s/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py
+12
-1
src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py
...ipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py
+12
-1
src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py
...ers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py
+12
-0
src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py
...lines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py
+12
-0
src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py
src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py
+12
-0
src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py
...users/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py
+12
-0
src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
...consistency_models/pipeline_latent_consistency_img2img.py
+11
-0
src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py
...onsistency_models/pipeline_latent_consistency_text2img.py
+12
-0
src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
...s/pipelines/latent_diffusion/pipeline_latent_diffusion.py
+12
-0
src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
...nt_diffusion/pipeline_latent_diffusion_superresolution.py
+12
-1
src/diffusers/pipelines/latte/pipeline_latte.py
src/diffusers/pipelines/latte/pipeline_latte.py
+12
-0
No files found.
src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py
View file @
95c5ce4e
...
...
@@ -23,15 +23,23 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from
...loaders
import
HunyuanVideoLoraLoaderMixin
from
...models
import
AutoencoderKLHunyuanVideo
,
HunyuanVideoTransformer3DModel
from
...schedulers
import
FlowMatchEulerDiscreteScheduler
from
...utils
import
logging
,
replace_example_docstring
from
...utils
import
is_torch_xla_available
,
logging
,
replace_example_docstring
from
...utils.torch_utils
import
randn_tensor
from
...video_processor
import
VideoProcessor
from
..pipeline_utils
import
DiffusionPipeline
from
.pipeline_output
import
HunyuanVideoPipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
Examples:
```python
...
...
@@ -667,6 +675,9 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
not
output_type
==
"latent"
:
latents
=
latents
.
to
(
self
.
vae
.
dtype
)
/
self
.
vae
.
config
.
scaling_factor
video
=
self
.
vae
.
decode
(
latents
,
return_dict
=
False
)[
0
]
...
...
src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py
View file @
95c5ce4e
...
...
@@ -27,6 +27,7 @@ from ...models.unets.unet_i2vgen_xl import I2VGenXLUNet
from
...schedulers
import
DDIMScheduler
from
...utils
import
(
BaseOutput
,
is_torch_xla_available
,
logging
,
replace_example_docstring
,
)
...
...
@@ -35,8 +36,16 @@ from ...video_processor import VideoProcessor
from
..pipeline_utils
import
DiffusionPipeline
,
StableDiffusionMixin
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
Examples:
```py
...
...
@@ -711,6 +720,9 @@ class I2VGenXLPipeline(
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# 8. Post processing
if
output_type
==
"latent"
:
video
=
latents
...
...
src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py
View file @
95c5ce4e
...
...
@@ -22,6 +22,7 @@ from transformers import (
from
...models
import
UNet2DConditionModel
,
VQModel
from
...schedulers
import
DDIMScheduler
,
DDPMScheduler
from
...utils
import
(
is_torch_xla_available
,
logging
,
replace_example_docstring
,
)
...
...
@@ -30,8 +31,16 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from
.text_encoder
import
MultilingualCLIP
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
Examples:
```py
...
...
@@ -385,6 +394,9 @@ class KandinskyPipeline(DiffusionPipeline):
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# post-processing
image
=
self
.
movq
.
decode
(
latents
,
force_not_quantize
=
True
)[
"sample"
]
...
...
src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py
View file @
95c5ce4e
...
...
@@ -25,6 +25,7 @@ from transformers import (
from
...models
import
UNet2DConditionModel
,
VQModel
from
...schedulers
import
DDIMScheduler
from
...utils
import
(
is_torch_xla_available
,
logging
,
replace_example_docstring
,
)
...
...
@@ -33,8 +34,16 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from
.text_encoder
import
MultilingualCLIP
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
Examples:
```py
...
...
@@ -478,6 +487,9 @@ class KandinskyImg2ImgPipeline(DiffusionPipeline):
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# 7. post-processing
image
=
self
.
movq
.
decode
(
latents
,
force_not_quantize
=
True
)[
"sample"
]
...
...
src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
View file @
95c5ce4e
...
...
@@ -29,6 +29,7 @@ from ... import __version__
from
...models
import
UNet2DConditionModel
,
VQModel
from
...schedulers
import
DDIMScheduler
from
...utils
import
(
is_torch_xla_available
,
logging
,
replace_example_docstring
,
)
...
...
@@ -37,8 +38,16 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from
.text_encoder
import
MultilingualCLIP
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
Examples:
```py
...
...
@@ -613,6 +622,9 @@ class KandinskyInpaintPipeline(DiffusionPipeline):
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# post-processing
image
=
self
.
movq
.
decode
(
latents
,
force_not_quantize
=
True
)[
"sample"
]
...
...
src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py
View file @
95c5ce4e
...
...
@@ -24,6 +24,7 @@ from ...models import PriorTransformer
from
...schedulers
import
UnCLIPScheduler
from
...utils
import
(
BaseOutput
,
is_torch_xla_available
,
logging
,
replace_example_docstring
,
)
...
...
@@ -31,8 +32,16 @@ from ...utils.torch_utils import randn_tensor
from
..pipeline_utils
import
DiffusionPipeline
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
Examples:
```py
...
...
@@ -519,6 +528,9 @@ class KandinskyPriorPipeline(DiffusionPipeline):
prev_timestep
=
prev_timestep
,
).
prev_sample
if
XLA_AVAILABLE
:
xm
.
mark_step
()
latents
=
self
.
prior
.
post_process_latents
(
latents
)
image_embeddings
=
latents
...
...
src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py
View file @
95c5ce4e
...
...
@@ -18,13 +18,21 @@ import torch
from
...models
import
UNet2DConditionModel
,
VQModel
from
...schedulers
import
DDPMScheduler
from
...utils
import
deprecate
,
logging
,
replace_example_docstring
from
...utils
import
deprecate
,
is_torch_xla_available
,
logging
,
replace_example_docstring
from
...utils.torch_utils
import
randn_tensor
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
Examples:
```py
...
...
@@ -296,6 +304,9 @@ class KandinskyV22Pipeline(DiffusionPipeline):
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
output_type
not
in
[
"pt"
,
"np"
,
"pil"
,
"latent"
]:
raise
ValueError
(
f
"Only the output types `pt`, `pil` and `np` are supported not output_type=
{
output_type
}
"
)
...
...
src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py
View file @
95c5ce4e
...
...
@@ -19,14 +19,23 @@ import torch
from
...models
import
UNet2DConditionModel
,
VQModel
from
...schedulers
import
DDPMScheduler
from
...utils
import
(
is_torch_xla_available
,
logging
,
)
from
...utils.torch_utils
import
randn_tensor
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
Examples:
```py
...
...
@@ -297,6 +306,10 @@ class KandinskyV22ControlnetPipeline(DiffusionPipeline):
if
callback
is
not
None
and
i
%
callback_steps
==
0
:
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# post-processing
image
=
self
.
movq
.
decode
(
latents
,
force_not_quantize
=
True
)[
"sample"
]
...
...
src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py
View file @
95c5ce4e
...
...
@@ -22,14 +22,23 @@ from PIL import Image
from
...models
import
UNet2DConditionModel
,
VQModel
from
...schedulers
import
DDPMScheduler
from
...utils
import
(
is_torch_xla_available
,
logging
,
)
from
...utils.torch_utils
import
randn_tensor
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
Examples:
```py
...
...
@@ -358,6 +367,9 @@ class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline):
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# post-processing
image
=
self
.
movq
.
decode
(
latents
,
force_not_quantize
=
True
)[
"sample"
]
...
...
src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py
View file @
95c5ce4e
...
...
@@ -21,13 +21,21 @@ from PIL import Image
from
...models
import
UNet2DConditionModel
,
VQModel
from
...schedulers
import
DDPMScheduler
from
...utils
import
deprecate
,
logging
from
...utils
import
deprecate
,
is_torch_xla_available
,
logging
from
...utils.torch_utils
import
randn_tensor
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
Examples:
```py
...
...
@@ -372,6 +380,9 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
output_type
not
in
[
"pt"
,
"np"
,
"pil"
,
"latent"
]:
raise
ValueError
(
f
"Only the output types `pt`, `pil` ,`np` and `latent` are supported not output_type=
{
output_type
}
"
...
...
src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py
View file @
95c5ce4e
...
...
@@ -25,13 +25,21 @@ from PIL import Image
from
...
import
__version__
from
...models
import
UNet2DConditionModel
,
VQModel
from
...schedulers
import
DDPMScheduler
from
...utils
import
deprecate
,
logging
from
...utils
import
deprecate
,
is_torch_xla_available
,
logging
from
...utils.torch_utils
import
randn_tensor
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
Examples:
```py
...
...
@@ -526,6 +534,9 @@ class KandinskyV22InpaintPipeline(DiffusionPipeline):
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# post-processing
latents
=
mask_image
[:
1
]
*
image
[:
1
]
+
(
1
-
mask_image
[:
1
])
*
latents
...
...
src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py
View file @
95c5ce4e
...
...
@@ -7,6 +7,7 @@ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTo
from
...models
import
PriorTransformer
from
...schedulers
import
UnCLIPScheduler
from
...utils
import
(
is_torch_xla_available
,
logging
,
replace_example_docstring
,
)
...
...
@@ -15,8 +16,16 @@ from ..kandinsky import KandinskyPriorPipelineOutput
from
..pipeline_utils
import
DiffusionPipeline
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
Examples:
```py
...
...
@@ -524,6 +533,9 @@ class KandinskyV22PriorPipeline(DiffusionPipeline):
)
text_mask
=
callback_outputs
.
pop
(
"text_mask"
,
text_mask
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
latents
=
self
.
prior
.
post_process_latents
(
latents
)
image_embeddings
=
latents
...
...
src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py
View file @
95c5ce4e
...
...
@@ -7,6 +7,7 @@ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTo
from
...models
import
PriorTransformer
from
...schedulers
import
UnCLIPScheduler
from
...utils
import
(
is_torch_xla_available
,
logging
,
replace_example_docstring
,
)
...
...
@@ -15,8 +16,16 @@ from ..kandinsky import KandinskyPriorPipelineOutput
from
..pipeline_utils
import
DiffusionPipeline
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
Examples:
```py
...
...
@@ -538,6 +547,9 @@ class KandinskyV22PriorEmb2EmbPipeline(DiffusionPipeline):
prev_timestep
=
prev_timestep
,
).
prev_sample
if
XLA_AVAILABLE
:
xm
.
mark_step
()
latents
=
self
.
prior
.
post_process_latents
(
latents
)
image_embeddings
=
latents
...
...
src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py
View file @
95c5ce4e
...
...
@@ -8,6 +8,7 @@ from ...models import Kandinsky3UNet, VQModel
from
...schedulers
import
DDPMScheduler
from
...utils
import
(
deprecate
,
is_torch_xla_available
,
logging
,
replace_example_docstring
,
)
...
...
@@ -15,8 +16,16 @@ from ...utils.torch_utils import randn_tensor
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
Examples:
```py
...
...
@@ -549,6 +558,9 @@ class Kandinsky3Pipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin):
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# post-processing
if
output_type
not
in
[
"pt"
,
"np"
,
"pil"
,
"latent"
]:
raise
ValueError
(
...
...
src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py
View file @
95c5ce4e
...
...
@@ -12,6 +12,7 @@ from ...models import Kandinsky3UNet, VQModel
from
...schedulers
import
DDPMScheduler
from
...utils
import
(
deprecate
,
is_torch_xla_available
,
logging
,
replace_example_docstring
,
)
...
...
@@ -19,8 +20,16 @@ from ...utils.torch_utils import randn_tensor
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
Examples:
```py
...
...
@@ -617,6 +626,9 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixi
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# post-processing
if
output_type
not
in
[
"pt"
,
"np"
,
"pil"
,
"latent"
]:
raise
ValueError
(
...
...
src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
View file @
95c5ce4e
...
...
@@ -30,6 +30,7 @@ from ...schedulers import LCMScheduler
from
...utils
import
(
USE_PEFT_BACKEND
,
deprecate
,
is_torch_xla_available
,
logging
,
replace_example_docstring
,
scale_lora_layers
,
...
...
@@ -40,6 +41,13 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from
..stable_diffusion
import
StableDiffusionPipelineOutput
,
StableDiffusionSafetyChecker
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
...
@@ -952,6 +960,9 @@ class LatentConsistencyModelImg2ImgPipeline(
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
denoised
=
denoised
.
to
(
prompt_embeds
.
dtype
)
if
not
output_type
==
"latent"
:
image
=
self
.
vae
.
decode
(
denoised
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
...
...
src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py
View file @
95c5ce4e
...
...
@@ -29,6 +29,7 @@ from ...schedulers import LCMScheduler
from
...utils
import
(
USE_PEFT_BACKEND
,
deprecate
,
is_torch_xla_available
,
logging
,
replace_example_docstring
,
scale_lora_layers
,
...
...
@@ -39,8 +40,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from
..stable_diffusion
import
StableDiffusionPipelineOutput
,
StableDiffusionSafetyChecker
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
Examples:
```py
...
...
@@ -881,6 +890,9 @@ class LatentConsistencyModelPipeline(
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
denoised
=
denoised
.
to
(
prompt_embeds
.
dtype
)
if
not
output_type
==
"latent"
:
image
=
self
.
vae
.
decode
(
denoised
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
...
...
src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
View file @
95c5ce4e
...
...
@@ -25,10 +25,19 @@ from transformers.utils import logging
from
...models
import
AutoencoderKL
,
UNet2DConditionModel
,
UNet2DModel
,
VQModel
from
...schedulers
import
DDIMScheduler
,
LMSDiscreteScheduler
,
PNDMScheduler
from
...utils
import
is_torch_xla_available
from
...utils.torch_utils
import
randn_tensor
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
class
LDMTextToImagePipeline
(
DiffusionPipeline
):
r
"""
Pipeline for text-to-image generation using latent diffusion.
...
...
@@ -202,6 +211,9 @@ class LDMTextToImagePipeline(DiffusionPipeline):
# compute the previous noisy sample x_t -> x_t-1
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
,
**
extra_kwargs
).
prev_sample
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# scale and decode the image latents with vae
latents
=
1
/
self
.
vqvae
.
config
.
scaling_factor
*
latents
image
=
self
.
vqvae
.
decode
(
latents
).
sample
...
...
src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
View file @
95c5ce4e
...
...
@@ -15,11 +15,19 @@ from ...schedulers import (
LMSDiscreteScheduler
,
PNDMScheduler
,
)
from
...utils
import
PIL_INTERPOLATION
from
...utils
import
PIL_INTERPOLATION
,
is_torch_xla_available
from
...utils.torch_utils
import
randn_tensor
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
def
preprocess
(
image
):
w
,
h
=
image
.
size
w
,
h
=
(
x
-
x
%
32
for
x
in
(
w
,
h
))
# resize to integer multiple of 32
...
...
@@ -174,6 +182,9 @@ class LDMSuperResolutionPipeline(DiffusionPipeline):
# compute the previous noisy sample x_t -> x_t-1
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
,
**
extra_kwargs
).
prev_sample
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# decode the image latents with the VQVAE
image
=
self
.
vqvae
.
decode
(
latents
).
sample
image
=
torch
.
clamp
(
image
,
-
1.0
,
1.0
)
...
...
src/diffusers/pipelines/latte/pipeline_latte.py
View file @
95c5ce4e
...
...
@@ -32,6 +32,7 @@ from ...utils import (
BaseOutput
,
is_bs4_available
,
is_ftfy_available
,
is_torch_xla_available
,
logging
,
replace_example_docstring
,
)
...
...
@@ -39,8 +40,16 @@ from ...utils.torch_utils import is_compiled_module, randn_tensor
from
...video_processor
import
VideoProcessor
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
if
is_bs4_available
():
from
bs4
import
BeautifulSoup
...
...
@@ -836,6 +845,9 @@ class LattePipeline(DiffusionPipeline):
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
not
output_type
==
"latents"
:
video
=
self
.
decode_latents
(
latents
,
video_length
,
decode_chunk_size
=
14
)
video
=
self
.
video_processor
.
postprocess_video
(
video
=
video
,
output_type
=
output_type
)
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment