Unverified Commit 22963ed8 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Fix gradient checkpointing test (#797)

* Fix gradient checkpointing test

* more tsets
parent fab17528
...@@ -273,37 +273,39 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -273,37 +273,39 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
model.to(torch_device) model.to(torch_device)
assert not model.is_gradient_checkpointing and model.training
out = model(**inputs_dict).sample out = model(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose, # run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum() # we won't calculate the loss and rather backprop on out.sum()
model.zero_grad() model.zero_grad()
out.sum().backward()
# now we save the output and parameter gradients that we will use for comparison purposes with labels = torch.randn_like(out)
# the non-checkpointed run. loss = (out - labels).mean()
output_not_checkpointed = out.data.clone() loss.backward()
grad_not_checkpointed = {}
for name, param in model.named_parameters():
grad_not_checkpointed[name] = param.grad.data.clone()
model.enable_gradient_checkpointing() # re-instantiate the model now enabling gradient checkpointing
out = model(**inputs_dict).sample model_2 = self.model_class(**init_dict)
# clone model
model_2.load_state_dict(model.state_dict())
model_2.to(torch_device)
model_2.enable_gradient_checkpointing()
assert model_2.is_gradient_checkpointing and model_2.training
out_2 = model_2(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose, # run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum() # we won't calculate the loss and rather backprop on out.sum()
model.zero_grad() model_2.zero_grad()
out.sum().backward() loss_2 = (out_2 - labels).mean()
loss_2.backward()
# now we save the output and parameter gradients that we will use for comparison purposes with
# the non-checkpointed run.
output_checkpointed = out.data.clone()
grad_checkpointed = {}
for name, param in model.named_parameters():
grad_checkpointed[name] = param.grad.data.clone()
# compare the output and parameters gradients # compare the output and parameters gradients
self.assertTrue((output_checkpointed == output_not_checkpointed).all()) self.assertTrue((loss - loss_2).abs() < 1e-5)
for name in grad_checkpointed: named_params = dict(model.named_parameters())
self.assertTrue(torch.allclose(grad_checkpointed[name], grad_not_checkpointed[name], atol=5e-5)) named_params_2 = dict(model_2.named_parameters())
for name, param in named_params.items():
self.assertTrue(torch.allclose(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
# TODO(Patrick) - Re-add this test after having cleaned up LDM # TODO(Patrick) - Re-add this test after having cleaned up LDM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment