Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
ad8f985e
Unverified
Commit
ad8f985e
authored
Jul 14, 2023
by
Patrick von Platen
Committed by
GitHub
Jul 14, 2023
Browse files
Allow low precision vae sd xl (#4083)
* Allow low precision sd xl * finish * finish * make style
parent
ee2f2775
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
79 additions
and
62 deletions
+79
-62
src/diffusers/models/autoencoder_kl.py
src/diffusers/models/autoencoder_kl.py
+5
-0
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
...nes/stable_diffusion/pipeline_stable_diffusion_upscale.py
+22
-20
src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
...lines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
+23
-19
src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
...able_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
+29
-23
No files found.
src/diffusers/models/autoencoder_kl.py
View file @
ad8f985e
...
@@ -64,6 +64,10 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
...
@@ -64,6 +64,10 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
force_upcast (`bool`, *optional*, default to `True`):
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
can be fine-tuned / trained to a lower range without loosing too much precision in which case
`force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
"""
"""
_supports_gradient_checkpointing
=
True
_supports_gradient_checkpointing
=
True
...
@@ -82,6 +86,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
...
@@ -82,6 +86,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
norm_num_groups
:
int
=
32
,
norm_num_groups
:
int
=
32
,
sample_size
:
int
=
32
,
sample_size
:
int
=
32
,
scaling_factor
:
float
=
0.18215
,
scaling_factor
:
float
=
0.18215
,
force_upcast
:
float
=
True
,
):
):
super
().
__init__
()
super
().
__init__
()
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
View file @
ad8f985e
...
@@ -501,6 +501,25 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
...
@@ -501,6 +501,25 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
latents
=
latents
*
self
.
scheduler
.
init_noise_sigma
latents
=
latents
*
self
.
scheduler
.
init_noise_sigma
return
latents
return
latents
def
upcast_vae
(
self
):
dtype
=
self
.
vae
.
dtype
self
.
vae
.
to
(
dtype
=
torch
.
float32
)
use_torch_2_0_or_xformers
=
isinstance
(
self
.
vae
.
decoder
.
mid_block
.
attentions
[
0
].
processor
,
(
AttnProcessor2_0
,
XFormersAttnProcessor
,
LoRAXFormersAttnProcessor
,
LoRAAttnProcessor2_0
,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if
use_torch_2_0_or_xformers
:
self
.
vae
.
post_quant_conv
.
to
(
dtype
)
self
.
vae
.
decoder
.
conv_in
.
to
(
dtype
)
self
.
vae
.
decoder
.
mid_block
.
to
(
dtype
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
__call__
(
def
__call__
(
self
,
self
,
...
@@ -746,26 +765,9 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
...
@@ -746,26 +765,9 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
# 10. Post-processing
# 10. Post-processing
# make sure the VAE is in float32 mode, as it overflows in float16
# make sure the VAE is in float32 mode, as it overflows in float16
self
.
vae
.
to
(
dtype
=
torch
.
float32
)
if
self
.
vae
.
dtype
==
torch
.
float16
and
self
.
vae
.
config
.
force_upcast
:
self
.
upcast_vae
()
use_torch_2_0_or_xformers
=
isinstance
(
latents
=
latents
.
to
(
next
(
iter
(
self
.
vae
.
post_quant_conv
.
parameters
())).
dtype
)
self
.
vae
.
decoder
.
mid_block
.
attentions
[
0
].
processor
,
(
AttnProcessor2_0
,
XFormersAttnProcessor
,
LoRAXFormersAttnProcessor
,
LoRAAttnProcessor2_0
,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if
use_torch_2_0_or_xformers
:
self
.
vae
.
post_quant_conv
.
to
(
latents
.
dtype
)
self
.
vae
.
decoder
.
conv_in
.
to
(
latents
.
dtype
)
self
.
vae
.
decoder
.
mid_block
.
to
(
latents
.
dtype
)
else
:
latents
=
latents
.
float
()
# post-processing
# post-processing
if
not
output_type
==
"latent"
:
if
not
output_type
==
"latent"
:
...
...
src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
View file @
ad8f985e
...
@@ -537,6 +537,26 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
...
@@ -537,6 +537,26 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
add_time_ids
=
torch
.
tensor
([
add_time_ids
],
dtype
=
dtype
)
add_time_ids
=
torch
.
tensor
([
add_time_ids
],
dtype
=
dtype
)
return
add_time_ids
return
add_time_ids
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
def
upcast_vae
(
self
):
dtype
=
self
.
vae
.
dtype
self
.
vae
.
to
(
dtype
=
torch
.
float32
)
use_torch_2_0_or_xformers
=
isinstance
(
self
.
vae
.
decoder
.
mid_block
.
attentions
[
0
].
processor
,
(
AttnProcessor2_0
,
XFormersAttnProcessor
,
LoRAXFormersAttnProcessor
,
LoRAAttnProcessor2_0
,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if
use_torch_2_0_or_xformers
:
self
.
vae
.
post_quant_conv
.
to
(
dtype
)
self
.
vae
.
decoder
.
conv_in
.
to
(
dtype
)
self
.
vae
.
decoder
.
mid_block
.
to
(
dtype
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
@
replace_example_docstring
(
EXAMPLE_DOC_STRING
)
@
replace_example_docstring
(
EXAMPLE_DOC_STRING
)
def
__call__
(
def
__call__
(
...
@@ -799,25 +819,9 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
...
@@ -799,25 +819,9 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
callback
(
i
,
t
,
latents
)
callback
(
i
,
t
,
latents
)
# make sure the VAE is in float32 mode, as it overflows in float16
# make sure the VAE is in float32 mode, as it overflows in float16
self
.
vae
.
to
(
dtype
=
torch
.
float32
)
if
self
.
vae
.
dtype
==
torch
.
float16
and
self
.
vae
.
config
.
force_upcast
:
self
.
upcast_vae
()
use_torch_2_0_or_xformers
=
isinstance
(
latents
=
latents
.
to
(
next
(
iter
(
self
.
vae
.
post_quant_conv
.
parameters
())).
dtype
)
self
.
vae
.
decoder
.
mid_block
.
attentions
[
0
].
processor
,
(
AttnProcessor2_0
,
XFormersAttnProcessor
,
LoRAXFormersAttnProcessor
,
LoRAAttnProcessor2_0
,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if
use_torch_2_0_or_xformers
:
self
.
vae
.
post_quant_conv
.
to
(
latents
.
dtype
)
self
.
vae
.
decoder
.
conv_in
.
to
(
latents
.
dtype
)
self
.
vae
.
decoder
.
mid_block
.
to
(
latents
.
dtype
)
else
:
latents
=
latents
.
float
()
if
not
output_type
==
"latent"
:
if
not
output_type
==
"latent"
:
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
...
...
src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
View file @
ad8f985e
...
@@ -542,6 +542,7 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
...
@@ -542,6 +542,7 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
else
:
else
:
# make sure the VAE is in float32 mode, as it overflows in float16
# make sure the VAE is in float32 mode, as it overflows in float16
if
self
.
vae
.
config
.
force_upcast
:
image
=
image
.
float
()
image
=
image
.
float
()
self
.
vae
.
to
(
dtype
=
torch
.
float32
)
self
.
vae
.
to
(
dtype
=
torch
.
float32
)
...
@@ -559,9 +560,10 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
...
@@ -559,9 +560,10 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
else
:
else
:
init_latents
=
self
.
vae
.
encode
(
image
).
latent_dist
.
sample
(
generator
)
init_latents
=
self
.
vae
.
encode
(
image
).
latent_dist
.
sample
(
generator
)
if
self
.
vae
.
config
.
force_upcast
:
self
.
vae
.
to
(
dtype
)
self
.
vae
.
to
(
dtype
)
init_latents
=
init_latents
.
to
(
dtype
)
init_latents
=
init_latents
.
to
(
dtype
)
init_latents
=
self
.
vae
.
config
.
scaling_factor
*
init_latents
init_latents
=
self
.
vae
.
config
.
scaling_factor
*
init_latents
if
batch_size
>
init_latents
.
shape
[
0
]
and
batch_size
%
init_latents
.
shape
[
0
]
==
0
:
if
batch_size
>
init_latents
.
shape
[
0
]
and
batch_size
%
init_latents
.
shape
[
0
]
==
0
:
...
@@ -624,6 +626,26 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
...
@@ -624,6 +626,26 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
return
add_time_ids
,
add_neg_time_ids
return
add_time_ids
,
add_neg_time_ids
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
def
upcast_vae
(
self
):
dtype
=
self
.
vae
.
dtype
self
.
vae
.
to
(
dtype
=
torch
.
float32
)
use_torch_2_0_or_xformers
=
isinstance
(
self
.
vae
.
decoder
.
mid_block
.
attentions
[
0
].
processor
,
(
AttnProcessor2_0
,
XFormersAttnProcessor
,
LoRAXFormersAttnProcessor
,
LoRAAttnProcessor2_0
,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if
use_torch_2_0_or_xformers
:
self
.
vae
.
post_quant_conv
.
to
(
dtype
)
self
.
vae
.
decoder
.
conv_in
.
to
(
dtype
)
self
.
vae
.
decoder
.
mid_block
.
to
(
dtype
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
@
replace_example_docstring
(
EXAMPLE_DOC_STRING
)
@
replace_example_docstring
(
EXAMPLE_DOC_STRING
)
def
__call__
(
def
__call__
(
...
@@ -932,25 +954,9 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
...
@@ -932,25 +954,9 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
callback
(
i
,
t
,
latents
)
callback
(
i
,
t
,
latents
)
# make sure the VAE is in float32 mode, as it overflows in float16
# make sure the VAE is in float32 mode, as it overflows in float16
self
.
vae
.
to
(
dtype
=
torch
.
float32
)
if
self
.
vae
.
dtype
==
torch
.
float16
and
self
.
vae
.
config
.
force_upcast
:
self
.
upcast_vae
()
use_torch_2_0_or_xformers
=
isinstance
(
latents
=
latents
.
to
(
next
(
iter
(
self
.
vae
.
post_quant_conv
.
parameters
())).
dtype
)
self
.
vae
.
decoder
.
mid_block
.
attentions
[
0
].
processor
,
(
AttnProcessor2_0
,
XFormersAttnProcessor
,
LoRAXFormersAttnProcessor
,
LoRAAttnProcessor2_0
,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if
use_torch_2_0_or_xformers
:
self
.
vae
.
post_quant_conv
.
to
(
latents
.
dtype
)
self
.
vae
.
decoder
.
conv_in
.
to
(
latents
.
dtype
)
self
.
vae
.
decoder
.
mid_block
.
to
(
latents
.
dtype
)
else
:
latents
=
latents
.
float
()
if
not
output_type
==
"latent"
:
if
not
output_type
==
"latent"
:
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment