# 数据集 coco2017 ``` https://cocodataset.org/#home http://images.cocodataset.org/zips/train2017.zip # train dataset http://images.cocodataset.org/zips/val2017.zip # validation dataset http://images.cocodataset.org/zips/test2017.zip # test dataset http://images.cocodataset.org/zips/unlabeled2017.zip http://images.cocodataset.org/annotations/annotations_trainval2017.zip http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip http://images.cocodataset.org/annotations/image_info_test2017.zip http://images.cocodataset.org/annotations/image_info_unlabeled2017.zip ``` 宝可梦 ``` https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions ``` ## 数据集处理 deal_coco.py ``` #!/usr/bin/env python import argparse import json import os from typing import Iterator, Tuple def read_coco_annotations(path: str, image_root: str) -> Iterator[Tuple[str, str]]: with open(path, "r") as f: content = json.load(f)["annotations"] for record in content: image_id = record["image_id"] caption = record["caption"] image_name = f"{image_id:012d}.jpg" image_path = os.path.join(image_root, image_name) if not os.path.isfile(image_path): print(f"Cannot find `{image_path}`, skip.") continue caption = caption.strip().replace(",", "").replace("\n", "") yield image_name, caption def main(): parser = argparse.ArgumentParser(description="Converting COCO json annatotion into plain txt") parser.add_argument("--label", required=True, help="Path of the label file") parser.add_argument("--image", required=True, help="Path of the image root") parser.add_argument("--out", default="metadata.csv", help="Output path of the txt file") args = parser.parse_args() with open(args.out, "w") as f: f.write("file_name,text\n") for image_name, caption in read_coco_annotations(args.label, args.image): f.write(f"{image_name},{caption}\n") if __name__ == "__main__": main() ``` 处理数据 ``` python deal_coco.py --lablel captions_train2017.json --images train2017/ ``` # 模型 ``` https://github.com/huggingface/diffusers.git ``` # 环境搭建 ``` #将requiments.txt中torchvision注释 python setup.py install cd examples/text_to_image pip install -r requirements_sdxl.txt -i https://pypi.tuna.tsinghua.edu.cn/simple ``` # 训练 ## lora 训练 ``` export MODEL_NAME="/datasets/custom_model/stable-diffusion-xl-base-1.0" export OUTPUT_DIR="/path/to/sd_xl" #export DATASET_NAME="/datasets/custom_datasets/pokemon-blip-captions" export DATASET_NAME="/datasets/custom_datasets/coco2017/images/train2017" export VAE_NAME="/datasets/custom_model/sdxl-vae-fp16-fix" accelerate launch --multi_gpu examples/text_to_image/train_text_to_image_lora_sdxl.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --pretrained_vae_model_name_or_path=$VAE_NAME \ --dataset_name=$DATASET_NAME \ --enable_xformers_memory_efficient_attention \ --resolution=512 --center_crop --random_flip \ --proportion_empty_prompts=0.2 \ --train_batch_size=1 \ --gradient_accumulation_steps=4 --gradient_checkpointing \ --max_train_steps=10000 \ --use_8bit_adam \ --learning_rate=1e-06 --lr_scheduler="constant" --lr_warmup_steps=0 \ --mixed_precision="fp16" \ --validation_prompt="a cute Sundar Pichai creature" --validation_epochs 5 \ --checkpointing_steps=5000 \ --output_dir=$OUTPUT_DIR ``` ## fine train ``` export MODEL_NAME="/mnt/fs/user/llama/custom_model/stable-diffusion-xl-base-1.0" export OUTPUT_DIR="/path/to/sd_xl" #export DATASET_NAME="/datasets/custom_datasets/pokemon-blip-captions" export DATASET_NAME="/mnt/fs/user/llama/custom_datasets/coco2017/images/train2017" DATASET_NAME="/datasets/custom_datasets/coco2017/images/train2017" accelerate launch examples/text_to_image/train_text_to_image_sdxl.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --pretrained_vae_model_name_or_path=$VAE_NAME \ --train_data_dir=$DATASET_NAME --caption_column="text" \ --resolution=512 --random_flip \ --train_batch_size=1 \ --num_train_epochs=2 --checkpointing_steps=500 \ --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \ --mixed_precision="fp16" \ --seed=42 \ --output_dir="/path/to/sd-pokemon-model-lora-sdxl" \ --validation_prompt="cute dragon creature" \ ``` ## 多卡训练: 分别在不同机器测试 --machine_rank 1 ``` export MODEL_NAME="/mnt/fs/user/llama/custom_model/stable-diffusion-xl-base-1.0" export OUTPUT_DIR="/path/to/sd_xl" #export DATASET_NAME="/datasets/custom_datasets/pokemon-blip-captions" export DATASET_NAME="/mnt/fs/user/llama/custom_datasets/coco2017/images/ export VAE_NAME="/mnt/fs/user/llama/custom_model/sdxl-vae-fp16-fix" accelerate launch --multi_gpu --num_processes 16 --num_machines "2" --machine_rank 1 --rdzv_backend static --main_process_ip node21 --main_process_port 11223 examples/text_to_image/train_text_to_image_sdxl.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --pretrained_vae_model_name_or_path=$VAE_NAME \ --train_data_dir=$DATASET_NAME --caption_column="text" \ --resolution=512 --random_flip \ --train_batch_size=1 \ --num_train_epochs=2 --checkpointing_steps=500 \ --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \ --mixed_precision="fp16" \ --seed=42 \ --output_dir=$OUTPUT_DIR \ --validation_prompt="cute dragon creature" \ ``` # 推理 ``` from diffusers import DiffusionPipeline import torch model_path = "/mnt/fs/user/llama/custom_model/stable-diffusion-xl-base-1.0" # <-- change this pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16) pipe.to("cuda") prompt = "A naruto with green eyes and red legs." image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0] image.save("naruto.png") ``` # 参考 https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/README_sdxl.md