"vscode:/vscode.git/clone" did not exist on "e05f03ae41540123a99afdacff86d26170c1315d"
Unverified Commit 48706267 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Examples] Improve the model card pushed from the `train_text_to_image.py` script (#3810)



* refactor: readme serialized from the example when push_to_hub is True.

* fix: batch size arg.

* a bit better formatting

* minor fixes.

* add note on env.

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* condition wandb info better

* make mixed_precision assignment in cli args explicit.

* separate inference block for sample images.

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* address more comments.

* autocast mode.

* correct none image type problem.

* ifx: list assignment.

* minor fix.

---------
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent 66674330
......@@ -55,11 +55,11 @@ With `gradient_checkpointing` and `mixed_precision` it should be possible to fin
<!-- accelerate_snippet_start -->
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export dataset_name="lambdalabs/pokemon-blip-captions"
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
accelerate launch --mixed_precision="fp16" train_text_to_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$dataset_name \
--dataset_name=$DATASET_NAME \
--use_ema \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
......@@ -133,11 +133,11 @@ for running distributed training with `accelerate`. Here is an example command:
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export dataset_name="lambdalabs/pokemon-blip-captions"
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
accelerate launch --mixed_precision="fp16" --multi_gpu train_text_to_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$dataset_name \
--dataset_name=$DATASET_NAME \
--use_ema \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
......@@ -274,11 +274,11 @@ pip install -U -r requirements_flax.txt
```bash
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export dataset_name="lambdalabs/pokemon-blip-captions"
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
python train_text_to_image_flax.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$dataset_name \
--dataset_name=$DATASET_NAME \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--mixed_precision="fp16" \
......
......@@ -35,6 +35,7 @@ from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
......@@ -62,6 +63,92 @@ DATASET_NAME_MAPPING = {
}
def make_image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
def save_model_card(
args,
repo_id: str,
images=None,
repo_folder=None,
):
img_str = ""
if len(images) > 0:
image_grid = make_image_grid(images, 1, len(args.validation_prompts))
image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
img_str += "![val_imgs_grid](./val_imgs_grid.png)\n"
yaml = f"""
---
license: creativeml-openrail-m
base_model: {args.pretrained_model_name_or_path}
datasets:
- {args.dataset_name}
tags:
- stable-diffusion
- stable-diffusion-diffusers
- text-to-image
- diffusers
inference: true
---
"""
model_card = f"""
# Text-to-image finetuning - {repo_id}
This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
{img_str}
## Pipeline usage
You can use the pipeline like so:
```python
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype=torch.float16)
prompt = "{args.validation_prompts[0]}"
image = pipeline(prompt).images[0]
image.save("my_image.png")
```
## Training info
These are the key hyperparameters used during training:
* Epochs: {args.num_train_epochs}
* Learning rate: {args.learning_rate}
* Batch size: {args.train_batch_size}
* Gradient accumulation steps: {args.gradient_accumulation_steps}
* Image resolution: {args.resolution}
* Mixed-precision: {args.mixed_precision}
"""
wandb_info = ""
if is_wandb_available():
wandb_run_url = None
if wandb.run is not None:
wandb_run_url = wandb.run.url
if wandb_run_url is not None:
wandb_info = f"""
More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
"""
model_card += wandb_info
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch):
logger.info("Running validation... ")
......@@ -112,6 +199,8 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight
del pipeline
torch.cuda.empty_cache()
return images
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
......@@ -747,8 +836,10 @@ def main():
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
args.mixed_precision = accelerator.mixed_precision
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
args.mixed_precision = accelerator.mixed_precision
# Move text_encode and vae to gpu and cast to weight_dtype
text_encoder.to(accelerator.device, dtype=weight_dtype)
......@@ -970,7 +1061,29 @@ def main():
)
pipeline.save_pretrained(args.output_dir)
# Run a final round of inference.
images = []
if args.validation_prompts is not None:
logger.info("Running inference for collecting generated images...")
pipeline = pipeline.to(accelerator.device)
pipeline.torch_dtype = weight_dtype
pipeline.set_progress_bar_config(disable=True)
if args.enable_xformers_memory_efficient_attention:
pipeline.enable_xformers_memory_efficient_attention()
if args.seed is None:
generator = None
else:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
for i in range(len(args.validation_prompts)):
with torch.autocast("cuda"):
image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
images.append(image)
if args.push_to_hub:
save_model_card(args, repo_id, images, repo_folder=args.output_dir)
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment