Unverified Commit 9d974407 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Easy] fix: save_model_card utility of the DreamBooth SDXL LoRA script (#7258)

* fix: save_model_card utility.

* fix a little more to make it more lenient.

* remove lower()
parent d9a3b698
......@@ -114,7 +114,7 @@ def save_model_card(
)
model_description = f"""
# {'SDXL' if 'playgroundai' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id}
# {'SDXL' if 'playground' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id}
<Gallery />
......@@ -139,7 +139,7 @@ Weights for this model are available in Safetensors format.
[Download]({repo_id}/tree/main) them in the Files & versions tab.
"""
if "playgroundai" in args.pretrained_model_name_or_path:
if "playground" in base_model:
model_description += """\n
## License
......@@ -148,7 +148,7 @@ Please adhere to the licensing terms as described [here](https://huggingface.co/
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="openrail++" if "playgroundai" not in base_model else "playground-v2dot5-community",
license="openrail++" if "playground" not in base_model else "playground-v2dot5-community",
base_model=base_model,
prompt=instance_prompt,
model_description=model_description,
......@@ -162,7 +162,7 @@ Please adhere to the licensing terms as described [here](https://huggingface.co/
"lora" if not use_dora else "dora",
"template:sd-lora",
]
if "playgroundai" in base_model:
if "playground" in base_model:
tags.extend(["playground", "playground-diffusers"])
else:
tags.extend(["stable-diffusion-xl", "stable-diffusion-xl-diffusers"])
......@@ -206,7 +206,7 @@ def log_validation(
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
inference_ctx = (
contextlib.nullcontext() if "playgroundai" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast()
contextlib.nullcontext() if "playground" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast()
)
with inference_ctx:
......@@ -1509,7 +1509,7 @@ def main(args):
if accelerator.is_main_process:
tracker_name = (
"dreambooth-lora-sd-xl"
if "playgroundai" not in args.pretrained_model_name_or_path
if "playground" not in args.pretrained_model_name_or_path
else "dreambooth-lora-playground"
)
accelerator.init_trackers(tracker_name, config=vars(args))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment