Unverified Commit f751b884 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

update dreambooth lora to work with IF stage II (#3560)

parent abb89da4
...@@ -60,6 +60,7 @@ from diffusers.models.attention_processor import ( ...@@ -60,6 +60,7 @@ from diffusers.models.attention_processor import (
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, check_min_version, is_wandb_available from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import randn_tensor
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
...@@ -425,6 +426,19 @@ def parse_args(input_args=None): ...@@ -425,6 +426,19 @@ def parse_args(input_args=None):
required=False, required=False,
help="Whether to use attention mask for the text encoder", help="Whether to use attention mask for the text encoder",
) )
parser.add_argument(
"--validation_images",
required=False,
default=None,
nargs="+",
help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.",
)
parser.add_argument(
"--class_labels_conditioning",
required=False,
default=None,
help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.",
)
if input_args is not None: if input_args is not None:
args = parser.parse_args(input_args) args = parser.parse_args(input_args)
...@@ -1121,7 +1135,7 @@ def main(args): ...@@ -1121,7 +1135,7 @@ def main(args):
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn_like(model_input) noise = torch.randn_like(model_input)
bsz = model_input.shape[0] bsz, channels, height, width = model_input.shape
# Sample a random timestep for each image # Sample a random timestep for each image
timesteps = torch.randint( timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
...@@ -1143,8 +1157,24 @@ def main(args): ...@@ -1143,8 +1157,24 @@ def main(args):
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
) )
if unet.config.in_channels > channels:
needed_additional_channels = unet.config.in_channels - channels
additional_latents = randn_tensor(
(bsz, needed_additional_channels, height, width),
device=noisy_model_input.device,
dtype=noisy_model_input.dtype,
)
noisy_model_input = torch.cat([additional_latents, noisy_model_input], dim=1)
if args.class_labels_conditioning == "timesteps":
class_labels = timesteps
else:
class_labels = None
# Predict the noise residual # Predict the noise residual
model_pred = unet(noisy_model_input, timesteps, encoder_hidden_states).sample model_pred = unet(
noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels
).sample
# if model predicts variance, throw away the prediction. we will only train on the # if model predicts variance, throw away the prediction. we will only train on the
# simplified training objective. This means that all schedulers using the fine tuned # simplified training objective. This means that all schedulers using the fine tuned
...@@ -1248,9 +1278,18 @@ def main(args): ...@@ -1248,9 +1278,18 @@ def main(args):
} }
else: else:
pipeline_args = {"prompt": args.validation_prompt} pipeline_args = {"prompt": args.validation_prompt}
images = [
pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images) if args.validation_images is None:
] images = [
pipeline(**pipeline_args, generator=generator).images[0]
for _ in range(args.num_validation_images)
]
else:
images = []
for image in args.validation_images:
image = Image.open(image)
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
images.append(image)
for tracker in accelerator.trackers: for tracker in accelerator.trackers:
if tracker.name == "tensorboard": if tracker.name == "tensorboard":
......
...@@ -10,6 +10,7 @@ import torch ...@@ -10,6 +10,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
from ...loaders import LoraLoaderMixin
from ...models import UNet2DConditionModel from ...models import UNet2DConditionModel
from ...schedulers import DDPMScheduler from ...schedulers import DDPMScheduler
from ...utils import ( from ...utils import (
...@@ -112,7 +113,7 @@ EXAMPLE_DOC_STRING = """ ...@@ -112,7 +113,7 @@ EXAMPLE_DOC_STRING = """
""" """
class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline): class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
tokenizer: T5Tokenizer tokenizer: T5Tokenizer
text_encoder: T5EncoderModel text_encoder: T5EncoderModel
...@@ -1047,6 +1048,9 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline): ...@@ -1047,6 +1048,9 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
intermediate_images = self.scheduler.step( intermediate_images = self.scheduler.step(
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
......
...@@ -10,6 +10,7 @@ import torch ...@@ -10,6 +10,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
from ...loaders import LoraLoaderMixin
from ...models import UNet2DConditionModel from ...models import UNet2DConditionModel
from ...schedulers import DDPMScheduler from ...schedulers import DDPMScheduler
from ...utils import ( from ...utils import (
...@@ -114,7 +115,7 @@ EXAMPLE_DOC_STRING = """ ...@@ -114,7 +115,7 @@ EXAMPLE_DOC_STRING = """
""" """
class IFInpaintingSuperResolutionPipeline(DiffusionPipeline): class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
tokenizer: T5Tokenizer tokenizer: T5Tokenizer
text_encoder: T5EncoderModel text_encoder: T5EncoderModel
...@@ -1154,6 +1155,9 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline): ...@@ -1154,6 +1155,9 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
prev_intermediate_images = intermediate_images prev_intermediate_images = intermediate_images
......
...@@ -10,6 +10,7 @@ import torch ...@@ -10,6 +10,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
from ...loaders import LoraLoaderMixin
from ...models import UNet2DConditionModel from ...models import UNet2DConditionModel
from ...schedulers import DDPMScheduler from ...schedulers import DDPMScheduler
from ...utils import ( from ...utils import (
...@@ -70,7 +71,7 @@ EXAMPLE_DOC_STRING = """ ...@@ -70,7 +71,7 @@ EXAMPLE_DOC_STRING = """
""" """
class IFSuperResolutionPipeline(DiffusionPipeline): class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
tokenizer: T5Tokenizer tokenizer: T5Tokenizer
text_encoder: T5EncoderModel text_encoder: T5EncoderModel
...@@ -903,6 +904,9 @@ class IFSuperResolutionPipeline(DiffusionPipeline): ...@@ -903,6 +904,9 @@ class IFSuperResolutionPipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
intermediate_images = self.scheduler.step( intermediate_images = self.scheduler.step(
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment