Unverified Commit 4da810b9 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Remove insecure `torch.load` calls (#7393)

update
parent 161c6e14
...@@ -19,7 +19,7 @@ import torch ...@@ -19,7 +19,7 @@ import torch
from huggingface_hub.utils import validate_hf_hub_args from huggingface_hub.utils import validate_hf_hub_args
from safetensors import safe_open from safetensors import safe_open
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
from ..utils import ( from ..utils import (
_get_model_file, _get_model_file,
is_accelerate_available, is_accelerate_available,
...@@ -182,7 +182,7 @@ class IPAdapterMixin: ...@@ -182,7 +182,7 @@ class IPAdapterMixin:
elif key.startswith("ip_adapter."): elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
else: else:
state_dict = torch.load(model_file, map_location="cpu") state_dict = load_state_dict(model_file)
else: else:
state_dict = pretrained_model_name_or_path_or_dict state_dict = pretrained_model_name_or_path_or_dict
......
...@@ -25,7 +25,7 @@ from packaging import version ...@@ -25,7 +25,7 @@ from packaging import version
from torch import nn from torch import nn
from .. import __version__ from .. import __version__
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
from ..utils import ( from ..utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
_get_model_file, _get_model_file,
...@@ -281,7 +281,7 @@ class LoraLoaderMixin: ...@@ -281,7 +281,7 @@ class LoraLoaderMixin:
subfolder=subfolder, subfolder=subfolder,
user_agent=user_agent, user_agent=user_agent,
) )
state_dict = torch.load(model_file, map_location="cpu") state_dict = load_state_dict(model_file)
else: else:
state_dict = pretrained_model_name_or_path_or_dict state_dict = pretrained_model_name_or_path_or_dict
......
...@@ -18,6 +18,7 @@ import torch ...@@ -18,6 +18,7 @@ import torch
from huggingface_hub.utils import validate_hf_hub_args from huggingface_hub.utils import validate_hf_hub_args
from torch import nn from torch import nn
from ..models.modeling_utils import load_state_dict
from ..utils import _get_model_file, is_accelerate_available, is_transformers_available, logging from ..utils import _get_model_file, is_accelerate_available, is_transformers_available, logging
...@@ -100,7 +101,7 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs) ...@@ -100,7 +101,7 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
subfolder=subfolder, subfolder=subfolder,
user_agent=user_agent, user_agent=user_agent,
) )
state_dict = torch.load(model_file, map_location="cpu") state_dict = load_state_dict(model_file)
else: else:
state_dict = pretrained_model_name_or_path state_dict = pretrained_model_name_or_path
......
...@@ -31,7 +31,7 @@ from ..models.embeddings import ( ...@@ -31,7 +31,7 @@ from ..models.embeddings import (
IPAdapterPlusImageProjection, IPAdapterPlusImageProjection,
MultiIPAdapterImageProjection, MultiIPAdapterImageProjection,
) )
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
from ..utils import ( from ..utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
_get_model_file, _get_model_file,
...@@ -214,7 +214,7 @@ class UNet2DConditionLoadersMixin: ...@@ -214,7 +214,7 @@ class UNet2DConditionLoadersMixin:
subfolder=subfolder, subfolder=subfolder,
user_agent=user_agent, user_agent=user_agent,
) )
state_dict = torch.load(model_file, map_location="cpu") state_dict = load_state_dict(model_file)
else: else:
state_dict = pretrained_model_name_or_path_or_dict state_dict = pretrained_model_name_or_path_or_dict
......
...@@ -108,7 +108,12 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[ ...@@ -108,7 +108,12 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
if file_extension == SAFETENSORS_FILE_EXTENSION: if file_extension == SAFETENSORS_FILE_EXTENSION:
return safetensors.torch.load_file(checkpoint_file, device="cpu") return safetensors.torch.load_file(checkpoint_file, device="cpu")
else: else:
return torch.load(checkpoint_file, map_location="cpu") weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
return torch.load(
checkpoint_file,
map_location="cpu",
**weights_only_kwarg,
)
except Exception as e: except Exception as e:
try: try:
with open(checkpoint_file) as f: with open(checkpoint_file) as f:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment