Unverified Commit 34c90dbb authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

fix OOM for test_vae_tiling (#7510)

use float16 and add torch.no_grad()
parent e49c04d5
...@@ -1118,8 +1118,10 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase): ...@@ -1118,8 +1118,10 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
assert torch_all_close(actual_output, expected_output, atol=5e-3) assert torch_all_close(actual_output, expected_output, atol=5e-3)
def test_vae_tiling(self): def test_vae_tiling(self):
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder") vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16)
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", vae=vae, safety_checker=None) pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", vae=vae, safety_checker=None, torch_dtype=torch.float16
)
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -1143,6 +1145,7 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase): ...@@ -1143,6 +1145,7 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
# test that tiled decode works with various shapes # test that tiled decode works with various shapes
shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)] shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)]
with torch.no_grad():
for shape in shapes: for shape in shapes:
image = torch.zeros(shape, device=torch_device) image = torch.zeros(shape, device=torch_device)
pipe.vae.decode(image) pipe.vae.decode(image)
...@@ -124,6 +124,7 @@ class SDFunctionTesterMixin: ...@@ -124,6 +124,7 @@ class SDFunctionTesterMixin:
# test that tiled decode works with various shapes # test that tiled decode works with various shapes
shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)] shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)]
with torch.no_grad():
for shape in shapes: for shape in shapes:
zeros = torch.zeros(shape).to(torch_device) zeros = torch.zeros(shape).to(torch_device)
pipe.vae.decode(zeros) pipe.vae.decode(zeros)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment