Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
5b972fbd
Unverified
Commit
5b972fbd
authored
Nov 08, 2024
by
Michael Tkachuk
Committed by
GitHub
Nov 08, 2024
Browse files
Enabling gradient checkpointing in eval() mode (#9878)
* refactored
parent
0be52c07
Changes
34
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
42 additions
and
42 deletions
+42
-42
src/diffusers/models/transformers/transformer_mochi.py
src/diffusers/models/transformers/transformer_mochi.py
+1
-1
src/diffusers/models/transformers/transformer_sd3.py
src/diffusers/models/transformers/transformer_sd3.py
+1
-1
src/diffusers/models/transformers/transformer_temporal.py
src/diffusers/models/transformers/transformer_temporal.py
+1
-1
src/diffusers/models/unets/unet_2d_blocks.py
src/diffusers/models/unets/unet_2d_blocks.py
+13
-13
src/diffusers/models/unets/unet_3d_blocks.py
src/diffusers/models/unets/unet_3d_blocks.py
+5
-5
src/diffusers/models/unets/unet_motion_model.py
src/diffusers/models/unets/unet_motion_model.py
+5
-5
src/diffusers/models/unets/unet_stable_cascade.py
src/diffusers/models/unets/unet_stable_cascade.py
+2
-2
src/diffusers/models/unets/uvit_2d.py
src/diffusers/models/unets/uvit_2d.py
+1
-1
src/diffusers/pipelines/audioldm2/modeling_audioldm2.py
src/diffusers/pipelines/audioldm2/modeling_audioldm2.py
+3
-3
src/diffusers/pipelines/blip_diffusion/modeling_blip2.py
src/diffusers/pipelines/blip_diffusion/modeling_blip2.py
+1
-1
src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py
...ines/deprecated/versatile_diffusion/modeling_text_unet.py
+5
-5
src/diffusers/pipelines/kolors/text_encoder.py
src/diffusers/pipelines/kolors/text_encoder.py
+2
-2
src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
...s/pipelines/latent_diffusion/pipeline_latent_diffusion.py
+1
-1
src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py
...ffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py
+1
-1
No files found.
src/diffusers/models/transformers/transformer_mochi.py
View file @
5b972fbd
...
@@ -350,7 +350,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin):
...
@@ -350,7 +350,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin):
)
)
for
i
,
block
in
enumerate
(
self
.
transformer_blocks
):
for
i
,
block
in
enumerate
(
self
.
transformer_blocks
):
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
...
src/diffusers/models/transformers/transformer_sd3.py
View file @
5b972fbd
...
@@ -317,7 +317,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
...
@@ -317,7 +317,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
encoder_hidden_states
=
self
.
context_embedder
(
encoder_hidden_states
)
encoder_hidden_states
=
self
.
context_embedder
(
encoder_hidden_states
)
for
index_block
,
block
in
enumerate
(
self
.
transformer_blocks
):
for
index_block
,
block
in
enumerate
(
self
.
transformer_blocks
):
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
...
src/diffusers/models/transformers/transformer_temporal.py
View file @
5b972fbd
...
@@ -340,7 +340,7 @@ class TransformerSpatioTemporalModel(nn.Module):
...
@@ -340,7 +340,7 @@ class TransformerSpatioTemporalModel(nn.Module):
# 2. Blocks
# 2. Blocks
for
block
,
temporal_block
in
zip
(
self
.
transformer_blocks
,
self
.
temporal_transformer_blocks
):
for
block
,
temporal_block
in
zip
(
self
.
transformer_blocks
,
self
.
temporal_transformer_blocks
):
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
hidden_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
hidden_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
block
,
block
,
hidden_states
,
hidden_states
,
...
...
src/diffusers/models/unets/unet_2d_blocks.py
View file @
5b972fbd
...
@@ -859,7 +859,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
...
@@ -859,7 +859,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
hidden_states
=
self
.
resnets
[
0
](
hidden_states
,
temb
)
hidden_states
=
self
.
resnets
[
0
](
hidden_states
,
temb
)
for
attn
,
resnet
in
zip
(
self
.
attentions
,
self
.
resnets
[
1
:]):
for
attn
,
resnet
in
zip
(
self
.
attentions
,
self
.
resnets
[
1
:]):
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -1257,7 +1257,7 @@ class CrossAttnDownBlock2D(nn.Module):
...
@@ -1257,7 +1257,7 @@ class CrossAttnDownBlock2D(nn.Module):
blocks
=
list
(
zip
(
self
.
resnets
,
self
.
attentions
))
blocks
=
list
(
zip
(
self
.
resnets
,
self
.
attentions
))
for
i
,
(
resnet
,
attn
)
in
enumerate
(
blocks
):
for
i
,
(
resnet
,
attn
)
in
enumerate
(
blocks
):
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -1371,7 +1371,7 @@ class DownBlock2D(nn.Module):
...
@@ -1371,7 +1371,7 @@ class DownBlock2D(nn.Module):
output_states
=
()
output_states
=
()
for
resnet
in
self
.
resnets
:
for
resnet
in
self
.
resnets
:
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -1859,7 +1859,7 @@ class ResnetDownsampleBlock2D(nn.Module):
...
@@ -1859,7 +1859,7 @@ class ResnetDownsampleBlock2D(nn.Module):
output_states
=
()
output_states
=
()
for
resnet
in
self
.
resnets
:
for
resnet
in
self
.
resnets
:
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -2011,7 +2011,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
...
@@ -2011,7 +2011,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
mask
=
attention_mask
mask
=
attention_mask
for
resnet
,
attn
in
zip
(
self
.
resnets
,
self
.
attentions
):
for
resnet
,
attn
in
zip
(
self
.
resnets
,
self
.
attentions
):
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -2106,7 +2106,7 @@ class KDownBlock2D(nn.Module):
...
@@ -2106,7 +2106,7 @@ class KDownBlock2D(nn.Module):
output_states
=
()
output_states
=
()
for
resnet
in
self
.
resnets
:
for
resnet
in
self
.
resnets
:
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -2215,7 +2215,7 @@ class KCrossAttnDownBlock2D(nn.Module):
...
@@ -2215,7 +2215,7 @@ class KCrossAttnDownBlock2D(nn.Module):
output_states
=
()
output_states
=
()
for
resnet
,
attn
in
zip
(
self
.
resnets
,
self
.
attentions
):
for
resnet
,
attn
in
zip
(
self
.
resnets
,
self
.
attentions
):
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -2520,7 +2520,7 @@ class CrossAttnUpBlock2D(nn.Module):
...
@@ -2520,7 +2520,7 @@ class CrossAttnUpBlock2D(nn.Module):
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -2653,7 +2653,7 @@ class UpBlock2D(nn.Module):
...
@@ -2653,7 +2653,7 @@ class UpBlock2D(nn.Module):
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -3183,7 +3183,7 @@ class ResnetUpsampleBlock2D(nn.Module):
...
@@ -3183,7 +3183,7 @@ class ResnetUpsampleBlock2D(nn.Module):
res_hidden_states_tuple
=
res_hidden_states_tuple
[:
-
1
]
res_hidden_states_tuple
=
res_hidden_states_tuple
[:
-
1
]
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -3341,7 +3341,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
...
@@ -3341,7 +3341,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
res_hidden_states_tuple
=
res_hidden_states_tuple
[:
-
1
]
res_hidden_states_tuple
=
res_hidden_states_tuple
[:
-
1
]
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -3444,7 +3444,7 @@ class KUpBlock2D(nn.Module):
...
@@ -3444,7 +3444,7 @@ class KUpBlock2D(nn.Module):
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states_tuple
],
dim
=
1
)
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states_tuple
],
dim
=
1
)
for
resnet
in
self
.
resnets
:
for
resnet
in
self
.
resnets
:
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -3572,7 +3572,7 @@ class KCrossAttnUpBlock2D(nn.Module):
...
@@ -3572,7 +3572,7 @@ class KCrossAttnUpBlock2D(nn.Module):
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states_tuple
],
dim
=
1
)
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states_tuple
],
dim
=
1
)
for
resnet
,
attn
in
zip
(
self
.
resnets
,
self
.
attentions
):
for
resnet
,
attn
in
zip
(
self
.
resnets
,
self
.
attentions
):
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
...
src/diffusers/models/unets/unet_3d_blocks.py
View file @
5b972fbd
...
@@ -1078,7 +1078,7 @@ class UNetMidBlockSpatioTemporal(nn.Module):
...
@@ -1078,7 +1078,7 @@ class UNetMidBlockSpatioTemporal(nn.Module):
)
)
for
attn
,
resnet
in
zip
(
self
.
attentions
,
self
.
resnets
[
1
:]):
for
attn
,
resnet
in
zip
(
self
.
attentions
,
self
.
resnets
[
1
:]):
if
self
.
training
and
self
.
gradient_checkpointing
:
# TODO
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
# TODO
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -1168,7 +1168,7 @@ class DownBlockSpatioTemporal(nn.Module):
...
@@ -1168,7 +1168,7 @@ class DownBlockSpatioTemporal(nn.Module):
)
->
Tuple
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
...]]:
)
->
Tuple
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
...]]:
output_states
=
()
output_states
=
()
for
resnet
in
self
.
resnets
:
for
resnet
in
self
.
resnets
:
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -1281,7 +1281,7 @@ class CrossAttnDownBlockSpatioTemporal(nn.Module):
...
@@ -1281,7 +1281,7 @@ class CrossAttnDownBlockSpatioTemporal(nn.Module):
blocks
=
list
(
zip
(
self
.
resnets
,
self
.
attentions
))
blocks
=
list
(
zip
(
self
.
resnets
,
self
.
attentions
))
for
resnet
,
attn
in
blocks
:
for
resnet
,
attn
in
blocks
:
if
self
.
training
and
self
.
gradient_checkpointing
:
# TODO
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
# TODO
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -1383,7 +1383,7 @@ class UpBlockSpatioTemporal(nn.Module):
...
@@ -1383,7 +1383,7 @@ class UpBlockSpatioTemporal(nn.Module):
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -1493,7 +1493,7 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
...
@@ -1493,7 +1493,7 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
if
self
.
training
and
self
.
gradient_checkpointing
:
# TODO
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
# TODO
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
...
src/diffusers/models/unets/unet_motion_model.py
View file @
5b972fbd
...
@@ -323,7 +323,7 @@ class DownBlockMotion(nn.Module):
...
@@ -323,7 +323,7 @@ class DownBlockMotion(nn.Module):
blocks
=
zip
(
self
.
resnets
,
self
.
motion_modules
)
blocks
=
zip
(
self
.
resnets
,
self
.
motion_modules
)
for
resnet
,
motion_module
in
blocks
:
for
resnet
,
motion_module
in
blocks
:
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -513,7 +513,7 @@ class CrossAttnDownBlockMotion(nn.Module):
...
@@ -513,7 +513,7 @@ class CrossAttnDownBlockMotion(nn.Module):
blocks
=
list
(
zip
(
self
.
resnets
,
self
.
attentions
,
self
.
motion_modules
))
blocks
=
list
(
zip
(
self
.
resnets
,
self
.
attentions
,
self
.
motion_modules
))
for
i
,
(
resnet
,
attn
,
motion_module
)
in
enumerate
(
blocks
):
for
i
,
(
resnet
,
attn
,
motion_module
)
in
enumerate
(
blocks
):
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -732,7 +732,7 @@ class CrossAttnUpBlockMotion(nn.Module):
...
@@ -732,7 +732,7 @@ class CrossAttnUpBlockMotion(nn.Module):
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -895,7 +895,7 @@ class UpBlockMotion(nn.Module):
...
@@ -895,7 +895,7 @@ class UpBlockMotion(nn.Module):
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -1079,7 +1079,7 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
...
@@ -1079,7 +1079,7 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
return_dict
=
False
,
return_dict
=
False
,
)[
0
]
)[
0
]
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
...
src/diffusers/models/unets/unet_stable_cascade.py
View file @
5b972fbd
...
@@ -455,7 +455,7 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
...
@@ -455,7 +455,7 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
level_outputs
=
[]
level_outputs
=
[]
block_group
=
zip
(
self
.
down_blocks
,
self
.
down_downscalers
,
self
.
down_repeat_mappers
)
block_group
=
zip
(
self
.
down_blocks
,
self
.
down_downscalers
,
self
.
down_repeat_mappers
)
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -504,7 +504,7 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
...
@@ -504,7 +504,7 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
x
=
level_outputs
[
0
]
x
=
level_outputs
[
0
]
block_group
=
zip
(
self
.
up_blocks
,
self
.
up_upscalers
,
self
.
up_repeat_mappers
)
block_group
=
zip
(
self
.
up_blocks
,
self
.
up_upscalers
,
self
.
up_repeat_mappers
)
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
...
src/diffusers/models/unets/uvit_2d.py
View file @
5b972fbd
...
@@ -181,7 +181,7 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
...
@@ -181,7 +181,7 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
hidden_states
=
self
.
project_to_hidden
(
hidden_states
)
hidden_states
=
self
.
project_to_hidden
(
hidden_states
)
for
layer
in
self
.
transformer_layers
:
for
layer
in
self
.
transformer_layers
:
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
layer_
(
*
args
):
def
layer_
(
*
args
):
return
checkpoint
(
layer
,
*
args
)
return
checkpoint
(
layer
,
*
args
)
...
...
src/diffusers/pipelines/audioldm2/modeling_audioldm2.py
View file @
5b972fbd
...
@@ -1112,7 +1112,7 @@ class CrossAttnDownBlock2D(nn.Module):
...
@@ -1112,7 +1112,7 @@ class CrossAttnDownBlock2D(nn.Module):
)
)
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -1290,7 +1290,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
...
@@ -1290,7 +1290,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
)
)
for
i
in
range
(
len
(
self
.
resnets
[
1
:])):
for
i
in
range
(
len
(
self
.
resnets
[
1
:])):
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -1464,7 +1464,7 @@ class CrossAttnUpBlock2D(nn.Module):
...
@@ -1464,7 +1464,7 @@ class CrossAttnUpBlock2D(nn.Module):
res_hidden_states_tuple
=
res_hidden_states_tuple
[:
-
1
]
res_hidden_states_tuple
=
res_hidden_states_tuple
[:
-
1
]
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
...
src/diffusers/pipelines/blip_diffusion/modeling_blip2.py
View file @
5b972fbd
...
@@ -167,7 +167,7 @@ class Blip2QFormerEncoder(nn.Module):
...
@@ -167,7 +167,7 @@ class Blip2QFormerEncoder(nn.Module):
layer_head_mask
=
head_mask
[
i
]
if
head_mask
is
not
None
else
None
layer_head_mask
=
head_mask
[
i
]
if
head_mask
is
not
None
else
None
past_key_value
=
past_key_values
[
i
]
if
past_key_values
is
not
None
else
None
past_key_value
=
past_key_values
[
i
]
if
past_key_values
is
not
None
else
None
if
getattr
(
self
.
config
,
"gradient_checkpointing"
,
False
)
and
self
.
training
:
if
getattr
(
self
.
config
,
"gradient_checkpointing"
,
False
)
and
torch
.
is_grad_enabled
()
:
if
use_cache
:
if
use_cache
:
logger
.
warning
(
logger
.
warning
(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
...
...
src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py
View file @
5b972fbd
...
@@ -1595,7 +1595,7 @@ class DownBlockFlat(nn.Module):
...
@@ -1595,7 +1595,7 @@ class DownBlockFlat(nn.Module):
output_states
=
()
output_states
=
()
for
resnet
in
self
.
resnets
:
for
resnet
in
self
.
resnets
:
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -1732,7 +1732,7 @@ class CrossAttnDownBlockFlat(nn.Module):
...
@@ -1732,7 +1732,7 @@ class CrossAttnDownBlockFlat(nn.Module):
blocks
=
list
(
zip
(
self
.
resnets
,
self
.
attentions
))
blocks
=
list
(
zip
(
self
.
resnets
,
self
.
attentions
))
for
i
,
(
resnet
,
attn
)
in
enumerate
(
blocks
):
for
i
,
(
resnet
,
attn
)
in
enumerate
(
blocks
):
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -1874,7 +1874,7 @@ class UpBlockFlat(nn.Module):
...
@@ -1874,7 +1874,7 @@ class UpBlockFlat(nn.Module):
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -2033,7 +2033,7 @@ class CrossAttnUpBlockFlat(nn.Module):
...
@@ -2033,7 +2033,7 @@ class CrossAttnUpBlockFlat(nn.Module):
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
@@ -2352,7 +2352,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
...
@@ -2352,7 +2352,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
hidden_states
=
self
.
resnets
[
0
](
hidden_states
,
temb
)
hidden_states
=
self
.
resnets
[
0
](
hidden_states
,
temb
)
for
attn
,
resnet
in
zip
(
self
.
attentions
,
self
.
resnets
[
1
:]):
for
attn
,
resnet
in
zip
(
self
.
attentions
,
self
.
resnets
[
1
:]):
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
...
src/diffusers/pipelines/kolors/text_encoder.py
View file @
5b972fbd
...
@@ -590,7 +590,7 @@ class GLMTransformer(torch.nn.Module):
...
@@ -590,7 +590,7 @@ class GLMTransformer(torch.nn.Module):
if
not
kv_caches
:
if
not
kv_caches
:
kv_caches
=
[
None
for
_
in
range
(
self
.
num_layers
)]
kv_caches
=
[
None
for
_
in
range
(
self
.
num_layers
)]
presents
=
()
if
use_cache
else
None
presents
=
()
if
use_cache
else
None
if
self
.
gradient_checkpointing
and
self
.
training
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
if
use_cache
:
if
use_cache
:
logger
.
warning_once
(
logger
.
warning_once
(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
...
@@ -604,7 +604,7 @@ class GLMTransformer(torch.nn.Module):
...
@@ -604,7 +604,7 @@ class GLMTransformer(torch.nn.Module):
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,)
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,)
layer
=
self
.
_get_layer
(
index
)
layer
=
self
.
_get_layer
(
index
)
if
self
.
gradient_checkpointing
and
self
.
training
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
layer_ret
=
torch
.
utils
.
checkpoint
.
checkpoint
(
layer_ret
=
torch
.
utils
.
checkpoint
.
checkpoint
(
layer
,
hidden_states
,
attention_mask
,
rotary_pos_emb
,
kv_caches
[
index
],
use_cache
layer
,
hidden_states
,
attention_mask
,
rotary_pos_emb
,
kv_caches
[
index
],
use_cache
)
)
...
...
src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
View file @
5b972fbd
...
@@ -675,7 +675,7 @@ class LDMBertEncoder(LDMBertPreTrainedModel):
...
@@ -675,7 +675,7 @@ class LDMBertEncoder(LDMBertPreTrainedModel):
for
idx
,
encoder_layer
in
enumerate
(
self
.
layers
):
for
idx
,
encoder_layer
in
enumerate
(
self
.
layers
):
if
output_hidden_states
:
if
output_hidden_states
:
encoder_states
=
encoder_states
+
(
hidden_states
,)
encoder_states
=
encoder_states
+
(
hidden_states
,)
if
self
.
gradient_checkpointing
and
self
.
training
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
...
src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py
View file @
5b972fbd
...
@@ -158,7 +158,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
...
@@ -158,7 +158,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
c_embed
=
self
.
cond_mapper
(
c
)
c_embed
=
self
.
cond_mapper
(
c
)
r_embed
=
self
.
gen_r_embedding
(
r
)
r_embed
=
self
.
gen_r_embedding
(
r
)
if
self
.
training
and
self
.
gradient_checkpointing
:
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
inputs
):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment