Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
95c5ce4e
Unverified
Commit
95c5ce4e
authored
Jan 08, 2025
by
hlky
Committed by
GitHub
Jan 08, 2025
Browse files
PyTorch/XLA support (#10498)
Co-authored-by:
Sayak Paul
<
spsayakpaul@gmail.com
>
parent
c0964571
Changes
111
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
262 additions
and
9 deletions
+262
-9
src/diffusers/pipelines/shap_e/pipeline_shap_e.py
src/diffusers/pipelines/shap_e/pipeline_shap_e.py
+12
-0
src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py
src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py
+12
-0
src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py
...users/pipelines/stable_cascade/pipeline_stable_cascade.py
+12
-1
src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py
...pipelines/stable_cascade/pipeline_stable_cascade_prior.py
+12
-1
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py
...s/stable_diffusion/pipeline_stable_diffusion_depth2img.py
+19
-1
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
...le_diffusion/pipeline_stable_diffusion_image_variation.py
+11
-1
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
...nes/stable_diffusion/pipeline_stable_diffusion_img2img.py
+12
-0
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
...nes/stable_diffusion/pipeline_stable_diffusion_inpaint.py
+18
-1
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
...ble_diffusion/pipeline_stable_diffusion_latent_upscale.py
+11
-1
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
...nes/stable_diffusion/pipeline_stable_diffusion_upscale.py
+18
-1
src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
...sers/pipelines/stable_diffusion/pipeline_stable_unclip.py
+12
-0
src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py
...elines/stable_diffusion/pipeline_stable_unclip_img2img.py
+12
-0
src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py
...and_excite/pipeline_stable_diffusion_attend_and_excite.py
+12
-0
src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py
..._diffusion_diffedit/pipeline_stable_diffusion_diffedit.py
+11
-0
src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py
...able_diffusion_gligen/pipeline_stable_diffusion_gligen.py
+12
-0
src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py
...ion_gligen/pipeline_stable_diffusion_gligen_text_image.py
+19
-1
src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py
...stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py
+12
-0
src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py
..._diffusion_panorama/pipeline_stable_diffusion_panorama.py
+12
-0
src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py
...s/stable_diffusion_safe/pipeline_stable_diffusion_safe.py
+11
-1
src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py
...nes/stable_diffusion_sag/pipeline_stable_diffusion_sag.py
+12
-0
No files found.
src/diffusers/pipelines/shap_e/pipeline_shap_e.py
View file @
95c5ce4e
...
@@ -25,6 +25,7 @@ from ...models import PriorTransformer
...
@@ -25,6 +25,7 @@ from ...models import PriorTransformer
from
...schedulers
import
HeunDiscreteScheduler
from
...schedulers
import
HeunDiscreteScheduler
from
...utils
import
(
from
...utils
import
(
BaseOutput
,
BaseOutput
,
is_torch_xla_available
,
logging
,
logging
,
replace_example_docstring
,
replace_example_docstring
,
)
)
...
@@ -33,8 +34,16 @@ from ..pipeline_utils import DiffusionPipeline
...
@@ -33,8 +34,16 @@ from ..pipeline_utils import DiffusionPipeline
from
.renderer
import
ShapERenderer
from
.renderer
import
ShapERenderer
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
EXAMPLE_DOC_STRING
=
"""
Examples:
Examples:
```py
```py
...
@@ -291,6 +300,9 @@ class ShapEPipeline(DiffusionPipeline):
...
@@ -291,6 +300,9 @@ class ShapEPipeline(DiffusionPipeline):
sample
=
latents
,
sample
=
latents
,
).
prev_sample
).
prev_sample
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# Offload all models
# Offload all models
self
.
maybe_free_model_hooks
()
self
.
maybe_free_model_hooks
()
...
...
src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py
View file @
95c5ce4e
...
@@ -24,6 +24,7 @@ from ...models import PriorTransformer
...
@@ -24,6 +24,7 @@ from ...models import PriorTransformer
from
...schedulers
import
HeunDiscreteScheduler
from
...schedulers
import
HeunDiscreteScheduler
from
...utils
import
(
from
...utils
import
(
BaseOutput
,
BaseOutput
,
is_torch_xla_available
,
logging
,
logging
,
replace_example_docstring
,
replace_example_docstring
,
)
)
...
@@ -32,8 +33,16 @@ from ..pipeline_utils import DiffusionPipeline
...
@@ -32,8 +33,16 @@ from ..pipeline_utils import DiffusionPipeline
from
.renderer
import
ShapERenderer
from
.renderer
import
ShapERenderer
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
EXAMPLE_DOC_STRING
=
"""
Examples:
Examples:
```py
```py
...
@@ -278,6 +287,9 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
...
@@ -278,6 +287,9 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
sample
=
latents
,
sample
=
latents
,
).
prev_sample
).
prev_sample
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
output_type
not
in
[
"np"
,
"pil"
,
"latent"
,
"mesh"
]:
if
output_type
not
in
[
"np"
,
"pil"
,
"latent"
,
"mesh"
]:
raise
ValueError
(
raise
ValueError
(
f
"Only the output types `pil`, `np`, `latent` and `mesh` are supported not output_type=
{
output_type
}
"
f
"Only the output types `pil`, `np`, `latent` and `mesh` are supported not output_type=
{
output_type
}
"
...
...
src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py
View file @
95c5ce4e
...
@@ -19,14 +19,22 @@ from transformers import CLIPTextModel, CLIPTokenizer
...
@@ -19,14 +19,22 @@ from transformers import CLIPTextModel, CLIPTokenizer
from
...models
import
StableCascadeUNet
from
...models
import
StableCascadeUNet
from
...schedulers
import
DDPMWuerstchenScheduler
from
...schedulers
import
DDPMWuerstchenScheduler
from
...utils
import
is_torch_version
,
logging
,
replace_example_docstring
from
...utils
import
is_torch_version
,
is_torch_xla_available
,
logging
,
replace_example_docstring
from
...utils.torch_utils
import
randn_tensor
from
...utils.torch_utils
import
randn_tensor
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
from
..wuerstchen.modeling_paella_vq_model
import
PaellaVQModel
from
..wuerstchen.modeling_paella_vq_model
import
PaellaVQModel
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
EXAMPLE_DOC_STRING
=
"""
Examples:
Examples:
```py
```py
...
@@ -503,6 +511,9 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
...
@@ -503,6 +511,9 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
prompt_embeds
=
callback_outputs
.
pop
(
"prompt_embeds"
,
prompt_embeds
)
prompt_embeds
=
callback_outputs
.
pop
(
"prompt_embeds"
,
prompt_embeds
)
negative_prompt_embeds
=
callback_outputs
.
pop
(
"negative_prompt_embeds"
,
negative_prompt_embeds
)
negative_prompt_embeds
=
callback_outputs
.
pop
(
"negative_prompt_embeds"
,
negative_prompt_embeds
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
output_type
not
in
[
"pt"
,
"np"
,
"pil"
,
"latent"
]:
if
output_type
not
in
[
"pt"
,
"np"
,
"pil"
,
"latent"
]:
raise
ValueError
(
raise
ValueError
(
f
"Only the output types `pt`, `np`, `pil` and `latent` are supported not output_type=
{
output_type
}
"
f
"Only the output types `pt`, `np`, `pil` and `latent` are supported not output_type=
{
output_type
}
"
...
...
src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py
View file @
95c5ce4e
...
@@ -23,13 +23,21 @@ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTo
...
@@ -23,13 +23,21 @@ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTo
from
...models
import
StableCascadeUNet
from
...models
import
StableCascadeUNet
from
...schedulers
import
DDPMWuerstchenScheduler
from
...schedulers
import
DDPMWuerstchenScheduler
from
...utils
import
BaseOutput
,
logging
,
replace_example_docstring
from
...utils
import
BaseOutput
,
is_torch_xla_available
,
logging
,
replace_example_docstring
from
...utils.torch_utils
import
randn_tensor
from
...utils.torch_utils
import
randn_tensor
from
..pipeline_utils
import
DiffusionPipeline
from
..pipeline_utils
import
DiffusionPipeline
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
DEFAULT_STAGE_C_TIMESTEPS
=
list
(
np
.
linspace
(
1.0
,
2
/
3
,
20
))
+
list
(
np
.
linspace
(
2
/
3
,
0.0
,
11
))[
1
:]
DEFAULT_STAGE_C_TIMESTEPS
=
list
(
np
.
linspace
(
1.0
,
2
/
3
,
20
))
+
list
(
np
.
linspace
(
2
/
3
,
0.0
,
11
))[
1
:]
EXAMPLE_DOC_STRING
=
"""
EXAMPLE_DOC_STRING
=
"""
...
@@ -611,6 +619,9 @@ class StableCascadePriorPipeline(DiffusionPipeline):
...
@@ -611,6 +619,9 @@ class StableCascadePriorPipeline(DiffusionPipeline):
prompt_embeds
=
callback_outputs
.
pop
(
"prompt_embeds"
,
prompt_embeds
)
prompt_embeds
=
callback_outputs
.
pop
(
"prompt_embeds"
,
prompt_embeds
)
negative_prompt_embeds
=
callback_outputs
.
pop
(
"negative_prompt_embeds"
,
negative_prompt_embeds
)
negative_prompt_embeds
=
callback_outputs
.
pop
(
"negative_prompt_embeds"
,
negative_prompt_embeds
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# Offload all models
# Offload all models
self
.
maybe_free_model_hooks
()
self
.
maybe_free_model_hooks
()
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py
View file @
95c5ce4e
...
@@ -28,11 +28,26 @@ from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMix
...
@@ -28,11 +28,26 @@ from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMix
from
...models
import
AutoencoderKL
,
UNet2DConditionModel
from
...models
import
AutoencoderKL
,
UNet2DConditionModel
from
...models.lora
import
adjust_lora_scale_text_encoder
from
...models.lora
import
adjust_lora_scale_text_encoder
from
...schedulers
import
KarrasDiffusionSchedulers
from
...schedulers
import
KarrasDiffusionSchedulers
from
...utils
import
PIL_INTERPOLATION
,
USE_PEFT_BACKEND
,
deprecate
,
logging
,
scale_lora_layers
,
unscale_lora_layers
from
...utils
import
(
PIL_INTERPOLATION
,
USE_PEFT_BACKEND
,
deprecate
,
is_torch_xla_available
,
logging
,
scale_lora_layers
,
unscale_lora_layers
,
)
from
...utils.torch_utils
import
randn_tensor
from
...utils.torch_utils
import
randn_tensor
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
@@ -861,6 +876,9 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
...
@@ -861,6 +876,9 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
not
output_type
==
"latent"
:
if
not
output_type
==
"latent"
:
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
else
:
else
:
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
View file @
95c5ce4e
...
@@ -24,13 +24,20 @@ from ...configuration_utils import FrozenDict
...
@@ -24,13 +24,20 @@ from ...configuration_utils import FrozenDict
from
...image_processor
import
VaeImageProcessor
from
...image_processor
import
VaeImageProcessor
from
...models
import
AutoencoderKL
,
UNet2DConditionModel
from
...models
import
AutoencoderKL
,
UNet2DConditionModel
from
...schedulers
import
KarrasDiffusionSchedulers
from
...schedulers
import
KarrasDiffusionSchedulers
from
...utils
import
deprecate
,
logging
from
...utils
import
deprecate
,
is_torch_xla_available
,
logging
from
...utils.torch_utils
import
randn_tensor
from
...utils.torch_utils
import
randn_tensor
from
..pipeline_utils
import
DiffusionPipeline
,
StableDiffusionMixin
from
..pipeline_utils
import
DiffusionPipeline
,
StableDiffusionMixin
from
.
import
StableDiffusionPipelineOutput
from
.
import
StableDiffusionPipelineOutput
from
.safety_checker
import
StableDiffusionSafetyChecker
from
.safety_checker
import
StableDiffusionSafetyChecker
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
@@ -401,6 +408,9 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline, StableDiffusionMi
...
@@ -401,6 +408,9 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline, StableDiffusionMi
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
self
.
maybe_free_model_hooks
()
self
.
maybe_free_model_hooks
()
if
not
output_type
==
"latent"
:
if
not
output_type
==
"latent"
:
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
View file @
95c5ce4e
...
@@ -32,6 +32,7 @@ from ...utils import (
...
@@ -32,6 +32,7 @@ from ...utils import (
PIL_INTERPOLATION
,
PIL_INTERPOLATION
,
USE_PEFT_BACKEND
,
USE_PEFT_BACKEND
,
deprecate
,
deprecate
,
is_torch_xla_available
,
logging
,
logging
,
replace_example_docstring
,
replace_example_docstring
,
scale_lora_layers
,
scale_lora_layers
,
...
@@ -43,8 +44,16 @@ from . import StableDiffusionPipelineOutput
...
@@ -43,8 +44,16 @@ from . import StableDiffusionPipelineOutput
from
.safety_checker
import
StableDiffusionSafetyChecker
from
.safety_checker
import
StableDiffusionSafetyChecker
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
EXAMPLE_DOC_STRING
=
"""
Examples:
Examples:
```py
```py
...
@@ -1120,6 +1129,9 @@ class StableDiffusionImg2ImgPipeline(
...
@@ -1120,6 +1129,9 @@ class StableDiffusionImg2ImgPipeline(
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
not
output_type
==
"latent"
:
if
not
output_type
==
"latent"
:
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
,
generator
=
generator
)[
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
,
generator
=
generator
)[
0
0
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
View file @
95c5ce4e
...
@@ -27,13 +27,27 @@ from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraL
...
@@ -27,13 +27,27 @@ from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraL
from
...models
import
AsymmetricAutoencoderKL
,
AutoencoderKL
,
ImageProjection
,
UNet2DConditionModel
from
...models
import
AsymmetricAutoencoderKL
,
AutoencoderKL
,
ImageProjection
,
UNet2DConditionModel
from
...models.lora
import
adjust_lora_scale_text_encoder
from
...models.lora
import
adjust_lora_scale_text_encoder
from
...schedulers
import
KarrasDiffusionSchedulers
from
...schedulers
import
KarrasDiffusionSchedulers
from
...utils
import
USE_PEFT_BACKEND
,
deprecate
,
logging
,
scale_lora_layers
,
unscale_lora_layers
from
...utils
import
(
USE_PEFT_BACKEND
,
deprecate
,
is_torch_xla_available
,
logging
,
scale_lora_layers
,
unscale_lora_layers
,
)
from
...utils.torch_utils
import
randn_tensor
from
...utils.torch_utils
import
randn_tensor
from
..pipeline_utils
import
DiffusionPipeline
,
StableDiffusionMixin
from
..pipeline_utils
import
DiffusionPipeline
,
StableDiffusionMixin
from
.
import
StableDiffusionPipelineOutput
from
.
import
StableDiffusionPipelineOutput
from
.safety_checker
import
StableDiffusionSafetyChecker
from
.safety_checker
import
StableDiffusionSafetyChecker
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
@@ -1303,6 +1317,9 @@ class StableDiffusionInpaintPipeline(
...
@@ -1303,6 +1317,9 @@ class StableDiffusionInpaintPipeline(
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
not
output_type
==
"latent"
:
if
not
output_type
==
"latent"
:
condition_kwargs
=
{}
condition_kwargs
=
{}
if
isinstance
(
self
.
vae
,
AsymmetricAutoencoderKL
):
if
isinstance
(
self
.
vae
,
AsymmetricAutoencoderKL
):
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
View file @
95c5ce4e
...
@@ -25,11 +25,18 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor
...
@@ -25,11 +25,18 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor
from
...loaders
import
FromSingleFileMixin
from
...loaders
import
FromSingleFileMixin
from
...models
import
AutoencoderKL
,
UNet2DConditionModel
from
...models
import
AutoencoderKL
,
UNet2DConditionModel
from
...schedulers
import
EulerDiscreteScheduler
from
...schedulers
import
EulerDiscreteScheduler
from
...utils
import
deprecate
,
logging
from
...utils
import
deprecate
,
is_torch_xla_available
,
logging
from
...utils.torch_utils
import
randn_tensor
from
...utils.torch_utils
import
randn_tensor
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
,
StableDiffusionMixin
from
..pipeline_utils
import
DiffusionPipeline
,
ImagePipelineOutput
,
StableDiffusionMixin
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
@@ -640,6 +647,9 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
...
@@ -640,6 +647,9 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
not
output_type
==
"latent"
:
if
not
output_type
==
"latent"
:
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
else
:
else
:
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
View file @
95c5ce4e
...
@@ -30,12 +30,26 @@ from ...models.attention_processor import (
...
@@ -30,12 +30,26 @@ from ...models.attention_processor import (
)
)
from
...models.lora
import
adjust_lora_scale_text_encoder
from
...models.lora
import
adjust_lora_scale_text_encoder
from
...schedulers
import
DDPMScheduler
,
KarrasDiffusionSchedulers
from
...schedulers
import
DDPMScheduler
,
KarrasDiffusionSchedulers
from
...utils
import
USE_PEFT_BACKEND
,
deprecate
,
logging
,
scale_lora_layers
,
unscale_lora_layers
from
...utils
import
(
USE_PEFT_BACKEND
,
deprecate
,
is_torch_xla_available
,
logging
,
scale_lora_layers
,
unscale_lora_layers
,
)
from
...utils.torch_utils
import
randn_tensor
from
...utils.torch_utils
import
randn_tensor
from
..pipeline_utils
import
DiffusionPipeline
,
StableDiffusionMixin
from
..pipeline_utils
import
DiffusionPipeline
,
StableDiffusionMixin
from
.
import
StableDiffusionPipelineOutput
from
.
import
StableDiffusionPipelineOutput
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
@@ -769,6 +783,9 @@ class StableDiffusionUpscalePipeline(
...
@@ -769,6 +783,9 @@ class StableDiffusionUpscalePipeline(
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
not
output_type
==
"latent"
:
if
not
output_type
==
"latent"
:
# make sure the VAE is in float32 mode, as it overflows in float16
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting
=
self
.
vae
.
dtype
==
torch
.
float16
and
self
.
vae
.
config
.
force_upcast
needs_upcasting
=
self
.
vae
.
dtype
==
torch
.
float16
and
self
.
vae
.
config
.
force_upcast
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
View file @
95c5ce4e
...
@@ -28,6 +28,7 @@ from ...schedulers import KarrasDiffusionSchedulers
...
@@ -28,6 +28,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from
...utils
import
(
from
...utils
import
(
USE_PEFT_BACKEND
,
USE_PEFT_BACKEND
,
deprecate
,
deprecate
,
is_torch_xla_available
,
logging
,
logging
,
replace_example_docstring
,
replace_example_docstring
,
scale_lora_layers
,
scale_lora_layers
,
...
@@ -38,8 +39,16 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffu
...
@@ -38,8 +39,16 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffu
from
.stable_unclip_image_normalizer
import
StableUnCLIPImageNormalizer
from
.stable_unclip_image_normalizer
import
StableUnCLIPImageNormalizer
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
EXAMPLE_DOC_STRING
=
"""
Examples:
Examples:
```py
```py
...
@@ -924,6 +933,9 @@ class StableUnCLIPPipeline(
...
@@ -924,6 +933,9 @@ class StableUnCLIPPipeline(
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
not
output_type
==
"latent"
:
if
not
output_type
==
"latent"
:
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
else
:
else
:
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py
View file @
95c5ce4e
...
@@ -28,6 +28,7 @@ from ...schedulers import KarrasDiffusionSchedulers
...
@@ -28,6 +28,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from
...utils
import
(
from
...utils
import
(
USE_PEFT_BACKEND
,
USE_PEFT_BACKEND
,
deprecate
,
deprecate
,
is_torch_xla_available
,
logging
,
logging
,
replace_example_docstring
,
replace_example_docstring
,
scale_lora_layers
,
scale_lora_layers
,
...
@@ -38,8 +39,16 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffu
...
@@ -38,8 +39,16 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffu
from
.stable_unclip_image_normalizer
import
StableUnCLIPImageNormalizer
from
.stable_unclip_image_normalizer
import
StableUnCLIPImageNormalizer
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
EXAMPLE_DOC_STRING
=
"""
Examples:
Examples:
```py
```py
...
@@ -829,6 +838,9 @@ class StableUnCLIPImg2ImgPipeline(
...
@@ -829,6 +838,9 @@ class StableUnCLIPImg2ImgPipeline(
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# 9. Post-processing
# 9. Post-processing
if
not
output_type
==
"latent"
:
if
not
output_type
==
"latent"
:
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
...
...
src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py
View file @
95c5ce4e
...
@@ -30,6 +30,7 @@ from ...schedulers import KarrasDiffusionSchedulers
...
@@ -30,6 +30,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from
...utils
import
(
from
...utils
import
(
USE_PEFT_BACKEND
,
USE_PEFT_BACKEND
,
deprecate
,
deprecate
,
is_torch_xla_available
,
logging
,
logging
,
replace_example_docstring
,
replace_example_docstring
,
scale_lora_layers
,
scale_lora_layers
,
...
@@ -41,6 +42,14 @@ from ..stable_diffusion import StableDiffusionPipelineOutput
...
@@ -41,6 +42,14 @@ from ..stable_diffusion import StableDiffusionPipelineOutput
from
..stable_diffusion.safety_checker
import
StableDiffusionSafetyChecker
from
..stable_diffusion.safety_checker
import
StableDiffusionSafetyChecker
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
EXAMPLE_DOC_STRING
=
"""
EXAMPLE_DOC_STRING
=
"""
...
@@ -1008,6 +1017,9 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, StableDiffusionM
...
@@ -1008,6 +1017,9 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, StableDiffusionM
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# 8. Post-processing
# 8. Post-processing
if
not
output_type
==
"latent"
:
if
not
output_type
==
"latent"
:
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
...
...
src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py
View file @
95c5ce4e
...
@@ -33,6 +33,7 @@ from ...utils import (
...
@@ -33,6 +33,7 @@ from ...utils import (
USE_PEFT_BACKEND
,
USE_PEFT_BACKEND
,
BaseOutput
,
BaseOutput
,
deprecate
,
deprecate
,
is_torch_xla_available
,
logging
,
logging
,
replace_example_docstring
,
replace_example_docstring
,
scale_lora_layers
,
scale_lora_layers
,
...
@@ -44,6 +45,13 @@ from ..stable_diffusion import StableDiffusionPipelineOutput
...
@@ -44,6 +45,13 @@ from ..stable_diffusion import StableDiffusionPipelineOutput
from
..stable_diffusion.safety_checker
import
StableDiffusionSafetyChecker
from
..stable_diffusion.safety_checker
import
StableDiffusionSafetyChecker
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
@@ -1508,6 +1516,9 @@ class StableDiffusionDiffEditPipeline(
...
@@ -1508,6 +1516,9 @@ class StableDiffusionDiffEditPipeline(
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
not
output_type
==
"latent"
:
if
not
output_type
==
"latent"
:
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
image
,
has_nsfw_concept
=
self
.
run_safety_checker
(
image
,
device
,
prompt_embeds
.
dtype
)
image
,
has_nsfw_concept
=
self
.
run_safety_checker
(
image
,
device
,
prompt_embeds
.
dtype
)
...
...
src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py
View file @
95c5ce4e
...
@@ -29,6 +29,7 @@ from ...schedulers import KarrasDiffusionSchedulers
...
@@ -29,6 +29,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from
...utils
import
(
from
...utils
import
(
USE_PEFT_BACKEND
,
USE_PEFT_BACKEND
,
deprecate
,
deprecate
,
is_torch_xla_available
,
logging
,
logging
,
replace_example_docstring
,
replace_example_docstring
,
scale_lora_layers
,
scale_lora_layers
,
...
@@ -40,8 +41,16 @@ from ..stable_diffusion import StableDiffusionPipelineOutput
...
@@ -40,8 +41,16 @@ from ..stable_diffusion import StableDiffusionPipelineOutput
from
..stable_diffusion.safety_checker
import
StableDiffusionSafetyChecker
from
..stable_diffusion.safety_checker
import
StableDiffusionSafetyChecker
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
EXAMPLE_DOC_STRING
=
"""
Examples:
Examples:
```py
```py
...
@@ -828,6 +837,9 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline, StableDiffusionMixin):
...
@@ -828,6 +837,9 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline, StableDiffusionMixin):
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
not
output_type
==
"latent"
:
if
not
output_type
==
"latent"
:
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
image
,
has_nsfw_concept
=
self
.
run_safety_checker
(
image
,
device
,
prompt_embeds
.
dtype
)
image
,
has_nsfw_concept
=
self
.
run_safety_checker
(
image
,
device
,
prompt_embeds
.
dtype
)
...
...
src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py
View file @
95c5ce4e
...
@@ -32,7 +32,14 @@ from ...models import AutoencoderKL, UNet2DConditionModel
...
@@ -32,7 +32,14 @@ from ...models import AutoencoderKL, UNet2DConditionModel
from
...models.attention
import
GatedSelfAttentionDense
from
...models.attention
import
GatedSelfAttentionDense
from
...models.lora
import
adjust_lora_scale_text_encoder
from
...models.lora
import
adjust_lora_scale_text_encoder
from
...schedulers
import
KarrasDiffusionSchedulers
from
...schedulers
import
KarrasDiffusionSchedulers
from
...utils
import
USE_PEFT_BACKEND
,
logging
,
replace_example_docstring
,
scale_lora_layers
,
unscale_lora_layers
from
...utils
import
(
USE_PEFT_BACKEND
,
is_torch_xla_available
,
logging
,
replace_example_docstring
,
scale_lora_layers
,
unscale_lora_layers
,
)
from
...utils.torch_utils
import
randn_tensor
from
...utils.torch_utils
import
randn_tensor
from
..pipeline_utils
import
DiffusionPipeline
,
StableDiffusionMixin
from
..pipeline_utils
import
DiffusionPipeline
,
StableDiffusionMixin
from
..stable_diffusion
import
StableDiffusionPipelineOutput
from
..stable_diffusion
import
StableDiffusionPipelineOutput
...
@@ -40,8 +47,16 @@ from ..stable_diffusion.clip_image_project_model import CLIPImageProjection
...
@@ -40,8 +47,16 @@ from ..stable_diffusion.clip_image_project_model import CLIPImageProjection
from
..stable_diffusion.safety_checker
import
StableDiffusionSafetyChecker
from
..stable_diffusion.safety_checker
import
StableDiffusionSafetyChecker
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
EXAMPLE_DOC_STRING
=
"""
Examples:
Examples:
```py
```py
...
@@ -1010,6 +1025,9 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline, StableDiffusionM
...
@@ -1010,6 +1025,9 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline, StableDiffusionM
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
not
output_type
==
"latent"
:
if
not
output_type
==
"latent"
:
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
image
,
has_nsfw_concept
=
self
.
run_safety_checker
(
image
,
device
,
prompt_embeds
.
dtype
)
image
,
has_nsfw_concept
=
self
.
run_safety_checker
(
image
,
device
,
prompt_embeds
.
dtype
)
...
...
src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py
View file @
95c5ce4e
...
@@ -30,6 +30,7 @@ from ...utils import (
...
@@ -30,6 +30,7 @@ from ...utils import (
USE_PEFT_BACKEND
,
USE_PEFT_BACKEND
,
BaseOutput
,
BaseOutput
,
deprecate
,
deprecate
,
is_torch_xla_available
,
logging
,
logging
,
replace_example_docstring
,
replace_example_docstring
,
scale_lora_layers
,
scale_lora_layers
,
...
@@ -40,8 +41,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
...
@@ -40,8 +41,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from
..stable_diffusion.safety_checker
import
StableDiffusionSafetyChecker
from
..stable_diffusion.safety_checker
import
StableDiffusionSafetyChecker
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
EXAMPLE_DOC_STRING
=
"""
Examples:
Examples:
```python
```python
...
@@ -1002,6 +1011,9 @@ class StableDiffusionLDM3DPipeline(
...
@@ -1002,6 +1011,9 @@ class StableDiffusionLDM3DPipeline(
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
not
output_type
==
"latent"
:
if
not
output_type
==
"latent"
:
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
image
,
has_nsfw_concept
=
self
.
run_safety_checker
(
image
,
device
,
prompt_embeds
.
dtype
)
image
,
has_nsfw_concept
=
self
.
run_safety_checker
(
image
,
device
,
prompt_embeds
.
dtype
)
...
...
src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py
View file @
95c5ce4e
...
@@ -26,6 +26,7 @@ from ...schedulers import DDIMScheduler
...
@@ -26,6 +26,7 @@ from ...schedulers import DDIMScheduler
from
...utils
import
(
from
...utils
import
(
USE_PEFT_BACKEND
,
USE_PEFT_BACKEND
,
deprecate
,
deprecate
,
is_torch_xla_available
,
logging
,
logging
,
replace_example_docstring
,
replace_example_docstring
,
scale_lora_layers
,
scale_lora_layers
,
...
@@ -37,8 +38,16 @@ from ..stable_diffusion import StableDiffusionPipelineOutput
...
@@ -37,8 +38,16 @@ from ..stable_diffusion import StableDiffusionPipelineOutput
from
..stable_diffusion.safety_checker
import
StableDiffusionSafetyChecker
from
..stable_diffusion.safety_checker
import
StableDiffusionSafetyChecker
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
EXAMPLE_DOC_STRING
=
"""
Examples:
Examples:
```py
```py
...
@@ -1155,6 +1164,9 @@ class StableDiffusionPanoramaPipeline(
...
@@ -1155,6 +1164,9 @@ class StableDiffusionPanoramaPipeline(
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
output_type
!=
"latent"
:
if
output_type
!=
"latent"
:
if
circular_padding
:
if
circular_padding
:
image
=
self
.
decode_latents_with_padding
(
latents
)
image
=
self
.
decode_latents_with_padding
(
latents
)
...
...
src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py
View file @
95c5ce4e
...
@@ -12,13 +12,20 @@ from ...image_processor import PipelineImageInput
...
@@ -12,13 +12,20 @@ from ...image_processor import PipelineImageInput
from
...loaders
import
IPAdapterMixin
from
...loaders
import
IPAdapterMixin
from
...models
import
AutoencoderKL
,
ImageProjection
,
UNet2DConditionModel
from
...models
import
AutoencoderKL
,
ImageProjection
,
UNet2DConditionModel
from
...schedulers
import
KarrasDiffusionSchedulers
from
...schedulers
import
KarrasDiffusionSchedulers
from
...utils
import
deprecate
,
logging
from
...utils
import
deprecate
,
is_torch_xla_available
,
logging
from
...utils.torch_utils
import
randn_tensor
from
...utils.torch_utils
import
randn_tensor
from
..pipeline_utils
import
DiffusionPipeline
,
StableDiffusionMixin
from
..pipeline_utils
import
DiffusionPipeline
,
StableDiffusionMixin
from
.
import
StableDiffusionSafePipelineOutput
from
.
import
StableDiffusionSafePipelineOutput
from
.safety_checker
import
SafeStableDiffusionSafetyChecker
from
.safety_checker
import
SafeStableDiffusionSafetyChecker
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
@@ -739,6 +746,9 @@ class StableDiffusionPipelineSafe(DiffusionPipeline, StableDiffusionMixin, IPAda
...
@@ -739,6 +746,9 @@ class StableDiffusionPipelineSafe(DiffusionPipeline, StableDiffusionMixin, IPAda
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
# 8. Post-processing
# 8. Post-processing
image
=
self
.
decode_latents
(
latents
)
image
=
self
.
decode_latents
(
latents
)
...
...
src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py
View file @
95c5ce4e
...
@@ -27,6 +27,7 @@ from ...schedulers import KarrasDiffusionSchedulers
...
@@ -27,6 +27,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from
...utils
import
(
from
...utils
import
(
USE_PEFT_BACKEND
,
USE_PEFT_BACKEND
,
deprecate
,
deprecate
,
is_torch_xla_available
,
logging
,
logging
,
replace_example_docstring
,
replace_example_docstring
,
scale_lora_layers
,
scale_lora_layers
,
...
@@ -38,8 +39,16 @@ from ..stable_diffusion import StableDiffusionPipelineOutput
...
@@ -38,8 +39,16 @@ from ..stable_diffusion import StableDiffusionPipelineOutput
from
..stable_diffusion.safety_checker
import
StableDiffusionSafetyChecker
from
..stable_diffusion.safety_checker
import
StableDiffusionSafetyChecker
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
XLA_AVAILABLE
=
True
else
:
XLA_AVAILABLE
=
False
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
EXAMPLE_DOC_STRING
=
"""
EXAMPLE_DOC_STRING
=
"""
Examples:
Examples:
```py
```py
...
@@ -840,6 +849,9 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, StableDiffusionMixin, Textua
...
@@ -840,6 +849,9 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, StableDiffusionMixin, Textua
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
callback
(
step_idx
,
t
,
latents
)
if
XLA_AVAILABLE
:
xm
.
mark_step
()
if
not
output_type
==
"latent"
:
if
not
output_type
==
"latent"
:
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
image
,
has_nsfw_concept
=
self
.
run_safety_checker
(
image
,
device
,
prompt_embeds
.
dtype
)
image
,
has_nsfw_concept
=
self
.
run_safety_checker
(
image
,
device
,
prompt_embeds
.
dtype
)
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment