"git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "8abbe22ffbc306b7be0e2e09ba1ce167430f2c7f"
Unverified Commit d9227cf7 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding `use_safetensors` argument to give more control to users (#2123)



* Adding `use_safetensors` argument to give more control to users

about which weights they use.

* Doc style.

* Rebased (not functional).

* Rebased and functional with tests.

* Style.

* Apply suggestions from code review

* Style.

* Addressing comments.

* Update tests/test_pipelines.py
Co-authored-by: default avatarWill Berman <wlbberman@gmail.com>

* Black ???

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarWill Berman <wlbberman@gmail.com>
parent e8282327
...@@ -142,6 +142,17 @@ class UNet2DConditionLoadersMixin: ...@@ -142,6 +142,17 @@ class UNet2DConditionLoadersMixin:
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None) weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
if use_safetensors and not is_safetensors_available():
raise ValueError(
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
)
allow_pickle = False
if use_safetensors is None:
use_safetensors = is_safetensors_available()
allow_pickle = True
user_agent = { user_agent = {
"file_type": "attn_procs_weights", "file_type": "attn_procs_weights",
...@@ -151,7 +162,7 @@ class UNet2DConditionLoadersMixin: ...@@ -151,7 +162,7 @@ class UNet2DConditionLoadersMixin:
model_file = None model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict): if not isinstance(pretrained_model_name_or_path_or_dict, dict):
# Let's first try to load .safetensors weights # Let's first try to load .safetensors weights
if (is_safetensors_available() and weight_name is None) or ( if (use_safetensors and weight_name is None) or (
weight_name is not None and weight_name.endswith(".safetensors") weight_name is not None and weight_name.endswith(".safetensors")
): ):
try: try:
...@@ -169,10 +180,11 @@ class UNet2DConditionLoadersMixin: ...@@ -169,10 +180,11 @@ class UNet2DConditionLoadersMixin:
user_agent=user_agent, user_agent=user_agent,
) )
state_dict = safetensors.torch.load_file(model_file, device="cpu") state_dict = safetensors.torch.load_file(model_file, device="cpu")
except EnvironmentError: except IOError as e:
if not allow_pickle:
raise e
# try loading non-safetensors weights # try loading non-safetensors weights
pass pass
if model_file is None: if model_file is None:
model_file = _get_model_file( model_file = _get_model_file(
pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict,
......
...@@ -392,6 +392,10 @@ class ModelMixin(torch.nn.Module): ...@@ -392,6 +392,10 @@ class ModelMixin(torch.nn.Module):
variant (`str`, *optional*): variant (`str`, *optional*):
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
ignored when using `from_flax`. ignored when using `from_flax`.
use_safetensors (`bool`, *optional* ):
If set to `True`, the pipeline will forcibly load the models from `safetensors` weights. If set to
`None` (the default). The pipeline will load using `safetensors` if safetensors weights are available
*and* if `safetensors` is installed. If the to `False` the pipeline will *not* use `safetensors`.
<Tip> <Tip>
...@@ -423,6 +427,17 @@ class ModelMixin(torch.nn.Module): ...@@ -423,6 +427,17 @@ class ModelMixin(torch.nn.Module):
device_map = kwargs.pop("device_map", None) device_map = kwargs.pop("device_map", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None) variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
if use_safetensors and not is_safetensors_available():
raise ValueError(
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
)
allow_pickle = False
if use_safetensors is None:
use_safetensors = is_safetensors_available()
allow_pickle = True
if low_cpu_mem_usage and not is_accelerate_available(): if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False low_cpu_mem_usage = False
...@@ -509,7 +524,7 @@ class ModelMixin(torch.nn.Module): ...@@ -509,7 +524,7 @@ class ModelMixin(torch.nn.Module):
model = load_flax_checkpoint_in_pytorch_model(model, model_file) model = load_flax_checkpoint_in_pytorch_model(model, model_file)
else: else:
if is_safetensors_available(): if use_safetensors:
try: try:
model_file = _get_model_file( model_file = _get_model_file(
pretrained_model_name_or_path, pretrained_model_name_or_path,
...@@ -525,7 +540,9 @@ class ModelMixin(torch.nn.Module): ...@@ -525,7 +540,9 @@ class ModelMixin(torch.nn.Module):
user_agent=user_agent, user_agent=user_agent,
commit_hash=commit_hash, commit_hash=commit_hash,
) )
except: # noqa: E722 except IOError as e:
if not allow_pickle:
raise e
pass pass
if model_file is None: if model_file is None:
model_file = _get_model_file( model_file = _get_model_file(
......
...@@ -694,6 +694,10 @@ class DiffusionPipeline(ConfigMixin): ...@@ -694,6 +694,10 @@ class DiffusionPipeline(ConfigMixin):
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
setting this argument to `True` will raise an error. setting this argument to `True` will raise an error.
use_safetensors (`bool`, *optional* ):
If set to `True`, the pipeline will be loaded from `safetensors` weights. If set to `None` (the
default). The pipeline will load using `safetensors` if the safetensors weights are available *and* if
`safetensors` is installed. If the to `False` the pipeline will *not* use `safetensors`.
kwargs (remaining dictionary of keyword arguments, *optional*): kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
specific pipeline class. The overwritten components are then directly passed to the pipelines specific pipeline class. The overwritten components are then directly passed to the pipelines
...@@ -752,6 +756,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -752,6 +756,7 @@ class DiffusionPipeline(ConfigMixin):
device_map = kwargs.pop("device_map", None) device_map = kwargs.pop("device_map", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None) variant = kwargs.pop("variant", None)
kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
# 1. Download the checkpoints and configs # 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained # use snapshot download here to get it working from from_pretrained
...@@ -1068,6 +1073,17 @@ class DiffusionPipeline(ConfigMixin): ...@@ -1068,6 +1073,17 @@ class DiffusionPipeline(ConfigMixin):
from_flax = kwargs.pop("from_flax", False) from_flax = kwargs.pop("from_flax", False)
custom_pipeline = kwargs.pop("custom_pipeline", None) custom_pipeline = kwargs.pop("custom_pipeline", None)
variant = kwargs.pop("variant", None) variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
if use_safetensors and not is_safetensors_available():
raise ValueError(
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
)
allow_pickle = False
if use_safetensors is None:
use_safetensors = is_safetensors_available()
allow_pickle = True
pipeline_is_cached = False pipeline_is_cached = False
allow_patterns = None allow_patterns = None
...@@ -1123,9 +1139,17 @@ class DiffusionPipeline(ConfigMixin): ...@@ -1123,9 +1139,17 @@ class DiffusionPipeline(ConfigMixin):
CUSTOM_PIPELINE_FILE_NAME, CUSTOM_PIPELINE_FILE_NAME,
] ]
if (
use_safetensors
and not allow_pickle
and not is_safetensors_compatible(model_filenames, variant=variant)
):
raise EnvironmentError(
f"Could not found the necessary `safetensors` weights in {model_filenames} (variant={variant})"
)
if from_flax: if from_flax:
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"] ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
elif is_safetensors_available() and is_safetensors_compatible(model_filenames, variant=variant): elif use_safetensors and is_safetensors_compatible(model_filenames, variant=variant):
ignore_patterns = ["*.bin", "*.msgpack"] ignore_patterns = ["*.bin", "*.msgpack"]
safetensors_variant_filenames = set([f for f in variant_filenames if f.endswith(".safetensors")]) safetensors_variant_filenames = set([f for f in variant_filenames if f.endswith(".safetensors")])
......
...@@ -440,7 +440,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -440,7 +440,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
# LoRA and no LoRA should NOT be the same # LoRA and no LoRA should NOT be the same
assert (sample - old_sample).abs().max() > 1e-4 assert (sample - old_sample).abs().max() > 1e-4
def test_lora_save_load_safetensors_load_torch(self): def test_lora_save_safetensors_load_torch(self):
# enable deterministic behavior for gradient checkpointing # enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
...@@ -475,6 +475,43 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -475,6 +475,43 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
new_model.to(torch_device) new_model.to(torch_device)
new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.bin") new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.bin")
def test_lora_save_torch_force_load_safetensors_error(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
lora_attn_procs = {}
for name in model.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = model.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
model.set_attn_processor(lora_attn_procs)
# Saving as torch, properly reloads with directly filename
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.to(torch_device)
with self.assertRaises(IOError) as e:
new_model.load_attn_procs(tmpdirname, use_safetensors=True)
self.assertIn("Error no file named pytorch_lora_weights.safetensors", str(e.exception))
def test_lora_on_off(self): def test_lora_on_off(self):
# enable deterministic behavior for gradient checkpointing # enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
......
...@@ -108,6 +108,17 @@ class DownloadTests(unittest.TestCase): ...@@ -108,6 +108,17 @@ class DownloadTests(unittest.TestCase):
# We need to never convert this tiny model to safetensors for this test to pass # We need to never convert this tiny model to safetensors for this test to pass
assert not any(f.endswith(".safetensors") for f in files) assert not any(f.endswith(".safetensors") for f in files)
def test_force_safetensors_error(self):
with tempfile.TemporaryDirectory() as tmpdirname:
# pipeline has Flax weights
with self.assertRaises(EnvironmentError):
tmpdirname = DiffusionPipeline.download(
"hf-internal-testing/tiny-stable-diffusion-pipe-no-safetensors",
safety_checker=None,
cache_dir=tmpdirname,
use_safetensors=True,
)
def test_returned_cached_folder(self): def test_returned_cached_folder(self):
prompt = "hello" prompt = "hello"
pipe = StableDiffusionPipeline.from_pretrained( pipe = StableDiffusionPipeline.from_pretrained(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment