Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
ae060fc4
Unverified
Commit
ae060fc4
authored
Jan 05, 2024
by
Sayak Paul
Committed by
GitHub
Jan 05, 2024
Browse files
[feat] introduce `unload_lora()`. (#6451)
* introduce unload_lora. * fix-copies
parent
9d945b2b
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
38 additions
and
10 deletions
+38
-10
src/diffusers/models/unet_2d_condition.py
src/diffusers/models/unet_2d_condition.py
+11
-0
src/diffusers/models/unet_3d_condition.py
src/diffusers/models/unet_3d_condition.py
+13
-1
src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py
...ines/deprecated/versatile_diffusion/modeling_text_unet.py
+11
-0
tests/lora/test_lora_layers_old_backend.py
tests/lora/test_lora_layers_old_backend.py
+3
-9
No files found.
src/diffusers/models/unet_2d_condition.py
View file @
ae060fc4
...
@@ -829,6 +829,17 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
...
@@ -829,6 +829,17 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if
self
.
original_attn_processors
is
not
None
:
if
self
.
original_attn_processors
is
not
None
:
self
.
set_attn_processor
(
self
.
original_attn_processors
)
self
.
set_attn_processor
(
self
.
original_attn_processors
)
def
unload_lora
(
self
):
"""Unloads LoRA weights."""
deprecate
(
"unload_lora"
,
"0.28.0"
,
"Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters()."
,
)
for
module
in
self
.
modules
():
if
hasattr
(
module
,
"set_lora_layer"
):
module
.
set_lora_layer
(
None
)
def
forward
(
def
forward
(
self
,
self
,
sample
:
torch
.
FloatTensor
,
sample
:
torch
.
FloatTensor
,
...
...
src/diffusers/models/unet_3d_condition.py
View file @
ae060fc4
...
@@ -22,7 +22,7 @@ import torch.utils.checkpoint
...
@@ -22,7 +22,7 @@ import torch.utils.checkpoint
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..loaders
import
UNet2DConditionLoadersMixin
from
..loaders
import
UNet2DConditionLoadersMixin
from
..utils
import
BaseOutput
,
logging
from
..utils
import
BaseOutput
,
deprecate
,
logging
from
.activations
import
get_activation
from
.activations
import
get_activation
from
.attention_processor
import
(
from
.attention_processor
import
(
ADDED_KV_ATTENTION_PROCESSORS
,
ADDED_KV_ATTENTION_PROCESSORS
,
...
@@ -503,6 +503,18 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
...
@@ -503,6 +503,18 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if
hasattr
(
upsample_block
,
k
)
or
getattr
(
upsample_block
,
k
,
None
)
is
not
None
:
if
hasattr
(
upsample_block
,
k
)
or
getattr
(
upsample_block
,
k
,
None
)
is
not
None
:
setattr
(
upsample_block
,
k
,
None
)
setattr
(
upsample_block
,
k
,
None
)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unload_lora
def
unload_lora
(
self
):
"""Unloads LoRA weights."""
deprecate
(
"unload_lora"
,
"0.28.0"
,
"Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters()."
,
)
for
module
in
self
.
modules
():
if
hasattr
(
module
,
"set_lora_layer"
):
module
.
set_lora_layer
(
None
)
def
forward
(
def
forward
(
self
,
self
,
sample
:
torch
.
FloatTensor
,
sample
:
torch
.
FloatTensor
,
...
...
src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py
View file @
ae060fc4
...
@@ -1034,6 +1034,17 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
...
@@ -1034,6 +1034,17 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if
self
.
original_attn_processors
is
not
None
:
if
self
.
original_attn_processors
is
not
None
:
self
.
set_attn_processor
(
self
.
original_attn_processors
)
self
.
set_attn_processor
(
self
.
original_attn_processors
)
def
unload_lora
(
self
):
"""Unloads LoRA weights."""
deprecate
(
"unload_lora"
,
"0.28.0"
,
"Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters()."
,
)
for
module
in
self
.
modules
():
if
hasattr
(
module
,
"set_lora_layer"
):
module
.
set_lora_layer
(
None
)
def
forward
(
def
forward
(
self
,
self
,
sample
:
torch
.
FloatTensor
,
sample
:
torch
.
FloatTensor
,
...
...
tests/lora/test_lora_layers_old_backend.py
View file @
ae060fc4
...
@@ -151,9 +151,7 @@ def create_unet_lora_layers(unet: nn.Module, rank=4, mock_weights=True):
...
@@ -151,9 +151,7 @@ def create_unet_lora_layers(unet: nn.Module, rank=4, mock_weights=True):
unet_lora_sd
=
unet_lora_state_dict
(
unet
)
unet_lora_sd
=
unet_lora_state_dict
(
unet
)
# Unload LoRA.
# Unload LoRA.
for
module
in
unet
.
modules
():
unet
.
unload_lora
()
if
hasattr
(
module
,
"set_lora_layer"
):
module
.
set_lora_layer
(
None
)
return
unet_lora_parameters
,
unet_lora_sd
return
unet_lora_parameters
,
unet_lora_sd
...
@@ -230,9 +228,7 @@ def create_3d_unet_lora_layers(unet: nn.Module, rank=4, mock_weights=True):
...
@@ -230,9 +228,7 @@ def create_3d_unet_lora_layers(unet: nn.Module, rank=4, mock_weights=True):
unet_lora_sd
=
unet_lora_state_dict
(
unet
)
unet_lora_sd
=
unet_lora_state_dict
(
unet
)
# Unload LoRA.
# Unload LoRA.
for
module
in
unet
.
modules
():
unet
.
unload_lora
()
if
hasattr
(
module
,
"set_lora_layer"
):
module
.
set_lora_layer
(
None
)
return
unet_lora_sd
return
unet_lora_sd
...
@@ -1545,9 +1541,7 @@ class UNet2DConditionLoRAModelTests(unittest.TestCase):
...
@@ -1545,9 +1541,7 @@ class UNet2DConditionLoRAModelTests(unittest.TestCase):
sample
=
model
(
**
inputs_dict
,
cross_attention_kwargs
=
{
"scale"
:
0.0
}).
sample
sample
=
model
(
**
inputs_dict
,
cross_attention_kwargs
=
{
"scale"
:
0.0
}).
sample
# Unload LoRA.
# Unload LoRA.
for
module
in
model
.
modules
():
model
.
unload_lora
()
if
hasattr
(
module
,
"set_lora_layer"
):
module
.
set_lora_layer
(
None
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
new_sample
=
model
(
**
inputs_dict
).
sample
new_sample
=
model
(
**
inputs_dict
).
sample
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment