Unverified Commit 0bee4d33 authored by dg845's avatar dg845 Committed by GitHub
Browse files

LCM Distill Scripts Fix Bug when Initializing Target U-Net (#6848)



* Initialize target_unet from unet rather than teacher_unet so that we correctly add time_embedding.cond_proj if necessary.

* Use UNet2DConditionModel.from_config to initialize target_unet from unet's config.

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 42f25d60
......@@ -945,7 +945,7 @@ def main(args):
# 8. Create target student U-Net. This will be updated via EMA updates (polyak averaging).
# Initialize from (online) unet
target_unet = UNet2DConditionModel(**teacher_unet.config)
target_unet = UNet2DConditionModel.from_config(unet.config)
target_unet.load_state_dict(unet.state_dict())
target_unet.train()
target_unet.requires_grad_(False)
......
......@@ -1004,7 +1004,7 @@ def main(args):
# 8. Create target student U-Net. This will be updated via EMA updates (polyak averaging).
# Initialize from (online) unet
target_unet = UNet2DConditionModel(**teacher_unet.config)
target_unet = UNet2DConditionModel.from_config(unet.config)
target_unet.load_state_dict(unet.state_dict())
target_unet.train()
target_unet.requires_grad_(False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment