Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
3f7c3511
Unverified
Commit
3f7c3511
authored
Nov 27, 2023
by
Sayak Paul
Committed by
GitHub
Nov 27, 2023
Browse files
[Core] add support for gradient checkpointing in transformer_2d (#5943)
add support for gradient checkpointing in transformer_2d
parent
7d6f30e8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
3 deletions
+20
-3
src/diffusers/models/transformer_2d.py
src/diffusers/models/transformer_2d.py
+20
-3
No files found.
src/diffusers/models/transformer_2d.py
View file @
3f7c3511
...
...
@@ -20,7 +20,7 @@ from torch import nn
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..models.embeddings
import
ImagePositionalEmbeddings
from
..utils
import
USE_PEFT_BACKEND
,
BaseOutput
,
deprecate
from
..utils
import
USE_PEFT_BACKEND
,
BaseOutput
,
deprecate
,
is_torch_version
from
.attention
import
BasicTransformerBlock
from
.embeddings
import
CaptionProjection
,
PatchEmbed
from
.lora
import
LoRACompatibleConv
,
LoRACompatibleLinear
...
...
@@ -70,6 +70,8 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
Configure if the `TransformerBlocks` attention should contain a bias parameter.
"""
_supports_gradient_checkpointing
=
True
@
register_to_config
def
__init__
(
self
,
...
...
@@ -237,6 +239,10 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self
.
gradient_checkpointing
=
False
def
_set_gradient_checkpointing
(
self
,
module
,
value
=
False
):
if
hasattr
(
module
,
"gradient_checkpointing"
):
module
.
gradient_checkpointing
=
value
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
@@ -360,8 +366,19 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
for
block
in
self
.
transformer_blocks
:
if
self
.
training
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
if
return_dict
is
not
None
:
return
module
(
*
inputs
,
return_dict
=
return_dict
)
else
:
return
module
(
*
inputs
)
return
custom_forward
ckpt_kwargs
:
Dict
[
str
,
Any
]
=
{
"use_reentrant"
:
False
}
if
is_torch_version
(
">="
,
"1.11.0"
)
else
{}
hidden_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
block
,
create_custom_forward
(
block
)
,
hidden_states
,
attention_mask
,
encoder_hidden_states
,
...
...
@@ -369,7 +386,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
timestep
,
cross_attention_kwargs
,
class_labels
,
use_reentrant
=
False
,
**
ckpt_kwargs
,
)
else
:
hidden_states
=
block
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment