Unverified Commit e5810e68 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Variant] Add "variant" as input kwarg so to have better UX when downloading...


[Variant] Add "variant" as input kwarg so to have better UX when downloading no_ema or fp16 weights (#2305)

* [Variant] Add variant loading mechanism

* clean

* improve further

* up

* add tests

* add some first tests

* up

* up

* use path splittetx

* add deprecate

* deprecation warnings

* improve docs

* up

* up

* up

* fix tests

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* correct code format

* fix warning

* finish

* Apply suggestions from code review
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* Apply suggestions from code review
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* Update docs/source/en/using-diffusers/loading.mdx
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* Apply suggestions from code review
Co-authored-by: default avatarWill Berman <wlbberman@gmail.com>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* correct loading docs

* finish

---------
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: default avatarWill Berman <wlbberman@gmail.com>
parent e3ddbe25
...@@ -23,31 +23,50 @@ In the following we explain in-detail how to easily load: ...@@ -23,31 +23,50 @@ In the following we explain in-detail how to easily load:
## Loading pipelines ## Loading pipelines
The [`DiffusionPipeline`] class is the easiest way to access any diffusion model that is [available on the Hub](https://huggingface.co/models?library=diffusers). Let's look at an example on how to download [CompVis' Latent Diffusion model](https://huggingface.co/CompVis/ldm-text2im-large-256). The [`DiffusionPipeline`] class is the easiest way to access any diffusion model that is [available on the Hub](https://huggingface.co/models?library=diffusers). Let's look at an example on how to download [Runway's Stable Diffusion model](https://huggingface.co/runwayml/stable-diffusion-v1-5).
```python ```python
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
repo_id = "CompVis/ldm-text2im-large-256" repo_id = "runwayml/stable-diffusion-v1-5"
ldm = DiffusionPipeline.from_pretrained(repo_id) pipe = DiffusionPipeline.from_pretrained(repo_id)
``` ```
Here [`DiffusionPipeline`] automatically detects the correct pipeline (*i.e.* [`LDMTextToImagePipeline`]), downloads and caches all required configuration and weight files (if not already done so), and finally returns a pipeline instance, called `ldm`. Here [`DiffusionPipeline`] automatically detects the correct pipeline (*i.e.* [`StableDiffusionPipeline`]), downloads and caches all required configuration and weight files (if not already done so), and finally returns a pipeline instance, called `pipe`.
The pipeline instance can then be called using [`LDMTextToImagePipeline.__call__`] (i.e., `ldm("image of a astronaut riding a horse")`) for text-to-image generation. The pipeline instance can then be called using [`StableDiffusionPipeline.__call__`] (i.e., `pipe("image of a astronaut riding a horse")`) for text-to-image generation.
Instead of using the generic [`DiffusionPipeline`] class for loading, you can also load the appropriate pipeline class directly. The code snippet above yields the same instance as when doing: Instead of using the generic [`DiffusionPipeline`] class for loading, you can also load the appropriate pipeline class directly. The code snippet above yields the same instance as when doing:
```python ```python
from diffusers import LDMTextToImagePipeline from diffusers import StableDiffusionPipeline
repo_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(repo_id)
```
<Tip>
Many checkpoints, such as [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) and [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) can be used for multiple tasks, *e.g.* *text-to-image* or *image-to-image*.
If you want to use those checkpoints for a task that is different from the default one, you have to load it directly from the corresponding task-specific pipeline class:
```python
from diffusers import StableDiffusionImg2ImgPipeline
repo_id = "CompVis/ldm-text2im-large-256" repo_id = "runwayml/stable-diffusion-v1-5"
ldm = LDMTextToImagePipeline.from_pretrained(repo_id) pipe = StableDiffusionImg2ImgPipeline.from_pretrained(repo_id)
``` ```
Diffusion pipelines like `LDMTextToImagePipeline` often consist of multiple components. These components can be both parameterized models, such as `"unet"`, `"vqvae"` and "bert", tokenizers or schedulers. These components can interact in complex ways with each other when using the pipeline in inference, *e.g.* for [`LDMTextToImagePipeline`] or [`StableDiffusionPipeline`] the inference call is explained [here](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work). </Tip>
Diffusion pipelines like `StableDiffusionPipeline` or `StableDiffusionImg2ImgPipeline` consist of multiple components. These components can be both parameterized models, such as `"unet"`, `"vae"` and `"text_encoder"`, tokenizers or schedulers.
These components often interact in complex ways with each other when using the pipeline in inference, *e.g.* for [`StableDiffusionPipeline`] the inference call is explained [here](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work).
The purpose of the [pipeline classes](./api/overview#diffusers-summary) is to wrap the complexity of these diffusion systems and give the user an easy-to-use API while staying flexible for customization, as will be shown later. The purpose of the [pipeline classes](./api/overview#diffusers-summary) is to wrap the complexity of these diffusion systems and give the user an easy-to-use API while staying flexible for customization, as will be shown later.
### Loading pipelines that require access request <!---
THE FOLLOWING CAN BE UNCOMMENTED ONCE WE HAVE NEW MODELS WITH ACCESS REQUIREMENT
# Loading pipelines that require access request
Due to the capabilities of diffusion models to generate extremely realistic images, there is a certain danger that such models might be misused for unwanted applications, *e.g.* generating pornography or violent images. Due to the capabilities of diffusion models to generate extremely realistic images, there is a certain danger that such models might be misused for unwanted applications, *e.g.* generating pornography or violent images.
In order to minimize the possibility of such unsolicited use cases, some of the most powerful diffusion models require users to acknowledge a license before being able to use the model. If the user does not agree to the license, the pipeline cannot be downloaded. In order to minimize the possibility of such unsolicited use cases, some of the most powerful diffusion models require users to acknowledge a license before being able to use the model. If the user does not agree to the license, the pipeline cannot be downloaded.
...@@ -94,6 +113,7 @@ stable_diffusion = DiffusionPipeline.from_pretrained(repo_id, use_auth_token="<y ...@@ -94,6 +113,7 @@ stable_diffusion = DiffusionPipeline.from_pretrained(repo_id, use_auth_token="<y
``` ```
The final option to use pipelines that require access without having to rely on the Hugging Face Hub is to load the pipeline locally as explained in the next section. The final option to use pipelines that require access without having to rely on the Hugging Face Hub is to load the pipeline locally as explained in the next section.
-->
### Loading pipelines locally ### Loading pipelines locally
...@@ -101,9 +121,9 @@ If you prefer to have complete control over the pipeline and its corresponding f ...@@ -101,9 +121,9 @@ If you prefer to have complete control over the pipeline and its corresponding f
we recommend loading pipelines locally. we recommend loading pipelines locally.
To load a diffusion pipeline locally, you first need to manually download the whole folder structure on your local disk and then pass a local path to the [`DiffusionPipeline.from_pretrained`]. Let's again look at an example for To load a diffusion pipeline locally, you first need to manually download the whole folder structure on your local disk and then pass a local path to the [`DiffusionPipeline.from_pretrained`]. Let's again look at an example for
[CompVis' Latent Diffusion model](https://huggingface.co/CompVis/ldm-text2im-large-256). [Runway's Stable Diffusion Diffusion model](https://huggingface.co/runwayml/stable-diffusion-v1-5).
First, you should make use of [`git-lfs`](https://git-lfs.github.com/) to download the whole folder structure that has been uploaded to the [model repository](https://huggingface.co/CompVis/ldm-text2im-large-256/tree/main): First, you should make use of [`git-lfs`](https://git-lfs.github.com/) to download the whole folder structure that has been uploaded to the [model repository](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main):
``` ```
git lfs install git lfs install
...@@ -178,105 +198,324 @@ stable_diffusion_img2img = StableDiffusionImg2ImgPipeline(**components) ...@@ -178,105 +198,324 @@ stable_diffusion_img2img = StableDiffusionImg2ImgPipeline(**components)
Note how the above code snippet makes use of [`DiffusionPipeline.components`]. Note how the above code snippet makes use of [`DiffusionPipeline.components`].
### Loading variants
Diffusion Pipeline checkpoints can offer variants of the "main" diffusion pipeline checkpoint.
Such checkpoint variants are usually variations of the checkpoint that have advantages for specific use-cases and that are so similar to the "main" checkpoint that they **should not** be put in a new checkpoint.
A variation of a checkpoint has to have **exactly** the same serialization format and **exactly** the same model structure, including all weights having the same tensor shapes.
Examples of variations are different floating point types and non-ema weights. I.e. "fp16", "bf16", and "no_ema" are common variations.
#### Let's first talk about whats **not** checkpoint variant,
Checkpoint variants do **not** include different serialization formats (such as [safetensors](https://huggingface.co/docs/diffusers/main/en/using-diffusers/using_safetensors)) as weights in different serialization formats are
identical to the weights of the "main" checkpoint, just loaded in a different framework.
Also variants do not correspond to different model structures, *e.g.* [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) is not a variant of [stable-diffusion-2-0](https://huggingface.co/stabilityai/stable-diffusion-2) since the model structure is different (Stable Diffusion 1-5 uses a different `CLIPTextModel` compared to Stable Diffusion 2.0).
Pipeline checkpoints that are identical in model structure, but have been trained on different datasets, trained with vastly different training setups and thus correspond to different official releases (such as [Stable Diffusion v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) and [Stable Diffusion v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5)) should probably be stored in individual repositories instead of as variations of eachother.
#### So what are checkpoint variants then?
Checkpoint variants usually consist of the checkpoint stored in "*low-precision, low-storage*" dtype so that less bandwith is required to download them, or of *non-exponential-averaged* weights that shall be used when continuing fine-tuning from the checkpoint.
Both use cases have clear advantages when their weights are considered variants: they share the same serialization format as the reference weights, and they correspond to a specialization of the "main" checkpoint which does not warrant a new model repository.
A checkpoint stored in [torch's half-precision / float16 format](https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/) requires only half the bandwith and storage when downloading the checkpoint,
**but** cannot be used when continuing training or when running the checkpoint on CPU.
Similarly the *non-exponential-averaged* (or non-EMA) version of the checkpoint should be used when continuing fine-tuning of the model checkpoint, **but** should not be used when using the checkpoint for inference.
#### How to save and load variants
Saving a diffusion pipeline as a variant can be done by providing [`DiffusionPipeline.save_pretrained`] with the `variant` argument.
The `variant` extends the weight name by the provided variation, by changing the default weight name from `diffusion_pytorch_model.bin` to `diffusion_pytorch_model.{variant}.bin` or from `diffusion_pytorch_model.safetensors` to `diffusion_pytorch_model.{variant}.safetensors`. By doing so, one creates a variant of the pipeline checkpoint that can be loaded **instead** of the "main" pipeline checkpoint.
Let's have a look at how we could create a float16 variant of a pipeline. First, we load
the "main" variant of a checkpoint (stored in `float32` precision) into mixed precision format, using `torch_dtype=torch.float16`.
```py
from diffusers import DiffusionPipeline
import torch
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
```
Now all model components of the pipeline are stored in half-precision dtype. We can now save the
pipeline under a `"fp16"` variant as follows:
```py
pipe.save_pretrained("./stable-diffusion-v1-5", variant="fp16")
```
If we don't save into an existing `stable-diffusion-v1-5` folder the new folder would look as follows:
```
stable-diffusion-v1-5
├── feature_extractor
│   └── preprocessor_config.json
├── model_index.json
├── safety_checker
│   ├── config.json
│   └── pytorch_model.fp16.bin
├── scheduler
│   └── scheduler_config.json
├── text_encoder
│   ├── config.json
│   └── pytorch_model.fp16.bin
├── tokenizer
│   ├── merges.txt
│   ├── special_tokens_map.json
│   ├── tokenizer_config.json
│   └── vocab.json
├── unet
│   ├── config.json
│   └── diffusion_pytorch_model.fp16.bin
└── vae
├── config.json
└── diffusion_pytorch_model.fp16.bin
```
As one can see, all model files now have a `.fp16.bin` extension instead of just `.bin`.
The variant now has to be loaded by also passing a `variant="fp16"` to [`DiffusionPipeline.from_pretrained`], e.g.:
```py
DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5", variant="fp16", torch_dtype=torch.float16)
```
works just fine, while:
```py
DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5", torch_dtype=torch.float16)
```
throws an Exception:
```
OSError: Error no file named diffusion_pytorch_model.bin found in directory ./stable-diffusion-v1-45/vae since we **only** stored the model
```
This is expected as we don't have any "non-variant" checkpoint files saved locally.
However, the whole idea of pipeline variants is that they can co-exist with the "main" variant,
so one would typically also save the "main" variant in the same folder. Let's do this:
```py
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipe.save_pretrained("./stable-diffusion-v1-5")
```
and upload the pipeline to the Hub under [diffusers/stable-diffusion-variants](https://huggingface.co/diffusers/stable-diffusion-variants).
The file structure [on the Hub](https://huggingface.co/diffusers/stable-diffusion-variants/tree/main) now looks as follows:
```
├── feature_extractor
│   └── preprocessor_config.json
├── model_index.json
├── safety_checker
│   ├── config.json
│   ├── pytorch_model.bin
│   └── pytorch_model.fp16.bin
├── scheduler
│   └── scheduler_config.json
├── text_encoder
│   ├── config.json
│   ├── pytorch_model.bin
│   └── pytorch_model.fp16.bin
├── tokenizer
│   ├── merges.txt
│   ├── special_tokens_map.json
│   ├── tokenizer_config.json
│   └── vocab.json
├── unet
│   ├── config.json
│   ├── diffusion_pytorch_model.bin
│   ├── diffusion_pytorch_model.fp16.bin
└── vae
├── config.json
├── diffusion_pytorch_model.bin
└── diffusion_pytorch_model.fp16.bin
```
We can now both download the "main" and the "fp16" variant from the Hub. Both:
```py
pipe = DiffusionPipeline.from_pretrained("diffusers/stable-diffusion-variants")
```
and
```py
pipe = DiffusionPipeline.from_pretrained("diffusers/stable-diffusion-variants", variant="fp16")
```
works.
<Tip>
Note that Diffusers never downloads more checkpoints than needed. E.g. when downloading
the "main" variant, none of the "fp16.bin" files are downloaded and cached.
Only when the user specifies `variant="fp16"` are those files downloaded and cached.
</Tip>
Finally, there are cases where only some of the checkpoint files of the pipeline are of a certain
variation. E.g. it's usually only the UNet checkpoint that has both a *exponential-mean-averaged* (EMA) and a *non-exponential-mean-averaged* (non-EMA) version. All other model components, e.g. the text encoder, safety checker or variational auto-encoder usually don't have such a variation.
In such a case, one would upload just the UNet's checkpoint file with a `non_ema` version format (as done [here](https://huggingface.co/diffusers/stable-diffusion-variants/blob/main/unet/diffusion_pytorch_model.non_ema.bin)) and upon calling:
```python
pipe = DiffusionPipeline.from_pretrained("diffusers/stable-diffusion-variants", variant="non_ema")
```
the model will use only the "non_ema" checkpoint variant if it is available - otherwise it'll load the
"main" variation. In the above example, `variant="non_ema"` would therefore download the following file structure:
```
├── feature_extractor
│   └── preprocessor_config.json
├── model_index.json
├── safety_checker
│   ├── config.json
│   ├── pytorch_model.bin
├── scheduler
│   └── scheduler_config.json
├── text_encoder
│   ├── config.json
│   ├── pytorch_model.bin
├── tokenizer
│   ├── merges.txt
│   ├── special_tokens_map.json
│   ├── tokenizer_config.json
│   └── vocab.json
├── unet
│   ├── config.json
│   └── diffusion_pytorch_model.non_ema.bin
└── vae
├── config.json
├── diffusion_pytorch_model.bin
```
In a nutshell, using `variant="{variant}"` will download all files that match the `{variant}` and if for a model component such a file variant is not present it will download the "main" variant. If neither a "main" or `{variant}` variant is available, an error will the thrown.
### How does loading work? ### How does loading work?
As a class method, [`DiffusionPipeline.from_pretrained`] is responsible for two things: As a class method, [`DiffusionPipeline.from_pretrained`] is responsible for two things:
- Download the latest version of the folder structure required to run the `repo_id` with `diffusers` and cache them. If the latest folder structure is available in the local cache, [`DiffusionPipeline.from_pretrained`] will simply reuse the cache and **not** re-download the files. - Download the latest version of the folder structure required to run the `repo_id` with `diffusers` and cache them. If the latest folder structure is available in the local cache, [`DiffusionPipeline.from_pretrained`] will simply reuse the cache and **not** re-download the files.
- Load the cached weights into the _correct_ pipeline class – one of the [officially supported pipeline classes](./api/overview#diffusers-summary) - and return an instance of the class. The _correct_ pipeline class is thereby retrieved from the `model_index.json` file. - Load the cached weights into the _correct_ pipeline class – one of the [officially supported pipeline classes](./api/overview#diffusers-summary) - and return an instance of the class. The _correct_ pipeline class is thereby retrieved from the `model_index.json` file.
The underlying folder structure of diffusion pipelines correspond 1-to-1 to their corresponding class instances, *e.g.* [`LDMTextToImagePipeline`] for [`CompVis/ldm-text2im-large-256`](https://huggingface.co/CompVis/ldm-text2im-large-256) The underlying folder structure of diffusion pipelines correspond 1-to-1 to their corresponding class instances, *e.g.* [`StableDiffusionPipeline`] for [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5)
This can be understood better by looking at an example. Let's print out pipeline class instance `pipeline` we just defined: This can be better understood by looking at an example. Let's load a pipeline class instance `pipe` and print it:
```python ```python
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
repo_id = "CompVis/ldm-text2im-large-256" repo_id = "runwayml/stable-diffusion-v1-5"
ldm = DiffusionPipeline.from_pretrained(repo_id) pipe = DiffusionPipeline.from_pretrained(repo_id)
print(ldm) print(pipe)
``` ```
*Output*: *Output*:
``` ```
LDMTextToImagePipeline { StableDiffusionPipeline {
"bert": [ "feature_extractor": [
"latent_diffusion", "transformers",
"LDMBertModel" "CLIPFeatureExtractor"
],
"safety_checker": [
"stable_diffusion",
"StableDiffusionSafetyChecker"
], ],
"scheduler": [ "scheduler": [
"diffusers", "diffusers",
"DDIMScheduler" "PNDMScheduler"
],
"text_encoder": [
"transformers",
"CLIPTextModel"
], ],
"tokenizer": [ "tokenizer": [
"transformers", "transformers",
"BertTokenizer" "CLIPTokenizer"
], ],
"unet": [ "unet": [
"diffusers", "diffusers",
"UNet2DConditionModel" "UNet2DConditionModel"
], ],
"vqvae": [ "vae": [
"diffusers", "diffusers",
"AutoencoderKL" "AutoencoderKL"
] ]
} }
``` ```
First, we see that the official pipeline is the [`LDMTextToImagePipeline`], and second we see that the `LDMTextToImagePipeline` consists of 5 components: First, we see that the official pipeline is the [`StableDiffusionPipeline`], and second we see that the `StableDiffusionPipeline` consists of 7 components:
- `"bert"` of class `LDMBertModel` as defined [in the pipeline](https://github.com/huggingface/diffusers/blob/cd502b25cf0debac6f98d27a6638ef95208d1ea2/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py#L664) - `"feature_extractor"` of class `CLIPFeatureExtractor` as defined [in `transformers`](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPFeatureExtractor).
- `"scheduler"` of class [`DDIMScheduler`] - `"safety_checker"` as defined [here](https://github.com/huggingface/diffusers/blob/e55687e1e15407f60f32242027b7bb8170e58266/src/diffusers/pipelines/stable_diffusion/safety_checker.py#L32).
- `"tokenizer"` of class `BertTokenizer` as defined [in `transformers`](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer) - `"scheduler"` of class [`PNDMScheduler`].
- `"unet"` of class [`UNet2DConditionModel`] - `"text_encoder"` of class `CLIPTextModel` as defined [in `transformers`](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPTextModel).
- `"vqvae"` of class [`AutoencoderKL`] - `"tokenizer"` of class `CLIPTokenizer` as defined [in `transformers`](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPTokenizer).
- `"unet"` of class [`UNet2DConditionModel`].
- `"vae"` of class [`AutoencoderKL`].
Let's now compare the pipeline instance to the folder structure of the model repository `CompVis/ldm-text2im-large-256`. Looking at the folder structure of [`CompVis/ldm-text2im-large-256`](https://huggingface.co/CompVis/ldm-text2im-large-256/tree/main) on the Hub, we can see it matches 1-to-1 the printed out instance of `LDMTextToImagePipeline` above: Let's now compare the pipeline instance to the folder structure of the model repository `runwayml/stable-diffusion-v1-5`. Looking at the folder structure of [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main) on the Hub and excluding model and saving format variants, we can see it matches 1-to-1 the printed out instance of `StableDiffusionPipeline` above:
``` ```
. .
├── bert ├── feature_extractor
│   └── preprocessor_config.json
├── model_index.json
├── safety_checker
│   ├── config.json │   ├── config.json
│   └── pytorch_model.bin │   └── pytorch_model.bin
├── model_index.json
├── scheduler ├── scheduler
│   └── scheduler_config.json │   └── scheduler_config.json
├── text_encoder
│   ├── config.json
│   └── pytorch_model.bin
├── tokenizer ├── tokenizer
│   ├── merges.txt
│   ├── special_tokens_map.json │   ├── special_tokens_map.json
│   ├── tokenizer_config.json │   ├── tokenizer_config.json
│   └── vocab.txt │   └── vocab.json
├── unet ├── unet
│   ├── config.json │   ├── config.json
│   ── diffusion_pytorch_model.bin │   ── diffusion_pytorch_model.bin
└── vqvae └── vae
├── config.json ├── config.json
── diffusion_pytorch_model.bin ── diffusion_pytorch_model.bin
``` ```
As we can see each attribute of the instance of `LDMTextToImagePipeline` has its configuration and possibly weights defined in a subfolder that is called **exactly** like the class attribute (`"bert"`, `"scheduler"`, `"tokenizer"`, `"unet"`, `"vqvae"`). Importantly, every pipeline expects a `model_index.json` file that tells the `DiffusionPipeline` both: Each attribute of the instance of `StableDiffusionPipeline` has its configuration and possibly weights defined in a subfolder that is called **exactly** like the class attribute (`"feature_extractor"`, `"safety_checker"`, `"scheduler"`, `"text_encoder"`, `"tokenizer"`, `"unet"`, `"vae"`). Importantly, every pipeline expects a `model_index.json` file that tells the `DiffusionPipeline` both:
- which pipeline class should be loaded, and - which pipeline class should be loaded, and
- what sub-classes from which library are stored in which subfolders - what sub-classes from which library are stored in which subfolders
In the case of `CompVis/ldm-text2im-large-256` the `model_index.json` is therefore defined as follows: In the case of `runwayml/stable-diffusion-v1-5` the `model_index.json` is therefore defined as follows:
``` ```
{ {
"_class_name": "LDMTextToImagePipeline", "_class_name": "StableDiffusionPipeline",
"_diffusers_version": "0.0.4", "_diffusers_version": "0.6.0",
"bert": [ "feature_extractor": [
"latent_diffusion", "transformers",
"LDMBertModel" "CLIPFeatureExtractor"
],
"safety_checker": [
"stable_diffusion",
"StableDiffusionSafetyChecker"
], ],
"scheduler": [ "scheduler": [
"diffusers", "diffusers",
"DDIMScheduler" "PNDMScheduler"
],
"text_encoder": [
"transformers",
"CLIPTextModel"
], ],
"tokenizer": [ "tokenizer": [
"transformers", "transformers",
"BertTokenizer" "CLIPTokenizer"
], ],
"unet": [ "unet": [
"diffusers", "diffusers",
"UNet2DConditionModel" "UNet2DConditionModel"
], ],
"vqvae": [ "vae": [
"diffusers", "diffusers",
"AutoencoderKL" "AutoencoderKL"
] ]
...@@ -292,10 +531,36 @@ In the case of `CompVis/ldm-text2im-large-256` the `model_index.json` is therefo ...@@ -292,10 +531,36 @@ In the case of `CompVis/ldm-text2im-large-256` the `model_index.json` is therefo
"class" "class"
] ]
``` ```
- The `"name"` field corresponds both to the name of the subfolder in which the configuration and weights are stored as well as the attribute name of the pipeline class (as can be seen [here](https://huggingface.co/CompVis/ldm-text2im-large-256/tree/main/bert) and [here](https://github.com/huggingface/diffusers/blob/cd502b25cf0debac6f98d27a6638ef95208d1ea2/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py#L42) - The `"name"` field corresponds both to the name of the subfolder in which the configuration and weights are stored as well as the attribute name of the pipeline class (as can be seen [here](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/bert) and [here](https://github.com/huggingface/diffusers/blob/cd502b25cf0debac6f98d27a6638ef95208d1ea2/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py#L42)
- The `"library"` field corresponds to the name of the library, *e.g.* `diffusers` or `transformers` from which the `"class"` should be loaded - The `"library"` field corresponds to the name of the library, *e.g.* `diffusers` or `transformers` from which the `"class"` should be loaded
- The `"class"` field corresponds to the name of the class, *e.g.* [`BertTokenizer`](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer) or [`UNet2DConditionModel`] - The `"class"` field corresponds to the name of the class, *e.g.* [`CLIPTokenizer`](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPTokenizer) or [`UNet2DConditionModel`]
<!--
TODO(Patrick) - Make sure to uncomment this part as soon as things are deprecated.
#### Using `revision` to load pipeline variants is deprecated
Previously the `revision` argument of [`DiffusionPipeline.from_pretrained`] was heavily used to
load model variants, e.g.:
```python
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16")
```
However, this behavior is now deprecated since the "revision" argument should (just as it's done in GitHub) better be used to load model checkpoints from a specific commit or branch in development.
The above example is therefore deprecated and won't be supported anymore for `diffusers >= 1.0.0`.
<Tip warning={true}>
If you load diffusers pipelines or models with `revision="fp16"` or `revision="non_ema"`,
please make sure to update to code and use `variant="fp16"` or `variation="non_ema"` respectively
instead.
</Tip>
-->
## Loading models ## Loading models
...@@ -310,19 +575,19 @@ Let's look at an example: ...@@ -310,19 +575,19 @@ Let's look at an example:
```python ```python
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
repo_id = "CompVis/ldm-text2im-large-256" repo_id = "runwayml/stable-diffusion-v1-5"
model = UNet2DConditionModel.from_pretrained(repo_id, subfolder="unet") model = UNet2DConditionModel.from_pretrained(repo_id, subfolder="unet")
``` ```
Note how we have to define the `subfolder="unet"` argument to tell [`ModelMixin.from_pretrained`] that the model weights are located in a [subfolder of the repository](https://huggingface.co/CompVis/ldm-text2im-large-256/tree/main/unet). Note how we have to define the `subfolder="unet"` argument to tell [`ModelMixin.from_pretrained`] that the model weights are located in a [subfolder of the repository](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/unet).
As explained in [Loading customized pipelines]("./using-diffusers/loading#loading-customized-pipelines"), one can pass a loaded model to a diffusion pipeline, via [`DiffusionPipeline.from_pretrained`]: As explained in [Loading customized pipelines]("./using-diffusers/loading#loading-customized-pipelines"), one can pass a loaded model to a diffusion pipeline, via [`DiffusionPipeline.from_pretrained`]:
```python ```python
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
repo_id = "CompVis/ldm-text2im-large-256" repo_id = "runwayml/stable-diffusion-v1-5"
ldm = DiffusionPipeline.from_pretrained(repo_id, unet=model) pipe = DiffusionPipeline.from_pretrained(repo_id, unet=model)
``` ```
If the model files can be found directly at the root level, which is usually only the case for some very simple diffusion models, such as [`google/ddpm-cifar10-32`](https://huggingface.co/google/ddpm-cifar10-32), we don't If the model files can be found directly at the root level, which is usually only the case for some very simple diffusion models, such as [`google/ddpm-cifar10-32`](https://huggingface.co/google/ddpm-cifar10-32), we don't
...@@ -335,6 +600,18 @@ repo_id = "google/ddpm-cifar10-32" ...@@ -335,6 +600,18 @@ repo_id = "google/ddpm-cifar10-32"
model = UNet2DModel.from_pretrained(repo_id) model = UNet2DModel.from_pretrained(repo_id)
``` ```
As motivated in [How to save and load variants?](#how-to-save-and-load-variants), models can load and
save variants. To load a model variant, one should pass the `variant` function argument to [`ModelMixin.from_pretrained`]. Analogous, to save a model variant, one should pass the `variant` function argument to [`ModelMixin.save_pretrained`]:
```python
from diffusers import UNet2DConditionModel
model = UNet2DConditionModel.from_pretrained(
"diffusers/stable-diffusion-variants", subfolder="unet", variant="non_ema"
)
model.save_pretrained("./local-unet", variant="non_ema")
```
## Loading schedulers ## Loading schedulers
Schedulers rely on [`SchedulerMixin.from_pretrained`]. Schedulers are **not parameterized** or **trained**, but instead purely defined by a configuration file. Schedulers rely on [`SchedulerMixin.from_pretrained`]. Schedulers are **not parameterized** or **trained**, but instead purely defined by a configuration file.
......
...@@ -16,18 +16,21 @@ ...@@ -16,18 +16,21 @@
import inspect import inspect
import os import os
import warnings
from functools import partial from functools import partial
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
import torch import torch
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from packaging import version
from requests import HTTPError from requests import HTTPError
from torch import Tensor, device from torch import Tensor, device
from .. import __version__ from .. import __version__
from ..utils import ( from ..utils import (
CONFIG_NAME, CONFIG_NAME,
DEPRECATED_REVISION_ARGS,
DIFFUSERS_CACHE, DIFFUSERS_CACHE,
FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_NAME,
HF_HUB_OFFLINE, HF_HUB_OFFLINE,
...@@ -89,12 +92,12 @@ def get_parameter_dtype(parameter: torch.nn.Module): ...@@ -89,12 +92,12 @@ def get_parameter_dtype(parameter: torch.nn.Module):
return first_tuple[1].dtype return first_tuple[1].dtype
def load_state_dict(checkpoint_file: Union[str, os.PathLike]): def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
""" """
Reads a checkpoint file, returning properly formatted errors if they arise. Reads a checkpoint file, returning properly formatted errors if they arise.
""" """
try: try:
if os.path.basename(checkpoint_file) == WEIGHTS_NAME: if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant):
return torch.load(checkpoint_file, map_location="cpu") return torch.load(checkpoint_file, map_location="cpu")
else: else:
return safetensors.torch.load_file(checkpoint_file, device="cpu") return safetensors.torch.load_file(checkpoint_file, device="cpu")
...@@ -141,6 +144,15 @@ def _load_state_dict_into_model(model_to_load, state_dict): ...@@ -141,6 +144,15 @@ def _load_state_dict_into_model(model_to_load, state_dict):
return error_msgs return error_msgs
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
if variant is not None:
splits = weights_name.split(".")
splits = splits[:-1] + [variant] + splits[-1:]
weights_name = ".".join(splits)
return weights_name
class ModelMixin(torch.nn.Module): class ModelMixin(torch.nn.Module):
r""" r"""
Base class for all models. Base class for all models.
...@@ -250,6 +262,7 @@ class ModelMixin(torch.nn.Module): ...@@ -250,6 +262,7 @@ class ModelMixin(torch.nn.Module):
is_main_process: bool = True, is_main_process: bool = True,
save_function: Callable = None, save_function: Callable = None,
safe_serialization: bool = False, safe_serialization: bool = False,
variant: Optional[str] = None,
): ):
""" """
Save a model and its configuration file to a directory, so that it can be re-loaded using the Save a model and its configuration file to a directory, so that it can be re-loaded using the
...@@ -268,6 +281,8 @@ class ModelMixin(torch.nn.Module): ...@@ -268,6 +281,8 @@ class ModelMixin(torch.nn.Module):
`DIFFUSERS_SAVE_MODE`. `DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `False`): safe_serialization (`bool`, *optional*, defaults to `False`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
variant (`str`, *optional*):
If specified, weights are saved in the format pytorch_model.<variant>.bin.
""" """
if safe_serialization and not is_safetensors_available(): if safe_serialization and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.") raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
...@@ -292,6 +307,7 @@ class ModelMixin(torch.nn.Module): ...@@ -292,6 +307,7 @@ class ModelMixin(torch.nn.Module):
state_dict = model_to_save.state_dict() state_dict = model_to_save.state_dict()
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
weights_name = _add_variant(weights_name, variant)
# Save the model # Save the model
save_function(state_dict, os.path.join(save_directory, weights_name)) save_function(state_dict, os.path.join(save_directory, weights_name))
...@@ -371,6 +387,9 @@ class ModelMixin(torch.nn.Module): ...@@ -371,6 +387,9 @@ class ModelMixin(torch.nn.Module):
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
setting this argument to `True` will raise an error. setting this argument to `True` will raise an error.
variant (`str`, *optional*):
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
ignored when using `from_flax`.
<Tip> <Tip>
...@@ -401,6 +420,7 @@ class ModelMixin(torch.nn.Module): ...@@ -401,6 +420,7 @@ class ModelMixin(torch.nn.Module):
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
device_map = kwargs.pop("device_map", None) device_map = kwargs.pop("device_map", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None)
if low_cpu_mem_usage and not is_accelerate_available(): if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False low_cpu_mem_usage = False
...@@ -488,7 +508,7 @@ class ModelMixin(torch.nn.Module): ...@@ -488,7 +508,7 @@ class ModelMixin(torch.nn.Module):
try: try:
model_file = _get_model_file( model_file = _get_model_file(
pretrained_model_name_or_path, pretrained_model_name_or_path,
weights_name=SAFETENSORS_WEIGHTS_NAME, weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
cache_dir=cache_dir, cache_dir=cache_dir,
force_download=force_download, force_download=force_download,
resume_download=resume_download, resume_download=resume_download,
...@@ -504,7 +524,7 @@ class ModelMixin(torch.nn.Module): ...@@ -504,7 +524,7 @@ class ModelMixin(torch.nn.Module):
if model_file is None: if model_file is None:
model_file = _get_model_file( model_file = _get_model_file(
pretrained_model_name_or_path, pretrained_model_name_or_path,
weights_name=WEIGHTS_NAME, weights_name=_add_variant(WEIGHTS_NAME, variant),
cache_dir=cache_dir, cache_dir=cache_dir,
force_download=force_download, force_download=force_download,
resume_download=resume_download, resume_download=resume_download,
...@@ -538,7 +558,7 @@ class ModelMixin(torch.nn.Module): ...@@ -538,7 +558,7 @@ class ModelMixin(torch.nn.Module):
# if device_map is None, load the state dict and move the params from meta device to the cpu # if device_map is None, load the state dict and move the params from meta device to the cpu
if device_map is None: if device_map is None:
param_device = "cpu" param_device = "cpu"
state_dict = load_state_dict(model_file) state_dict = load_state_dict(model_file, variant=variant)
# move the params from meta device to cpu # move the params from meta device to cpu
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
if len(missing_keys) > 0: if len(missing_keys) > 0:
...@@ -587,7 +607,7 @@ class ModelMixin(torch.nn.Module): ...@@ -587,7 +607,7 @@ class ModelMixin(torch.nn.Module):
) )
model = cls.from_config(config, **unused_kwargs) model = cls.from_config(config, **unused_kwargs)
state_dict = load_state_dict(model_file) state_dict = load_state_dict(model_file, variant=variant)
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model, model,
...@@ -800,8 +820,38 @@ def _get_model_file( ...@@ -800,8 +820,38 @@ def _get_model_file(
f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}." f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
) )
else: else:
# 1. First check if deprecated way of loading from branches is used
if (
revision in DEPRECATED_REVISION_ARGS
and (weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME)
and version.parse(version.parse(__version__).base_version) >= version.parse("0.15.0")
):
try:
model_file = hf_hub_download(
pretrained_model_name_or_path,
filename=_add_variant(weights_name, revision),
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision,
)
warnings.warn(
f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
FutureWarning,
)
return model_file
except: # noqa: E722
warnings.warn(
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have a {_add_variant(weights_name)} file in the 'main' branch of {pretrained_model_name_or_path}. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {_add_variant(weights_name)}' so that the correct variant file can be added.",
FutureWarning,
)
try: try:
# Load from URL or cache if already cached # 2. Load model file as usual
model_file = hf_hub_download( model_file = hf_hub_download(
pretrained_model_name_or_path, pretrained_model_name_or_path,
filename=weights_name, filename=weights_name,
......
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
import importlib import importlib
import inspect import inspect
import os import os
import re
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
...@@ -31,15 +33,16 @@ from tqdm.auto import tqdm ...@@ -31,15 +33,16 @@ from tqdm.auto import tqdm
import diffusers import diffusers
from .. import __version__
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from ..utils import ( from ..utils import (
CONFIG_NAME, CONFIG_NAME,
DEPRECATED_REVISION_ARGS,
DIFFUSERS_CACHE, DIFFUSERS_CACHE,
FLAX_WEIGHTS_NAME,
HF_HUB_OFFLINE, HF_HUB_OFFLINE,
ONNX_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
BaseOutput, BaseOutput,
deprecate, deprecate,
...@@ -56,6 +59,11 @@ from ..utils import ( ...@@ -56,6 +59,11 @@ from ..utils import (
if is_transformers_available(): if is_transformers_available():
import transformers import transformers
from transformers import PreTrainedModel from transformers import PreTrainedModel
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
from ..utils import FLAX_WEIGHTS_NAME, ONNX_WEIGHTS_NAME
INDEX_FILE = "diffusion_pytorch_model.bin" INDEX_FILE = "diffusion_pytorch_model.bin"
...@@ -120,15 +128,16 @@ class AudioPipelineOutput(BaseOutput): ...@@ -120,15 +128,16 @@ class AudioPipelineOutput(BaseOutput):
audios: np.ndarray audios: np.ndarray
def is_safetensors_compatible(info) -> bool: def is_safetensors_compatible(filenames, variant=None) -> bool:
filenames = set(sibling.rfilename for sibling in info.siblings)
pt_filenames = set(filename for filename in filenames if filename.endswith(".bin")) pt_filenames = set(filename for filename in filenames if filename.endswith(".bin"))
is_safetensors_compatible = any(file.endswith(".safetensors") for file in filenames) is_safetensors_compatible = any(file.endswith(".safetensors") for file in filenames)
for pt_filename in pt_filenames: for pt_filename in pt_filenames:
_variant = f".{variant}" if (variant is not None and variant in pt_filename) else ""
prefix, raw = os.path.split(pt_filename) prefix, raw = os.path.split(pt_filename)
if raw == "pytorch_model.bin": if raw == f"pytorch_model{_variant}.bin":
# transformers specific # transformers specific
sf_filename = os.path.join(prefix, "model.safetensors") sf_filename = os.path.join(prefix, f"model{_variant}.safetensors")
else: else:
sf_filename = pt_filename[: -len(".bin")] + ".safetensors" sf_filename = pt_filename[: -len(".bin")] + ".safetensors"
if is_safetensors_compatible and sf_filename not in filenames: if is_safetensors_compatible and sf_filename not in filenames:
...@@ -137,6 +146,41 @@ def is_safetensors_compatible(info) -> bool: ...@@ -137,6 +146,41 @@ def is_safetensors_compatible(info) -> bool:
return is_safetensors_compatible return is_safetensors_compatible
def variant_compatible_siblings(info, variant=None) -> Union[List[os.PathLike], str]:
filenames = set(sibling.rfilename for sibling in info.siblings)
weight_names = [WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, FLAX_WEIGHTS_NAME, ONNX_WEIGHTS_NAME]
if is_transformers_available():
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
# model_pytorch, diffusion_model_pytorch, ...
weight_prefixes = [w.split(".")[0] for w in weight_names]
# .bin, .safetensors, ...
weight_suffixs = [w.split(".")[-1] for w in weight_names]
variant_file_regex = (
re.compile(f"({'|'.join(weight_prefixes)})(.{variant}.)({'|'.join(weight_suffixs)})")
if variant is not None
else None
)
non_variant_file_regex = re.compile(f"{'|'.join(weight_names)}")
if variant is not None:
variant_filenames = set(f for f in filenames if variant_file_regex.match(f.split("/")[-1]) is not None)
else:
variant_filenames = set()
non_variant_filenames = set(f for f in filenames if non_variant_file_regex.match(f.split("/")[-1]) is not None)
usable_filenames = set(variant_filenames)
for f in non_variant_filenames:
variant_filename = f"{f.split('.')[0]}.{variant}.{f.split('.')[1]}"
if variant_filename not in usable_filenames:
usable_filenames.add(f)
return usable_filenames, variant_filenames
class DiffusionPipeline(ConfigMixin): class DiffusionPipeline(ConfigMixin):
r""" r"""
Base class for all models. Base class for all models.
...@@ -194,6 +238,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -194,6 +238,7 @@ class DiffusionPipeline(ConfigMixin):
self, self,
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
safe_serialization: bool = False, safe_serialization: bool = False,
variant: Optional[str] = None,
): ):
""" """
Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to
...@@ -205,6 +250,8 @@ class DiffusionPipeline(ConfigMixin): ...@@ -205,6 +250,8 @@ class DiffusionPipeline(ConfigMixin):
Directory to which to save. Will be created if it doesn't exist. Directory to which to save. Will be created if it doesn't exist.
safe_serialization (`bool`, *optional*, defaults to `False`): safe_serialization (`bool`, *optional*, defaults to `False`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
variant (`str`, *optional*):
If specified, weights are saved in the format pytorch_model.<variant>.bin.
""" """
self.save_config(save_directory) self.save_config(save_directory)
...@@ -246,12 +293,15 @@ class DiffusionPipeline(ConfigMixin): ...@@ -246,12 +293,15 @@ class DiffusionPipeline(ConfigMixin):
# Call the save method with the argument safe_serialization only if it's supported # Call the save method with the argument safe_serialization only if it's supported
save_method_signature = inspect.signature(save_method) save_method_signature = inspect.signature(save_method)
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
save_method_accept_variant = "variant" in save_method_signature.parameters
save_kwargs = {}
if save_method_accept_safe: if save_method_accept_safe:
save_method( save_kwargs["safe_serialization"] = safe_serialization
os.path.join(save_directory, pipeline_component_name), safe_serialization=safe_serialization if save_method_accept_variant:
) save_kwargs["variant"] = variant
else:
save_method(os.path.join(save_directory, pipeline_component_name)) save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
def to(self, torch_device: Optional[Union[str, torch.device]] = None): def to(self, torch_device: Optional[Union[str, torch.device]] = None):
if torch_device is None: if torch_device is None:
...@@ -403,6 +453,9 @@ class DiffusionPipeline(ConfigMixin): ...@@ -403,6 +453,9 @@ class DiffusionPipeline(ConfigMixin):
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
specific pipeline class. The overwritten components are then directly passed to the pipelines specific pipeline class. The overwritten components are then directly passed to the pipelines
`__init__` method. See example below for more information. `__init__` method. See example below for more information.
variant (`str`, *optional*):
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
ignored when using `from_flax`.
<Tip> <Tip>
...@@ -454,6 +507,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -454,6 +507,7 @@ class DiffusionPipeline(ConfigMixin):
device_map = kwargs.pop("device_map", None) device_map = kwargs.pop("device_map", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
return_cached_folder = kwargs.pop("return_cached_folder", False) return_cached_folder = kwargs.pop("return_cached_folder", False)
variant = kwargs.pop("variant", None)
# 1. Download the checkpoints and configs # 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained # use snapshot download here to get it working from from_pretrained
...@@ -468,28 +522,87 @@ class DiffusionPipeline(ConfigMixin): ...@@ -468,28 +522,87 @@ class DiffusionPipeline(ConfigMixin):
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
revision=revision, revision=revision,
) )
# make sure we only download sub-folders and `diffusers` filenames
folder_names = [k for k in config_dict.keys() if not k.startswith("_")] # retrieve all folder_names that contain relevant files
allow_patterns = [os.path.join(k, "*") for k in folder_names] folder_names = [k for k, v in config_dict.items() if isinstance(v, list)]
if not local_files_only:
info = model_info(
pretrained_model_name_or_path,
use_auth_token=use_auth_token,
revision=revision,
)
model_filenames, variant_filenames = variant_compatible_siblings(info, variant=variant)
model_folder_names = set([os.path.split(f)[0] for f in model_filenames])
if revision in DEPRECATED_REVISION_ARGS and version.parse(
version.parse(__version__).base_version
) >= version.parse("0.10.0"):
info = model_info(
pretrained_model_name_or_path,
use_auth_token=use_auth_token,
revision=None,
)
comp_model_filenames, _ = variant_compatible_siblings(info, variant=revision)
comp_model_filenames = [
".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames
]
if set(comp_model_filenames) == set(model_filenames):
warnings.warn(
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{variant}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
FutureWarning,
)
else:
warnings.warn(
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.",
FutureWarning,
)
# all filenames compatible with variant will be added
allow_patterns = list(model_filenames)
# allow all patterns from non-model folders
# this enables downloading schedulers, tokenizers, ...
allow_patterns += [os.path.join(k, "*") for k in folder_names if k not in model_folder_names]
# also allow downloading config.jsons with the model
allow_patterns += [os.path.join(k, "*.json") for k in model_folder_names]
allow_patterns += [ allow_patterns += [
WEIGHTS_NAME,
SCHEDULER_CONFIG_NAME, SCHEDULER_CONFIG_NAME,
CONFIG_NAME, CONFIG_NAME,
ONNX_WEIGHTS_NAME,
cls.config_name, cls.config_name,
CUSTOM_PIPELINE_FILE_NAME,
] ]
# make sure we don't download flax weights
ignore_patterns = ["*.msgpack"]
if from_flax: if from_flax:
ignore_patterns = ["*.bin", "*.safetensors"] ignore_patterns = ["*.bin", "*.safetensors", ".onnx"]
allow_patterns += [ elif is_safetensors_available() and is_safetensors_compatible(model_filenames, variant=variant):
FLAX_WEIGHTS_NAME, ignore_patterns = ["*.bin", "*.msgpack"]
]
safetensors_variant_filenames = set([f for f in variant_filenames if f.endswith(".safetensors")])
safetensors_model_filenames = set([f for f in model_filenames if f.endswith(".safetensors")])
if (
len(safetensors_variant_filenames) > 0
and safetensors_model_filenames != safetensors_variant_filenames
):
logger.warn(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
)
if custom_pipeline is not None: else:
allow_patterns += [CUSTOM_PIPELINE_FILE_NAME] ignore_patterns = ["*.safetensors", "*.msgpack"]
bin_variant_filenames = set([f for f in variant_filenames if f.endswith(".bin")])
bin_model_filenames = set([f for f in model_filenames if f.endswith(".bin")])
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
logger.warn(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
)
else:
# allow everything since it has to be downloaded anyways
ignore_patterns = allow_patterns = None
if cls != DiffusionPipeline: if cls != DiffusionPipeline:
requested_pipeline_class = cls.__name__ requested_pipeline_class = cls.__name__
...@@ -501,21 +614,6 @@ class DiffusionPipeline(ConfigMixin): ...@@ -501,21 +614,6 @@ class DiffusionPipeline(ConfigMixin):
user_agent = http_user_agent(user_agent) user_agent = http_user_agent(user_agent)
if is_safetensors_available() and not local_files_only:
info = model_info(
pretrained_model_name_or_path,
use_auth_token=use_auth_token,
revision=revision,
)
if is_safetensors_compatible(info):
ignore_patterns.append("*.bin")
else:
# as a safety mechanism we also don't download safetensors if
# not all safetensors files are there
ignore_patterns.append("*.safetensors")
else:
ignore_patterns.append("*.safetensors")
# download all allow_patterns # download all allow_patterns
cached_folder = snapshot_download( cached_folder = snapshot_download(
pretrained_model_name_or_path, pretrained_model_name_or_path,
...@@ -533,6 +631,16 @@ class DiffusionPipeline(ConfigMixin): ...@@ -533,6 +631,16 @@ class DiffusionPipeline(ConfigMixin):
cached_folder = pretrained_model_name_or_path cached_folder = pretrained_model_name_or_path
config_dict = cls.load_config(cached_folder) config_dict = cls.load_config(cached_folder)
# retrieve which subfolders should load variants
model_variants = {}
if variant is not None:
for folder in os.listdir(cached_folder):
folder_path = os.path.join(cached_folder, folder)
is_folder = os.path.isdir(folder_path) and folder in config_dict
variant_exists = is_folder and any(path.split(".")[1] == variant for path in os.listdir(folder_path))
if variant_exists:
model_variants[folder] = variant
# 2. Load the pipeline class, if using custom module then load it from the hub # 2. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it # if we load from explicit class, let's use it
if custom_pipeline is not None: if custom_pipeline is not None:
...@@ -717,10 +825,11 @@ class DiffusionPipeline(ConfigMixin): ...@@ -717,10 +825,11 @@ class DiffusionPipeline(ConfigMixin):
loading_kwargs["sess_options"] = sess_options loading_kwargs["sess_options"] = sess_options
is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin) is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin)
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
is_transformers_model = ( is_transformers_model = (
is_transformers_available() is_transformers_available()
and issubclass(class_obj, PreTrainedModel) and issubclass(class_obj, PreTrainedModel)
and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0") and transformers_version >= version.parse("4.20.0")
) )
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers. # When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
...@@ -728,9 +837,23 @@ class DiffusionPipeline(ConfigMixin): ...@@ -728,9 +837,23 @@ class DiffusionPipeline(ConfigMixin):
# This makes sure that the weights won't be initialized which significantly speeds up loading. # This makes sure that the weights won't be initialized which significantly speeds up loading.
if is_diffusers_model or is_transformers_model: if is_diffusers_model or is_transformers_model:
loading_kwargs["device_map"] = device_map loading_kwargs["device_map"] = device_map
loading_kwargs["variant"] = model_variants.pop(name, None)
if from_flax: if from_flax:
loading_kwargs["from_flax"] = True loading_kwargs["from_flax"] = True
# the following can be deleted once the minimum required `transformers` version
# is higher than 4.27
if (
is_transformers_model
and loading_kwargs["variant"] is not None
and transformers_version < version.parse("4.27.0")
):
raise ImportError(
f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0"
)
elif is_transformers_model and loading_kwargs["variant"] is None:
loading_kwargs.pop("variant")
# if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage` # if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage`
if not (from_flax and is_transformers_model): if not (from_flax and is_transformers_model):
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
......
...@@ -20,6 +20,7 @@ from packaging import version ...@@ -20,6 +20,7 @@ from packaging import version
from .. import __version__ from .. import __version__
from .constants import ( from .constants import (
CONFIG_NAME, CONFIG_NAME,
DEPRECATED_REVISION_ARGS,
DIFFUSERS_CACHE, DIFFUSERS_CACHE,
DIFFUSERS_DYNAMIC_MODULE_NAME, DIFFUSERS_DYNAMIC_MODULE_NAME,
FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_NAME,
......
...@@ -30,3 +30,4 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co" ...@@ -30,3 +30,4 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
DIFFUSERS_CACHE = default_cache_path DIFFUSERS_CACHE = default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
...@@ -16,10 +16,12 @@ ...@@ -16,10 +16,12 @@
import inspect import inspect
import tempfile import tempfile
import unittest import unittest
import unittest.mock as mock
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import numpy as np import numpy as np
import torch import torch
from requests.exceptions import HTTPError
from diffusers.models import ModelMixin, UNet2DConditionModel from diffusers.models import ModelMixin, UNet2DConditionModel
from diffusers.training_utils import EMAModel from diffusers.training_utils import EMAModel
...@@ -34,6 +36,30 @@ class ModelUtilsTest(unittest.TestCase): ...@@ -34,6 +36,30 @@ class ModelUtilsTest(unittest.TestCase):
# make sure that error message states what keys are missing # make sure that error message states what keys are missing
assert "conv_out.bias" in str(error_context.exception) assert "conv_out.bias" in str(error_context.exception)
def test_cached_files_are_used_when_no_internet(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError
response_mock.json.return_value = {}
# Download this model to make sure it's in the cache.
orig_model = UNet2DConditionModel.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet"
)
# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("requests.request", return_value=response_mock):
# Download this model to make sure it's in the cache.
model = UNet2DConditionModel.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", local_files_only=True
)
for p1, p2 in zip(orig_model.parameters(), model.parameters()):
if p1.data.ne(p2.data).sum() > 0:
assert False, "Parameters not the same!"
class ModelTesterMixin: class ModelTesterMixin:
def test_from_save_pretrained(self): def test_from_save_pretrained(self):
...@@ -66,6 +92,44 @@ class ModelTesterMixin: ...@@ -66,6 +92,44 @@ class ModelTesterMixin:
max_diff = (image - new_image).abs().sum().item() max_diff = (image - new_image).abs().sum().item()
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes") self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
def test_from_save_pretrained_variant(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, variant="fp16")
new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16")
# non-variant cannot be loaded
with self.assertRaises(OSError) as error_context:
self.model_class.from_pretrained(tmpdirname)
# make sure that error message states what keys are missing
assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(error_context.exception)
new_model.to(torch_device)
with torch.no_grad():
# Warmup pass when using mps (see #372)
if torch_device == "mps" and isinstance(model, ModelMixin):
_ = model(**self.dummy_input)
_ = new_model(**self.dummy_input)
image = model(**inputs_dict)
if isinstance(image, dict):
image = image.sample
new_image = new_model(**inputs_dict)
if isinstance(new_image, dict):
new_image = new_image.sample
max_diff = (image - new_image).abs().sum().item()
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
def test_from_save_pretrained_dtype(self): def test_from_save_pretrained_dtype(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
......
...@@ -21,6 +21,7 @@ import shutil ...@@ -21,6 +21,7 @@ import shutil
import sys import sys
import tempfile import tempfile
import unittest import unittest
import unittest.mock as mock
import numpy as np import numpy as np
import PIL import PIL
...@@ -28,6 +29,7 @@ import safetensors.torch ...@@ -28,6 +29,7 @@ import safetensors.torch
import torch import torch
from parameterized import parameterized from parameterized import parameterized
from PIL import Image from PIL import Image
from requests.exceptions import HTTPError
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import ( from diffusers import (
...@@ -166,6 +168,155 @@ class DownloadTests(unittest.TestCase): ...@@ -166,6 +168,155 @@ class DownloadTests(unittest.TestCase):
assert np.max(np.abs(out - out_2)) < 1e-3 assert np.max(np.abs(out - out_2)) < 1e-3
def test_cached_files_are_used_when_no_internet(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError
response_mock.json.return_value = {}
# Download this model to make sure it's in the cache.
orig_pipe = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
)
orig_comps = {k: v for k, v in orig_pipe.components.items() if hasattr(v, "parameters")}
# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("requests.request", return_value=response_mock):
# Download this model to make sure it's in the cache.
pipe = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None, local_files_only=True
)
comps = {k: v for k, v in pipe.components.items() if hasattr(v, "parameters")}
for m1, m2 in zip(orig_comps.values(), comps.values()):
for p1, p2 in zip(m1.parameters(), m2.parameters()):
if p1.data.ne(p2.data).sum() > 0:
assert False, "Parameters not the same!"
def test_download_from_variant_folder(self):
for safe_avail in [False, True]:
import diffusers
diffusers.utils.import_utils._safetensors_available = safe_avail
other_format = ".bin" if safe_avail else ".safetensors"
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname
)
all_root_files = [
t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))
]
files = [item for sublist in all_root_files for item in sublist]
# None of the downloaded files should be a variant file even if we have some here:
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
assert not any(f.endswith(other_format) for f in files)
# no variants
assert not any(len(f.split(".")) == 3 for f in files)
diffusers.utils.import_utils._safetensors_available = True
def test_download_variant_all(self):
for safe_avail in [False, True]:
import diffusers
diffusers.utils.import_utils._safetensors_available = safe_avail
other_format = ".bin" if safe_avail else ".safetensors"
this_format = ".safetensors" if safe_avail else ".bin"
variant = "fp16"
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant
)
all_root_files = [
t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))
]
files = [item for sublist in all_root_files for item in sublist]
# None of the downloaded files should be a non-variant file even if we have some here:
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
# unet, vae, text_encoder, safety_checker
assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 4
# all checkpoints should have variant ending
assert not any(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files)
assert not any(f.endswith(other_format) for f in files)
diffusers.utils.import_utils._safetensors_available = True
def test_download_variant_partly(self):
for safe_avail in [False, True]:
import diffusers
diffusers.utils.import_utils._safetensors_available = safe_avail
other_format = ".bin" if safe_avail else ".safetensors"
this_format = ".safetensors" if safe_avail else ".bin"
variant = "no_ema"
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant
)
snapshots = os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots")
all_root_files = [t[-1] for t in os.walk(snapshots)]
files = [item for sublist in all_root_files for item in sublist]
unet_files = os.listdir(os.path.join(snapshots, os.listdir(snapshots)[0], "unet"))
# Some of the downloaded files should be a non-variant file, check:
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
# only unet has "no_ema" variant
assert f"diffusion_pytorch_model.{variant}{this_format}" in unet_files
assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 1
# vae, safety_checker and text_encoder should have no variant
assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3
assert not any(f.endswith(other_format) for f in files)
diffusers.utils.import_utils._safetensors_available = True
def test_download_broken_variant(self):
for safe_avail in [False, True]:
import diffusers
diffusers.utils.import_utils._safetensors_available = safe_avail
# text encoder is missing no variant and "no_ema" variant weights, so the following can't work
for variant in [None, "no_ema"]:
with self.assertRaises(OSError) as error_context:
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/stable-diffusion-broken-variants",
cache_dir=tmpdirname,
variant=variant,
)
assert "Error no file name" in str(error_context.exception)
# text encoder has fp16 variants so we can load it
with tempfile.TemporaryDirectory() as tmpdirname:
pipe = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/stable-diffusion-broken-variants", cache_dir=tmpdirname, variant="fp16"
)
assert pipe is not None
snapshots = os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots")
all_root_files = [t[-1] for t in os.walk(snapshots)]
files = [item for sublist in all_root_files for item in sublist]
# None of the downloaded files should be a non-variant file even if we have some here:
# https://huggingface.co/hf-internal-testing/stable-diffusion-broken-variants/tree/main/unet
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
# only unet has "no_ema" variant
diffusers.utils.import_utils._safetensors_available = True
class CustomPipelineTests(unittest.TestCase): class CustomPipelineTests(unittest.TestCase):
def test_load_custom_pipeline(self): def test_load_custom_pipeline(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment