Unverified Commit 48706267 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Examples] Improve the model card pushed from the `train_text_to_image.py` script (#3810)



* refactor: readme serialized from the example when push_to_hub is True.

* fix: batch size arg.

* a bit better formatting

* minor fixes.

* add note on env.

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* condition wandb info better

* make mixed_precision assignment in cli args explicit.

* separate inference block for sample images.

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* address more comments.

* autocast mode.

* correct none image type problem.

* ifx: list assignment.

* minor fix.

---------
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent 66674330
...@@ -55,11 +55,11 @@ With `gradient_checkpointing` and `mixed_precision` it should be possible to fin ...@@ -55,11 +55,11 @@ With `gradient_checkpointing` and `mixed_precision` it should be possible to fin
<!-- accelerate_snippet_start --> <!-- accelerate_snippet_start -->
```bash ```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4" export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export dataset_name="lambdalabs/pokemon-blip-captions" export DATASET_NAME="lambdalabs/pokemon-blip-captions"
accelerate launch --mixed_precision="fp16" train_text_to_image.py \ accelerate launch --mixed_precision="fp16" train_text_to_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \ --pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$dataset_name \ --dataset_name=$DATASET_NAME \
--use_ema \ --use_ema \
--resolution=512 --center_crop --random_flip \ --resolution=512 --center_crop --random_flip \
--train_batch_size=1 \ --train_batch_size=1 \
...@@ -133,11 +133,11 @@ for running distributed training with `accelerate`. Here is an example command: ...@@ -133,11 +133,11 @@ for running distributed training with `accelerate`. Here is an example command:
```bash ```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4" export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export dataset_name="lambdalabs/pokemon-blip-captions" export DATASET_NAME="lambdalabs/pokemon-blip-captions"
accelerate launch --mixed_precision="fp16" --multi_gpu train_text_to_image.py \ accelerate launch --mixed_precision="fp16" --multi_gpu train_text_to_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \ --pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$dataset_name \ --dataset_name=$DATASET_NAME \
--use_ema \ --use_ema \
--resolution=512 --center_crop --random_flip \ --resolution=512 --center_crop --random_flip \
--train_batch_size=1 \ --train_batch_size=1 \
...@@ -274,11 +274,11 @@ pip install -U -r requirements_flax.txt ...@@ -274,11 +274,11 @@ pip install -U -r requirements_flax.txt
```bash ```bash
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export dataset_name="lambdalabs/pokemon-blip-captions" export DATASET_NAME="lambdalabs/pokemon-blip-captions"
python train_text_to_image_flax.py \ python train_text_to_image_flax.py \
--pretrained_model_name_or_path=$MODEL_NAME \ --pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$dataset_name \ --dataset_name=$DATASET_NAME \
--resolution=512 --center_crop --random_flip \ --resolution=512 --center_crop --random_flip \
--train_batch_size=1 \ --train_batch_size=1 \
--mixed_precision="fp16" \ --mixed_precision="fp16" \
......
...@@ -35,6 +35,7 @@ from accelerate.utils import ProjectConfiguration, set_seed ...@@ -35,6 +35,7 @@ from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder from huggingface_hub import create_repo, upload_folder
from packaging import version from packaging import version
from PIL import Image
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
...@@ -62,6 +63,92 @@ DATASET_NAME_MAPPING = { ...@@ -62,6 +63,92 @@ DATASET_NAME_MAPPING = {
} }
def make_image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
def save_model_card(
args,
repo_id: str,
images=None,
repo_folder=None,
):
img_str = ""
if len(images) > 0:
image_grid = make_image_grid(images, 1, len(args.validation_prompts))
image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
img_str += "![val_imgs_grid](./val_imgs_grid.png)\n"
yaml = f"""
---
license: creativeml-openrail-m
base_model: {args.pretrained_model_name_or_path}
datasets:
- {args.dataset_name}
tags:
- stable-diffusion
- stable-diffusion-diffusers
- text-to-image
- diffusers
inference: true
---
"""
model_card = f"""
# Text-to-image finetuning - {repo_id}
This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
{img_str}
## Pipeline usage
You can use the pipeline like so:
```python
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype=torch.float16)
prompt = "{args.validation_prompts[0]}"
image = pipeline(prompt).images[0]
image.save("my_image.png")
```
## Training info
These are the key hyperparameters used during training:
* Epochs: {args.num_train_epochs}
* Learning rate: {args.learning_rate}
* Batch size: {args.train_batch_size}
* Gradient accumulation steps: {args.gradient_accumulation_steps}
* Image resolution: {args.resolution}
* Mixed-precision: {args.mixed_precision}
"""
wandb_info = ""
if is_wandb_available():
wandb_run_url = None
if wandb.run is not None:
wandb_run_url = wandb.run.url
if wandb_run_url is not None:
wandb_info = f"""
More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
"""
model_card += wandb_info
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch): def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch):
logger.info("Running validation... ") logger.info("Running validation... ")
...@@ -112,6 +199,8 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight ...@@ -112,6 +199,8 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight
del pipeline del pipeline
torch.cuda.empty_cache() torch.cuda.empty_cache()
return images
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.") parser = argparse.ArgumentParser(description="Simple example of a training script.")
...@@ -747,8 +836,10 @@ def main(): ...@@ -747,8 +836,10 @@ def main():
weight_dtype = torch.float32 weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16": if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16 weight_dtype = torch.float16
args.mixed_precision = accelerator.mixed_precision
elif accelerator.mixed_precision == "bf16": elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16 weight_dtype = torch.bfloat16
args.mixed_precision = accelerator.mixed_precision
# Move text_encode and vae to gpu and cast to weight_dtype # Move text_encode and vae to gpu and cast to weight_dtype
text_encoder.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype)
...@@ -970,7 +1061,29 @@ def main(): ...@@ -970,7 +1061,29 @@ def main():
) )
pipeline.save_pretrained(args.output_dir) pipeline.save_pretrained(args.output_dir)
# Run a final round of inference.
images = []
if args.validation_prompts is not None:
logger.info("Running inference for collecting generated images...")
pipeline = pipeline.to(accelerator.device)
pipeline.torch_dtype = weight_dtype
pipeline.set_progress_bar_config(disable=True)
if args.enable_xformers_memory_efficient_attention:
pipeline.enable_xformers_memory_efficient_attention()
if args.seed is None:
generator = None
else:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
for i in range(len(args.validation_prompts)):
with torch.autocast("cuda"):
image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
images.append(image)
if args.push_to_hub: if args.push_to_hub:
save_model_card(args, repo_id, images, repo_folder=args.output_dir)
upload_folder( upload_folder(
repo_id=repo_id, repo_id=repo_id,
folder_path=args.output_dir, folder_path=args.output_dir,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment