Unverified Commit 07eac4d6 authored by dg845's avatar dg845 Committed by GitHub
Browse files

Fix LCM Stable Diffusion distillation bug related to parsing unet_time_cond_proj_dim (#5893)

* Fix bug related to parsing unet_time_cond_proj_dim.

* Fix analogous bug in the SD-XL LCM distillation script.
parent c079cae3
......@@ -657,6 +657,15 @@ def parse_args():
default=0.001,
help="The huber loss parameter. Only used if `--loss_type=huber`.",
)
parser.add_argument(
"--unet_time_cond_proj_dim",
type=int,
default=256,
help=(
"The dimension of the guidance scale embedding in the U-Net, which will be used if the teacher U-Net"
" does not have `time_cond_proj_dim` set."
),
)
# ----Exponential Moving Average (EMA)----
parser.add_argument(
"--ema_decay",
......@@ -1138,7 +1147,7 @@ def main(args):
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
w_embedding = guidance_scale_embedding(w, embedding_dim=args.unet_time_cond_proj_dim)
w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim)
w = w.reshape(bsz, 1, 1, 1)
# Move to U-Net device and dtype
w = w.to(device=latents.device, dtype=latents.dtype)
......
......@@ -677,6 +677,15 @@ def parse_args():
default=0.001,
help="The huber loss parameter. Only used if `--loss_type=huber`.",
)
parser.add_argument(
"--unet_time_cond_proj_dim",
type=int,
default=256,
help=(
"The dimension of the guidance scale embedding in the U-Net, which will be used if the teacher U-Net"
" does not have `time_cond_proj_dim` set."
),
)
# ----Exponential Moving Average (EMA)----
parser.add_argument(
"--ema_decay",
......@@ -1233,6 +1242,7 @@ def main(args):
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim)
w = w.reshape(bsz, 1, 1, 1)
w = w.to(device=latents.device, dtype=latents.dtype)
......@@ -1243,7 +1253,7 @@ def main(args):
noise_pred = unet(
noisy_model_input,
start_timesteps,
timestep_cond=None,
timestep_cond=w_embedding,
encoder_hidden_states=prompt_embeds.float(),
added_cond_kwargs=encoded_text,
).sample
......@@ -1308,7 +1318,7 @@ def main(args):
target_noise_pred = target_unet(
x_prev.float(),
timesteps,
timestep_cond=None,
timestep_cond=w_embedding,
encoder_hidden_states=prompt_embeds.float(),
added_cond_kwargs=encoded_text,
).sample
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment