Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
4adf6aff
Unverified
Commit
4adf6aff
authored
Oct 31, 2024
by
Sayak Paul
Committed by
GitHub
Oct 31, 2024
Browse files
[Tests] clean up and refactor gradient checkpointing tests (#9494)
* check. * fixes * fixes * updates * fixes * fixes
parent
8ce37ab0
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
180 additions
and
273 deletions
+180
-273
tests/models/autoencoders/test_models_vae.py
tests/models/autoencoders/test_models_vae.py
+25
-84
tests/models/test_modeling_common.py
tests/models/test_modeling_common.py
+97
-0
tests/models/transformers/test_models_dit_transformer2d.py
tests/models/transformers/test_models_dit_transformer2d.py
+7
-0
tests/models/transformers/test_models_pixart_transformer2d.py
...s/models/transformers/test_models_pixart_transformer2d.py
+4
-0
tests/models/transformers/test_models_transformer_allegro.py
tests/models/transformers/test_models_transformer_allegro.py
+4
-0
tests/models/transformers/test_models_transformer_aura_flow.py
.../models/transformers/test_models_transformer_aura_flow.py
+4
-0
tests/models/transformers/test_models_transformer_cogvideox.py
.../models/transformers/test_models_transformer_cogvideox.py
+4
-0
tests/models/transformers/test_models_transformer_cogview3plus.py
...dels/transformers/test_models_transformer_cogview3plus.py
+4
-0
tests/models/transformers/test_models_transformer_flux.py
tests/models/transformers/test_models_transformer_flux.py
+4
-0
tests/models/transformers/test_models_transformer_latte.py
tests/models/transformers/test_models_transformer_latte.py
+4
-0
tests/models/transformers/test_models_transformer_sd3.py
tests/models/transformers/test_models_transformer_sd3.py
+8
-0
tests/models/unets/test_models_unet_2d_condition.py
tests/models/unets/test_models_unet_2d_condition.py
+6
-70
tests/models/unets/test_models_unet_controlnetxs.py
tests/models/unets/test_models_unet_controlnetxs.py
+2
-26
tests/models/unets/test_models_unet_motion.py
tests/models/unets/test_models_unet_motion.py
+2
-24
tests/models/unets/test_models_unet_spatiotemporal.py
tests/models/unets/test_models_unet_spatiotemporal.py
+5
-69
No files found.
tests/models/autoencoders/test_models_vae.py
View file @
4adf6aff
...
...
@@ -39,7 +39,6 @@ from diffusers.utils.testing_utils import (
load_hf_numpy
,
require_torch_accelerator
,
require_torch_accelerator_with_fp16
,
require_torch_accelerator_with_training
,
require_torch_gpu
,
skip_mps
,
slow
,
...
...
@@ -170,52 +169,17 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
@
unittest
.
skip
(
"Not tested."
)
def
test_forward_signature
(
self
):
pass
@
unittest
.
skip
(
"Not tested."
)
def
test_training
(
self
):
pass
@
require_torch_accelerator_with_training
def
test_gradient_checkpointing
(
self
):
# enable deterministic behavior for gradient checkpointing
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
assert
not
model
.
is_gradient_checkpointing
and
model
.
training
out
=
model
(
**
inputs_dict
).
sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model
.
zero_grad
()
labels
=
torch
.
randn_like
(
out
)
loss
=
(
out
-
labels
).
mean
()
loss
.
backward
()
# re-instantiate the model now enabling gradient checkpointing
model_2
=
self
.
model_class
(
**
init_dict
)
# clone model
model_2
.
load_state_dict
(
model
.
state_dict
())
model_2
.
to
(
torch_device
)
model_2
.
enable_gradient_checkpointing
()
assert
model_2
.
is_gradient_checkpointing
and
model_2
.
training
out_2
=
model_2
(
**
inputs_dict
).
sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model_2
.
zero_grad
()
loss_2
=
(
out_2
-
labels
).
mean
()
loss_2
.
backward
()
# compare the output and parameters gradients
self
.
assertTrue
((
loss
-
loss_2
).
abs
()
<
1e-5
)
named_params
=
dict
(
model
.
named_parameters
())
named_params_2
=
dict
(
model_2
.
named_parameters
())
for
name
,
param
in
named_params
.
items
():
self
.
assertTrue
(
torch_all_close
(
param
.
grad
.
data
,
named_params_2
[
name
].
grad
.
data
,
atol
=
5e-5
))
def
test_gradient_checkpointing_is_applied
(
self
):
expected_set
=
{
"Decoder"
,
"Encoder"
}
super
().
test_gradient_checkpointing_is_applied
(
expected_set
=
expected_set
)
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
AutoencoderKL
.
from_pretrained
(
"fusing/autoencoder-kl-dummy"
,
output_loading_info
=
True
)
...
...
@@ -329,9 +293,11 @@ class AsymmetricAutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.T
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
@
unittest
.
skip
(
"Not tested."
)
def
test_forward_signature
(
self
):
pass
@
unittest
.
skip
(
"Not tested."
)
def
test_forward_with_norm_groups
(
self
):
pass
...
...
@@ -364,9 +330,20 @@ class AutoencoderTinyTests(ModelTesterMixin, unittest.TestCase):
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
@
unittest
.
skip
(
"Not tested."
)
def
test_outputs_equivalence
(
self
):
pass
def
test_gradient_checkpointing_is_applied
(
self
):
expected_set
=
{
"DecoderTiny"
,
"EncoderTiny"
}
super
().
test_gradient_checkpointing_is_applied
(
expected_set
=
expected_set
)
@
unittest
.
skip
(
"Gradient checkpointing is supported but this test doesn't apply to this class because it's forward is a bit different from the rest."
)
def
test_effective_gradient_checkpointing
(
self
):
pass
class
ConsistencyDecoderVAETests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
ConsistencyDecoderVAE
...
...
@@ -443,55 +420,17 @@ class AutoencoderKLTemporalDecoderFastTests(ModelTesterMixin, unittest.TestCase)
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
@
unittest
.
skip
(
"Not tested."
)
def
test_forward_signature
(
self
):
pass
@
unittest
.
skip
(
"Not tested."
)
def
test_training
(
self
):
pass
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"Gradient checkpointing skipped on MPS"
)
def
test_gradient_checkpointing
(
self
):
# enable deterministic behavior for gradient checkpointing
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
assert
not
model
.
is_gradient_checkpointing
and
model
.
training
out
=
model
(
**
inputs_dict
).
sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model
.
zero_grad
()
labels
=
torch
.
randn_like
(
out
)
loss
=
(
out
-
labels
).
mean
()
loss
.
backward
()
# re-instantiate the model now enabling gradient checkpointing
model_2
=
self
.
model_class
(
**
init_dict
)
# clone model
model_2
.
load_state_dict
(
model
.
state_dict
())
model_2
.
to
(
torch_device
)
model_2
.
enable_gradient_checkpointing
()
assert
model_2
.
is_gradient_checkpointing
and
model_2
.
training
out_2
=
model_2
(
**
inputs_dict
).
sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model_2
.
zero_grad
()
loss_2
=
(
out_2
-
labels
).
mean
()
loss_2
.
backward
()
# compare the output and parameters gradients
self
.
assertTrue
((
loss
-
loss_2
).
abs
()
<
1e-5
)
named_params
=
dict
(
model
.
named_parameters
())
named_params_2
=
dict
(
model_2
.
named_parameters
())
for
name
,
param
in
named_params
.
items
():
if
"post_quant_conv"
in
name
:
continue
self
.
assertTrue
(
torch_all_close
(
param
.
grad
.
data
,
named_params_2
[
name
].
grad
.
data
,
atol
=
5e-5
))
def
test_gradient_checkpointing_is_applied
(
self
):
expected_set
=
{
"Encoder"
,
"TemporalDecoder"
}
super
().
test_gradient_checkpointing_is_applied
(
expected_set
=
expected_set
)
class
AutoencoderOobleckTests
(
ModelTesterMixin
,
UNetTesterMixin
,
unittest
.
TestCase
):
...
...
@@ -522,9 +461,11 @@ class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCa
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
@
unittest
.
skip
(
"Not tested."
)
def
test_forward_signature
(
self
):
pass
@
unittest
.
skip
(
"Not tested."
)
def
test_forward_with_norm_groups
(
self
):
pass
...
...
tests/models/test_modeling_common.py
View file @
4adf6aff
...
...
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
import
inspect
import
json
import
os
...
...
@@ -57,6 +58,7 @@ from diffusers.utils.testing_utils import (
require_torch_gpu
,
require_torch_multi_gpu
,
run_test_in_subprocess
,
torch_all_close
,
torch_device
,
)
...
...
@@ -785,6 +787,101 @@ class ModelTesterMixin:
model
.
disable_gradient_checkpointing
()
self
.
assertFalse
(
model
.
is_gradient_checkpointing
)
@
require_torch_accelerator_with_training
def
test_effective_gradient_checkpointing
(
self
,
loss_tolerance
=
1e-5
,
param_grad_tol
=
5e-5
):
if
not
self
.
model_class
.
_supports_gradient_checkpointing
:
return
# Skip test if model does not support gradient checkpointing
# enable deterministic behavior for gradient checkpointing
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
inputs_dict_copy
=
copy
.
deepcopy
(
inputs_dict
)
torch
.
manual_seed
(
0
)
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
assert
not
model
.
is_gradient_checkpointing
and
model
.
training
out
=
model
(
**
inputs_dict
).
sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model
.
zero_grad
()
labels
=
torch
.
randn_like
(
out
)
loss
=
(
out
-
labels
).
mean
()
loss
.
backward
()
# re-instantiate the model now enabling gradient checkpointing
torch
.
manual_seed
(
0
)
model_2
=
self
.
model_class
(
**
init_dict
)
# clone model
model_2
.
load_state_dict
(
model
.
state_dict
())
model_2
.
to
(
torch_device
)
model_2
.
enable_gradient_checkpointing
()
assert
model_2
.
is_gradient_checkpointing
and
model_2
.
training
out_2
=
model_2
(
**
inputs_dict_copy
).
sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model_2
.
zero_grad
()
loss_2
=
(
out_2
-
labels
).
mean
()
loss_2
.
backward
()
# compare the output and parameters gradients
self
.
assertTrue
((
loss
-
loss_2
).
abs
()
<
loss_tolerance
)
named_params
=
dict
(
model
.
named_parameters
())
named_params_2
=
dict
(
model_2
.
named_parameters
())
for
name
,
param
in
named_params
.
items
():
if
"post_quant_conv"
in
name
:
continue
self
.
assertTrue
(
torch_all_close
(
param
.
grad
.
data
,
named_params_2
[
name
].
grad
.
data
,
atol
=
param_grad_tol
))
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"This test is not supported for MPS devices."
)
def
test_gradient_checkpointing_is_applied
(
self
,
expected_set
=
None
,
attention_head_dim
=
None
,
num_attention_heads
=
None
,
block_out_channels
=
None
):
if
not
self
.
model_class
.
_supports_gradient_checkpointing
:
return
# Skip test if model does not support gradient checkpointing
if
self
.
model_class
.
__name__
in
[
"UNetSpatioTemporalConditionModel"
,
"AutoencoderKLTemporalDecoder"
,
]:
return
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
if
attention_head_dim
is
not
None
:
init_dict
[
"attention_head_dim"
]
=
attention_head_dim
if
num_attention_heads
is
not
None
:
init_dict
[
"num_attention_heads"
]
=
num_attention_heads
if
block_out_channels
is
not
None
:
init_dict
[
"block_out_channels"
]
=
block_out_channels
model_class_copy
=
copy
.
copy
(
self
.
model_class
)
modules_with_gc_enabled
=
{}
# now monkey patch the following function:
# def _set_gradient_checkpointing(self, module, value=False):
# if hasattr(module, "gradient_checkpointing"):
# module.gradient_checkpointing = value
def
_set_gradient_checkpointing_new
(
self
,
module
,
value
=
False
):
if
hasattr
(
module
,
"gradient_checkpointing"
):
module
.
gradient_checkpointing
=
value
modules_with_gc_enabled
[
module
.
__class__
.
__name__
]
=
True
model_class_copy
.
_set_gradient_checkpointing
=
_set_gradient_checkpointing_new
model
=
model_class_copy
(
**
init_dict
)
model
.
enable_gradient_checkpointing
()
print
(
f
"
{
set
(
modules_with_gc_enabled
.
keys
())
=
}
,
{
expected_set
=
}
"
)
assert
set
(
modules_with_gc_enabled
.
keys
())
==
expected_set
assert
all
(
modules_with_gc_enabled
.
values
()),
"All modules should be enabled"
def
test_deprecated_kwargs
(
self
):
has_kwarg_in_model_class
=
"kwargs"
in
inspect
.
signature
(
self
.
model_class
.
__init__
).
parameters
has_deprecated_kwarg
=
len
(
self
.
model_class
.
_deprecated_kwargs
)
>
0
...
...
tests/models/transformers/test_models_dit_transformer2d.py
View file @
4adf6aff
...
...
@@ -84,6 +84,13 @@ class DiTTransformer2DModelTests(ModelTesterMixin, unittest.TestCase):
model
=
Transformer2DModel
.
from_config
(
init_dict
)
assert
isinstance
(
model
,
DiTTransformer2DModel
)
def
test_gradient_checkpointing_is_applied
(
self
):
expected_set
=
{
"DiTTransformer2DModel"
}
super
().
test_gradient_checkpointing_is_applied
(
expected_set
=
expected_set
)
def
test_effective_gradient_checkpointing
(
self
):
super
().
test_effective_gradient_checkpointing
(
loss_tolerance
=
1e-4
)
def
test_correct_class_remapping_from_pretrained_config
(
self
):
config
=
DiTTransformer2DModel
.
load_config
(
"facebook/DiT-XL-2-256"
,
subfolder
=
"transformer"
)
model
=
Transformer2DModel
.
from_config
(
config
)
...
...
tests/models/transformers/test_models_pixart_transformer2d.py
View file @
4adf6aff
...
...
@@ -92,6 +92,10 @@ class PixArtTransformer2DModelTests(ModelTesterMixin, unittest.TestCase):
expected_output_shape
=
(
self
.
dummy_input
[
self
.
main_input_name
].
shape
[
0
],)
+
self
.
output_shape
)
def
test_gradient_checkpointing_is_applied
(
self
):
expected_set
=
{
"PixArtTransformer2DModel"
}
super
().
test_gradient_checkpointing_is_applied
(
expected_set
=
expected_set
)
def
test_correct_class_remapping_from_dict_config
(
self
):
init_dict
,
_
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
Transformer2DModel
.
from_config
(
init_dict
)
...
...
tests/models/transformers/test_models_transformer_allegro.py
View file @
4adf6aff
...
...
@@ -77,3 +77,7 @@ class AllegroTransformerTests(ModelTesterMixin, unittest.TestCase):
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_gradient_checkpointing_is_applied
(
self
):
expected_set
=
{
"AllegroTransformer3DModel"
}
super
().
test_gradient_checkpointing_is_applied
(
expected_set
=
expected_set
)
tests/models/transformers/test_models_transformer_aura_flow.py
View file @
4adf6aff
...
...
@@ -74,6 +74,10 @@ class AuraFlowTransformerTests(ModelTesterMixin, unittest.TestCase):
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_gradient_checkpointing_is_applied
(
self
):
expected_set
=
{
"AuraFlowTransformer2DModel"
}
super
().
test_gradient_checkpointing_is_applied
(
expected_set
=
expected_set
)
@
unittest
.
skip
(
"AuraFlowTransformer2DModel uses its own dedicated attention processor. This test does not apply"
)
def
test_set_attn_processor_for_determinism
(
self
):
pass
tests/models/transformers/test_models_transformer_cogvideox.py
View file @
4adf6aff
...
...
@@ -81,3 +81,7 @@ class CogVideoXTransformerTests(ModelTesterMixin, unittest.TestCase):
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_gradient_checkpointing_is_applied
(
self
):
expected_set
=
{
"CogVideoXTransformer3DModel"
}
super
().
test_gradient_checkpointing_is_applied
(
expected_set
=
expected_set
)
tests/models/transformers/test_models_transformer_cogview3plus.py
View file @
4adf6aff
...
...
@@ -83,3 +83,7 @@ class CogView3PlusTransformerTests(ModelTesterMixin, unittest.TestCase):
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_gradient_checkpointing_is_applied
(
self
):
expected_set
=
{
"CogView3PlusTransformer2DModel"
}
super
().
test_gradient_checkpointing_is_applied
(
expected_set
=
expected_set
)
tests/models/transformers/test_models_transformer_flux.py
View file @
4adf6aff
...
...
@@ -111,3 +111,7 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
torch
.
allclose
(
output_1
,
output_2
,
atol
=
1e-5
),
msg
=
"output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs"
,
)
def
test_gradient_checkpointing_is_applied
(
self
):
expected_set
=
{
"FluxTransformer2DModel"
}
super
().
test_gradient_checkpointing_is_applied
(
expected_set
=
expected_set
)
tests/models/transformers/test_models_transformer_latte.py
View file @
4adf6aff
...
...
@@ -86,3 +86,7 @@ class LatteTransformerTests(ModelTesterMixin, unittest.TestCase):
super
().
test_output
(
expected_output_shape
=
(
self
.
dummy_input
[
self
.
main_input_name
].
shape
[
0
],)
+
self
.
output_shape
)
def
test_gradient_checkpointing_is_applied
(
self
):
expected_set
=
{
"LatteTransformer3DModel"
}
super
().
test_gradient_checkpointing_is_applied
(
expected_set
=
expected_set
)
tests/models/transformers/test_models_transformer_sd3.py
View file @
4adf6aff
...
...
@@ -84,6 +84,10 @@ class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
def
test_set_attn_processor_for_determinism
(
self
):
pass
def
test_gradient_checkpointing_is_applied
(
self
):
expected_set
=
{
"SD3Transformer2DModel"
}
super
().
test_gradient_checkpointing_is_applied
(
expected_set
=
expected_set
)
class
SD35TransformerTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
SD3Transformer2DModel
...
...
@@ -139,3 +143,7 @@ class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
@
unittest
.
skip
(
"SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply"
)
def
test_set_attn_processor_for_determinism
(
self
):
pass
def
test_gradient_checkpointing_is_applied
(
self
):
expected_set
=
{
"SD3Transformer2DModel"
}
super
().
test_gradient_checkpointing_is_applied
(
expected_set
=
expected_set
)
tests/models/unets/test_models_unet_2d_condition.py
View file @
4adf6aff
...
...
@@ -43,7 +43,6 @@ from diffusers.utils.testing_utils import (
require_peft_backend
,
require_torch_accelerator
,
require_torch_accelerator_with_fp16
,
require_torch_accelerator_with_training
,
require_torch_gpu
,
skip_mps
,
slow
,
...
...
@@ -406,47 +405,6 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
==
"XFormersAttnProcessor"
),
"xformers is not enabled"
@
require_torch_accelerator_with_training
def
test_gradient_checkpointing
(
self
):
# enable deterministic behavior for gradient checkpointing
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
assert
not
model
.
is_gradient_checkpointing
and
model
.
training
out
=
model
(
**
inputs_dict
).
sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model
.
zero_grad
()
labels
=
torch
.
randn_like
(
out
)
loss
=
(
out
-
labels
).
mean
()
loss
.
backward
()
# re-instantiate the model now enabling gradient checkpointing
model_2
=
self
.
model_class
(
**
init_dict
)
# clone model
model_2
.
load_state_dict
(
model
.
state_dict
())
model_2
.
to
(
torch_device
)
model_2
.
enable_gradient_checkpointing
()
assert
model_2
.
is_gradient_checkpointing
and
model_2
.
training
out_2
=
model_2
(
**
inputs_dict
).
sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model_2
.
zero_grad
()
loss_2
=
(
out_2
-
labels
).
mean
()
loss_2
.
backward
()
# compare the output and parameters gradients
self
.
assertTrue
((
loss
-
loss_2
).
abs
()
<
1e-5
)
named_params
=
dict
(
model
.
named_parameters
())
named_params_2
=
dict
(
model_2
.
named_parameters
())
for
name
,
param
in
named_params
.
items
():
self
.
assertTrue
(
torch_all_close
(
param
.
grad
.
data
,
named_params_2
[
name
].
grad
.
data
,
atol
=
5e-5
))
def
test_model_with_attention_head_dim_tuple
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
...
...
@@ -599,31 +557,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
check_sliceable_dim_attr
(
module
)
def
test_gradient_checkpointing_is_applied
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
init_dict
[
"block_out_channels"
]
=
(
16
,
32
)
init_dict
[
"attention_head_dim"
]
=
(
8
,
16
)
model_class_copy
=
copy
.
copy
(
self
.
model_class
)
modules_with_gc_enabled
=
{}
# now monkey patch the following function:
# def _set_gradient_checkpointing(self, module, value=False):
# if hasattr(module, "gradient_checkpointing"):
# module.gradient_checkpointing = value
def
_set_gradient_checkpointing_new
(
self
,
module
,
value
=
False
):
if
hasattr
(
module
,
"gradient_checkpointing"
):
module
.
gradient_checkpointing
=
value
modules_with_gc_enabled
[
module
.
__class__
.
__name__
]
=
True
model_class_copy
.
_set_gradient_checkpointing
=
_set_gradient_checkpointing_new
model
=
model_class_copy
(
**
init_dict
)
model
.
enable_gradient_checkpointing
()
EXPECTED_SET
=
{
expected_set
=
{
"CrossAttnUpBlock2D"
,
"CrossAttnDownBlock2D"
,
"UNetMidBlock2DCrossAttn"
,
...
...
@@ -631,9 +565,11 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
"Transformer2DModel"
,
"DownBlock2D"
,
}
assert
set
(
modules_with_gc_enabled
.
keys
())
==
EXPECTED_SET
assert
all
(
modules_with_gc_enabled
.
values
()),
"All modules should be enabled"
attention_head_dim
=
(
8
,
16
)
block_out_channels
=
(
16
,
32
)
super
().
test_gradient_checkpointing_is_applied
(
expected_set
=
expected_set
,
attention_head_dim
=
attention_head_dim
,
block_out_channels
=
block_out_channels
)
def
test_special_attn_proc
(
self
):
class
AttnEasyProc
(
torch
.
nn
.
Module
):
...
...
tests/models/unets/test_models_unet_controlnetxs.py
View file @
4adf6aff
...
...
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
import
unittest
import
numpy
as
np
...
...
@@ -269,37 +268,14 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
assert_unfrozen
(
u
.
ctrl_to_base
)
def
test_gradient_checkpointing_is_applied
(
self
):
model_class_copy
=
copy
.
copy
(
UNetControlNetXSModel
)
modules_with_gc_enabled
=
{}
# now monkey patch the following function:
# def _set_gradient_checkpointing(self, module, value=False):
# if hasattr(module, "gradient_checkpointing"):
# module.gradient_checkpointing = value
def
_set_gradient_checkpointing_new
(
self
,
module
,
value
=
False
):
if
hasattr
(
module
,
"gradient_checkpointing"
):
module
.
gradient_checkpointing
=
value
modules_with_gc_enabled
[
module
.
__class__
.
__name__
]
=
True
model_class_copy
.
_set_gradient_checkpointing
=
_set_gradient_checkpointing_new
init_dict
,
_
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
model_class_copy
(
**
init_dict
)
model
.
enable_gradient_checkpointing
()
EXPECTED_SET
=
{
expected_set
=
{
"Transformer2DModel"
,
"UNetMidBlock2DCrossAttn"
,
"ControlNetXSCrossAttnDownBlock2D"
,
"ControlNetXSCrossAttnMidBlock2D"
,
"ControlNetXSCrossAttnUpBlock2D"
,
}
assert
set
(
modules_with_gc_enabled
.
keys
())
==
EXPECTED_SET
assert
all
(
modules_with_gc_enabled
.
values
()),
"All modules should be enabled"
super
().
test_gradient_checkpointing_is_applied
(
expected_set
=
expected_set
)
@
is_flaky
def
test_forward_no_control
(
self
):
...
...
tests/models/unets/test_models_unet_motion.py
View file @
4adf6aff
...
...
@@ -161,27 +161,7 @@ class UNetMotionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase)
),
"xformers is not enabled"
def
test_gradient_checkpointing_is_applied
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model_class_copy
=
copy
.
copy
(
self
.
model_class
)
modules_with_gc_enabled
=
{}
# now monkey patch the following function:
# def _set_gradient_checkpointing(self, module, value=False):
# if hasattr(module, "gradient_checkpointing"):
# module.gradient_checkpointing = value
def
_set_gradient_checkpointing_new
(
self
,
module
,
value
=
False
):
if
hasattr
(
module
,
"gradient_checkpointing"
):
module
.
gradient_checkpointing
=
value
modules_with_gc_enabled
[
module
.
__class__
.
__name__
]
=
True
model_class_copy
.
_set_gradient_checkpointing
=
_set_gradient_checkpointing_new
model
=
model_class_copy
(
**
init_dict
)
model
.
enable_gradient_checkpointing
()
EXPECTED_SET
=
{
expected_set
=
{
"CrossAttnUpBlockMotion"
,
"CrossAttnDownBlockMotion"
,
"UNetMidBlockCrossAttnMotion"
,
...
...
@@ -189,9 +169,7 @@ class UNetMotionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase)
"Transformer2DModel"
,
"DownBlockMotion"
,
}
assert
set
(
modules_with_gc_enabled
.
keys
())
==
EXPECTED_SET
assert
all
(
modules_with_gc_enabled
.
values
()),
"All modules should be enabled"
super
().
test_gradient_checkpointing_is_applied
(
expected_set
=
expected_set
)
def
test_feed_forward_chunking
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
...
...
tests/models/unets/test_models_unet_spatiotemporal.py
View file @
4adf6aff
...
...
@@ -25,7 +25,6 @@ from diffusers.utils.testing_utils import (
enable_full_determinism
,
floats_tensor
,
skip_mps
,
torch_all_close
,
torch_device
,
)
...
...
@@ -160,47 +159,6 @@ class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, u
==
"XFormersAttnProcessor"
),
"xformers is not enabled"
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"Gradient checkpointing skipped on MPS"
)
def
test_gradient_checkpointing
(
self
):
# enable deterministic behavior for gradient checkpointing
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
assert
not
model
.
is_gradient_checkpointing
and
model
.
training
out
=
model
(
**
inputs_dict
).
sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model
.
zero_grad
()
labels
=
torch
.
randn_like
(
out
)
loss
=
(
out
-
labels
).
mean
()
loss
.
backward
()
# re-instantiate the model now enabling gradient checkpointing
model_2
=
self
.
model_class
(
**
init_dict
)
# clone model
model_2
.
load_state_dict
(
model
.
state_dict
())
model_2
.
to
(
torch_device
)
model_2
.
enable_gradient_checkpointing
()
assert
model_2
.
is_gradient_checkpointing
and
model_2
.
training
out_2
=
model_2
(
**
inputs_dict
).
sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model_2
.
zero_grad
()
loss_2
=
(
out_2
-
labels
).
mean
()
loss_2
.
backward
()
# compare the output and parameters gradients
self
.
assertTrue
((
loss
-
loss_2
).
abs
()
<
1e-5
)
named_params
=
dict
(
model
.
named_parameters
())
named_params_2
=
dict
(
model_2
.
named_parameters
())
for
name
,
param
in
named_params
.
items
():
self
.
assertTrue
(
torch_all_close
(
param
.
grad
.
data
,
named_params_2
[
name
].
grad
.
data
,
atol
=
5e-5
))
def
test_model_with_num_attention_heads_tuple
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
...
...
@@ -239,30 +197,7 @@ class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, u
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
def
test_gradient_checkpointing_is_applied
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
init_dict
[
"num_attention_heads"
]
=
(
8
,
16
)
model_class_copy
=
copy
.
copy
(
self
.
model_class
)
modules_with_gc_enabled
=
{}
# now monkey patch the following function:
# def _set_gradient_checkpointing(self, module, value=False):
# if hasattr(module, "gradient_checkpointing"):
# module.gradient_checkpointing = value
def
_set_gradient_checkpointing_new
(
self
,
module
,
value
=
False
):
if
hasattr
(
module
,
"gradient_checkpointing"
):
module
.
gradient_checkpointing
=
value
modules_with_gc_enabled
[
module
.
__class__
.
__name__
]
=
True
model_class_copy
.
_set_gradient_checkpointing
=
_set_gradient_checkpointing_new
model
=
model_class_copy
(
**
init_dict
)
model
.
enable_gradient_checkpointing
()
EXPECTED_SET
=
{
expected_set
=
{
"TransformerSpatioTemporalModel"
,
"CrossAttnDownBlockSpatioTemporal"
,
"DownBlockSpatioTemporal"
,
...
...
@@ -270,9 +205,10 @@ class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, u
"CrossAttnUpBlockSpatioTemporal"
,
"UNetMidBlockSpatioTemporal"
,
}
assert
set
(
modules_with_gc_enabled
.
keys
())
==
EXPECTED_SET
assert
all
(
modules_with_gc_enabled
.
values
()),
"All modules should be enabled"
num_attention_heads
=
(
8
,
16
)
super
().
test_gradient_checkpointing_is_applied
(
expected_set
=
expected_set
,
num_attention_heads
=
num_attention_heads
)
def
test_pickle
(
self
):
# enable deterministic behavior for gradient checkpointing
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment