Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
f653ded7
Unverified
Commit
f653ded7
authored
Jan 26, 2023
by
Patrick von Platen
Committed by
GitHub
Jan 26, 2023
Browse files
[LoRA] Make sure LoRA can be disabled after it's run (#2128)
parent
e92d43fe
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
71 additions
and
25 deletions
+71
-25
src/diffusers/models/cross_attention.py
src/diffusers/models/cross_attention.py
+14
-0
tests/models/test_models_unet_2d_condition.py
tests/models/test_models_unet_2d_condition.py
+57
-25
No files found.
src/diffusers/models/cross_attention.py
View file @
f653ded7
...
...
@@ -17,9 +17,13 @@ import torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
..utils
import
logging
from
..utils.import_utils
import
is_xformers_available
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
if
is_xformers_available
():
import
xformers
import
xformers.ops
...
...
@@ -151,6 +155,16 @@ class CrossAttention(nn.Module):
self
.
set_processor
(
processor
)
def
set_processor
(
self
,
processor
:
"AttnProcessor"
):
# if current processor is in `self._modules` and if passed `processor` is not, we need to
# pop `processor` from `self._modules`
if
(
hasattr
(
self
,
"processor"
)
and
isinstance
(
self
.
processor
,
torch
.
nn
.
Module
)
and
not
isinstance
(
processor
,
torch
.
nn
.
Module
)
):
logger
.
info
(
f
"You are removing possibly trained weights of
{
self
.
processor
}
with
{
processor
}
"
)
self
.
_modules
.
pop
(
"processor"
)
self
.
processor
=
processor
def
forward
(
self
,
hidden_states
,
encoder_hidden_states
=
None
,
attention_mask
=
None
,
**
cross_attention_kwargs
):
...
...
tests/models/test_models_unet_2d_condition.py
View file @
f653ded7
...
...
@@ -20,7 +20,7 @@ import unittest
import
torch
from
diffusers
import
UNet2DConditionModel
from
diffusers.models.cross_attention
import
LoRACrossAttnProcessor
from
diffusers.models.cross_attention
import
CrossAttnProcessor
,
LoRACrossAttnProcessor
from
diffusers.utils
import
(
floats_tensor
,
load_hf_numpy
,
...
...
@@ -40,6 +40,34 @@ logger = logging.get_logger(__name__)
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
def
create_lora_layers
(
model
):
lora_attn_procs
=
{}
for
name
in
model
.
attn_processors
.
keys
():
cross_attention_dim
=
None
if
name
.
endswith
(
"attn1.processor"
)
else
model
.
config
.
cross_attention_dim
if
name
.
startswith
(
"mid_block"
):
hidden_size
=
model
.
config
.
block_out_channels
[
-
1
]
elif
name
.
startswith
(
"up_blocks"
):
block_id
=
int
(
name
[
len
(
"up_blocks."
)])
hidden_size
=
list
(
reversed
(
model
.
config
.
block_out_channels
))[
block_id
]
elif
name
.
startswith
(
"down_blocks"
):
block_id
=
int
(
name
[
len
(
"down_blocks."
)])
hidden_size
=
model
.
config
.
block_out_channels
[
block_id
]
lora_attn_procs
[
name
]
=
LoRACrossAttnProcessor
(
hidden_size
=
hidden_size
,
cross_attention_dim
=
cross_attention_dim
)
lora_attn_procs
[
name
]
=
lora_attn_procs
[
name
].
to
(
model
.
device
)
# add 1 to weights to mock trained weights
with
torch
.
no_grad
():
lora_attn_procs
[
name
].
to_q_lora
.
up
.
weight
+=
1
lora_attn_procs
[
name
].
to_k_lora
.
up
.
weight
+=
1
lora_attn_procs
[
name
].
to_v_lora
.
up
.
weight
+=
1
lora_attn_procs
[
name
].
to_out_lora
.
up
.
weight
+=
1
return
lora_attn_procs
class
UNet2DConditionModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
UNet2DConditionModel
...
...
@@ -336,30 +364,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
with
torch
.
no_grad
():
old_sample
=
model
(
**
inputs_dict
).
sample
lora_attn_procs
=
{}
for
name
in
model
.
attn_processors
.
keys
():
cross_attention_dim
=
None
if
name
.
endswith
(
"attn1.processor"
)
else
model
.
config
.
cross_attention_dim
if
name
.
startswith
(
"mid_block"
):
hidden_size
=
model
.
config
.
block_out_channels
[
-
1
]
elif
name
.
startswith
(
"up_blocks"
):
block_id
=
int
(
name
[
len
(
"up_blocks."
)])
hidden_size
=
list
(
reversed
(
model
.
config
.
block_out_channels
))[
block_id
]
elif
name
.
startswith
(
"down_blocks"
):
block_id
=
int
(
name
[
len
(
"down_blocks."
)])
hidden_size
=
model
.
config
.
block_out_channels
[
block_id
]
lora_attn_procs
[
name
]
=
LoRACrossAttnProcessor
(
hidden_size
=
hidden_size
,
cross_attention_dim
=
cross_attention_dim
)
lora_attn_procs
[
name
]
=
lora_attn_procs
[
name
].
to
(
model
.
device
)
# add 1 to weights to mock trained weights
with
torch
.
no_grad
():
lora_attn_procs
[
name
].
to_q_lora
.
up
.
weight
+=
1
lora_attn_procs
[
name
].
to_k_lora
.
up
.
weight
+=
1
lora_attn_procs
[
name
].
to_v_lora
.
up
.
weight
+=
1
lora_attn_procs
[
name
].
to_out_lora
.
up
.
weight
+=
1
lora_attn_procs
=
create_lora_layers
(
model
)
model
.
set_attn_processor
(
lora_attn_procs
)
with
torch
.
no_grad
():
...
...
@@ -380,6 +385,33 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
# LoRA and no LoRA should NOT be the same
assert
(
sample
-
old_sample
).
abs
().
max
()
>
1e-4
def
test_lora_on_off
(
self
):
# enable deterministic behavior for gradient checkpointing
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
init_dict
[
"attention_head_dim"
]
=
(
8
,
16
)
torch
.
manual_seed
(
0
)
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
with
torch
.
no_grad
():
old_sample
=
model
(
**
inputs_dict
).
sample
lora_attn_procs
=
create_lora_layers
(
model
)
model
.
set_attn_processor
(
lora_attn_procs
)
with
torch
.
no_grad
():
sample
=
model
(
**
inputs_dict
,
cross_attention_kwargs
=
{
"scale"
:
0.0
}).
sample
model
.
set_attn_processor
(
CrossAttnProcessor
())
with
torch
.
no_grad
():
new_sample
=
model
(
**
inputs_dict
).
sample
assert
(
sample
-
new_sample
).
abs
().
max
()
<
1e-4
assert
(
sample
-
old_sample
).
abs
().
max
()
<
1e-4
@
slow
class
UNet2DConditionModelIntegrationTests
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment