Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
ad8f985e
Unverified
Commit
ad8f985e
authored
Jul 14, 2023
by
Patrick von Platen
Committed by
GitHub
Jul 14, 2023
Browse files
Allow low precision vae sd xl (#4083)
* Allow low precision sd xl * finish * finish * make style
parent
ee2f2775
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
79 additions
and
62 deletions
+79
-62
src/diffusers/models/autoencoder_kl.py
src/diffusers/models/autoencoder_kl.py
+5
-0
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
...nes/stable_diffusion/pipeline_stable_diffusion_upscale.py
+22
-20
src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
...lines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
+23
-19
src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
...able_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
+29
-23
No files found.
src/diffusers/models/autoencoder_kl.py
View file @
ad8f985e
...
@@ -64,6 +64,10 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
...
@@ -64,6 +64,10 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
force_upcast (`bool`, *optional*, default to `True`):
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
can be fine-tuned / trained to a lower range without loosing too much precision in which case
`force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
"""
"""
_supports_gradient_checkpointing
=
True
_supports_gradient_checkpointing
=
True
...
@@ -82,6 +86,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
...
@@ -82,6 +86,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
norm_num_groups
:
int
=
32
,
norm_num_groups
:
int
=
32
,
sample_size
:
int
=
32
,
sample_size
:
int
=
32
,
scaling_factor
:
float
=
0.18215
,
scaling_factor
:
float
=
0.18215
,
force_upcast
:
float
=
True
,
):
):
super
().
__init__
()
super
().
__init__
()
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
View file @
ad8f985e
...
@@ -501,6 +501,25 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
...
@@ -501,6 +501,25 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
latents
=
latents
*
self
.
scheduler
.
init_noise_sigma
latents
=
latents
*
self
.
scheduler
.
init_noise_sigma
return
latents
return
latents
def
upcast_vae
(
self
):
dtype
=
self
.
vae
.
dtype
self
.
vae
.
to
(
dtype
=
torch
.
float32
)
use_torch_2_0_or_xformers
=
isinstance
(
self
.
vae
.
decoder
.
mid_block
.
attentions
[
0
].
processor
,
(
AttnProcessor2_0
,
XFormersAttnProcessor
,
LoRAXFormersAttnProcessor
,
LoRAAttnProcessor2_0
,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if
use_torch_2_0_or_xformers
:
self
.
vae
.
post_quant_conv
.
to
(
dtype
)
self
.
vae
.
decoder
.
conv_in
.
to
(
dtype
)
self
.
vae
.
decoder
.
mid_block
.
to
(
dtype
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
__call__
(
def
__call__
(
self
,
self
,
...
@@ -746,26 +765,9 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
...
@@ -746,26 +765,9 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
# 10. Post-processing
# 10. Post-processing
# make sure the VAE is in float32 mode, as it overflows in float16
# make sure the VAE is in float32 mode, as it overflows in float16
self
.
vae
.
to
(
dtype
=
torch
.
float32
)
if
self
.
vae
.
dtype
==
torch
.
float16
and
self
.
vae
.
config
.
force_upcast
:
self
.
upcast_vae
()
use_torch_2_0_or_xformers
=
isinstance
(
latents
=
latents
.
to
(
next
(
iter
(
self
.
vae
.
post_quant_conv
.
parameters
())).
dtype
)
self
.
vae
.
decoder
.
mid_block
.
attentions
[
0
].
processor
,
(
AttnProcessor2_0
,
XFormersAttnProcessor
,
LoRAXFormersAttnProcessor
,
LoRAAttnProcessor2_0
,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if
use_torch_2_0_or_xformers
:
self
.
vae
.
post_quant_conv
.
to
(
latents
.
dtype
)
self
.
vae
.
decoder
.
conv_in
.
to
(
latents
.
dtype
)
self
.
vae
.
decoder
.
mid_block
.
to
(
latents
.
dtype
)
else
:
latents
=
latents
.
float
()
# post-processing
# post-processing
if
not
output_type
==
"latent"
:
if
not
output_type
==
"latent"
:
...
...
src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
View file @
ad8f985e
...
@@ -537,6 +537,26 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
...
@@ -537,6 +537,26 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
add_time_ids
=
torch
.
tensor
([
add_time_ids
],
dtype
=
dtype
)
add_time_ids
=
torch
.
tensor
([
add_time_ids
],
dtype
=
dtype
)
return
add_time_ids
return
add_time_ids
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
def
upcast_vae
(
self
):
dtype
=
self
.
vae
.
dtype
self
.
vae
.
to
(
dtype
=
torch
.
float32
)
use_torch_2_0_or_xformers
=
isinstance
(
self
.
vae
.
decoder
.
mid_block
.
attentions
[
0
].
processor
,
(
AttnProcessor2_0
,
XFormersAttnProcessor
,
LoRAXFormersAttnProcessor
,
LoRAAttnProcessor2_0
,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if
use_torch_2_0_or_xformers
:
self
.
vae
.
post_quant_conv
.
to
(
dtype
)
self
.
vae
.
decoder
.
conv_in
.
to
(
dtype
)
self
.
vae
.
decoder
.
mid_block
.
to
(
dtype
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
@
replace_example_docstring
(
EXAMPLE_DOC_STRING
)
@
replace_example_docstring
(
EXAMPLE_DOC_STRING
)
def
__call__
(
def
__call__
(
...
@@ -799,25 +819,9 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
...
@@ -799,25 +819,9 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
callback
(
i
,
t
,
latents
)
callback
(
i
,
t
,
latents
)
# make sure the VAE is in float32 mode, as it overflows in float16
# make sure the VAE is in float32 mode, as it overflows in float16
self
.
vae
.
to
(
dtype
=
torch
.
float32
)
if
self
.
vae
.
dtype
==
torch
.
float16
and
self
.
vae
.
config
.
force_upcast
:
self
.
upcast_vae
()
use_torch_2_0_or_xformers
=
isinstance
(
latents
=
latents
.
to
(
next
(
iter
(
self
.
vae
.
post_quant_conv
.
parameters
())).
dtype
)
self
.
vae
.
decoder
.
mid_block
.
attentions
[
0
].
processor
,
(
AttnProcessor2_0
,
XFormersAttnProcessor
,
LoRAXFormersAttnProcessor
,
LoRAAttnProcessor2_0
,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if
use_torch_2_0_or_xformers
:
self
.
vae
.
post_quant_conv
.
to
(
latents
.
dtype
)
self
.
vae
.
decoder
.
conv_in
.
to
(
latents
.
dtype
)
self
.
vae
.
decoder
.
mid_block
.
to
(
latents
.
dtype
)
else
:
latents
=
latents
.
float
()
if
not
output_type
==
"latent"
:
if
not
output_type
==
"latent"
:
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
...
...
src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
View file @
ad8f985e
...
@@ -542,8 +542,9 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
...
@@ -542,8 +542,9 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
else
:
else
:
# make sure the VAE is in float32 mode, as it overflows in float16
# make sure the VAE is in float32 mode, as it overflows in float16
image
=
image
.
float
()
if
self
.
vae
.
config
.
force_upcast
:
self
.
vae
.
to
(
dtype
=
torch
.
float32
)
image
=
image
.
float
()
self
.
vae
.
to
(
dtype
=
torch
.
float32
)
if
isinstance
(
generator
,
list
)
and
len
(
generator
)
!=
batch_size
:
if
isinstance
(
generator
,
list
)
and
len
(
generator
)
!=
batch_size
:
raise
ValueError
(
raise
ValueError
(
...
@@ -559,9 +560,10 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
...
@@ -559,9 +560,10 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
else
:
else
:
init_latents
=
self
.
vae
.
encode
(
image
).
latent_dist
.
sample
(
generator
)
init_latents
=
self
.
vae
.
encode
(
image
).
latent_dist
.
sample
(
generator
)
self
.
vae
.
to
(
dtype
)
if
self
.
vae
.
config
.
force_upcast
:
init_latents
=
init_latents
.
to
(
dtype
)
self
.
vae
.
to
(
dtype
)
init_latents
=
init_latents
.
to
(
dtype
)
init_latents
=
self
.
vae
.
config
.
scaling_factor
*
init_latents
init_latents
=
self
.
vae
.
config
.
scaling_factor
*
init_latents
if
batch_size
>
init_latents
.
shape
[
0
]
and
batch_size
%
init_latents
.
shape
[
0
]
==
0
:
if
batch_size
>
init_latents
.
shape
[
0
]
and
batch_size
%
init_latents
.
shape
[
0
]
==
0
:
...
@@ -624,6 +626,26 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
...
@@ -624,6 +626,26 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
return
add_time_ids
,
add_neg_time_ids
return
add_time_ids
,
add_neg_time_ids
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
def
upcast_vae
(
self
):
dtype
=
self
.
vae
.
dtype
self
.
vae
.
to
(
dtype
=
torch
.
float32
)
use_torch_2_0_or_xformers
=
isinstance
(
self
.
vae
.
decoder
.
mid_block
.
attentions
[
0
].
processor
,
(
AttnProcessor2_0
,
XFormersAttnProcessor
,
LoRAXFormersAttnProcessor
,
LoRAAttnProcessor2_0
,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if
use_torch_2_0_or_xformers
:
self
.
vae
.
post_quant_conv
.
to
(
dtype
)
self
.
vae
.
decoder
.
conv_in
.
to
(
dtype
)
self
.
vae
.
decoder
.
mid_block
.
to
(
dtype
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
@
replace_example_docstring
(
EXAMPLE_DOC_STRING
)
@
replace_example_docstring
(
EXAMPLE_DOC_STRING
)
def
__call__
(
def
__call__
(
...
@@ -932,25 +954,9 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
...
@@ -932,25 +954,9 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
callback
(
i
,
t
,
latents
)
callback
(
i
,
t
,
latents
)
# make sure the VAE is in float32 mode, as it overflows in float16
# make sure the VAE is in float32 mode, as it overflows in float16
self
.
vae
.
to
(
dtype
=
torch
.
float32
)
if
self
.
vae
.
dtype
==
torch
.
float16
and
self
.
vae
.
config
.
force_upcast
:
self
.
upcast_vae
()
use_torch_2_0_or_xformers
=
isinstance
(
latents
=
latents
.
to
(
next
(
iter
(
self
.
vae
.
post_quant_conv
.
parameters
())).
dtype
)
self
.
vae
.
decoder
.
mid_block
.
attentions
[
0
].
processor
,
(
AttnProcessor2_0
,
XFormersAttnProcessor
,
LoRAXFormersAttnProcessor
,
LoRAAttnProcessor2_0
,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if
use_torch_2_0_or_xformers
:
self
.
vae
.
post_quant_conv
.
to
(
latents
.
dtype
)
self
.
vae
.
decoder
.
conv_in
.
to
(
latents
.
dtype
)
self
.
vae
.
decoder
.
mid_block
.
to
(
latents
.
dtype
)
else
:
latents
=
latents
.
float
()
if
not
output_type
==
"latent"
:
if
not
output_type
==
"latent"
:
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment