Unverified Commit e2d9a9be authored by Haiwen Huang's avatar Haiwen Huang Committed by GitHub
Browse files

fix the in-place modification in unet condition when using controlnet (#2586)



* fix the in-place modification in unet condition when using controlnet, which will cause backprop errors when training

* add clone to mid block

* fix-copies

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarWilliam Berman <WLBberman@gmail.com>
parent f9cfb5ab
...@@ -598,7 +598,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -598,7 +598,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
for down_block_res_sample, down_block_additional_residual in zip( for down_block_res_sample, down_block_additional_residual in zip(
down_block_res_samples, down_block_additional_residuals down_block_res_samples, down_block_additional_residuals
): ):
down_block_res_sample += down_block_additional_residual down_block_res_sample = down_block_res_sample + down_block_additional_residual
new_down_block_res_samples += (down_block_res_sample,) new_down_block_res_samples += (down_block_res_sample,)
down_block_res_samples = new_down_block_res_samples down_block_res_samples = new_down_block_res_samples
...@@ -614,7 +614,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -614,7 +614,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
) )
if mid_block_additional_residual is not None: if mid_block_additional_residual is not None:
sample += mid_block_additional_residual sample = sample + mid_block_additional_residual
# 5. up # 5. up
for i, upsample_block in enumerate(self.up_blocks): for i, upsample_block in enumerate(self.up_blocks):
......
...@@ -688,7 +688,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -688,7 +688,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
for down_block_res_sample, down_block_additional_residual in zip( for down_block_res_sample, down_block_additional_residual in zip(
down_block_res_samples, down_block_additional_residuals down_block_res_samples, down_block_additional_residuals
): ):
down_block_res_sample += down_block_additional_residual down_block_res_sample = down_block_res_sample + down_block_additional_residual
new_down_block_res_samples += (down_block_res_sample,) new_down_block_res_samples += (down_block_res_sample,)
down_block_res_samples = new_down_block_res_samples down_block_res_samples = new_down_block_res_samples
...@@ -704,7 +704,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -704,7 +704,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
) )
if mid_block_additional_residual is not None: if mid_block_additional_residual is not None:
sample += mid_block_additional_residual sample = sample + mid_block_additional_residual
# 5. up # 5. up
for i, upsample_block in enumerate(self.up_blocks): for i, upsample_block in enumerate(self.up_blocks):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment