Unverified Commit 17cece07 authored by dg845's avatar dg845 Committed by GitHub
Browse files

Fix bug in LCM Distillation Scripts when args.unet_time_cond_proj_dim is used (#6523)

* Fix bug where unet's time_cond_proj_dim is not set correctly if using args.unet_time_cond_proj_dim.

* make style
parent a551ddf9
...@@ -921,10 +921,12 @@ def main(args): ...@@ -921,10 +921,12 @@ def main(args):
# 7. Create online student U-Net. This will be updated by the optimizer (e.g. via backpropagation.) # 7. Create online student U-Net. This will be updated by the optimizer (e.g. via backpropagation.)
# Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None # Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
if teacher_unet.config.time_cond_proj_dim is None: time_cond_proj_dim = (
teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim teacher_unet.config.time_cond_proj_dim
time_cond_proj_dim = teacher_unet.config.time_cond_proj_dim if teacher_unet.config.time_cond_proj_dim is not None
unet = UNet2DConditionModel(**teacher_unet.config) else args.unet_time_cond_proj_dim
)
unet = UNet2DConditionModel.from_config(teacher_unet.config, time_cond_proj_dim=time_cond_proj_dim)
# load teacher_unet weights into unet # load teacher_unet weights into unet
unet.load_state_dict(teacher_unet.state_dict(), strict=False) unet.load_state_dict(teacher_unet.state_dict(), strict=False)
unet.train() unet.train()
......
...@@ -980,10 +980,12 @@ def main(args): ...@@ -980,10 +980,12 @@ def main(args):
# 7. Create online student U-Net. This will be updated by the optimizer (e.g. via backpropagation.) # 7. Create online student U-Net. This will be updated by the optimizer (e.g. via backpropagation.)
# Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None # Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
if teacher_unet.config.time_cond_proj_dim is None: time_cond_proj_dim = (
teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim teacher_unet.config.time_cond_proj_dim
time_cond_proj_dim = teacher_unet.config.time_cond_proj_dim if teacher_unet.config.time_cond_proj_dim is not None
unet = UNet2DConditionModel(**teacher_unet.config) else args.unet_time_cond_proj_dim
)
unet = UNet2DConditionModel.from_config(teacher_unet.config, time_cond_proj_dim=time_cond_proj_dim)
# load teacher_unet weights into unet # load teacher_unet weights into unet
unet.load_state_dict(teacher_unet.state_dict(), strict=False) unet.load_state_dict(teacher_unet.state_dict(), strict=False)
unet.train() unet.train()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment