Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
22963ed8
Unverified
Commit
22963ed8
authored
Oct 10, 2022
by
Patrick von Platen
Committed by
GitHub
Oct 10, 2022
Browse files
Fix gradient checkpointing test (#797)
* Fix gradient checkpointing test * more tsets
parent
fab17528
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
21 deletions
+23
-21
tests/test_models_unet.py
tests/test_models_unet.py
+23
-21
No files found.
tests/test_models_unet.py
View file @
22963ed8
...
...
@@ -273,37 +273,39 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
assert
not
model
.
is_gradient_checkpointing
and
model
.
training
out
=
model
(
**
inputs_dict
).
sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model
.
zero_grad
()
out
.
sum
().
backward
()
# now we save the output and parameter gradients that we will use for comparison purposes with
# the non-checkpointed run.
output_not_checkpointed
=
out
.
data
.
clone
()
grad_not_checkpointed
=
{}
for
name
,
param
in
model
.
named_parameters
():
grad_not_checkpointed
[
name
]
=
param
.
grad
.
data
.
clone
()
labels
=
torch
.
randn_like
(
out
)
loss
=
(
out
-
labels
).
mean
()
loss
.
backward
()
model
.
enable_gradient_checkpointing
()
out
=
model
(
**
inputs_dict
).
sample
# re-instantiate the model now enabling gradient checkpointing
model_2
=
self
.
model_class
(
**
init_dict
)
# clone model
model_2
.
load_state_dict
(
model
.
state_dict
())
model_2
.
to
(
torch_device
)
model_2
.
enable_gradient_checkpointing
()
assert
model_2
.
is_gradient_checkpointing
and
model_2
.
training
out_2
=
model_2
(
**
inputs_dict
).
sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model
.
zero_grad
()
out
.
sum
().
backward
()
# now we save the output and parameter gradients that we will use for comparison purposes with
# the non-checkpointed run.
output_checkpointed
=
out
.
data
.
clone
()
grad_checkpointed
=
{}
for
name
,
param
in
model
.
named_parameters
():
grad_checkpointed
[
name
]
=
param
.
grad
.
data
.
clone
()
model_2
.
zero_grad
()
loss_2
=
(
out_2
-
labels
).
mean
()
loss_2
.
backward
()
# compare the output and parameters gradients
self
.
assertTrue
((
output_checkpointed
==
output_not_checkpointed
).
all
())
for
name
in
grad_checkpointed
:
self
.
assertTrue
(
torch
.
allclose
(
grad_checkpointed
[
name
],
grad_not_checkpointed
[
name
],
atol
=
5e-5
))
self
.
assertTrue
((
loss
-
loss_2
).
abs
()
<
1e-5
)
named_params
=
dict
(
model
.
named_parameters
())
named_params_2
=
dict
(
model_2
.
named_parameters
())
for
name
,
param
in
named_params
.
items
():
self
.
assertTrue
(
torch
.
allclose
(
param
.
grad
.
data
,
named_params_2
[
name
].
grad
.
data
,
atol
=
5e-5
))
# TODO(Patrick) - Re-add this test after having cleaned up LDM
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment