Unverified Commit c8d86e9f authored by Abhipsha Das's avatar Abhipsha Das Committed by GitHub
Browse files

Remove code snippets containing `is_safetensors_available()` (#4521)



* [WIP] Remove code snippets containing `is_safetensors_available()`

* Modifying `import_utils.py`

* update pipeline tests for safetensor default

* fix test related to cached requests

* address import nits

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent b28cd3fb
...@@ -2,14 +2,8 @@ import glob ...@@ -2,14 +2,8 @@ import glob
import os import os
from typing import Dict, List, Union from typing import Dict, List, Union
import safetensors.torch
import torch import torch
from diffusers.utils import is_safetensors_available
if is_safetensors_available():
import safetensors.torch
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from diffusers import DiffusionPipeline, __version__ from diffusers import DiffusionPipeline, __version__
...@@ -229,14 +223,14 @@ class CheckpointMergerPipeline(DiffusionPipeline): ...@@ -229,14 +223,14 @@ class CheckpointMergerPipeline(DiffusionPipeline):
update_theta_0 = getattr(module, "load_state_dict") update_theta_0 = getattr(module, "load_state_dict")
theta_1 = ( theta_1 = (
safetensors.torch.load_file(checkpoint_path_1) safetensors.torch.load_file(checkpoint_path_1)
if (is_safetensors_available() and checkpoint_path_1.endswith(".safetensors")) if (checkpoint_path_1.endswith(".safetensors"))
else torch.load(checkpoint_path_1, map_location="cpu") else torch.load(checkpoint_path_1, map_location="cpu")
) )
theta_2 = None theta_2 = None
if checkpoint_path_2: if checkpoint_path_2:
theta_2 = ( theta_2 = (
safetensors.torch.load_file(checkpoint_path_2) safetensors.torch.load_file(checkpoint_path_2)
if (is_safetensors_available() and checkpoint_path_2.endswith(".safetensors")) if (checkpoint_path_2.endswith(".safetensors"))
else torch.load(checkpoint_path_2, map_location="cpu") else torch.load(checkpoint_path_2, map_location="cpu")
) )
......
...@@ -38,7 +38,7 @@ from diffusers import ( ...@@ -38,7 +38,7 @@ from diffusers import (
PNDMScheduler, PNDMScheduler,
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.utils import is_omegaconf_available, is_safetensors_available from diffusers.utils import is_omegaconf_available
from diffusers.utils.import_utils import BACKENDS_MAPPING from diffusers.utils.import_utils import BACKENDS_MAPPING
...@@ -824,9 +824,6 @@ def load_pipeline_from_original_audioldm_ckpt( ...@@ -824,9 +824,6 @@ def load_pipeline_from_original_audioldm_ckpt(
from omegaconf import OmegaConf from omegaconf import OmegaConf
if from_safetensors: if from_safetensors:
if not is_safetensors_available():
raise ValueError(BACKENDS_MAPPING["safetensors"][1])
from safetensors import safe_open from safetensors import safe_open
checkpoint = {} checkpoint = {}
......
import argparse import argparse
from diffusers.utils import is_safetensors_available import safetensors.torch
if is_safetensors_available():
import safetensors.torch
else:
raise ImportError("Please install `safetensors`.")
from diffusers import AutoencoderTiny from diffusers import AutoencoderTiny
......
...@@ -27,7 +27,7 @@ import torch ...@@ -27,7 +27,7 @@ import torch
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from packaging import version from packaging import version
from ..utils import is_safetensors_available, logging from ..utils import logging
from . import BaseDiffusersCLICommand from . import BaseDiffusersCLICommand
...@@ -68,12 +68,7 @@ class FP16SafetensorsCommand(BaseDiffusersCLICommand): ...@@ -68,12 +68,7 @@ class FP16SafetensorsCommand(BaseDiffusersCLICommand):
self.local_ckpt_dir = f"/tmp/{ckpt_id}" self.local_ckpt_dir = f"/tmp/{ckpt_id}"
self.fp16 = fp16 self.fp16 = fp16
if is_safetensors_available(): self.use_safetensors = use_safetensors
self.use_safetensors = use_safetensors
else:
raise ImportError(
"When `use_safetensors` is set to True, the `safetensors` library needs to be installed. Install it via `pip install safetensors`."
)
if not self.use_safetensors and not self.fp16: if not self.use_safetensors and not self.fp16:
raise NotImplementedError( raise NotImplementedError(
......
...@@ -22,6 +22,7 @@ from pathlib import Path ...@@ -22,6 +22,7 @@ from pathlib import Path
from typing import Callable, Dict, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
import requests import requests
import safetensors
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
...@@ -34,16 +35,12 @@ from .utils import ( ...@@ -34,16 +35,12 @@ from .utils import (
deprecate, deprecate,
is_accelerate_available, is_accelerate_available,
is_omegaconf_available, is_omegaconf_available,
is_safetensors_available,
is_transformers_available, is_transformers_available,
logging, logging,
) )
from .utils.import_utils import BACKENDS_MAPPING from .utils.import_utils import BACKENDS_MAPPING
if is_safetensors_available():
import safetensors
if is_transformers_available(): if is_transformers_available():
from transformers import CLIPTextModel, CLIPTextModelWithProjection, PreTrainedModel, PreTrainedTokenizer from transformers import CLIPTextModel, CLIPTextModelWithProjection, PreTrainedModel, PreTrainedTokenizer
...@@ -261,14 +258,10 @@ class UNet2DConditionLoadersMixin: ...@@ -261,14 +258,10 @@ class UNet2DConditionLoadersMixin:
network_alphas = kwargs.pop("network_alphas", None) network_alphas = kwargs.pop("network_alphas", None)
is_network_alphas_none = network_alphas is None is_network_alphas_none = network_alphas is None
if use_safetensors and not is_safetensors_available():
raise ValueError(
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
)
allow_pickle = False allow_pickle = False
if use_safetensors is None: if use_safetensors is None:
use_safetensors = is_safetensors_available() use_safetensors = True
allow_pickle = True allow_pickle = True
user_agent = { user_agent = {
...@@ -757,14 +750,9 @@ class TextualInversionLoaderMixin: ...@@ -757,14 +750,9 @@ class TextualInversionLoaderMixin:
weight_name = kwargs.pop("weight_name", None) weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
if use_safetensors and not is_safetensors_available():
raise ValueError(
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
)
allow_pickle = False allow_pickle = False
if use_safetensors is None: if use_safetensors is None:
use_safetensors = is_safetensors_available() use_safetensors = True
allow_pickle = True allow_pickle = True
user_agent = { user_agent = {
...@@ -1014,14 +1002,9 @@ class LoraLoaderMixin: ...@@ -1014,14 +1002,9 @@ class LoraLoaderMixin:
unet_config = kwargs.pop("unet_config", None) unet_config = kwargs.pop("unet_config", None)
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
if use_safetensors and not is_safetensors_available():
raise ValueError(
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
)
allow_pickle = False allow_pickle = False
if use_safetensors is None: if use_safetensors is None:
use_safetensors = is_safetensors_available() use_safetensors = True
allow_pickle = True allow_pickle = True
user_agent = { user_agent = {
...@@ -1853,7 +1836,7 @@ class FromSingleFileMixin: ...@@ -1853,7 +1836,7 @@ class FromSingleFileMixin:
torch_dtype = kwargs.pop("torch_dtype", None) torch_dtype = kwargs.pop("torch_dtype", None)
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) use_safetensors = kwargs.pop("use_safetensors", None)
pipeline_name = cls.__name__ pipeline_name = cls.__name__
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1] file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
...@@ -2050,7 +2033,7 @@ class FromOriginalVAEMixin: ...@@ -2050,7 +2033,7 @@ class FromOriginalVAEMixin:
torch_dtype = kwargs.pop("torch_dtype", None) torch_dtype = kwargs.pop("torch_dtype", None)
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) use_safetensors = kwargs.pop("use_safetensors", None)
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1] file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
from_safetensors = file_extension == "safetensors" from_safetensors = file_extension == "safetensors"
...@@ -2223,7 +2206,7 @@ class FromOriginalControlnetMixin: ...@@ -2223,7 +2206,7 @@ class FromOriginalControlnetMixin:
torch_dtype = kwargs.pop("torch_dtype", None) torch_dtype = kwargs.pop("torch_dtype", None)
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) use_safetensors = kwargs.pop("use_safetensors", None)
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1] file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
from_safetensors = file_extension == "safetensors" from_safetensors = file_extension == "safetensors"
......
...@@ -21,6 +21,7 @@ import re ...@@ -21,6 +21,7 @@ import re
from functools import partial from functools import partial
from typing import Any, Callable, List, Optional, Tuple, Union from typing import Any, Callable, List, Optional, Tuple, Union
import safetensors
import torch import torch
from torch import Tensor, device, nn from torch import Tensor, device, nn
...@@ -36,7 +37,6 @@ from ..utils import ( ...@@ -36,7 +37,6 @@ from ..utils import (
_get_model_file, _get_model_file,
deprecate, deprecate,
is_accelerate_available, is_accelerate_available,
is_safetensors_available,
is_torch_version, is_torch_version,
logging, logging,
) )
...@@ -56,9 +56,6 @@ if is_accelerate_available(): ...@@ -56,9 +56,6 @@ if is_accelerate_available():
from accelerate.utils import set_module_tensor_to_device from accelerate.utils import set_module_tensor_to_device
from accelerate.utils.versions import is_torch_version from accelerate.utils.versions import is_torch_version
if is_safetensors_available():
import safetensors
def get_parameter_device(parameter: torch.nn.Module): def get_parameter_device(parameter: torch.nn.Module):
try: try:
...@@ -296,9 +293,6 @@ class ModelMixin(torch.nn.Module): ...@@ -296,9 +293,6 @@ class ModelMixin(torch.nn.Module):
variant (`str`, *optional*): variant (`str`, *optional*):
If specified, weights are saved in the format `pytorch_model.<variant>.bin`. If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
""" """
if safe_serialization and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
if os.path.isfile(save_directory): if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file") logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return return
...@@ -454,14 +448,9 @@ class ModelMixin(torch.nn.Module): ...@@ -454,14 +448,9 @@ class ModelMixin(torch.nn.Module):
variant = kwargs.pop("variant", None) variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
if use_safetensors and not is_safetensors_available():
raise ValueError(
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
)
allow_pickle = False allow_pickle = False
if use_safetensors is None: if use_safetensors is None:
use_safetensors = is_safetensors_available() use_safetensors = True
allow_pickle = True allow_pickle = True
if low_cpu_mem_usage and not is_accelerate_available(): if low_cpu_mem_usage and not is_accelerate_available():
......
...@@ -52,7 +52,6 @@ from ..utils import ( ...@@ -52,7 +52,6 @@ from ..utils import (
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
is_compiled_module, is_compiled_module,
is_safetensors_available,
is_torch_version, is_torch_version,
is_transformers_available, is_transformers_available,
logging, logging,
...@@ -899,7 +898,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -899,7 +898,7 @@ class DiffusionPipeline(ConfigMixin):
offload_state_dict = kwargs.pop("offload_state_dict", False) offload_state_dict = kwargs.pop("offload_state_dict", False)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None) variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) use_safetensors = kwargs.pop("use_safetensors", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
# 1. Download the checkpoints and configs # 1. Download the checkpoints and configs
...@@ -1311,14 +1310,9 @@ class DiffusionPipeline(ConfigMixin): ...@@ -1311,14 +1310,9 @@ class DiffusionPipeline(ConfigMixin):
use_onnx = kwargs.pop("use_onnx", None) use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
if use_safetensors and not is_safetensors_available():
raise ValueError(
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
)
allow_pickle = False allow_pickle = False
if use_safetensors is None: if use_safetensors is None:
use_safetensors = is_safetensors_available() use_safetensors = True
allow_pickle = True allow_pickle = True
allow_patterns = None allow_patterns = None
......
...@@ -50,7 +50,7 @@ from ...schedulers import ( ...@@ -50,7 +50,7 @@ from ...schedulers import (
PNDMScheduler, PNDMScheduler,
UnCLIPScheduler, UnCLIPScheduler,
) )
from ...utils import is_accelerate_available, is_omegaconf_available, is_safetensors_available, logging from ...utils import is_accelerate_available, is_omegaconf_available, logging
from ...utils.import_utils import BACKENDS_MAPPING from ...utils.import_utils import BACKENDS_MAPPING
from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
from ..paint_by_example import PaintByExampleImageEncoder from ..paint_by_example import PaintByExampleImageEncoder
...@@ -1225,9 +1225,6 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1225,9 +1225,6 @@ def download_from_original_stable_diffusion_ckpt(
from omegaconf import OmegaConf from omegaconf import OmegaConf
if from_safetensors: if from_safetensors:
if not is_safetensors_available():
raise ValueError(BACKENDS_MAPPING["safetensors"][1])
from safetensors.torch import load_file as safe_load from safetensors.torch import load_file as safe_load
checkpoint = safe_load(checkpoint_path, device="cpu") checkpoint = safe_load(checkpoint_path, device="cpu")
...@@ -1650,9 +1647,6 @@ def download_controlnet_from_original_ckpt( ...@@ -1650,9 +1647,6 @@ def download_controlnet_from_original_ckpt(
from omegaconf import OmegaConf from omegaconf import OmegaConf
if from_safetensors: if from_safetensors:
if not is_safetensors_available():
raise ValueError(BACKENDS_MAPPING["safetensors"][1])
from safetensors import safe_open from safetensors import safe_open
checkpoint = {} checkpoint = {}
......
...@@ -64,7 +64,6 @@ from .import_utils import ( ...@@ -64,7 +64,6 @@ from .import_utils import (
is_note_seq_available, is_note_seq_available,
is_omegaconf_available, is_omegaconf_available,
is_onnx_available, is_onnx_available,
is_safetensors_available,
is_scipy_available, is_scipy_available,
is_tensorboard_available, is_tensorboard_available,
is_tf_available, is_tf_available,
......
...@@ -306,10 +306,6 @@ def is_torch_available(): ...@@ -306,10 +306,6 @@ def is_torch_available():
return _torch_available return _torch_available
def is_safetensors_available():
return _safetensors_available
def is_tf_available(): def is_tf_available():
return _tf_available return _tf_available
......
...@@ -60,10 +60,6 @@ class ModelUtilsTest(unittest.TestCase): ...@@ -60,10 +60,6 @@ class ModelUtilsTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
import diffusers
diffusers.utils.import_utils._safetensors_available = True
def test_accelerate_loading_error_message(self): def test_accelerate_loading_error_message(self):
with self.assertRaises(ValueError) as error_context: with self.assertRaises(ValueError) as error_context:
UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet") UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet")
...@@ -100,14 +96,15 @@ class ModelUtilsTest(unittest.TestCase): ...@@ -100,14 +96,15 @@ class ModelUtilsTest(unittest.TestCase):
if torch_device == "mps": if torch_device == "mps":
return return
import diffusers use_safetensors = False
diffusers.utils.import_utils._safetensors_available = False
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
with requests_mock.mock(real_http=True) as m: with requests_mock.mock(real_http=True) as m:
UNet2DConditionModel.from_pretrained( UNet2DConditionModel.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", cache_dir=tmpdirname "hf-internal-testing/tiny-stable-diffusion-torch",
subfolder="unet",
cache_dir=tmpdirname,
use_safetensors=use_safetensors,
) )
download_requests = [r.method for r in m.request_history] download_requests = [r.method for r in m.request_history]
...@@ -116,7 +113,10 @@ class ModelUtilsTest(unittest.TestCase): ...@@ -116,7 +113,10 @@ class ModelUtilsTest(unittest.TestCase):
with requests_mock.mock(real_http=True) as m: with requests_mock.mock(real_http=True) as m:
UNet2DConditionModel.from_pretrained( UNet2DConditionModel.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", cache_dir=tmpdirname "hf-internal-testing/tiny-stable-diffusion-torch",
subfolder="unet",
cache_dir=tmpdirname,
use_safetensors=use_safetensors,
) )
cache_requests = [r.method for r in m.request_history] cache_requests = [r.method for r in m.request_history]
...@@ -124,8 +124,6 @@ class ModelUtilsTest(unittest.TestCase): ...@@ -124,8 +124,6 @@ class ModelUtilsTest(unittest.TestCase):
"HEAD" == cache_requests[0] and len(cache_requests) == 1 "HEAD" == cache_requests[0] and len(cache_requests) == 1
), "We should call only `model_info` to check for _commit hash and `send_telemetry`" ), "We should call only `model_info` to check for _commit hash and `send_telemetry`"
diffusers.utils.import_utils._safetensors_available = True
def test_weight_overwrite(self): def test_weight_overwrite(self):
with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context: with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context:
UNet2DConditionModel.from_pretrained( UNet2DConditionModel.from_pretrained(
......
...@@ -472,15 +472,13 @@ class DownloadTests(unittest.TestCase): ...@@ -472,15 +472,13 @@ class DownloadTests(unittest.TestCase):
assert False, "Parameters not the same!" assert False, "Parameters not the same!"
def test_download_from_variant_folder(self): def test_download_from_variant_folder(self):
for safe_avail in [False, True]: for use_safetensors in [False, True]:
import diffusers other_format = ".bin" if use_safetensors else ".safetensors"
diffusers.utils.import_utils._safetensors_available = safe_avail
other_format = ".bin" if safe_avail else ".safetensors"
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = StableDiffusionPipeline.download( tmpdirname = StableDiffusionPipeline.download(
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname "hf-internal-testing/stable-diffusion-all-variants",
cache_dir=tmpdirname,
use_safetensors=use_safetensors,
) )
all_root_files = [t[-1] for t in os.walk(tmpdirname)] all_root_files = [t[-1] for t in os.walk(tmpdirname)]
files = [item for sublist in all_root_files for item in sublist] files = [item for sublist in all_root_files for item in sublist]
...@@ -492,21 +490,18 @@ class DownloadTests(unittest.TestCase): ...@@ -492,21 +490,18 @@ class DownloadTests(unittest.TestCase):
# no variants # no variants
assert not any(len(f.split(".")) == 3 for f in files) assert not any(len(f.split(".")) == 3 for f in files)
diffusers.utils.import_utils._safetensors_available = True
def test_download_variant_all(self): def test_download_variant_all(self):
for safe_avail in [False, True]: for use_safetensors in [False, True]:
import diffusers other_format = ".bin" if use_safetensors else ".safetensors"
this_format = ".safetensors" if use_safetensors else ".bin"
diffusers.utils.import_utils._safetensors_available = safe_avail
other_format = ".bin" if safe_avail else ".safetensors"
this_format = ".safetensors" if safe_avail else ".bin"
variant = "fp16" variant = "fp16"
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = StableDiffusionPipeline.download( tmpdirname = StableDiffusionPipeline.download(
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant "hf-internal-testing/stable-diffusion-all-variants",
cache_dir=tmpdirname,
variant=variant,
use_safetensors=use_safetensors,
) )
all_root_files = [t[-1] for t in os.walk(tmpdirname)] all_root_files = [t[-1] for t in os.walk(tmpdirname)]
files = [item for sublist in all_root_files for item in sublist] files = [item for sublist in all_root_files for item in sublist]
...@@ -520,21 +515,18 @@ class DownloadTests(unittest.TestCase): ...@@ -520,21 +515,18 @@ class DownloadTests(unittest.TestCase):
assert not any(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) assert not any(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files)
assert not any(f.endswith(other_format) for f in files) assert not any(f.endswith(other_format) for f in files)
diffusers.utils.import_utils._safetensors_available = True
def test_download_variant_partly(self): def test_download_variant_partly(self):
for safe_avail in [False, True]: for use_safetensors in [False, True]:
import diffusers other_format = ".bin" if use_safetensors else ".safetensors"
this_format = ".safetensors" if use_safetensors else ".bin"
diffusers.utils.import_utils._safetensors_available = safe_avail
other_format = ".bin" if safe_avail else ".safetensors"
this_format = ".safetensors" if safe_avail else ".bin"
variant = "no_ema" variant = "no_ema"
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = StableDiffusionPipeline.download( tmpdirname = StableDiffusionPipeline.download(
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant "hf-internal-testing/stable-diffusion-all-variants",
cache_dir=tmpdirname,
variant=variant,
use_safetensors=use_safetensors,
) )
all_root_files = [t[-1] for t in os.walk(tmpdirname)] all_root_files = [t[-1] for t in os.walk(tmpdirname)]
files = [item for sublist in all_root_files for item in sublist] files = [item for sublist in all_root_files for item in sublist]
...@@ -551,13 +543,8 @@ class DownloadTests(unittest.TestCase): ...@@ -551,13 +543,8 @@ class DownloadTests(unittest.TestCase):
assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3 assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3
assert not any(f.endswith(other_format) for f in files) assert not any(f.endswith(other_format) for f in files)
diffusers.utils.import_utils._safetensors_available = True
def test_download_broken_variant(self): def test_download_broken_variant(self):
for safe_avail in [False, True]: for use_safetensors in [False, True]:
import diffusers
diffusers.utils.import_utils._safetensors_available = safe_avail
# text encoder is missing no variant and "no_ema" variant weights, so the following can't work # text encoder is missing no variant and "no_ema" variant weights, so the following can't work
for variant in [None, "no_ema"]: for variant in [None, "no_ema"]:
with self.assertRaises(OSError) as error_context: with self.assertRaises(OSError) as error_context:
...@@ -566,6 +553,7 @@ class DownloadTests(unittest.TestCase): ...@@ -566,6 +553,7 @@ class DownloadTests(unittest.TestCase):
"hf-internal-testing/stable-diffusion-broken-variants", "hf-internal-testing/stable-diffusion-broken-variants",
cache_dir=tmpdirname, cache_dir=tmpdirname,
variant=variant, variant=variant,
use_safetensors=use_safetensors,
) )
assert "Error no file name" in str(error_context.exception) assert "Error no file name" in str(error_context.exception)
...@@ -573,7 +561,10 @@ class DownloadTests(unittest.TestCase): ...@@ -573,7 +561,10 @@ class DownloadTests(unittest.TestCase):
# text encoder has fp16 variants so we can load it # text encoder has fp16 variants so we can load it
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = StableDiffusionPipeline.download( tmpdirname = StableDiffusionPipeline.download(
"hf-internal-testing/stable-diffusion-broken-variants", cache_dir=tmpdirname, variant="fp16" "hf-internal-testing/stable-diffusion-broken-variants",
use_safetensors=use_safetensors,
cache_dir=tmpdirname,
variant="fp16",
) )
all_root_files = [t[-1] for t in os.walk(tmpdirname)] all_root_files = [t[-1] for t in os.walk(tmpdirname)]
...@@ -584,8 +575,6 @@ class DownloadTests(unittest.TestCase): ...@@ -584,8 +575,6 @@ class DownloadTests(unittest.TestCase):
assert len(files) == 15, f"We should only download 15 files, not {len(files)}" assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
# only unet has "no_ema" variant # only unet has "no_ema" variant
diffusers.utils.import_utils._safetensors_available = True
def test_local_save_load_index(self): def test_local_save_load_index(self):
prompt = "hello" prompt = "hello"
for variant in [None, "fp16"]: for variant in [None, "fp16"]:
...@@ -961,10 +950,6 @@ class PipelineFastTests(unittest.TestCase): ...@@ -961,10 +950,6 @@ class PipelineFastTests(unittest.TestCase):
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
import diffusers
diffusers.utils.import_utils._safetensors_available = True
def dummy_image(self): def dummy_image(self):
batch_size = 1 batch_size = 1
num_channels = 3 num_channels = 3
...@@ -1319,14 +1304,13 @@ class PipelineFastTests(unittest.TestCase): ...@@ -1319,14 +1304,13 @@ class PipelineFastTests(unittest.TestCase):
assert not os.path.exists(os.path.join(path, "diffusion_pytorch_model.bin")) assert not os.path.exists(os.path.join(path, "diffusion_pytorch_model.bin"))
def test_no_safetensors_download_when_doing_pytorch(self): def test_no_safetensors_download_when_doing_pytorch(self):
# mock diffusers safetensors not available use_safetensors = False
import diffusers
diffusers.utils.import_utils._safetensors_available = False
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
_ = StableDiffusionPipeline.from_pretrained( _ = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/diffusers-stable-diffusion-tiny-all", cache_dir=tmpdirname "hf-internal-testing/diffusers-stable-diffusion-tiny-all",
cache_dir=tmpdirname,
use_safetensors=use_safetensors,
) )
path = os.path.join( path = os.path.join(
...@@ -1341,8 +1325,6 @@ class PipelineFastTests(unittest.TestCase): ...@@ -1341,8 +1325,6 @@ class PipelineFastTests(unittest.TestCase):
# pytorch does # pytorch does
assert os.path.exists(os.path.join(path, "diffusion_pytorch_model.bin")) assert os.path.exists(os.path.join(path, "diffusion_pytorch_model.bin"))
diffusers.utils.import_utils._safetensors_available = True
def test_optional_components(self): def test_optional_components(self):
unet = self.dummy_cond_unet() unet = self.dummy_cond_unet()
pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler") pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment