Unverified Commit 07eac4d6 authored by dg845's avatar dg845 Committed by GitHub
Browse files

Fix LCM Stable Diffusion distillation bug related to parsing unet_time_cond_proj_dim (#5893)

* Fix bug related to parsing unet_time_cond_proj_dim.

* Fix analogous bug in the SD-XL LCM distillation script.
parent c079cae3
...@@ -657,6 +657,15 @@ def parse_args(): ...@@ -657,6 +657,15 @@ def parse_args():
default=0.001, default=0.001,
help="The huber loss parameter. Only used if `--loss_type=huber`.", help="The huber loss parameter. Only used if `--loss_type=huber`.",
) )
parser.add_argument(
"--unet_time_cond_proj_dim",
type=int,
default=256,
help=(
"The dimension of the guidance scale embedding in the U-Net, which will be used if the teacher U-Net"
" does not have `time_cond_proj_dim` set."
),
)
# ----Exponential Moving Average (EMA)---- # ----Exponential Moving Average (EMA)----
parser.add_argument( parser.add_argument(
"--ema_decay", "--ema_decay",
...@@ -1138,7 +1147,7 @@ def main(args): ...@@ -1138,7 +1147,7 @@ def main(args):
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it # 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
w_embedding = guidance_scale_embedding(w, embedding_dim=args.unet_time_cond_proj_dim) w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim)
w = w.reshape(bsz, 1, 1, 1) w = w.reshape(bsz, 1, 1, 1)
# Move to U-Net device and dtype # Move to U-Net device and dtype
w = w.to(device=latents.device, dtype=latents.dtype) w = w.to(device=latents.device, dtype=latents.dtype)
......
...@@ -677,6 +677,15 @@ def parse_args(): ...@@ -677,6 +677,15 @@ def parse_args():
default=0.001, default=0.001,
help="The huber loss parameter. Only used if `--loss_type=huber`.", help="The huber loss parameter. Only used if `--loss_type=huber`.",
) )
parser.add_argument(
"--unet_time_cond_proj_dim",
type=int,
default=256,
help=(
"The dimension of the guidance scale embedding in the U-Net, which will be used if the teacher U-Net"
" does not have `time_cond_proj_dim` set."
),
)
# ----Exponential Moving Average (EMA)---- # ----Exponential Moving Average (EMA)----
parser.add_argument( parser.add_argument(
"--ema_decay", "--ema_decay",
...@@ -1233,6 +1242,7 @@ def main(args): ...@@ -1233,6 +1242,7 @@ def main(args):
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it # 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim)
w = w.reshape(bsz, 1, 1, 1) w = w.reshape(bsz, 1, 1, 1)
w = w.to(device=latents.device, dtype=latents.dtype) w = w.to(device=latents.device, dtype=latents.dtype)
...@@ -1243,7 +1253,7 @@ def main(args): ...@@ -1243,7 +1253,7 @@ def main(args):
noise_pred = unet( noise_pred = unet(
noisy_model_input, noisy_model_input,
start_timesteps, start_timesteps,
timestep_cond=None, timestep_cond=w_embedding,
encoder_hidden_states=prompt_embeds.float(), encoder_hidden_states=prompt_embeds.float(),
added_cond_kwargs=encoded_text, added_cond_kwargs=encoded_text,
).sample ).sample
...@@ -1308,7 +1318,7 @@ def main(args): ...@@ -1308,7 +1318,7 @@ def main(args):
target_noise_pred = target_unet( target_noise_pred = target_unet(
x_prev.float(), x_prev.float(),
timesteps, timesteps,
timestep_cond=None, timestep_cond=w_embedding,
encoder_hidden_states=prompt_embeds.float(), encoder_hidden_states=prompt_embeds.float(),
added_cond_kwargs=encoded_text, added_cond_kwargs=encoded_text,
).sample ).sample
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment