Unverified Commit dc5b4e23 authored by Haofan Wang's avatar Haofan Wang Committed by GitHub
Browse files

Update train_text_to_image_lora.py (#2767)

* Update train_text_to_image_lora.py

* Update train_text_to_image_lora.py

* Update train_text_to_image_lora.py

* Update train_text_to_image_lora.py

* format
parent 0d7aac3e
...@@ -582,7 +582,7 @@ def main(): ...@@ -582,7 +582,7 @@ def main():
else: else:
optimizer_cls = torch.optim.AdamW optimizer_cls = torch.optim.AdamW
if args.peft: if args.use_peft:
# Optimizer creation # Optimizer creation
params_to_optimize = ( params_to_optimize = (
itertools.chain(unet.parameters(), text_encoder.parameters()) itertools.chain(unet.parameters(), text_encoder.parameters())
...@@ -724,7 +724,7 @@ def main(): ...@@ -724,7 +724,7 @@ def main():
) )
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.
if args.peft: if args.use_peft:
if args.train_text_encoder: if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler unet, text_encoder, optimizer, train_dataloader, lr_scheduler
...@@ -842,7 +842,7 @@ def main(): ...@@ -842,7 +842,7 @@ def main():
# Backpropagate # Backpropagate
accelerator.backward(loss) accelerator.backward(loss)
if accelerator.sync_gradients: if accelerator.sync_gradients:
if args.peft: if args.use_peft:
params_to_clip = ( params_to_clip = (
itertools.chain(unet.parameters(), text_encoder.parameters()) itertools.chain(unet.parameters(), text_encoder.parameters())
if args.train_text_encoder if args.train_text_encoder
...@@ -922,18 +922,22 @@ def main(): ...@@ -922,18 +922,22 @@ def main():
if accelerator.is_main_process: if accelerator.is_main_process:
if args.use_peft: if args.use_peft:
lora_config = {} lora_config = {}
state_dict = get_peft_model_state_dict(unet, state_dict=accelerator.get_state_dict(unet)) unwarpped_unet = accelerator.unwrap_model(unet)
lora_config["peft_config"] = unet.get_peft_config_as_dict(inference=True) state_dict = get_peft_model_state_dict(unwarpped_unet, state_dict=accelerator.get_state_dict(unet))
lora_config["peft_config"] = unwarpped_unet.get_peft_config_as_dict(inference=True)
if args.train_text_encoder: if args.train_text_encoder:
unwarpped_text_encoder = accelerator.unwrap_model(text_encoder)
text_encoder_state_dict = get_peft_model_state_dict( text_encoder_state_dict = get_peft_model_state_dict(
text_encoder, state_dict=accelerator.get_state_dict(text_encoder) unwarpped_text_encoder, state_dict=accelerator.get_state_dict(text_encoder)
) )
text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()}
state_dict.update(text_encoder_state_dict) state_dict.update(text_encoder_state_dict)
lora_config["text_encoder_peft_config"] = text_encoder.get_peft_config_as_dict(inference=True) lora_config["text_encoder_peft_config"] = unwarpped_text_encoder.get_peft_config_as_dict(
inference=True
)
accelerator.save(state_dict, os.path.join(args.output_dir, f"{args.instance_prompt}_lora.pt")) accelerator.save(state_dict, os.path.join(args.output_dir, f"{global_step}_lora.pt"))
with open(os.path.join(args.output_dir, f"{args.instance_prompt}_lora_config.json"), "w") as f: with open(os.path.join(args.output_dir, f"{global_step}_lora_config.json"), "w") as f:
json.dump(lora_config, f) json.dump(lora_config, f)
else: else:
unet = unet.to(torch.float32) unet = unet.to(torch.float32)
...@@ -957,12 +961,12 @@ def main(): ...@@ -957,12 +961,12 @@ def main():
if args.use_peft: if args.use_peft:
def load_and_set_lora_ckpt(pipe, ckpt_dir, instance_prompt, device, dtype): def load_and_set_lora_ckpt(pipe, ckpt_dir, global_step, device, dtype):
with open(f"{ckpt_dir}{instance_prompt}_lora_config.json", "r") as f: with open(os.path.join(args.output_dir, f"{global_step}_lora_config.json"), "r") as f:
lora_config = json.load(f) lora_config = json.load(f)
print(lora_config) print(lora_config)
checkpoint = f"{ckpt_dir}{instance_prompt}_lora.pt" checkpoint = os.path.join(args.output_dir, f"{global_step}_lora.pt")
lora_checkpoint_sd = torch.load(checkpoint) lora_checkpoint_sd = torch.load(checkpoint)
unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k} unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k}
text_encoder_lora_ds = { text_encoder_lora_ds = {
...@@ -985,9 +989,7 @@ def main(): ...@@ -985,9 +989,7 @@ def main():
pipe.to(device) pipe.to(device)
return pipe return pipe
pipeline = load_and_set_lora_ckpt( pipeline = load_and_set_lora_ckpt(pipeline, args.output_dir, global_step, accelerator.device, weight_dtype)
pipeline, args.output_dir, args.instance_prompt, accelerator.device, weight_dtype
)
else: else:
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device)
...@@ -995,7 +997,10 @@ def main(): ...@@ -995,7 +997,10 @@ def main():
pipeline.unet.load_attn_procs(args.output_dir) pipeline.unet.load_attn_procs(args.output_dir)
# run inference # run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
else:
generator = None
images = [] images = []
for _ in range(args.num_validation_images): for _ in range(args.num_validation_images):
images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment