Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
f653ded7
Unverified
Commit
f653ded7
authored
Jan 26, 2023
by
Patrick von Platen
Committed by
GitHub
Jan 26, 2023
Browse files
[LoRA] Make sure LoRA can be disabled after it's run (#2128)
parent
e92d43fe
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
71 additions
and
25 deletions
+71
-25
src/diffusers/models/cross_attention.py
src/diffusers/models/cross_attention.py
+14
-0
tests/models/test_models_unet_2d_condition.py
tests/models/test_models_unet_2d_condition.py
+57
-25
No files found.
src/diffusers/models/cross_attention.py
View file @
f653ded7
...
...
@@ -17,9 +17,13 @@ import torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
..utils
import
logging
from
..utils.import_utils
import
is_xformers_available
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
if
is_xformers_available
():
import
xformers
import
xformers.ops
...
...
@@ -151,6 +155,16 @@ class CrossAttention(nn.Module):
self
.
set_processor
(
processor
)
def
set_processor
(
self
,
processor
:
"AttnProcessor"
):
# if current processor is in `self._modules` and if passed `processor` is not, we need to
# pop `processor` from `self._modules`
if
(
hasattr
(
self
,
"processor"
)
and
isinstance
(
self
.
processor
,
torch
.
nn
.
Module
)
and
not
isinstance
(
processor
,
torch
.
nn
.
Module
)
):
logger
.
info
(
f
"You are removing possibly trained weights of
{
self
.
processor
}
with
{
processor
}
"
)
self
.
_modules
.
pop
(
"processor"
)
self
.
processor
=
processor
def
forward
(
self
,
hidden_states
,
encoder_hidden_states
=
None
,
attention_mask
=
None
,
**
cross_attention_kwargs
):
...
...
tests/models/test_models_unet_2d_condition.py
View file @
f653ded7
...
...
@@ -20,7 +20,7 @@ import unittest
import
torch
from
diffusers
import
UNet2DConditionModel
from
diffusers.models.cross_attention
import
LoRACrossAttnProcessor
from
diffusers.models.cross_attention
import
CrossAttnProcessor
,
LoRACrossAttnProcessor
from
diffusers.utils
import
(
floats_tensor
,
load_hf_numpy
,
...
...
@@ -40,6 +40,34 @@ logger = logging.get_logger(__name__)
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
def
create_lora_layers
(
model
):
lora_attn_procs
=
{}
for
name
in
model
.
attn_processors
.
keys
():
cross_attention_dim
=
None
if
name
.
endswith
(
"attn1.processor"
)
else
model
.
config
.
cross_attention_dim
if
name
.
startswith
(
"mid_block"
):
hidden_size
=
model
.
config
.
block_out_channels
[
-
1
]
elif
name
.
startswith
(
"up_blocks"
):
block_id
=
int
(
name
[
len
(
"up_blocks."
)])
hidden_size
=
list
(
reversed
(
model
.
config
.
block_out_channels
))[
block_id
]
elif
name
.
startswith
(
"down_blocks"
):
block_id
=
int
(
name
[
len
(
"down_blocks."
)])
hidden_size
=
model
.
config
.
block_out_channels
[
block_id
]
lora_attn_procs
[
name
]
=
LoRACrossAttnProcessor
(
hidden_size
=
hidden_size
,
cross_attention_dim
=
cross_attention_dim
)
lora_attn_procs
[
name
]
=
lora_attn_procs
[
name
].
to
(
model
.
device
)
# add 1 to weights to mock trained weights
with
torch
.
no_grad
():
lora_attn_procs
[
name
].
to_q_lora
.
up
.
weight
+=
1
lora_attn_procs
[
name
].
to_k_lora
.
up
.
weight
+=
1
lora_attn_procs
[
name
].
to_v_lora
.
up
.
weight
+=
1
lora_attn_procs
[
name
].
to_out_lora
.
up
.
weight
+=
1
return
lora_attn_procs
class
UNet2DConditionModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
UNet2DConditionModel
...
...
@@ -336,30 +364,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
with
torch
.
no_grad
():
old_sample
=
model
(
**
inputs_dict
).
sample
lora_attn_procs
=
{}
for
name
in
model
.
attn_processors
.
keys
():
cross_attention_dim
=
None
if
name
.
endswith
(
"attn1.processor"
)
else
model
.
config
.
cross_attention_dim
if
name
.
startswith
(
"mid_block"
):
hidden_size
=
model
.
config
.
block_out_channels
[
-
1
]
elif
name
.
startswith
(
"up_blocks"
):
block_id
=
int
(
name
[
len
(
"up_blocks."
)])
hidden_size
=
list
(
reversed
(
model
.
config
.
block_out_channels
))[
block_id
]
elif
name
.
startswith
(
"down_blocks"
):
block_id
=
int
(
name
[
len
(
"down_blocks."
)])
hidden_size
=
model
.
config
.
block_out_channels
[
block_id
]
lora_attn_procs
[
name
]
=
LoRACrossAttnProcessor
(
hidden_size
=
hidden_size
,
cross_attention_dim
=
cross_attention_dim
)
lora_attn_procs
[
name
]
=
lora_attn_procs
[
name
].
to
(
model
.
device
)
# add 1 to weights to mock trained weights
with
torch
.
no_grad
():
lora_attn_procs
[
name
].
to_q_lora
.
up
.
weight
+=
1
lora_attn_procs
[
name
].
to_k_lora
.
up
.
weight
+=
1
lora_attn_procs
[
name
].
to_v_lora
.
up
.
weight
+=
1
lora_attn_procs
[
name
].
to_out_lora
.
up
.
weight
+=
1
lora_attn_procs
=
create_lora_layers
(
model
)
model
.
set_attn_processor
(
lora_attn_procs
)
with
torch
.
no_grad
():
...
...
@@ -380,6 +385,33 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
# LoRA and no LoRA should NOT be the same
assert
(
sample
-
old_sample
).
abs
().
max
()
>
1e-4
def
test_lora_on_off
(
self
):
# enable deterministic behavior for gradient checkpointing
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
init_dict
[
"attention_head_dim"
]
=
(
8
,
16
)
torch
.
manual_seed
(
0
)
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
with
torch
.
no_grad
():
old_sample
=
model
(
**
inputs_dict
).
sample
lora_attn_procs
=
create_lora_layers
(
model
)
model
.
set_attn_processor
(
lora_attn_procs
)
with
torch
.
no_grad
():
sample
=
model
(
**
inputs_dict
,
cross_attention_kwargs
=
{
"scale"
:
0.0
}).
sample
model
.
set_attn_processor
(
CrossAttnProcessor
())
with
torch
.
no_grad
():
new_sample
=
model
(
**
inputs_dict
).
sample
assert
(
sample
-
new_sample
).
abs
().
max
()
<
1e-4
assert
(
sample
-
old_sample
).
abs
().
max
()
<
1e-4
@
slow
class
UNet2DConditionModelIntegrationTests
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment