Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
4da810b9
Unverified
Commit
4da810b9
authored
Mar 20, 2024
by
Dhruv Nair
Committed by
GitHub
Mar 19, 2024
Browse files
Remove insecure `torch.load` calls (#7393)
update
parent
161c6e14
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
14 additions
and
8 deletions
+14
-8
src/diffusers/loaders/ip_adapter.py
src/diffusers/loaders/ip_adapter.py
+2
-2
src/diffusers/loaders/lora.py
src/diffusers/loaders/lora.py
+2
-2
src/diffusers/loaders/textual_inversion.py
src/diffusers/loaders/textual_inversion.py
+2
-1
src/diffusers/loaders/unet.py
src/diffusers/loaders/unet.py
+2
-2
src/diffusers/models/modeling_utils.py
src/diffusers/models/modeling_utils.py
+6
-1
No files found.
src/diffusers/loaders/ip_adapter.py
View file @
4da810b9
...
...
@@ -19,7 +19,7 @@ import torch
from
huggingface_hub.utils
import
validate_hf_hub_args
from
safetensors
import
safe_open
from
..models.modeling_utils
import
_LOW_CPU_MEM_USAGE_DEFAULT
from
..models.modeling_utils
import
_LOW_CPU_MEM_USAGE_DEFAULT
,
load_state_dict
from
..utils
import
(
_get_model_file
,
is_accelerate_available
,
...
...
@@ -182,7 +182,7 @@ class IPAdapterMixin:
elif
key
.
startswith
(
"ip_adapter."
):
state_dict
[
"ip_adapter"
][
key
.
replace
(
"ip_adapter."
,
""
)]
=
f
.
get_tensor
(
key
)
else
:
state_dict
=
torch
.
load
(
model_file
,
map_location
=
"cpu"
)
state_dict
=
load_state_dict
(
model_file
)
else
:
state_dict
=
pretrained_model_name_or_path_or_dict
...
...
src/diffusers/loaders/lora.py
View file @
4da810b9
...
...
@@ -25,7 +25,7 @@ from packaging import version
from
torch
import
nn
from
..
import
__version__
from
..models.modeling_utils
import
_LOW_CPU_MEM_USAGE_DEFAULT
from
..models.modeling_utils
import
_LOW_CPU_MEM_USAGE_DEFAULT
,
load_state_dict
from
..utils
import
(
USE_PEFT_BACKEND
,
_get_model_file
,
...
...
@@ -281,7 +281,7 @@ class LoraLoaderMixin:
subfolder
=
subfolder
,
user_agent
=
user_agent
,
)
state_dict
=
torch
.
load
(
model_file
,
map_location
=
"cpu"
)
state_dict
=
load_state_dict
(
model_file
)
else
:
state_dict
=
pretrained_model_name_or_path_or_dict
...
...
src/diffusers/loaders/textual_inversion.py
View file @
4da810b9
...
...
@@ -18,6 +18,7 @@ import torch
from
huggingface_hub.utils
import
validate_hf_hub_args
from
torch
import
nn
from
..models.modeling_utils
import
load_state_dict
from
..utils
import
_get_model_file
,
is_accelerate_available
,
is_transformers_available
,
logging
...
...
@@ -100,7 +101,7 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
subfolder
=
subfolder
,
user_agent
=
user_agent
,
)
state_dict
=
torch
.
load
(
model_file
,
map_location
=
"cpu"
)
state_dict
=
load_state_dict
(
model_file
)
else
:
state_dict
=
pretrained_model_name_or_path
...
...
src/diffusers/loaders/unet.py
View file @
4da810b9
...
...
@@ -31,7 +31,7 @@ from ..models.embeddings import (
IPAdapterPlusImageProjection
,
MultiIPAdapterImageProjection
,
)
from
..models.modeling_utils
import
_LOW_CPU_MEM_USAGE_DEFAULT
,
load_model_dict_into_meta
from
..models.modeling_utils
import
_LOW_CPU_MEM_USAGE_DEFAULT
,
load_model_dict_into_meta
,
load_state_dict
from
..utils
import
(
USE_PEFT_BACKEND
,
_get_model_file
,
...
...
@@ -214,7 +214,7 @@ class UNet2DConditionLoadersMixin:
subfolder
=
subfolder
,
user_agent
=
user_agent
,
)
state_dict
=
torch
.
load
(
model_file
,
map_location
=
"cpu"
)
state_dict
=
load_state_dict
(
model_file
)
else
:
state_dict
=
pretrained_model_name_or_path_or_dict
...
...
src/diffusers/models/modeling_utils.py
View file @
4da810b9
...
...
@@ -108,7 +108,12 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
if
file_extension
==
SAFETENSORS_FILE_EXTENSION
:
return
safetensors
.
torch
.
load_file
(
checkpoint_file
,
device
=
"cpu"
)
else
:
return
torch
.
load
(
checkpoint_file
,
map_location
=
"cpu"
)
weights_only_kwarg
=
{
"weights_only"
:
True
}
if
is_torch_version
(
">="
,
"1.13"
)
else
{}
return
torch
.
load
(
checkpoint_file
,
map_location
=
"cpu"
,
**
weights_only_kwarg
,
)
except
Exception
as
e
:
try
:
with
open
(
checkpoint_file
)
as
f
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment