README.md 4.52 KB
Newer Older
wangsen's avatar
wangsen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# 数据集

coco2017

```
https://cocodataset.org/#home
http://images.cocodataset.org/zips/train2017.zip # train dataset
http://images.cocodataset.org/zips/val2017.zip # validation dataset
http://images.cocodataset.org/zips/test2017.zip # test dataset
http://images.cocodataset.org/zips/unlabeled2017.zip
http://images.cocodataset.org/annotations/annotations_trainval2017.zip
http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip
http://images.cocodataset.org/annotations/image_info_test2017.zip
http://images.cocodataset.org/annotations/image_info_unlabeled2017.zip
```

宝可梦
```
https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions

```


## 数据集处理

deal_coco.py 
```
#!/usr/bin/env python
import argparse
import json
import os
from typing import Iterator, Tuple


def read_coco_annotations(path: str, image_root: str) -> Iterator[Tuple[str, str]]:
    with open(path, "r") as f:
        content = json.load(f)["annotations"]
        for record in content:
            image_id = record["image_id"]
            caption = record["caption"]
            image_name = f"{image_id:012d}.jpg"
            image_path = os.path.join(image_root, image_name)
            if not os.path.isfile(image_path):
                print(f"Cannot find `{image_path}`, skip.")
                continue
            caption = caption.strip().replace(",", "").replace("\n", "")
            yield image_name, caption


def main():
    parser = argparse.ArgumentParser(description="Converting COCO json annatotion into plain txt")
    parser.add_argument("--label", required=True, help="Path of the label file")
    parser.add_argument("--image", required=True, help="Path of the image root")
    parser.add_argument("--out", default="metadata.csv", help="Output path of the txt file")
    args = parser.parse_args()

    with open(args.out, "w") as f:
        f.write("file_name,text\n")
        for image_name, caption in read_coco_annotations(args.label, args.image):
            f.write(f"{image_name},{caption}\n")


if __name__ == "__main__":
    main()
```

处理数据
```
python  deal_coco.py  --lablel  captions_train2017.json   --images train2017/
```





# 模型
```
https://github.com/huggingface/diffusers.git
```




# 环境搭建

```
#将requiments.txt中torchvision注释
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
python setup.py install 





```


# 训练


lora 训练
```
export MODEL_NAME="/datasets/custom_model/stable-diffusion-xl-base-1.0"                                                                                                                                                                                                                                                                                
 
export OUTPUT_DIR="/path/to/sd_xl"                                                                                                                                                                                                                                                                                                                   
#export DATASET_NAME="/datasets/custom_datasets/pokemon-blip-captions"                                                                                                                                                                                                                                                                                 
                                                                                                                                                             
export DATASET_NAME="/datasets/custom_datasets/coco2017/images/train2017"   
export VAE_NAME="/datasets/custom_model/sdxl-vae-fp16-fix"
accelerate launch --multi_gpu  examples/text_to_image/train_text_to_image_sdxl.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --pretrained_vae_model_name_or_path=$VAE_NAME \
  --dataset_name=$DATASET_NAME \
  --enable_xformers_memory_efficient_attention \
  --resolution=512 --center_crop --random_flip \
  --proportion_empty_prompts=0.2 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 --gradient_checkpointing \
  --max_train_steps=10000 \
  --use_8bit_adam \
  --learning_rate=1e-06 --lr_scheduler="constant" --lr_warmup_steps=0 \
  --mixed_precision="fp16" \
  --validation_prompt="a cute Sundar Pichai creature" --validation_epochs 5 \
  --checkpointing_steps=5000 \
  --output_dir=$OUTPUT_DIR

```

# 推理



# 参考
 https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/README_sdxl.md