Unverified Commit f7e5954d authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Tests] fix: VAE tiling tests when setting the right device (#7246)

* debug

* checking

* fix more

* remove device.

* fix-copies
parent 8e19c073
...@@ -99,14 +99,13 @@ class SDFunctionTesterMixin: ...@@ -99,14 +99,13 @@ class SDFunctionTesterMixin:
assert np.abs(output_2[0].flatten() - output_1[0].flatten()).max() < 1e-2 assert np.abs(output_2[0].flatten() - output_1[0].flatten()).max() < 1e-2
def test_vae_tiling(self): def test_vae_tiling(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components() components = self.get_dummy_components()
# make sure here that pndm scheduler skips prk # make sure here that pndm scheduler skips prk
if "safety_checker" in components: if "safety_checker" in components:
components["safety_checker"] = None components["safety_checker"] = None
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device) inputs = self.get_dummy_inputs(torch_device)
...@@ -126,7 +125,7 @@ class SDFunctionTesterMixin: ...@@ -126,7 +125,7 @@ class SDFunctionTesterMixin:
# test that tiled decode works with various shapes # test that tiled decode works with various shapes
shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)] shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)]
for shape in shapes: for shape in shapes:
zeros = torch.zeros(shape).to(device) zeros = torch.zeros(shape).to(torch_device)
pipe.vae.decode(zeros) pipe.vae.decode(zeros)
def test_freeu_enabled(self): def test_freeu_enabled(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment