Unverified Commit 15782fd5 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Pipeline utils] feat: implement push_to_hub for standalone models, schedulers...


[Pipeline utils] feat: implement push_to_hub for standalone models, schedulers as well as pipelines (#4128)

* feat: implement push_to_hub for standalone models.

* address PR feedback.

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* remove max_shard_size.

* add: support for scheduler push_to_hub

* enable push_to_hub support for flax schedulers.

* enable push_to_hub for pipelines.

* Apply suggestions from code review
Co-authored-by: default avatarLucain <lucainp@gmail.com>

* reflect pr feedback.

* address another round of deedback.

* better handling of kwargs.

* add: tests

* Apply suggestions from code review
Co-authored-by: default avatarLucain <lucainp@gmail.com>

* setting hub staging to False for now.

* incorporate staging test as a separate job.
Co-authored-by: default avatarydshieh <2521628+ydshieh@users.noreply.github.com>

* fix: tokenizer loading.

* fix: json dumping.

* move is_staging_test to a better location.

* better treatment to tokens.

* define repo_id to better handle concurrency

* style

* explicitly set token

* Empty-Commit

* move SUER, TOKEN to test

* collate org_repo_id

* delete repo

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarLucain <lucainp@gmail.com>
Co-authored-by: default avatarydshieh <2521628+ydshieh@users.noreply.github.com>
parent d93ca268
...@@ -113,3 +113,60 @@ jobs: ...@@ -113,3 +113,60 @@ jobs:
with: with:
name: pr_${{ matrix.config.report }}_test_reports name: pr_${{ matrix.config.report }}_test_reports
path: reports path: reports
run_staging_tests:
strategy:
fail-fast: false
matrix:
config:
- name: Hub tests for models, schedulers, and pipelines
framework: hub_tests_pytorch
runner: docker-cpu
image: diffusers/diffusers-pytorch-cpu
report: torch_hub
name: ${{ matrix.config.name }}
runs-on: ${{ matrix.config.runner }}
container:
image: ${{ matrix.config.image }}
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
defaults:
run:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
run: |
apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m pip install -e .[quality,test]
- name: Environment
run: |
python utils/print_env.py
- name: Run Hub tests for models, schedulers, and pipelines on a staging env
if: ${{ matrix.config.framework == 'hub_tests_pytorch' }}
run: |
HUGGINGFACE_CO_STAGING=true python -m pytest \
-m "is_staging_test" \
--make-reports=tests_${{ matrix.config.report }} \
tests
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v2
with:
name: pr_${{ matrix.config.report }}_test_reports
path: reports
\ No newline at end of file
...@@ -9,4 +9,8 @@ All models are built from the base [`ModelMixin`] class which is a [`torch.nn.mo ...@@ -9,4 +9,8 @@ All models are built from the base [`ModelMixin`] class which is a [`torch.nn.mo
## FlaxModelMixin ## FlaxModelMixin
[[autodoc]] FlaxModelMixin [[autodoc]] FlaxModelMixin
\ No newline at end of file
## Pushing to the Hub
[[autodoc]] utils.PushToHubMixin
\ No newline at end of file
...@@ -26,7 +26,7 @@ from pathlib import PosixPath ...@@ -26,7 +26,7 @@ from pathlib import PosixPath
from typing import Any, Dict, Tuple, Union from typing import Any, Dict, Tuple, Union
import numpy as np import numpy as np
from huggingface_hub import hf_hub_download from huggingface_hub import create_repo, hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError from requests import HTTPError
...@@ -144,6 +144,12 @@ class ConfigMixin: ...@@ -144,6 +144,12 @@ class ConfigMixin:
Args: Args:
save_directory (`str` or `os.PathLike`): save_directory (`str` or `os.PathLike`):
Directory where the configuration JSON file is saved (will be created if it does not exist). Directory where the configuration JSON file is saved (will be created if it does not exist).
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
""" """
if os.path.isfile(save_directory): if os.path.isfile(save_directory):
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
...@@ -156,6 +162,22 @@ class ConfigMixin: ...@@ -156,6 +162,22 @@ class ConfigMixin:
self.to_json_file(output_config_file) self.to_json_file(output_config_file)
logger.info(f"Configuration saved in {output_config_file}") logger.info(f"Configuration saved in {output_config_file}")
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
private = kwargs.pop("private", False)
create_pr = kwargs.pop("create_pr", False)
token = kwargs.pop("token", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
self._upload_folder(
save_directory,
repo_id,
token=token,
commit_message=commit_message,
create_pr=create_pr,
)
@classmethod @classmethod
def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs): def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
r""" r"""
......
...@@ -23,7 +23,7 @@ import msgpack.exceptions ...@@ -23,7 +23,7 @@ import msgpack.exceptions
from flax.core.frozen_dict import FrozenDict, unfreeze from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.serialization import from_bytes, to_bytes from flax.serialization import from_bytes, to_bytes
from flax.traverse_util import flatten_dict, unflatten_dict from flax.traverse_util import flatten_dict, unflatten_dict
from huggingface_hub import hf_hub_download from huggingface_hub import create_repo, hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError from requests import HTTPError
...@@ -34,6 +34,7 @@ from ..utils import ( ...@@ -34,6 +34,7 @@ from ..utils import (
FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_NAME,
HUGGINGFACE_CO_RESOLVE_ENDPOINT, HUGGINGFACE_CO_RESOLVE_ENDPOINT,
WEIGHTS_NAME, WEIGHTS_NAME,
PushToHubMixin,
logging, logging,
) )
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
...@@ -42,7 +43,7 @@ from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax ...@@ -42,7 +43,7 @@ from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
class FlaxModelMixin: class FlaxModelMixin(PushToHubMixin):
r""" r"""
Base class for all Flax models. Base class for all Flax models.
...@@ -497,6 +498,8 @@ class FlaxModelMixin: ...@@ -497,6 +498,8 @@ class FlaxModelMixin:
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
params: Union[Dict, FrozenDict], params: Union[Dict, FrozenDict],
is_main_process: bool = True, is_main_process: bool = True,
push_to_hub: bool = False,
**kwargs,
): ):
""" """
Save a model and its configuration file to a directory so that it can be reloaded using the Save a model and its configuration file to a directory so that it can be reloaded using the
...@@ -511,6 +514,12 @@ class FlaxModelMixin: ...@@ -511,6 +514,12 @@ class FlaxModelMixin:
Whether the process calling this is the main process or not. Useful during distributed training and you Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main need to call this function on all processes. In this case, set `is_main_process=True` only on the main
process to avoid race conditions. process to avoid race conditions.
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
kwargs (`Dict[str, Any]`, *optional*):
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
""" """
if os.path.isfile(save_directory): if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file") logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
...@@ -518,6 +527,14 @@ class FlaxModelMixin: ...@@ -518,6 +527,14 @@ class FlaxModelMixin:
os.makedirs(save_directory, exist_ok=True) os.makedirs(save_directory, exist_ok=True)
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
private = kwargs.pop("private", False)
create_pr = kwargs.pop("create_pr", False)
token = kwargs.pop("token", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
model_to_save = self model_to_save = self
# Attach architecture to the config # Attach architecture to the config
...@@ -532,3 +549,12 @@ class FlaxModelMixin: ...@@ -532,3 +549,12 @@ class FlaxModelMixin:
f.write(model_bytes) f.write(model_bytes)
logger.info(f"Model weights saved in {output_model_file}") logger.info(f"Model weights saved in {output_model_file}")
if push_to_hub:
self._upload_folder(
save_directory,
repo_id,
token=token,
commit_message=commit_message,
create_pr=create_pr,
)
...@@ -23,6 +23,7 @@ from typing import Any, Callable, List, Optional, Tuple, Union ...@@ -23,6 +23,7 @@ from typing import Any, Callable, List, Optional, Tuple, Union
import safetensors import safetensors
import torch import torch
from huggingface_hub import create_repo
from torch import Tensor, device, nn from torch import Tensor, device, nn
from .. import __version__ from .. import __version__
...@@ -40,6 +41,7 @@ from ..utils import ( ...@@ -40,6 +41,7 @@ from ..utils import (
is_torch_version, is_torch_version,
logging, logging,
) )
from ..utils.hub_utils import PushToHubMixin
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -147,7 +149,7 @@ def _load_state_dict_into_model(model_to_load, state_dict): ...@@ -147,7 +149,7 @@ def _load_state_dict_into_model(model_to_load, state_dict):
return error_msgs return error_msgs
class ModelMixin(torch.nn.Module): class ModelMixin(torch.nn.Module, PushToHubMixin):
r""" r"""
Base class for all models. Base class for all models.
...@@ -272,6 +274,8 @@ class ModelMixin(torch.nn.Module): ...@@ -272,6 +274,8 @@ class ModelMixin(torch.nn.Module):
save_function: Callable = None, save_function: Callable = None,
safe_serialization: bool = False, safe_serialization: bool = False,
variant: Optional[str] = None, variant: Optional[str] = None,
push_to_hub: bool = False,
**kwargs,
): ):
""" """
Save a model and its configuration file to a directory so that it can be reloaded using the Save a model and its configuration file to a directory so that it can be reloaded using the
...@@ -292,6 +296,12 @@ class ModelMixin(torch.nn.Module): ...@@ -292,6 +296,12 @@ class ModelMixin(torch.nn.Module):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
variant (`str`, *optional*): variant (`str`, *optional*):
If specified, weights are saved in the format `pytorch_model.<variant>.bin`. If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
""" """
if os.path.isfile(save_directory): if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file") logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
...@@ -299,6 +309,15 @@ class ModelMixin(torch.nn.Module): ...@@ -299,6 +309,15 @@ class ModelMixin(torch.nn.Module):
os.makedirs(save_directory, exist_ok=True) os.makedirs(save_directory, exist_ok=True)
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
private = kwargs.pop("private", False)
create_pr = kwargs.pop("create_pr", False)
token = kwargs.pop("token", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
# Only save the model itself if we are using distributed training
model_to_save = self model_to_save = self
# Attach architecture to the config # Attach architecture to the config
...@@ -322,6 +341,15 @@ class ModelMixin(torch.nn.Module): ...@@ -322,6 +341,15 @@ class ModelMixin(torch.nn.Module):
logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}") logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
if push_to_hub:
self._upload_folder(
save_directory,
repo_id,
token=token,
commit_message=commit_message,
create_pr=create_pr,
)
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
r""" r"""
......
...@@ -23,14 +23,22 @@ import flax ...@@ -23,14 +23,22 @@ import flax
import numpy as np import numpy as np
import PIL import PIL
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict
from huggingface_hub import snapshot_download from huggingface_hub import create_repo, snapshot_download
from PIL import Image from PIL import Image
from tqdm.auto import tqdm from tqdm.auto import tqdm
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..models.modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin from ..models.modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin
from ..schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin from ..schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin
from ..utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, http_user_agent, is_transformers_available, logging from ..utils import (
CONFIG_NAME,
DIFFUSERS_CACHE,
BaseOutput,
PushToHubMixin,
http_user_agent,
is_transformers_available,
logging,
)
if is_transformers_available(): if is_transformers_available():
...@@ -90,7 +98,7 @@ class FlaxImagePipelineOutput(BaseOutput): ...@@ -90,7 +98,7 @@ class FlaxImagePipelineOutput(BaseOutput):
images: Union[List[PIL.Image.Image], np.ndarray] images: Union[List[PIL.Image.Image], np.ndarray]
class FlaxDiffusionPipeline(ConfigMixin): class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
r""" r"""
Base class for Flax-based pipelines. Base class for Flax-based pipelines.
...@@ -139,7 +147,13 @@ class FlaxDiffusionPipeline(ConfigMixin): ...@@ -139,7 +147,13 @@ class FlaxDiffusionPipeline(ConfigMixin):
# set models # set models
setattr(self, name, module) setattr(self, name, module)
def save_pretrained(self, save_directory: Union[str, os.PathLike], params: Union[Dict, FrozenDict]): def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
params: Union[Dict, FrozenDict],
push_to_hub: bool = False,
**kwargs,
):
# TODO: handle inference_state # TODO: handle inference_state
""" """
Save all saveable variables of the pipeline to a directory. A pipeline variable can be saved and loaded if its Save all saveable variables of the pipeline to a directory. A pipeline variable can be saved and loaded if its
...@@ -149,6 +163,12 @@ class FlaxDiffusionPipeline(ConfigMixin): ...@@ -149,6 +163,12 @@ class FlaxDiffusionPipeline(ConfigMixin):
Arguments: Arguments:
save_directory (`str` or `os.PathLike`): save_directory (`str` or `os.PathLike`):
Directory to which to save. Will be created if it doesn't exist. Directory to which to save. Will be created if it doesn't exist.
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
""" """
self.save_config(save_directory) self.save_config(save_directory)
...@@ -157,6 +177,14 @@ class FlaxDiffusionPipeline(ConfigMixin): ...@@ -157,6 +177,14 @@ class FlaxDiffusionPipeline(ConfigMixin):
model_index_dict.pop("_diffusers_version") model_index_dict.pop("_diffusers_version")
model_index_dict.pop("_module", None) model_index_dict.pop("_module", None)
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
private = kwargs.pop("private", False)
create_pr = kwargs.pop("create_pr", False)
token = kwargs.pop("token", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
for pipeline_component_name in model_index_dict.keys(): for pipeline_component_name in model_index_dict.keys():
sub_model = getattr(self, pipeline_component_name) sub_model = getattr(self, pipeline_component_name)
if sub_model is None: if sub_model is None:
...@@ -188,6 +216,15 @@ class FlaxDiffusionPipeline(ConfigMixin): ...@@ -188,6 +216,15 @@ class FlaxDiffusionPipeline(ConfigMixin):
else: else:
save_method(os.path.join(save_directory, pipeline_component_name)) save_method(os.path.join(save_directory, pipeline_component_name))
if push_to_hub:
self._upload_folder(
save_directory,
repo_id,
token=token,
commit_message=commit_message,
create_pr=create_pr,
)
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
r""" r"""
......
...@@ -28,7 +28,7 @@ from typing import Any, Callable, Dict, List, Optional, Union ...@@ -28,7 +28,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np import numpy as np
import PIL import PIL
import torch import torch
from huggingface_hub import ModelCard, hf_hub_download, model_info, snapshot_download from huggingface_hub import ModelCard, create_repo, hf_hub_download, model_info, snapshot_download
from packaging import version from packaging import version
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from tqdm.auto import tqdm from tqdm.auto import tqdm
...@@ -66,7 +66,7 @@ if is_transformers_available(): ...@@ -66,7 +66,7 @@ if is_transformers_available():
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, PushToHubMixin
if is_accelerate_available(): if is_accelerate_available():
...@@ -472,7 +472,7 @@ def load_sub_model( ...@@ -472,7 +472,7 @@ def load_sub_model(
return loaded_sub_model return loaded_sub_model
class DiffusionPipeline(ConfigMixin): class DiffusionPipeline(ConfigMixin, PushToHubMixin):
r""" r"""
Base class for all pipelines. Base class for all pipelines.
...@@ -558,6 +558,8 @@ class DiffusionPipeline(ConfigMixin): ...@@ -558,6 +558,8 @@ class DiffusionPipeline(ConfigMixin):
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
safe_serialization: bool = False, safe_serialization: bool = False,
variant: Optional[str] = None, variant: Optional[str] = None,
push_to_hub: bool = False,
**kwargs,
): ):
""" """
Save all saveable variables of the pipeline to a directory. A pipeline variable can be saved and loaded if its Save all saveable variables of the pipeline to a directory. A pipeline variable can be saved and loaded if its
...@@ -571,6 +573,12 @@ class DiffusionPipeline(ConfigMixin): ...@@ -571,6 +573,12 @@ class DiffusionPipeline(ConfigMixin):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
variant (`str`, *optional*): variant (`str`, *optional*):
If specified, weights are saved in the format `pytorch_model.<variant>.bin`. If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
""" """
model_index_dict = dict(self.config) model_index_dict = dict(self.config)
model_index_dict.pop("_class_name", None) model_index_dict.pop("_class_name", None)
...@@ -578,6 +586,14 @@ class DiffusionPipeline(ConfigMixin): ...@@ -578,6 +586,14 @@ class DiffusionPipeline(ConfigMixin):
model_index_dict.pop("_module", None) model_index_dict.pop("_module", None)
model_index_dict.pop("_name_or_path", None) model_index_dict.pop("_name_or_path", None)
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
private = kwargs.pop("private", False)
create_pr = kwargs.pop("create_pr", False)
token = kwargs.pop("token", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
expected_modules, optional_kwargs = self._get_signature_keys(self) expected_modules, optional_kwargs = self._get_signature_keys(self)
def is_saveable_module(name, value): def is_saveable_module(name, value):
...@@ -641,6 +657,15 @@ class DiffusionPipeline(ConfigMixin): ...@@ -641,6 +657,15 @@ class DiffusionPipeline(ConfigMixin):
# finally save the config # finally save the config
self.save_config(save_directory) self.save_config(save_directory)
if push_to_hub:
self._upload_folder(
save_directory,
repo_id,
token=token,
commit_message=commit_message,
create_pr=create_pr,
)
def to( def to(
self, self,
torch_device: Optional[Union[str, torch.device]] = None, torch_device: Optional[Union[str, torch.device]] = None,
......
...@@ -19,7 +19,7 @@ from typing import Any, Dict, Optional, Union ...@@ -19,7 +19,7 @@ from typing import Any, Dict, Optional, Union
import torch import torch
from ..utils import BaseOutput from ..utils import BaseOutput, PushToHubMixin
SCHEDULER_CONFIG_NAME = "scheduler_config.json" SCHEDULER_CONFIG_NAME = "scheduler_config.json"
...@@ -60,7 +60,7 @@ class SchedulerOutput(BaseOutput): ...@@ -60,7 +60,7 @@ class SchedulerOutput(BaseOutput):
prev_sample: torch.FloatTensor prev_sample: torch.FloatTensor
class SchedulerMixin: class SchedulerMixin(PushToHubMixin):
""" """
Base class for all schedulers. Base class for all schedulers.
...@@ -153,7 +153,13 @@ class SchedulerMixin: ...@@ -153,7 +153,13 @@ class SchedulerMixin:
Args: Args:
save_directory (`str` or `os.PathLike`): save_directory (`str` or `os.PathLike`):
Directory to save a configuration JSON file to. Will be created if it doesn't exist. Directory where the configuration JSON file will be saved (will be created if it does not exist).
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
""" """
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
......
...@@ -21,7 +21,7 @@ from typing import Any, Dict, Optional, Tuple, Union ...@@ -21,7 +21,7 @@ from typing import Any, Dict, Optional, Tuple, Union
import flax import flax
import jax.numpy as jnp import jax.numpy as jnp
from ..utils import BaseOutput from ..utils import BaseOutput, PushToHubMixin
SCHEDULER_CONFIG_NAME = "scheduler_config.json" SCHEDULER_CONFIG_NAME = "scheduler_config.json"
...@@ -53,7 +53,7 @@ class FlaxSchedulerOutput(BaseOutput): ...@@ -53,7 +53,7 @@ class FlaxSchedulerOutput(BaseOutput):
prev_sample: jnp.ndarray prev_sample: jnp.ndarray
class FlaxSchedulerMixin: class FlaxSchedulerMixin(PushToHubMixin):
""" """
Mixin containing common functions for the schedulers. Mixin containing common functions for the schedulers.
...@@ -156,6 +156,12 @@ class FlaxSchedulerMixin: ...@@ -156,6 +156,12 @@ class FlaxSchedulerMixin:
Args: Args:
save_directory (`str` or `os.PathLike`): save_directory (`str` or `os.PathLike`):
Directory where the configuration JSON file will be saved (will be created if it does not exist). Directory where the configuration JSON file will be saved (will be created if it does not exist).
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
""" """
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
......
...@@ -37,6 +37,7 @@ from .doc_utils import replace_example_docstring ...@@ -37,6 +37,7 @@ from .doc_utils import replace_example_docstring
from .dynamic_modules_utils import get_class_from_dynamic_module from .dynamic_modules_utils import get_class_from_dynamic_module
from .hub_utils import ( from .hub_utils import (
HF_HUB_OFFLINE, HF_HUB_OFFLINE,
PushToHubMixin,
_add_variant, _add_variant,
_get_model_file, _get_model_file,
extract_commit_hash, extract_commit_hash,
......
...@@ -17,13 +17,22 @@ ...@@ -17,13 +17,22 @@
import os import os
import re import re
import sys import sys
import tempfile
import traceback import traceback
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
from uuid import uuid4 from uuid import uuid4
from huggingface_hub import HfFolder, ModelCard, ModelCardData, hf_hub_download, whoami from huggingface_hub import (
HfFolder,
ModelCard,
ModelCardData,
create_repo,
hf_hub_download,
upload_folder,
whoami,
)
from huggingface_hub.file_download import REGEX_COMMIT_HASH from huggingface_hub.file_download import REGEX_COMMIT_HASH
from huggingface_hub.utils import ( from huggingface_hub.utils import (
EntryNotFoundError, EntryNotFoundError,
...@@ -359,3 +368,96 @@ def _get_model_file( ...@@ -359,3 +368,96 @@ def _get_model_file(
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a file named {weights_name}" f"containing a file named {weights_name}"
) )
class PushToHubMixin:
"""
A Mixin containing the functionality to push a model/scheduler to the Hugging Face Hub.
"""
def _upload_folder(
self,
working_dir: Union[str, os.PathLike],
repo_id: str,
token: Optional[str] = None,
commit_message: Optional[str] = None,
create_pr: bool = False,
):
"""
Uploads all files in `working_dir` to `repo_id`.
"""
if commit_message is None:
if "Model" in self.__class__.__name__:
commit_message = "Upload model"
elif "Scheduler" in self.__class__.__name__:
commit_message = "Upload scheduler"
else:
commit_message = f"Upload {self.__class__.__name__}"
logger.info(f"Uploading the files of {working_dir} to {repo_id}.")
return upload_folder(
repo_id=repo_id, folder_path=working_dir, token=token, commit_message=commit_message, create_pr=create_pr
)
def push_to_hub(
self,
repo_id: str,
commit_message: Optional[str] = None,
private: Optional[bool] = None,
token: Optional[str] = None,
create_pr: bool = False,
safe_serialization: bool = True,
variant: Optional[str] = None,
) -> str:
"""
Upload the {object_files} to the 🤗 Hugging Face Hub.
Parameters:
repo_id (`str`):
The name of the repository you want to push your {object} to. It should contain your organization name
when pushing to a given organization. `repo_id` can also be a path to a local directory.
commit_message (`str`, *optional*):
Message to commit while pushing. Will default to `"Upload {object}"`.
private (`bool`, *optional*):
Whether or not the repository created should be private.
token (`str`, *optional*):
The token to use as HTTP bearer authorization for remote files. The token generated when running
`huggingface-cli login` (stored in `~/.huggingface`).
create_pr (`bool`, *optional*, defaults to `False`):
Whether or not to create a PR with the uploaded files or directly commit.
safe_serialization (`bool`, *optional*, defaults to `False`):
Whether or not to convert the model weights in safetensors format for safer serialization.
variant (`str`, *optional*):
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
Examples:
```python
from diffusers import UNet2DConditionModel
unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="unet")
# Push the `unet` to your namespace with the name "my-finetuned-unet".
unet.push_to_hub("my-finetuned-unet")
# Push the {object} to an organization with the name "my-finetuned-unet".
unet.push_to_hub("your-org/my-finetuned-unet")
```
"""
repo_id = create_repo(repo_id, private=private, token=token, exist_ok=True).repo_id
# Save all files.
save_kwargs = {"safe_serialization": safe_serialization}
if "Scheduler" not in self.__class__.__name__:
save_kwargs.update({"variant": variant})
with tempfile.TemporaryDirectory() as tmpdir:
self.save_pretrained(tmpdir, **save_kwargs)
return self._upload_folder(
tmpdir,
repo_id,
token=token,
commit_message=commit_message,
create_pr=create_pr,
)
...@@ -18,18 +18,27 @@ import tempfile ...@@ -18,18 +18,27 @@ import tempfile
import traceback import traceback
import unittest import unittest
import unittest.mock as mock import unittest.mock as mock
import uuid
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import numpy as np import numpy as np
import requests_mock import requests_mock
import torch import torch
from huggingface_hub import delete_repo
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from diffusers.models import UNet2DConditionModel from diffusers.models import UNet2DConditionModel
from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0, XFormersAttnProcessor from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0, XFormersAttnProcessor
from diffusers.training_utils import EMAModel from diffusers.training_utils import EMAModel
from diffusers.utils import logging, torch_device from diffusers.utils import logging, torch_device
from diffusers.utils.testing_utils import CaptureLogger, require_torch_2, require_torch_gpu, run_test_in_subprocess from diffusers.utils.testing_utils import (
CaptureLogger,
require_torch_2,
require_torch_gpu,
run_test_in_subprocess,
)
from ..others.test_utils import TOKEN, USER, is_staging_test
# Will be run via run_test_in_subprocess # Will be run via run_test_in_subprocess
...@@ -563,3 +572,72 @@ class ModelTesterMixin: ...@@ -563,3 +572,72 @@ class ModelTesterMixin:
f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument" f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument"
" from `_deprecated_kwargs = [<deprecated_argument>]`" " from `_deprecated_kwargs = [<deprecated_argument>]`"
) )
@is_staging_test
class ModelPushToHubTester(unittest.TestCase):
identifier = uuid.uuid4()
repo_id = f"test-model-{identifier}"
org_repo_id = f"valid_org/{repo_id}-org"
def test_push_to_hub(self):
model = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
)
model.push_to_hub(self.repo_id, token=TOKEN)
new_model = UNet2DConditionModel.from_pretrained(f"{USER}/{self.repo_id}")
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
# Reset repo
delete_repo(token=TOKEN, repo_id=self.repo_id)
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, repo_id=self.repo_id, push_to_hub=True, token=TOKEN)
new_model = UNet2DConditionModel.from_pretrained(f"{USER}/{self.repo_id}")
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
# Reset repo
delete_repo(self.repo_id, token=TOKEN)
def test_push_to_hub_in_organization(self):
model = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
)
model.push_to_hub(self.org_repo_id, token=TOKEN)
new_model = UNet2DConditionModel.from_pretrained(self.org_repo_id)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
# Reset repo
delete_repo(token=TOKEN, repo_id=self.org_repo_id)
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, push_to_hub=True, token=TOKEN, repo_id=self.org_repo_id)
new_model = UNet2DConditionModel.from_pretrained(self.org_repo_id)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
# Reset repo
delete_repo(self.org_repo_id, token=TOKEN)
...@@ -13,12 +13,24 @@ ...@@ -13,12 +13,24 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import unittest import unittest
from distutils.util import strtobool
import pytest
from diffusers import __version__ from diffusers import __version__
from diffusers.utils import deprecate from diffusers.utils import deprecate
# Used to test the hub
USER = "__DUMMY_TRANSFORMERS_USER__"
ENDPOINT_STAGING = "https://hub-ci.huggingface.co"
# Not critical, only usable on the sandboxed CI instance.
TOKEN = "hf_94wBhPGp6KrrTH3KDchhKpRxZwd6dmHWLL"
class DeprecateTester(unittest.TestCase): class DeprecateTester(unittest.TestCase):
higher_version = ".".join([str(int(__version__.split(".")[0]) + 1)] + __version__.split(".")[1:]) higher_version = ".".join([str(int(__version__.split(".")[0]) + 1)] + __version__.split(".")[1:])
lower_version = "0.0.1" lower_version = "0.0.1"
...@@ -168,3 +180,34 @@ class DeprecateTester(unittest.TestCase): ...@@ -168,3 +180,34 @@ class DeprecateTester(unittest.TestCase):
deprecate(("deprecated_arg", self.higher_version, "This message is better!!!"), standard_warn=False) deprecate(("deprecated_arg", self.higher_version, "This message is better!!!"), standard_warn=False)
assert str(warning.warning) == "This message is better!!!" assert str(warning.warning) == "This message is better!!!"
assert "diffusers/tests/others/test_utils.py" in warning.filename assert "diffusers/tests/others/test_utils.py" in warning.filename
def parse_flag_from_env(key, default=False):
try:
value = os.environ[key]
except KeyError:
# KEY isn't set, default to `default`.
_value = default
else:
# KEY is set, convert it to True or False.
try:
_value = strtobool(value)
except ValueError:
# More values are supported, but let's keep the message simple.
raise ValueError(f"If set, {key} must be yes or no.")
return _value
_run_staging = parse_flag_from_env("HUGGINGFACE_CO_STAGING", default=False)
def is_staging_test(test_case):
"""
Decorator marking a test as a staging test.
Those tests will run using the staging environment of huggingface.co instead of the real model hub.
"""
if not _run_staging:
return unittest.skip("test is staging test")(test_case)
else:
return pytest.mark.is_staging_test()(test_case)
...@@ -2,23 +2,30 @@ import contextlib ...@@ -2,23 +2,30 @@ import contextlib
import gc import gc
import inspect import inspect
import io import io
import json
import os
import re import re
import tempfile import tempfile
import unittest import unittest
import uuid
from typing import Callable, Union from typing import Callable, Union
import numpy as np import numpy as np
import PIL import PIL
import torch import torch
from huggingface_hub import delete_repo
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
import diffusers import diffusers
from diffusers import DiffusionPipeline from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import logging from diffusers.utils import logging
from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available
from diffusers.utils.testing_utils import CaptureLogger, require_torch, torch_device from diffusers.utils.testing_utils import CaptureLogger, require_torch, torch_device
from ..others.test_utils import TOKEN, USER, is_staging_test
def to_np(tensor): def to_np(tensor):
if isinstance(tensor, torch.Tensor): if isinstance(tensor, torch.Tensor):
...@@ -795,6 +802,126 @@ class PipelineTesterMixin: ...@@ -795,6 +802,126 @@ class PipelineTesterMixin:
assert out_cfg.shape == out_no_cfg.shape assert out_cfg.shape == out_no_cfg.shape
@is_staging_test
class PipelinePushToHubTester(unittest.TestCase):
identifier = uuid.uuid4()
repo_id = f"test-pipeline-{identifier}"
org_repo_id = f"valid_org/{repo_id}-org"
def get_pipeline_components(self):
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
)
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
vae = AutoencoderKL(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
text_encoder = CLIPTextModel(text_encoder_config)
with tempfile.TemporaryDirectory() as tmpdir:
dummy_vocab = {"<|startoftext|>": 0, "<|endoftext|>": 1, "!": 2}
vocab_path = os.path.join(tmpdir, "vocab.json")
with open(vocab_path, "w") as f:
json.dump(dummy_vocab, f)
merges = "Ġ t\nĠt h"
merges_path = os.path.join(tmpdir, "merges.txt")
with open(merges_path, "w") as f:
f.writelines(merges)
tokenizer = CLIPTokenizer(vocab_file=vocab_path, merges_file=merges_path)
components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
}
return components
def test_push_to_hub(self):
components = self.get_pipeline_components()
pipeline = StableDiffusionPipeline(**components)
pipeline.push_to_hub(self.repo_id, token=TOKEN)
new_model = UNet2DConditionModel.from_pretrained(f"{USER}/{self.repo_id}", subfolder="unet")
unet = components["unet"]
for p1, p2 in zip(unet.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
# Reset repo
delete_repo(token=TOKEN, repo_id=self.repo_id)
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir, repo_id=self.repo_id, push_to_hub=True, token=TOKEN)
new_model = UNet2DConditionModel.from_pretrained(f"{USER}/{self.repo_id}", subfolder="unet")
for p1, p2 in zip(unet.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
# Reset repo
delete_repo(self.repo_id, token=TOKEN)
def test_push_to_hub_in_organization(self):
components = self.get_pipeline_components()
pipeline = StableDiffusionPipeline(**components)
pipeline.push_to_hub(self.org_repo_id, token=TOKEN)
new_model = UNet2DConditionModel.from_pretrained(self.org_repo_id, subfolder="unet")
unet = components["unet"]
for p1, p2 in zip(unet.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
# Reset repo
delete_repo(token=TOKEN, repo_id=self.org_repo_id)
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir, push_to_hub=True, token=TOKEN, repo_id=self.org_repo_id)
new_model = UNet2DConditionModel.from_pretrained(self.org_repo_id, subfolder="unet")
for p1, p2 in zip(unet.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
# Reset repo
delete_repo(self.org_repo_id, token=TOKEN)
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used. # Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a # This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
# reference image. # reference image.
......
...@@ -17,10 +17,12 @@ import json ...@@ -17,10 +17,12 @@ import json
import os import os
import tempfile import tempfile
import unittest import unittest
import uuid
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import numpy as np import numpy as np
import torch import torch
from huggingface_hub import delete_repo
import diffusers import diffusers
from diffusers import ( from diffusers import (
...@@ -41,6 +43,8 @@ from diffusers.schedulers.scheduling_utils import SchedulerMixin ...@@ -41,6 +43,8 @@ from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.utils import torch_device from diffusers.utils import torch_device
from diffusers.utils.testing_utils import CaptureLogger from diffusers.utils.testing_utils import CaptureLogger
from ..others.test_utils import TOKEN, USER, is_staging_test
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
...@@ -720,3 +724,64 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -720,3 +724,64 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler.does_not_exist scheduler.does_not_exist
assert str(error.exception) == f"'{type(scheduler).__name__}' object has no attribute 'does_not_exist'" assert str(error.exception) == f"'{type(scheduler).__name__}' object has no attribute 'does_not_exist'"
@is_staging_test
class SchedulerPushToHubTester(unittest.TestCase):
identifier = uuid.uuid4()
repo_id = f"test-scheduler-{identifier}"
org_repo_id = f"valid_org/{repo_id}-org"
def test_push_to_hub(self):
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
scheduler.push_to_hub(self.repo_id, token=TOKEN)
scheduler_loaded = DDIMScheduler.from_pretrained(f"{USER}/{self.repo_id}")
assert type(scheduler) == type(scheduler_loaded)
# Reset repo
delete_repo(token=TOKEN, repo_id=self.repo_id)
# Push to hub via save_config
with tempfile.TemporaryDirectory() as tmp_dir:
scheduler.save_config(tmp_dir, repo_id=self.repo_id, push_to_hub=True, token=TOKEN)
scheduler_loaded = DDIMScheduler.from_pretrained(f"{USER}/{self.repo_id}")
assert type(scheduler) == type(scheduler_loaded)
# Reset repo
delete_repo(token=TOKEN, repo_id=self.repo_id)
def test_push_to_hub_in_organization(self):
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
scheduler.push_to_hub(self.org_repo_id, token=TOKEN)
scheduler_loaded = DDIMScheduler.from_pretrained(self.org_repo_id)
assert type(scheduler) == type(scheduler_loaded)
# Reset repo
delete_repo(token=TOKEN, repo_id=self.org_repo_id)
# Push to hub via save_config
with tempfile.TemporaryDirectory() as tmp_dir:
scheduler.save_config(tmp_dir, repo_id=self.org_repo_id, push_to_hub=True, token=TOKEN)
scheduler_loaded = DDIMScheduler.from_pretrained(self.org_repo_id)
assert type(scheduler) == type(scheduler_loaded)
# Reset repo
delete_repo(token=TOKEN, repo_id=self.org_repo_id)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment