Unverified Commit 20fd00b1 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

[Tests] Add single file tester mixin for Models and remove unittest dependency (#12352)

* update

* update

* update

* update

* update
parent 76d4e416
import gc
import tempfile import tempfile
from io import BytesIO from io import BytesIO
...@@ -9,7 +10,10 @@ from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_nam ...@@ -9,7 +10,10 @@ from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_nam
from diffusers.models.attention_processor import AttnProcessor from diffusers.models.attention_processor import AttnProcessor
from ..testing_utils import ( from ..testing_utils import (
backend_empty_cache,
nightly,
numpy_cosine_similarity_distance, numpy_cosine_similarity_distance,
require_torch_accelerator,
torch_device, torch_device,
) )
...@@ -47,6 +51,93 @@ def download_diffusers_config(repo_id, tmpdir): ...@@ -47,6 +51,93 @@ def download_diffusers_config(repo_id, tmpdir):
return path return path
@nightly
@require_torch_accelerator
class SingleFileModelTesterMixin:
def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
def test_single_file_model_config(self):
pretrained_kwargs = {}
single_file_kwargs = {}
if hasattr(self, "subfolder") and self.subfolder:
pretrained_kwargs["subfolder"] = self.subfolder
if hasattr(self, "torch_dtype") and self.torch_dtype:
pretrained_kwargs["torch_dtype"] = self.torch_dtype
single_file_kwargs["torch_dtype"] = self.torch_dtype
model = self.model_class.from_pretrained(self.repo_id, **pretrained_kwargs)
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert model.config[param_name] == param_value, (
f"{param_name} differs between pretrained loading and single file loading"
)
def test_single_file_model_parameters(self):
pretrained_kwargs = {}
single_file_kwargs = {}
if hasattr(self, "subfolder") and self.subfolder:
pretrained_kwargs["subfolder"] = self.subfolder
if hasattr(self, "torch_dtype") and self.torch_dtype:
pretrained_kwargs["torch_dtype"] = self.torch_dtype
single_file_kwargs["torch_dtype"] = self.torch_dtype
model = self.model_class.from_pretrained(self.repo_id, **pretrained_kwargs)
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
state_dict = model.state_dict()
state_dict_single_file = model_single_file.state_dict()
assert set(state_dict.keys()) == set(state_dict_single_file.keys()), (
"Model parameters keys differ between pretrained and single file loading"
)
for key in state_dict.keys():
param = state_dict[key]
param_single_file = state_dict_single_file[key]
assert param.shape == param_single_file.shape, (
f"Parameter shape mismatch for {key}: "
f"pretrained {param.shape} vs single file {param_single_file.shape}"
)
assert torch.allclose(param, param_single_file, rtol=1e-5, atol=1e-5), (
f"Parameter values differ for {key}: "
f"max difference {torch.max(torch.abs(param - param_single_file)).item()}"
)
def test_checkpoint_altered_keys_loading(self):
# Test loading with checkpoints that have altered keys
if not hasattr(self, "alternate_keys_ckpt_paths") or not self.alternate_keys_ckpt_paths:
return
for ckpt_path in self.alternate_keys_ckpt_paths:
backend_empty_cache(torch_device)
single_file_kwargs = {}
if hasattr(self, "torch_dtype") and self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype
model = self.model_class.from_single_file(ckpt_path, **single_file_kwargs)
del model
gc.collect()
backend_empty_cache(torch_device)
class SDSingleFileTesterMixin: class SDSingleFileTesterMixin:
single_file_kwargs = {} single_file_kwargs = {}
......
...@@ -13,26 +13,21 @@ ...@@ -13,26 +13,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import gc
import unittest
from diffusers import ( from diffusers import (
Lumina2Transformer2DModel, Lumina2Transformer2DModel,
) )
from ..testing_utils import ( from ..testing_utils import (
backend_empty_cache,
enable_full_determinism, enable_full_determinism,
require_torch_accelerator,
torch_device,
) )
from .single_file_testing_utils import SingleFileModelTesterMixin
enable_full_determinism() enable_full_determinism()
@require_torch_accelerator class TestLumina2Transformer2DModelSingleFile(SingleFileModelTesterMixin):
class Lumina2Transformer2DModelSingleFileTests(unittest.TestCase):
model_class = Lumina2Transformer2DModel model_class = Lumina2Transformer2DModel
ckpt_path = "https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/diffusion_models/lumina_2_model_bf16.safetensors" ckpt_path = "https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/diffusion_models/lumina_2_model_bf16.safetensors"
alternate_keys_ckpt_paths = [ alternate_keys_ckpt_paths = [
...@@ -40,34 +35,4 @@ class Lumina2Transformer2DModelSingleFileTests(unittest.TestCase): ...@@ -40,34 +35,4 @@ class Lumina2Transformer2DModelSingleFileTests(unittest.TestCase):
] ]
repo_id = "Alpha-VLLM/Lumina-Image-2.0" repo_id = "Alpha-VLLM/Lumina-Image-2.0"
subfolder = "transformer"
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
def test_single_file_components(self):
model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
model_single_file = self.model_class.from_single_file(self.ckpt_path)
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert model.config[param_name] == param_value, (
f"{param_name} differs between single file loading and pretrained loading"
)
def test_checkpoint_loading(self):
for ckpt_path in self.alternate_keys_ckpt_paths:
backend_empty_cache(torch_device)
model = self.model_class.from_single_file(ckpt_path)
del model
gc.collect()
backend_empty_cache(torch_device)
...@@ -13,8 +13,6 @@ ...@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import gc
import unittest
import torch import torch
...@@ -23,38 +21,24 @@ from diffusers import ( ...@@ -23,38 +21,24 @@ from diffusers import (
) )
from ..testing_utils import ( from ..testing_utils import (
backend_empty_cache,
enable_full_determinism, enable_full_determinism,
load_hf_numpy, load_hf_numpy,
numpy_cosine_similarity_distance, numpy_cosine_similarity_distance,
require_torch_accelerator,
slow,
torch_device, torch_device,
) )
from .single_file_testing_utils import SingleFileModelTesterMixin
enable_full_determinism() enable_full_determinism()
@slow class TestAutoencoderDCSingleFile(SingleFileModelTesterMixin):
@require_torch_accelerator
class AutoencoderDCSingleFileTests(unittest.TestCase):
model_class = AutoencoderDC model_class = AutoencoderDC
ckpt_path = "https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0/blob/main/model.safetensors" ckpt_path = "https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0/blob/main/model.safetensors"
repo_id = "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers" repo_id = "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"
main_input_name = "sample" main_input_name = "sample"
base_precision = 1e-2 base_precision = 1e-2
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
def get_file_format(self, seed, shape): def get_file_format(self, seed, shape):
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy" return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
...@@ -80,18 +64,6 @@ class AutoencoderDCSingleFileTests(unittest.TestCase): ...@@ -80,18 +64,6 @@ class AutoencoderDCSingleFileTests(unittest.TestCase):
assert numpy_cosine_similarity_distance(output_slice_1, output_slice_2) < 1e-4 assert numpy_cosine_similarity_distance(output_slice_1, output_slice_2) < 1e-4
def test_single_file_components(self):
model = self.model_class.from_pretrained(self.repo_id)
model_single_file = self.model_class.from_single_file(self.ckpt_path)
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert model.config[param_name] == param_value, (
f"{param_name} differs between pretrained loading and single file loading"
)
def test_single_file_in_type_variant_components(self): def test_single_file_in_type_variant_components(self):
# `in` variant checkpoints require passing in a `config` parameter # `in` variant checkpoints require passing in a `config` parameter
# in order to set the scaling factor correctly. # in order to set the scaling factor correctly.
......
...@@ -13,8 +13,6 @@ ...@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import gc
import unittest
import torch import torch
...@@ -23,46 +21,19 @@ from diffusers import ( ...@@ -23,46 +21,19 @@ from diffusers import (
) )
from ..testing_utils import ( from ..testing_utils import (
backend_empty_cache,
enable_full_determinism, enable_full_determinism,
require_torch_accelerator,
slow,
torch_device,
) )
from .single_file_testing_utils import SingleFileModelTesterMixin
enable_full_determinism() enable_full_determinism()
@slow class TestControlNetModelSingleFile(SingleFileModelTesterMixin):
@require_torch_accelerator
class ControlNetModelSingleFileTests(unittest.TestCase):
model_class = ControlNetModel model_class = ControlNetModel
ckpt_path = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" ckpt_path = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
repo_id = "lllyasviel/control_v11p_sd15_canny" repo_id = "lllyasviel/control_v11p_sd15_canny"
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
def test_single_file_components(self):
model = self.model_class.from_pretrained(self.repo_id)
model_single_file = self.model_class.from_single_file(self.ckpt_path)
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert model.config[param_name] == param_value, (
f"{param_name} differs between single file loading and pretrained loading"
)
def test_single_file_arguments(self): def test_single_file_arguments(self):
model_default = self.model_class.from_single_file(self.ckpt_path) model_default = self.model_class.from_single_file(self.ckpt_path)
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
import gc import gc
import unittest
from diffusers import ( from diffusers import (
FluxTransformer2DModel, FluxTransformer2DModel,
...@@ -23,52 +22,21 @@ from diffusers import ( ...@@ -23,52 +22,21 @@ from diffusers import (
from ..testing_utils import ( from ..testing_utils import (
backend_empty_cache, backend_empty_cache,
enable_full_determinism, enable_full_determinism,
require_torch_accelerator,
torch_device, torch_device,
) )
from .single_file_testing_utils import SingleFileModelTesterMixin
enable_full_determinism() enable_full_determinism()
@require_torch_accelerator class TestFluxTransformer2DModelSingleFile(SingleFileModelTesterMixin):
class FluxTransformer2DModelSingleFileTests(unittest.TestCase):
model_class = FluxTransformer2DModel model_class = FluxTransformer2DModel
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors" ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"] alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"]
repo_id = "black-forest-labs/FLUX.1-dev" repo_id = "black-forest-labs/FLUX.1-dev"
subfolder = "transformer"
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
def test_single_file_components(self):
model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
model_single_file = self.model_class.from_single_file(self.ckpt_path)
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert model.config[param_name] == param_value, (
f"{param_name} differs between single file loading and pretrained loading"
)
def test_checkpoint_loading(self):
for ckpt_path in self.alternate_keys_ckpt_paths:
backend_empty_cache(torch_device)
model = self.model_class.from_single_file(ckpt_path)
del model
gc.collect()
backend_empty_cache(torch_device)
def test_device_map_cuda(self): def test_device_map_cuda(self):
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import unittest
from diffusers import ( from diffusers import (
MotionAdapter, MotionAdapter,
...@@ -27,7 +26,7 @@ from ..testing_utils import ( ...@@ -27,7 +26,7 @@ from ..testing_utils import (
enable_full_determinism() enable_full_determinism()
class MotionAdapterSingleFileTests(unittest.TestCase): class MotionAdapterSingleFileTests:
model_class = MotionAdapter model_class = MotionAdapter
def test_single_file_components_version_v1_5(self): def test_single_file_components_version_v1_5(self):
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
import gc import gc
import unittest
import torch import torch
...@@ -37,14 +36,12 @@ enable_full_determinism() ...@@ -37,14 +36,12 @@ enable_full_determinism()
@slow @slow
@require_torch_accelerator @require_torch_accelerator
class StableCascadeUNetSingleFileTest(unittest.TestCase): class StableCascadeUNetSingleFileTest:
def setUp(self): def setup_method(self):
super().setUp()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
def tearDown(self): def teardown_method(self):
super().tearDown()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
......
...@@ -13,8 +13,6 @@ ...@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import gc
import unittest
import torch import torch
...@@ -23,22 +21,18 @@ from diffusers import ( ...@@ -23,22 +21,18 @@ from diffusers import (
) )
from ..testing_utils import ( from ..testing_utils import (
backend_empty_cache,
enable_full_determinism, enable_full_determinism,
load_hf_numpy, load_hf_numpy,
numpy_cosine_similarity_distance, numpy_cosine_similarity_distance,
require_torch_accelerator,
slow,
torch_device, torch_device,
) )
from .single_file_testing_utils import SingleFileModelTesterMixin
enable_full_determinism() enable_full_determinism()
@slow class TestAutoencoderKLSingleFile(SingleFileModelTesterMixin):
@require_torch_accelerator
class AutoencoderKLSingleFileTests(unittest.TestCase):
model_class = AutoencoderKL model_class = AutoencoderKL
ckpt_path = ( ckpt_path = (
"https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors"
...@@ -47,16 +41,6 @@ class AutoencoderKLSingleFileTests(unittest.TestCase): ...@@ -47,16 +41,6 @@ class AutoencoderKLSingleFileTests(unittest.TestCase):
main_input_name = "sample" main_input_name = "sample"
base_precision = 1e-2 base_precision = 1e-2
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
def get_file_format(self, seed, shape): def get_file_format(self, seed, shape):
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy" return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
...@@ -84,18 +68,6 @@ class AutoencoderKLSingleFileTests(unittest.TestCase): ...@@ -84,18 +68,6 @@ class AutoencoderKLSingleFileTests(unittest.TestCase):
assert numpy_cosine_similarity_distance(output_slice_1, output_slice_2) < 1e-4 assert numpy_cosine_similarity_distance(output_slice_1, output_slice_2) < 1e-4
def test_single_file_components(self):
model = self.model_class.from_pretrained(self.repo_id)
model_single_file = self.model_class.from_single_file(self.ckpt_path, config=self.repo_id)
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert model.config[param_name] == param_value, (
f"{param_name} differs between pretrained loading and single file loading"
)
def test_single_file_arguments(self): def test_single_file_arguments(self):
model_default = self.model_class.from_single_file(self.ckpt_path, config=self.repo_id) model_default = self.model_class.from_single_file(self.ckpt_path, config=self.repo_id)
......
...@@ -13,50 +13,24 @@ ...@@ -13,50 +13,24 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import gc
import unittest
from diffusers import ( from diffusers import (
AutoencoderKLWan, AutoencoderKLWan,
) )
from ..testing_utils import ( from ..testing_utils import (
backend_empty_cache,
enable_full_determinism, enable_full_determinism,
require_torch_accelerator,
torch_device,
) )
from .single_file_testing_utils import SingleFileModelTesterMixin
enable_full_determinism() enable_full_determinism()
@require_torch_accelerator class TestAutoencoderKLWanSingleFile(SingleFileModelTesterMixin):
class AutoencoderKLWanSingleFileTests(unittest.TestCase):
model_class = AutoencoderKLWan model_class = AutoencoderKLWan
ckpt_path = ( ckpt_path = (
"https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/vae/wan_2.1_vae.safetensors" "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/vae/wan_2.1_vae.safetensors"
) )
repo_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" repo_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
subfolder = "vae"
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
def test_single_file_components(self):
model = self.model_class.from_pretrained(self.repo_id, subfolder="vae")
model_single_file = self.model_class.from_single_file(self.ckpt_path)
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert model.config[param_name] == param_value, (
f"{param_name} differs between single file loading and pretrained loading"
)
...@@ -13,8 +13,6 @@ ...@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import gc
import unittest
import torch import torch
...@@ -23,72 +21,26 @@ from diffusers import ( ...@@ -23,72 +21,26 @@ from diffusers import (
) )
from ..testing_utils import ( from ..testing_utils import (
backend_empty_cache,
enable_full_determinism, enable_full_determinism,
require_big_accelerator, require_big_accelerator,
require_torch_accelerator,
torch_device,
) )
from .single_file_testing_utils import SingleFileModelTesterMixin
enable_full_determinism() enable_full_determinism()
@require_torch_accelerator class TestWanTransformer3DModelText2VideoSingleFile(SingleFileModelTesterMixin):
class WanTransformer3DModelText2VideoSingleFileTest(unittest.TestCase):
model_class = WanTransformer3DModel model_class = WanTransformer3DModel
ckpt_path = "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors" ckpt_path = "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors"
repo_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" repo_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
subfolder = "transformer"
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
def test_single_file_components(self):
model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
model_single_file = self.model_class.from_single_file(self.ckpt_path)
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert model.config[param_name] == param_value, (
f"{param_name} differs between single file loading and pretrained loading"
)
@require_big_accelerator @require_big_accelerator
@require_torch_accelerator class TestWanTransformer3DModelImage2VideoSingleFile(SingleFileModelTesterMixin):
class WanTransformer3DModelImage2VideoSingleFileTest(unittest.TestCase):
model_class = WanTransformer3DModel model_class = WanTransformer3DModel
ckpt_path = "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_i2v_480p_14B_fp8_e4m3fn.safetensors" ckpt_path = "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_i2v_480p_14B_fp8_e4m3fn.safetensors"
repo_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" repo_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
torch_dtype = torch.float8_e4m3fn torch_dtype = torch.float8_e4m3fn
subfolder = "transformer"
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
def test_single_file_components(self):
model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer", torch_dtype=self.torch_dtype)
model_single_file = self.model_class.from_single_file(self.ckpt_path, torch_dtype=self.torch_dtype)
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert model.config[param_name] == param_value, (
f"{param_name} differs between single file loading and pretrained loading"
)
import gc
import unittest
from diffusers import ( from diffusers import (
SanaTransformer2DModel, SanaTransformer2DModel,
) )
from ..testing_utils import ( from ..testing_utils import (
backend_empty_cache,
enable_full_determinism, enable_full_determinism,
require_torch_accelerator,
torch_device,
) )
from .single_file_testing_utils import SingleFileModelTesterMixin
enable_full_determinism() enable_full_determinism()
@require_torch_accelerator class TestSanaTransformer2DModelSingleFile(SingleFileModelTesterMixin):
class SanaTransformer2DModelSingleFileTests(unittest.TestCase):
model_class = SanaTransformer2DModel model_class = SanaTransformer2DModel
ckpt_path = ( ckpt_path = (
"https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px/blob/main/checkpoints/Sana_1600M_1024px.pth" "https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px/blob/main/checkpoints/Sana_1600M_1024px.pth"
...@@ -27,34 +21,4 @@ class SanaTransformer2DModelSingleFileTests(unittest.TestCase): ...@@ -27,34 +21,4 @@ class SanaTransformer2DModelSingleFileTests(unittest.TestCase):
] ]
repo_id = "Efficient-Large-Model/Sana_1600M_1024px_diffusers" repo_id = "Efficient-Large-Model/Sana_1600M_1024px_diffusers"
subfolder = "transformer"
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
def test_single_file_components(self):
model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
model_single_file = self.model_class.from_single_file(self.ckpt_path)
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert model.config[param_name] == param_value, (
f"{param_name} differs between single file loading and pretrained loading"
)
def test_checkpoint_loading(self):
for ckpt_path in self.alternate_keys_ckpt_paths:
backend_empty_cache(torch_device)
model = self.model_class.from_single_file(ckpt_path)
del model
gc.collect()
backend_empty_cache(torch_device)
import gc import gc
import tempfile import tempfile
import unittest
import torch import torch
...@@ -29,7 +28,7 @@ enable_full_determinism() ...@@ -29,7 +28,7 @@ enable_full_determinism()
@slow @slow
@require_torch_accelerator @require_torch_accelerator
class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin): class TestStableDiffusionControlNetPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionControlNetPipeline pipeline_class = StableDiffusionControlNetPipeline
ckpt_path = ( ckpt_path = (
"https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors" "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
...@@ -39,13 +38,11 @@ class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SD ...@@ -39,13 +38,11 @@ class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SD
) )
repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5" repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
def setUp(self): def setup_method(self):
super().setUp()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
def tearDown(self): def teardown_method(self):
super().tearDown()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
......
import gc import gc
import tempfile import tempfile
import unittest
import pytest
import torch import torch
from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
...@@ -29,19 +29,17 @@ enable_full_determinism() ...@@ -29,19 +29,17 @@ enable_full_determinism()
@slow @slow
@require_torch_accelerator @require_torch_accelerator
class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin): class TestStableDiffusionControlNetInpaintPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionControlNetInpaintPipeline pipeline_class = StableDiffusionControlNetInpaintPipeline
ckpt_path = "https://huggingface.co/botp/stable-diffusion-v1-5-inpainting/blob/main/sd-v1-5-inpainting.ckpt" ckpt_path = "https://huggingface.co/botp/stable-diffusion-v1-5-inpainting/blob/main/sd-v1-5-inpainting.ckpt"
original_config = "https://raw.githubusercontent.com/runwayml/stable-diffusion/main/configs/stable-diffusion/v1-inpainting-inference.yaml" original_config = "https://raw.githubusercontent.com/runwayml/stable-diffusion/main/configs/stable-diffusion/v1-inpainting-inference.yaml"
repo_id = "stable-diffusion-v1-5/stable-diffusion-inpainting" repo_id = "stable-diffusion-v1-5/stable-diffusion-inpainting"
def setUp(self): def setup_method(self):
super().setUp()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
def tearDown(self): def teardown_method(self):
super().tearDown()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
...@@ -115,7 +113,7 @@ class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestC ...@@ -115,7 +113,7 @@ class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestC
super()._compare_component_configs(pipe, pipe_single_file) super()._compare_component_configs(pipe, pipe_single_file)
@unittest.skip("runwayml original config repo does not exist") @pytest.mark.skip(reason="runwayml original config repo does not exist")
def test_single_file_components_with_original_config(self): def test_single_file_components_with_original_config(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny", variant="fp16") controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny", variant="fp16")
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet) pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet)
...@@ -125,7 +123,7 @@ class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestC ...@@ -125,7 +123,7 @@ class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestC
super()._compare_component_configs(pipe, pipe_single_file) super()._compare_component_configs(pipe, pipe_single_file)
@unittest.skip("runwayml original config repo does not exist") @pytest.mark.skip(reason="runwayml original config repo does not exist")
def test_single_file_components_with_original_config_local_files_only(self): def test_single_file_components_with_original_config_local_files_only(self):
controlnet = ControlNetModel.from_pretrained( controlnet = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16" "lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16"
......
import gc import gc
import tempfile import tempfile
import unittest
import torch import torch
...@@ -29,7 +28,7 @@ enable_full_determinism() ...@@ -29,7 +28,7 @@ enable_full_determinism()
@slow @slow
@require_torch_accelerator @require_torch_accelerator
class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin): class TestStableDiffusionControlNetPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionControlNetPipeline pipeline_class = StableDiffusionControlNetPipeline
ckpt_path = ( ckpt_path = (
"https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors" "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
...@@ -39,13 +38,11 @@ class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SD ...@@ -39,13 +38,11 @@ class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SD
) )
repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5" repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
def setUp(self): def setup_method(self):
super().setUp()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
def tearDown(self): def teardown_method(self):
super().tearDown()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
......
import gc import gc
import unittest
import torch import torch
...@@ -23,7 +22,7 @@ enable_full_determinism() ...@@ -23,7 +22,7 @@ enable_full_determinism()
@slow @slow
@require_torch_accelerator @require_torch_accelerator
class StableDiffusionImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin): class TestStableDiffusionImg2ImgPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionImg2ImgPipeline pipeline_class = StableDiffusionImg2ImgPipeline
ckpt_path = ( ckpt_path = (
"https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors" "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
...@@ -33,13 +32,11 @@ class StableDiffusionImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDSin ...@@ -33,13 +32,11 @@ class StableDiffusionImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDSin
) )
repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5" repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
def setUp(self): def setup_method(self):
super().setUp()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
def tearDown(self): def teardown_method(self):
super().tearDown()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
...@@ -66,19 +63,17 @@ class StableDiffusionImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDSin ...@@ -66,19 +63,17 @@ class StableDiffusionImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDSin
@slow @slow
@require_torch_accelerator @require_torch_accelerator
class StableDiffusion21Img2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin): class TestStableDiffusion21Img2ImgPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionImg2ImgPipeline pipeline_class = StableDiffusionImg2ImgPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors" ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors"
original_config = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" original_config = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
repo_id = "stabilityai/stable-diffusion-2-1" repo_id = "stabilityai/stable-diffusion-2-1"
def setUp(self): def setup_method(self):
super().setUp()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
def tearDown(self): def teardown_method(self):
super().tearDown()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
......
import gc import gc
import unittest
import pytest
import torch import torch
from diffusers import ( from diffusers import (
...@@ -23,19 +23,17 @@ enable_full_determinism() ...@@ -23,19 +23,17 @@ enable_full_determinism()
@slow @slow
@require_torch_accelerator @require_torch_accelerator
class StableDiffusionInpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin): class TestStableDiffusionInpaintPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionInpaintPipeline pipeline_class = StableDiffusionInpaintPipeline
ckpt_path = "https://huggingface.co/botp/stable-diffusion-v1-5-inpainting/blob/main/sd-v1-5-inpainting.ckpt" ckpt_path = "https://huggingface.co/botp/stable-diffusion-v1-5-inpainting/blob/main/sd-v1-5-inpainting.ckpt"
original_config = "https://raw.githubusercontent.com/runwayml/stable-diffusion/main/configs/stable-diffusion/v1-inpainting-inference.yaml" original_config = "https://raw.githubusercontent.com/runwayml/stable-diffusion/main/configs/stable-diffusion/v1-inpainting-inference.yaml"
repo_id = "botp/stable-diffusion-v1-5-inpainting" repo_id = "botp/stable-diffusion-v1-5-inpainting"
def setUp(self): def setup_method(self):
super().setUp()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
def tearDown(self): def teardown_method(self):
super().tearDown()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
...@@ -70,18 +68,18 @@ class StableDiffusionInpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSin ...@@ -70,18 +68,18 @@ class StableDiffusionInpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSin
assert pipe.unet.config.in_channels == 4 assert pipe.unet.config.in_channels == 4
@unittest.skip("runwayml original config has been removed") @pytest.mark.skip(reason="runwayml original config has been removed")
def test_single_file_components_with_original_config(self): def test_single_file_components_with_original_config(self):
return return
@unittest.skip("runwayml original config has been removed") @pytest.mark.skip(reason="runwayml original config has been removed")
def test_single_file_components_with_original_config_local_files_only(self): def test_single_file_components_with_original_config_local_files_only(self):
return return
@slow @slow
@require_torch_accelerator @require_torch_accelerator
class StableDiffusion21InpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin): class TestStableDiffusion21InpaintPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionInpaintPipeline pipeline_class = StableDiffusionInpaintPipeline
ckpt_path = ( ckpt_path = (
"https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/blob/main/512-inpainting-ema.safetensors" "https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/blob/main/512-inpainting-ema.safetensors"
...@@ -89,13 +87,11 @@ class StableDiffusion21InpaintPipelineSingleFileSlowTests(unittest.TestCase, SDS ...@@ -89,13 +87,11 @@ class StableDiffusion21InpaintPipelineSingleFileSlowTests(unittest.TestCase, SDS
original_config = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inpainting-inference.yaml" original_config = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inpainting-inference.yaml"
repo_id = "stabilityai/stable-diffusion-2-inpainting" repo_id = "stabilityai/stable-diffusion-2-inpainting"
def setUp(self): def setup_method(self):
super().setUp()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
def tearDown(self): def teardown_method(self):
super().tearDown()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
......
import gc import gc
import tempfile import tempfile
import unittest
import torch import torch
...@@ -28,7 +27,7 @@ enable_full_determinism() ...@@ -28,7 +27,7 @@ enable_full_determinism()
@slow @slow
@require_torch_accelerator @require_torch_accelerator
class StableDiffusionPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin): class TestStableDiffusionPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionPipeline pipeline_class = StableDiffusionPipeline
ckpt_path = ( ckpt_path = (
"https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors" "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
...@@ -38,13 +37,11 @@ class StableDiffusionPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFile ...@@ -38,13 +37,11 @@ class StableDiffusionPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFile
) )
repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5" repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
def setUp(self): def setup_method(self):
super().setUp()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
def tearDown(self): def teardown_method(self):
super().tearDown()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
...@@ -90,19 +87,17 @@ class StableDiffusionPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFile ...@@ -90,19 +87,17 @@ class StableDiffusionPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFile
@slow @slow
class StableDiffusion21PipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin): class TestStableDiffusion21PipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionPipeline pipeline_class = StableDiffusionPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors" ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors"
original_config = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" original_config = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
repo_id = "stabilityai/stable-diffusion-2-1" repo_id = "stabilityai/stable-diffusion-2-1"
def setUp(self): def setup_method(self):
super().setUp()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
def tearDown(self): def teardown_method(self):
super().tearDown()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
...@@ -125,7 +120,7 @@ class StableDiffusion21PipelineSingleFileSlowTests(unittest.TestCase, SDSingleFi ...@@ -125,7 +120,7 @@ class StableDiffusion21PipelineSingleFileSlowTests(unittest.TestCase, SDSingleFi
@nightly @nightly
@slow @slow
@require_torch_accelerator @require_torch_accelerator
class StableDiffusionInstructPix2PixPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin): class TestStableDiffusionInstructPix2PixPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionInstructPix2PixPipeline pipeline_class = StableDiffusionInstructPix2PixPipeline
ckpt_path = "https://huggingface.co/timbrooks/instruct-pix2pix/blob/main/instruct-pix2pix-00-22000.safetensors" ckpt_path = "https://huggingface.co/timbrooks/instruct-pix2pix/blob/main/instruct-pix2pix-00-22000.safetensors"
original_config = ( original_config = (
...@@ -134,13 +129,11 @@ class StableDiffusionInstructPix2PixPipelineSingleFileSlowTests(unittest.TestCas ...@@ -134,13 +129,11 @@ class StableDiffusionInstructPix2PixPipelineSingleFileSlowTests(unittest.TestCas
repo_id = "timbrooks/instruct-pix2pix" repo_id = "timbrooks/instruct-pix2pix"
single_file_kwargs = {"extract_ema": True} single_file_kwargs = {"extract_ema": True}
def setUp(self): def setup_method(self):
super().setUp()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
def tearDown(self): def teardown_method(self):
super().tearDown()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
......
import gc import gc
import unittest
import pytest import pytest
import torch import torch
...@@ -25,19 +24,17 @@ enable_full_determinism() ...@@ -25,19 +24,17 @@ enable_full_determinism()
@slow @slow
@require_torch_accelerator @require_torch_accelerator
class StableDiffusionUpscalePipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin): class TestStableDiffusionUpscalePipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionUpscalePipeline pipeline_class = StableDiffusionUpscalePipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/blob/main/x4-upscaler-ema.safetensors" ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/blob/main/x4-upscaler-ema.safetensors"
original_config = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml" original_config = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml"
repo_id = "stabilityai/stable-diffusion-x4-upscaler" repo_id = "stabilityai/stable-diffusion-x4-upscaler"
def setUp(self): def setup_method(self):
super().setUp()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
def tearDown(self): def teardown_method(self):
super().tearDown()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
......
import gc import gc
import tempfile import tempfile
import unittest
import torch import torch
...@@ -32,7 +31,7 @@ enable_full_determinism() ...@@ -32,7 +31,7 @@ enable_full_determinism()
@slow @slow
@require_torch_accelerator @require_torch_accelerator
class StableDiffusionXLAdapterPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingleFileTesterMixin): class TestStableDiffusionXLAdapterPipelineSingleFileSlow(SDXLSingleFileTesterMixin):
pipeline_class = StableDiffusionXLAdapterPipeline pipeline_class = StableDiffusionXLAdapterPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors" ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
repo_id = "stabilityai/stable-diffusion-xl-base-1.0" repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
...@@ -40,13 +39,11 @@ class StableDiffusionXLAdapterPipelineSingleFileSlowTests(unittest.TestCase, SDX ...@@ -40,13 +39,11 @@ class StableDiffusionXLAdapterPipelineSingleFileSlowTests(unittest.TestCase, SDX
"https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml" "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
) )
def setUp(self): def setup_method(self):
super().setUp()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
def tearDown(self): def teardown_method(self):
super().tearDown()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
......
import gc import gc
import tempfile import tempfile
import unittest
import torch import torch
...@@ -28,7 +27,7 @@ enable_full_determinism() ...@@ -28,7 +27,7 @@ enable_full_determinism()
@slow @slow
@require_torch_accelerator @require_torch_accelerator
class StableDiffusionXLControlNetPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingleFileTesterMixin): class TestStableDiffusionXLControlNetPipelineSingleFileSlow(SDXLSingleFileTesterMixin):
pipeline_class = StableDiffusionXLControlNetPipeline pipeline_class = StableDiffusionXLControlNetPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors" ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
repo_id = "stabilityai/stable-diffusion-xl-base-1.0" repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
...@@ -36,13 +35,11 @@ class StableDiffusionXLControlNetPipelineSingleFileSlowTests(unittest.TestCase, ...@@ -36,13 +35,11 @@ class StableDiffusionXLControlNetPipelineSingleFileSlowTests(unittest.TestCase,
"https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml" "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
) )
def setUp(self): def setup_method(self):
super().setUp()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
def tearDown(self): def teardown_method(self):
super().tearDown()
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment