"git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "1530a66469f32dbc790d9fda9d838cd6773d4e32"
Unverified Commit 49db233b authored by dg845's avatar dg845 Committed by GitHub
Browse files

Clean Up Comments in LCM(-LoRA) Distillation Scripts. (#6145)



* Clean up comments in LCM(-LoRA) distillation scripts.

* Calculate predicted source noise noise_pred correctly for all prediction_types.

* make style

* apply suggestions from review

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 93ea26f2
...@@ -156,7 +156,7 @@ class WebdatasetFilter: ...@@ -156,7 +156,7 @@ class WebdatasetFilter:
return False return False
class Text2ImageDataset: class SDText2ImageDataset:
def __init__( def __init__(
self, self,
train_shards_path_or_url: Union[str, List[str]], train_shards_path_or_url: Union[str, List[str]],
...@@ -359,19 +359,43 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling= ...@@ -359,19 +359,43 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=
# Compare LCMScheduler.step, Step 4 # Compare LCMScheduler.step, Step 4
def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas): def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
if prediction_type == "epsilon": if prediction_type == "epsilon":
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
pred_x_0 = (sample - sigmas * model_output) / alphas pred_x_0 = (sample - sigmas * model_output) / alphas
elif prediction_type == "sample":
pred_x_0 = model_output
elif prediction_type == "v_prediction": elif prediction_type == "v_prediction":
pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output pred_x_0 = alphas * sample - sigmas * model_output
else: else:
raise ValueError(f"Prediction type {prediction_type} currently not supported.") raise ValueError(
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
f" are supported."
)
return pred_x_0 return pred_x_0
# Based on step 4 in DDIMScheduler.step
def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
if prediction_type == "epsilon":
pred_epsilon = model_output
elif prediction_type == "sample":
pred_epsilon = (sample - alphas * model_output) / sigmas
elif prediction_type == "v_prediction":
pred_epsilon = alphas * model_output + sigmas * sample
else:
raise ValueError(
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
f" are supported."
)
return pred_epsilon
def extract_into_tensor(a, t, x_shape): def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape b, *_ = t.shape
out = a.gather(-1, t) out = a.gather(-1, t)
...@@ -835,34 +859,35 @@ def main(args): ...@@ -835,34 +859,35 @@ def main(args):
args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
) )
# The scheduler calculates the alpha and sigma schedule for us # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod) alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod) sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
# Initialize the DDIM ODE solver for distillation.
solver = DDIMSolver( solver = DDIMSolver(
noise_scheduler.alphas_cumprod.numpy(), noise_scheduler.alphas_cumprod.numpy(),
timesteps=noise_scheduler.config.num_train_timesteps, timesteps=noise_scheduler.config.num_train_timesteps,
ddim_timesteps=args.num_ddim_timesteps, ddim_timesteps=args.num_ddim_timesteps,
) )
# 2. Load tokenizers from SD-XL checkpoint. # 2. Load tokenizers from SD 1.X/2.X checkpoint.
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False
) )
# 3. Load text encoders from SD-1.5 checkpoint. # 3. Load text encoders from SD 1.X/2.X checkpoint.
# import correct text encoder classes # import correct text encoder classes
text_encoder = CLIPTextModel.from_pretrained( text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision
) )
# 4. Load VAE from SD-XL checkpoint (or more stable VAE) # 4. Load VAE from SD 1.X/2.X checkpoint
vae = AutoencoderKL.from_pretrained( vae = AutoencoderKL.from_pretrained(
args.pretrained_teacher_model, args.pretrained_teacher_model,
subfolder="vae", subfolder="vae",
revision=args.teacher_revision, revision=args.teacher_revision,
) )
# 5. Load teacher U-Net from SD-XL checkpoint # 5. Load teacher U-Net from SD 1.X/2.X checkpoint
teacher_unet = UNet2DConditionModel.from_pretrained( teacher_unet = UNet2DConditionModel.from_pretrained(
args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
) )
...@@ -872,7 +897,7 @@ def main(args): ...@@ -872,7 +897,7 @@ def main(args):
text_encoder.requires_grad_(False) text_encoder.requires_grad_(False)
teacher_unet.requires_grad_(False) teacher_unet.requires_grad_(False)
# 7. Create online (`unet`) student U-Nets. # 7. Create online student U-Net.
unet = UNet2DConditionModel.from_pretrained( unet = UNet2DConditionModel.from_pretrained(
args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
) )
...@@ -935,6 +960,7 @@ def main(args): ...@@ -935,6 +960,7 @@ def main(args):
# Also move the alpha and sigma noise schedules to accelerator.device. # Also move the alpha and sigma noise schedules to accelerator.device.
alpha_schedule = alpha_schedule.to(accelerator.device) alpha_schedule = alpha_schedule.to(accelerator.device)
sigma_schedule = sigma_schedule.to(accelerator.device) sigma_schedule = sigma_schedule.to(accelerator.device)
# Move the ODE solver to accelerator.device.
solver = solver.to(accelerator.device) solver = solver.to(accelerator.device)
# 10. Handle saving and loading of checkpoints # 10. Handle saving and loading of checkpoints
...@@ -1011,13 +1037,14 @@ def main(args): ...@@ -1011,13 +1037,14 @@ def main(args):
eps=args.adam_epsilon, eps=args.adam_epsilon,
) )
# 13. Dataset creation and data processing
# Here, we compute not just the text embeddings but also the additional embeddings # Here, we compute not just the text embeddings but also the additional embeddings
# needed for the SD XL UNet to operate. # needed for the SD XL UNet to operate.
def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True): def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True):
prompt_embeds = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train) prompt_embeds = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train)
return {"prompt_embeds": prompt_embeds} return {"prompt_embeds": prompt_embeds}
dataset = Text2ImageDataset( dataset = SDText2ImageDataset(
train_shards_path_or_url=args.train_shards_path_or_url, train_shards_path_or_url=args.train_shards_path_or_url,
num_train_examples=args.max_train_samples, num_train_examples=args.max_train_samples,
per_gpu_batch_size=args.train_batch_size, per_gpu_batch_size=args.train_batch_size,
...@@ -1037,6 +1064,7 @@ def main(args): ...@@ -1037,6 +1064,7 @@ def main(args):
tokenizer=tokenizer, tokenizer=tokenizer,
) )
# 14. LR Scheduler creation
# Scheduler and math around the number of training steps. # Scheduler and math around the number of training steps.
overrode_max_train_steps = False overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
...@@ -1051,6 +1079,7 @@ def main(args): ...@@ -1051,6 +1079,7 @@ def main(args):
num_training_steps=args.max_train_steps, num_training_steps=args.max_train_steps,
) )
# 15. Prepare for training
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.
unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler) unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
...@@ -1072,7 +1101,7 @@ def main(args): ...@@ -1072,7 +1101,7 @@ def main(args):
).input_ids.to(accelerator.device) ).input_ids.to(accelerator.device)
uncond_prompt_embeds = text_encoder(uncond_input_ids)[0] uncond_prompt_embeds = text_encoder(uncond_input_ids)[0]
# Train! # 16. Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****") logger.info("***** Running training *****")
...@@ -1123,6 +1152,7 @@ def main(args): ...@@ -1123,6 +1152,7 @@ def main(args):
for epoch in range(first_epoch, args.num_train_epochs): for epoch in range(first_epoch, args.num_train_epochs):
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet): with accelerator.accumulate(unet):
# 1. Load and process the image and text conditioning
image, text = batch image, text = batch
image = image.to(accelerator.device, non_blocking=True) image = image.to(accelerator.device, non_blocking=True)
...@@ -1140,37 +1170,37 @@ def main(args): ...@@ -1140,37 +1170,37 @@ def main(args):
latents = latents * vae.config.scaling_factor latents = latents * vae.config.scaling_factor
latents = latents.to(weight_dtype) latents = latents.to(weight_dtype)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0] bsz = latents.shape[0]
# Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias. # 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.
# For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]
topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long() index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
start_timesteps = solver.ddim_timesteps[index] start_timesteps = solver.ddim_timesteps[index]
timesteps = start_timesteps - topk timesteps = start_timesteps - topk
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
# 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps. # 3. Get boundary scalings for start_timesteps and (end) timesteps.
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
c_skip, c_out = scalings_for_boundary_conditions(timesteps) c_skip, c_out = scalings_for_boundary_conditions(timesteps)
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
# 20.4.5. Add noise to the latents according to the noise magnitude at each timestep # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
# (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1] # timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
noise = torch.randn_like(latents)
noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps) noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it # 5. Sample a random guidance scale w from U[w_min, w_max]
# Note that for LCM-LoRA distillation it is not necessary to use a guidance scale embedding
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
w = w.reshape(bsz, 1, 1, 1) w = w.reshape(bsz, 1, 1, 1)
w = w.to(device=latents.device, dtype=latents.dtype) w = w.to(device=latents.device, dtype=latents.dtype)
# 20.4.8. Prepare prompt embeds and unet_added_conditions # 6. Prepare prompt embeds and unet_added_conditions
prompt_embeds = encoded_text.pop("prompt_embeds") prompt_embeds = encoded_text.pop("prompt_embeds")
# 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k} # 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)
noise_pred = unet( noise_pred = unet(
noisy_model_input, noisy_model_input,
start_timesteps, start_timesteps,
...@@ -1179,7 +1209,7 @@ def main(args): ...@@ -1179,7 +1209,7 @@ def main(args):
added_cond_kwargs=encoded_text, added_cond_kwargs=encoded_text,
).sample ).sample
pred_x_0 = predicted_origin( pred_x_0 = get_predicted_original_sample(
noise_pred, noise_pred,
start_timesteps, start_timesteps,
noisy_model_input, noisy_model_input,
...@@ -1190,17 +1220,27 @@ def main(args): ...@@ -1190,17 +1220,27 @@ def main(args):
model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0 model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
# 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after # 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the
# noisy_latents with both the conditioning embedding c and unconditional embedding 0 # predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these
# Get teacher model prediction on noisy_latents and conditional embedding # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep.
with torch.no_grad(): with torch.no_grad():
with torch.autocast("cuda"): with torch.autocast("cuda"):
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet( cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype), noisy_model_input.to(weight_dtype),
start_timesteps, start_timesteps,
encoder_hidden_states=prompt_embeds.to(weight_dtype), encoder_hidden_states=prompt_embeds.to(weight_dtype),
).sample ).sample
cond_pred_x0 = predicted_origin( cond_pred_x0 = get_predicted_original_sample(
cond_teacher_output,
start_timesteps,
noisy_model_input,
noise_scheduler.config.prediction_type,
alpha_schedule,
sigma_schedule,
)
cond_pred_noise = get_predicted_noise(
cond_teacher_output, cond_teacher_output,
start_timesteps, start_timesteps,
noisy_model_input, noisy_model_input,
...@@ -1209,13 +1249,21 @@ def main(args): ...@@ -1209,13 +1249,21 @@ def main(args):
sigma_schedule, sigma_schedule,
) )
# Get teacher model prediction on noisy_latents and unconditional embedding # 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0
uncond_teacher_output = teacher_unet( uncond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype), noisy_model_input.to(weight_dtype),
start_timesteps, start_timesteps,
encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype), encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
).sample ).sample
uncond_pred_x0 = predicted_origin( uncond_pred_x0 = get_predicted_original_sample(
uncond_teacher_output,
start_timesteps,
noisy_model_input,
noise_scheduler.config.prediction_type,
alpha_schedule,
sigma_schedule,
)
uncond_pred_noise = get_predicted_noise(
uncond_teacher_output, uncond_teacher_output,
start_timesteps, start_timesteps,
noisy_model_input, noisy_model_input,
...@@ -1224,12 +1272,17 @@ def main(args): ...@@ -1224,12 +1272,17 @@ def main(args):
sigma_schedule, sigma_schedule,
) )
# 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation) # 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)
# Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation
pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0) pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output) pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)
# 4. Run one step of the ODE solver to estimate the next point x_prev on the
# augmented PF-ODE trajectory (solving backward in time)
# Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.
x_prev = solver.ddim_step(pred_x0, pred_noise, index) x_prev = solver.ddim_step(pred_x0, pred_noise, index)
# 20.4.12. Get target LCM prediction on x_prev, w, c, t_n # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
# Note that we do not use a separate target network for LCM-LoRA distillation.
with torch.no_grad(): with torch.no_grad():
with torch.autocast("cuda", dtype=weight_dtype): with torch.autocast("cuda", dtype=weight_dtype):
target_noise_pred = unet( target_noise_pred = unet(
...@@ -1238,7 +1291,7 @@ def main(args): ...@@ -1238,7 +1291,7 @@ def main(args):
timestep_cond=None, timestep_cond=None,
encoder_hidden_states=prompt_embeds.float(), encoder_hidden_states=prompt_embeds.float(),
).sample ).sample
pred_x_0 = predicted_origin( pred_x_0 = get_predicted_original_sample(
target_noise_pred, target_noise_pred,
timesteps, timesteps,
x_prev, x_prev,
...@@ -1248,7 +1301,7 @@ def main(args): ...@@ -1248,7 +1301,7 @@ def main(args):
) )
target = c_skip * x_prev + c_out * pred_x_0 target = c_skip * x_prev + c_out * pred_x_0
# 20.4.13. Calculate loss # 10. Calculate loss
if args.loss_type == "l2": if args.loss_type == "l2":
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
elif args.loss_type == "huber": elif args.loss_type == "huber":
...@@ -1256,7 +1309,7 @@ def main(args): ...@@ -1256,7 +1309,7 @@ def main(args):
torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
) )
# 20.4.14. Backpropagate on the online student model (`unet`) # 11. Backpropagate on the online student model (`unet`)
accelerator.backward(loss) accelerator.backward(loss)
if accelerator.sync_gradients: if accelerator.sync_gradients:
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
......
...@@ -162,7 +162,7 @@ class WebdatasetFilter: ...@@ -162,7 +162,7 @@ class WebdatasetFilter:
return False return False
class Text2ImageDataset: class SDXLText2ImageDataset:
def __init__( def __init__(
self, self,
train_shards_path_or_url: Union[str, List[str]], train_shards_path_or_url: Union[str, List[str]],
...@@ -346,19 +346,43 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling= ...@@ -346,19 +346,43 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=
# Compare LCMScheduler.step, Step 4 # Compare LCMScheduler.step, Step 4
def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas): def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
if prediction_type == "epsilon": if prediction_type == "epsilon":
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
pred_x_0 = (sample - sigmas * model_output) / alphas pred_x_0 = (sample - sigmas * model_output) / alphas
elif prediction_type == "sample":
pred_x_0 = model_output
elif prediction_type == "v_prediction": elif prediction_type == "v_prediction":
pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output pred_x_0 = alphas * sample - sigmas * model_output
else: else:
raise ValueError(f"Prediction type {prediction_type} currently not supported.") raise ValueError(
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
f" are supported."
)
return pred_x_0 return pred_x_0
# Based on step 4 in DDIMScheduler.step
def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
if prediction_type == "epsilon":
pred_epsilon = model_output
elif prediction_type == "sample":
pred_epsilon = (sample - alphas * model_output) / sigmas
elif prediction_type == "v_prediction":
pred_epsilon = alphas * model_output + sigmas * sample
else:
raise ValueError(
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
f" are supported."
)
return pred_epsilon
def extract_into_tensor(a, t, x_shape): def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape b, *_ = t.shape
out = a.gather(-1, t) out = a.gather(-1, t)
...@@ -830,9 +854,10 @@ def main(args): ...@@ -830,9 +854,10 @@ def main(args):
args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
) )
# The scheduler calculates the alpha and sigma schedule for us # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod) alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod) sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
# Initialize the DDIM ODE solver for distillation.
solver = DDIMSolver( solver = DDIMSolver(
noise_scheduler.alphas_cumprod.numpy(), noise_scheduler.alphas_cumprod.numpy(),
timesteps=noise_scheduler.config.num_train_timesteps, timesteps=noise_scheduler.config.num_train_timesteps,
...@@ -886,7 +911,7 @@ def main(args): ...@@ -886,7 +911,7 @@ def main(args):
text_encoder_two.requires_grad_(False) text_encoder_two.requires_grad_(False)
teacher_unet.requires_grad_(False) teacher_unet.requires_grad_(False)
# 7. Create online (`unet`) student U-Nets. # 7. Create online student U-Net.
unet = UNet2DConditionModel.from_pretrained( unet = UNet2DConditionModel.from_pretrained(
args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
) )
...@@ -950,6 +975,7 @@ def main(args): ...@@ -950,6 +975,7 @@ def main(args):
# Also move the alpha and sigma noise schedules to accelerator.device. # Also move the alpha and sigma noise schedules to accelerator.device.
alpha_schedule = alpha_schedule.to(accelerator.device) alpha_schedule = alpha_schedule.to(accelerator.device)
sigma_schedule = sigma_schedule.to(accelerator.device) sigma_schedule = sigma_schedule.to(accelerator.device)
# Move the ODE solver to accelerator.device.
solver = solver.to(accelerator.device) solver = solver.to(accelerator.device)
# 10. Handle saving and loading of checkpoints # 10. Handle saving and loading of checkpoints
...@@ -1057,7 +1083,7 @@ def main(args): ...@@ -1057,7 +1083,7 @@ def main(args):
return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs} return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
dataset = Text2ImageDataset( dataset = SDXLText2ImageDataset(
train_shards_path_or_url=args.train_shards_path_or_url, train_shards_path_or_url=args.train_shards_path_or_url,
num_train_examples=args.max_train_samples, num_train_examples=args.max_train_samples,
per_gpu_batch_size=args.train_batch_size, per_gpu_batch_size=args.train_batch_size,
...@@ -1175,6 +1201,7 @@ def main(args): ...@@ -1175,6 +1201,7 @@ def main(args):
for epoch in range(first_epoch, args.num_train_epochs): for epoch in range(first_epoch, args.num_train_epochs):
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet): with accelerator.accumulate(unet):
# 1. Load and process the image, text, and micro-conditioning (original image size, crop coordinates)
image, text, orig_size, crop_coords = batch image, text, orig_size, crop_coords = batch
image = image.to(accelerator.device, non_blocking=True) image = image.to(accelerator.device, non_blocking=True)
...@@ -1196,37 +1223,37 @@ def main(args): ...@@ -1196,37 +1223,37 @@ def main(args):
latents = latents * vae.config.scaling_factor latents = latents * vae.config.scaling_factor
if args.pretrained_vae_model_name_or_path is None: if args.pretrained_vae_model_name_or_path is None:
latents = latents.to(weight_dtype) latents = latents.to(weight_dtype)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0] bsz = latents.shape[0]
# Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias. # 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.
# For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]
topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long() index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
start_timesteps = solver.ddim_timesteps[index] start_timesteps = solver.ddim_timesteps[index]
timesteps = start_timesteps - topk timesteps = start_timesteps - topk
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
# 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps. # 3. Get boundary scalings for start_timesteps and (end) timesteps.
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
c_skip, c_out = scalings_for_boundary_conditions(timesteps) c_skip, c_out = scalings_for_boundary_conditions(timesteps)
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
# 20.4.5. Add noise to the latents according to the noise magnitude at each timestep # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
# (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1] # timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
noise = torch.randn_like(latents)
noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps) noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it # 5. Sample a random guidance scale w from U[w_min, w_max]
# Note that for LCM-LoRA distillation it is not necessary to use a guidance scale embedding
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
w = w.reshape(bsz, 1, 1, 1) w = w.reshape(bsz, 1, 1, 1)
w = w.to(device=latents.device, dtype=latents.dtype) w = w.to(device=latents.device, dtype=latents.dtype)
# 20.4.8. Prepare prompt embeds and unet_added_conditions # 6. Prepare prompt embeds and unet_added_conditions
prompt_embeds = encoded_text.pop("prompt_embeds") prompt_embeds = encoded_text.pop("prompt_embeds")
# 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k} # 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)
noise_pred = unet( noise_pred = unet(
noisy_model_input, noisy_model_input,
start_timesteps, start_timesteps,
...@@ -1235,7 +1262,7 @@ def main(args): ...@@ -1235,7 +1262,7 @@ def main(args):
added_cond_kwargs=encoded_text, added_cond_kwargs=encoded_text,
).sample ).sample
pred_x_0 = predicted_origin( pred_x_0 = get_predicted_original_sample(
noise_pred, noise_pred,
start_timesteps, start_timesteps,
noisy_model_input, noisy_model_input,
...@@ -1246,18 +1273,28 @@ def main(args): ...@@ -1246,18 +1273,28 @@ def main(args):
model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0 model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
# 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after # 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the
# noisy_latents with both the conditioning embedding c and unconditional embedding 0 # predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these
# Get teacher model prediction on noisy_latents and conditional embedding # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep.
with torch.no_grad(): with torch.no_grad():
with torch.autocast("cuda"): with torch.autocast("cuda"):
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet( cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype), noisy_model_input.to(weight_dtype),
start_timesteps, start_timesteps,
encoder_hidden_states=prompt_embeds.to(weight_dtype), encoder_hidden_states=prompt_embeds.to(weight_dtype),
added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()}, added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()},
).sample ).sample
cond_pred_x0 = predicted_origin( cond_pred_x0 = get_predicted_original_sample(
cond_teacher_output,
start_timesteps,
noisy_model_input,
noise_scheduler.config.prediction_type,
alpha_schedule,
sigma_schedule,
)
cond_pred_noise = get_predicted_noise(
cond_teacher_output, cond_teacher_output,
start_timesteps, start_timesteps,
noisy_model_input, noisy_model_input,
...@@ -1266,7 +1303,7 @@ def main(args): ...@@ -1266,7 +1303,7 @@ def main(args):
sigma_schedule, sigma_schedule,
) )
# Get teacher model prediction on noisy_latents and unconditional embedding # 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0
uncond_added_conditions = copy.deepcopy(encoded_text) uncond_added_conditions = copy.deepcopy(encoded_text)
uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds
uncond_teacher_output = teacher_unet( uncond_teacher_output = teacher_unet(
...@@ -1275,7 +1312,15 @@ def main(args): ...@@ -1275,7 +1312,15 @@ def main(args):
encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype), encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()}, added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()},
).sample ).sample
uncond_pred_x0 = predicted_origin( uncond_pred_x0 = get_predicted_original_sample(
uncond_teacher_output,
start_timesteps,
noisy_model_input,
noise_scheduler.config.prediction_type,
alpha_schedule,
sigma_schedule,
)
uncond_pred_noise = get_predicted_noise(
uncond_teacher_output, uncond_teacher_output,
start_timesteps, start_timesteps,
noisy_model_input, noisy_model_input,
...@@ -1284,12 +1329,17 @@ def main(args): ...@@ -1284,12 +1329,17 @@ def main(args):
sigma_schedule, sigma_schedule,
) )
# 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation) # 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)
# Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation
pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0) pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output) pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)
# 4. Run one step of the ODE solver to estimate the next point x_prev on the
# augmented PF-ODE trajectory (solving backward in time)
# Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.
x_prev = solver.ddim_step(pred_x0, pred_noise, index) x_prev = solver.ddim_step(pred_x0, pred_noise, index)
# 20.4.12. Get target LCM prediction on x_prev, w, c, t_n # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
# Note that we do not use a separate target network for LCM-LoRA distillation.
with torch.no_grad(): with torch.no_grad():
with torch.autocast("cuda", enabled=True, dtype=weight_dtype): with torch.autocast("cuda", enabled=True, dtype=weight_dtype):
target_noise_pred = unet( target_noise_pred = unet(
...@@ -1299,7 +1349,7 @@ def main(args): ...@@ -1299,7 +1349,7 @@ def main(args):
encoder_hidden_states=prompt_embeds.float(), encoder_hidden_states=prompt_embeds.float(),
added_cond_kwargs=encoded_text, added_cond_kwargs=encoded_text,
).sample ).sample
pred_x_0 = predicted_origin( pred_x_0 = get_predicted_original_sample(
target_noise_pred, target_noise_pred,
timesteps, timesteps,
x_prev, x_prev,
...@@ -1309,7 +1359,7 @@ def main(args): ...@@ -1309,7 +1359,7 @@ def main(args):
) )
target = c_skip * x_prev + c_out * pred_x_0 target = c_skip * x_prev + c_out * pred_x_0
# 20.4.13. Calculate loss # 10. Calculate loss
if args.loss_type == "l2": if args.loss_type == "l2":
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
elif args.loss_type == "huber": elif args.loss_type == "huber":
...@@ -1317,7 +1367,7 @@ def main(args): ...@@ -1317,7 +1367,7 @@ def main(args):
torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
) )
# 20.4.14. Backpropagate on the online student model (`unet`) # 11. Backpropagate on the online student model (`unet`)
accelerator.backward(loss) accelerator.backward(loss)
if accelerator.sync_gradients: if accelerator.sync_gradients:
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
......
...@@ -138,7 +138,7 @@ class WebdatasetFilter: ...@@ -138,7 +138,7 @@ class WebdatasetFilter:
return False return False
class Text2ImageDataset: class SDText2ImageDataset:
def __init__( def __init__(
self, self,
train_shards_path_or_url: Union[str, List[str]], train_shards_path_or_url: Union[str, List[str]],
...@@ -336,19 +336,43 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling= ...@@ -336,19 +336,43 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=
# Compare LCMScheduler.step, Step 4 # Compare LCMScheduler.step, Step 4
def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas): def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
if prediction_type == "epsilon": if prediction_type == "epsilon":
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
pred_x_0 = (sample - sigmas * model_output) / alphas pred_x_0 = (sample - sigmas * model_output) / alphas
elif prediction_type == "sample":
pred_x_0 = model_output
elif prediction_type == "v_prediction": elif prediction_type == "v_prediction":
pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output pred_x_0 = alphas * sample - sigmas * model_output
else: else:
raise ValueError(f"Prediction type {prediction_type} currently not supported.") raise ValueError(
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
f" are supported."
)
return pred_x_0 return pred_x_0
# Based on step 4 in DDIMScheduler.step
def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
if prediction_type == "epsilon":
pred_epsilon = model_output
elif prediction_type == "sample":
pred_epsilon = (sample - alphas * model_output) / sigmas
elif prediction_type == "v_prediction":
pred_epsilon = alphas * model_output + sigmas * sample
else:
raise ValueError(
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
f" are supported."
)
return pred_epsilon
def extract_into_tensor(a, t, x_shape): def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape b, *_ = t.shape
out = a.gather(-1, t) out = a.gather(-1, t)
...@@ -823,34 +847,35 @@ def main(args): ...@@ -823,34 +847,35 @@ def main(args):
args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
) )
# The scheduler calculates the alpha and sigma schedule for us # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod) alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod) sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
# Initialize the DDIM ODE solver for distillation.
solver = DDIMSolver( solver = DDIMSolver(
noise_scheduler.alphas_cumprod.numpy(), noise_scheduler.alphas_cumprod.numpy(),
timesteps=noise_scheduler.config.num_train_timesteps, timesteps=noise_scheduler.config.num_train_timesteps,
ddim_timesteps=args.num_ddim_timesteps, ddim_timesteps=args.num_ddim_timesteps,
) )
# 2. Load tokenizers from SD-XL checkpoint. # 2. Load tokenizers from SD 1.X/2.X checkpoint.
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False
) )
# 3. Load text encoders from SD-1.5 checkpoint. # 3. Load text encoders from SD 1.X/2.X checkpoint.
# import correct text encoder classes # import correct text encoder classes
text_encoder = CLIPTextModel.from_pretrained( text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision
) )
# 4. Load VAE from SD-XL checkpoint (or more stable VAE) # 4. Load VAE from SD 1.X/2.X checkpoint
vae = AutoencoderKL.from_pretrained( vae = AutoencoderKL.from_pretrained(
args.pretrained_teacher_model, args.pretrained_teacher_model,
subfolder="vae", subfolder="vae",
revision=args.teacher_revision, revision=args.teacher_revision,
) )
# 5. Load teacher U-Net from SD-XL checkpoint # 5. Load teacher U-Net from SD 1.X/2.X checkpoint
teacher_unet = UNet2DConditionModel.from_pretrained( teacher_unet = UNet2DConditionModel.from_pretrained(
args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
) )
...@@ -860,7 +885,7 @@ def main(args): ...@@ -860,7 +885,7 @@ def main(args):
text_encoder.requires_grad_(False) text_encoder.requires_grad_(False)
teacher_unet.requires_grad_(False) teacher_unet.requires_grad_(False)
# 8. Create online (`unet`) student U-Nets. This will be updated by the optimizer (e.g. via backpropagation.) # 7. Create online student U-Net. This will be updated by the optimizer (e.g. via backpropagation.)
# Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None # Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
if teacher_unet.config.time_cond_proj_dim is None: if teacher_unet.config.time_cond_proj_dim is None:
teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim
...@@ -869,8 +894,8 @@ def main(args): ...@@ -869,8 +894,8 @@ def main(args):
unet.load_state_dict(teacher_unet.state_dict(), strict=False) unet.load_state_dict(teacher_unet.state_dict(), strict=False)
unet.train() unet.train()
# 9. Create target (`ema_unet`) student U-Net parameters. This will be updated via EMA updates (polyak averaging). # 8. Create target student U-Net. This will be updated via EMA updates (polyak averaging).
# Initialize from unet # Initialize from (online) unet
target_unet = UNet2DConditionModel(**teacher_unet.config) target_unet = UNet2DConditionModel(**teacher_unet.config)
target_unet.load_state_dict(unet.state_dict()) target_unet.load_state_dict(unet.state_dict())
target_unet.train() target_unet.train()
...@@ -887,7 +912,7 @@ def main(args): ...@@ -887,7 +912,7 @@ def main(args):
f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}" f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
) )
# 10. Handle mixed precision and device placement # 9. Handle mixed precision and device placement
# For mixed precision training we cast all non-trainable weigths to half-precision # For mixed precision training we cast all non-trainable weigths to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required. # as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32 weight_dtype = torch.float32
...@@ -914,7 +939,7 @@ def main(args): ...@@ -914,7 +939,7 @@ def main(args):
sigma_schedule = sigma_schedule.to(accelerator.device) sigma_schedule = sigma_schedule.to(accelerator.device)
solver = solver.to(accelerator.device) solver = solver.to(accelerator.device)
# 11. Handle saving and loading of checkpoints # 10. Handle saving and loading of checkpoints
# `accelerate` 0.16.0 will have better support for customized saving # `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
...@@ -948,7 +973,7 @@ def main(args): ...@@ -948,7 +973,7 @@ def main(args):
accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook) accelerator.register_load_state_pre_hook(load_model_hook)
# 12. Enable optimizations # 11. Enable optimizations
if args.enable_xformers_memory_efficient_attention: if args.enable_xformers_memory_efficient_attention:
if is_xformers_available(): if is_xformers_available():
import xformers import xformers
...@@ -994,13 +1019,14 @@ def main(args): ...@@ -994,13 +1019,14 @@ def main(args):
eps=args.adam_epsilon, eps=args.adam_epsilon,
) )
# 13. Dataset creation and data processing
# Here, we compute not just the text embeddings but also the additional embeddings # Here, we compute not just the text embeddings but also the additional embeddings
# needed for the SD XL UNet to operate. # needed for the SD XL UNet to operate.
def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True): def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True):
prompt_embeds = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train) prompt_embeds = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train)
return {"prompt_embeds": prompt_embeds} return {"prompt_embeds": prompt_embeds}
dataset = Text2ImageDataset( dataset = SDText2ImageDataset(
train_shards_path_or_url=args.train_shards_path_or_url, train_shards_path_or_url=args.train_shards_path_or_url,
num_train_examples=args.max_train_samples, num_train_examples=args.max_train_samples,
per_gpu_batch_size=args.train_batch_size, per_gpu_batch_size=args.train_batch_size,
...@@ -1020,6 +1046,7 @@ def main(args): ...@@ -1020,6 +1046,7 @@ def main(args):
tokenizer=tokenizer, tokenizer=tokenizer,
) )
# 14. LR Scheduler creation
# Scheduler and math around the number of training steps. # Scheduler and math around the number of training steps.
overrode_max_train_steps = False overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
...@@ -1034,6 +1061,7 @@ def main(args): ...@@ -1034,6 +1061,7 @@ def main(args):
num_training_steps=args.max_train_steps, num_training_steps=args.max_train_steps,
) )
# 15. Prepare for training
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.
unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler) unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
...@@ -1055,7 +1083,7 @@ def main(args): ...@@ -1055,7 +1083,7 @@ def main(args):
).input_ids.to(accelerator.device) ).input_ids.to(accelerator.device)
uncond_prompt_embeds = text_encoder(uncond_input_ids)[0] uncond_prompt_embeds = text_encoder(uncond_input_ids)[0]
# Train! # 16. Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****") logger.info("***** Running training *****")
...@@ -1106,6 +1134,7 @@ def main(args): ...@@ -1106,6 +1134,7 @@ def main(args):
for epoch in range(first_epoch, args.num_train_epochs): for epoch in range(first_epoch, args.num_train_epochs):
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet): with accelerator.accumulate(unet):
# 1. Load and process the image and text conditioning
image, text = batch image, text = batch
image = image.to(accelerator.device, non_blocking=True) image = image.to(accelerator.device, non_blocking=True)
...@@ -1123,29 +1152,28 @@ def main(args): ...@@ -1123,29 +1152,28 @@ def main(args):
latents = latents * vae.config.scaling_factor latents = latents * vae.config.scaling_factor
latents = latents.to(weight_dtype) latents = latents.to(weight_dtype)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0] bsz = latents.shape[0]
# Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias. # 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.
# For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]
topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long() index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
start_timesteps = solver.ddim_timesteps[index] start_timesteps = solver.ddim_timesteps[index]
timesteps = start_timesteps - topk timesteps = start_timesteps - topk
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
# 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps. # 3. Get boundary scalings for start_timesteps and (end) timesteps.
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
c_skip, c_out = scalings_for_boundary_conditions(timesteps) c_skip, c_out = scalings_for_boundary_conditions(timesteps)
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
# 20.4.5. Add noise to the latents according to the noise magnitude at each timestep # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
# (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1] # timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
noise = torch.randn_like(latents)
noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps) noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it # 5. Sample a random guidance scale w from U[w_min, w_max] and embed it
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim) w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim)
w = w.reshape(bsz, 1, 1, 1) w = w.reshape(bsz, 1, 1, 1)
...@@ -1153,10 +1181,10 @@ def main(args): ...@@ -1153,10 +1181,10 @@ def main(args):
w = w.to(device=latents.device, dtype=latents.dtype) w = w.to(device=latents.device, dtype=latents.dtype)
w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype) w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype)
# 20.4.8. Prepare prompt embeds and unet_added_conditions # 6. Prepare prompt embeds and unet_added_conditions
prompt_embeds = encoded_text.pop("prompt_embeds") prompt_embeds = encoded_text.pop("prompt_embeds")
# 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k} # 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)
noise_pred = unet( noise_pred = unet(
noisy_model_input, noisy_model_input,
start_timesteps, start_timesteps,
...@@ -1165,7 +1193,7 @@ def main(args): ...@@ -1165,7 +1193,7 @@ def main(args):
added_cond_kwargs=encoded_text, added_cond_kwargs=encoded_text,
).sample ).sample
pred_x_0 = predicted_origin( pred_x_0 = get_predicted_original_sample(
noise_pred, noise_pred,
start_timesteps, start_timesteps,
noisy_model_input, noisy_model_input,
...@@ -1176,17 +1204,27 @@ def main(args): ...@@ -1176,17 +1204,27 @@ def main(args):
model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0 model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
# 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after # 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the
# noisy_latents with both the conditioning embedding c and unconditional embedding 0 # predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these
# Get teacher model prediction on noisy_latents and conditional embedding # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep.
with torch.no_grad(): with torch.no_grad():
with torch.autocast("cuda"): with torch.autocast("cuda"):
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet( cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype), noisy_model_input.to(weight_dtype),
start_timesteps, start_timesteps,
encoder_hidden_states=prompt_embeds.to(weight_dtype), encoder_hidden_states=prompt_embeds.to(weight_dtype),
).sample ).sample
cond_pred_x0 = predicted_origin( cond_pred_x0 = get_predicted_original_sample(
cond_teacher_output,
start_timesteps,
noisy_model_input,
noise_scheduler.config.prediction_type,
alpha_schedule,
sigma_schedule,
)
cond_pred_noise = get_predicted_noise(
cond_teacher_output, cond_teacher_output,
start_timesteps, start_timesteps,
noisy_model_input, noisy_model_input,
...@@ -1195,13 +1233,21 @@ def main(args): ...@@ -1195,13 +1233,21 @@ def main(args):
sigma_schedule, sigma_schedule,
) )
# Get teacher model prediction on noisy_latents and unconditional embedding # 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0
uncond_teacher_output = teacher_unet( uncond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype), noisy_model_input.to(weight_dtype),
start_timesteps, start_timesteps,
encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype), encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
).sample ).sample
uncond_pred_x0 = predicted_origin( uncond_pred_x0 = get_predicted_original_sample(
uncond_teacher_output,
start_timesteps,
noisy_model_input,
noise_scheduler.config.prediction_type,
alpha_schedule,
sigma_schedule,
)
uncond_pred_noise = get_predicted_noise(
uncond_teacher_output, uncond_teacher_output,
start_timesteps, start_timesteps,
noisy_model_input, noisy_model_input,
...@@ -1210,12 +1256,16 @@ def main(args): ...@@ -1210,12 +1256,16 @@ def main(args):
sigma_schedule, sigma_schedule,
) )
# 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation) # 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)
# Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation
pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0) pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output) pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)
# 4. Run one step of the ODE solver to estimate the next point x_prev on the
# augmented PF-ODE trajectory (solving backward in time)
# Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.
x_prev = solver.ddim_step(pred_x0, pred_noise, index) x_prev = solver.ddim_step(pred_x0, pred_noise, index)
# 20.4.12. Get target LCM prediction on x_prev, w, c, t_n # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
with torch.no_grad(): with torch.no_grad():
with torch.autocast("cuda", dtype=weight_dtype): with torch.autocast("cuda", dtype=weight_dtype):
target_noise_pred = target_unet( target_noise_pred = target_unet(
...@@ -1224,7 +1274,7 @@ def main(args): ...@@ -1224,7 +1274,7 @@ def main(args):
timestep_cond=w_embedding, timestep_cond=w_embedding,
encoder_hidden_states=prompt_embeds.float(), encoder_hidden_states=prompt_embeds.float(),
).sample ).sample
pred_x_0 = predicted_origin( pred_x_0 = get_predicted_original_sample(
target_noise_pred, target_noise_pred,
timesteps, timesteps,
x_prev, x_prev,
...@@ -1234,7 +1284,7 @@ def main(args): ...@@ -1234,7 +1284,7 @@ def main(args):
) )
target = c_skip * x_prev + c_out * pred_x_0 target = c_skip * x_prev + c_out * pred_x_0
# 20.4.13. Calculate loss # 10. Calculate loss
if args.loss_type == "l2": if args.loss_type == "l2":
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
elif args.loss_type == "huber": elif args.loss_type == "huber":
...@@ -1242,7 +1292,7 @@ def main(args): ...@@ -1242,7 +1292,7 @@ def main(args):
torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
) )
# 20.4.14. Backpropagate on the online student model (`unet`) # 11. Backpropagate on the online student model (`unet`)
accelerator.backward(loss) accelerator.backward(loss)
if accelerator.sync_gradients: if accelerator.sync_gradients:
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
...@@ -1252,7 +1302,7 @@ def main(args): ...@@ -1252,7 +1302,7 @@ def main(args):
# Checks if the accelerator has performed an optimization step behind the scenes # Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients: if accelerator.sync_gradients:
# 20.4.15. Make EMA update to target student model parameters # 12. Make EMA update to target student model parameters (`target_unet`)
update_ema(target_unet.parameters(), unet.parameters(), args.ema_decay) update_ema(target_unet.parameters(), unet.parameters(), args.ema_decay)
progress_bar.update(1) progress_bar.update(1)
global_step += 1 global_step += 1
......
...@@ -144,7 +144,7 @@ class WebdatasetFilter: ...@@ -144,7 +144,7 @@ class WebdatasetFilter:
return False return False
class Text2ImageDataset: class SDXLText2ImageDataset:
def __init__( def __init__(
self, self,
train_shards_path_or_url: Union[str, List[str]], train_shards_path_or_url: Union[str, List[str]],
...@@ -324,19 +324,43 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling= ...@@ -324,19 +324,43 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=
# Compare LCMScheduler.step, Step 4 # Compare LCMScheduler.step, Step 4
def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas): def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
if prediction_type == "epsilon": if prediction_type == "epsilon":
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
pred_x_0 = (sample - sigmas * model_output) / alphas pred_x_0 = (sample - sigmas * model_output) / alphas
elif prediction_type == "sample":
pred_x_0 = model_output
elif prediction_type == "v_prediction": elif prediction_type == "v_prediction":
pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output pred_x_0 = alphas * sample - sigmas * model_output
else: else:
raise ValueError(f"Prediction type {prediction_type} currently not supported.") raise ValueError(
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
f" are supported."
)
return pred_x_0 return pred_x_0
# Based on step 4 in DDIMScheduler.step
def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
if prediction_type == "epsilon":
pred_epsilon = model_output
elif prediction_type == "sample":
pred_epsilon = (sample - alphas * model_output) / sigmas
elif prediction_type == "v_prediction":
pred_epsilon = alphas * model_output + sigmas * sample
else:
raise ValueError(
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
f" are supported."
)
return pred_epsilon
def extract_into_tensor(a, t, x_shape): def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape b, *_ = t.shape
out = a.gather(-1, t) out = a.gather(-1, t)
...@@ -863,9 +887,10 @@ def main(args): ...@@ -863,9 +887,10 @@ def main(args):
args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
) )
# The scheduler calculates the alpha and sigma schedule for us # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod) alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod) sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
# Initialize the DDIM ODE solver for distillation.
solver = DDIMSolver( solver = DDIMSolver(
noise_scheduler.alphas_cumprod.numpy(), noise_scheduler.alphas_cumprod.numpy(),
timesteps=noise_scheduler.config.num_train_timesteps, timesteps=noise_scheduler.config.num_train_timesteps,
...@@ -919,7 +944,7 @@ def main(args): ...@@ -919,7 +944,7 @@ def main(args):
text_encoder_two.requires_grad_(False) text_encoder_two.requires_grad_(False)
teacher_unet.requires_grad_(False) teacher_unet.requires_grad_(False)
# 8. Create online (`unet`) student U-Nets. This will be updated by the optimizer (e.g. via backpropagation.) # 7. Create online student U-Net. This will be updated by the optimizer (e.g. via backpropagation.)
# Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None # Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
if teacher_unet.config.time_cond_proj_dim is None: if teacher_unet.config.time_cond_proj_dim is None:
teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim
...@@ -928,8 +953,8 @@ def main(args): ...@@ -928,8 +953,8 @@ def main(args):
unet.load_state_dict(teacher_unet.state_dict(), strict=False) unet.load_state_dict(teacher_unet.state_dict(), strict=False)
unet.train() unet.train()
# 9. Create target (`ema_unet`) student U-Net parameters. This will be updated via EMA updates (polyak averaging). # 8. Create target student U-Net. This will be updated via EMA updates (polyak averaging).
# Initialize from unet # Initialize from (online) unet
target_unet = UNet2DConditionModel(**teacher_unet.config) target_unet = UNet2DConditionModel(**teacher_unet.config)
target_unet.load_state_dict(unet.state_dict()) target_unet.load_state_dict(unet.state_dict())
target_unet.train() target_unet.train()
...@@ -971,6 +996,7 @@ def main(args): ...@@ -971,6 +996,7 @@ def main(args):
# Also move the alpha and sigma noise schedules to accelerator.device. # Also move the alpha and sigma noise schedules to accelerator.device.
alpha_schedule = alpha_schedule.to(accelerator.device) alpha_schedule = alpha_schedule.to(accelerator.device)
sigma_schedule = sigma_schedule.to(accelerator.device) sigma_schedule = sigma_schedule.to(accelerator.device)
# Move the ODE solver to accelerator.device.
solver = solver.to(accelerator.device) solver = solver.to(accelerator.device)
# 10. Handle saving and loading of checkpoints # 10. Handle saving and loading of checkpoints
...@@ -1084,7 +1110,7 @@ def main(args): ...@@ -1084,7 +1110,7 @@ def main(args):
return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs} return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
dataset = Text2ImageDataset( dataset = SDXLText2ImageDataset(
train_shards_path_or_url=args.train_shards_path_or_url, train_shards_path_or_url=args.train_shards_path_or_url,
num_train_examples=args.max_train_samples, num_train_examples=args.max_train_samples,
per_gpu_batch_size=args.train_batch_size, per_gpu_batch_size=args.train_batch_size,
...@@ -1202,6 +1228,7 @@ def main(args): ...@@ -1202,6 +1228,7 @@ def main(args):
for epoch in range(first_epoch, args.num_train_epochs): for epoch in range(first_epoch, args.num_train_epochs):
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet): with accelerator.accumulate(unet):
# 1. Load and process the image, text, and micro-conditioning (original image size, crop coordinates)
image, text, orig_size, crop_coords = batch image, text, orig_size, crop_coords = batch
image = image.to(accelerator.device, non_blocking=True) image = image.to(accelerator.device, non_blocking=True)
...@@ -1223,38 +1250,39 @@ def main(args): ...@@ -1223,38 +1250,39 @@ def main(args):
latents = latents * vae.config.scaling_factor latents = latents * vae.config.scaling_factor
if args.pretrained_vae_model_name_or_path is None: if args.pretrained_vae_model_name_or_path is None:
latents = latents.to(weight_dtype) latents = latents.to(weight_dtype)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0] bsz = latents.shape[0]
# Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias. # 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.
# For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]
topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long() index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
start_timesteps = solver.ddim_timesteps[index] start_timesteps = solver.ddim_timesteps[index]
timesteps = start_timesteps - topk timesteps = start_timesteps - topk
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
# 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps. # 3. Get boundary scalings for start_timesteps and (end) timesteps.
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
c_skip, c_out = scalings_for_boundary_conditions(timesteps) c_skip, c_out = scalings_for_boundary_conditions(timesteps)
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
# 20.4.5. Add noise to the latents according to the noise magnitude at each timestep # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
# (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1] # timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
noise = torch.randn_like(latents)
noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps) noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it # 5. Sample a random guidance scale w from U[w_min, w_max] and embed it
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim) w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim)
w = w.reshape(bsz, 1, 1, 1) w = w.reshape(bsz, 1, 1, 1)
# Move to U-Net device and dtype
w = w.to(device=latents.device, dtype=latents.dtype) w = w.to(device=latents.device, dtype=latents.dtype)
w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype)
# 20.4.8. Prepare prompt embeds and unet_added_conditions # 6. Prepare prompt embeds and unet_added_conditions
prompt_embeds = encoded_text.pop("prompt_embeds") prompt_embeds = encoded_text.pop("prompt_embeds")
# 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k} # 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)
noise_pred = unet( noise_pred = unet(
noisy_model_input, noisy_model_input,
start_timesteps, start_timesteps,
...@@ -1263,7 +1291,7 @@ def main(args): ...@@ -1263,7 +1291,7 @@ def main(args):
added_cond_kwargs=encoded_text, added_cond_kwargs=encoded_text,
).sample ).sample
pred_x_0 = predicted_origin( pred_x_0 = get_predicted_original_sample(
noise_pred, noise_pred,
start_timesteps, start_timesteps,
noisy_model_input, noisy_model_input,
...@@ -1274,18 +1302,28 @@ def main(args): ...@@ -1274,18 +1302,28 @@ def main(args):
model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0 model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
# 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after # 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the
# noisy_latents with both the conditioning embedding c and unconditional embedding 0 # predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these
# Get teacher model prediction on noisy_latents and conditional embedding # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep.
with torch.no_grad(): with torch.no_grad():
with torch.autocast("cuda"): with torch.autocast("cuda"):
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet( cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype), noisy_model_input.to(weight_dtype),
start_timesteps, start_timesteps,
encoder_hidden_states=prompt_embeds.to(weight_dtype), encoder_hidden_states=prompt_embeds.to(weight_dtype),
added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()}, added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()},
).sample ).sample
cond_pred_x0 = predicted_origin( cond_pred_x0 = get_predicted_original_sample(
cond_teacher_output,
start_timesteps,
noisy_model_input,
noise_scheduler.config.prediction_type,
alpha_schedule,
sigma_schedule,
)
cond_pred_noise = get_predicted_noise(
cond_teacher_output, cond_teacher_output,
start_timesteps, start_timesteps,
noisy_model_input, noisy_model_input,
...@@ -1294,7 +1332,7 @@ def main(args): ...@@ -1294,7 +1332,7 @@ def main(args):
sigma_schedule, sigma_schedule,
) )
# Get teacher model prediction on noisy_latents and unconditional embedding # 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0
uncond_added_conditions = copy.deepcopy(encoded_text) uncond_added_conditions = copy.deepcopy(encoded_text)
uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds
uncond_teacher_output = teacher_unet( uncond_teacher_output = teacher_unet(
...@@ -1303,7 +1341,15 @@ def main(args): ...@@ -1303,7 +1341,15 @@ def main(args):
encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype), encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()}, added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()},
).sample ).sample
uncond_pred_x0 = predicted_origin( uncond_pred_x0 = get_predicted_original_sample(
uncond_teacher_output,
start_timesteps,
noisy_model_input,
noise_scheduler.config.prediction_type,
alpha_schedule,
sigma_schedule,
)
uncond_pred_noise = get_predicted_noise(
uncond_teacher_output, uncond_teacher_output,
start_timesteps, start_timesteps,
noisy_model_input, noisy_model_input,
...@@ -1312,12 +1358,16 @@ def main(args): ...@@ -1312,12 +1358,16 @@ def main(args):
sigma_schedule, sigma_schedule,
) )
# 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation) # 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)
# Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation
pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0) pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output) pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)
# 4. Run one step of the ODE solver to estimate the next point x_prev on the
# augmented PF-ODE trajectory (solving backward in time)
# Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.
x_prev = solver.ddim_step(pred_x0, pred_noise, index) x_prev = solver.ddim_step(pred_x0, pred_noise, index)
# 20.4.12. Get target LCM prediction on x_prev, w, c, t_n # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
with torch.no_grad(): with torch.no_grad():
with torch.autocast("cuda", dtype=weight_dtype): with torch.autocast("cuda", dtype=weight_dtype):
target_noise_pred = target_unet( target_noise_pred = target_unet(
...@@ -1327,7 +1377,7 @@ def main(args): ...@@ -1327,7 +1377,7 @@ def main(args):
encoder_hidden_states=prompt_embeds.float(), encoder_hidden_states=prompt_embeds.float(),
added_cond_kwargs=encoded_text, added_cond_kwargs=encoded_text,
).sample ).sample
pred_x_0 = predicted_origin( pred_x_0 = get_predicted_original_sample(
target_noise_pred, target_noise_pred,
timesteps, timesteps,
x_prev, x_prev,
...@@ -1337,7 +1387,7 @@ def main(args): ...@@ -1337,7 +1387,7 @@ def main(args):
) )
target = c_skip * x_prev + c_out * pred_x_0 target = c_skip * x_prev + c_out * pred_x_0
# 20.4.13. Calculate loss # 10. Calculate loss
if args.loss_type == "l2": if args.loss_type == "l2":
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
elif args.loss_type == "huber": elif args.loss_type == "huber":
...@@ -1345,7 +1395,7 @@ def main(args): ...@@ -1345,7 +1395,7 @@ def main(args):
torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
) )
# 20.4.14. Backpropagate on the online student model (`unet`) # 11. Backpropagate on the online student model (`unet`)
accelerator.backward(loss) accelerator.backward(loss)
if accelerator.sync_gradients: if accelerator.sync_gradients:
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
...@@ -1355,7 +1405,7 @@ def main(args): ...@@ -1355,7 +1405,7 @@ def main(args):
# Checks if the accelerator has performed an optimization step behind the scenes # Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients: if accelerator.sync_gradients:
# 20.4.15. Make EMA update to target student model parameters # 12. Make EMA update to target student model parameters (`target_unet`)
update_ema(target_unet.parameters(), unet.parameters(), args.ema_decay) update_ema(target_unet.parameters(), unet.parameters(), args.ema_decay)
progress_bar.update(1) progress_bar.update(1)
global_step += 1 global_step += 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment