Unverified Commit 29dfe22a authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

[advanced dreambooth lora sdxl training script] load pipeline for inference...


[advanced dreambooth lora sdxl training script] load pipeline for inference only if validation prompt is used (#6171)

* load pipeline for inference only if validation prompt is used

* move things outside

* load pipeline for inference only if validation prompt is used

* fix readme when validation prompt is used

---------
Co-authored-by: default avatarlinoytsaban <linoy@huggingface.co>
Co-authored-by: default avatarapolinário <joaopaulo.passos@gmail.com>
parent 56806cdb
...@@ -112,7 +112,7 @@ def save_model_card( ...@@ -112,7 +112,7 @@ def save_model_card(
repo_folder=None, repo_folder=None,
vae_path=None, vae_path=None,
): ):
img_str = "widget:\n" if images else "" img_str = "widget:\n"
for i, image in enumerate(images): for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png")) image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f""" img_str += f"""
...@@ -121,6 +121,10 @@ def save_model_card( ...@@ -121,6 +121,10 @@ def save_model_card(
url: url:
"image_{i}.png" "image_{i}.png"
""" """
if not images:
img_str += f"""
- text: '{instance_prompt}'
"""
trigger_str = f"You should use {instance_prompt} to trigger the image generation." trigger_str = f"You should use {instance_prompt} to trigger the image generation."
diffusers_imports_pivotal = "" diffusers_imports_pivotal = ""
...@@ -157,8 +161,6 @@ tags: ...@@ -157,8 +161,6 @@ tags:
base_model: {base_model} base_model: {base_model}
instance_prompt: {instance_prompt} instance_prompt: {instance_prompt}
license: openrail++ license: openrail++
widget:
- text: '{validation_prompt if validation_prompt else instance_prompt}'
--- ---
""" """
...@@ -2010,43 +2012,42 @@ def main(args): ...@@ -2010,43 +2012,42 @@ def main(args):
text_encoder_lora_layers=text_encoder_lora_layers, text_encoder_lora_layers=text_encoder_lora_layers,
text_encoder_2_lora_layers=text_encoder_2_lora_layers, text_encoder_2_lora_layers=text_encoder_2_lora_layers,
) )
images = []
if args.validation_prompt and args.num_validation_images > 0:
# Final inference
# Load previous pipeline
vae = AutoencoderKL.from_pretrained(
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
# Final inference # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
# Load previous pipeline scheduler_args = {}
vae = AutoencoderKL.from_pretrained(
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
scheduler_args = {}
if "variance_type" in pipeline.scheduler.config: if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type variance_type = pipeline.scheduler.config.variance_type
if variance_type in ["learned", "learned_range"]: if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small" variance_type = "fixed_small"
scheduler_args["variance_type"] = variance_type scheduler_args["variance_type"] = variance_type
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
# load attention processors # load attention processors
pipeline.load_lora_weights(args.output_dir) pipeline.load_lora_weights(args.output_dir)
# run inference # run inference
images = []
if args.validation_prompt and args.num_validation_images > 0:
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device)
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
images = [ images = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment