Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
5b972fbd
Unverified
Commit
5b972fbd
authored
Nov 08, 2024
by
Michael Tkachuk
Committed by
GitHub
Nov 08, 2024
Browse files
Enabling gradient checkpointing in eval() mode (#9878)
* refactored
parent
0be52c07
Changes
34
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
42 additions
and
42 deletions
+42
-42
examples/community/matryoshka.py
examples/community/matryoshka.py
+4
-4
examples/research_projects/pixart/controlnet_pixart_alpha.py
examples/research_projects/pixart/controlnet_pixart_alpha.py
+1
-1
src/diffusers/models/autoencoders/autoencoder_kl_allegro.py
src/diffusers/models/autoencoders/autoencoder_kl_allegro.py
+2
-2
src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
...diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
+5
-5
src/diffusers/models/autoencoders/autoencoder_kl_mochi.py
src/diffusers/models/autoencoders/autoencoder_kl_mochi.py
+5
-5
src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py
...rs/models/autoencoders/autoencoder_kl_temporal_decoder.py
+1
-1
src/diffusers/models/autoencoders/vae.py
src/diffusers/models/autoencoders/vae.py
+5
-5
src/diffusers/models/controlnets/controlnet_flux.py
src/diffusers/models/controlnets/controlnet_flux.py
+2
-2
src/diffusers/models/controlnets/controlnet_sd3.py
src/diffusers/models/controlnets/controlnet_sd3.py
+1
-1
src/diffusers/models/controlnets/controlnet_xs.py
src/diffusers/models/controlnets/controlnet_xs.py
+3
-3
src/diffusers/models/transformers/auraflow_transformer_2d.py
src/diffusers/models/transformers/auraflow_transformer_2d.py
+2
-2
src/diffusers/models/transformers/cogvideox_transformer_3d.py
...diffusers/models/transformers/cogvideox_transformer_3d.py
+1
-1
src/diffusers/models/transformers/dit_transformer_2d.py
src/diffusers/models/transformers/dit_transformer_2d.py
+1
-1
src/diffusers/models/transformers/latte_transformer_3d.py
src/diffusers/models/transformers/latte_transformer_3d.py
+2
-2
src/diffusers/models/transformers/pixart_transformer_2d.py
src/diffusers/models/transformers/pixart_transformer_2d.py
+1
-1
src/diffusers/models/transformers/stable_audio_transformer.py
...diffusers/models/transformers/stable_audio_transformer.py
+1
-1
src/diffusers/models/transformers/transformer_2d.py
src/diffusers/models/transformers/transformer_2d.py
+1
-1
src/diffusers/models/transformers/transformer_allegro.py
src/diffusers/models/transformers/transformer_allegro.py
+1
-1
src/diffusers/models/transformers/transformer_cogview3plus.py
...diffusers/models/transformers/transformer_cogview3plus.py
+1
-1
src/diffusers/models/transformers/transformer_flux.py
src/diffusers/models/transformers/transformer_flux.py
+2
-2
No files found.
examples/community/matryoshka.py
View file @
5b972fbd
...
@@ -868,7 +868,7 @@ class CrossAttnDownBlock2D(nn.Module):
...
@@ -868,7 +868,7 @@ class CrossAttnDownBlock2D(nn.Module):
blocks
=
list
(
zip
(
self
.
resnets
,
self
.
attentions
))
blocks
=
list
(
zip
(
self
.
resnets
,
self
.
attentions
))
for
i
,
(
resnet
,
attn
)
in
enumerate
(
blocks
):
for
i
,
(
resnet
,
attn
)
in
enumerate
(
blocks
):
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -1029,7 +1029,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
...
@@ -1029,7 +1029,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
hidden_states
=
self
.
resnets
[
0
](
hidden_states
,
temb
)
hidden_states
=
self
.
resnets
[
0
](
hidden_states
,
temb
)
for
attn
,
resnet
in
zip
(
self
.
attentions
,
self
.
resnets
[
1
:]):
for
attn
,
resnet
in
zip
(
self
.
attentions
,
self
.
resnets
[
1
:]):
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -1191,7 +1191,7 @@ class CrossAttnUpBlock2D(nn.Module):
...
@@ -1191,7 +1191,7 @@ class CrossAttnUpBlock2D(nn.Module):
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -1364,7 +1364,7 @@ class MatryoshkaTransformer2DModel(LegacyModelMixin, LegacyConfigMixin):
...
@@ -1364,7 +1364,7 @@ class MatryoshkaTransformer2DModel(LegacyModelMixin, LegacyConfigMixin):
# Blocks
# Blocks
for
block
in
self
.
transformer_blocks
:
for
block
in
self
.
transformer_blocks
:
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
...
examples/research_projects/pixart/controlnet_pixart_alpha.py
View file @
5b972fbd
...
@@ -215,7 +215,7 @@ class PixArtControlNetTransformerModel(ModelMixin, ConfigMixin):
...
@@ -215,7 +215,7 @@ class PixArtControlNetTransformerModel(ModelMixin, ConfigMixin):
# 2. Blocks
# 2. Blocks
for
block_index
,
block
in
enumerate
(
self
.
transformer
.
transformer_blocks
):
for
block_index
,
block
in
enumerate
(
self
.
transformer
.
transformer_blocks
):
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
# rc todo: for training and gradient checkpointing
# rc todo: for training and gradient checkpointing
print
(
"Gradient checkpointing is not supported for the controlnet transformer model, yet."
)
print
(
"Gradient checkpointing is not supported for the controlnet transformer model, yet."
)
exit
(
1
)
exit
(
1
)
...
...
src/diffusers/models/autoencoders/autoencoder_kl_allegro.py
View file @
5b972fbd
...
@@ -506,7 +506,7 @@ class AllegroEncoder3D(nn.Module):
...
@@ -506,7 +506,7 @@ class AllegroEncoder3D(nn.Module):
sample
=
self
.
temp_conv_in
(
sample
)
sample
=
self
.
temp_conv_in
(
sample
)
sample
=
sample
+
residual
sample
=
sample
+
residual
if
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -646,7 +646,7 @@ class AllegroDecoder3D(nn.Module):
...
@@ -646,7 +646,7 @@ class AllegroDecoder3D(nn.Module):
upscale_dtype
=
next
(
iter
(
self
.
up_blocks
.
parameters
())).
dtype
upscale_dtype
=
next
(
iter
(
self
.
up_blocks
.
parameters
())).
dtype
if
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
...
src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
View file @
5b972fbd
...
@@ -420,7 +420,7 @@ class CogVideoXDownBlock3D(nn.Module):
...
@@ -420,7 +420,7 @@ class CogVideoXDownBlock3D(nn.Module):
for
i
,
resnet
in
enumerate
(
self
.
resnets
):
for
i
,
resnet
in
enumerate
(
self
.
resnets
):
conv_cache_key
=
f
"resnet_
{
i
}
"
conv_cache_key
=
f
"resnet_
{
i
}
"
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
create_forward
(
*
inputs
):
def
create_forward
(
*
inputs
):
...
@@ -522,7 +522,7 @@ class CogVideoXMidBlock3D(nn.Module):
...
@@ -522,7 +522,7 @@ class CogVideoXMidBlock3D(nn.Module):
for
i
,
resnet
in
enumerate
(
self
.
resnets
):
for
i
,
resnet
in
enumerate
(
self
.
resnets
):
conv_cache_key
=
f
"resnet_
{
i
}
"
conv_cache_key
=
f
"resnet_
{
i
}
"
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
create_forward
(
*
inputs
):
def
create_forward
(
*
inputs
):
...
@@ -636,7 +636,7 @@ class CogVideoXUpBlock3D(nn.Module):
...
@@ -636,7 +636,7 @@ class CogVideoXUpBlock3D(nn.Module):
for
i
,
resnet
in
enumerate
(
self
.
resnets
):
for
i
,
resnet
in
enumerate
(
self
.
resnets
):
conv_cache_key
=
f
"resnet_
{
i
}
"
conv_cache_key
=
f
"resnet_
{
i
}
"
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
create_forward
(
*
inputs
):
def
create_forward
(
*
inputs
):
...
@@ -773,7 +773,7 @@ class CogVideoXEncoder3D(nn.Module):
...
@@ -773,7 +773,7 @@ class CogVideoXEncoder3D(nn.Module):
hidden_states
,
new_conv_cache
[
"conv_in"
]
=
self
.
conv_in
(
sample
,
conv_cache
=
conv_cache
.
get
(
"conv_in"
))
hidden_states
,
new_conv_cache
[
"conv_in"
]
=
self
.
conv_in
(
sample
,
conv_cache
=
conv_cache
.
get
(
"conv_in"
))
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -939,7 +939,7 @@ class CogVideoXDecoder3D(nn.Module):
...
@@ -939,7 +939,7 @@ class CogVideoXDecoder3D(nn.Module):
hidden_states
,
new_conv_cache
[
"conv_in"
]
=
self
.
conv_in
(
sample
,
conv_cache
=
conv_cache
.
get
(
"conv_in"
))
hidden_states
,
new_conv_cache
[
"conv_in"
]
=
self
.
conv_in
(
sample
,
conv_cache
=
conv_cache
.
get
(
"conv_in"
))
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
...
src/diffusers/models/autoencoders/autoencoder_kl_mochi.py
View file @
5b972fbd
...
@@ -206,7 +206,7 @@ class MochiDownBlock3D(nn.Module):
...
@@ -206,7 +206,7 @@ class MochiDownBlock3D(nn.Module):
for
i
,
(
resnet
,
norm
,
attn
)
in
enumerate
(
zip
(
self
.
resnets
,
self
.
norms
,
self
.
attentions
)):
for
i
,
(
resnet
,
norm
,
attn
)
in
enumerate
(
zip
(
self
.
resnets
,
self
.
norms
,
self
.
attentions
)):
conv_cache_key
=
f
"resnet_
{
i
}
"
conv_cache_key
=
f
"resnet_
{
i
}
"
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
create_forward
(
*
inputs
):
def
create_forward
(
*
inputs
):
...
@@ -311,7 +311,7 @@ class MochiMidBlock3D(nn.Module):
...
@@ -311,7 +311,7 @@ class MochiMidBlock3D(nn.Module):
for
i
,
(
resnet
,
norm
,
attn
)
in
enumerate
(
zip
(
self
.
resnets
,
self
.
norms
,
self
.
attentions
)):
for
i
,
(
resnet
,
norm
,
attn
)
in
enumerate
(
zip
(
self
.
resnets
,
self
.
norms
,
self
.
attentions
)):
conv_cache_key
=
f
"resnet_
{
i
}
"
conv_cache_key
=
f
"resnet_
{
i
}
"
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
create_forward
(
*
inputs
):
def
create_forward
(
*
inputs
):
...
@@ -392,7 +392,7 @@ class MochiUpBlock3D(nn.Module):
...
@@ -392,7 +392,7 @@ class MochiUpBlock3D(nn.Module):
for
i
,
resnet
in
enumerate
(
self
.
resnets
):
for
i
,
resnet
in
enumerate
(
self
.
resnets
):
conv_cache_key
=
f
"resnet_
{
i
}
"
conv_cache_key
=
f
"resnet_
{
i
}
"
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
create_forward
(
*
inputs
):
def
create_forward
(
*
inputs
):
...
@@ -529,7 +529,7 @@ class MochiEncoder3D(nn.Module):
...
@@ -529,7 +529,7 @@ class MochiEncoder3D(nn.Module):
hidden_states
=
self
.
proj_in
(
hidden_states
)
hidden_states
=
self
.
proj_in
(
hidden_states
)
hidden_states
=
hidden_states
.
permute
(
0
,
4
,
1
,
2
,
3
)
hidden_states
=
hidden_states
.
permute
(
0
,
4
,
1
,
2
,
3
)
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
create_forward
(
*
inputs
):
def
create_forward
(
*
inputs
):
...
@@ -646,7 +646,7 @@ class MochiDecoder3D(nn.Module):
...
@@ -646,7 +646,7 @@ class MochiDecoder3D(nn.Module):
hidden_states
=
self
.
conv_in
(
hidden_states
)
hidden_states
=
self
.
conv_in
(
hidden_states
)
# 1. Mid
# 1. Mid
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
create_forward
(
*
inputs
):
def
create_forward
(
*
inputs
):
...
...
src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py
View file @
5b972fbd
...
@@ -95,7 +95,7 @@ class TemporalDecoder(nn.Module):
...
@@ -95,7 +95,7 @@ class TemporalDecoder(nn.Module):
sample
=
self
.
conv_in
(
sample
)
sample
=
self
.
conv_in
(
sample
)
upscale_dtype
=
next
(
iter
(
self
.
up_blocks
.
parameters
())).
dtype
upscale_dtype
=
next
(
iter
(
self
.
up_blocks
.
parameters
())).
dtype
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
...
src/diffusers/models/autoencoders/vae.py
View file @
5b972fbd
...
@@ -142,7 +142,7 @@ class Encoder(nn.Module):
...
@@ -142,7 +142,7 @@ class Encoder(nn.Module):
sample
=
self
.
conv_in
(
sample
)
sample
=
self
.
conv_in
(
sample
)
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -291,7 +291,7 @@ class Decoder(nn.Module):
...
@@ -291,7 +291,7 @@ class Decoder(nn.Module):
sample
=
self
.
conv_in
(
sample
)
sample
=
self
.
conv_in
(
sample
)
upscale_dtype
=
next
(
iter
(
self
.
up_blocks
.
parameters
())).
dtype
upscale_dtype
=
next
(
iter
(
self
.
up_blocks
.
parameters
())).
dtype
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -544,7 +544,7 @@ class MaskConditionDecoder(nn.Module):
...
@@ -544,7 +544,7 @@ class MaskConditionDecoder(nn.Module):
sample
=
self
.
conv_in
(
sample
)
sample
=
self
.
conv_in
(
sample
)
upscale_dtype
=
next
(
iter
(
self
.
up_blocks
.
parameters
())).
dtype
upscale_dtype
=
next
(
iter
(
self
.
up_blocks
.
parameters
())).
dtype
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -876,7 +876,7 @@ class EncoderTiny(nn.Module):
...
@@ -876,7 +876,7 @@ class EncoderTiny(nn.Module):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
r
"""The forward method of the `EncoderTiny` class."""
r
"""The forward method of the `EncoderTiny` class."""
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -962,7 +962,7 @@ class DecoderTiny(nn.Module):
...
@@ -962,7 +962,7 @@ class DecoderTiny(nn.Module):
# Clamp.
# Clamp.
x
=
torch
.
tanh
(
x
/
3
)
*
3
x
=
torch
.
tanh
(
x
/
3
)
*
3
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
...
src/diffusers/models/controlnets/controlnet_flux.py
View file @
5b972fbd
...
@@ -329,7 +329,7 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
...
@@ -329,7 +329,7 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
block_samples
=
()
block_samples
=
()
for
index_block
,
block
in
enumerate
(
self
.
transformer_blocks
):
for
index_block
,
block
in
enumerate
(
self
.
transformer_blocks
):
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -363,7 +363,7 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
...
@@ -363,7 +363,7 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
single_block_samples
=
()
single_block_samples
=
()
for
index_block
,
block
in
enumerate
(
self
.
single_transformer_blocks
):
for
index_block
,
block
in
enumerate
(
self
.
single_transformer_blocks
):
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
...
src/diffusers/models/controlnets/controlnet_sd3.py
View file @
5b972fbd
...
@@ -324,7 +324,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
...
@@ -324,7 +324,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
block_res_samples
=
()
block_res_samples
=
()
for
block
in
self
.
transformer_blocks
:
for
block
in
self
.
transformer_blocks
:
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
...
src/diffusers/models/controlnets/controlnet_xs.py
View file @
5b972fbd
...
@@ -1466,7 +1466,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
...
@@ -1466,7 +1466,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
h_ctrl
=
torch
.
cat
([
h_ctrl
,
b2c
(
h_base
)],
dim
=
1
)
h_ctrl
=
torch
.
cat
([
h_ctrl
,
b2c
(
h_base
)],
dim
=
1
)
# apply base subblock
# apply base subblock
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
ckpt_kwargs
:
Dict
[
str
,
Any
]
=
{
"use_reentrant"
:
False
}
if
is_torch_version
(
">="
,
"1.11.0"
)
else
{}
ckpt_kwargs
:
Dict
[
str
,
Any
]
=
{
"use_reentrant"
:
False
}
if
is_torch_version
(
">="
,
"1.11.0"
)
else
{}
h_base
=
torch
.
utils
.
checkpoint
.
checkpoint
(
h_base
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
b_res
),
create_custom_forward
(
b_res
),
...
@@ -1489,7 +1489,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
...
@@ -1489,7 +1489,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
# apply ctrl subblock
# apply ctrl subblock
if
apply_control
:
if
apply_control
:
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
ckpt_kwargs
:
Dict
[
str
,
Any
]
=
{
"use_reentrant"
:
False
}
if
is_torch_version
(
">="
,
"1.11.0"
)
else
{}
ckpt_kwargs
:
Dict
[
str
,
Any
]
=
{
"use_reentrant"
:
False
}
if
is_torch_version
(
">="
,
"1.11.0"
)
else
{}
h_ctrl
=
torch
.
utils
.
checkpoint
.
checkpoint
(
h_ctrl
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
c_res
),
create_custom_forward
(
c_res
),
...
@@ -1898,7 +1898,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
...
@@ -1898,7 +1898,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
hidden_states
,
res_h_base
=
maybe_apply_freeu_to_subblock
(
hidden_states
,
res_h_base
)
hidden_states
,
res_h_base
=
maybe_apply_freeu_to_subblock
(
hidden_states
,
res_h_base
)
hidden_states
=
torch
.
cat
([
hidden_states
,
res_h_base
],
dim
=
1
)
hidden_states
=
torch
.
cat
([
hidden_states
,
res_h_base
],
dim
=
1
)
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
ckpt_kwargs
:
Dict
[
str
,
Any
]
=
{
"use_reentrant"
:
False
}
if
is_torch_version
(
">="
,
"1.11.0"
)
else
{}
ckpt_kwargs
:
Dict
[
str
,
Any
]
=
{
"use_reentrant"
:
False
}
if
is_torch_version
(
">="
,
"1.11.0"
)
else
{}
hidden_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
hidden_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
resnet
),
create_custom_forward
(
resnet
),
...
...
src/diffusers/models/transformers/auraflow_transformer_2d.py
View file @
5b972fbd
...
@@ -466,7 +466,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
...
@@ -466,7 +466,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
# MMDiT blocks.
# MMDiT blocks.
for
index_block
,
block
in
enumerate
(
self
.
joint_transformer_blocks
):
for
index_block
,
block
in
enumerate
(
self
.
joint_transformer_blocks
):
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -497,7 +497,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
...
@@ -497,7 +497,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
combined_hidden_states
=
torch
.
cat
([
encoder_hidden_states
,
hidden_states
],
dim
=
1
)
combined_hidden_states
=
torch
.
cat
([
encoder_hidden_states
,
hidden_states
],
dim
=
1
)
for
index_block
,
block
in
enumerate
(
self
.
single_transformer_blocks
):
for
index_block
,
block
in
enumerate
(
self
.
single_transformer_blocks
):
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
...
src/diffusers/models/transformers/cogvideox_transformer_3d.py
View file @
5b972fbd
...
@@ -452,7 +452,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
...
@@ -452,7 +452,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
# 3. Transformer blocks
# 3. Transformer blocks
for
i
,
block
in
enumerate
(
self
.
transformer_blocks
):
for
i
,
block
in
enumerate
(
self
.
transformer_blocks
):
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
...
src/diffusers/models/transformers/dit_transformer_2d.py
View file @
5b972fbd
...
@@ -184,7 +184,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
...
@@ -184,7 +184,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
# 2. Blocks
# 2. Blocks
for
block
in
self
.
transformer_blocks
:
for
block
in
self
.
transformer_blocks
:
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
...
src/diffusers/models/transformers/latte_transformer_3d.py
View file @
5b972fbd
...
@@ -238,7 +238,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
...
@@ -238,7 +238,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
for
i
,
(
spatial_block
,
temp_block
)
in
enumerate
(
for
i
,
(
spatial_block
,
temp_block
)
in
enumerate
(
zip
(
self
.
transformer_blocks
,
self
.
temporal_transformer_blocks
)
zip
(
self
.
transformer_blocks
,
self
.
temporal_transformer_blocks
)
):
):
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
hidden_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
hidden_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
spatial_block
,
spatial_block
,
hidden_states
,
hidden_states
,
...
@@ -271,7 +271,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
...
@@ -271,7 +271,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
if
i
==
0
and
num_frame
>
1
:
if
i
==
0
and
num_frame
>
1
:
hidden_states
=
hidden_states
+
self
.
temp_pos_embed
hidden_states
=
hidden_states
+
self
.
temp_pos_embed
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
hidden_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
hidden_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
temp_block
,
temp_block
,
hidden_states
,
hidden_states
,
...
...
src/diffusers/models/transformers/pixart_transformer_2d.py
View file @
5b972fbd
...
@@ -386,7 +386,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
...
@@ -386,7 +386,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
# 2. Blocks
# 2. Blocks
for
block
in
self
.
transformer_blocks
:
for
block
in
self
.
transformer_blocks
:
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
...
src/diffusers/models/transformers/stable_audio_transformer.py
View file @
5b972fbd
...
@@ -414,7 +414,7 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin):
...
@@ -414,7 +414,7 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin):
attention_mask
=
torch
.
cat
([
prepend_mask
,
attention_mask
],
dim
=-
1
)
attention_mask
=
torch
.
cat
([
prepend_mask
,
attention_mask
],
dim
=-
1
)
for
block
in
self
.
transformer_blocks
:
for
block
in
self
.
transformer_blocks
:
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
...
src/diffusers/models/transformers/transformer_2d.py
View file @
5b972fbd
...
@@ -415,7 +415,7 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
...
@@ -415,7 +415,7 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
# 2. Blocks
# 2. Blocks
for
block
in
self
.
transformer_blocks
:
for
block
in
self
.
transformer_blocks
:
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
...
src/diffusers/models/transformers/transformer_allegro.py
View file @
5b972fbd
...
@@ -371,7 +371,7 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
...
@@ -371,7 +371,7 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
# 3. Transformer blocks
# 3. Transformer blocks
for
i
,
block
in
enumerate
(
self
.
transformer_blocks
):
for
i
,
block
in
enumerate
(
self
.
transformer_blocks
):
# TODO(aryan): Implement gradient checkpointing
# TODO(aryan): Implement gradient checkpointing
if
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
...
src/diffusers/models/transformers/transformer_cogview3plus.py
View file @
5b972fbd
...
@@ -341,7 +341,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
...
@@ -341,7 +341,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
hidden_states
=
hidden_states
[:,
text_seq_length
:]
hidden_states
=
hidden_states
[:,
text_seq_length
:]
for
index_block
,
block
in
enumerate
(
self
.
transformer_blocks
):
for
index_block
,
block
in
enumerate
(
self
.
transformer_blocks
):
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
...
src/diffusers/models/transformers/transformer_flux.py
View file @
5b972fbd
...
@@ -480,7 +480,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
...
@@ -480,7 +480,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
image_rotary_emb
=
self
.
pos_embed
(
ids
)
image_rotary_emb
=
self
.
pos_embed
(
ids
)
for
index_block
,
block
in
enumerate
(
self
.
transformer_blocks
):
for
index_block
,
block
in
enumerate
(
self
.
transformer_blocks
):
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -525,7 +525,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
...
@@ -525,7 +525,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
hidden_states
=
torch
.
cat
([
encoder_hidden_states
,
hidden_states
],
dim
=
1
)
hidden_states
=
torch
.
cat
([
encoder_hidden_states
,
hidden_states
],
dim
=
1
)
for
index_block
,
block
in
enumerate
(
self
.
single_transformer_blocks
):
for
index_block
,
block
in
enumerate
(
self
.
single_transformer_blocks
):
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment