Unverified Commit 04717fd8 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Add Stable Diffusion 3 (#8483)



* up

* add sd3

* update

* update

* add tests

* fix copies

* fix docs

* update

* add dreambooth lora

* add LoRA

* update

* update

* update

* update

* import fix

* update

* Update src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* import fix 2

* update

* Update src/diffusers/models/autoencoders/autoencoder_kl.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/autoencoders/autoencoder_kl.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/autoencoders/autoencoder_kl.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/autoencoders/autoencoder_kl.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/autoencoders/autoencoder_kl.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/autoencoders/autoencoder_kl.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/autoencoders/autoencoder_kl.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/autoencoders/autoencoder_kl.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/autoencoders/autoencoder_kl.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/autoencoders/autoencoder_kl.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/autoencoders/autoencoder_kl.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* update

* update

* update

* fix ckpt id

* fix more ids

* update

* missing doc

* Update src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* Update src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* Update docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* update'

* fix

* update

* Update src/diffusers/models/autoencoders/autoencoder_kl.py

* Update src/diffusers/models/autoencoders/autoencoder_kl.py

* note on gated access.

* requirements

* licensing

---------
Co-authored-by: default avatarsayakpaul <spsayakpaul@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 6fd458e9
...@@ -107,7 +107,8 @@ ...@@ -107,7 +107,8 @@
title: Create a dataset for training title: Create a dataset for training
- local: training/adapt_a_model - local: training/adapt_a_model
title: Adapt a model to a new task title: Adapt a model to a new task
- sections: - isExpanded: false
sections:
- local: training/unconditional_training - local: training/unconditional_training
title: Unconditional image generation title: Unconditional image generation
- local: training/text2image - local: training/text2image
...@@ -125,8 +126,8 @@ ...@@ -125,8 +126,8 @@
- local: training/instructpix2pix - local: training/instructpix2pix
title: InstructPix2Pix title: InstructPix2Pix
title: Models title: Models
isExpanded: false - isExpanded: false
- sections: sections:
- local: training/text_inversion - local: training/text_inversion
title: Textual Inversion title: Textual Inversion
- local: training/dreambooth - local: training/dreambooth
...@@ -140,7 +141,6 @@ ...@@ -140,7 +141,6 @@
- local: training/ddpo - local: training/ddpo
title: Reinforcement learning training with DDPO title: Reinforcement learning training with DDPO
title: Methods title: Methods
isExpanded: false
title: Training title: Training
- sections: - sections:
- local: optimization/fp16 - local: optimization/fp16
...@@ -187,7 +187,8 @@ ...@@ -187,7 +187,8 @@
title: Evaluating Diffusion Models title: Evaluating Diffusion Models
title: Conceptual Guides title: Conceptual Guides
- sections: - sections:
- sections: - isExpanded: false
sections:
- local: api/configuration - local: api/configuration
title: Configuration title: Configuration
- local: api/logging - local: api/logging
...@@ -195,8 +196,8 @@ ...@@ -195,8 +196,8 @@
- local: api/outputs - local: api/outputs
title: Outputs title: Outputs
title: Main Classes title: Main Classes
isExpanded: false - isExpanded: false
- sections: sections:
- local: api/loaders/ip_adapter - local: api/loaders/ip_adapter
title: IP-Adapter title: IP-Adapter
- local: api/loaders/lora - local: api/loaders/lora
...@@ -210,8 +211,8 @@ ...@@ -210,8 +211,8 @@
- local: api/loaders/peft - local: api/loaders/peft
title: PEFT title: PEFT
title: Loaders title: Loaders
isExpanded: false - isExpanded: false
- sections: sections:
- local: api/models/overview - local: api/models/overview
title: Overview title: Overview
- local: api/models/unet - local: api/models/unet
...@@ -246,13 +247,15 @@ ...@@ -246,13 +247,15 @@
title: HunyuanDiT2DModel title: HunyuanDiT2DModel
- local: api/models/transformer_temporal - local: api/models/transformer_temporal
title: TransformerTemporalModel title: TransformerTemporalModel
- local: api/models/sd3_transformer2d
title: SD3Transformer2DModel
- local: api/models/prior_transformer - local: api/models/prior_transformer
title: PriorTransformer title: PriorTransformer
- local: api/models/controlnet - local: api/models/controlnet
title: ControlNetModel title: ControlNetModel
title: Models title: Models
isExpanded: false - isExpanded: false
- sections: sections:
- local: api/pipelines/overview - local: api/pipelines/overview
title: Overview title: Overview
- local: api/pipelines/amused - local: api/pipelines/amused
...@@ -350,6 +353,8 @@ ...@@ -350,6 +353,8 @@
title: Safe Stable Diffusion title: Safe Stable Diffusion
- local: api/pipelines/stable_diffusion/stable_diffusion_2 - local: api/pipelines/stable_diffusion/stable_diffusion_2
title: Stable Diffusion 2 title: Stable Diffusion 2
- local: api/pipelines/stable_diffusion/stable_diffusion_3
title: Stable Diffusion 3
- local: api/pipelines/stable_diffusion/stable_diffusion_xl - local: api/pipelines/stable_diffusion/stable_diffusion_xl
title: Stable Diffusion XL title: Stable Diffusion XL
- local: api/pipelines/stable_diffusion/sdxl_turbo - local: api/pipelines/stable_diffusion/sdxl_turbo
...@@ -382,8 +387,8 @@ ...@@ -382,8 +387,8 @@
- local: api/pipelines/wuerstchen - local: api/pipelines/wuerstchen
title: Wuerstchen title: Wuerstchen
title: Pipelines title: Pipelines
isExpanded: false - isExpanded: false
- sections: sections:
- local: api/schedulers/overview - local: api/schedulers/overview
title: Overview title: Overview
- local: api/schedulers/cm_stochastic_iterative - local: api/schedulers/cm_stochastic_iterative
...@@ -414,6 +419,8 @@ ...@@ -414,6 +419,8 @@
title: EulerAncestralDiscreteScheduler title: EulerAncestralDiscreteScheduler
- local: api/schedulers/euler - local: api/schedulers/euler
title: EulerDiscreteScheduler title: EulerDiscreteScheduler
- local: api/schedulers/flow_match_euler_discrete
title: FlowMatchEulerDiscreteScheduler
- local: api/schedulers/heun - local: api/schedulers/heun
title: HeunDiscreteScheduler title: HeunDiscreteScheduler
- local: api/schedulers/ipndm - local: api/schedulers/ipndm
...@@ -443,8 +450,8 @@ ...@@ -443,8 +450,8 @@
- local: api/schedulers/vq_diffusion - local: api/schedulers/vq_diffusion
title: VQDiffusionScheduler title: VQDiffusionScheduler
title: Schedulers title: Schedulers
isExpanded: false - isExpanded: false
- sections: sections:
- local: api/internal_classes_overview - local: api/internal_classes_overview
title: Overview title: Overview
- local: api/attnprocessor - local: api/attnprocessor
...@@ -460,5 +467,4 @@ ...@@ -460,5 +467,4 @@
- local: api/video_processor - local: api/video_processor
title: Video Processor title: Video Processor
title: Internal classes title: Internal classes
isExpanded: false
title: API title: API
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# SD3 Transformer Model
The Transformer model introduced in [Stable Diffusion 3](https://hf.co/papers/2403.03206). Its novelty lies in the MMDiT transformer block.
## SD3Transformer2DModel
[[autodoc]] SD3Transformer2DModel
\ No newline at end of file
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Stable Diffusion 3
Stable Diffusion 3 (SD3) was proposed in [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/pdf/2403.03206.pdf) by Patrick Esser, Sumith Kulal, Andreas Blattmann, Rahim Entezari, Jonas Muller, Harry Saini, Yam Levi, Dominik Lorenz, Axel Sauer, Frederic Boesel, Dustin Podell, Tim Dockhorn, Zion English, Kyle Lacey, Alex Goodwin, Yannik Marek, and Robin Rombach.
The abstract from the paper is:
*Diffusion models create data from noise by inverting the forward paths of data towards noise and have emerged as a powerful generative modeling technique for high-dimensional, perceptual data such as images and videos. Rectified flow is a recent generative model formulation that connects data and noise in a straight line. Despite its better theoretical properties and conceptual simplicity, it is not yet decisively established as standard practice. In this work, we improve existing noise sampling techniques for training rectified flow models by biasing them towards perceptually relevant scales. Through a large-scale study, we demonstrate the superior performance of this approach compared to established diffusion formulations for high-resolution text-to-image synthesis. Additionally, we present a novel transformer-based architecture for text-to-image generation that uses separate weights for the two modalities and enables a bidirectional flow of information between image and text tokens, improving text comprehension typography, and human preference ratings. We demonstrate that this architecture follows predictable scaling trends and correlates lower validation loss to improved text-to-image synthesis as measured by various metrics and human evaluations.*
## Usage Example
_As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._
Use the command below to log in:
```bash
huggingface-cli login
```
<Tip>
The SD3 pipeline uses three text encoders to generate an image. Model offloading is necessary in order for it to run on most commodity hardware. Please use the `torch.float16` data type for additional memory savings.
</Tip>
```python
import torch
from diffusers import StableDiffusion3Pipeline
pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16)
pipe.to("cuda")
image = pipe(
prompt="a photo of a cat holding a sign that says hello world",
negative_prompt="",
num_inference_steps=28,
height=1024,
width=1024,
guidance_scale=7.0,
).images[0]
image.save("sd3_hello_world.png")
```
## Memory Optimisations for SD3
SD3 uses three text encoders, one if which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware.
### Running Inference with Model Offloading
The most basic memory optimization available in Diffusers allows you to offload the components of the model to CPU during inference in order to save memory, while seeing a slight increase in inference latency. Model offloading will only move a model component onto the GPU when it needs to be executed, while keeping the remaining components on the CPU.
```python
import torch
from diffusers import StableDiffusion3Pipeline
pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()
image = pipe(
prompt="a photo of a cat holding a sign that says hello world",
negative_prompt="",
num_inference_steps=28,
height=1024,
width=1024,
guidance_scale=7.0,
).images[0]
image.save("sd3_hello_world.png")
```
### Dropping the T5 Text Encoder during Inference
Removing the memory-intensive 4.7B parameter T5-XXL text encoder during inference can significantly decrease the memory requirements for SD3 with only a slight loss in performance.
```python
import torch
from diffusers import StableDiffusion3Pipeline
pipe = StableDiffusion3Pipeline.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers",
text_encoder_3=None,
tokenizer_3=None,
torch_dtype=torch.float16
)
pipe.to("cuda")
image = pipe(
prompt="a photo of a cat holding a sign that says hello world",
negative_prompt="",
num_inference_steps=28,
height=1024,
width=1024,
guidance_scale=7.0,
).images[0]
image.save("sd3_hello_world-no-T5.png")
```
### Using a Quantized Version of the T5 Text Encoder
We can leverage the `bitsandbytes` library to load and quantize the T5-XXL text encoder to 8-bit precision. This allows you to keep using all three text encoders while only slightly impacting performance.
First install the `bitsandbytes` library.
```shell
pip install bitsandbytes
```
Then load the T5-XXL model using the `BitsAndBytesConfig`.
```python
import torch
from diffusers import StableDiffusion3Pipeline
from transformers import T5EncoderModel, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model_id = "stabilityai/stable-diffusion-3-medium-diffusers"
text_encoder = T5EncoderModel.from_pretrained(
model_id,
subfolder="text_encoder_3",
quantization_config=quantization_config,
)
pipe = StableDiffusion3Pipeline.from_pretrained(
model_id,
text_encoder_3=text_encoder,
device_map="balanced",
torch_dtype=torch.float16
)
image = pipe(
prompt="a photo of a cat holding a sign that says hello world",
negative_prompt="",
num_inference_steps=28,
height=1024,
width=1024,
guidance_scale=7.0,
).images[0]
image.save("sd3_hello_world-8bit-T5.png")
```
You can find the end-to-end script [here](https://gist.github.com/sayakpaul/82acb5976509851f2db1a83456e504f1).
## Performance Optimizations for SD3
### Using Torch Compile to Speed Up Inference
Using compiled components in the SD3 pipeline can speed up inference by as much as 4X. The following code snippet demonstrates how to compile the Transformer and VAE components of the SD3 pipeline.
```python
import torch
from diffusers import StableDiffusion3Pipeline
torch.set_float32_matmul_precision("high")
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True
pipe = StableDiffusion3Pipeline.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers",
torch_dtype=torch.float16
).to("cuda")
pipe.set_progress_bar_config(disable=True)
pipe.transformer.to(memory_format=torch.channels_last)
pipe.vae.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)
# Warm Up
prompt = "a photo of a cat holding a sign that says hello world",
for _ in range(3):
_ = pipe(prompt=prompt, generator=torch.manual_seed(1))
# Run Inference
image = pipe(prompt=prompt, generator=torch.manual_seed(1)).images[0]
image.save("sd3_hello_world.png")
```
Check out the full script [here](https://gist.github.com/sayakpaul/508d89d7aad4f454900813da5d42ca97).
## Loading the original checkpoints via `from_single_file`
The `SD3Transformer2DModel` and `StableDiffusion3Pipeline` classes support loading the original checkpoints via the `from_single_file` method. This method allows you to load the original checkpoint files that were used to train the models.
## Loading the original checkpoints for the `SD3Transformer2DModel`
```python
from diffusers import SD3Transformer2DModel
model = SD3Transformer2DModel.from_single_file("https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium.safetensors")
```
## Loading the single checkpoint for the `StableDiffusion3Pipeline`
```python
from diffusers import StableDiffusion3Pipeline
from transformers import T5EncoderModel
text_encoder_3 = T5EncoderModel.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="text_encoder_3", torch_dtype=torch.float16)
pipe = StableDiffusion3Pipeline.from_single_file("https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips.safetensors", torch_dtype=torch.float16, text_encoder_3=text_encoder_3)
```
<Tip>
`from_single_file` support for the `fp8` version of the checkpoints is coming soon. Watch this space.
</Tip>
## StableDiffusion3Pipeline
[[autodoc]] StableDiffusion3Pipeline
- all
- __call__
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# FlowMatchEulerDiscreteScheduler
`FlowMatchEulerDiscreteScheduler` is based on the flow-matching sampling introduced in [Stable Diffusion 3](https://arxiv.org/abs/2403.03206).
## FlowMatchEulerDiscreteScheduler
[[autodoc]] FlowMatchEulerDiscreteScheduler
# DreamBooth training example for Stable Diffusion 3 (SD3)
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.
The `train_dreambooth_sd3.py` script shows how to implement the training procedure and adapt it for [Stable Diffusion 3](https://huggingface.co/papers/2403.03206). We also provide a LoRA implementation in the `train_dreambooth_lora_sd3.py` script.
> [!NOTE]
> As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
```bash
huggingface-cli login
```
## Running locally with PyTorch
### Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
```
Then cd in the `examples/dreambooth` folder and run
```bash
pip install -r requirements_sd3.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
Or for a default accelerate configuration without answering questions about your environment
```bash
accelerate config default
```
Or if your environment doesn't support an interactive shell (e.g., a notebook)
```python
from accelerate.utils import write_basic_config
write_basic_config()
```
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
### Dog toy example
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
Let's first download it locally:
```python
from huggingface_hub import snapshot_download
local_dir = "./dog"
snapshot_download(
"diffusers/dog-example",
local_dir=local_dir, repo_type="dataset",
ignore_patterns=".gitattributes",
)
```
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
Now, we can launch training using:
```bash
export MODEL_NAME="stabilityai/stable-diffusion-3-medium-diffusers"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-sd3"
accelerate launch train_dreambooth_sd3.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="fp16" \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```
To better track our training experiments, we're using the following flags in the command above:
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
> [!TIP]
> You can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so.
## LoRA + DreamBooth
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
To perform DreamBooth with LoRA, run:
```bash
export MODEL_NAME="stabilityai/stable-diffusion-3-medium-diffusers"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-sd3-lora"
accelerate launch train_dreambooth_sd3.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="fp16" \
--instance_prompt="a photo of sks dog" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--learning_rate=1e-5 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```
\ No newline at end of file
accelerate>=0.31.0
torchvision
transformers>=4.41.2
ftfy
tensorboard
Jinja2
peft== 0.11.1
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
...@@ -91,6 +91,7 @@ else: ...@@ -91,6 +91,7 @@ else:
"MultiAdapter", "MultiAdapter",
"PixArtTransformer2DModel", "PixArtTransformer2DModel",
"PriorTransformer", "PriorTransformer",
"SD3Transformer2DModel",
"StableCascadeUNet", "StableCascadeUNet",
"T2IAdapter", "T2IAdapter",
"T5FilmDecoder", "T5FilmDecoder",
...@@ -156,6 +157,7 @@ else: ...@@ -156,6 +157,7 @@ else:
"EDMEulerScheduler", "EDMEulerScheduler",
"EulerAncestralDiscreteScheduler", "EulerAncestralDiscreteScheduler",
"EulerDiscreteScheduler", "EulerDiscreteScheduler",
"FlowMatchEulerDiscreteScheduler",
"HeunDiscreteScheduler", "HeunDiscreteScheduler",
"IPNDMScheduler", "IPNDMScheduler",
"KarrasVeScheduler", "KarrasVeScheduler",
...@@ -276,6 +278,8 @@ else: ...@@ -276,6 +278,8 @@ else:
"StableCascadeCombinedPipeline", "StableCascadeCombinedPipeline",
"StableCascadeDecoderPipeline", "StableCascadeDecoderPipeline",
"StableCascadePriorPipeline", "StableCascadePriorPipeline",
"StableDiffusion3Img2ImgPipeline",
"StableDiffusion3Pipeline",
"StableDiffusionAdapterPipeline", "StableDiffusionAdapterPipeline",
"StableDiffusionAttendAndExcitePipeline", "StableDiffusionAttendAndExcitePipeline",
"StableDiffusionControlNetImg2ImgPipeline", "StableDiffusionControlNetImg2ImgPipeline",
...@@ -497,6 +501,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -497,6 +501,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
MultiAdapter, MultiAdapter,
PixArtTransformer2DModel, PixArtTransformer2DModel,
PriorTransformer, PriorTransformer,
SD3Transformer2DModel,
T2IAdapter, T2IAdapter,
T5FilmDecoder, T5FilmDecoder,
Transformer2DModel, Transformer2DModel,
...@@ -559,6 +564,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -559,6 +564,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
EDMEulerScheduler, EDMEulerScheduler,
EulerAncestralDiscreteScheduler, EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler, EulerDiscreteScheduler,
FlowMatchEulerDiscreteScheduler,
HeunDiscreteScheduler, HeunDiscreteScheduler,
IPNDMScheduler, IPNDMScheduler,
KarrasVeScheduler, KarrasVeScheduler,
...@@ -660,6 +666,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -660,6 +666,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableCascadeCombinedPipeline, StableCascadeCombinedPipeline,
StableCascadeDecoderPipeline, StableCascadeDecoderPipeline,
StableCascadePriorPipeline, StableCascadePriorPipeline,
StableDiffusion3Img2ImgPipeline,
StableDiffusion3Pipeline,
StableDiffusionAdapterPipeline, StableDiffusionAdapterPipeline,
StableDiffusionAttendAndExcitePipeline, StableDiffusionAttendAndExcitePipeline,
StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetImg2ImgPipeline,
......
...@@ -86,6 +86,7 @@ class VaeImageProcessor(ConfigMixin): ...@@ -86,6 +86,7 @@ class VaeImageProcessor(ConfigMixin):
self, self,
do_resize: bool = True, do_resize: bool = True,
vae_scale_factor: int = 8, vae_scale_factor: int = 8,
vae_latent_channels: int = 4,
resample: str = "lanczos", resample: str = "lanczos",
do_normalize: bool = True, do_normalize: bool = True,
do_binarize: bool = False, do_binarize: bool = False,
......
...@@ -59,7 +59,7 @@ if is_torch_available(): ...@@ -59,7 +59,7 @@ if is_torch_available():
_import_structure["utils"] = ["AttnProcsLayers"] _import_structure["utils"] = ["AttnProcsLayers"]
if is_transformers_available(): if is_transformers_available():
_import_structure["single_file"] = ["FromSingleFileMixin"] _import_structure["single_file"] = ["FromSingleFileMixin"]
_import_structure["lora"] = ["LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin"] _import_structure["lora"] = ["LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin", "SD3LoraLoaderMixin"]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = ["IPAdapterMixin"] _import_structure["ip_adapter"] = ["IPAdapterMixin"]
...@@ -74,7 +74,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -74,7 +74,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_transformers_available(): if is_transformers_available():
from .ip_adapter import IPAdapterMixin from .ip_adapter import IPAdapterMixin
from .lora import LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin from .lora import LoraLoaderMixin, SD3LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin
from .single_file import FromSingleFileMixin from .single_file import FromSingleFileMixin
from .textual_inversion import TextualInversionLoaderMixin from .textual_inversion import TextualInversionLoaderMixin
......
...@@ -1337,3 +1337,393 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin): ...@@ -1337,3 +1337,393 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
if getattr(self.text_encoder_2, "peft_config", None) is not None: if getattr(self.text_encoder_2, "peft_config", None) is not None:
del self.text_encoder_2.peft_config del self.text_encoder_2.peft_config
self.text_encoder_2._hf_peft_config_loaded = None self.text_encoder_2._hf_peft_config_loaded = None
class SD3LoraLoaderMixin:
r"""
Load LoRA layers into [`SD3Transformer2DModel`].
"""
transformer_name = TRANSFORMER_NAME
num_fused_loras = 0
def load_lora_weights(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
`self.text_encoder`.
All kwargs are forwarded to `self.lora_state_dict`.
See [`~loaders.LoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
See [`~loaders.LoraLoaderMixin.load_lora_into_transformer`] for more details on how the state dict is loaded
into `self.transformer`.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
kwargs (`dict`, *optional*):
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
_pipeline=self,
)
@classmethod
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
Return state dict for lora weights and the network alphas.
<Tip warning={true}>
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
This function is experimental and might change in the future.
</Tip>
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
incompletely downloaded files are deleted.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
"""
# Load the main state dict first which has the LoRA layers for either of
# UNet and text encoder or both.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
# Let's first try to load .safetensors weights
if (use_safetensors and weight_name is None) or (
weight_name is not None and weight_name.endswith(".safetensors")
):
try:
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except (IOError, safetensors.SafetensorError) as e:
if not allow_pickle:
raise e
# try loading non-safetensors weights
model_file = None
pass
if model_file is None:
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name or LORA_WEIGHT_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = load_state_dict(model_file)
else:
state_dict = pretrained_model_name_or_path_or_dict
return state_dict
@classmethod
def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
Parameters:
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
encoder lora layers.
transformer (`SD3Transformer2DModel`):
The Transformer model to load the LoRA layers into.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
"""
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
keys = list(state_dict.keys())
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
state_dict = {
k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
}
if len(state_dict.keys()) > 0:
if adapter_name in getattr(transformer, "peft_config", {}):
raise ValueError(
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
)
rank = {}
for key, val in state_dict.items():
if "lora_B" in key:
rank[key] = val.shape[1]
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
lora_config_kwargs.pop("use_dora")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(transformer)
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
@classmethod
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, torch.nn.Module] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
):
r"""
Save the LoRA parameters corresponding to the UNet and text encoder.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to save LoRA parameters to. Will be created if it doesn't exist.
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `transformer`.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
process to avoid race conditions.
save_function (`Callable`):
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
"""
state_dict = {}
def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers:
state_dict.update(pack_weights(transformer_lora_layers, cls.transformer_name))
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
@staticmethod
def write_lora_layers(
state_dict: Dict[str, torch.Tensor],
save_directory: str,
is_main_process: bool,
weight_name: str,
save_function: Callable,
safe_serialization: bool,
):
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
if save_function is None:
if safe_serialization:
def save_function(weights, filename):
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
else:
save_function = torch.save
os.makedirs(save_directory, exist_ok=True)
if weight_name is None:
if safe_serialization:
weight_name = LORA_WEIGHT_NAME_SAFE
else:
weight_name = LORA_WEIGHT_NAME
save_path = Path(save_directory, weight_name).as_posix()
save_function(state_dict, save_path)
logger.info(f"Model weights saved in {save_path}")
def unload_lora_weights(self):
"""
Unloads the LoRA parameters.
Examples:
```python
>>> # Assuming `pipeline` is already loaded with the LoRA parameters.
>>> pipeline.unload_lora_weights()
>>> ...
```
"""
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
recurse_remove_peft_layers(transformer)
if hasattr(transformer, "peft_config"):
del transformer.peft_config
@classmethod
# Copied from diffusers.loaders.lora.LoraLoaderMixin._optionally_disable_offloading
def _optionally_disable_offloading(cls, _pipeline):
"""
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
Args:
_pipeline (`DiffusionPipeline`):
The pipeline to disable offloading for.
Returns:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload)
...@@ -234,7 +234,7 @@ def _download_diffusers_model_config_from_hub( ...@@ -234,7 +234,7 @@ def _download_diffusers_model_config_from_hub(
local_files_only=None, local_files_only=None,
token=None, token=None,
): ):
allow_patterns = ["**/*.json", "*.json", "*.txt", "**/*.txt"] allow_patterns = ["**/*.json", "*.json", "*.txt", "**/*.txt", "**/*.model"]
cached_model_path = snapshot_download( cached_model_path = snapshot_download(
pretrained_model_name_or_path, pretrained_model_name_or_path,
cache_dir=cache_dir, cache_dir=cache_dir,
......
...@@ -24,6 +24,7 @@ from .single_file_utils import ( ...@@ -24,6 +24,7 @@ from .single_file_utils import (
convert_controlnet_checkpoint, convert_controlnet_checkpoint,
convert_ldm_unet_checkpoint, convert_ldm_unet_checkpoint,
convert_ldm_vae_checkpoint, convert_ldm_vae_checkpoint,
convert_sd3_transformer_checkpoint_to_diffusers,
convert_stable_cascade_unet_single_file_to_diffusers, convert_stable_cascade_unet_single_file_to_diffusers,
create_controlnet_diffusers_config_from_ldm, create_controlnet_diffusers_config_from_ldm,
create_unet_diffusers_config_from_ldm, create_unet_diffusers_config_from_ldm,
...@@ -64,6 +65,10 @@ SINGLE_FILE_LOADABLE_CLASSES = { ...@@ -64,6 +65,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn": convert_controlnet_checkpoint, "checkpoint_mapping_fn": convert_controlnet_checkpoint,
"config_mapping_fn": create_controlnet_diffusers_config_from_ldm, "config_mapping_fn": create_controlnet_diffusers_config_from_ldm,
}, },
"SD3Transformer2DModel": {
"checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
} }
......
...@@ -21,6 +21,7 @@ from io import BytesIO ...@@ -21,6 +21,7 @@ from io import BytesIO
from urllib.parse import urlparse from urllib.parse import urlparse
import requests import requests
import torch
import yaml import yaml
from ..models.modeling_utils import load_state_dict from ..models.modeling_utils import load_state_dict
...@@ -65,11 +66,14 @@ CHECKPOINT_KEY_NAMES = { ...@@ -65,11 +66,14 @@ CHECKPOINT_KEY_NAMES = {
"inpainting": "model.diffusion_model.input_blocks.0.0.weight", "inpainting": "model.diffusion_model.input_blocks.0.0.weight",
"clip": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight", "clip": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
"clip_sdxl": "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight", "clip_sdxl": "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight",
"clip_sd3": "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight",
"open_clip": "cond_stage_model.model.token_embedding.weight", "open_clip": "cond_stage_model.model.token_embedding.weight",
"open_clip_sdxl": "conditioner.embedders.1.model.positional_embedding", "open_clip_sdxl": "conditioner.embedders.1.model.positional_embedding",
"open_clip_sdxl_refiner": "conditioner.embedders.0.model.text_projection", "open_clip_sdxl_refiner": "conditioner.embedders.0.model.text_projection",
"open_clip_sd3": "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight",
"stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight", "stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
"stable_cascade_stage_c": "clip_txt_mapper.weight", "stable_cascade_stage_c": "clip_txt_mapper.weight",
"sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
} }
DIFFUSERS_DEFAULT_PIPELINE_PATHS = { DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
...@@ -96,6 +100,9 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = { ...@@ -96,6 +100,9 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"pretrained_model_name_or_path": "stabilityai/stable-cascade-prior", "pretrained_model_name_or_path": "stabilityai/stable-cascade-prior",
"subfolder": "prior_lite", "subfolder": "prior_lite",
}, },
"sd3": {
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-3-medium-diffusers",
},
} }
# Use to configure model sample size when original config is provided # Use to configure model sample size when original config is provided
...@@ -242,7 +249,11 @@ LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215 ...@@ -242,7 +249,11 @@ LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
PLAYGROUND_VAE_SCALING_FACTOR = 0.5 PLAYGROUND_VAE_SCALING_FACTOR = 0.5
LDM_UNET_KEY = "model.diffusion_model." LDM_UNET_KEY = "model.diffusion_model."
LDM_CONTROLNET_KEY = "control_model." LDM_CONTROLNET_KEY = "control_model."
LDM_CLIP_PREFIX_TO_REMOVE = ["cond_stage_model.transformer.", "conditioner.embedders.0.transformer."] LDM_CLIP_PREFIX_TO_REMOVE = [
"cond_stage_model.transformer.",
"conditioner.embedders.0.transformer.",
"text_encoders.clip_l.transformer.",
]
OPEN_CLIP_PREFIX = "conditioner.embedders.0.model." OPEN_CLIP_PREFIX = "conditioner.embedders.0.model."
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024 LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
...@@ -366,6 +377,13 @@ def is_clip_sdxl_model(checkpoint): ...@@ -366,6 +377,13 @@ def is_clip_sdxl_model(checkpoint):
return False return False
def is_clip_sd3_model(checkpoint):
if CHECKPOINT_KEY_NAMES["clip_sd3"] in checkpoint:
return True
return False
def is_open_clip_model(checkpoint): def is_open_clip_model(checkpoint):
if CHECKPOINT_KEY_NAMES["open_clip"] in checkpoint: if CHECKPOINT_KEY_NAMES["open_clip"] in checkpoint:
return True return True
...@@ -380,8 +398,12 @@ def is_open_clip_sdxl_model(checkpoint): ...@@ -380,8 +398,12 @@ def is_open_clip_sdxl_model(checkpoint):
return False return False
def is_open_clip_sd3_model(checkpoint):
is_open_clip_sdxl_refiner_model(checkpoint)
def is_open_clip_sdxl_refiner_model(checkpoint): def is_open_clip_sdxl_refiner_model(checkpoint):
if CHECKPOINT_KEY_NAMES["open_clip_sdxl_refiner"] in checkpoint: if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint:
return True return True
return False return False
...@@ -391,9 +413,11 @@ def is_clip_model_in_single_file(class_obj, checkpoint): ...@@ -391,9 +413,11 @@ def is_clip_model_in_single_file(class_obj, checkpoint):
is_clip_in_checkpoint = any( is_clip_in_checkpoint = any(
[ [
is_clip_model(checkpoint), is_clip_model(checkpoint),
is_clip_sd3_model(checkpoint),
is_open_clip_model(checkpoint), is_open_clip_model(checkpoint),
is_open_clip_sdxl_model(checkpoint), is_open_clip_sdxl_model(checkpoint),
is_open_clip_sdxl_refiner_model(checkpoint), is_open_clip_sdxl_refiner_model(checkpoint),
is_open_clip_sd3_model(checkpoint),
] ]
) )
if ( if (
...@@ -456,6 +480,9 @@ def infer_diffusers_model_type(checkpoint): ...@@ -456,6 +480,9 @@ def infer_diffusers_model_type(checkpoint):
): ):
model_type = "stable_cascade_stage_b" model_type = "stable_cascade_stage_b"
elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint:
model_type = "sd3"
else: else:
model_type = "v1" model_type = "v1"
...@@ -1364,6 +1391,10 @@ def create_diffusers_clip_model_from_ldm( ...@@ -1364,6 +1391,10 @@ def create_diffusers_clip_model_from_ldm(
prefix = "conditioner.embedders.0.model." prefix = "conditioner.embedders.0.model."
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix) diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
elif is_open_clip_sd3_model(checkpoint):
prefix = "text_encoders.clip_g.transformer."
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
else: else:
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.") raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
...@@ -1559,3 +1590,212 @@ def _legacy_load_safety_checker(local_files_only, torch_dtype): ...@@ -1559,3 +1590,212 @@ def _legacy_load_safety_checker(local_files_only, torch_dtype):
) )
return {"safety_checker": safety_checker, "feature_extractor": feature_extractor} return {"safety_checker": safety_checker, "feature_extractor": feature_extractor}
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
def swap_scale_shift(weight, dim):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
keys = list(checkpoint.keys())
for k in keys:
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "joint_blocks" in k))[-1] + 1 # noqa: C401
caption_projection_dim = 1536
# Positional and patch embeddings.
converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("pos_embed")
converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
# Timestep embeddings.
converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop(
"t_embedder.mlp.0.weight"
)
converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop(
"t_embedder.mlp.2.weight"
)
converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
# Context projections.
converted_state_dict["context_embedder.weight"] = checkpoint.pop("context_embedder.weight")
converted_state_dict["context_embedder.bias"] = checkpoint.pop("context_embedder.bias")
# Pooled context projection.
converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("y_embedder.mlp.0.weight")
converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("y_embedder.mlp.0.bias")
converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop("y_embedder.mlp.2.weight")
converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("y_embedder.mlp.2.bias")
# Transformer blocks 🎸.
for i in range(num_layers):
# Q, K, V
sample_q, sample_k, sample_v = torch.chunk(
checkpoint.pop(f"joint_blocks.{i}.x_block.attn.qkv.weight"), 3, dim=0
)
context_q, context_k, context_v = torch.chunk(
checkpoint.pop(f"joint_blocks.{i}.context_block.attn.qkv.weight"), 3, dim=0
)
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
checkpoint.pop(f"joint_blocks.{i}.x_block.attn.qkv.bias"), 3, dim=0
)
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
checkpoint.pop(f"joint_blocks.{i}.context_block.attn.qkv.bias"), 3, dim=0
)
converted_state_dict[f"transformer_blocks.{i}.attn.to_q.weight"] = torch.cat([sample_q])
converted_state_dict[f"transformer_blocks.{i}.attn.to_q.bias"] = torch.cat([sample_q_bias])
converted_state_dict[f"transformer_blocks.{i}.attn.to_k.weight"] = torch.cat([sample_k])
converted_state_dict[f"transformer_blocks.{i}.attn.to_k.bias"] = torch.cat([sample_k_bias])
converted_state_dict[f"transformer_blocks.{i}.attn.to_v.weight"] = torch.cat([sample_v])
converted_state_dict[f"transformer_blocks.{i}.attn.to_v.bias"] = torch.cat([sample_v_bias])
converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.weight"] = torch.cat([context_q])
converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.bias"] = torch.cat([context_q_bias])
converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.weight"] = torch.cat([context_k])
converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.bias"] = torch.cat([context_k_bias])
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v])
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias])
# output projections.
converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = checkpoint.pop(
f"joint_blocks.{i}.x_block.attn.proj.weight"
)
converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.bias"] = checkpoint.pop(
f"joint_blocks.{i}.x_block.attn.proj.bias"
)
if not (i == num_layers - 1):
converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.weight"] = checkpoint.pop(
f"joint_blocks.{i}.context_block.attn.proj.weight"
)
converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.bias"] = checkpoint.pop(
f"joint_blocks.{i}.context_block.attn.proj.bias"
)
# norms.
converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop(
f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight"
)
converted_state_dict[f"transformer_blocks.{i}.norm1.linear.bias"] = checkpoint.pop(
f"joint_blocks.{i}.x_block.adaLN_modulation.1.bias"
)
if not (i == num_layers - 1):
converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = checkpoint.pop(
f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"
)
converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = checkpoint.pop(
f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"
)
else:
converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = swap_scale_shift(
checkpoint.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"),
dim=caption_projection_dim,
)
converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = swap_scale_shift(
checkpoint.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"),
dim=caption_projection_dim,
)
# ffs.
converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.weight"] = checkpoint.pop(
f"joint_blocks.{i}.x_block.mlp.fc1.weight"
)
converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.bias"] = checkpoint.pop(
f"joint_blocks.{i}.x_block.mlp.fc1.bias"
)
converted_state_dict[f"transformer_blocks.{i}.ff.net.2.weight"] = checkpoint.pop(
f"joint_blocks.{i}.x_block.mlp.fc2.weight"
)
converted_state_dict[f"transformer_blocks.{i}.ff.net.2.bias"] = checkpoint.pop(
f"joint_blocks.{i}.x_block.mlp.fc2.bias"
)
if not (i == num_layers - 1):
converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.weight"] = checkpoint.pop(
f"joint_blocks.{i}.context_block.mlp.fc1.weight"
)
converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.bias"] = checkpoint.pop(
f"joint_blocks.{i}.context_block.mlp.fc1.bias"
)
converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.weight"] = checkpoint.pop(
f"joint_blocks.{i}.context_block.mlp.fc2.weight"
)
converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.bias"] = checkpoint.pop(
f"joint_blocks.{i}.context_block.mlp.fc2.bias"
)
# Final blocks.
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
checkpoint.pop("final_layer.adaLN_modulation.1.weight"), dim=caption_projection_dim
)
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
checkpoint.pop("final_layer.adaLN_modulation.1.bias"), dim=caption_projection_dim
)
return converted_state_dict
def is_t5_in_single_file(checkpoint):
if "text_encoders.t5xxl.transformer.shared.weight" in checkpoint:
return True
return False
def convert_sd3_t5_checkpoint_to_diffusers(checkpoint):
keys = list(checkpoint.keys())
text_model_dict = {}
remove_prefixes = ["text_encoders.t5xxl.transformer.encoder."]
for key in keys:
for prefix in remove_prefixes:
if key.startswith(prefix):
diffusers_key = key.replace(prefix, "")
text_model_dict[diffusers_key] = checkpoint.get(key)
return text_model_dict
def create_diffusers_t5_model_from_checkpoint(
cls,
checkpoint,
subfolder="",
config=None,
torch_dtype=None,
local_files_only=None,
):
if config:
config = {"pretrained_model_name_or_path": config}
else:
config = fetch_diffusers_config(checkpoint)
model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
model = cls(model_config)
diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint)
if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
if model._keys_to_ignore_on_load_unexpected is not None:
for pat in model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else:
model.load_state_dict(diffusers_format_checkpoint)
...@@ -43,6 +43,7 @@ if is_torch_available(): ...@@ -43,6 +43,7 @@ if is_torch_available():
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"] _import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"] _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"] _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
_import_structure["unets.unet_1d"] = ["UNet1DModel"] _import_structure["unets.unet_1d"] = ["UNet1DModel"]
_import_structure["unets.unet_2d"] = ["UNet2DModel"] _import_structure["unets.unet_2d"] = ["UNet2DModel"]
...@@ -82,6 +83,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -82,6 +83,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiT2DModel, HunyuanDiT2DModel,
PixArtTransformer2DModel, PixArtTransformer2DModel,
PriorTransformer, PriorTransformer,
SD3Transformer2DModel,
T5FilmDecoder, T5FilmDecoder,
Transformer2DModel, Transformer2DModel,
TransformerTemporalModel, TransformerTemporalModel,
......
...@@ -20,7 +20,7 @@ from torch import nn ...@@ -20,7 +20,7 @@ from torch import nn
from ..utils import deprecate, logging from ..utils import deprecate, logging
from ..utils.torch_utils import maybe_allow_in_graph from ..utils.torch_utils import maybe_allow_in_graph
from .activations import GEGLU, GELU, ApproximateGELU from .activations import GEGLU, GELU, ApproximateGELU
from .attention_processor import Attention from .attention_processor import Attention, JointAttnProcessor2_0
from .embeddings import SinusoidalPositionalEmbedding from .embeddings import SinusoidalPositionalEmbedding
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
...@@ -85,6 +85,130 @@ class GatedSelfAttentionDense(nn.Module): ...@@ -85,6 +85,130 @@ class GatedSelfAttentionDense(nn.Module):
return x return x
@maybe_allow_in_graph
class JointTransformerBlock(nn.Module):
r"""
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
Reference: https://arxiv.org/abs/2403.03206
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
processing of `context` conditions.
"""
def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False):
super().__init__()
self.context_pre_only = context_pre_only
context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
self.norm1 = AdaLayerNormZero(dim)
if context_norm_type == "ada_norm_continous":
self.norm1_context = AdaLayerNormContinuous(
dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
)
elif context_norm_type == "ada_norm_zero":
self.norm1_context = AdaLayerNormZero(dim)
else:
raise ValueError(
f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
)
if hasattr(F, "scaled_dot_product_attention"):
processor = JointAttnProcessor2_0()
else:
raise ValueError(
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
)
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
added_kv_proj_dim=dim,
dim_head=attention_head_dim // num_attention_heads,
heads=num_attention_heads,
out_dim=attention_head_dim,
context_pre_only=context_pre_only,
bias=True,
processor=processor,
)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
if not context_pre_only:
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
else:
self.norm2_context = None
self.ff_context = None
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
# Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
):
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
if self.context_pre_only:
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
else:
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)
# Attention.
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
)
# Process attention outputs for the `hidden_states`.
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else:
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = hidden_states + ff_output
# Process attention outputs for the `encoder_hidden_states`.
if self.context_pre_only:
encoder_hidden_states = None
else:
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
context_ff_output = _chunked_feed_forward(
self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
)
else:
context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
return encoder_hidden_states, hidden_states
@maybe_allow_in_graph @maybe_allow_in_graph
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):
r""" r"""
......
This diff is collapsed.
...@@ -81,9 +81,12 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -81,9 +81,12 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
norm_num_groups: int = 32, norm_num_groups: int = 32,
sample_size: int = 32, sample_size: int = 32,
scaling_factor: float = 0.18215, scaling_factor: float = 0.18215,
shift_factor: Optional[float] = None,
latents_mean: Optional[Tuple[float]] = None, latents_mean: Optional[Tuple[float]] = None,
latents_std: Optional[Tuple[float]] = None, latents_std: Optional[Tuple[float]] = None,
force_upcast: float = True, force_upcast: float = True,
use_quant_conv: bool = True,
use_post_quant_conv: bool = True,
): ):
super().__init__() super().__init__()
...@@ -110,8 +113,8 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -110,8 +113,8 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
act_fn=act_fn, act_fn=act_fn,
) )
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None
self.use_slicing = False self.use_slicing = False
self.use_tiling = False self.use_tiling = False
...@@ -245,13 +248,11 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -245,13 +248,11 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
Args: Args:
x (`torch.Tensor`): Input batch of images. x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoders.autoencoder_kl.AutoencoderKLOutput`] instead of a plain Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
tuple.
Returns: Returns:
The latent representations of the encoded images. If `return_dict` is True, a The latent representations of the encoded images. If `return_dict` is True, a
[`~models.autoencoders.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
returned.
""" """
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
return self.tiled_encode(x, return_dict=return_dict) return self.tiled_encode(x, return_dict=return_dict)
...@@ -262,7 +263,11 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -262,7 +263,11 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
else: else:
h = self.encoder(x) h = self.encoder(x)
moments = self.quant_conv(h) if self.quant_conv is not None:
moments = self.quant_conv(h)
else:
moments = h
posterior = DiagonalGaussianDistribution(moments) posterior = DiagonalGaussianDistribution(moments)
if not return_dict: if not return_dict:
...@@ -274,7 +279,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -274,7 +279,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
return self.tiled_decode(z, return_dict=return_dict) return self.tiled_decode(z, return_dict=return_dict)
z = self.post_quant_conv(z) if self.post_quant_conv is not None:
z = self.post_quant_conv(z)
dec = self.decoder(z) dec = self.decoder(z)
if not return_dict: if not return_dict:
...@@ -283,7 +290,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -283,7 +290,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
return DecoderOutput(sample=dec) return DecoderOutput(sample=dec)
@apply_forward_hook @apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True, generator=None) -> Union[DecoderOutput, torch.Tensor]: def decode(
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
) -> Union[DecoderOutput, torch.FloatTensor]:
""" """
Decode a batch of images. Decode a batch of images.
...@@ -302,7 +311,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -302,7 +311,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices) decoded = torch.cat(decoded_slices)
else: else:
decoded = self._decode(z, return_dict=False)[0] decoded = self._decode(z).sample
if not return_dict: if not return_dict:
return (decoded,) return (decoded,)
...@@ -333,13 +342,12 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -333,13 +342,12 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
Args: Args:
x (`torch.Tensor`): Input batch of images. x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.autoencoders.autoencoder_kl.AutoencoderKLOutput`] instead of a Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
plain tuple.
Returns: Returns:
[`~models.autoencoders.autoencoder_kl.AutoencoderKLOutput`] or `tuple`: [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
If return_dict is True, a [`~models.autoencoders.autoencoder_kl.AutoencoderKLOutput`] is returned, If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
otherwise a plain `tuple` is returned. `tuple` is returned.
""" """
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment