Unverified Commit 4f14b363 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

Full Dreambooth IF stage II upscaling (#3561)

* update dreambooth lora to work with IF stage II

* Update dreambooth script for IF stage II upscaler
parent f751b884
...@@ -52,6 +52,7 @@ from diffusers import ( ...@@ -52,6 +52,7 @@ from diffusers import (
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import randn_tensor
if is_wandb_available(): if is_wandb_available():
...@@ -114,16 +115,17 @@ def log_validation( ...@@ -114,16 +115,17 @@ def log_validation(
pipeline_args = {} pipeline_args = {}
if text_encoder is not None:
pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder)
if vae is not None: if vae is not None:
pipeline_args["vae"] = vae pipeline_args["vae"] = vae
if text_encoder is not None:
text_encoder = accelerator.unwrap_model(text_encoder)
# create pipeline (note: unet and vae are loaded again in float32) # create pipeline (note: unet and vae are loaded again in float32)
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
tokenizer=tokenizer, tokenizer=tokenizer,
text_encoder=text_encoder,
unet=accelerator.unwrap_model(unet), unet=accelerator.unwrap_model(unet),
revision=args.revision, revision=args.revision,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
...@@ -156,10 +158,16 @@ def log_validation( ...@@ -156,10 +158,16 @@ def log_validation(
# run inference # run inference
generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
images = [] images = []
for _ in range(args.num_validation_images): if args.validation_images is None:
with torch.autocast("cuda"): for _ in range(args.num_validation_images):
image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0] with torch.autocast("cuda"):
images.append(image) image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0]
images.append(image)
else:
for image in args.validation_images:
image = Image.open(image)
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
images.append(image)
for tracker in accelerator.trackers: for tracker in accelerator.trackers:
if tracker.name == "tensorboard": if tracker.name == "tensorboard":
...@@ -525,6 +533,19 @@ def parse_args(input_args=None): ...@@ -525,6 +533,19 @@ def parse_args(input_args=None):
parser.add_argument( parser.add_argument(
"--skip_save_text_encoder", action="store_true", required=False, help="Set to not save text encoder" "--skip_save_text_encoder", action="store_true", required=False, help="Set to not save text encoder"
) )
parser.add_argument(
"--validation_images",
required=False,
default=None,
nargs="+",
help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.",
)
parser.add_argument(
"--class_labels_conditioning",
required=False,
default=None,
help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.",
)
if input_args is not None: if input_args is not None:
args = parser.parse_args(input_args) args = parser.parse_args(input_args)
...@@ -1169,7 +1190,7 @@ def main(args): ...@@ -1169,7 +1190,7 @@ def main(args):
) )
else: else:
noise = torch.randn_like(model_input) noise = torch.randn_like(model_input)
bsz = model_input.shape[0] bsz, channels, height, width = model_input.shape
# Sample a random timestep for each image # Sample a random timestep for each image
timesteps = torch.randint( timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
...@@ -1191,8 +1212,24 @@ def main(args): ...@@ -1191,8 +1212,24 @@ def main(args):
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
) )
if unet.config.in_channels > channels:
needed_additional_channels = unet.config.in_channels - channels
additional_latents = randn_tensor(
(bsz, needed_additional_channels, height, width),
device=noisy_model_input.device,
dtype=noisy_model_input.dtype,
)
noisy_model_input = torch.cat([additional_latents, noisy_model_input], dim=1)
if args.class_labels_conditioning == "timesteps":
class_labels = timesteps
else:
class_labels = None
# Predict the noise residual # Predict the noise residual
model_pred = unet(noisy_model_input, timesteps, encoder_hidden_states).sample model_pred = unet(
noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels
).sample
if model_pred.shape[1] == 6: if model_pred.shape[1] == 6:
model_pred, _ = torch.chunk(model_pred, 2, dim=1) model_pred, _ = torch.chunk(model_pred, 2, dim=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment