Unverified Commit 75ada250 authored by Lucain's avatar Lucain Committed by GitHub
Browse files

Harmonize HF environment variables + deprecate use_auth_token (#6066)

* Harmonize HF environment variables + deprecate use_auth_token

* fix import

* fix
parent 2243a594
...@@ -16,8 +16,9 @@ ...@@ -16,8 +16,9 @@
import inspect import inspect
from collections import OrderedDict from collections import OrderedDict
from huggingface_hub.utils import validate_hf_hub_args
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..utils import DIFFUSERS_CACHE
from .controlnet import ( from .controlnet import (
StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline, StableDiffusionControlNetInpaintPipeline,
...@@ -195,6 +196,7 @@ class AutoPipelineForText2Image(ConfigMixin): ...@@ -195,6 +196,7 @@ class AutoPipelineForText2Image(ConfigMixin):
) )
@classmethod @classmethod
@validate_hf_hub_args
def from_pretrained(cls, pretrained_model_or_path, **kwargs): def from_pretrained(cls, pretrained_model_or_path, **kwargs):
r""" r"""
Instantiates a text-to-image Pytorch diffusion pipeline from pretrained pipeline weight. Instantiates a text-to-image Pytorch diffusion pipeline from pretrained pipeline weight.
...@@ -246,7 +248,7 @@ class AutoPipelineForText2Image(ConfigMixin): ...@@ -246,7 +248,7 @@ class AutoPipelineForText2Image(ConfigMixin):
local_files_only (`bool`, *optional*, defaults to `False`): local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub. won't be downloaded from the Hub.
use_auth_token (`str` or *bool*, *optional*): token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used. `diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`): revision (`str`, *optional*, defaults to `"main"`):
...@@ -310,11 +312,11 @@ class AutoPipelineForText2Image(ConfigMixin): ...@@ -310,11 +312,11 @@ class AutoPipelineForText2Image(ConfigMixin):
>>> image = pipeline(prompt).images[0] >>> image = pipeline(prompt).images[0]
``` ```
""" """
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
use_auth_token = kwargs.pop("use_auth_token", None) token = kwargs.pop("token", None)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
...@@ -323,7 +325,7 @@ class AutoPipelineForText2Image(ConfigMixin): ...@@ -323,7 +325,7 @@ class AutoPipelineForText2Image(ConfigMixin):
"force_download": force_download, "force_download": force_download,
"resume_download": resume_download, "resume_download": resume_download,
"proxies": proxies, "proxies": proxies,
"use_auth_token": use_auth_token, "token": token,
"local_files_only": local_files_only, "local_files_only": local_files_only,
"revision": revision, "revision": revision,
} }
...@@ -466,6 +468,7 @@ class AutoPipelineForImage2Image(ConfigMixin): ...@@ -466,6 +468,7 @@ class AutoPipelineForImage2Image(ConfigMixin):
) )
@classmethod @classmethod
@validate_hf_hub_args
def from_pretrained(cls, pretrained_model_or_path, **kwargs): def from_pretrained(cls, pretrained_model_or_path, **kwargs):
r""" r"""
Instantiates a image-to-image Pytorch diffusion pipeline from pretrained pipeline weight. Instantiates a image-to-image Pytorch diffusion pipeline from pretrained pipeline weight.
...@@ -518,7 +521,7 @@ class AutoPipelineForImage2Image(ConfigMixin): ...@@ -518,7 +521,7 @@ class AutoPipelineForImage2Image(ConfigMixin):
local_files_only (`bool`, *optional*, defaults to `False`): local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub. won't be downloaded from the Hub.
use_auth_token (`str` or *bool*, *optional*): token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used. `diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`): revision (`str`, *optional*, defaults to `"main"`):
...@@ -582,11 +585,11 @@ class AutoPipelineForImage2Image(ConfigMixin): ...@@ -582,11 +585,11 @@ class AutoPipelineForImage2Image(ConfigMixin):
>>> image = pipeline(prompt, image).images[0] >>> image = pipeline(prompt, image).images[0]
``` ```
""" """
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
use_auth_token = kwargs.pop("use_auth_token", None) token = kwargs.pop("token", None)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
...@@ -595,7 +598,7 @@ class AutoPipelineForImage2Image(ConfigMixin): ...@@ -595,7 +598,7 @@ class AutoPipelineForImage2Image(ConfigMixin):
"force_download": force_download, "force_download": force_download,
"resume_download": resume_download, "resume_download": resume_download,
"proxies": proxies, "proxies": proxies,
"use_auth_token": use_auth_token, "token": token,
"local_files_only": local_files_only, "local_files_only": local_files_only,
"revision": revision, "revision": revision,
} }
...@@ -742,6 +745,7 @@ class AutoPipelineForInpainting(ConfigMixin): ...@@ -742,6 +745,7 @@ class AutoPipelineForInpainting(ConfigMixin):
) )
@classmethod @classmethod
@validate_hf_hub_args
def from_pretrained(cls, pretrained_model_or_path, **kwargs): def from_pretrained(cls, pretrained_model_or_path, **kwargs):
r""" r"""
Instantiates a inpainting Pytorch diffusion pipeline from pretrained pipeline weight. Instantiates a inpainting Pytorch diffusion pipeline from pretrained pipeline weight.
...@@ -793,7 +797,7 @@ class AutoPipelineForInpainting(ConfigMixin): ...@@ -793,7 +797,7 @@ class AutoPipelineForInpainting(ConfigMixin):
local_files_only (`bool`, *optional*, defaults to `False`): local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub. won't be downloaded from the Hub.
use_auth_token (`str` or *bool*, *optional*): token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used. `diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`): revision (`str`, *optional*, defaults to `"main"`):
...@@ -857,11 +861,11 @@ class AutoPipelineForInpainting(ConfigMixin): ...@@ -857,11 +861,11 @@ class AutoPipelineForInpainting(ConfigMixin):
>>> image = pipeline(prompt, image=init_image, mask_image=mask_image).images[0] >>> image = pipeline(prompt, image=init_image, mask_image=mask_image).images[0]
``` ```
""" """
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
use_auth_token = kwargs.pop("use_auth_token", None) token = kwargs.pop("token", None)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
...@@ -870,7 +874,7 @@ class AutoPipelineForInpainting(ConfigMixin): ...@@ -870,7 +874,7 @@ class AutoPipelineForInpainting(ConfigMixin):
"force_download": force_download, "force_download": force_download,
"resume_download": resume_download, "resume_download": resume_download,
"proxies": proxies, "proxies": proxies,
"use_auth_token": use_auth_token, "token": token,
"local_files_only": local_files_only, "local_files_only": local_files_only,
"revision": revision, "revision": revision,
} }
......
...@@ -22,6 +22,7 @@ from typing import Optional, Union ...@@ -22,6 +22,7 @@ from typing import Optional, Union
import numpy as np import numpy as np
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from huggingface_hub.utils import validate_hf_hub_args
from ..utils import ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, is_onnx_available, logging from ..utils import ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, is_onnx_available, logging
...@@ -130,10 +131,11 @@ class OnnxRuntimeModel: ...@@ -130,10 +131,11 @@ class OnnxRuntimeModel:
self._save_pretrained(save_directory, **kwargs) self._save_pretrained(save_directory, **kwargs)
@classmethod @classmethod
@validate_hf_hub_args
def _from_pretrained( def _from_pretrained(
cls, cls,
model_id: Union[str, Path], model_id: Union[str, Path],
use_auth_token: Optional[Union[bool, str, None]] = None, token: Optional[Union[bool, str, None]] = None,
revision: Optional[Union[str, None]] = None, revision: Optional[Union[str, None]] = None,
force_download: bool = False, force_download: bool = False,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
...@@ -148,7 +150,7 @@ class OnnxRuntimeModel: ...@@ -148,7 +150,7 @@ class OnnxRuntimeModel:
Arguments: Arguments:
model_id (`str` or `Path`): model_id (`str` or `Path`):
Directory from which to load Directory from which to load
use_auth_token (`str` or `bool`): token (`str` or `bool`):
Is needed to load models from a private or gated repository Is needed to load models from a private or gated repository
revision (`str`): revision (`str`):
Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id
...@@ -179,7 +181,7 @@ class OnnxRuntimeModel: ...@@ -179,7 +181,7 @@ class OnnxRuntimeModel:
model_cache_path = hf_hub_download( model_cache_path = hf_hub_download(
repo_id=model_id, repo_id=model_id,
filename=model_file_name, filename=model_file_name,
use_auth_token=use_auth_token, token=token,
revision=revision, revision=revision,
cache_dir=cache_dir, cache_dir=cache_dir,
force_download=force_download, force_download=force_download,
...@@ -190,11 +192,12 @@ class OnnxRuntimeModel: ...@@ -190,11 +192,12 @@ class OnnxRuntimeModel:
return cls(model=model, **kwargs) return cls(model=model, **kwargs)
@classmethod @classmethod
@validate_hf_hub_args
def from_pretrained( def from_pretrained(
cls, cls,
model_id: Union[str, Path], model_id: Union[str, Path],
force_download: bool = True, force_download: bool = True,
use_auth_token: Optional[str] = None, token: Optional[str] = None,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
**model_kwargs, **model_kwargs,
): ):
...@@ -207,6 +210,6 @@ class OnnxRuntimeModel: ...@@ -207,6 +210,6 @@ class OnnxRuntimeModel:
revision=revision, revision=revision,
cache_dir=cache_dir, cache_dir=cache_dir,
force_download=force_download, force_download=force_download,
use_auth_token=use_auth_token, token=token,
**model_kwargs, **model_kwargs,
) )
...@@ -24,6 +24,7 @@ import numpy as np ...@@ -24,6 +24,7 @@ import numpy as np
import PIL.Image import PIL.Image
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict
from huggingface_hub import create_repo, snapshot_download from huggingface_hub import create_repo, snapshot_download
from huggingface_hub.utils import validate_hf_hub_args
from PIL import Image from PIL import Image
from tqdm.auto import tqdm from tqdm.auto import tqdm
...@@ -32,7 +33,6 @@ from ..models.modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin ...@@ -32,7 +33,6 @@ from ..models.modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin
from ..schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin from ..schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin
from ..utils import ( from ..utils import (
CONFIG_NAME, CONFIG_NAME,
DIFFUSERS_CACHE,
BaseOutput, BaseOutput,
PushToHubMixin, PushToHubMixin,
http_user_agent, http_user_agent,
...@@ -227,6 +227,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -227,6 +227,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
) )
@classmethod @classmethod
@validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
r""" r"""
Instantiate a Flax-based diffusion pipeline from pretrained pipeline weights. Instantiate a Flax-based diffusion pipeline from pretrained pipeline weights.
...@@ -264,7 +265,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -264,7 +265,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
local_files_only (`bool`, *optional*, defaults to `False`): local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub. won't be downloaded from the Hub.
use_auth_token (`str` or *bool*, *optional*): token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used. `diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`): revision (`str`, *optional*, defaults to `"main"`):
...@@ -314,11 +315,11 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -314,11 +315,11 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
>>> dpm_params["scheduler"] = dpmpp_state >>> dpm_params["scheduler"] = dpmpp_state
``` ```
""" """
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", None)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None) token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
from_pt = kwargs.pop("from_pt", False) from_pt = kwargs.pop("from_pt", False)
use_memory_efficient_attention = kwargs.pop("use_memory_efficient_attention", False) use_memory_efficient_attention = kwargs.pop("use_memory_efficient_attention", False)
...@@ -334,7 +335,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -334,7 +335,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, token=token,
revision=revision, revision=revision,
) )
# make sure we only download sub-folders and `diffusers` filenames # make sure we only download sub-folders and `diffusers` filenames
...@@ -365,7 +366,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -365,7 +366,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, token=token,
revision=revision, revision=revision,
allow_patterns=allow_patterns, allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns, ignore_patterns=ignore_patterns,
......
...@@ -28,7 +28,14 @@ from typing import Any, Callable, Dict, List, Optional, Union ...@@ -28,7 +28,14 @@ from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import torch import torch
from huggingface_hub import ModelCard, create_repo, hf_hub_download, model_info, snapshot_download from huggingface_hub import (
ModelCard,
create_repo,
hf_hub_download,
model_info,
snapshot_download,
)
from huggingface_hub.utils import validate_hf_hub_args
from packaging import version from packaging import version
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from tqdm.auto import tqdm from tqdm.auto import tqdm
...@@ -40,8 +47,6 @@ from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME ...@@ -40,8 +47,6 @@ from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from ..utils import ( from ..utils import (
CONFIG_NAME, CONFIG_NAME,
DEPRECATED_REVISION_ARGS, DEPRECATED_REVISION_ARGS,
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
SAFETENSORS_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
BaseOutput, BaseOutput,
...@@ -249,10 +254,11 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi ...@@ -249,10 +254,11 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
return usable_filenames, variant_filenames return usable_filenames, variant_filenames
def warn_deprecated_model_variant(pretrained_model_name_or_path, use_auth_token, variant, revision, model_filenames): @validate_hf_hub_args
def warn_deprecated_model_variant(pretrained_model_name_or_path, token, variant, revision, model_filenames):
info = model_info( info = model_info(
pretrained_model_name_or_path, pretrained_model_name_or_path,
use_auth_token=use_auth_token, token=token,
revision=None, revision=None,
) )
filenames = {sibling.rfilename for sibling in info.siblings} filenames = {sibling.rfilename for sibling in info.siblings}
...@@ -375,7 +381,6 @@ def _get_pipeline_class( ...@@ -375,7 +381,6 @@ def _get_pipeline_class(
custom_pipeline, custom_pipeline,
module_file=file_name, module_file=file_name,
class_name=class_name, class_name=class_name,
repo_id=repo_id,
cache_dir=cache_dir, cache_dir=cache_dir,
revision=revision, revision=revision,
) )
...@@ -909,6 +914,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -909,6 +914,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
return torch.float32 return torch.float32
@classmethod @classmethod
@validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
r""" r"""
Instantiate a PyTorch diffusion pipeline from pretrained pipeline weights. Instantiate a PyTorch diffusion pipeline from pretrained pipeline weights.
...@@ -976,7 +982,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -976,7 +982,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
local_files_only (`bool`, *optional*, defaults to `False`): local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub. won't be downloaded from the Hub.
use_auth_token (`str` or *bool*, *optional*): token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used. `diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`): revision (`str`, *optional*, defaults to `"main"`):
...@@ -1056,12 +1062,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1056,12 +1062,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
>>> pipeline.scheduler = scheduler >>> pipeline.scheduler = scheduler
``` ```
""" """
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", None)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) local_files_only = kwargs.pop("local_files_only", None)
use_auth_token = kwargs.pop("use_auth_token", None) token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
from_flax = kwargs.pop("from_flax", False) from_flax = kwargs.pop("from_flax", False)
torch_dtype = kwargs.pop("torch_dtype", None) torch_dtype = kwargs.pop("torch_dtype", None)
...@@ -1094,7 +1100,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1094,7 +1100,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
force_download=force_download, force_download=force_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, token=token,
revision=revision, revision=revision,
from_flax=from_flax, from_flax=from_flax,
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
...@@ -1299,7 +1305,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1299,7 +1305,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
"force_download": force_download, "force_download": force_download,
"proxies": proxies, "proxies": proxies,
"local_files_only": local_files_only, "local_files_only": local_files_only,
"use_auth_token": use_auth_token, "token": token,
"revision": revision, "revision": revision,
"torch_dtype": torch_dtype, "torch_dtype": torch_dtype,
"custom_pipeline": custom_pipeline, "custom_pipeline": custom_pipeline,
...@@ -1529,6 +1535,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1529,6 +1535,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
cpu_offload(model, device, offload_buffers=offload_buffers) cpu_offload(model, device, offload_buffers=offload_buffers)
@classmethod @classmethod
@validate_hf_hub_args
def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
r""" r"""
Download and cache a PyTorch diffusion pipeline from pretrained pipeline weights. Download and cache a PyTorch diffusion pipeline from pretrained pipeline weights.
...@@ -1576,7 +1583,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1576,7 +1583,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
local_files_only (`bool`, *optional*, defaults to `False`): local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub. won't be downloaded from the Hub.
use_auth_token (`str` or *bool*, *optional*): token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used. `diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`): revision (`str`, *optional*, defaults to `"main"`):
...@@ -1619,12 +1626,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1619,12 +1626,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
</Tip> </Tip>
""" """
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", None)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) local_files_only = kwargs.pop("local_files_only", None)
use_auth_token = kwargs.pop("use_auth_token", None) token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
from_flax = kwargs.pop("from_flax", False) from_flax = kwargs.pop("from_flax", False)
custom_pipeline = kwargs.pop("custom_pipeline", None) custom_pipeline = kwargs.pop("custom_pipeline", None)
...@@ -1646,11 +1653,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1646,11 +1653,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
model_info_call_error: Optional[Exception] = None model_info_call_error: Optional[Exception] = None
if not local_files_only: if not local_files_only:
try: try:
info = model_info( info = model_info(pretrained_model_name, token=token, revision=revision)
pretrained_model_name,
use_auth_token=use_auth_token,
revision=revision,
)
except HTTPError as e: except HTTPError as e:
logger.warn(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.") logger.warn(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.")
local_files_only = True local_files_only = True
...@@ -1665,7 +1668,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1665,7 +1668,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
proxies=proxies, proxies=proxies,
force_download=force_download, force_download=force_download,
resume_download=resume_download, resume_download=resume_download,
use_auth_token=use_auth_token, token=token,
) )
config_dict = cls._dict_from_json_file(config_file) config_dict = cls._dict_from_json_file(config_file)
...@@ -1715,9 +1718,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1715,9 +1718,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
if revision in DEPRECATED_REVISION_ARGS and version.parse( if revision in DEPRECATED_REVISION_ARGS and version.parse(
version.parse(__version__).base_version version.parse(__version__).base_version
) >= version.parse("0.22.0"): ) >= version.parse("0.22.0"):
warn_deprecated_model_variant( warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, model_filenames)
pretrained_model_name, use_auth_token, variant, revision, model_filenames
)
model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names} model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}
...@@ -1859,7 +1860,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1859,7 +1860,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, token=token,
revision=revision, revision=revision,
allow_patterns=allow_patterns, allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns, ignore_patterns=ignore_patterns,
...@@ -1883,7 +1884,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1883,7 +1884,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
"force_download": force_download, "force_download": force_download,
"proxies": proxies, "proxies": proxies,
"local_files_only": local_files_only, "local_files_only": local_files_only,
"use_auth_token": use_auth_token, "token": token,
"variant": variant, "variant": variant,
"use_safetensors": use_safetensors, "use_safetensors": use_safetensors,
} }
......
...@@ -18,6 +18,7 @@ from enum import Enum ...@@ -18,6 +18,7 @@ from enum import Enum
from typing import Optional, Union from typing import Optional, Union
import torch import torch
from huggingface_hub.utils import validate_hf_hub_args
from ..utils import BaseOutput, PushToHubMixin from ..utils import BaseOutput, PushToHubMixin
...@@ -81,6 +82,7 @@ class SchedulerMixin(PushToHubMixin): ...@@ -81,6 +82,7 @@ class SchedulerMixin(PushToHubMixin):
has_compatibles = True has_compatibles = True
@classmethod @classmethod
@validate_hf_hub_args
def from_pretrained( def from_pretrained(
cls, cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
...@@ -120,7 +122,7 @@ class SchedulerMixin(PushToHubMixin): ...@@ -120,7 +122,7 @@ class SchedulerMixin(PushToHubMixin):
local_files_only(`bool`, *optional*, defaults to `False`): local_files_only(`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub. won't be downloaded from the Hub.
use_auth_token (`str` or *bool*, *optional*): token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used. `diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`): revision (`str`, *optional*, defaults to `"main"`):
......
...@@ -20,6 +20,7 @@ from typing import Optional, Tuple, Union ...@@ -20,6 +20,7 @@ from typing import Optional, Tuple, Union
import flax import flax
import jax.numpy as jnp import jax.numpy as jnp
from huggingface_hub.utils import validate_hf_hub_args
from ..utils import BaseOutput, PushToHubMixin from ..utils import BaseOutput, PushToHubMixin
...@@ -70,6 +71,7 @@ class FlaxSchedulerMixin(PushToHubMixin): ...@@ -70,6 +71,7 @@ class FlaxSchedulerMixin(PushToHubMixin):
has_compatibles = True has_compatibles = True
@classmethod @classmethod
@validate_hf_hub_args
def from_pretrained( def from_pretrained(
cls, cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
...@@ -110,7 +112,7 @@ class FlaxSchedulerMixin(PushToHubMixin): ...@@ -110,7 +112,7 @@ class FlaxSchedulerMixin(PushToHubMixin):
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(`bool`, *optional*, defaults to `False`): local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model). Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (`str` or *bool*, *optional*): token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `transformers-cli login` (stored in `~/.huggingface`). when running `transformers-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`): revision (`str`, *optional*, defaults to `"main"`):
......
...@@ -21,7 +21,6 @@ from .. import __version__ ...@@ -21,7 +21,6 @@ from .. import __version__
from .constants import ( from .constants import (
CONFIG_NAME, CONFIG_NAME,
DEPRECATED_REVISION_ARGS, DEPRECATED_REVISION_ARGS,
DIFFUSERS_CACHE,
DIFFUSERS_DYNAMIC_MODULE_NAME, DIFFUSERS_DYNAMIC_MODULE_NAME,
FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_NAME,
HF_MODULES_CACHE, HF_MODULES_CACHE,
...@@ -38,7 +37,6 @@ from .doc_utils import replace_example_docstring ...@@ -38,7 +37,6 @@ from .doc_utils import replace_example_docstring
from .dynamic_modules_utils import get_class_from_dynamic_module from .dynamic_modules_utils import get_class_from_dynamic_module
from .export_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video from .export_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video
from .hub_utils import ( from .hub_utils import (
HF_HUB_OFFLINE,
PushToHubMixin, PushToHubMixin,
_add_variant, _add_variant,
_get_model_file, _get_model_file,
......
...@@ -14,15 +14,13 @@ ...@@ -14,15 +14,13 @@
import importlib import importlib
import os import os
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE, hf_cache_home from huggingface_hub.constants import HF_HOME
from packaging import version from packaging import version
from ..dependency_versions_check import dep_version_check from ..dependency_versions_check import dep_version_check
from .import_utils import ENV_VARS_TRUE_VALUES, is_peft_available, is_transformers_available from .import_utils import ENV_VARS_TRUE_VALUES, is_peft_available, is_transformers_available
default_cache_path = HUGGINGFACE_HUB_CACHE
MIN_PEFT_VERSION = "0.6.0" MIN_PEFT_VERSION = "0.6.0"
MIN_TRANSFORMERS_VERSION = "4.34.0" MIN_TRANSFORMERS_VERSION = "4.34.0"
_CHECK_PEFT = os.environ.get("_CHECK_PEFT", "1") in ENV_VARS_TRUE_VALUES _CHECK_PEFT = os.environ.get("_CHECK_PEFT", "1") in ENV_VARS_TRUE_VALUES
...@@ -35,9 +33,8 @@ ONNX_WEIGHTS_NAME = "model.onnx" ...@@ -35,9 +33,8 @@ ONNX_WEIGHTS_NAME = "model.onnx"
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors" SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb" ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co") HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
DIFFUSERS_CACHE = default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
# Below should be `True` if the current version of `peft` and `transformers` are compatible with # Below should be `True` if the current version of `peft` and `transformers` are compatible with
......
...@@ -25,7 +25,8 @@ from pathlib import Path ...@@ -25,7 +25,8 @@ from pathlib import Path
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
from urllib import request from urllib import request
from huggingface_hub import HfFolder, cached_download, hf_hub_download, model_info from huggingface_hub import cached_download, hf_hub_download, model_info
from huggingface_hub.utils import validate_hf_hub_args
from packaging import version from packaging import version
from .. import __version__ from .. import __version__
...@@ -194,6 +195,7 @@ def find_pipeline_class(loaded_module): ...@@ -194,6 +195,7 @@ def find_pipeline_class(loaded_module):
return pipeline_class return pipeline_class
@validate_hf_hub_args
def get_cached_module_file( def get_cached_module_file(
pretrained_model_name_or_path: Union[str, os.PathLike], pretrained_model_name_or_path: Union[str, os.PathLike],
module_file: str, module_file: str,
...@@ -201,7 +203,7 @@ def get_cached_module_file( ...@@ -201,7 +203,7 @@ def get_cached_module_file(
force_download: bool = False, force_download: bool = False,
resume_download: bool = False, resume_download: bool = False,
proxies: Optional[Dict[str, str]] = None, proxies: Optional[Dict[str, str]] = None,
use_auth_token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
local_files_only: bool = False, local_files_only: bool = False,
): ):
...@@ -232,7 +234,7 @@ def get_cached_module_file( ...@@ -232,7 +234,7 @@ def get_cached_module_file(
proxies (`Dict[str, str]`, *optional*): proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
use_auth_token (`str` or *bool*, *optional*): token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `transformers-cli login` (stored in `~/.huggingface`). when running `transformers-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`): revision (`str`, *optional*, defaults to `"main"`):
...@@ -244,7 +246,7 @@ def get_cached_module_file( ...@@ -244,7 +246,7 @@ def get_cached_module_file(
<Tip> <Tip>
You may pass a token in `use_auth_token` if you are not logged in (`huggingface-cli long`) and want to use private You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private
or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models). or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
</Tip> </Tip>
...@@ -289,7 +291,7 @@ def get_cached_module_file( ...@@ -289,7 +291,7 @@ def get_cached_module_file(
proxies=proxies, proxies=proxies,
resume_download=resume_download, resume_download=resume_download,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=False, token=False,
) )
submodule = "git" submodule = "git"
module_file = pretrained_model_name_or_path + ".py" module_file = pretrained_model_name_or_path + ".py"
...@@ -307,7 +309,7 @@ def get_cached_module_file( ...@@ -307,7 +309,7 @@ def get_cached_module_file(
proxies=proxies, proxies=proxies,
resume_download=resume_download, resume_download=resume_download,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, token=token,
) )
submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/"))) submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/")))
except EnvironmentError: except EnvironmentError:
...@@ -332,13 +334,6 @@ def get_cached_module_file( ...@@ -332,13 +334,6 @@ def get_cached_module_file(
else: else:
# Get the commit hash # Get the commit hash
# TODO: we will get this info in the etag soon, so retrieve it from there and not here. # TODO: we will get this info in the etag soon, so retrieve it from there and not here.
if isinstance(use_auth_token, str):
token = use_auth_token
elif use_auth_token is True:
token = HfFolder.get_token()
else:
token = None
commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=token).sha commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=token).sha
# The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the # The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the
...@@ -359,13 +354,14 @@ def get_cached_module_file( ...@@ -359,13 +354,14 @@ def get_cached_module_file(
force_download=force_download, force_download=force_download,
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
use_auth_token=use_auth_token, token=token,
revision=revision, revision=revision,
local_files_only=local_files_only, local_files_only=local_files_only,
) )
return os.path.join(full_submodule, module_file) return os.path.join(full_submodule, module_file)
@validate_hf_hub_args
def get_class_from_dynamic_module( def get_class_from_dynamic_module(
pretrained_model_name_or_path: Union[str, os.PathLike], pretrained_model_name_or_path: Union[str, os.PathLike],
module_file: str, module_file: str,
...@@ -374,7 +370,7 @@ def get_class_from_dynamic_module( ...@@ -374,7 +370,7 @@ def get_class_from_dynamic_module(
force_download: bool = False, force_download: bool = False,
resume_download: bool = False, resume_download: bool = False,
proxies: Optional[Dict[str, str]] = None, proxies: Optional[Dict[str, str]] = None,
use_auth_token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
local_files_only: bool = False, local_files_only: bool = False,
**kwargs, **kwargs,
...@@ -414,7 +410,7 @@ def get_class_from_dynamic_module( ...@@ -414,7 +410,7 @@ def get_class_from_dynamic_module(
proxies (`Dict[str, str]`, *optional*): proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
use_auth_token (`str` or `bool`, *optional*): token (`str` or `bool`, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `transformers-cli login` (stored in `~/.huggingface`). when running `transformers-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`): revision (`str`, *optional*, defaults to `"main"`):
...@@ -426,7 +422,7 @@ def get_class_from_dynamic_module( ...@@ -426,7 +422,7 @@ def get_class_from_dynamic_module(
<Tip> <Tip>
You may pass a token in `use_auth_token` if you are not logged in (`huggingface-cli long`) and want to use private You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private
or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models). or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
</Tip> </Tip>
...@@ -449,7 +445,7 @@ def get_class_from_dynamic_module( ...@@ -449,7 +445,7 @@ def get_class_from_dynamic_module(
force_download=force_download, force_download=force_download,
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
use_auth_token=use_auth_token, token=token,
revision=revision, revision=revision,
local_files_only=local_files_only, local_files_only=local_files_only,
) )
......
...@@ -25,20 +25,21 @@ from typing import Dict, Optional, Union ...@@ -25,20 +25,21 @@ from typing import Dict, Optional, Union
from uuid import uuid4 from uuid import uuid4
from huggingface_hub import ( from huggingface_hub import (
HfFolder,
ModelCard, ModelCard,
ModelCardData, ModelCardData,
create_repo, create_repo,
get_full_repo_name,
hf_hub_download, hf_hub_download,
upload_folder, upload_folder,
whoami,
) )
from huggingface_hub.constants import HF_HUB_CACHE, HF_HUB_DISABLE_TELEMETRY, HF_HUB_OFFLINE
from huggingface_hub.file_download import REGEX_COMMIT_HASH from huggingface_hub.file_download import REGEX_COMMIT_HASH
from huggingface_hub.utils import ( from huggingface_hub.utils import (
EntryNotFoundError, EntryNotFoundError,
RepositoryNotFoundError, RepositoryNotFoundError,
RevisionNotFoundError, RevisionNotFoundError,
is_jinja_available, is_jinja_available,
validate_hf_hub_args,
) )
from packaging import version from packaging import version
from requests import HTTPError from requests import HTTPError
...@@ -46,7 +47,6 @@ from requests import HTTPError ...@@ -46,7 +47,6 @@ from requests import HTTPError
from .. import __version__ from .. import __version__
from .constants import ( from .constants import (
DEPRECATED_REVISION_ARGS, DEPRECATED_REVISION_ARGS,
DIFFUSERS_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT, HUGGINGFACE_CO_RESOLVE_ENDPOINT,
SAFETENSORS_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
...@@ -69,9 +69,6 @@ logger = get_logger(__name__) ...@@ -69,9 +69,6 @@ logger = get_logger(__name__)
MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "model_card_template.md" MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "model_card_template.md"
SESSION_ID = uuid4().hex SESSION_ID = uuid4().hex
HF_HUB_OFFLINE = os.getenv("HF_HUB_OFFLINE", "").upper() in ENV_VARS_TRUE_VALUES
DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", "").upper() in ENV_VARS_TRUE_VALUES
HUGGINGFACE_CO_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/"
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str: def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
...@@ -79,7 +76,7 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str: ...@@ -79,7 +76,7 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
Formats a user-agent string with basic info about a request. Formats a user-agent string with basic info about a request.
""" """
ua = f"diffusers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}" ua = f"diffusers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}"
if DISABLE_TELEMETRY or HF_HUB_OFFLINE: if HF_HUB_DISABLE_TELEMETRY or HF_HUB_OFFLINE:
return ua + "; telemetry/off" return ua + "; telemetry/off"
if is_torch_available(): if is_torch_available():
ua += f"; torch/{_torch_version}" ua += f"; torch/{_torch_version}"
...@@ -98,16 +95,6 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str: ...@@ -98,16 +95,6 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
return ua return ua
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
def create_model_card(args, model_name): def create_model_card(args, model_name):
if not is_jinja_available(): if not is_jinja_available():
raise ValueError( raise ValueError(
...@@ -183,7 +170,7 @@ old_diffusers_cache = os.path.join(hf_cache_home, "diffusers") ...@@ -183,7 +170,7 @@ old_diffusers_cache = os.path.join(hf_cache_home, "diffusers")
def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str] = None) -> None: def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str] = None) -> None:
if new_cache_dir is None: if new_cache_dir is None:
new_cache_dir = DIFFUSERS_CACHE new_cache_dir = HF_HUB_CACHE
if old_cache_dir is None: if old_cache_dir is None:
old_cache_dir = old_diffusers_cache old_cache_dir = old_diffusers_cache
...@@ -203,7 +190,7 @@ def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str] ...@@ -203,7 +190,7 @@ def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str]
# At this point, old_cache_dir contains symlinks to the new cache (it can still be used). # At this point, old_cache_dir contains symlinks to the new cache (it can still be used).
cache_version_file = os.path.join(DIFFUSERS_CACHE, "version_diffusers_cache.txt") cache_version_file = os.path.join(HF_HUB_CACHE, "version_diffusers_cache.txt")
if not os.path.isfile(cache_version_file): if not os.path.isfile(cache_version_file):
cache_version = 0 cache_version = 0
else: else:
...@@ -233,12 +220,12 @@ if cache_version < 1: ...@@ -233,12 +220,12 @@ if cache_version < 1:
if cache_version < 1: if cache_version < 1:
try: try:
os.makedirs(DIFFUSERS_CACHE, exist_ok=True) os.makedirs(HF_HUB_CACHE, exist_ok=True)
with open(cache_version_file, "w") as f: with open(cache_version_file, "w") as f:
f.write("1") f.write("1")
except Exception: except Exception:
logger.warning( logger.warning(
f"There was a problem when trying to write in your cache folder ({DIFFUSERS_CACHE}). Please, ensure " f"There was a problem when trying to write in your cache folder ({HF_HUB_CACHE}). Please, ensure "
"the directory exists and can be written to." "the directory exists and can be written to."
) )
...@@ -252,20 +239,21 @@ def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: ...@@ -252,20 +239,21 @@ def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
return weights_name return weights_name
@validate_hf_hub_args
def _get_model_file( def _get_model_file(
pretrained_model_name_or_path, pretrained_model_name_or_path: Union[str, Path],
*, *,
weights_name, weights_name: str,
subfolder, subfolder: Optional[str],
cache_dir, cache_dir: Optional[str],
force_download, force_download: bool,
proxies, proxies: Optional[Dict],
resume_download, resume_download: bool,
local_files_only, local_files_only: bool,
use_auth_token, token: Optional[str],
user_agent, user_agent: Union[Dict, str, None],
revision, revision: Optional[str],
commit_hash=None, commit_hash: Optional[str] = None,
): ):
pretrained_model_name_or_path = str(pretrained_model_name_or_path) pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isfile(pretrained_model_name_or_path): if os.path.isfile(pretrained_model_name_or_path):
...@@ -300,7 +288,7 @@ def _get_model_file( ...@@ -300,7 +288,7 @@ def _get_model_file(
proxies=proxies, proxies=proxies,
resume_download=resume_download, resume_download=resume_download,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, token=token,
user_agent=user_agent, user_agent=user_agent,
subfolder=subfolder, subfolder=subfolder,
revision=revision or commit_hash, revision=revision or commit_hash,
...@@ -325,7 +313,7 @@ def _get_model_file( ...@@ -325,7 +313,7 @@ def _get_model_file(
proxies=proxies, proxies=proxies,
resume_download=resume_download, resume_download=resume_download,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, token=token,
user_agent=user_agent, user_agent=user_agent,
subfolder=subfolder, subfolder=subfolder,
revision=revision or commit_hash, revision=revision or commit_hash,
...@@ -336,7 +324,7 @@ def _get_model_file( ...@@ -336,7 +324,7 @@ def _get_model_file(
raise EnvironmentError( raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli " "token having permission to this repo with `token` or log in with `huggingface-cli "
"login`." "login`."
) )
except RevisionNotFoundError: except RevisionNotFoundError:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment