"docs/vscode:/vscode.git/clone" did not exist on "e4b4fcffc2716f6a74eb018084485c9e108e8952"
Unverified Commit 34c90dbb authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

fix OOM for test_vae_tiling (#7510)

use float16 and add torch.no_grad()
parent e49c04d5
......@@ -1118,8 +1118,10 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
assert torch_all_close(actual_output, expected_output, atol=5e-3)
def test_vae_tiling(self):
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder")
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", vae=vae, safety_checker=None)
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16)
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", vae=vae, safety_checker=None, torch_dtype=torch.float16
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -1143,6 +1145,7 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
# test that tiled decode works with various shapes
shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)]
for shape in shapes:
image = torch.zeros(shape, device=torch_device)
pipe.vae.decode(image)
with torch.no_grad():
for shape in shapes:
image = torch.zeros(shape, device=torch_device)
pipe.vae.decode(image)
......@@ -124,9 +124,10 @@ class SDFunctionTesterMixin:
# test that tiled decode works with various shapes
shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)]
for shape in shapes:
zeros = torch.zeros(shape).to(torch_device)
pipe.vae.decode(zeros)
with torch.no_grad():
for shape in shapes:
zeros = torch.zeros(shape).to(torch_device)
pipe.vae.decode(zeros)
def test_freeu_enabled(self):
components = self.get_dummy_components()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment