Unverified Commit bfa0aa4f authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

[SD3-5 dreambooth lora] update model cards (#9749)



* improve readme

* style

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent ab1b7b20
...@@ -86,6 +86,15 @@ def save_model_card( ...@@ -86,6 +86,15 @@ def save_model_card(
validation_prompt=None, validation_prompt=None,
repo_folder=None, repo_folder=None,
): ):
if "large" in base_model:
model_variant = "SD3.5-Large"
license_url = "https://huggingface.co/stabilityai/stable-diffusion-3.5-large/blob/main/LICENSE.md"
variant_tags = ["sd3.5-large", "sd3.5", "sd3.5-diffusers"]
else:
model_variant = "SD3"
license_url = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE.md"
variant_tags = ["sd3", "sd3-diffusers"]
widget_dict = [] widget_dict = []
if images is not None: if images is not None:
for i, image in enumerate(images): for i, image in enumerate(images):
...@@ -95,7 +104,7 @@ def save_model_card( ...@@ -95,7 +104,7 @@ def save_model_card(
) )
model_description = f""" model_description = f"""
# SD3 DreamBooth LoRA - {repo_id} # {model_variant} DreamBooth LoRA - {repo_id}
<Gallery /> <Gallery />
...@@ -120,7 +129,7 @@ You should use `{instance_prompt}` to trigger the image generation. ...@@ -120,7 +129,7 @@ You should use `{instance_prompt}` to trigger the image generation.
```py ```py
from diffusers import AutoPipelineForText2Image from diffusers import AutoPipelineForText2Image
import torch import torch
pipeline = AutoPipelineForText2Image.from_pretrained('stabilityai/stable-diffusion-3-medium-diffusers', torch_dtype=torch.float16).to('cuda') pipeline = AutoPipelineForText2Image.from_pretrained({base_model}, torch_dtype=torch.float16).to('cuda')
pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors') pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')
image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0] image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
``` ```
...@@ -135,7 +144,7 @@ For more details, including weighting, merging and fusing LoRAs, check the [docu ...@@ -135,7 +144,7 @@ For more details, including weighting, merging and fusing LoRAs, check the [docu
## License ## License
Please adhere to the licensing terms as described [here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE). Please adhere to the licensing terms as described [here]({license_url}).
""" """
model_card = load_or_create_model_card( model_card = load_or_create_model_card(
repo_id_or_path=repo_id, repo_id_or_path=repo_id,
...@@ -151,11 +160,11 @@ Please adhere to the licensing terms as described [here](https://huggingface.co/ ...@@ -151,11 +160,11 @@ Please adhere to the licensing terms as described [here](https://huggingface.co/
"diffusers-training", "diffusers-training",
"diffusers", "diffusers",
"lora", "lora",
"sd3",
"sd3-diffusers",
"template:sd-lora", "template:sd-lora",
] ]
tags += variant_tags
model_card = populate_model_card(model_card, tags=tags) model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md")) model_card.save(os.path.join(repo_folder, "README.md"))
......
...@@ -77,6 +77,15 @@ def save_model_card( ...@@ -77,6 +77,15 @@ def save_model_card(
validation_prompt=None, validation_prompt=None,
repo_folder=None, repo_folder=None,
): ):
if "large" in base_model:
model_variant = "SD3.5-Large"
license_url = "https://huggingface.co/stabilityai/stable-diffusion-3.5-large/blob/main/LICENSE.md"
variant_tags = ["sd3.5-large", "sd3.5", "sd3.5-diffusers"]
else:
model_variant = "SD3"
license_url = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE.md"
variant_tags = ["sd3", "sd3-diffusers"]
widget_dict = [] widget_dict = []
if images is not None: if images is not None:
for i, image in enumerate(images): for i, image in enumerate(images):
...@@ -86,7 +95,7 @@ def save_model_card( ...@@ -86,7 +95,7 @@ def save_model_card(
) )
model_description = f""" model_description = f"""
# SD3 DreamBooth - {repo_id} # {model_variant} DreamBooth - {repo_id}
<Gallery /> <Gallery />
...@@ -113,7 +122,7 @@ image = pipeline('{validation_prompt if validation_prompt else instance_prompt}' ...@@ -113,7 +122,7 @@ image = pipeline('{validation_prompt if validation_prompt else instance_prompt}'
## License ## License
Please adhere to the licensing terms as described `[here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE)`. Please adhere to the licensing terms as described `[here]({license_url})`.
""" """
model_card = load_or_create_model_card( model_card = load_or_create_model_card(
repo_id_or_path=repo_id, repo_id_or_path=repo_id,
...@@ -128,10 +137,9 @@ Please adhere to the licensing terms as described `[here](https://huggingface.co ...@@ -128,10 +137,9 @@ Please adhere to the licensing terms as described `[here](https://huggingface.co
"text-to-image", "text-to-image",
"diffusers-training", "diffusers-training",
"diffusers", "diffusers",
"sd3",
"sd3-diffusers",
"template:sd-lora", "template:sd-lora",
] ]
tags += variant_tags
model_card = populate_model_card(model_card, tags=tags) model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md")) model_card.save(os.path.join(repo_folder, "README.md"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment