README.md 8.26 KB
Newer Older
wangsen's avatar
wangsen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# 数据集

coco2017

```
https://cocodataset.org/#home
http://images.cocodataset.org/zips/train2017.zip # train dataset
http://images.cocodataset.org/zips/val2017.zip # validation dataset
http://images.cocodataset.org/zips/test2017.zip # test dataset
http://images.cocodataset.org/zips/unlabeled2017.zip
http://images.cocodataset.org/annotations/annotations_trainval2017.zip
http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip
http://images.cocodataset.org/annotations/image_info_test2017.zip
http://images.cocodataset.org/annotations/image_info_unlabeled2017.zip
```

宝可梦
```
https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions

```


## 数据集处理

deal_coco.py 
```
#!/usr/bin/env python
import argparse
import json
import os
from typing import Iterator, Tuple


def read_coco_annotations(path: str, image_root: str) -> Iterator[Tuple[str, str]]:
    with open(path, "r") as f:
        content = json.load(f)["annotations"]
        for record in content:
            image_id = record["image_id"]
            caption = record["caption"]
            image_name = f"{image_id:012d}.jpg"
            image_path = os.path.join(image_root, image_name)
            if not os.path.isfile(image_path):
                print(f"Cannot find `{image_path}`, skip.")
                continue
            caption = caption.strip().replace(",", "").replace("\n", "")
            yield image_name, caption


def main():
    parser = argparse.ArgumentParser(description="Converting COCO json annatotion into plain txt")
    parser.add_argument("--label", required=True, help="Path of the label file")
    parser.add_argument("--image", required=True, help="Path of the image root")
    parser.add_argument("--out", default="metadata.csv", help="Output path of the txt file")
    args = parser.parse_args()

    with open(args.out, "w") as f:
        f.write("file_name,text\n")
        for image_name, caption in read_coco_annotations(args.label, args.image):
            f.write(f"{image_name},{caption}\n")


if __name__ == "__main__":
    main()
```

处理数据
```
python  deal_coco.py  --lablel  captions_train2017.json   --images train2017/
```





# 模型
```
https://github.com/huggingface/diffusers.git
```

# 环境搭建

```
#将requiments.txt中torchvision注释

wangsen's avatar
wangsen committed
86
87
88
python setup.py install 
cd examples/text_to_image
pip install -r requirements_sdxl.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
wangsen's avatar
wangsen committed
89
90
91
92
93
94
95

```


# 训练


wangsen's avatar
wangsen committed
96
## lora 训练
wangsen's avatar
wangsen committed
97
98
99
100
101
102
103
104
```
export MODEL_NAME="/datasets/custom_model/stable-diffusion-xl-base-1.0"                                                                                                                                                                                                                                                                                
 
export OUTPUT_DIR="/path/to/sd_xl"                                                                                                                                                                                                                                                                                                                   
#export DATASET_NAME="/datasets/custom_datasets/pokemon-blip-captions"                                                                                                                                                                                                                                                                                 
                                                                                                                                                             
export DATASET_NAME="/datasets/custom_datasets/coco2017/images/train2017"   
export VAE_NAME="/datasets/custom_model/sdxl-vae-fp16-fix"
wangsen's avatar
wangsen committed
105
accelerate launch --multi_gpu  examples/text_to_image/train_text_to_image_lora_sdxl.py \
wangsen's avatar
wangsen committed
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
  --pretrained_model_name_or_path=$MODEL_NAME \
  --pretrained_vae_model_name_or_path=$VAE_NAME \
  --dataset_name=$DATASET_NAME \
  --enable_xformers_memory_efficient_attention \
  --resolution=512 --center_crop --random_flip \
  --proportion_empty_prompts=0.2 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 --gradient_checkpointing \
  --max_train_steps=10000 \
  --use_8bit_adam \
  --learning_rate=1e-06 --lr_scheduler="constant" --lr_warmup_steps=0 \
  --mixed_precision="fp16" \
  --validation_prompt="a cute Sundar Pichai creature" --validation_epochs 5 \
  --checkpointing_steps=5000 \
  --output_dir=$OUTPUT_DIR

```

wangsen's avatar
wangsen committed
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181


## fine train 


```
export MODEL_NAME="/mnt/fs/user/llama/custom_model/stable-diffusion-xl-base-1.0"                                                                                                                              
export OUTPUT_DIR="/path/to/sd_xl"                                                                                                                                                                                                               
#export DATASET_NAME="/datasets/custom_datasets/pokemon-blip-captions"                                                                                                                                            
export DATASET_NAME="/mnt/fs/user/llama/custom_datasets/coco2017/images/train2017"                                                                                                                                                                                                                                                                                  
                                                                                                                                                             
DATASET_NAME="/datasets/custom_datasets/coco2017/images/train2017"

accelerate launch examples/text_to_image/train_text_to_image_sdxl.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --pretrained_vae_model_name_or_path=$VAE_NAME \
  --train_data_dir=$DATASET_NAME --caption_column="text" \
  --resolution=512 --random_flip \
  --train_batch_size=1 \
  --num_train_epochs=2 --checkpointing_steps=500 \
  --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
  --mixed_precision="fp16" \
  --seed=42 \
  --output_dir="/path/to/sd-pokemon-model-lora-sdxl" \
  --validation_prompt="cute dragon creature"  \

```




## 多卡训练:
分别在不同机器测试 --machine_rank 1 

```
export MODEL_NAME="/mnt/fs/user/llama/custom_model/stable-diffusion-xl-base-1.0"                                                                                                                              
export OUTPUT_DIR="/path/to/sd_xl"                                                                                                                                                                                                               
#export DATASET_NAME="/datasets/custom_datasets/pokemon-blip-captions"                                                                                                                                            
export DATASET_NAME="/mnt/fs/user/llama/custom_datasets/coco2017/images/
export VAE_NAME="/mnt/fs/user/llama/custom_model/sdxl-vae-fp16-fix"
accelerate launch --multi_gpu  --num_processes  16  --num_machines  "2" --machine_rank 1  --rdzv_backend  static  --main_process_ip node21  --main_process_port 11223  examples/text_to_image/train_text_to_image_sdxl.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --pretrained_vae_model_name_or_path=$VAE_NAME \
  --train_data_dir=$DATASET_NAME --caption_column="text" \
  --resolution=512 --random_flip \
  --train_batch_size=1 \
  --num_train_epochs=2 --checkpointing_steps=500 \
  --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
  --mixed_precision="fp16" \
  --seed=42 \
  --output_dir=$OUTPUT_DIR \
  --validation_prompt="cute dragon creature"  \

```




wangsen's avatar
wangsen committed
182
183
184
# 推理


wangsen's avatar
wangsen committed
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
```
from diffusers import DiffusionPipeline
import torch

model_path = "/mnt/fs/user/llama/custom_model/stable-diffusion-xl-base-1.0" # <-- change this
pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
pipe.to("cuda")

prompt = "A naruto with green eyes and red legs."
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
image.save("naruto.png")

```



wangsen's avatar
wangsen committed
201
202
203
204
205

# 参考
 https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/README_sdxl.md