Unverified Commit 79a7ab92 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Fix clearing backend cache from device agnostic testing (#6075)

update
parent c2717317
...@@ -164,7 +164,7 @@ class PriorTransformerIntegrationTests(unittest.TestCase): ...@@ -164,7 +164,7 @@ class PriorTransformerIntegrationTests(unittest.TestCase):
# clean up the VRAM after each test # clean up the VRAM after each test
super().tearDown() super().tearDown()
gc.collect() gc.collect()
backend_empty_cache() backend_empty_cache(torch_device)
@parameterized.expand( @parameterized.expand(
[ [
......
...@@ -869,7 +869,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase): ...@@ -869,7 +869,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
# clean up the VRAM after each test # clean up the VRAM after each test
super().tearDown() super().tearDown()
gc.collect() gc.collect()
backend_empty_cache() backend_empty_cache(torch_device)
def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False): def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False):
dtype = torch.float16 if fp16 else torch.float32 dtype = torch.float16 if fp16 else torch.float32
......
...@@ -485,7 +485,7 @@ class AutoencoderTinyIntegrationTests(unittest.TestCase): ...@@ -485,7 +485,7 @@ class AutoencoderTinyIntegrationTests(unittest.TestCase):
# clean up the VRAM after each test # clean up the VRAM after each test
super().tearDown() super().tearDown()
gc.collect() gc.collect()
backend_empty_cache() backend_empty_cache(torch_device)
def get_file_format(self, seed, shape): def get_file_format(self, seed, shape):
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy" return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
...@@ -565,7 +565,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase): ...@@ -565,7 +565,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
# clean up the VRAM after each test # clean up the VRAM after each test
super().tearDown() super().tearDown()
gc.collect() gc.collect()
backend_empty_cache() backend_empty_cache(torch_device)
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False): def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
dtype = torch.float16 if fp16 else torch.float32 dtype = torch.float16 if fp16 else torch.float32
...@@ -820,7 +820,7 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase): ...@@ -820,7 +820,7 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
# clean up the VRAM after each test # clean up the VRAM after each test
super().tearDown() super().tearDown()
gc.collect() gc.collect()
backend_empty_cache() backend_empty_cache(torch_device)
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False): def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
dtype = torch.float16 if fp16 else torch.float32 dtype = torch.float16 if fp16 else torch.float32
......
...@@ -310,7 +310,7 @@ class StableDiffusion2PipelineSlowTests(unittest.TestCase): ...@@ -310,7 +310,7 @@ class StableDiffusion2PipelineSlowTests(unittest.TestCase):
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
gc.collect() gc.collect()
backend_empty_cache() backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0): def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
_generator_device = "cpu" if not generator_device.startswith("cuda") else "cuda" _generator_device = "cpu" if not generator_device.startswith("cuda") else "cuda"
...@@ -531,7 +531,7 @@ class StableDiffusion2PipelineNightlyTests(unittest.TestCase): ...@@ -531,7 +531,7 @@ class StableDiffusion2PipelineNightlyTests(unittest.TestCase):
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
gc.collect() gc.collect()
backend_empty_cache() backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0): def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
_generator_device = "cpu" if not generator_device.startswith("cuda") else "cuda" _generator_device = "cpu" if not generator_device.startswith("cuda") else "cuda"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment