Unverified Commit f0dba33d authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[training] show how metadata stuff should be incorporated in training scripts. (#11707)



* show how metadata stuff should be incorporated in training scripts.

* typing

* fix

---------
Co-authored-by: default avatarLinoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
parent d1db4f85
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import logging import logging
import os import os
import sys import sys
...@@ -20,6 +21,8 @@ import tempfile ...@@ -20,6 +21,8 @@ import tempfile
import safetensors import safetensors
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
sys.path.append("..") sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
...@@ -234,3 +237,45 @@ class DreamBoothLoRAFlux(ExamplesTestsAccelerate): ...@@ -234,3 +237,45 @@ class DreamBoothLoRAFlux(ExamplesTestsAccelerate):
run_command(self._launch_args + resume_run_args) run_command(self._launch_args + resume_run_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"}) self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
def test_dreambooth_lora_with_metadata(self):
# Use a `lora_alpha` that is different from `rank`.
lora_alpha = 8
rank = 4
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--lora_alpha={lora_alpha}
--rank={rank}
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
self.assertTrue(os.path.isfile(state_dict_file))
# Check if the metadata was properly serialized.
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
metadata = f.metadata() or {}
metadata.pop("format", None)
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
if raw:
raw = json.loads(raw)
loaded_lora_alpha = raw["transformer.lora_alpha"]
self.assertTrue(loaded_lora_alpha == lora_alpha)
loaded_lora_rank = raw["transformer.r"]
self.assertTrue(loaded_lora_rank == rank)
...@@ -27,7 +27,6 @@ from pathlib import Path ...@@ -27,7 +27,6 @@ from pathlib import Path
import numpy as np import numpy as np
import torch import torch
import torch.utils.checkpoint
import transformers import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
...@@ -53,6 +52,7 @@ from diffusers import ( ...@@ -53,6 +52,7 @@ from diffusers import (
) )
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import ( from diffusers.training_utils import (
_collate_lora_metadata,
_set_state_dict_into_text_encoder, _set_state_dict_into_text_encoder,
cast_training_params, cast_training_params,
compute_density_for_timestep_sampling, compute_density_for_timestep_sampling,
...@@ -358,7 +358,12 @@ def parse_args(input_args=None): ...@@ -358,7 +358,12 @@ def parse_args(input_args=None):
default=4, default=4,
help=("The dimension of the LoRA update matrices."), help=("The dimension of the LoRA update matrices."),
) )
parser.add_argument(
"--lora_alpha",
type=int,
default=4,
help="LoRA alpha to be used for additional scaling.",
)
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers") parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
parser.add_argument( parser.add_argument(
...@@ -1238,7 +1243,7 @@ def main(args): ...@@ -1238,7 +1243,7 @@ def main(args):
# now we will add new LoRA weights the transformer layers # now we will add new LoRA weights the transformer layers
transformer_lora_config = LoraConfig( transformer_lora_config = LoraConfig(
r=args.rank, r=args.rank,
lora_alpha=args.rank, lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout, lora_dropout=args.lora_dropout,
init_lora_weights="gaussian", init_lora_weights="gaussian",
target_modules=target_modules, target_modules=target_modules,
...@@ -1247,7 +1252,7 @@ def main(args): ...@@ -1247,7 +1252,7 @@ def main(args):
if args.train_text_encoder: if args.train_text_encoder:
text_lora_config = LoraConfig( text_lora_config = LoraConfig(
r=args.rank, r=args.rank,
lora_alpha=args.rank, lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout, lora_dropout=args.lora_dropout,
init_lora_weights="gaussian", init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
...@@ -1264,12 +1269,14 @@ def main(args): ...@@ -1264,12 +1269,14 @@ def main(args):
if accelerator.is_main_process: if accelerator.is_main_process:
transformer_lora_layers_to_save = None transformer_lora_layers_to_save = None
text_encoder_one_lora_layers_to_save = None text_encoder_one_lora_layers_to_save = None
modules_to_save = {}
for model in models: for model in models:
if isinstance(model, type(unwrap_model(transformer))): if isinstance(model, type(unwrap_model(transformer))):
transformer_lora_layers_to_save = get_peft_model_state_dict(model) transformer_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["transformer"] = model
elif isinstance(model, type(unwrap_model(text_encoder_one))): elif isinstance(model, type(unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["text_encoder"] = model
else: else:
raise ValueError(f"unexpected save model: {model.__class__}") raise ValueError(f"unexpected save model: {model.__class__}")
...@@ -1280,6 +1287,7 @@ def main(args): ...@@ -1280,6 +1287,7 @@ def main(args):
output_dir, output_dir,
transformer_lora_layers=transformer_lora_layers_to_save, transformer_lora_layers=transformer_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
**_collate_lora_metadata(modules_to_save),
) )
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
...@@ -1889,16 +1897,19 @@ def main(args): ...@@ -1889,16 +1897,19 @@ def main(args):
# Save the lora layers # Save the lora layers
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
modules_to_save = {}
transformer = unwrap_model(transformer) transformer = unwrap_model(transformer)
if args.upcast_before_saving: if args.upcast_before_saving:
transformer.to(torch.float32) transformer.to(torch.float32)
else: else:
transformer = transformer.to(weight_dtype) transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer) transformer_lora_layers = get_peft_model_state_dict(transformer)
modules_to_save["transformer"] = transformer
if args.train_text_encoder: if args.train_text_encoder:
text_encoder_one = unwrap_model(text_encoder_one) text_encoder_one = unwrap_model(text_encoder_one)
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32)) text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
modules_to_save["text_encoder"] = text_encoder_one
else: else:
text_encoder_lora_layers = None text_encoder_lora_layers = None
...@@ -1906,6 +1917,7 @@ def main(args): ...@@ -1906,6 +1917,7 @@ def main(args):
save_directory=args.output_dir, save_directory=args.output_dir,
transformer_lora_layers=transformer_lora_layers, transformer_lora_layers=transformer_lora_layers,
text_encoder_lora_layers=text_encoder_lora_layers, text_encoder_lora_layers=text_encoder_lora_layers,
**_collate_lora_metadata(modules_to_save),
) )
# Final inference # Final inference
......
...@@ -247,6 +247,14 @@ def _set_state_dict_into_text_encoder( ...@@ -247,6 +247,14 @@ def _set_state_dict_into_text_encoder(
set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default") set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default")
def _collate_lora_metadata(modules_to_save: Dict[str, torch.nn.Module]) -> Dict[str, Any]:
metadatas = {}
for module_name, module in modules_to_save.items():
if module is not None:
metadatas[f"{module_name}_lora_adapter_metadata"] = module.peft_config["default"].to_dict()
return metadatas
def compute_density_for_timestep_sampling( def compute_density_for_timestep_sampling(
weighting_scheme: str, weighting_scheme: str,
batch_size: int, batch_size: int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment