Unverified Commit 75ada250 authored by Lucain's avatar Lucain Committed by GitHub
Browse files

Harmonize HF environment variables + deprecate use_auth_token (#6066)

* Harmonize HF environment variables + deprecate use_auth_token

* fix import

* fix
parent 2243a594
...@@ -174,10 +174,4 @@ Set `private=True` in the [`~diffusers.utils.PushToHubMixin.push_to_hub`] functi ...@@ -174,10 +174,4 @@ Set `private=True` in the [`~diffusers.utils.PushToHubMixin.push_to_hub`] functi
controlnet.push_to_hub("my-controlnet-model-private", private=True) controlnet.push_to_hub("my-controlnet-model-private", private=True)
``` ```
Private repositories are only visible to you, and other users won't be able to clone the repository and your repository won't appear in search results. Even if a user has the URL to your private repository, they'll receive a `404 - Sorry, we can't find the page you are looking for.` Private repositories are only visible to you, and other users won't be able to clone the repository and your repository won't appear in search results. Even if a user has the URL to your private repository, they'll receive a `404 - Sorry, we can't find the page you are looking for`. You must be [logged in](https://huggingface.co/docs/huggingface_hub/quick-start#login) to load a model from a private repository.
\ No newline at end of file
To load a model, scheduler, or pipeline from private or gated repositories, set `use_auth_token=True`:
```py
model = ControlNetModel.from_pretrained("your-namespace/my-controlnet-model-private", use_auth_token=True)
```
...@@ -512,7 +512,6 @@ device = torch.device('cpu' if not has_cuda else 'cuda') ...@@ -512,7 +512,6 @@ device = torch.device('cpu' if not has_cuda else 'cuda')
pipe = DiffusionPipeline.from_pretrained( pipe = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", "CompVis/stable-diffusion-v1-4",
safety_checker=None, safety_checker=None,
use_auth_token=True,
custom_pipeline="imagic_stable_diffusion", custom_pipeline="imagic_stable_diffusion",
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
).to(device) ).to(device)
...@@ -552,7 +551,6 @@ device = th.device('cpu' if not has_cuda else 'cuda') ...@@ -552,7 +551,6 @@ device = th.device('cpu' if not has_cuda else 'cuda')
pipe = DiffusionPipeline.from_pretrained( pipe = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", "CompVis/stable-diffusion-v1-4",
use_auth_token=True,
custom_pipeline="seed_resize_stable_diffusion" custom_pipeline="seed_resize_stable_diffusion"
).to(device) ).to(device)
...@@ -588,7 +586,6 @@ generator = th.Generator("cuda").manual_seed(0) ...@@ -588,7 +586,6 @@ generator = th.Generator("cuda").manual_seed(0)
pipe = DiffusionPipeline.from_pretrained( pipe = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", "CompVis/stable-diffusion-v1-4",
use_auth_token=True,
custom_pipeline="/home/mark/open_source/diffusers/examples/community/" custom_pipeline="/home/mark/open_source/diffusers/examples/community/"
).to(device) ).to(device)
...@@ -607,7 +604,6 @@ image.save('./seed_resize/seed_resize_{w}_{h}_image.png'.format(w=width, h=heigh ...@@ -607,7 +604,6 @@ image.save('./seed_resize/seed_resize_{w}_{h}_image.png'.format(w=width, h=heigh
pipe_compare = DiffusionPipeline.from_pretrained( pipe_compare = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", "CompVis/stable-diffusion-v1-4",
use_auth_token=True,
custom_pipeline="/home/mark/open_source/diffusers/examples/community/" custom_pipeline="/home/mark/open_source/diffusers/examples/community/"
).to(device) ).to(device)
......
...@@ -5,10 +5,11 @@ from typing import Dict, List, Union ...@@ -5,10 +5,11 @@ from typing import Dict, List, Union
import safetensors.torch import safetensors.torch
import torch import torch
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from huggingface_hub.utils import validate_hf_hub_args
from diffusers import DiffusionPipeline, __version__ from diffusers import DiffusionPipeline, __version__
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME from diffusers.utils import CONFIG_NAME, ONNX_WEIGHTS_NAME, WEIGHTS_NAME
class CheckpointMergerPipeline(DiffusionPipeline): class CheckpointMergerPipeline(DiffusionPipeline):
...@@ -57,6 +58,7 @@ class CheckpointMergerPipeline(DiffusionPipeline): ...@@ -57,6 +58,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
return (temp_dict, meta_keys) return (temp_dict, meta_keys)
@torch.no_grad() @torch.no_grad()
@validate_hf_hub_args
def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike]], **kwargs): def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike]], **kwargs):
""" """
Returns a new pipeline object of the class 'DiffusionPipeline' with the merged checkpoints(weights) of the models passed Returns a new pipeline object of the class 'DiffusionPipeline' with the merged checkpoints(weights) of the models passed
...@@ -69,7 +71,7 @@ class CheckpointMergerPipeline(DiffusionPipeline): ...@@ -69,7 +71,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
**kwargs: **kwargs:
Supports all the default DiffusionPipeline.get_config_dict kwargs viz.. Supports all the default DiffusionPipeline.get_config_dict kwargs viz..
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map. cache_dir, resume_download, force_download, proxies, local_files_only, token, revision, torch_dtype, device_map.
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2 would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
...@@ -81,12 +83,12 @@ class CheckpointMergerPipeline(DiffusionPipeline): ...@@ -81,12 +83,12 @@ class CheckpointMergerPipeline(DiffusionPipeline):
""" """
# Default kwargs from DiffusionPipeline # Default kwargs from DiffusionPipeline
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", None)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None) token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None) torch_dtype = kwargs.pop("torch_dtype", None)
device_map = kwargs.pop("device_map", None) device_map = kwargs.pop("device_map", None)
...@@ -123,7 +125,7 @@ class CheckpointMergerPipeline(DiffusionPipeline): ...@@ -123,7 +125,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
force_download=force_download, force_download=force_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, token=token,
revision=revision, revision=revision,
) )
config_dicts.append(config_dict) config_dicts.append(config_dict)
...@@ -159,7 +161,7 @@ class CheckpointMergerPipeline(DiffusionPipeline): ...@@ -159,7 +161,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, token=token,
revision=revision, revision=revision,
allow_patterns=allow_patterns, allow_patterns=allow_patterns,
user_agent=user_agent, user_agent=user_agent,
......
...@@ -28,6 +28,7 @@ import PIL.Image ...@@ -28,6 +28,7 @@ import PIL.Image
import tensorrt as trt import tensorrt as trt
import torch import torch
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from huggingface_hub.utils import validate_hf_hub_args
from onnx import shape_inference from onnx import shape_inference
from polygraphy import cuda from polygraphy import cuda
from polygraphy.backend.common import bytes_from_path from polygraphy.backend.common import bytes_from_path
...@@ -50,7 +51,7 @@ from diffusers.pipelines.stable_diffusion import ( ...@@ -50,7 +51,7 @@ from diffusers.pipelines.stable_diffusion import (
StableDiffusionSafetyChecker, StableDiffusionSafetyChecker,
) )
from diffusers.schedulers import DDIMScheduler from diffusers.schedulers import DDIMScheduler
from diffusers.utils import DIFFUSERS_CACHE, logging from diffusers.utils import logging
""" """
...@@ -778,12 +779,13 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): ...@@ -778,12 +779,13 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
self.models["vae_encoder"] = make_VAEEncoder(self.vae, **models_args) self.models["vae_encoder"] = make_VAEEncoder(self.vae, **models_args)
@classmethod @classmethod
@validate_hf_hub_args
def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", None)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None) token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
cls.cached_folder = ( cls.cached_folder = (
...@@ -795,7 +797,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): ...@@ -795,7 +797,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, token=token,
revision=revision, revision=revision,
) )
) )
......
...@@ -28,6 +28,7 @@ import PIL.Image ...@@ -28,6 +28,7 @@ import PIL.Image
import tensorrt as trt import tensorrt as trt
import torch import torch
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from huggingface_hub.utils import validate_hf_hub_args
from onnx import shape_inference from onnx import shape_inference
from polygraphy import cuda from polygraphy import cuda
from polygraphy.backend.common import bytes_from_path from polygraphy.backend.common import bytes_from_path
...@@ -51,7 +52,7 @@ from diffusers.pipelines.stable_diffusion import ( ...@@ -51,7 +52,7 @@ from diffusers.pipelines.stable_diffusion import (
) )
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image
from diffusers.schedulers import DDIMScheduler from diffusers.schedulers import DDIMScheduler
from diffusers.utils import DIFFUSERS_CACHE, logging from diffusers.utils import logging
""" """
...@@ -779,12 +780,13 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline): ...@@ -779,12 +780,13 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
self.models["vae_encoder"] = make_VAEEncoder(self.vae, **models_args) self.models["vae_encoder"] = make_VAEEncoder(self.vae, **models_args)
@classmethod @classmethod
@validate_hf_hub_args
def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", None)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None) token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
cls.cached_folder = ( cls.cached_folder = (
...@@ -796,7 +798,7 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline): ...@@ -796,7 +798,7 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, token=token,
revision=revision, revision=revision,
) )
) )
......
...@@ -27,6 +27,7 @@ import onnx_graphsurgeon as gs ...@@ -27,6 +27,7 @@ import onnx_graphsurgeon as gs
import tensorrt as trt import tensorrt as trt
import torch import torch
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from huggingface_hub.utils import validate_hf_hub_args
from onnx import shape_inference from onnx import shape_inference
from polygraphy import cuda from polygraphy import cuda
from polygraphy.backend.common import bytes_from_path from polygraphy.backend.common import bytes_from_path
...@@ -49,7 +50,7 @@ from diffusers.pipelines.stable_diffusion import ( ...@@ -49,7 +50,7 @@ from diffusers.pipelines.stable_diffusion import (
StableDiffusionSafetyChecker, StableDiffusionSafetyChecker,
) )
from diffusers.schedulers import DDIMScheduler from diffusers.schedulers import DDIMScheduler
from diffusers.utils import DIFFUSERS_CACHE, logging from diffusers.utils import logging
""" """
...@@ -691,12 +692,13 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -691,12 +692,13 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
self.models["vae"] = make_VAE(self.vae, **models_args) self.models["vae"] = make_VAE(self.vae, **models_args)
@classmethod @classmethod
@validate_hf_hub_args
def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", None)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None) token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
cls.cached_folder = ( cls.cached_folder = (
...@@ -708,7 +710,7 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -708,7 +710,7 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, token=token,
revision=revision, revision=revision,
) )
) )
......
...@@ -423,7 +423,7 @@ def import_model_class_from_model_name_or_path( ...@@ -423,7 +423,7 @@ def import_model_class_from_model_name_or_path(
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
): ):
text_encoder_config = PretrainedConfig.from_pretrained( text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True pretrained_model_name_or_path, subfolder=subfolder, revision=revision
) )
model_class = text_encoder_config.architectures[0] model_class = text_encoder_config.architectures[0]
......
...@@ -397,7 +397,7 @@ def import_model_class_from_model_name_or_path( ...@@ -397,7 +397,7 @@ def import_model_class_from_model_name_or_path(
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
): ):
text_encoder_config = PretrainedConfig.from_pretrained( text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True pretrained_model_name_or_path, subfolder=subfolder, revision=revision
) )
model_class = text_encoder_config.architectures[0] model_class = text_encoder_config.architectures[0]
......
...@@ -400,7 +400,7 @@ def import_model_class_from_model_name_or_path( ...@@ -400,7 +400,7 @@ def import_model_class_from_model_name_or_path(
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
): ):
text_encoder_config = PretrainedConfig.from_pretrained( text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True pretrained_model_name_or_path, subfolder=subfolder, revision=revision
) )
model_class = text_encoder_config.architectures[0] model_class = text_encoder_config.architectures[0]
......
...@@ -419,7 +419,7 @@ def import_model_class_from_model_name_or_path( ...@@ -419,7 +419,7 @@ def import_model_class_from_model_name_or_path(
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
): ):
text_encoder_config = PretrainedConfig.from_pretrained( text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True pretrained_model_name_or_path, subfolder=subfolder, revision=revision
) )
model_class = text_encoder_config.architectures[0] model_class = text_encoder_config.architectures[0]
......
...@@ -420,7 +420,7 @@ def import_model_class_from_model_name_or_path( ...@@ -420,7 +420,7 @@ def import_model_class_from_model_name_or_path(
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
): ):
text_encoder_config = PretrainedConfig.from_pretrained( text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True pretrained_model_name_or_path, subfolder=subfolder, revision=revision
) )
model_class = text_encoder_config.architectures[0] model_class = text_encoder_config.architectures[0]
...@@ -975,7 +975,7 @@ def main(args): ...@@ -975,7 +975,7 @@ def main(args):
revision=args.revision, revision=args.revision,
) )
unet = UNet2DConditionModel.from_pretrained( unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, use_auth_token=True args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
) )
if args.controlnet_model_name_or_path: if args.controlnet_model_name_or_path:
......
...@@ -19,6 +19,7 @@ Usage example: ...@@ -19,6 +19,7 @@ Usage example:
import glob import glob
import json import json
import warnings
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
from importlib import import_module from importlib import import_module
...@@ -32,12 +33,12 @@ from . import BaseDiffusersCLICommand ...@@ -32,12 +33,12 @@ from . import BaseDiffusersCLICommand
def conversion_command_factory(args: Namespace): def conversion_command_factory(args: Namespace):
return FP16SafetensorsCommand( if args.use_auth_token:
args.ckpt_id, warnings.warn(
args.fp16, "The `--use_auth_token` flag is deprecated and will be removed in a future version. Authentication is now"
args.use_safetensors, " handled automatically if user is logged in."
args.use_auth_token, )
) return FP16SafetensorsCommand(args.ckpt_id, args.fp16, args.use_safetensors)
class FP16SafetensorsCommand(BaseDiffusersCLICommand): class FP16SafetensorsCommand(BaseDiffusersCLICommand):
...@@ -62,7 +63,7 @@ class FP16SafetensorsCommand(BaseDiffusersCLICommand): ...@@ -62,7 +63,7 @@ class FP16SafetensorsCommand(BaseDiffusersCLICommand):
) )
conversion_parser.set_defaults(func=conversion_command_factory) conversion_parser.set_defaults(func=conversion_command_factory)
def __init__(self, ckpt_id: str, fp16: bool, use_safetensors: bool, use_auth_token: bool): def __init__(self, ckpt_id: str, fp16: bool, use_safetensors: bool):
self.logger = logging.get_logger("diffusers-cli/fp16_safetensors") self.logger = logging.get_logger("diffusers-cli/fp16_safetensors")
self.ckpt_id = ckpt_id self.ckpt_id = ckpt_id
self.local_ckpt_dir = f"/tmp/{ckpt_id}" self.local_ckpt_dir = f"/tmp/{ckpt_id}"
...@@ -75,8 +76,6 @@ class FP16SafetensorsCommand(BaseDiffusersCLICommand): ...@@ -75,8 +76,6 @@ class FP16SafetensorsCommand(BaseDiffusersCLICommand):
"When `use_safetensors` and `fp16` both are False, then this command is of no use." "When `use_safetensors` and `fp16` both are False, then this command is of no use."
) )
self.use_auth_token = use_auth_token
def run(self): def run(self):
if version.parse(huggingface_hub.__version__) < version.parse("0.9.0"): if version.parse(huggingface_hub.__version__) < version.parse("0.9.0"):
raise ImportError( raise ImportError(
...@@ -87,7 +86,7 @@ class FP16SafetensorsCommand(BaseDiffusersCLICommand): ...@@ -87,7 +86,7 @@ class FP16SafetensorsCommand(BaseDiffusersCLICommand):
from huggingface_hub import create_commit from huggingface_hub import create_commit
from huggingface_hub._commit_api import CommitOperationAdd from huggingface_hub._commit_api import CommitOperationAdd
model_index = hf_hub_download(repo_id=self.ckpt_id, filename="model_index.json", token=self.use_auth_token) model_index = hf_hub_download(repo_id=self.ckpt_id, filename="model_index.json")
with open(model_index, "r") as f: with open(model_index, "r") as f:
pipeline_class_name = json.load(f)["_class_name"] pipeline_class_name = json.load(f)["_class_name"]
pipeline_class = getattr(import_module("diffusers"), pipeline_class_name) pipeline_class = getattr(import_module("diffusers"), pipeline_class_name)
...@@ -96,7 +95,7 @@ class FP16SafetensorsCommand(BaseDiffusersCLICommand): ...@@ -96,7 +95,7 @@ class FP16SafetensorsCommand(BaseDiffusersCLICommand):
# Load the appropriate pipeline. We could have use `DiffusionPipeline` # Load the appropriate pipeline. We could have use `DiffusionPipeline`
# here, but just to avoid any rough edge cases. # here, but just to avoid any rough edge cases.
pipeline = pipeline_class.from_pretrained( pipeline = pipeline_class.from_pretrained(
self.ckpt_id, torch_dtype=torch.float16 if self.fp16 else torch.float32, use_auth_token=self.use_auth_token self.ckpt_id, torch_dtype=torch.float16 if self.fp16 else torch.float32
) )
pipeline.save_pretrained( pipeline.save_pretrained(
self.local_ckpt_dir, self.local_ckpt_dir,
......
...@@ -27,12 +27,16 @@ from typing import Any, Dict, Tuple, Union ...@@ -27,12 +27,16 @@ from typing import Any, Dict, Tuple, Union
import numpy as np import numpy as np
from huggingface_hub import create_repo, hf_hub_download from huggingface_hub import create_repo, hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from huggingface_hub.utils import (
EntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError,
validate_hf_hub_args,
)
from requests import HTTPError from requests import HTTPError
from . import __version__ from . import __version__
from .utils import ( from .utils import (
DIFFUSERS_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT, HUGGINGFACE_CO_RESOLVE_ENDPOINT,
DummyObject, DummyObject,
deprecate, deprecate,
...@@ -275,6 +279,7 @@ class ConfigMixin: ...@@ -275,6 +279,7 @@ class ConfigMixin:
return cls.load_config(*args, **kwargs) return cls.load_config(*args, **kwargs)
@classmethod @classmethod
@validate_hf_hub_args
def load_config( def load_config(
cls, cls,
pretrained_model_name_or_path: Union[str, os.PathLike], pretrained_model_name_or_path: Union[str, os.PathLike],
...@@ -311,7 +316,7 @@ class ConfigMixin: ...@@ -311,7 +316,7 @@ class ConfigMixin:
local_files_only (`bool`, *optional*, defaults to `False`): local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub. won't be downloaded from the Hub.
use_auth_token (`str` or *bool*, *optional*): token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used. `diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`): revision (`str`, *optional*, defaults to `"main"`):
...@@ -329,11 +334,11 @@ class ConfigMixin: ...@@ -329,11 +334,11 @@ class ConfigMixin:
A dictionary of all the parameters stored in a JSON configuration file. A dictionary of all the parameters stored in a JSON configuration file.
""" """
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
use_auth_token = kwargs.pop("use_auth_token", None) token = kwargs.pop("token", None)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
_ = kwargs.pop("mirror", None) _ = kwargs.pop("mirror", None)
...@@ -376,7 +381,7 @@ class ConfigMixin: ...@@ -376,7 +381,7 @@ class ConfigMixin:
proxies=proxies, proxies=proxies,
resume_download=resume_download, resume_download=resume_download,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, token=token,
user_agent=user_agent, user_agent=user_agent,
subfolder=subfolder, subfolder=subfolder,
revision=revision, revision=revision,
...@@ -385,8 +390,7 @@ class ConfigMixin: ...@@ -385,8 +390,7 @@ class ConfigMixin:
raise EnvironmentError( raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier" f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
" listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a" " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
" token having permission to this repo with `use_auth_token` or log in with `huggingface-cli" " token having permission to this repo with `token` or log in with `huggingface-cli login`."
" login`."
) )
except RevisionNotFoundError: except RevisionNotFoundError:
raise EnvironmentError( raise EnvironmentError(
......
...@@ -15,11 +15,10 @@ import os ...@@ -15,11 +15,10 @@ import os
from typing import Dict, Union from typing import Dict, Union
import torch import torch
from huggingface_hub.utils import validate_hf_hub_args
from safetensors import safe_open from safetensors import safe_open
from ..utils import ( from ..utils import (
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
_get_model_file, _get_model_file,
is_transformers_available, is_transformers_available,
logging, logging,
...@@ -43,6 +42,7 @@ logger = logging.get_logger(__name__) ...@@ -43,6 +42,7 @@ logger = logging.get_logger(__name__)
class IPAdapterMixin: class IPAdapterMixin:
"""Mixin for handling IP Adapters.""" """Mixin for handling IP Adapters."""
@validate_hf_hub_args
def load_ip_adapter( def load_ip_adapter(
self, self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
...@@ -77,7 +77,7 @@ class IPAdapterMixin: ...@@ -77,7 +77,7 @@ class IPAdapterMixin:
local_files_only (`bool`, *optional*, defaults to `False`): local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub. won't be downloaded from the Hub.
use_auth_token (`str` or *bool*, *optional*): token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used. `diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`): revision (`str`, *optional*, defaults to `"main"`):
...@@ -88,12 +88,12 @@ class IPAdapterMixin: ...@@ -88,12 +88,12 @@ class IPAdapterMixin:
""" """
# Load the main state dict first. # Load the main state dict first.
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) local_files_only = kwargs.pop("local_files_only", None)
use_auth_token = kwargs.pop("use_auth_token", None) token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
user_agent = { user_agent = {
...@@ -110,7 +110,7 @@ class IPAdapterMixin: ...@@ -110,7 +110,7 @@ class IPAdapterMixin:
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, token=token,
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
user_agent=user_agent, user_agent=user_agent,
......
...@@ -18,14 +18,13 @@ from typing import Callable, Dict, List, Optional, Union ...@@ -18,14 +18,13 @@ from typing import Callable, Dict, List, Optional, Union
import safetensors import safetensors
import torch import torch
from huggingface_hub import model_info from huggingface_hub import model_info
from huggingface_hub.utils import validate_hf_hub_args
from packaging import version from packaging import version
from torch import nn from torch import nn
from .. import __version__ from .. import __version__
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import ( from ..utils import (
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
_get_model_file, _get_model_file,
convert_state_dict_to_diffusers, convert_state_dict_to_diffusers,
...@@ -132,6 +131,7 @@ class LoraLoaderMixin: ...@@ -132,6 +131,7 @@ class LoraLoaderMixin:
) )
@classmethod @classmethod
@validate_hf_hub_args
def lora_state_dict( def lora_state_dict(
cls, cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
...@@ -174,7 +174,7 @@ class LoraLoaderMixin: ...@@ -174,7 +174,7 @@ class LoraLoaderMixin:
local_files_only (`bool`, *optional*, defaults to `False`): local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub. won't be downloaded from the Hub.
use_auth_token (`str` or *bool*, *optional*): token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used. `diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`): revision (`str`, *optional*, defaults to `"main"`):
...@@ -195,12 +195,12 @@ class LoraLoaderMixin: ...@@ -195,12 +195,12 @@ class LoraLoaderMixin:
""" """
# Load the main state dict first which has the LoRA layers for either of # Load the main state dict first which has the LoRA layers for either of
# UNet and text encoder or both. # UNet and text encoder or both.
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) local_files_only = kwargs.pop("local_files_only", None)
use_auth_token = kwargs.pop("use_auth_token", None) token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None) weight_name = kwargs.pop("weight_name", None)
...@@ -239,7 +239,7 @@ class LoraLoaderMixin: ...@@ -239,7 +239,7 @@ class LoraLoaderMixin:
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, token=token,
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
user_agent=user_agent, user_agent=user_agent,
...@@ -265,7 +265,7 @@ class LoraLoaderMixin: ...@@ -265,7 +265,7 @@ class LoraLoaderMixin:
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, token=token,
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
user_agent=user_agent, user_agent=user_agent,
......
...@@ -18,10 +18,9 @@ from pathlib import Path ...@@ -18,10 +18,9 @@ from pathlib import Path
import requests import requests
import torch import torch
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from huggingface_hub.utils import validate_hf_hub_args
from ..utils import ( from ..utils import (
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
deprecate, deprecate,
is_accelerate_available, is_accelerate_available,
is_omegaconf_available, is_omegaconf_available,
...@@ -52,6 +51,7 @@ class FromSingleFileMixin: ...@@ -52,6 +51,7 @@ class FromSingleFileMixin:
return cls.from_single_file(*args, **kwargs) return cls.from_single_file(*args, **kwargs)
@classmethod @classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path, **kwargs): def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
r""" r"""
Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors` Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors`
...@@ -81,7 +81,7 @@ class FromSingleFileMixin: ...@@ -81,7 +81,7 @@ class FromSingleFileMixin:
local_files_only (`bool`, *optional*, defaults to `False`): local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub. won't be downloaded from the Hub.
use_auth_token (`str` or *bool*, *optional*): token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used. `diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`): revision (`str`, *optional*, defaults to `"main"`):
...@@ -154,12 +154,12 @@ class FromSingleFileMixin: ...@@ -154,12 +154,12 @@ class FromSingleFileMixin:
original_config_file = kwargs.pop("original_config_file", None) original_config_file = kwargs.pop("original_config_file", None)
config_files = kwargs.pop("config_files", None) config_files = kwargs.pop("config_files", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", None)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) local_files_only = kwargs.pop("local_files_only", None)
use_auth_token = kwargs.pop("use_auth_token", None) token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
extract_ema = kwargs.pop("extract_ema", False) extract_ema = kwargs.pop("extract_ema", False)
image_size = kwargs.pop("image_size", None) image_size = kwargs.pop("image_size", None)
...@@ -253,7 +253,7 @@ class FromSingleFileMixin: ...@@ -253,7 +253,7 @@ class FromSingleFileMixin:
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, token=token,
revision=revision, revision=revision,
force_download=force_download, force_download=force_download,
) )
...@@ -293,6 +293,7 @@ class FromOriginalVAEMixin: ...@@ -293,6 +293,7 @@ class FromOriginalVAEMixin:
""" """
@classmethod @classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path, **kwargs): def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
r""" r"""
Instantiate a [`AutoencoderKL`] from pretrained ControlNet weights saved in the original `.ckpt` or Instantiate a [`AutoencoderKL`] from pretrained ControlNet weights saved in the original `.ckpt` or
...@@ -322,7 +323,7 @@ class FromOriginalVAEMixin: ...@@ -322,7 +323,7 @@ class FromOriginalVAEMixin:
local_files_only (`bool`, *optional*, defaults to `False`): local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to True, the model Whether to only load local model weights and configuration files or not. If set to True, the model
won't be downloaded from the Hub. won't be downloaded from the Hub.
use_auth_token (`str` or *bool*, *optional*): token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used. `diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`): revision (`str`, *optional*, defaults to `"main"`):
...@@ -379,12 +380,12 @@ class FromOriginalVAEMixin: ...@@ -379,12 +380,12 @@ class FromOriginalVAEMixin:
) )
config_file = kwargs.pop("config_file", None) config_file = kwargs.pop("config_file", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", None)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) local_files_only = kwargs.pop("local_files_only", None)
use_auth_token = kwargs.pop("use_auth_token", None) token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
image_size = kwargs.pop("image_size", None) image_size = kwargs.pop("image_size", None)
scaling_factor = kwargs.pop("scaling_factor", None) scaling_factor = kwargs.pop("scaling_factor", None)
...@@ -425,7 +426,7 @@ class FromOriginalVAEMixin: ...@@ -425,7 +426,7 @@ class FromOriginalVAEMixin:
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, token=token,
revision=revision, revision=revision,
force_download=force_download, force_download=force_download,
) )
...@@ -490,6 +491,7 @@ class FromOriginalControlnetMixin: ...@@ -490,6 +491,7 @@ class FromOriginalControlnetMixin:
""" """
@classmethod @classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path, **kwargs): def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
r""" r"""
Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or
...@@ -519,7 +521,7 @@ class FromOriginalControlnetMixin: ...@@ -519,7 +521,7 @@ class FromOriginalControlnetMixin:
local_files_only (`bool`, *optional*, defaults to `False`): local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to True, the model Whether to only load local model weights and configuration files or not. If set to True, the model
won't be downloaded from the Hub. won't be downloaded from the Hub.
use_auth_token (`str` or *bool*, *optional*): token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used. `diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`): revision (`str`, *optional*, defaults to `"main"`):
...@@ -555,12 +557,12 @@ class FromOriginalControlnetMixin: ...@@ -555,12 +557,12 @@ class FromOriginalControlnetMixin:
from ..pipelines.stable_diffusion.convert_from_ckpt import download_controlnet_from_original_ckpt from ..pipelines.stable_diffusion.convert_from_ckpt import download_controlnet_from_original_ckpt
config_file = kwargs.pop("config_file", None) config_file = kwargs.pop("config_file", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", None)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) local_files_only = kwargs.pop("local_files_only", None)
use_auth_token = kwargs.pop("use_auth_token", None) token = kwargs.pop("token", None)
num_in_channels = kwargs.pop("num_in_channels", None) num_in_channels = kwargs.pop("num_in_channels", None)
use_linear_projection = kwargs.pop("use_linear_projection", None) use_linear_projection = kwargs.pop("use_linear_projection", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
...@@ -603,7 +605,7 @@ class FromOriginalControlnetMixin: ...@@ -603,7 +605,7 @@ class FromOriginalControlnetMixin:
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, token=token,
revision=revision, revision=revision,
force_download=force_download, force_download=force_download,
) )
......
...@@ -15,16 +15,10 @@ from typing import Dict, List, Optional, Union ...@@ -15,16 +15,10 @@ from typing import Dict, List, Optional, Union
import safetensors import safetensors
import torch import torch
from huggingface_hub.utils import validate_hf_hub_args
from torch import nn from torch import nn
from ..utils import ( from ..utils import _get_model_file, is_accelerate_available, is_transformers_available, logging
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
_get_model_file,
is_accelerate_available,
is_transformers_available,
logging,
)
if is_transformers_available(): if is_transformers_available():
...@@ -39,13 +33,14 @@ TEXT_INVERSION_NAME = "learned_embeds.bin" ...@@ -39,13 +33,14 @@ TEXT_INVERSION_NAME = "learned_embeds.bin"
TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors" TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
@validate_hf_hub_args
def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs): def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs):
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) local_files_only = kwargs.pop("local_files_only", None)
use_auth_token = kwargs.pop("use_auth_token", None) token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None) weight_name = kwargs.pop("weight_name", None)
...@@ -79,7 +74,7 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs) ...@@ -79,7 +74,7 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, token=token,
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
user_agent=user_agent, user_agent=user_agent,
...@@ -100,7 +95,7 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs) ...@@ -100,7 +95,7 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, token=token,
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
user_agent=user_agent, user_agent=user_agent,
...@@ -267,6 +262,7 @@ class TextualInversionLoaderMixin: ...@@ -267,6 +262,7 @@ class TextualInversionLoaderMixin:
return all_tokens, all_embeddings return all_tokens, all_embeddings
@validate_hf_hub_args
def load_textual_inversion( def load_textual_inversion(
self, self,
pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]], pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
...@@ -320,7 +316,7 @@ class TextualInversionLoaderMixin: ...@@ -320,7 +316,7 @@ class TextualInversionLoaderMixin:
local_files_only (`bool`, *optional*, defaults to `False`): local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub. won't be downloaded from the Hub.
use_auth_token (`str` or *bool*, *optional*): token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used. `diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`): revision (`str`, *optional*, defaults to `"main"`):
......
...@@ -19,13 +19,12 @@ from typing import Callable, Dict, List, Optional, Union ...@@ -19,13 +19,12 @@ from typing import Callable, Dict, List, Optional, Union
import safetensors import safetensors
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from huggingface_hub.utils import validate_hf_hub_args
from torch import nn from torch import nn
from ..models.embeddings import ImageProjection, Resampler from ..models.embeddings import ImageProjection, Resampler
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import ( from ..utils import (
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
_get_model_file, _get_model_file,
delete_adapter_layers, delete_adapter_layers,
...@@ -62,6 +61,7 @@ class UNet2DConditionLoadersMixin: ...@@ -62,6 +61,7 @@ class UNet2DConditionLoadersMixin:
text_encoder_name = TEXT_ENCODER_NAME text_encoder_name = TEXT_ENCODER_NAME
unet_name = UNET_NAME unet_name = UNET_NAME
@validate_hf_hub_args
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
r""" r"""
Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
...@@ -95,7 +95,7 @@ class UNet2DConditionLoadersMixin: ...@@ -95,7 +95,7 @@ class UNet2DConditionLoadersMixin:
local_files_only (`bool`, *optional*, defaults to `False`): local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub. won't be downloaded from the Hub.
use_auth_token (`str` or *bool*, *optional*): token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used. `diffusers-cli login` (stored in `~/.huggingface`) is used.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
...@@ -130,12 +130,12 @@ class UNet2DConditionLoadersMixin: ...@@ -130,12 +130,12 @@ class UNet2DConditionLoadersMixin:
from ..models.attention_processor import CustomDiffusionAttnProcessor from ..models.attention_processor import CustomDiffusionAttnProcessor
from ..models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer from ..models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) local_files_only = kwargs.pop("local_files_only", None)
use_auth_token = kwargs.pop("use_auth_token", None) token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None) weight_name = kwargs.pop("weight_name", None)
...@@ -184,7 +184,7 @@ class UNet2DConditionLoadersMixin: ...@@ -184,7 +184,7 @@ class UNet2DConditionLoadersMixin:
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, token=token,
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
user_agent=user_agent, user_agent=user_agent,
...@@ -204,7 +204,7 @@ class UNet2DConditionLoadersMixin: ...@@ -204,7 +204,7 @@ class UNet2DConditionLoadersMixin:
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, token=token,
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
user_agent=user_agent, user_agent=user_agent,
......
...@@ -24,13 +24,17 @@ from flax.core.frozen_dict import FrozenDict, unfreeze ...@@ -24,13 +24,17 @@ from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.serialization import from_bytes, to_bytes from flax.serialization import from_bytes, to_bytes
from flax.traverse_util import flatten_dict, unflatten_dict from flax.traverse_util import flatten_dict, unflatten_dict
from huggingface_hub import create_repo, hf_hub_download from huggingface_hub import create_repo, hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from huggingface_hub.utils import (
EntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError,
validate_hf_hub_args,
)
from requests import HTTPError from requests import HTTPError
from .. import __version__, is_torch_available from .. import __version__, is_torch_available
from ..utils import ( from ..utils import (
CONFIG_NAME, CONFIG_NAME,
DIFFUSERS_CACHE,
FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_NAME,
HUGGINGFACE_CO_RESOLVE_ENDPOINT, HUGGINGFACE_CO_RESOLVE_ENDPOINT,
WEIGHTS_NAME, WEIGHTS_NAME,
...@@ -197,6 +201,7 @@ class FlaxModelMixin(PushToHubMixin): ...@@ -197,6 +201,7 @@ class FlaxModelMixin(PushToHubMixin):
raise NotImplementedError(f"init_weights method has to be implemented for {self}") raise NotImplementedError(f"init_weights method has to be implemented for {self}")
@classmethod @classmethod
@validate_hf_hub_args
def from_pretrained( def from_pretrained(
cls, cls,
pretrained_model_name_or_path: Union[str, os.PathLike], pretrained_model_name_or_path: Union[str, os.PathLike],
...@@ -288,13 +293,13 @@ class FlaxModelMixin(PushToHubMixin): ...@@ -288,13 +293,13 @@ class FlaxModelMixin(PushToHubMixin):
``` ```
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
from_pt = kwargs.pop("from_pt", False) from_pt = kwargs.pop("from_pt", False)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None) token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
...@@ -314,7 +319,7 @@ class FlaxModelMixin(PushToHubMixin): ...@@ -314,7 +319,7 @@ class FlaxModelMixin(PushToHubMixin):
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, token=token,
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
**kwargs, **kwargs,
...@@ -359,7 +364,7 @@ class FlaxModelMixin(PushToHubMixin): ...@@ -359,7 +364,7 @@ class FlaxModelMixin(PushToHubMixin):
proxies=proxies, proxies=proxies,
resume_download=resume_download, resume_download=resume_download,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, token=token,
user_agent=user_agent, user_agent=user_agent,
subfolder=subfolder, subfolder=subfolder,
revision=revision, revision=revision,
...@@ -369,7 +374,7 @@ class FlaxModelMixin(PushToHubMixin): ...@@ -369,7 +374,7 @@ class FlaxModelMixin(PushToHubMixin):
raise EnvironmentError( raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli " "token having permission to this repo with `token` or log in with `huggingface-cli "
"login`." "login`."
) )
except RevisionNotFoundError: except RevisionNotFoundError:
......
...@@ -25,14 +25,13 @@ from typing import Any, Callable, List, Optional, Tuple, Union ...@@ -25,14 +25,13 @@ from typing import Any, Callable, List, Optional, Tuple, Union
import safetensors import safetensors
import torch import torch
from huggingface_hub import create_repo from huggingface_hub import create_repo
from huggingface_hub.utils import validate_hf_hub_args
from torch import Tensor, nn from torch import Tensor, nn
from .. import __version__ from .. import __version__
from ..utils import ( from ..utils import (
CONFIG_NAME, CONFIG_NAME,
DIFFUSERS_CACHE,
FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_NAME,
HF_HUB_OFFLINE,
MIN_PEFT_VERSION, MIN_PEFT_VERSION,
SAFETENSORS_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
...@@ -535,6 +534,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -535,6 +534,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
) )
@classmethod @classmethod
@validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
r""" r"""
Instantiate a pretrained PyTorch model from a pretrained model configuration. Instantiate a pretrained PyTorch model from a pretrained model configuration.
...@@ -571,7 +571,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -571,7 +571,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
local_files_only(`bool`, *optional*, defaults to `False`): local_files_only(`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub. won't be downloaded from the Hub.
use_auth_token (`str` or *bool*, *optional*): token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used. `diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`): revision (`str`, *optional*, defaults to `"main"`):
...@@ -640,15 +640,15 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -640,15 +640,15 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
``` ```
""" """
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", None)
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
from_flax = kwargs.pop("from_flax", False) from_flax = kwargs.pop("from_flax", False)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
output_loading_info = kwargs.pop("output_loading_info", False) output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) local_files_only = kwargs.pop("local_files_only", None)
use_auth_token = kwargs.pop("use_auth_token", None) token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None) torch_dtype = kwargs.pop("torch_dtype", None)
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
...@@ -718,7 +718,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -718,7 +718,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, token=token,
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
device_map=device_map, device_map=device_map,
...@@ -740,7 +740,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -740,7 +740,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, token=token,
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
user_agent=user_agent, user_agent=user_agent,
...@@ -763,7 +763,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -763,7 +763,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, token=token,
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
user_agent=user_agent, user_agent=user_agent,
...@@ -782,7 +782,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -782,7 +782,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, token=token,
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
user_agent=user_agent, user_agent=user_agent,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment