Unverified Commit 306a7bd0 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[ONNX] Don't download ONNX model by default (#4338)

* [Download] Don't download ONNX weights by default

* [Download] Don't download ONNX weights by default

* [Download] Don't download ONNX weights by default

* fix more

* finish

* finish

* finish
parent c7250f2b
...@@ -494,6 +494,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -494,6 +494,7 @@ class DiffusionPipeline(ConfigMixin):
_optional_components = [] _optional_components = []
_exclude_from_cpu_offload = [] _exclude_from_cpu_offload = []
_load_connected_pipes = False _load_connected_pipes = False
_is_onnx = False
def register_modules(self, **kwargs): def register_modules(self, **kwargs):
# import it here to avoid circular import # import it here to avoid circular import
...@@ -839,6 +840,11 @@ class DiffusionPipeline(ConfigMixin): ...@@ -839,6 +840,11 @@ class DiffusionPipeline(ConfigMixin):
If set to `None`, the safetensors weights are downloaded if they're available **and** if the If set to `None`, the safetensors weights are downloaded if they're available **and** if the
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
weights. If set to `False`, safetensors weights are not loaded. weights. If set to `False`, safetensors weights are not loaded.
use_onnx (`bool`, *optional*, defaults to `None`):
If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
`False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending
with `.onnx` and `.pb`.
kwargs (remaining dictionary of keyword arguments, *optional*): kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
class). The overwritten components are passed directly to the pipelines `__init__` method. See example class). The overwritten components are passed directly to the pipelines `__init__` method. See example
...@@ -1268,6 +1274,15 @@ class DiffusionPipeline(ConfigMixin): ...@@ -1268,6 +1274,15 @@ class DiffusionPipeline(ConfigMixin):
variant (`str`, *optional*): variant (`str`, *optional*):
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`. loading `from_flax`.
use_safetensors (`bool`, *optional*, defaults to `None`):
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
weights. If set to `False`, safetensors weights are not loaded.
use_onnx (`bool`, *optional*, defaults to `False`):
If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
`False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending
with `.onnx` and `.pb`.
Returns: Returns:
`os.PathLike`: `os.PathLike`:
...@@ -1293,6 +1308,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -1293,6 +1308,7 @@ class DiffusionPipeline(ConfigMixin):
custom_revision = kwargs.pop("custom_revision", None) custom_revision = kwargs.pop("custom_revision", None)
variant = kwargs.pop("variant", None) variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
if use_safetensors and not is_safetensors_available(): if use_safetensors and not is_safetensors_available():
...@@ -1364,7 +1380,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -1364,7 +1380,7 @@ class DiffusionPipeline(ConfigMixin):
pretrained_model_name, use_auth_token, variant, revision, model_filenames pretrained_model_name, use_auth_token, variant, revision, model_filenames
) )
model_folder_names = {os.path.split(f)[0] for f in model_filenames} model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}
# all filenames compatible with variant will be added # all filenames compatible with variant will be added
allow_patterns = list(model_filenames) allow_patterns = list(model_filenames)
...@@ -1411,6 +1427,10 @@ class DiffusionPipeline(ConfigMixin): ...@@ -1411,6 +1427,10 @@ class DiffusionPipeline(ConfigMixin):
): ):
ignore_patterns = ["*.bin", "*.msgpack"] ignore_patterns = ["*.bin", "*.msgpack"]
use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
if not use_onnx:
ignore_patterns += ["*.onnx", "*.pb"]
safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")} safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")} safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
if ( if (
...@@ -1423,6 +1443,10 @@ class DiffusionPipeline(ConfigMixin): ...@@ -1423,6 +1443,10 @@ class DiffusionPipeline(ConfigMixin):
else: else:
ignore_patterns = ["*.safetensors", "*.msgpack"] ignore_patterns = ["*.safetensors", "*.msgpack"]
use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
if not use_onnx:
ignore_patterns += ["*.onnx", "*.pb"]
bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")} bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")} bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames: if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
......
...@@ -41,6 +41,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): ...@@ -41,6 +41,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
feature_extractor: CLIPImageProcessor feature_extractor: CLIPImageProcessor
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
_is_onnx = True
def __init__( def __init__(
self, self,
......
...@@ -98,6 +98,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -98,6 +98,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
feature_extractor: CLIPImageProcessor feature_extractor: CLIPImageProcessor
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
_is_onnx = True
def __init__( def __init__(
self, self,
......
...@@ -90,6 +90,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -90,6 +90,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
feature_extractor: CLIPImageProcessor feature_extractor: CLIPImageProcessor
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
_is_onnx = True
def __init__( def __init__(
self, self,
......
...@@ -67,6 +67,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -67,6 +67,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
_is_onnx = True
vae_encoder: OnnxRuntimeModel vae_encoder: OnnxRuntimeModel
vae_decoder: OnnxRuntimeModel vae_decoder: OnnxRuntimeModel
......
...@@ -46,6 +46,8 @@ def preprocess(image): ...@@ -46,6 +46,8 @@ def preprocess(image):
class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
_is_onnx = True
def __init__( def __init__(
self, self,
vae: OnnxRuntimeModel, vae: OnnxRuntimeModel,
......
...@@ -310,6 +310,49 @@ class DownloadTests(unittest.TestCase): ...@@ -310,6 +310,49 @@ class DownloadTests(unittest.TestCase):
assert len([f for f in files if ".bin" in f]) == 8 assert len([f for f in files if ".bin" in f]) == 8
assert not any(".safetensors" in f for f in files) assert not any(".safetensors" in f for f in files)
def test_download_no_openvino_by_default(self):
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = DiffusionPipeline.download(
"hf-internal-testing/tiny-stable-diffusion-open-vino",
cache_dir=tmpdirname,
)
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
files = [item for sublist in all_root_files for item in sublist]
# make sure that by default no openvino weights are downloaded
assert all((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files)
assert not any("openvino_" in f for f in files)
def test_download_no_onnx_by_default(self):
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = DiffusionPipeline.download(
"hf-internal-testing/tiny-random-OnnxStableDiffusionPipeline",
cache_dir=tmpdirname,
)
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
files = [item for sublist in all_root_files for item in sublist]
# make sure that by default no onnx weights are downloaded
assert all((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files)
assert not any((f.endswith(".onnx") or f.endswith(".pb")) for f in files)
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = DiffusionPipeline.download(
"hf-internal-testing/tiny-random-OnnxStableDiffusionPipeline",
cache_dir=tmpdirname,
use_onnx=True,
)
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
files = [item for sublist in all_root_files for item in sublist]
# if `use_onnx` is specified make sure weights are downloaded
assert any((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files)
assert any((f.endswith(".onnx")) for f in files)
assert any((f.endswith(".pb")) for f in files)
def test_download_no_safety_checker(self): def test_download_no_safety_checker(self):
prompt = "hello" prompt = "hello"
pipe = StableDiffusionPipeline.from_pretrained( pipe = StableDiffusionPipeline.from_pretrained(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment