Unverified Commit 5d21d4a2 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Fix FreeU tests (#7540)

update
parent 73ba8109
...@@ -133,11 +133,15 @@ class SDFunctionTesterMixin: ...@@ -133,11 +133,15 @@ class SDFunctionTesterMixin:
inputs = self.get_dummy_inputs(torch_device) inputs = self.get_dummy_inputs(torch_device)
inputs["return_dict"] = False inputs["return_dict"] = False
inputs["output_type"] = "np"
output = pipe(**inputs)[0] output = pipe(**inputs)[0]
pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4) pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
inputs = self.get_dummy_inputs(torch_device) inputs = self.get_dummy_inputs(torch_device)
inputs["return_dict"] = False inputs["return_dict"] = False
inputs["output_type"] = "np"
output_freeu = pipe(**inputs)[0] output_freeu = pipe(**inputs)[0]
assert not np.allclose( assert not np.allclose(
...@@ -152,6 +156,8 @@ class SDFunctionTesterMixin: ...@@ -152,6 +156,8 @@ class SDFunctionTesterMixin:
inputs = self.get_dummy_inputs(torch_device) inputs = self.get_dummy_inputs(torch_device)
inputs["return_dict"] = False inputs["return_dict"] = False
inputs["output_type"] = "np"
output = pipe(**inputs)[0] output = pipe(**inputs)[0]
pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4) pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
...@@ -164,6 +170,8 @@ class SDFunctionTesterMixin: ...@@ -164,6 +170,8 @@ class SDFunctionTesterMixin:
inputs = self.get_dummy_inputs(torch_device) inputs = self.get_dummy_inputs(torch_device)
inputs["return_dict"] = False inputs["return_dict"] = False
inputs["output_type"] = "np"
output_no_freeu = pipe(**inputs)[0] output_no_freeu = pipe(**inputs)[0]
assert np.allclose( assert np.allclose(
output, output_no_freeu, atol=1e-2 output, output_no_freeu, atol=1e-2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment