Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
95c5ce4e
Unverified
Commit
95c5ce4e
authored
Jan 08, 2025
by
hlky
Committed by
GitHub
Jan 08, 2025
Browse files
PyTorch/XLA support (#10498)
Co-authored-by:
Sayak Paul
<
spsayakpaul@gmail.com
>
parent
c0964571
Changes
111
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
139 additions
and
6 deletions
+139
-6
src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py
...stable_video_diffusion/pipeline_stable_video_diffusion.py
+12
-1
src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
...ipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
+13
-0
src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
...lines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
+12
-0
src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py
...s/text_to_video_synthesis/pipeline_text_to_video_synth.py
+12
-0
src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py
...o_video_synthesis/pipeline_text_to_video_synth_img2img.py
+12
-0
src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py
...xt_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py
+14
-0
src/diffusers/pipelines/unclip/pipeline_unclip.py
src/diffusers/pipelines/unclip/pipeline_unclip.py
+11
-1
src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py
...users/pipelines/unclip/pipeline_unclip_image_variation.py
+11
-1
src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py
src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py
+18
-1
src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py
src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py
+12
-1
src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py
...ffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py
+12
-1
No files found.
src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py
View file @
95c5ce4e
...
@@ -24,14 +24,22 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
...
@@ -24,14 +24,22 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from
...image_processor
import
PipelineImageInput
from
...image_processor
import
PipelineImageInput
from
...models
import
AutoencoderKLTemporalDecoder
,
UNetSpatioTemporalConditionModel
from
...models
import
AutoencoderKLTemporalDecoder
,
UNetSpatioTemporalConditionModel
from
...schedulers
import
EulerDiscreteScheduler
from
...schedulers
import
EulerDiscreteScheduler
from
...utils
import
BaseOutput
,
logging
,
replace_example_docstring
from
...utils
import
BaseOutput
,
is_torch_xla_available
,
logging
,
replace_example_docstring
from
...utils.torch_utils
import
is_compiled_module
,
randn_tensor
from
...utils.torch_utils
import
is_compiled_module
,
randn_tensor
from
...video_processor
import
VideoProcessor
from
...video_processor
import
VideoProcessor
from
..pipeline_utils
import
DiffusionPipeline
from
..pipeline_utils
import
DiffusionPipeline
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
EXAMPLE_DOC_STRING
=
"""
Examples:
Examples:
```py
```py
...
@@ -600,6 +608,9 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
...
@@ -600,6 +608,9 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
progress_bar
.
update
()
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
not
output_type
==
"latent"
:
if
not
output_type
==
"latent"
:
# cast back to fp16 if needed
# cast back to fp16 if needed
if
needs_upcasting
:
if
needs_upcasting
:
...
...
src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
View file @
95c5ce4e
...
@@ -31,6 +31,7 @@ from ...utils import (
...
@@ -31,6 +31,7 @@ from ...utils import (
USE_PEFT_BACKEND
,
USE_PEFT_BACKEND
,
BaseOutput
,
BaseOutput
,
deprecate
,
deprecate
,
is_torch_xla_available
,
logging
,
logging
,
replace_example_docstring
,
replace_example_docstring
,
scale_lora_layers
,
scale_lora_layers
,
...
@@ -41,6 +42,14 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
...
@@ -41,6 +42,14 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from
..stable_diffusion.safety_checker
import
StableDiffusionSafetyChecker
from
..stable_diffusion.safety_checker
import
StableDiffusionSafetyChecker
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
@
dataclass
@
dataclass
class
StableDiffusionAdapterPipelineOutput
(
BaseOutput
):
class
StableDiffusionAdapterPipelineOutput
(
BaseOutput
):
"""
"""
...
@@ -59,6 +68,7 @@ class StableDiffusionAdapterPipelineOutput(BaseOutput):
...
@@ -59,6 +68,7 @@ class StableDiffusionAdapterPipelineOutput(BaseOutput):
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
EXAMPLE_DOC_STRING
=
"""
Examples:
Examples:
```py
```py
...
@@ -915,6 +925,9 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline, StableDiffusionMixin):
...
@@ -915,6 +925,9 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline, StableDiffusionMixin):
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
output_type
==
"latent"
:
if
output_type
==
"latent"
:
image
=
latents
image
=
latents
has_nsfw_concept
=
None
has_nsfw_concept
=
None
...
...
src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
View file @
95c5ce4e
...
@@ -43,6 +43,7 @@ from ...schedulers import KarrasDiffusionSchedulers
...
@@ -43,6 +43,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from
...utils
import
(
from
...utils
import
(
PIL_INTERPOLATION
,
PIL_INTERPOLATION
,
USE_PEFT_BACKEND
,
USE_PEFT_BACKEND
,
is_torch_xla_available
,
logging
,
logging
,
replace_example_docstring
,
replace_example_docstring
,
scale_lora_layers
,
scale_lora_layers
,
...
@@ -53,8 +54,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
...
@@ -53,8 +54,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from
..stable_diffusion_xl.pipeline_output
import
StableDiffusionXLPipelineOutput
from
..stable_diffusion_xl.pipeline_output
import
StableDiffusionXLPipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
EXAMPLE_DOC_STRING
=
"""
Examples:
Examples:
```py
```py
...
@@ -1266,6 +1275,9 @@ class StableDiffusionXLAdapterPipeline(
...
@@ -1266,6 +1275,9 @@ class StableDiffusionXLAdapterPipeline(
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
not
output_type
==
"latent"
:
if
not
output_type
==
"latent"
:
# make sure the VAE is in float32 mode, as it overflows in float16
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting
=
self
.
vae
.
dtype
==
torch
.
float16
and
self
.
vae
.
config
.
force_upcast
needs_upcasting
=
self
.
vae
.
dtype
==
torch
.
float16
and
self
.
vae
.
config
.
force_upcast
...
...
src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py
View file @
95c5ce4e
...
@@ -25,6 +25,7 @@ from ...schedulers import KarrasDiffusionSchedulers
...
@@ -25,6 +25,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from
...utils
import
(
from
...utils
import
(
USE_PEFT_BACKEND
,
USE_PEFT_BACKEND
,
deprecate
,
deprecate
,
is_torch_xla_available
,
logging
,
logging
,
replace_example_docstring
,
replace_example_docstring
,
scale_lora_layers
,
scale_lora_layers
,
...
@@ -36,8 +37,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
...
@@ -36,8 +37,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from
.
import
TextToVideoSDPipelineOutput
from
.
import
TextToVideoSDPipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
EXAMPLE_DOC_STRING
=
"""
Examples:
Examples:
```py
```py
...
@@ -627,6 +636,9 @@ class TextToVideoSDPipeline(
...
@@ -627,6 +636,9 @@ class TextToVideoSDPipeline(
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# 8. Post processing
# 8. Post processing
if
output_type
==
"latent"
:
if
output_type
==
"latent"
:
video
=
latents
video
=
latents
...
...
src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py
View file @
95c5ce4e
...
@@ -26,6 +26,7 @@ from ...schedulers import KarrasDiffusionSchedulers
...
@@ -26,6 +26,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from
...utils
import
(
from
...utils
import
(
USE_PEFT_BACKEND
,
USE_PEFT_BACKEND
,
deprecate
,
deprecate
,
is_torch_xla_available
,
logging
,
logging
,
replace_example_docstring
,
replace_example_docstring
,
scale_lora_layers
,
scale_lora_layers
,
...
@@ -37,8 +38,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
...
@@ -37,8 +38,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from
.
import
TextToVideoSDPipelineOutput
from
.
import
TextToVideoSDPipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
EXAMPLE_DOC_STRING
=
"""
Examples:
Examples:
```py
```py
...
@@ -679,6 +688,9 @@ class VideoToVideoSDPipeline(
...
@@ -679,6 +688,9 @@ class VideoToVideoSDPipeline(
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# manually for max memory savings
# manually for max memory savings
if
hasattr
(
self
,
"final_offload_hook"
)
and
self
.
final_offload_hook
is
not
None
:
if
hasattr
(
self
,
"final_offload_hook"
)
and
self
.
final_offload_hook
is
not
None
:
self
.
unet
.
to
(
"cpu"
)
self
.
unet
.
to
(
"cpu"
)
...
...
src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py
View file @
95c5ce4e
...
@@ -42,6 +42,16 @@ if is_invisible_watermark_available():
...
@@ -42,6 +42,16 @@ if is_invisible_watermark_available():
from
..stable_diffusion_xl.watermark
import
StableDiffusionXLWatermarker
from
..stable_diffusion_xl.watermark
import
StableDiffusionXLWatermarker
from
...utils
import
is_torch_xla_available
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
@@ -926,6 +936,10 @@ class TextToVideoZeroSDXLPipeline(
...
@@ -926,6 +936,10 @@ class TextToVideoZeroSDXLPipeline(
progress_bar
.
update
()
progress_bar
.
update
()
if
callback
is
not
None
and
i
%
callback_steps
==
0
:
if
callback
is
not
None
and
i
%
callback_steps
==
0
:
callback
(
i
,
t
,
latents
)
callback
(
i
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
return
latents
.
clone
().
detach
()
return
latents
.
clone
().
detach
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
...
src/diffusers/pipelines/unclip/pipeline_unclip.py
View file @
95c5ce4e
...
@@ -22,12 +22,19 @@ from transformers.models.clip.modeling_clip import CLIPTextModelOutput
...
@@ -22,12 +22,19 @@ from transformers.models.clip.modeling_clip import CLIPTextModelOutput
from
...models
import
PriorTransformer
,
UNet2DConditionModel
,
UNet2DModel
from
...models
import
PriorTransformer
,
UNet2DConditionModel
,
UNet2DModel
from
...schedulers
import
UnCLIPScheduler
from
...schedulers
import
UnCLIPScheduler
from
...utils
import
logging
from
...utils
import
is_torch_xla_available
,
logging
from
...utils.torch_utils
import
randn_tensor
from
...utils.torch_utils
import
randn_tensor
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
from
.text_proj
import
UnCLIPTextProjModel
from
.text_proj
import
UnCLIPTextProjModel
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
@@ -474,6 +481,9 @@ class UnCLIPPipeline(DiffusionPipeline):
...
@@ -474,6 +481,9 @@ class UnCLIPPipeline(DiffusionPipeline):
noise_pred
,
t
,
super_res_latents
,
prev_timestep
=
prev_timestep
,
generator
=
generator
noise_pred
,
t
,
super_res_latents
,
prev_timestep
=
prev_timestep
,
generator
=
generator
).
prev_sample
).
prev_sample
if
XLA_AVAILABLE
:
xm
.
mark_step
()
image
=
super_res_latents
image
=
super_res_latents
# done super res
# done super res
...
...
src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py
View file @
95c5ce4e
...
@@ -27,12 +27,19 @@ from transformers import (
...
@@ -27,12 +27,19 @@ from transformers import (
from
...models
import
UNet2DConditionModel
,
UNet2DModel
from
...models
import
UNet2DConditionModel
,
UNet2DModel
from
...schedulers
import
UnCLIPScheduler
from
...schedulers
import
UnCLIPScheduler
from
...utils
import
logging
from
...utils
import
is_torch_xla_available
,
logging
from
...utils.torch_utils
import
randn_tensor
from
...utils.torch_utils
import
randn_tensor
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
from
.text_proj
import
UnCLIPTextProjModel
from
.text_proj
import
UnCLIPTextProjModel
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
@@ -400,6 +407,9 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
...
@@ -400,6 +407,9 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
noise_pred
,
t
,
super_res_latents
,
prev_timestep
=
prev_timestep
,
generator
=
generator
noise_pred
,
t
,
super_res_latents
,
prev_timestep
=
prev_timestep
,
generator
=
generator
).
prev_sample
).
prev_sample
if
XLA_AVAILABLE
:
xm
.
mark_step
()
image
=
super_res_latents
image
=
super_res_latents
# done super res
# done super res
...
...
src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py
View file @
95c5ce4e
...
@@ -18,7 +18,14 @@ from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMix
...
@@ -18,7 +18,14 @@ from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMix
from
...models
import
AutoencoderKL
from
...models
import
AutoencoderKL
from
...models.lora
import
adjust_lora_scale_text_encoder
from
...models.lora
import
adjust_lora_scale_text_encoder
from
...schedulers
import
KarrasDiffusionSchedulers
from
...schedulers
import
KarrasDiffusionSchedulers
from
...utils
import
USE_PEFT_BACKEND
,
deprecate
,
logging
,
scale_lora_layers
,
unscale_lora_layers
from
...utils
import
(
USE_PEFT_BACKEND
,
deprecate
,
is_torch_xla_available
,
logging
,
scale_lora_layers
,
unscale_lora_layers
,
)
from
...utils.outputs
import
BaseOutput
from
...utils.outputs
import
BaseOutput
from
...utils.torch_utils
import
randn_tensor
from
...utils.torch_utils
import
randn_tensor
from
..pipeline_utils
import
DiffusionPipeline
from
..pipeline_utils
import
DiffusionPipeline
...
@@ -26,6 +33,13 @@ from .modeling_text_decoder import UniDiffuserTextDecoder
...
@@ -26,6 +33,13 @@ from .modeling_text_decoder import UniDiffuserTextDecoder
from
.modeling_uvit
import
UniDiffuserModel
from
.modeling_uvit
import
UniDiffuserModel
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
@@ -1378,6 +1392,9 @@ class UniDiffuserPipeline(DiffusionPipeline):
...
@@ -1378,6 +1392,9 @@ class UniDiffuserPipeline(DiffusionPipeline):
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# 9. Post-processing
# 9. Post-processing
image
=
None
image
=
None
text
=
None
text
=
None
...
...
src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py
View file @
95c5ce4e
...
@@ -19,15 +19,23 @@ import torch
...
@@ -19,15 +19,23 @@ import torch
from
transformers
import
CLIPTextModel
,
CLIPTokenizer
from
transformers
import
CLIPTextModel
,
CLIPTokenizer
from
...schedulers
import
DDPMWuerstchenScheduler
from
...schedulers
import
DDPMWuerstchenScheduler
from
...utils
import
deprecate
,
logging
,
replace_example_docstring
from
...utils
import
deprecate
,
is_torch_xla_available
,
logging
,
replace_example_docstring
from
...utils.torch_utils
import
randn_tensor
from
...utils.torch_utils
import
randn_tensor
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
from
.modeling_paella_vq_model
import
PaellaVQModel
from
.modeling_paella_vq_model
import
PaellaVQModel
from
.modeling_wuerstchen_diffnext
import
WuerstchenDiffNeXt
from
.modeling_wuerstchen_diffnext
import
WuerstchenDiffNeXt
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
EXAMPLE_DOC_STRING
=
"""
Examples:
Examples:
```py
```py
...
@@ -413,6 +421,9 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
...
@@ -413,6 +421,9 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
output_type
not
in
[
"pt"
,
"np"
,
"pil"
,
"latent"
]:
if
output_type
not
in
[
"pt"
,
"np"
,
"pil"
,
"latent"
]:
raise
ValueError
(
raise
ValueError
(
f
"Only the output types `pt`, `np`, `pil` and `latent` are supported not output_type=
{
output_type
}
"
f
"Only the output types `pt`, `np`, `pil` and `latent` are supported not output_type=
{
output_type
}
"
...
...
src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py
View file @
95c5ce4e
...
@@ -22,14 +22,22 @@ from transformers import CLIPTextModel, CLIPTokenizer
...
@@ -22,14 +22,22 @@ from transformers import CLIPTextModel, CLIPTokenizer
from
...loaders
import
StableDiffusionLoraLoaderMixin
from
...loaders
import
StableDiffusionLoraLoaderMixin
from
...schedulers
import
DDPMWuerstchenScheduler
from
...schedulers
import
DDPMWuerstchenScheduler
from
...utils
import
BaseOutput
,
deprecate
,
logging
,
replace_example_docstring
from
...utils
import
BaseOutput
,
deprecate
,
is_torch_xla_available
,
logging
,
replace_example_docstring
from
...utils.torch_utils
import
randn_tensor
from
...utils.torch_utils
import
randn_tensor
from
..pipeline_utils
import
DiffusionPipeline
from
..pipeline_utils
import
DiffusionPipeline
from
.modeling_wuerstchen_prior
import
WuerstchenPrior
from
.modeling_wuerstchen_prior
import
WuerstchenPrior
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
DEFAULT_STAGE_C_TIMESTEPS
=
list
(
np
.
linspace
(
1.0
,
2
/
3
,
20
))
+
list
(
np
.
linspace
(
2
/
3
,
0.0
,
11
))[
1
:]
DEFAULT_STAGE_C_TIMESTEPS
=
list
(
np
.
linspace
(
1.0
,
2
/
3
,
20
))
+
list
(
np
.
linspace
(
2
/
3
,
0.0
,
11
))[
1
:]
EXAMPLE_DOC_STRING
=
"""
EXAMPLE_DOC_STRING
=
"""
...
@@ -502,6 +510,9 @@ class WuerstchenPriorPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin)
...
@@ -502,6 +510,9 @@ class WuerstchenPriorPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin)
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# 10. Denormalize the latents
# 10. Denormalize the latents
latents
=
latents
*
self
.
config
.
latent_mean
-
self
.
config
.
latent_std
latents
=
latents
*
self
.
config
.
latent_mean
-
self
.
config
.
latent_std
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment