Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
ae060fc4
Unverified
Commit
ae060fc4
authored
Jan 05, 2024
by
Sayak Paul
Committed by
GitHub
Jan 05, 2024
Browse files
[feat] introduce `unload_lora()`. (#6451)
* introduce unload_lora. * fix-copies
parent
9d945b2b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
38 additions
and
10 deletions
+38
-10
src/diffusers/models/unet_2d_condition.py
src/diffusers/models/unet_2d_condition.py
+11
-0
src/diffusers/models/unet_3d_condition.py
src/diffusers/models/unet_3d_condition.py
+13
-1
src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py
...ines/deprecated/versatile_diffusion/modeling_text_unet.py
+11
-0
tests/lora/test_lora_layers_old_backend.py
tests/lora/test_lora_layers_old_backend.py
+3
-9
No files found.
src/diffusers/models/unet_2d_condition.py
View file @
ae060fc4
...
...
@@ -829,6 +829,17 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if
self
.
original_attn_processors
is
not
None
:
self
.
set_attn_processor
(
self
.
original_attn_processors
)
def
unload_lora
(
self
):
"""Unloads LoRA weights."""
deprecate
(
"unload_lora"
,
"0.28.0"
,
"Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters()."
,
)
for
module
in
self
.
modules
():
if
hasattr
(
module
,
"set_lora_layer"
):
module
.
set_lora_layer
(
None
)
def
forward
(
self
,
sample
:
torch
.
FloatTensor
,
...
...
src/diffusers/models/unet_3d_condition.py
View file @
ae060fc4
...
...
@@ -22,7 +22,7 @@ import torch.utils.checkpoint
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..loaders
import
UNet2DConditionLoadersMixin
from
..utils
import
BaseOutput
,
logging
from
..utils
import
BaseOutput
,
deprecate
,
logging
from
.activations
import
get_activation
from
.attention_processor
import
(
ADDED_KV_ATTENTION_PROCESSORS
,
...
...
@@ -503,6 +503,18 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if
hasattr
(
upsample_block
,
k
)
or
getattr
(
upsample_block
,
k
,
None
)
is
not
None
:
setattr
(
upsample_block
,
k
,
None
)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unload_lora
def
unload_lora
(
self
):
"""Unloads LoRA weights."""
deprecate
(
"unload_lora"
,
"0.28.0"
,
"Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters()."
,
)
for
module
in
self
.
modules
():
if
hasattr
(
module
,
"set_lora_layer"
):
module
.
set_lora_layer
(
None
)
def
forward
(
self
,
sample
:
torch
.
FloatTensor
,
...
...
src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py
View file @
ae060fc4
...
...
@@ -1034,6 +1034,17 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if
self
.
original_attn_processors
is
not
None
:
self
.
set_attn_processor
(
self
.
original_attn_processors
)
def
unload_lora
(
self
):
"""Unloads LoRA weights."""
deprecate
(
"unload_lora"
,
"0.28.0"
,
"Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters()."
,
)
for
module
in
self
.
modules
():
if
hasattr
(
module
,
"set_lora_layer"
):
module
.
set_lora_layer
(
None
)
def
forward
(
self
,
sample
:
torch
.
FloatTensor
,
...
...
tests/lora/test_lora_layers_old_backend.py
View file @
ae060fc4
...
...
@@ -151,9 +151,7 @@ def create_unet_lora_layers(unet: nn.Module, rank=4, mock_weights=True):
unet_lora_sd
=
unet_lora_state_dict
(
unet
)
# Unload LoRA.
for
module
in
unet
.
modules
():
if
hasattr
(
module
,
"set_lora_layer"
):
module
.
set_lora_layer
(
None
)
unet
.
unload_lora
()
return
unet_lora_parameters
,
unet_lora_sd
...
...
@@ -230,9 +228,7 @@ def create_3d_unet_lora_layers(unet: nn.Module, rank=4, mock_weights=True):
unet_lora_sd
=
unet_lora_state_dict
(
unet
)
# Unload LoRA.
for
module
in
unet
.
modules
():
if
hasattr
(
module
,
"set_lora_layer"
):
module
.
set_lora_layer
(
None
)
unet
.
unload_lora
()
return
unet_lora_sd
...
...
@@ -1545,9 +1541,7 @@ class UNet2DConditionLoRAModelTests(unittest.TestCase):
sample
=
model
(
**
inputs_dict
,
cross_attention_kwargs
=
{
"scale"
:
0.0
}).
sample
# Unload LoRA.
for
module
in
model
.
modules
():
if
hasattr
(
module
,
"set_lora_layer"
):
module
.
set_lora_layer
(
None
)
model
.
unload_lora
()
with
torch
.
no_grad
():
new_sample
=
model
(
**
inputs_dict
).
sample
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment