"tools/vscode:/vscode.git/clone" did not exist on "655b7f7fd1d9e183955baedf4e6036c52db9a0ff"
Unverified Commit 67603002 authored by imbr92's avatar imbr92 Committed by GitHub
Browse files

Add --lora_alpha and metadata handling to train_dreambooth_lora_sana.py (#11744)


Co-authored-by: default avatarLinoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
parent 798265f2
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import logging import logging
import os import os
import sys import sys
...@@ -20,6 +21,8 @@ import tempfile ...@@ -20,6 +21,8 @@ import tempfile
import safetensors import safetensors
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
sys.path.append("..") sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
...@@ -204,3 +207,42 @@ class DreamBoothLoRASANA(ExamplesTestsAccelerate): ...@@ -204,3 +207,42 @@ class DreamBoothLoRASANA(ExamplesTestsAccelerate):
run_command(self._launch_args + resume_run_args) run_command(self._launch_args + resume_run_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"}) self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
def test_dreambooth_lora_sana_with_metadata(self):
lora_alpha = 8
rank = 4
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--resolution=32
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=4
--lora_alpha={lora_alpha}
--rank={rank}
--checkpointing_steps=2
--max_sequence_length 166
""".split()
test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
self.assertTrue(os.path.isfile(state_dict_file))
# Check if the metadata was properly serialized.
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
metadata = f.metadata() or {}
metadata.pop("format", None)
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
if raw:
raw = json.loads(raw)
loaded_lora_alpha = raw["transformer.lora_alpha"]
self.assertTrue(loaded_lora_alpha == lora_alpha)
loaded_lora_rank = raw["transformer.r"]
self.assertTrue(loaded_lora_rank == rank)
...@@ -52,6 +52,7 @@ from diffusers import ( ...@@ -52,6 +52,7 @@ from diffusers import (
) )
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import ( from diffusers.training_utils import (
_collate_lora_metadata,
cast_training_params, cast_training_params,
compute_density_for_timestep_sampling, compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3, compute_loss_weighting_for_sd3,
...@@ -323,9 +324,13 @@ def parse_args(input_args=None): ...@@ -323,9 +324,13 @@ def parse_args(input_args=None):
default=4, default=4,
help=("The dimension of the LoRA update matrices."), help=("The dimension of the LoRA update matrices."),
) )
parser.add_argument(
"--lora_alpha",
type=int,
default=4,
help="LoRA alpha to be used for additional scaling.",
)
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers") parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
parser.add_argument( parser.add_argument(
"--with_prior_preservation", "--with_prior_preservation",
default=False, default=False,
...@@ -1023,7 +1028,7 @@ def main(args): ...@@ -1023,7 +1028,7 @@ def main(args):
# now we will add new LoRA weights the transformer layers # now we will add new LoRA weights the transformer layers
transformer_lora_config = LoraConfig( transformer_lora_config = LoraConfig(
r=args.rank, r=args.rank,
lora_alpha=args.rank, lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout, lora_dropout=args.lora_dropout,
init_lora_weights="gaussian", init_lora_weights="gaussian",
target_modules=target_modules, target_modules=target_modules,
...@@ -1039,10 +1044,11 @@ def main(args): ...@@ -1039,10 +1044,11 @@ def main(args):
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process: if accelerator.is_main_process:
transformer_lora_layers_to_save = None transformer_lora_layers_to_save = None
modules_to_save = {}
for model in models: for model in models:
if isinstance(model, type(unwrap_model(transformer))): if isinstance(model, type(unwrap_model(transformer))):
transformer_lora_layers_to_save = get_peft_model_state_dict(model) transformer_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["transformer"] = model
else: else:
raise ValueError(f"unexpected save model: {model.__class__}") raise ValueError(f"unexpected save model: {model.__class__}")
...@@ -1052,6 +1058,7 @@ def main(args): ...@@ -1052,6 +1058,7 @@ def main(args):
SanaPipeline.save_lora_weights( SanaPipeline.save_lora_weights(
output_dir, output_dir,
transformer_lora_layers=transformer_lora_layers_to_save, transformer_lora_layers=transformer_lora_layers_to_save,
**_collate_lora_metadata(modules_to_save),
) )
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
...@@ -1507,15 +1514,18 @@ def main(args): ...@@ -1507,15 +1514,18 @@ def main(args):
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
transformer = unwrap_model(transformer) transformer = unwrap_model(transformer)
modules_to_save = {}
if args.upcast_before_saving: if args.upcast_before_saving:
transformer.to(torch.float32) transformer.to(torch.float32)
else: else:
transformer = transformer.to(weight_dtype) transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer) transformer_lora_layers = get_peft_model_state_dict(transformer)
modules_to_save["transformer"] = transformer
SanaPipeline.save_lora_weights( SanaPipeline.save_lora_weights(
save_directory=args.output_dir, save_directory=args.output_dir,
transformer_lora_layers=transformer_lora_layers, transformer_lora_layers=transformer_lora_layers,
**_collate_lora_metadata(modules_to_save),
) )
# Final inference # Final inference
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment