Unverified Commit 20fd00b1 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

[Tests] Add single file tester mixin for Models and remove unittest dependency (#12352)

* update

* update

* update

* update

* update
parent 76d4e416
import gc import gc
import unittest
import torch import torch
...@@ -25,7 +24,7 @@ enable_full_determinism() ...@@ -25,7 +24,7 @@ enable_full_determinism()
@slow @slow
@require_torch_accelerator @require_torch_accelerator
class StableDiffusionXLImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingleFileTesterMixin): class TestStableDiffusionXLImg2ImgPipelineSingleFileSlow(SDXLSingleFileTesterMixin):
pipeline_class = StableDiffusionXLImg2ImgPipeline pipeline_class = StableDiffusionXLImg2ImgPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors" ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
repo_id = "stabilityai/stable-diffusion-xl-base-1.0" repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
...@@ -33,13 +32,11 @@ class StableDiffusionXLImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDX ...@@ -33,13 +32,11 @@ class StableDiffusionXLImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDX
"https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml" "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
) )
def setUp(self): def setup_method(self):
super().setUp()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
def tearDown(self): def teardown_method(self):
super().tearDown()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
...@@ -66,7 +63,7 @@ class StableDiffusionXLImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDX ...@@ -66,7 +63,7 @@ class StableDiffusionXLImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDX
@slow @slow
@require_torch_accelerator @require_torch_accelerator
class StableDiffusionXLImg2ImgRefinerPipelineSingleFileSlowTests(unittest.TestCase): class StableDiffusionXLImg2ImgRefinerPipelineSingleFileSlowTests:
pipeline_class = StableDiffusionXLImg2ImgPipeline pipeline_class = StableDiffusionXLImg2ImgPipeline
ckpt_path = ( ckpt_path = (
"https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/blob/main/sd_xl_refiner_1.0.safetensors" "https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/blob/main/sd_xl_refiner_1.0.safetensors"
......
import gc import gc
import unittest
import torch import torch
...@@ -19,19 +18,17 @@ enable_full_determinism() ...@@ -19,19 +18,17 @@ enable_full_determinism()
@slow @slow
@require_torch_accelerator @require_torch_accelerator
class StableDiffusionXLInstructPix2PixPipeline(unittest.TestCase): class StableDiffusionXLInstructPix2PixPipeline:
pipeline_class = StableDiffusionXLInstructPix2PixPipeline pipeline_class = StableDiffusionXLInstructPix2PixPipeline
ckpt_path = "https://huggingface.co/stabilityai/cosxl/blob/main/cosxl_edit.safetensors" ckpt_path = "https://huggingface.co/stabilityai/cosxl/blob/main/cosxl_edit.safetensors"
original_config = None original_config = None
repo_id = "diffusers/sdxl-instructpix2pix-768" repo_id = "diffusers/sdxl-instructpix2pix-768"
def setUp(self): def setup_method(self):
super().setUp()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
def tearDown(self): def teardown_method(self):
super().tearDown()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
......
import gc import gc
import unittest
import torch import torch
...@@ -22,7 +21,7 @@ enable_full_determinism() ...@@ -22,7 +21,7 @@ enable_full_determinism()
@slow @slow
@require_torch_accelerator @require_torch_accelerator
class StableDiffusionXLPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingleFileTesterMixin): class TestStableDiffusionXLPipelineSingleFileSlow(SDXLSingleFileTesterMixin):
pipeline_class = StableDiffusionXLPipeline pipeline_class = StableDiffusionXLPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors" ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
repo_id = "stabilityai/stable-diffusion-xl-base-1.0" repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
...@@ -30,13 +29,11 @@ class StableDiffusionXLPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingle ...@@ -30,13 +29,11 @@ class StableDiffusionXLPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingle
"https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml" "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
) )
def setUp(self): def setup_method(self):
super().setUp()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
def tearDown(self): def teardown_method(self):
super().tearDown()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment