Unverified Commit e5902ed1 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Download] Smart downloading (#512)

* [Download] Smart downloading

* add test

* finish test

* update

* make style
parent a54cfe68
...@@ -38,8 +38,6 @@ logger = logging.get_logger(__name__) ...@@ -38,8 +38,6 @@ logger = logging.get_logger(__name__)
class OnnxRuntimeModel: class OnnxRuntimeModel:
base_model_prefix = "onnx_model"
def __init__(self, model=None, **kwargs): def __init__(self, model=None, **kwargs):
logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.") logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.")
self.model = model self.model = model
......
...@@ -30,7 +30,10 @@ from PIL import Image ...@@ -30,7 +30,10 @@ from PIL import Image
from tqdm.auto import tqdm from tqdm.auto import tqdm
from .configuration_utils import ConfigMixin from .configuration_utils import ConfigMixin
from .utils import DIFFUSERS_CACHE, BaseOutput, logging from .modeling_utils import WEIGHTS_NAME
from .onnx_utils import ONNX_WEIGHTS_NAME
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, logging
INDEX_FILE = "diffusion_pytorch_model.bin" INDEX_FILE = "diffusion_pytorch_model.bin"
...@@ -285,6 +288,21 @@ class DiffusionPipeline(ConfigMixin): ...@@ -285,6 +288,21 @@ class DiffusionPipeline(ConfigMixin):
# 1. Download the checkpoints and configs # 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained # use snapshot download here to get it working from from_pretrained
if not os.path.isdir(pretrained_model_name_or_path): if not os.path.isdir(pretrained_model_name_or_path):
config_dict = cls.get_config_dict(
pretrained_model_name_or_path,
cache_dir=cache_dir,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
)
# make sure we only download sub-folders and `diffusers` filenames
folder_names = [k for k in config_dict.keys() if not k.startswith("_")]
allow_patterns = [os.path.join(k, "*") for k in folder_names]
allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name]
# download all allow_patterns
cached_folder = snapshot_download( cached_folder = snapshot_download(
pretrained_model_name_or_path, pretrained_model_name_or_path,
cache_dir=cache_dir, cache_dir=cache_dir,
...@@ -293,6 +311,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -293,6 +311,7 @@ class DiffusionPipeline(ConfigMixin):
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
revision=revision, revision=revision,
allow_patterns=allow_patterns,
) )
else: else:
cached_folder = pretrained_model_name_or_path cached_folder = pretrained_model_name_or_path
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import gc import gc
import os
import random import random
import tempfile import tempfile
import unittest import unittest
...@@ -45,8 +46,11 @@ from diffusers import ( ...@@ -45,8 +46,11 @@ from diffusers import (
UNet2DModel, UNet2DModel,
VQModel, VQModel,
) )
from diffusers.modeling_utils import WEIGHTS_NAME
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.testing_utils import floats_tensor, load_image, slow, torch_device from diffusers.testing_utils import floats_tensor, load_image, slow, torch_device
from diffusers.utils import CONFIG_NAME
from PIL import Image from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
...@@ -707,6 +711,27 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -707,6 +711,27 @@ class PipelineTesterMixin(unittest.TestCase):
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
def test_smart_download(self):
model_id = "hf-internal-testing/unet-pipeline-dummy"
with tempfile.TemporaryDirectory() as tmpdirname:
_ = DiffusionPipeline.from_pretrained(model_id, cache_dir=tmpdirname, force_download=True)
local_repo_name = "--".join(["models"] + model_id.split("/"))
snapshot_dir = os.path.join(tmpdirname, local_repo_name, "snapshots")
snapshot_dir = os.path.join(snapshot_dir, os.listdir(snapshot_dir)[0])
# inspect all downloaded files to make sure that everything is included
assert os.path.isfile(os.path.join(snapshot_dir, DiffusionPipeline.config_name))
assert os.path.isfile(os.path.join(snapshot_dir, CONFIG_NAME))
assert os.path.isfile(os.path.join(snapshot_dir, SCHEDULER_CONFIG_NAME))
assert os.path.isfile(os.path.join(snapshot_dir, WEIGHTS_NAME))
assert os.path.isfile(os.path.join(snapshot_dir, "scheduler", SCHEDULER_CONFIG_NAME))
assert os.path.isfile(os.path.join(snapshot_dir, "unet", WEIGHTS_NAME))
assert os.path.isfile(os.path.join(snapshot_dir, "unet", WEIGHTS_NAME))
# let's make sure the super large numpy file:
# https://huggingface.co/hf-internal-testing/unet-pipeline-dummy/blob/main/big_array.npy
# is not downloaded, but all the expected ones
assert not os.path.isfile(os.path.join(snapshot_dir, "big_array.npy"))
@property @property
def dummy_safety_checker(self): def dummy_safety_checker(self):
def check(images, *args, **kwargs): def check(images, *args, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment