"...en/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "8f2c7b4df02686d01d2601f128f09e84750ca970"
Unverified Commit 1f4deb69 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding support for `safetensors` and LoRa. (#2448)

* Adding support for `safetensors` and LoRa.

* Adding metadata.
parent f20c8f5a
...@@ -19,13 +19,18 @@ import torch ...@@ -19,13 +19,18 @@ import torch
from .models.cross_attention import LoRACrossAttnProcessor from .models.cross_attention import LoRACrossAttnProcessor
from .models.modeling_utils import _get_model_file from .models.modeling_utils import _get_model_file
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, logging from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, is_safetensors_available, logging
if is_safetensors_available():
import safetensors
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
class AttnProcsLayers(torch.nn.Module): class AttnProcsLayers(torch.nn.Module):
...@@ -136,28 +141,53 @@ class UNet2DConditionLoadersMixin: ...@@ -136,28 +141,53 @@ class UNet2DConditionLoadersMixin:
use_auth_token = kwargs.pop("use_auth_token", None) use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", LORA_WEIGHT_NAME) weight_name = kwargs.pop("weight_name", None)
user_agent = { user_agent = {
"file_type": "attn_procs_weights", "file_type": "attn_procs_weights",
"framework": "pytorch", "framework": "pytorch",
} }
model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict): if not isinstance(pretrained_model_name_or_path_or_dict, dict):
model_file = _get_model_file( if is_safetensors_available():
pretrained_model_name_or_path_or_dict, if weight_name is None:
weights_name=weight_name, weight_name = LORA_WEIGHT_NAME_SAFE
cache_dir=cache_dir, try:
force_download=force_download, model_file = _get_model_file(
resume_download=resume_download, pretrained_model_name_or_path_or_dict,
proxies=proxies, weights_name=weight_name,
local_files_only=local_files_only, cache_dir=cache_dir,
use_auth_token=use_auth_token, force_download=force_download,
revision=revision, resume_download=resume_download,
subfolder=subfolder, proxies=proxies,
user_agent=user_agent, local_files_only=local_files_only,
) use_auth_token=use_auth_token,
state_dict = torch.load(model_file, map_location="cpu") revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except EnvironmentError:
if weight_name == LORA_WEIGHT_NAME_SAFE:
weight_name = None
if model_file is None:
if weight_name is None:
weight_name = LORA_WEIGHT_NAME
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = torch.load(model_file, map_location="cpu")
else: else:
state_dict = pretrained_model_name_or_path_or_dict state_dict = pretrained_model_name_or_path_or_dict
...@@ -195,8 +225,9 @@ class UNet2DConditionLoadersMixin: ...@@ -195,8 +225,9 @@ class UNet2DConditionLoadersMixin:
self, self,
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
is_main_process: bool = True, is_main_process: bool = True,
weights_name: str = LORA_WEIGHT_NAME, weights_name: str = None,
save_function: Callable = None, save_function: Callable = None,
safe_serialization: bool = False,
): ):
r""" r"""
Save an attention processor to a directory, so that it can be re-loaded using the Save an attention processor to a directory, so that it can be re-loaded using the
...@@ -219,7 +250,13 @@ class UNet2DConditionLoadersMixin: ...@@ -219,7 +250,13 @@ class UNet2DConditionLoadersMixin:
return return
if save_function is None: if save_function is None:
save_function = torch.save if safe_serialization:
def save_function(weights, filename):
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
else:
save_function = torch.save
os.makedirs(save_directory, exist_ok=True) os.makedirs(save_directory, exist_ok=True)
...@@ -237,6 +274,12 @@ class UNet2DConditionLoadersMixin: ...@@ -237,6 +274,12 @@ class UNet2DConditionLoadersMixin:
if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and is_main_process: if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and is_main_process:
os.remove(full_filename) os.remove(full_filename)
if weights_name is None:
if safe_serialization:
weights_name = LORA_WEIGHT_NAME_SAFE
else:
weights_name = LORA_WEIGHT_NAME
# Save the model # Save the model
save_function(state_dict, os.path.join(save_directory, weights_name)) save_function(state_dict, os.path.join(save_directory, weights_name))
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import gc import gc
import os
import tempfile import tempfile
import unittest import unittest
...@@ -372,6 +373,65 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -372,6 +373,65 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname) model.save_attn_procs(tmpdirname)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.to(torch_device)
new_model.load_attn_procs(tmpdirname)
with torch.no_grad():
new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
assert (sample - new_sample).abs().max() < 1e-4
# LoRA and no LoRA should NOT be the same
assert (sample - old_sample).abs().max() > 1e-4
def test_lora_save_load_safetensors(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
with torch.no_grad():
old_sample = model(**inputs_dict).sample
lora_attn_procs = {}
for name in model.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = model.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
# add 1 to weights to mock trained weights
with torch.no_grad():
lora_attn_procs[name].to_q_lora.up.weight += 1
lora_attn_procs[name].to_k_lora.up.weight += 1
lora_attn_procs[name].to_v_lora.up.weight += 1
lora_attn_procs[name].to_out_lora.up.weight += 1
model.set_attn_processor(lora_attn_procs)
with torch.no_grad():
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname, safe_serialization=True)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
torch.manual_seed(0) torch.manual_seed(0)
new_model = self.model_class(**init_dict) new_model = self.model_class(**init_dict)
new_model.to(torch_device) new_model.to(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment