Unverified Commit a757b2db authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

if dreambooth lora (#3360)

* update IF stage I pipelines

add fixed variance schedulers and lora loading

* added kv lora attn processor

* allow loading into alternative lora attn processor

* make vae optional

* throw away predicted variance

* allow loading into added kv lora layer

* allow load T5

* allow pre compute text embeddings

* set new variance type in schedulers

* fix copies

* refactor all prompt embedding code

class prompts are now included in pre-encoding code
max tokenizer length is now configurable
embedding attention mask is now configurable

* fix for when variance type is not defined on scheduler

* do not pre compute validation prompt if not present

* add example test for if lora dreambooth

* add check for train text encoder and pre compute text embeddings
parent 571bc1ea
...@@ -292,6 +292,41 @@ class ExamplesTestsAccelerate(unittest.TestCase): ...@@ -292,6 +292,41 @@ class ExamplesTestsAccelerate(unittest.TestCase):
is_correct_naming = all(k.startswith("unet") or k.startswith("text_encoder") for k in keys) is_correct_naming = all(k.startswith("unet") or k.startswith("text_encoder") for k in keys)
self.assertTrue(is_correct_naming) self.assertTrue(is_correct_naming)
def test_dreambooth_lora_if_model(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth_lora.py
--pretrained_model_name_or_path hf-internal-testing/tiny-if-pipe
--instance_data_dir docs/source/en/imgs
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--pre_compute_text_embeddings
--tokenizer_max_length=77
--text_encoder_use_attention_mask
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"unet"` in their names.
starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
self.assertTrue(starts_with_unet)
def test_custom_diffusion(self): def test_custom_diffusion(self):
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
test_args = f""" test_args = f"""
......
...@@ -21,9 +21,13 @@ import torch ...@@ -21,9 +21,13 @@ import torch
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from .models.attention_processor import ( from .models.attention_processor import (
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor,
CustomDiffusionXFormersAttnProcessor, CustomDiffusionXFormersAttnProcessor,
LoRAAttnAddedKVProcessor,
LoRAAttnProcessor, LoRAAttnProcessor,
SlicedAttnAddedKVProcessor,
) )
from .utils import ( from .utils import (
DIFFUSERS_CACHE, DIFFUSERS_CACHE,
...@@ -250,10 +254,22 @@ class UNet2DConditionLoadersMixin: ...@@ -250,10 +254,22 @@ class UNet2DConditionLoadersMixin:
for key, value_dict in lora_grouped_dict.items(): for key, value_dict in lora_grouped_dict.items():
rank = value_dict["to_k_lora.down.weight"].shape[0] rank = value_dict["to_k_lora.down.weight"].shape[0]
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
hidden_size = value_dict["to_k_lora.up.weight"].shape[0] hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
attn_processors[key] = LoRAAttnProcessor( attn_processor = self
for sub_key in key.split("."):
attn_processor = getattr(attn_processor, sub_key)
if isinstance(
attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)
):
cross_attention_dim = value_dict["add_k_proj_lora.down.weight"].shape[1]
attn_processor_class = LoRAAttnAddedKVProcessor
else:
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
attn_processor_class = LoRAAttnProcessor
attn_processors[key] = attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
) )
attn_processors[key].load_state_dict(value_dict) attn_processors[key].load_state_dict(value_dict)
......
...@@ -669,6 +669,73 @@ class AttnAddedKVProcessor2_0: ...@@ -669,6 +669,73 @@ class AttnAddedKVProcessor2_0:
return hidden_states return hidden_states
class LoRAAttnAddedKVProcessor(nn.Module):
def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
query = attn.head_to_batch_dim(query)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + scale * self.add_k_proj_lora(
encoder_hidden_states
)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + scale * self.add_v_proj_lora(
encoder_hidden_states
)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
if not attn.only_cross_attention:
key = attn.to_k(hidden_states) + scale * self.to_k_lora(hidden_states)
value = attn.to_v(hidden_states) + scale * self.to_v_lora(hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
else:
key = encoder_hidden_states_key_proj
value = encoder_hidden_states_value_proj
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
hidden_states = hidden_states + residual
return hidden_states
class XFormersAttnProcessor: class XFormersAttnProcessor:
def __init__(self, attention_op: Optional[Callable] = None): def __init__(self, attention_op: Optional[Callable] = None):
self.attention_op = attention_op self.attention_op = attention_op
...@@ -1022,6 +1089,7 @@ AttentionProcessor = Union[ ...@@ -1022,6 +1089,7 @@ AttentionProcessor = Union[
AttnAddedKVProcessor2_0, AttnAddedKVProcessor2_0,
LoRAAttnProcessor, LoRAAttnProcessor,
LoRAXFormersAttnProcessor, LoRAXFormersAttnProcessor,
LoRAAttnAddedKVProcessor,
CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor,
CustomDiffusionXFormersAttnProcessor, CustomDiffusionXFormersAttnProcessor,
] ]
...@@ -7,6 +7,7 @@ from typing import Any, Callable, Dict, List, Optional, Union ...@@ -7,6 +7,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
from ...loaders import LoraLoaderMixin
from ...models import UNet2DConditionModel from ...models import UNet2DConditionModel
from ...schedulers import DDPMScheduler from ...schedulers import DDPMScheduler
from ...utils import ( from ...utils import (
...@@ -85,7 +86,7 @@ EXAMPLE_DOC_STRING = """ ...@@ -85,7 +86,7 @@ EXAMPLE_DOC_STRING = """
""" """
class IFPipeline(DiffusionPipeline): class IFPipeline(DiffusionPipeline, LoraLoaderMixin):
tokenizer: T5Tokenizer tokenizer: T5Tokenizer
text_encoder: T5EncoderModel text_encoder: T5EncoderModel
...@@ -804,6 +805,9 @@ class IFPipeline(DiffusionPipeline): ...@@ -804,6 +805,9 @@ class IFPipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
intermediate_images = self.scheduler.step( intermediate_images = self.scheduler.step(
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
......
...@@ -9,6 +9,7 @@ import PIL ...@@ -9,6 +9,7 @@ import PIL
import torch import torch
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
from ...loaders import LoraLoaderMixin
from ...models import UNet2DConditionModel from ...models import UNet2DConditionModel
from ...schedulers import DDPMScheduler from ...schedulers import DDPMScheduler
from ...utils import ( from ...utils import (
...@@ -109,7 +110,7 @@ EXAMPLE_DOC_STRING = """ ...@@ -109,7 +110,7 @@ EXAMPLE_DOC_STRING = """
""" """
class IFImg2ImgPipeline(DiffusionPipeline): class IFImg2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
tokenizer: T5Tokenizer tokenizer: T5Tokenizer
text_encoder: T5EncoderModel text_encoder: T5EncoderModel
...@@ -929,6 +930,9 @@ class IFImg2ImgPipeline(DiffusionPipeline): ...@@ -929,6 +930,9 @@ class IFImg2ImgPipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
intermediate_images = self.scheduler.step( intermediate_images = self.scheduler.step(
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
......
...@@ -9,6 +9,7 @@ import PIL ...@@ -9,6 +9,7 @@ import PIL
import torch import torch
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
from ...loaders import LoraLoaderMixin
from ...models import UNet2DConditionModel from ...models import UNet2DConditionModel
from ...schedulers import DDPMScheduler from ...schedulers import DDPMScheduler
from ...utils import ( from ...utils import (
...@@ -112,7 +113,7 @@ EXAMPLE_DOC_STRING = """ ...@@ -112,7 +113,7 @@ EXAMPLE_DOC_STRING = """
""" """
class IFInpaintingPipeline(DiffusionPipeline): class IFInpaintingPipeline(DiffusionPipeline, LoraLoaderMixin):
tokenizer: T5Tokenizer tokenizer: T5Tokenizer
text_encoder: T5EncoderModel text_encoder: T5EncoderModel
...@@ -1044,6 +1045,9 @@ class IFInpaintingPipeline(DiffusionPipeline): ...@@ -1044,6 +1045,9 @@ class IFInpaintingPipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
prev_intermediate_images = intermediate_images prev_intermediate_images = intermediate_images
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment