Unverified Commit 84b9df57 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[gradient checkpointing] lower tolerance for test (#652)

* lowe tolerance

* put model in eval mode
parent 210be4fe
......@@ -199,7 +199,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
def test_gradient_checkpointing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model = self.model_class(**init_dict).eval()
model.to(torch_device)
out = model(**inputs_dict).sample
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment