Unverified Commit 4e2c1f3a authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Add config docs (#429)

* advance

* finish

* finish
parent 5e6417e9
...@@ -10,19 +10,14 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o ...@@ -10,19 +10,14 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License. specific language governing permissions and limitations under the License.
--> -->
# Models # Configuration
Diffusers contains pretrained models for popular algorithms and modules for creating the next set of diffusion models. In Diffusers, schedulers of type [`schedulers.scheduling_utils.SchedulerMixin`], and models of type [`ModelMixin`] inherit from [`ConfigMixin`] which conveniently takes care of storing all parameters that are
The primary function of these models is to denoise an input sample, by modeling the distribution $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$. passed to the respective `__init__` methods in a JSON-configuration file.
The models are built on the base class ['ModelMixin'] that is a `torch.nn.module` with basic functionality for saving and loading models both locally and from the HuggingFace hub.
## API TODO(PVP) - add example and better info here
Models should provide the `def forward` function and initialization of the model. ## ConfigMixin
All saving, loading, and utilities should be in the base ['ModelMixin'] class. [[autodoc]] ConfigMixin
- from_config
## Examples - save_config
- The ['UNetModel'] was proposed in [TODO](https://arxiv.org/) and has been used in paper1, paper2, paper3.
- Extensions of the ['UNetModel'] include the ['UNetGlideModel'] that uses attention and timestep embeddings for the [GLIDE](https://arxiv.org/abs/2112.10741) paper, the ['UNetGradTTS'] model from this [paper](https://arxiv.org/abs/2105.06337) for text-to-speech, ['UNetLDMModel'] for latent-diffusion models in this [paper](https://arxiv.org/abs/2112.10752), and the ['TemporalUNet'] used for time-series prediciton in this reinforcement learning [paper](https://arxiv.org/abs/2205.09991).
- TODO: mention VAE / SDE score estimation
\ No newline at end of file
...@@ -9,6 +9,7 @@ from .utils import ( ...@@ -9,6 +9,7 @@ from .utils import (
__version__ = "0.3.0.dev0" __version__ = "0.3.0.dev0"
from .configuration_utils import ConfigMixin
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
from .onnx_utils import OnnxRuntimeModel from .onnx_utils import OnnxRuntimeModel
......
...@@ -37,9 +37,16 @@ _re_configuration_file = re.compile(r"config\.(.*)\.json") ...@@ -37,9 +37,16 @@ _re_configuration_file = re.compile(r"config\.(.*)\.json")
class ConfigMixin: class ConfigMixin:
r""" r"""
Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all
methods for loading/downloading/saving configurations. methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with
- [`~ConfigMixin.from_config`]
- [`~ConfigMixin.save_config`]
Class attributes:
- **config_name** (`str`) -- A filename under which the config should stored when calling
[`~ConfigMixin.save_config`] (should be overriden by parent class).
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
overriden by parent class).
""" """
config_name = None config_name = None
ignore_for_config = [] ignore_for_config = []
...@@ -74,8 +81,6 @@ class ConfigMixin: ...@@ -74,8 +81,6 @@ class ConfigMixin:
Args: Args:
save_directory (`str` or `os.PathLike`): save_directory (`str` or `os.PathLike`):
Directory where the configuration JSON file will be saved (will be created if it does not exist). Directory where the configuration JSON file will be saved (will be created if it does not exist).
kwargs:
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
""" """
if os.path.isfile(save_directory): if os.path.isfile(save_directory):
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
...@@ -90,6 +95,64 @@ class ConfigMixin: ...@@ -90,6 +95,64 @@ class ConfigMixin:
@classmethod @classmethod
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs): def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
r"""
Instantiate a Python class from a pre-defined JSON-file.
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
organization name, like `google/ddpm-celebahq-256`.
- A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g.,
`./my_model_directory/`.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
checkpoint with 3 labels).
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info(`bool`, *optional*, defaults to `False`):
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `transformers-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
mirror (`str`, *optional*):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
Please refer to the mirror site for more information.
<Tip>
Passing `use_auth_token=True`` is required when you want to use a private model.
</Tip>
<Tip>
Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
use this method in a firewalled environment.
</Tip>
"""
config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs) init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
...@@ -298,10 +361,10 @@ class FrozenDict(OrderedDict): ...@@ -298,10 +361,10 @@ class FrozenDict(OrderedDict):
def register_to_config(init): def register_to_config(init):
""" r"""
Decorator to apply on the init of classes inheriting from `ConfigMixin` so that all the arguments are automatically Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that shouldn't be automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
registered in the config, use the `ignore_for_config` class variable shouldn't be registered in the config, use the `ignore_for_config` class variable
Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init! Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
""" """
......
...@@ -119,8 +119,6 @@ class ModelMixin(torch.nn.Module): ...@@ -119,8 +119,6 @@ class ModelMixin(torch.nn.Module):
[`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
and saving models. and saving models.
Class attributes:
- **config_name** ([`str`]) -- A filename under which the model should be stored when calling - **config_name** ([`str`]) -- A filename under which the model should be stored when calling
[`~modeling_utils.ModelMixin.save_pretrained`]. [`~modeling_utils.ModelMixin.save_pretrained`].
""" """
...@@ -200,10 +198,9 @@ class ModelMixin(torch.nn.Module): ...@@ -200,10 +198,9 @@ class ModelMixin(torch.nn.Module):
Can be either: Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
user or organization name, like `dbmdz/bert-base-german-cased`. - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_pretrained`], `./my_model_directory/`.
e.g., `./my_model_directory/`.
cache_dir (`Union[str, os.PathLike]`, *optional*): cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the Path to a directory in which a downloaded pretrained model configuration should be cached if the
...@@ -236,9 +233,6 @@ class ModelMixin(torch.nn.Module): ...@@ -236,9 +233,6 @@ class ModelMixin(torch.nn.Module):
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
Please refer to the mirror site for more information. Please refer to the mirror site for more information.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the [`ConfigMixin`] of the model (after it being loaded).
<Tip> <Tip>
Passing `use_auth_token=True`` is required when you want to use a private model. Passing `use_auth_token=True`` is required when you want to use a private model.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment