Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
3f7c3511
Unverified
Commit
3f7c3511
authored
Nov 27, 2023
by
Sayak Paul
Committed by
GitHub
Nov 27, 2023
Browse files
[Core] add support for gradient checkpointing in transformer_2d (#5943)
add support for gradient checkpointing in transformer_2d
parent
7d6f30e8
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
3 deletions
+20
-3
src/diffusers/models/transformer_2d.py
src/diffusers/models/transformer_2d.py
+20
-3
No files found.
src/diffusers/models/transformer_2d.py
View file @
3f7c3511
...
@@ -20,7 +20,7 @@ from torch import nn
...
@@ -20,7 +20,7 @@ from torch import nn
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..models.embeddings
import
ImagePositionalEmbeddings
from
..models.embeddings
import
ImagePositionalEmbeddings
from
..utils
import
USE_PEFT_BACKEND
,
BaseOutput
,
deprecate
from
..utils
import
USE_PEFT_BACKEND
,
BaseOutput
,
deprecate
,
is_torch_version
from
.attention
import
BasicTransformerBlock
from
.attention
import
BasicTransformerBlock
from
.embeddings
import
CaptionProjection
,
PatchEmbed
from
.embeddings
import
CaptionProjection
,
PatchEmbed
from
.lora
import
LoRACompatibleConv
,
LoRACompatibleLinear
from
.lora
import
LoRACompatibleConv
,
LoRACompatibleLinear
...
@@ -70,6 +70,8 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
...
@@ -70,6 +70,8 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
Configure if the `TransformerBlocks` attention should contain a bias parameter.
Configure if the `TransformerBlocks` attention should contain a bias parameter.
"""
"""
_supports_gradient_checkpointing
=
True
@
register_to_config
@
register_to_config
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -237,6 +239,10 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
...
@@ -237,6 +239,10 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self
.
gradient_checkpointing
=
False
self
.
gradient_checkpointing
=
False
def
_set_gradient_checkpointing
(
self
,
module
,
value
=
False
):
if
hasattr
(
module
,
"gradient_checkpointing"
):
module
.
gradient_checkpointing
=
value
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
@@ -360,8 +366,19 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
...
@@ -360,8 +366,19 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
for
block
in
self
.
transformer_blocks
:
for
block
in
self
.
transformer_blocks
:
if
self
.
training
and
self
.
gradient_checkpointing
:
if
self
.
training
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
,
return_dict
=
None
):
def
custom_forward
(
*
inputs
):
if
return_dict
is
not
None
:
return
module
(
*
inputs
,
return_dict
=
return_dict
)
else
:
return
module
(
*
inputs
)
return
custom_forward
ckpt_kwargs
:
Dict
[
str
,
Any
]
=
{
"use_reentrant"
:
False
}
if
is_torch_version
(
">="
,
"1.11.0"
)
else
{}
hidden_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
hidden_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
block
,
create_custom_forward
(
block
)
,
hidden_states
,
hidden_states
,
attention_mask
,
attention_mask
,
encoder_hidden_states
,
encoder_hidden_states
,
...
@@ -369,7 +386,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
...
@@ -369,7 +386,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
timestep
,
timestep
,
cross_attention_kwargs
,
cross_attention_kwargs
,
class_labels
,
class_labels
,
use_reentrant
=
False
,
**
ckpt_kwargs
,
)
)
else
:
else
:
hidden_states
=
block
(
hidden_states
=
block
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment