Unverified Commit ed616bd8 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[LoRA] Add LoRA training script (#1884)



* [Lora] first upload

* add first lora version

* upload

* more

* first training

* up

* correct

* improve

* finish loaders and inference

* up

* up

* fix more

* up

* finish more

* finish more

* up

* up

* change year

* revert year change

* Change lines

* Add cloneofsimo as co-author.
Co-authored-by: default avatarSimo Ryu <cloneofsimo@gmail.com>

* finish

* fix docs

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* upload

* finish
Co-authored-by: default avatarSimo Ryu <cloneofsimo@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent ac3fc649
...@@ -90,6 +90,8 @@ ...@@ -90,6 +90,8 @@
title: Configuration title: Configuration
- local: api/outputs - local: api/outputs
title: Outputs title: Outputs
- local: api/loaders
title: Loaders
title: Main Classes title: Main Classes
- sections: - sections:
- local: api/pipelines/overview - local: api/pipelines/overview
......
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Loaders
There are many weights to train adapter neural networks for diffusion models, such as
- [Textual Inversion](./training/text_inversion.mdx)
- [LoRA](https://github.com/cloneofsimo/lora)
- [Hypernetworks](https://arxiv.org/abs/1609.09106)
Such adapter neural networks often only consist of a fraction of the number of weights compared
to the pretrained model and as such are very portable. The Diffusers library offers an easy-to-use
API to load such adapter neural networks via the [`loaders.py` module](https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders.py).
**Note**: This module is still highly experimental and prone to future changes.
## LoaderMixins
### UNet2DConditionLoadersMixin
[[autodoc]] loaders.UNet2DConditionLoadersMixin
<!--Copyright 2020 The HuggingFace Team. All rights reserved. <!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at the License. You may obtain a copy of the License at
......
...@@ -5,6 +5,7 @@ The `train_dreambooth.py` script shows how to implement the training procedure a ...@@ -5,6 +5,7 @@ The `train_dreambooth.py` script shows how to implement the training procedure a
## Running locally with PyTorch ## Running locally with PyTorch
### Installing the dependencies ### Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies: Before running the scripts, make sure to install the library's training dependencies:
...@@ -235,6 +236,102 @@ image.save("dog-bucket.png") ...@@ -235,6 +236,102 @@ image.save("dog-bucket.png")
You can also perform inference from one of the checkpoints saved during the training process, if you used the `--checkpointing_steps` argument. Please, refer to [the documentation](https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint) to see how to do it. You can also perform inference from one of the checkpoints saved during the training process, if you used the `--checkpointing_steps` argument. Please, refer to [the documentation](https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint) to see how to do it.
## Training with Low-Rank Adaptation of Large Language Models (LoRA)
Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*
In a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114)
- Rank-decomposition matrices have significantly fewer parameters than the original model, which means that trained LoRA weights are easily portable.
- LoRA attention layers allow to control to which extent the model is adapted torwards new training images via a `scale` parameter.
[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in
the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.
### Training
Let's get started with a simple example. We will re-use the dog example of the [previous section](#dog-toy-example).
First, you need to set-up your dreambooth training example as is explained in the [installation section](#Installing-the-dependencies).
Next, let's download the dog dataset. Download images from [here](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ) and save them in a directory. Make sure to set `INSTANCE_DIR` to the name of your directory further below. This will be our training data.
Now, you can launch the training. Here we will use [Stable Diffusion 1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5).
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
**___Note: It is quite useful to monitor the training progress by regularly generating sample images during training. [wandb](https://docs.wandb.ai/quickstart) is a nice solution to easily see generating images during training. All you need to do is to run `pip install wandb` before training and pass `--report_to="wandb"` to automatically log images.___**
```bash
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export INSTANCE_DIR="path-to-instance-images"
export OUTPUT_DIR="path-to-save-model"
```
For this example we want to directly store the trained LoRA embeddings on the Hub, so
we need to be logged in and add the `--push_to_hub` flag.
```bash
huggingface-cli login
```
Now we can start training!
```bash
accelerate launch train_dreambooth_lora.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--instance_prompt="a photo of sks dog" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=1 \
--checkpointing_steps=100 \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=50 \
--seed="0" \
--push_to_hub
```
**___Note: When using LoRA we can use a much higher learning rate compared to vanilla dreambooth. Here we
use *1e-4* instead of the usual *2e-6*.___**
The final LoRA embedding weights have been uploaded to [patrickvonplaten/lora_dreambooth_dog_example](https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example). **___Note: [The final weights](https://huggingface.co/patrickvonplaten/lora/blob/main/pytorch_attn_procs.bin) are only 3 MB in size which is orders of magnitudes smaller than the original model.**
The training results are summarized [here](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5).
You can use the `Step` slider to see how the model learned the features of our subject while the model trained.
### Inference
After training, LoRA weights can be loaded very easily into the original pipeline. First, you need to
load the original pipeline:
```python
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
import torch
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.to("cuda")
```
Next, we can load the adapter layers into the UNet with the [`load_attn_procs` function](https://huggingface.co/docs/diffusers/api/loaders#diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs).
```python
pipe.load_attn_procs("patrickvonplaten/lora")
```
Finally, we can run the model in inference.
```python
image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0]
```
## Training with Flax/JAX ## Training with Flax/JAX
For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script. For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.
......
#!/usr/bin/env python
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import argparse import argparse
import hashlib import hashlib
import itertools import itertools
......
This diff is collapsed.
#!/usr/bin/env python
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import argparse import argparse
import copy import copy
import logging import logging
......
#!/usr/bin/env python
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import argparse import argparse
import logging import logging
import math import math
......
# Copyright 2020 The HuggingFace Team. All rights reserved. # Copyright 2022 The HuggingFace Team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections import defaultdict
from typing import Callable, Dict, Union
import torch
from .models.cross_attention import LoRACrossAttnProcessor
from .models.modeling_utils import _get_model_file
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, logging
logger = logging.get_logger(__name__)
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
class AttnProcsLayers(torch.nn.Module):
def __init__(self, state_dict: Dict[str, torch.Tensor]):
super().__init__()
self.layers = torch.nn.ModuleList(state_dict.values())
self.mapping = {k: v for k, v in enumerate(state_dict.keys())}
self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
# we add a hook to state_dict() and load_state_dict() so that the
# naming fits with `unet.attn_processors`
def map_to(module, state_dict, *args, **kwargs):
new_state_dict = {}
for key, value in state_dict.items():
num = int(key.split(".")[1]) # 0 is always "layers"
new_key = key.replace(f"layers.{num}", module.mapping[num])
new_state_dict[new_key] = value
return new_state_dict
def map_from(module, state_dict, *args, **kwargs):
all_keys = list(state_dict.keys())
for key in all_keys:
replace_key = key.split(".processor")[0] + ".processor"
new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
state_dict[new_key] = state_dict[key]
del state_dict[key]
self._register_state_dict_hook(map_to)
self._register_load_state_dict_pre_hook(map_from, with_module=True)
class UNet2DConditionLoadersMixin:
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
r"""
Load pretrained attention processor layers into `UNet2DConditionModel`. Attention processor layers have to be
defined in
[cross_attention.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py)
and be a `torch.nn.Module` class.
<Tip warning={true}>
This function is experimental and might change in the future
</Tip>
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
`./my_model_directory/`.
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `diffusers-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo (either remote in
huggingface.co or downloaded locally), you can specify the folder name here.
mirror (`str`, *optional*):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
Please refer to the mirror site for more information.
<Tip>
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
models](https://huggingface.co/docs/hub/models-gated#gated-models).
</Tip>
<Tip>
Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
this method in a firewalled environment.
</Tip>
"""
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", LORA_WEIGHT_NAME)
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = torch.load(model_file, map_location="cpu")
else:
state_dict = pretrained_model_name_or_path_or_dict
# fill attn processors
attn_processors = {}
is_lora = all("lora" in k for k in state_dict.keys())
if is_lora:
lora_grouped_dict = defaultdict(dict)
for key, value in state_dict.items():
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
lora_grouped_dict[attn_processor_key][sub_key] = value
for key, value_dict in lora_grouped_dict.items():
rank = value_dict["to_k_lora.down.weight"].shape[0]
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
attn_processors[key] = LoRACrossAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
)
attn_processors[key].load_state_dict(value_dict)
else:
raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.")
# set correct dtype & device
attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}
# set layers
self.set_attn_processor(attn_processors)
def save_attn_procs(
self,
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
weights_name: str = LORA_WEIGHT_NAME,
save_function: Callable = None,
):
r"""
Save an attention procesor to a directory, so that it can be re-loaded using the
`[`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`]` method.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to which to save. Will be created if it doesn't exist.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful when in distributed training like
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
the main process to avoid race conditions.
save_function (`Callable`):
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
need to replace `torch.save` by another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
"""
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
if save_function is None:
save_function = torch.save
os.makedirs(save_directory, exist_ok=True)
model_to_save = AttnProcsLayers(self.attn_processors)
# Save the model
state_dict = model_to_save.state_dict()
# Clean the folder from a previous save
for filename in os.listdir(save_directory):
full_filename = os.path.join(save_directory, filename)
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process
# in distributed settings to avoid race conditions.
weights_no_suffix = weights_name.replace(".bin", "")
if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and is_main_process:
os.remove(full_filename)
# Save the model
save_function(state_dict, os.path.join(save_directory, weights_name))
logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
...@@ -246,6 +246,68 @@ class CrossAttnProcessor: ...@@ -246,6 +246,68 @@ class CrossAttnProcessor:
return hidden_states return hidden_states
class LoRALinearLayer(nn.Module):
def __init__(self, in_features, out_features, rank=4):
super().__init__()
if rank > min(in_features, out_features):
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
self.down = nn.Linear(in_features, rank, bias=False)
self.up = nn.Linear(rank, out_features, bias=False)
self.scale = 1.0
nn.init.normal_(self.down.weight, std=1 / rank)
nn.init.zeros_(self.up.weight)
def forward(self, hidden_states):
orig_dtype = hidden_states.dtype
dtype = self.down.weight.dtype
down_hidden_states = self.down(hidden_states.to(dtype))
up_hidden_states = self.up(down_hidden_states)
return up_hidden_states.to(orig_dtype)
class LoRACrossAttnProcessor(nn.Module):
def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
super().__init__()
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size)
def __call__(
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
query = attn.head_to_batch_dim(query)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class CrossAttnAddedKVProcessor: class CrossAttnAddedKVProcessor:
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
residual = hidden_states residual = hidden_states
...@@ -312,6 +374,41 @@ class XFormersCrossAttnProcessor: ...@@ -312,6 +374,41 @@ class XFormersCrossAttnProcessor:
hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[0](hidden_states)
# dropout # dropout
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class LoRAXFormersCrossAttnProcessor(nn.Module):
def __init__(self, hidden_size, cross_attention_dim, rank=4):
super().__init__()
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size)
def __call__(
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
query = attn.head_to_batch_dim(query).contiguous()
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
# linear proj
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states return hidden_states
......
...@@ -438,7 +438,7 @@ class ModelMixin(torch.nn.Module): ...@@ -438,7 +438,7 @@ class ModelMixin(torch.nn.Module):
model_file = None model_file = None
if from_flax: if from_flax:
model_file = cls._get_model_file( model_file = _get_model_file(
pretrained_model_name_or_path, pretrained_model_name_or_path,
weights_name=FLAX_WEIGHTS_NAME, weights_name=FLAX_WEIGHTS_NAME,
cache_dir=cache_dir, cache_dir=cache_dir,
...@@ -474,7 +474,7 @@ class ModelMixin(torch.nn.Module): ...@@ -474,7 +474,7 @@ class ModelMixin(torch.nn.Module):
else: else:
if is_safetensors_available(): if is_safetensors_available():
try: try:
model_file = cls._get_model_file( model_file = _get_model_file(
pretrained_model_name_or_path, pretrained_model_name_or_path,
weights_name=SAFETENSORS_WEIGHTS_NAME, weights_name=SAFETENSORS_WEIGHTS_NAME,
cache_dir=cache_dir, cache_dir=cache_dir,
...@@ -490,7 +490,7 @@ class ModelMixin(torch.nn.Module): ...@@ -490,7 +490,7 @@ class ModelMixin(torch.nn.Module):
except: except:
pass pass
if model_file is None: if model_file is None:
model_file = cls._get_model_file( model_file = _get_model_file(
pretrained_model_name_or_path, pretrained_model_name_or_path,
weights_name=WEIGHTS_NAME, weights_name=WEIGHTS_NAME,
cache_dir=cache_dir, cache_dir=cache_dir,
...@@ -599,92 +599,6 @@ class ModelMixin(torch.nn.Module): ...@@ -599,92 +599,6 @@ class ModelMixin(torch.nn.Module):
return model return model
@classmethod
def _get_model_file(
cls,
pretrained_model_name_or_path,
*,
weights_name,
subfolder,
cache_dir,
force_download,
proxies,
resume_download,
local_files_only,
use_auth_token,
user_agent,
revision,
):
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
# Load from a PyTorch checkpoint
model_file = os.path.join(pretrained_model_name_or_path, weights_name)
elif subfolder is not None and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
):
model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
else:
raise EnvironmentError(
f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
)
return model_file
else:
try:
# Load from URL or cache if already cached
model_file = hf_hub_download(
pretrained_model_name_or_path,
filename=weights_name,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision,
)
return model_file
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
)
except HTTPError as err:
raise EnvironmentError(
"There was a specific connection error when trying to load"
f" {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a file named {weights_name} or"
" \nCheckout your internet connection or see how to run the library in"
" offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a file named {weights_name}"
)
@classmethod @classmethod
def _load_pretrained_model( def _load_pretrained_model(
cls, cls,
...@@ -848,7 +762,9 @@ def _get_model_file( ...@@ -848,7 +762,9 @@ def _get_model_file(
revision, revision,
): ):
pretrained_model_name_or_path = str(pretrained_model_name_or_path) pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path): if os.path.isfile(pretrained_model_name_or_path):
return pretrained_model_name_or_path
elif os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)): if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
# Load from a PyTorch checkpoint # Load from a PyTorch checkpoint
model_file = os.path.join(pretrained_model_name_or_path, weights_name) model_file = os.path.join(pretrained_model_name_or_path, weights_name)
......
...@@ -19,6 +19,7 @@ import torch.nn as nn ...@@ -19,6 +19,7 @@ import torch.nn as nn
import torch.utils.checkpoint import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin
from ..utils import BaseOutput, logging from ..utils import BaseOutput, logging
from .cross_attention import AttnProcessor from .cross_attention import AttnProcessor
from .embeddings import TimestepEmbedding, Timesteps from .embeddings import TimestepEmbedding, Timesteps
...@@ -49,7 +50,7 @@ class UNet2DConditionOutput(BaseOutput): ...@@ -49,7 +50,7 @@ class UNet2DConditionOutput(BaseOutput):
sample: torch.FloatTensor sample: torch.FloatTensor
class UNet2DConditionModel(ModelMixin, ConfigMixin): class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
r""" r"""
UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
and returns sample shaped output. and returns sample shaped output.
...@@ -266,17 +267,59 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -266,17 +267,59 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
self.conv_act = nn.SiLU() self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
def set_attn_processor(self, processor: AttnProcessor): @property
def attn_processors(self) -> Dict[str, AttnProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively # set recursively
def fn_recursive_attn_processor(module: torch.nn.Module): processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]):
if hasattr(module, "set_processor"): if hasattr(module, "set_processor"):
processors[f"{name}.processor"] = module.processor
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]):
r"""
Parameters:
`processor (`dict` of `AttnProcessor` or `AttnProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
of **all** `CrossAttention` layers.
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.:
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor) module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for child in module.children(): for sub_name, child in module.named_children():
fn_recursive_attn_processor(child) fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for module in self.children(): for name, module in self.named_children():
fn_recursive_attn_processor(module) fn_recursive_attn_processor(name, module, processor)
def set_attention_slice(self, slice_size): def set_attention_slice(self, slice_size):
r""" r"""
......
...@@ -353,17 +353,59 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -353,17 +353,59 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
self.conv_act = nn.SiLU() self.conv_act = nn.SiLU()
self.conv_out = LinearMultiDim(block_out_channels[0], out_channels, kernel_size=3, padding=1) self.conv_out = LinearMultiDim(block_out_channels[0], out_channels, kernel_size=3, padding=1)
def set_attn_processor(self, processor: AttnProcessor): @property
def attn_processors(self) -> Dict[str, AttnProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively # set recursively
def fn_recursive_attn_processor(module: torch.nn.Module): processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]):
if hasattr(module, "set_processor"): if hasattr(module, "set_processor"):
processors[f"{name}.processor"] = module.processor
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]):
r"""
Parameters:
`processor (`dict` of `AttnProcessor` or `AttnProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
of **all** `CrossAttention` layers.
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.:
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor) module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for child in module.children(): for sub_name, child in module.named_children():
fn_recursive_attn_processor(child) fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for module in self.children(): for name, module in self.named_children():
fn_recursive_attn_processor(module) fn_recursive_attn_processor(name, module, processor)
def set_attention_slice(self, slice_size): def set_attention_slice(self, slice_size):
r""" r"""
......
...@@ -58,6 +58,7 @@ from .import_utils import ( ...@@ -58,6 +58,7 @@ from .import_utils import (
is_transformers_available, is_transformers_available,
is_transformers_version, is_transformers_version,
is_unidecode_available, is_unidecode_available,
is_wandb_available,
is_xformers_available, is_xformers_available,
requires_backends, requires_backends,
) )
......
...@@ -217,6 +217,13 @@ try: ...@@ -217,6 +217,13 @@ try:
except importlib_metadata.PackageNotFoundError: except importlib_metadata.PackageNotFoundError:
_k_diffusion_available = False _k_diffusion_available = False
_wandb_available = importlib.util.find_spec("wandb") is not None
try:
_wandb_version = importlib_metadata.version("wandb")
logger.debug(f"Successfully imported k-diffusion version {_wandb_version }")
except importlib_metadata.PackageNotFoundError:
_wandb_available = False
def is_torch_available(): def is_torch_available():
return _torch_available return _torch_available
...@@ -274,6 +281,10 @@ def is_k_diffusion_available(): ...@@ -274,6 +281,10 @@ def is_k_diffusion_available():
return _k_diffusion_available return _k_diffusion_available
def is_wandb_available():
return _wandb_available
# docstyle-ignore # docstyle-ignore
FLAX_IMPORT_ERROR = """ FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
...@@ -328,6 +339,12 @@ K_DIFFUSION_IMPORT_ERROR = """ ...@@ -328,6 +339,12 @@ K_DIFFUSION_IMPORT_ERROR = """
install k-diffusion` install k-diffusion`
""" """
# docstyle-ignore
WANDB_IMPORT_ERROR = """
{0} requires the wandb library but it was not found in your environment. You can install it with pip: `pip
install wandb`
"""
BACKENDS_MAPPING = OrderedDict( BACKENDS_MAPPING = OrderedDict(
[ [
...@@ -340,6 +357,7 @@ BACKENDS_MAPPING = OrderedDict( ...@@ -340,6 +357,7 @@ BACKENDS_MAPPING = OrderedDict(
("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)), ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)),
("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)), ("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)),
("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)), ("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)),
("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)),
] ]
) )
......
# coding=utf-8 # coding=utf-8
# Copyright 2020 Optuna, Hugging Face # Copyright 2022 Optuna, Hugging Face
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
This diff is collapsed.
This diff is collapsed.
# coding=utf-8 # coding=utf-8
# Copyright 2020 The HuggingFace Inc. team. # Copyright 2022 The HuggingFace Inc. team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment