Unverified Commit 9df3d843 authored by dg845's avatar dg845 Committed by GitHub
Browse files

Fix LCM distillation bug when creating the guidance scale embeddings using multiple GPUs. (#6279)



Fix bug when creating the guidance embeddings using multiple GPUs.
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent c7514490
......@@ -889,6 +889,7 @@ def main(args):
# Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
if teacher_unet.config.time_cond_proj_dim is None:
teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim
time_cond_proj_dim = teacher_unet.config.time_cond_proj_dim
unet = UNet2DConditionModel(**teacher_unet.config)
# load teacher_unet weights into unet
unet.load_state_dict(teacher_unet.state_dict(), strict=False)
......@@ -1175,7 +1176,7 @@ def main(args):
# 5. Sample a random guidance scale w from U[w_min, w_max] and embed it
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim)
w_embedding = guidance_scale_embedding(w, embedding_dim=time_cond_proj_dim)
w = w.reshape(bsz, 1, 1, 1)
# Move to U-Net device and dtype
w = w.to(device=latents.device, dtype=latents.dtype)
......
......@@ -948,6 +948,7 @@ def main(args):
# Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
if teacher_unet.config.time_cond_proj_dim is None:
teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim
time_cond_proj_dim = teacher_unet.config.time_cond_proj_dim
unet = UNet2DConditionModel(**teacher_unet.config)
# load teacher_unet weights into unet
unet.load_state_dict(teacher_unet.state_dict(), strict=False)
......@@ -1273,7 +1274,7 @@ def main(args):
# 5. Sample a random guidance scale w from U[w_min, w_max] and embed it
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim)
w_embedding = guidance_scale_embedding(w, embedding_dim=time_cond_proj_dim)
w = w.reshape(bsz, 1, 1, 1)
# Move to U-Net device and dtype
w = w.to(device=latents.device, dtype=latents.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment