README.md 5.25 KB
Newer Older
mashun1's avatar
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# CatVTON_OpenPose

在原有的`catvton`中添加了`openpose`特征用于指导和衣物重叠的肢体生成。

## 论文

`CatVTON: Concatenation Is All You Need for Virtual Try-On with Diffusion Models`

* https://arxiv.org/pdf/2407.15886


## 模型结构

该模型基于`stable diffusion`结构,在原有的`catvton`中添加了`openpose`特征用于指导和衣物重叠的肢体生成。

![alt text](readme_imgs/arch.png)


## 算法原理

该算法基于`stable diffusion`,去除了多余的网络结构,直接将控制条件作为`Unet`的输入。

![alt text](readme_imgs/alg.png)

## 环境配置

### Docker(方法一)
    
    docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu22.04-dtk24.04.2-py3.10

    docker run --shm-size 50g --network=host --name=catvton --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -v /opt/hyhal:/opt/hyhal:ro -it <your IMAGE ID> bash

    pip install -r requirements.txt

### Dockerfile(方法二)

    docker build -t <IMAGE_NAME>:<TAG> .

    docker run --shm-size 50g --network=host --name=catvton --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -v /opt/hyhal:/opt/hyhal:ro -it <your IMAGE ID> bash

    pip install -r requirements.txt

### Anaconda (方法三)

1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装:
https://developer.hpccube.com/tool/

    DTK驱动:dtk24.04.2
    python:python3.10
    torch: 2.1.0
    torchvision: 0.16.0

Tips:以上dtk驱动、python、torch等DCU相关工具版本需要严格一一对应

2、其它非特殊库参照requirements.txt安装

    pip install -r requirements.txt

## 数据集

本项目已提供用于测试的tiny_datasets,见`datasets/tiny_datasets`,该数据既可用作训练集,也可用于测试使用。完整数据集见[SCNet高速下载通道](http://113.200.138.88:18080/aidatasets/project-dependency/viton-hd)

```
datasets
├── test
│   ├── test_data.jsonl
│   └── vitonhd
│       ├── agnostic-mask
│       ├── cloth
│       ├── image
│       └── openpose_img
└── train
    ├── eval_data.jsonl
    ├── train_data.jsonl
    └── vitonhd
        ├── agnostic-mask
        ├── cloth
        ├── image
        └── openpose_img
```
注意:`xxx.jsonl`不包含在数据集中,由给定脚本生成(见训练,推理部分)。


## 训练

### 数据路径文件准备

```bash
cd tools

python prepare_data_record.py \
--person_image_root="/path/to/person_image_dir" \
--cloth_image_root="/path/to/cloth_dir" \
--mask_root="/path/to/mask_dir" \
--extra_condition_image_root="/path/to/extra_condition_image_root" \
--extra_condition_key="e.g. openpose" \
--eval_nums=[用于验证的数据量] \
--save_root="/path/to/save_xxx.jsonl"
```

dcuai's avatar
dcuai committed
101
注意:`mask_root``agnostic-mask``train``image`文件夹相应的文件名需要进行修改,可使用
mashun1's avatar
mashun1 committed
102
103
104
105
106

```bash
find /path/to/files -type f -name "*_rendered.png" -exec bash -c 'mv "$0" "${0/_rendered.png/.png}"' {} \;
```

mashun1's avatar
mashun1 committed
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
### 启动训练

```bash
# 执行以下命令进行训练配置
accelerate config
```

```bash
cd run && bash scripts/train.sh
```

注意:在运行前请检查并确保`scripts/train.sh`中配置正确。需要说明的是,该训练代码并非官方代码(未开源),而是根据论文进行复现并进行了优化(添加数据增广等)。

## 推理

```bash
export HF_ENDPOINT=https://hf-mirror.com
```

```bash
cd run && bash scripts/generate_test_sample.sh
```

mashun1's avatar
mashun1 committed
130
注意:在运行前请检查并确保`scripts/generate_test_sample.sh`中配置正确。其中的`test_data.jsonl`文件生成可参考`训练-数据路径准备`(令eval_nums=0,然后将使用测试数据生成的`train_data.jsonl`重命名为`test_data.jsonl`即可)。
mashun1's avatar
mashun1 committed
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169

### 指标计算

```bash
cd run && bash scripts/cal_metrics.sh
```
注意:在运行前请检查并确保`scripts/cal_metrics.sh`中配置正确。

## result

|ground truth|openpose|w/o openpose|
|:---:|:---:|:---:|
|![alt txt](readme_imgs/gt/05087_00.jpg)|![alt txt](readme_imgs/openpose/05087_00.png)|![alt txt](readme_imgs/empty/05087_00.png)|
|![alt txt](readme_imgs/gt/06563_00.jpg)|![alt txt](readme_imgs/openpose/06563_00.png)|![alt txt](readme_imgs/empty/06563_00.png)|
|![alt txt](readme_imgs/gt/11791_00.jpg)|![alt txt](readme_imgs/openpose/11791_00.png)|![alt txt](readme_imgs/empty/11791_00.png)|

### 精度

本项目在dcu加速卡上完成训练。

`ssim`外,其余指标越低越好。

|| fid   | kid    | ssim  | lpips  |
|---| ----- | ------ | ----- | ------ |
|w/o openpose| 5.42  | 0.41   | 0.87  | 0.0565 |
|w openpose| 5.35  | 0.1327 | 0.865 | 0.06   |

## 应用场景

### 算法类别

`AIGC`

### 热点应用行业

`电商,绘画,广媒`

## 预训练权重

chenzk's avatar
chenzk committed
170
stable-diffusion-inpainting: [huggingface](https://hf-mirror.com/booksforcharlie/stable-diffusion-inpainting/tree/main) 
mashun1's avatar
mashun1 committed
171
172
173
174
175
176
177
178
179
180


## 源码仓库及问题反馈

* https://developer.sourcefind.cn/codes/modelzoo/catvton_openpose_pytorch

## 参考资料

* https://github.com/Zheng-Chong/CatVTON