"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "24895a1f494062d73028e31880c8848c6a674750"
Unverified Commit fbdf26ba authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

[dreambooth lora sdxl] add sdxl micro conditioning (#6795)



* add micro conditioning

* remove redundant lines

* style

* fix missing 's'

* fix missing shape bug due to missing RGB if statement

* remove redundant if, change arg order

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 13001ee3
...@@ -19,6 +19,7 @@ import itertools ...@@ -19,6 +19,7 @@ import itertools
import logging import logging
import math import math
import os import os
import random
import shutil import shutil
import warnings import warnings
from pathlib import Path from pathlib import Path
...@@ -40,6 +41,7 @@ from PIL import Image ...@@ -40,6 +41,7 @@ from PIL import Image
from PIL.ImageOps import exif_transpose from PIL.ImageOps import exif_transpose
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchvision import transforms from torchvision import transforms
from torchvision.transforms.functional import crop
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig from transformers import AutoTokenizer, PretrainedConfig
...@@ -304,18 +306,6 @@ def parse_args(input_args=None): ...@@ -304,18 +306,6 @@ def parse_args(input_args=None):
" resolution" " resolution"
), ),
) )
parser.add_argument(
"--crops_coords_top_left_h",
type=int,
default=0,
help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
)
parser.add_argument(
"--crops_coords_top_left_w",
type=int,
default=0,
help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
)
parser.add_argument( parser.add_argument(
"--center_crop", "--center_crop",
default=False, default=False,
...@@ -325,6 +315,11 @@ def parse_args(input_args=None): ...@@ -325,6 +315,11 @@ def parse_args(input_args=None):
" cropped. The images will be resized to the resolution first before cropping." " cropped. The images will be resized to the resolution first before cropping."
), ),
) )
parser.add_argument(
"--random_flip",
action="store_true",
help="whether to randomly flip images horizontally",
)
parser.add_argument( parser.add_argument(
"--train_text_encoder", "--train_text_encoder",
action="store_true", action="store_true",
...@@ -669,6 +664,41 @@ class DreamBoothDataset(Dataset): ...@@ -669,6 +664,41 @@ class DreamBoothDataset(Dataset):
self.instance_images = [] self.instance_images = []
for img in instance_images: for img in instance_images:
self.instance_images.extend(itertools.repeat(img, repeats)) self.instance_images.extend(itertools.repeat(img, repeats))
# image processing to prepare for using SD-XL micro-conditioning
self.original_sizes = []
self.crop_top_lefts = []
self.pixel_values = []
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
train_flip = transforms.RandomHorizontalFlip(p=1.0)
train_transforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
for image in self.instance_images:
image = exif_transpose(image)
if not image.mode == "RGB":
image = image.convert("RGB")
self.original_sizes.append((image.height, image.width))
image = train_resize(image)
if args.random_flip and random.random() < 0.5:
# flip
image = train_flip(image)
if args.center_crop:
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
image = train_crop(image)
else:
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
image = crop(image, y1, x1, h, w)
crop_top_left = (y1, x1)
self.crop_top_lefts.append(crop_top_left)
image = train_transforms(image)
self.pixel_values.append(image)
self.num_instance_images = len(self.instance_images) self.num_instance_images = len(self.instance_images)
self._length = self.num_instance_images self._length = self.num_instance_images
...@@ -698,12 +728,12 @@ class DreamBoothDataset(Dataset): ...@@ -698,12 +728,12 @@ class DreamBoothDataset(Dataset):
def __getitem__(self, index): def __getitem__(self, index):
example = {} example = {}
instance_image = self.instance_images[index % self.num_instance_images] instance_image = self.pixel_values[index % self.num_instance_images]
instance_image = exif_transpose(instance_image) original_size = self.original_sizes[index % self.num_instance_images]
crop_top_left = self.crop_top_lefts[index % self.num_instance_images]
if not instance_image.mode == "RGB": example["instance_images"] = instance_image
instance_image = instance_image.convert("RGB") example["original_size"] = original_size
example["instance_images"] = self.image_transforms(instance_image) example["crop_top_left"] = crop_top_left
if self.custom_instance_prompts: if self.custom_instance_prompts:
caption = self.custom_instance_prompts[index % self.num_instance_images] caption = self.custom_instance_prompts[index % self.num_instance_images]
...@@ -730,6 +760,8 @@ class DreamBoothDataset(Dataset): ...@@ -730,6 +760,8 @@ class DreamBoothDataset(Dataset):
def collate_fn(examples, with_prior_preservation=False): def collate_fn(examples, with_prior_preservation=False):
pixel_values = [example["instance_images"] for example in examples] pixel_values = [example["instance_images"] for example in examples]
prompts = [example["instance_prompt"] for example in examples] prompts = [example["instance_prompt"] for example in examples]
original_sizes = [example["original_size"] for example in examples]
crop_top_lefts = [example["crop_top_left"] for example in examples]
# Concat class and instance examples for prior preservation. # Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes. # We do this to avoid doing two forward passes.
...@@ -740,7 +772,12 @@ def collate_fn(examples, with_prior_preservation=False): ...@@ -740,7 +772,12 @@ def collate_fn(examples, with_prior_preservation=False):
pixel_values = torch.stack(pixel_values) pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
batch = {"pixel_values": pixel_values, "prompts": prompts} batch = {
"pixel_values": pixel_values,
"prompts": prompts,
"original_sizes": original_sizes,
"crop_top_lefts": crop_top_lefts,
}
return batch return batch
...@@ -1233,11 +1270,9 @@ def main(args): ...@@ -1233,11 +1270,9 @@ def main(args):
# pooled text embeddings # pooled text embeddings
# time ids # time ids
def compute_time_ids(): def compute_time_ids(original_size, crops_coords_top_left):
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
original_size = (args.resolution, args.resolution)
target_size = (args.resolution, args.resolution) target_size = (args.resolution, args.resolution)
crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)
add_time_ids = list(original_size + crops_coords_top_left + target_size) add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids]) add_time_ids = torch.tensor([add_time_ids])
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype) add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
...@@ -1254,9 +1289,6 @@ def main(args): ...@@ -1254,9 +1289,6 @@ def main(args):
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
return prompt_embeds, pooled_prompt_embeds return prompt_embeds, pooled_prompt_embeds
# Handle instance prompt.
instance_time_ids = compute_time_ids()
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT # If no type of tuning is done on the text_encoder and custom instance prompts are NOT
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
# the redundant encoding. # the redundant encoding.
...@@ -1267,7 +1299,6 @@ def main(args): ...@@ -1267,7 +1299,6 @@ def main(args):
# Handle class prompt for prior-preservation. # Handle class prompt for prior-preservation.
if args.with_prior_preservation: if args.with_prior_preservation:
class_time_ids = compute_time_ids()
if not args.train_text_encoder: if not args.train_text_encoder:
class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings( class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(
args.class_prompt, text_encoders, tokenizers args.class_prompt, text_encoders, tokenizers
...@@ -1282,9 +1313,6 @@ def main(args): ...@@ -1282,9 +1313,6 @@ def main(args):
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't # pack the statically computed variables appropriately here. This is so that we don't
# have to pass them to the dataloader. # have to pass them to the dataloader.
add_time_ids = instance_time_ids
if args.with_prior_preservation:
add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0)
if not train_dataset.custom_instance_prompts: if not train_dataset.custom_instance_prompts:
if not args.train_text_encoder: if not args.train_text_encoder:
...@@ -1436,18 +1464,24 @@ def main(args): ...@@ -1436,18 +1464,24 @@ def main(args):
# (this is the forward diffusion process) # (this is the forward diffusion process)
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
# time ids
add_time_ids = torch.cat(
[
compute_time_ids(original_size=s, crops_coords_top_left=c)
for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])
]
)
# Calculate the elements to repeat depending on the use of prior-preservation and custom captions. # Calculate the elements to repeat depending on the use of prior-preservation and custom captions.
if not train_dataset.custom_instance_prompts: if not train_dataset.custom_instance_prompts:
elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz
elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
else: else:
elems_to_repeat_text_embeds = 1 elems_to_repeat_text_embeds = 1
elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
# Predict the noise residual # Predict the noise residual
if not args.train_text_encoder: if not args.train_text_encoder:
unet_added_conditions = { unet_added_conditions = {
"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1), "time_ids": add_time_ids,
"text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1), "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1),
} }
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
...@@ -1459,7 +1493,7 @@ def main(args): ...@@ -1459,7 +1493,7 @@ def main(args):
return_dict=False, return_dict=False,
)[0] )[0]
else: else:
unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)} unet_added_conditions = {"time_ids": add_time_ids}
prompt_embeds, pooled_prompt_embeds = encode_prompt( prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two], text_encoders=[text_encoder_one, text_encoder_two],
tokenizers=None, tokenizers=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment