Unverified Commit f64d52db authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

fix custom diffusion tests (#4996)

parent 4d897aaf
......@@ -785,8 +785,8 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.to(torch_device)
new_model.load_attn_procs(tmpdirname, weight_name="pytorch_custom_diffusion_weights.bin")
new_model.to(torch_device)
with torch.no_grad():
new_sample = new_model(**inputs_dict).sample
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment