Unverified Commit 79859134 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Core] fix FreeU disable method (#5552)

* disable freeu debug

* debug

* potentially fix.

* finish

* manually remove the spaces

* remove tab
parent f912f39b
...@@ -791,7 +791,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -791,7 +791,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
freeu_keys = {"s1", "s2", "b1", "b2"} freeu_keys = {"s1", "s2", "b1", "b2"}
for i, upsample_block in enumerate(self.up_blocks): for i, upsample_block in enumerate(self.up_blocks):
for k in freeu_keys: for k in freeu_keys:
if hasattr(upsample_block, k) or getattr(upsample_block, k) is not None: if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
setattr(upsample_block, k, None) setattr(upsample_block, k, None)
def forward( def forward(
......
...@@ -494,7 +494,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -494,7 +494,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
freeu_keys = {"s1", "s2", "b1", "b2"} freeu_keys = {"s1", "s2", "b1", "b2"}
for i, upsample_block in enumerate(self.up_blocks): for i, upsample_block in enumerate(self.up_blocks):
for k in freeu_keys: for k in freeu_keys:
if hasattr(upsample_block, k) or getattr(upsample_block, k) is not None: if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
setattr(upsample_block, k, None) setattr(upsample_block, k, None)
def forward( def forward(
......
...@@ -1001,7 +1001,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -1001,7 +1001,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
freeu_keys = {"s1", "s2", "b1", "b2"} freeu_keys = {"s1", "s2", "b1", "b2"}
for i, upsample_block in enumerate(self.up_blocks): for i, upsample_block in enumerate(self.up_blocks):
for k in freeu_keys: for k in freeu_keys:
if hasattr(upsample_block, k) or getattr(upsample_block, k) is not None: if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
setattr(upsample_block, k, None) setattr(upsample_block, k, None)
def forward( def forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment