Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
22963ed8
Unverified
Commit
22963ed8
authored
Oct 10, 2022
by
Patrick von Platen
Committed by
GitHub
Oct 10, 2022
Browse files
Fix gradient checkpointing test (#797)
* Fix gradient checkpointing test * more tsets
parent
fab17528
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
21 deletions
+23
-21
tests/test_models_unet.py
tests/test_models_unet.py
+23
-21
No files found.
tests/test_models_unet.py
View file @
22963ed8
...
@@ -273,37 +273,39 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -273,37 +273,39 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
model
=
self
.
model_class
(
**
init_dict
)
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
assert
not
model
.
is_gradient_checkpointing
and
model
.
training
out
=
model
(
**
inputs_dict
).
sample
out
=
model
(
**
inputs_dict
).
sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
# we won't calculate the loss and rather backprop on out.sum()
model
.
zero_grad
()
model
.
zero_grad
()
out
.
sum
().
backward
()
# now we save the output and parameter gradients that we will use for comparison purposes with
labels
=
torch
.
randn_like
(
out
)
# the non-checkpointed run.
loss
=
(
out
-
labels
).
mean
()
output_not_checkpointed
=
out
.
data
.
clone
()
loss
.
backward
()
grad_not_checkpointed
=
{}
for
name
,
param
in
model
.
named_parameters
():
grad_not_checkpointed
[
name
]
=
param
.
grad
.
data
.
clone
()
model
.
enable_gradient_checkpointing
()
# re-instantiate the model now enabling gradient checkpointing
out
=
model
(
**
inputs_dict
).
sample
model_2
=
self
.
model_class
(
**
init_dict
)
# clone model
model_2
.
load_state_dict
(
model
.
state_dict
())
model_2
.
to
(
torch_device
)
model_2
.
enable_gradient_checkpointing
()
assert
model_2
.
is_gradient_checkpointing
and
model_2
.
training
out_2
=
model_2
(
**
inputs_dict
).
sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
# we won't calculate the loss and rather backprop on out.sum()
model
.
zero_grad
()
model_2
.
zero_grad
()
out
.
sum
().
backward
()
loss_2
=
(
out_2
-
labels
).
mean
()
loss_2
.
backward
()
# now we save the output and parameter gradients that we will use for comparison purposes with
# the non-checkpointed run.
output_checkpointed
=
out
.
data
.
clone
()
grad_checkpointed
=
{}
for
name
,
param
in
model
.
named_parameters
():
grad_checkpointed
[
name
]
=
param
.
grad
.
data
.
clone
()
# compare the output and parameters gradients
# compare the output and parameters gradients
self
.
assertTrue
((
output_checkpointed
==
output_not_checkpointed
).
all
())
self
.
assertTrue
((
loss
-
loss_2
).
abs
()
<
1e-5
)
for
name
in
grad_checkpointed
:
named_params
=
dict
(
model
.
named_parameters
())
self
.
assertTrue
(
torch
.
allclose
(
grad_checkpointed
[
name
],
grad_not_checkpointed
[
name
],
atol
=
5e-5
))
named_params_2
=
dict
(
model_2
.
named_parameters
())
for
name
,
param
in
named_params
.
items
():
self
.
assertTrue
(
torch
.
allclose
(
param
.
grad
.
data
,
named_params_2
[
name
].
grad
.
data
,
atol
=
5e-5
))
# TODO(Patrick) - Re-add this test after having cleaned up LDM
# TODO(Patrick) - Re-add this test after having cleaned up LDM
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment