Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
1f4deb69
Unverified
Commit
1f4deb69
authored
Mar 03, 2023
by
Nicolas Patry
Committed by
GitHub
Mar 03, 2023
Browse files
Adding support for `safetensors` and LoRa. (#2448)
* Adding support for `safetensors` and LoRa. * Adding metadata.
parent
f20c8f5a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
121 additions
and
18 deletions
+121
-18
src/diffusers/loaders.py
src/diffusers/loaders.py
+61
-18
tests/models/test_models_unet_2d_condition.py
tests/models/test_models_unet_2d_condition.py
+60
-0
No files found.
src/diffusers/loaders.py
View file @
1f4deb69
...
@@ -19,13 +19,18 @@ import torch
...
@@ -19,13 +19,18 @@ import torch
from
.models.cross_attention
import
LoRACrossAttnProcessor
from
.models.cross_attention
import
LoRACrossAttnProcessor
from
.models.modeling_utils
import
_get_model_file
from
.models.modeling_utils
import
_get_model_file
from
.utils
import
DIFFUSERS_CACHE
,
HF_HUB_OFFLINE
,
logging
from
.utils
import
DIFFUSERS_CACHE
,
HF_HUB_OFFLINE
,
is_safetensors_available
,
logging
if
is_safetensors_available
():
import
safetensors
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
LORA_WEIGHT_NAME
=
"pytorch_lora_weights.bin"
LORA_WEIGHT_NAME
=
"pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE
=
"pytorch_lora_weights.safetensors"
class
AttnProcsLayers
(
torch
.
nn
.
Module
):
class
AttnProcsLayers
(
torch
.
nn
.
Module
):
...
@@ -136,14 +141,39 @@ class UNet2DConditionLoadersMixin:
...
@@ -136,14 +141,39 @@ class UNet2DConditionLoadersMixin:
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
revision
=
kwargs
.
pop
(
"revision"
,
None
)
revision
=
kwargs
.
pop
(
"revision"
,
None
)
subfolder
=
kwargs
.
pop
(
"subfolder"
,
None
)
subfolder
=
kwargs
.
pop
(
"subfolder"
,
None
)
weight_name
=
kwargs
.
pop
(
"weight_name"
,
LORA_WEIGHT_NAME
)
weight_name
=
kwargs
.
pop
(
"weight_name"
,
None
)
user_agent
=
{
user_agent
=
{
"file_type"
:
"attn_procs_weights"
,
"file_type"
:
"attn_procs_weights"
,
"framework"
:
"pytorch"
,
"framework"
:
"pytorch"
,
}
}
model_file
=
None
if
not
isinstance
(
pretrained_model_name_or_path_or_dict
,
dict
):
if
not
isinstance
(
pretrained_model_name_or_path_or_dict
,
dict
):
if
is_safetensors_available
():
if
weight_name
is
None
:
weight_name
=
LORA_WEIGHT_NAME_SAFE
try
:
model_file
=
_get_model_file
(
pretrained_model_name_or_path_or_dict
,
weights_name
=
weight_name
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
resume_download
=
resume_download
,
proxies
=
proxies
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
revision
=
revision
,
subfolder
=
subfolder
,
user_agent
=
user_agent
,
)
state_dict
=
safetensors
.
torch
.
load_file
(
model_file
,
device
=
"cpu"
)
except
EnvironmentError
:
if
weight_name
==
LORA_WEIGHT_NAME_SAFE
:
weight_name
=
None
if
model_file
is
None
:
if
weight_name
is
None
:
weight_name
=
LORA_WEIGHT_NAME
model_file
=
_get_model_file
(
model_file
=
_get_model_file
(
pretrained_model_name_or_path_or_dict
,
pretrained_model_name_or_path_or_dict
,
weights_name
=
weight_name
,
weights_name
=
weight_name
,
...
@@ -195,8 +225,9 @@ class UNet2DConditionLoadersMixin:
...
@@ -195,8 +225,9 @@ class UNet2DConditionLoadersMixin:
self
,
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
],
save_directory
:
Union
[
str
,
os
.
PathLike
],
is_main_process
:
bool
=
True
,
is_main_process
:
bool
=
True
,
weights_name
:
str
=
LORA_WEIGHT_NAME
,
weights_name
:
str
=
None
,
save_function
:
Callable
=
None
,
save_function
:
Callable
=
None
,
safe_serialization
:
bool
=
False
,
):
):
r
"""
r
"""
Save an attention processor to a directory, so that it can be re-loaded using the
Save an attention processor to a directory, so that it can be re-loaded using the
...
@@ -219,6 +250,12 @@ class UNet2DConditionLoadersMixin:
...
@@ -219,6 +250,12 @@ class UNet2DConditionLoadersMixin:
return
return
if
save_function
is
None
:
if
save_function
is
None
:
if
safe_serialization
:
def
save_function
(
weights
,
filename
):
return
safetensors
.
torch
.
save_file
(
weights
,
filename
,
metadata
=
{
"format"
:
"pt"
})
else
:
save_function
=
torch
.
save
save_function
=
torch
.
save
os
.
makedirs
(
save_directory
,
exist_ok
=
True
)
os
.
makedirs
(
save_directory
,
exist_ok
=
True
)
...
@@ -237,6 +274,12 @@ class UNet2DConditionLoadersMixin:
...
@@ -237,6 +274,12 @@ class UNet2DConditionLoadersMixin:
if
filename
.
startswith
(
weights_no_suffix
)
and
os
.
path
.
isfile
(
full_filename
)
and
is_main_process
:
if
filename
.
startswith
(
weights_no_suffix
)
and
os
.
path
.
isfile
(
full_filename
)
and
is_main_process
:
os
.
remove
(
full_filename
)
os
.
remove
(
full_filename
)
if
weights_name
is
None
:
if
safe_serialization
:
weights_name
=
LORA_WEIGHT_NAME_SAFE
else
:
weights_name
=
LORA_WEIGHT_NAME
# Save the model
# Save the model
save_function
(
state_dict
,
os
.
path
.
join
(
save_directory
,
weights_name
))
save_function
(
state_dict
,
os
.
path
.
join
(
save_directory
,
weights_name
))
...
...
tests/models/test_models_unet_2d_condition.py
View file @
1f4deb69
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
# limitations under the License.
# limitations under the License.
import
gc
import
gc
import
os
import
tempfile
import
tempfile
import
unittest
import
unittest
...
@@ -372,6 +373,65 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -372,6 +373,65 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model
.
save_attn_procs
(
tmpdirname
)
model
.
save_attn_procs
(
tmpdirname
)
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdirname
,
"pytorch_lora_weights.bin"
)))
torch
.
manual_seed
(
0
)
new_model
=
self
.
model_class
(
**
init_dict
)
new_model
.
to
(
torch_device
)
new_model
.
load_attn_procs
(
tmpdirname
)
with
torch
.
no_grad
():
new_sample
=
new_model
(
**
inputs_dict
,
cross_attention_kwargs
=
{
"scale"
:
0.5
}).
sample
assert
(
sample
-
new_sample
).
abs
().
max
()
<
1e-4
# LoRA and no LoRA should NOT be the same
assert
(
sample
-
old_sample
).
abs
().
max
()
>
1e-4
def
test_lora_save_load_safetensors
(
self
):
# enable deterministic behavior for gradient checkpointing
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
init_dict
[
"attention_head_dim"
]
=
(
8
,
16
)
torch
.
manual_seed
(
0
)
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
with
torch
.
no_grad
():
old_sample
=
model
(
**
inputs_dict
).
sample
lora_attn_procs
=
{}
for
name
in
model
.
attn_processors
.
keys
():
cross_attention_dim
=
None
if
name
.
endswith
(
"attn1.processor"
)
else
model
.
config
.
cross_attention_dim
if
name
.
startswith
(
"mid_block"
):
hidden_size
=
model
.
config
.
block_out_channels
[
-
1
]
elif
name
.
startswith
(
"up_blocks"
):
block_id
=
int
(
name
[
len
(
"up_blocks."
)])
hidden_size
=
list
(
reversed
(
model
.
config
.
block_out_channels
))[
block_id
]
elif
name
.
startswith
(
"down_blocks"
):
block_id
=
int
(
name
[
len
(
"down_blocks."
)])
hidden_size
=
model
.
config
.
block_out_channels
[
block_id
]
lora_attn_procs
[
name
]
=
LoRACrossAttnProcessor
(
hidden_size
=
hidden_size
,
cross_attention_dim
=
cross_attention_dim
)
lora_attn_procs
[
name
]
=
lora_attn_procs
[
name
].
to
(
model
.
device
)
# add 1 to weights to mock trained weights
with
torch
.
no_grad
():
lora_attn_procs
[
name
].
to_q_lora
.
up
.
weight
+=
1
lora_attn_procs
[
name
].
to_k_lora
.
up
.
weight
+=
1
lora_attn_procs
[
name
].
to_v_lora
.
up
.
weight
+=
1
lora_attn_procs
[
name
].
to_out_lora
.
up
.
weight
+=
1
model
.
set_attn_processor
(
lora_attn_procs
)
with
torch
.
no_grad
():
sample
=
model
(
**
inputs_dict
,
cross_attention_kwargs
=
{
"scale"
:
0.5
}).
sample
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model
.
save_attn_procs
(
tmpdirname
,
safe_serialization
=
True
)
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdirname
,
"pytorch_lora_weights.safetensors"
)))
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
new_model
=
self
.
model_class
(
**
init_dict
)
new_model
=
self
.
model_class
(
**
init_dict
)
new_model
.
to
(
torch_device
)
new_model
.
to
(
torch_device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment