Unverified Commit f7e5954d authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Tests] fix: VAE tiling tests when setting the right device (#7246)

* debug

* checking

* fix more

* remove device.

* fix-copies
parent 8e19c073
......@@ -99,14 +99,13 @@ class SDFunctionTesterMixin:
assert np.abs(output_2[0].flatten() - output_1[0].flatten()).max() < 1e-2
def test_vae_tiling(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
# make sure here that pndm scheduler skips prk
if "safety_checker" in components:
components["safety_checker"] = None
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
......@@ -126,7 +125,7 @@ class SDFunctionTesterMixin:
# test that tiled decode works with various shapes
shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)]
for shape in shapes:
zeros = torch.zeros(shape).to(device)
zeros = torch.zeros(shape).to(torch_device)
pipe.vae.decode(zeros)
def test_freeu_enabled(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment