"docs/vscode:/vscode.git/clone" did not exist on "6e0f84d44d19651446e1ecb197f6e583426a47fc"
Unverified Commit 1f4deb69 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding support for `safetensors` and LoRa. (#2448)

* Adding support for `safetensors` and LoRa.

* Adding metadata.
parent f20c8f5a
......@@ -19,13 +19,18 @@ import torch
from .models.cross_attention import LoRACrossAttnProcessor
from .models.modeling_utils import _get_model_file
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, logging
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, is_safetensors_available, logging
if is_safetensors_available():
import safetensors
logger = logging.get_logger(__name__)
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
class AttnProcsLayers(torch.nn.Module):
......@@ -136,28 +141,53 @@ class UNet2DConditionLoadersMixin:
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", LORA_WEIGHT_NAME)
weight_name = kwargs.pop("weight_name", None)
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = torch.load(model_file, map_location="cpu")
if is_safetensors_available():
if weight_name is None:
weight_name = LORA_WEIGHT_NAME_SAFE
try:
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except EnvironmentError:
if weight_name == LORA_WEIGHT_NAME_SAFE:
weight_name = None
if model_file is None:
if weight_name is None:
weight_name = LORA_WEIGHT_NAME
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = torch.load(model_file, map_location="cpu")
else:
state_dict = pretrained_model_name_or_path_or_dict
......@@ -195,8 +225,9 @@ class UNet2DConditionLoadersMixin:
self,
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
weights_name: str = LORA_WEIGHT_NAME,
weights_name: str = None,
save_function: Callable = None,
safe_serialization: bool = False,
):
r"""
Save an attention processor to a directory, so that it can be re-loaded using the
......@@ -219,7 +250,13 @@ class UNet2DConditionLoadersMixin:
return
if save_function is None:
save_function = torch.save
if safe_serialization:
def save_function(weights, filename):
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
else:
save_function = torch.save
os.makedirs(save_directory, exist_ok=True)
......@@ -237,6 +274,12 @@ class UNet2DConditionLoadersMixin:
if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and is_main_process:
os.remove(full_filename)
if weights_name is None:
if safe_serialization:
weights_name = LORA_WEIGHT_NAME_SAFE
else:
weights_name = LORA_WEIGHT_NAME
# Save the model
save_function(state_dict, os.path.join(save_directory, weights_name))
......
......@@ -14,6 +14,7 @@
# limitations under the License.
import gc
import os
import tempfile
import unittest
......@@ -372,6 +373,65 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.to(torch_device)
new_model.load_attn_procs(tmpdirname)
with torch.no_grad():
new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
assert (sample - new_sample).abs().max() < 1e-4
# LoRA and no LoRA should NOT be the same
assert (sample - old_sample).abs().max() > 1e-4
def test_lora_save_load_safetensors(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
with torch.no_grad():
old_sample = model(**inputs_dict).sample
lora_attn_procs = {}
for name in model.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = model.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
# add 1 to weights to mock trained weights
with torch.no_grad():
lora_attn_procs[name].to_q_lora.up.weight += 1
lora_attn_procs[name].to_k_lora.up.weight += 1
lora_attn_procs[name].to_v_lora.up.weight += 1
lora_attn_procs[name].to_out_lora.up.weight += 1
model.set_attn_processor(lora_attn_procs)
with torch.no_grad():
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname, safe_serialization=True)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.to(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment