README.md 6.19 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# HAT
## 论文
[HAT: Hybrid Attention Transformer for Image Restoration](https://arxiv.org/abs/2309.05239)

## 模型结构
HAT包括三个部分,包括浅层特征提取、深层特征提取和图像重建。

<div align=center>
    <img src="./doc/model.png"/>
</div>

## 算法原理
HAT方法结合了通道注意力和基于窗口的自注意力方案,利用两者的互补优势。此外,引入了重叠的跨注意力模块来增强相邻窗口特征之间的交互, 更好地聚合跨窗口信息。在训练阶段,HAT还采用了相同的任务预训练策略,以进一步挖掘模型的潜力进行进一步改进。得益于这些设计,HAT可以激活更多的像素进行重建,从而显著提高性能。

<div align=center>
    <img src="./doc/method.png"/>
</div>

## 环境配置
-v 路径、docker_name和imageID根据实际情况修改

### Docker(方法一)

```bash
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.13.1-centos7.6-dtk23.10-py38

docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/ --shm-size=32G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash

cd /your_code_path/hat_pytorch
pip install -r requirements.txt
python setup.py develop
```

### Dockerfile(方法二)

```bash
cd ./docker
cp ../requirements.txt requirements.txt

docker build --no-cache -t hat:latest .
docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/ --shm-size=32G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash

cd /your_code_path/hat_pytorch
python setup.py develop
```

### Anaconda(方法三)

1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装: https://developer.hpccube.com/tool/

```bash
DTK软件栈:dtk23.10
python:python3.8
torch:1.13.1
torchvision:0.14.1
```

Tips:以上dtk软件栈、python、torch等DCU相关工具版本需要严格一一对应

2、其他非特殊库直接按照requirements.txt安装

```
pip install -r requirements.txt
python setup.py develop
```

## 数据集
训练:
[ImageNet dataset](https://image-net.org/challenges/LSVRC/2012/2012-downloads.php)
[DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/)
[Flickr2K](https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar)

Tips: DF2K: DIV2K 和 Flickr2 数据的整合
训练数据处理请参考[BasicSR](https://github.com/XPixelGroup/BasicSR/blob/master/docs/DatasetPreparation.md)

测试:
[Classical SR Testing](https://drive.google.com/drive/folders/1gt5eT293esqY0yr1Anbm36EdnxWW_5oH?usp=sharing)


数据准备具体步骤如下:
Rayyyyy's avatar
Rayyyyy committed
81
82
83
84
85
86
87
88
89
1. 将数据存放在datasets目录下;

2. BSD100和urban100需要再各自目录下新GTmod4、LRbicx4两个新目录,并把原始图片存放进GTmod4目录下,然后再datasets目录下分别执行下面两条命令:
```bash
python gen_LRbicx4.py --file_name ./BSD100
python gen_LRbicx4.py --file_name ./urban100
```

3. 建数据集的目录结构如下:
Rayyyyy's avatar
Rayyyyy committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106

```
├── DF2K
│   ├── DF2K_HR # HR 数据
│   ├── DF2K_HR_sub # 生成的
│   ├── DF2K_bicx4 # train_LR_bicubic_X4 数据
│   ├── DF2K_bicx4_sub # 生成的
├── Set5
│   ├── GTmod12
│   ├── LRbicx2
│   ├── LRbicx3
│   ├── LRbicx4
├── Set14
│   ├── GTmod12
│   ├── LRbicx2
│   ├── LRbicx3
│   ├── LRbicx4
Rayyyyy's avatar
Rayyyyy committed
107
108
109
110
111
112
├── BSDS100
│   ├── GTmod4 # 原始图像
│   ├── LRbicx4
├── urban100
│   ├── GTmod4 # 原始图像
│   ├── LRbicx4
Rayyyyy's avatar
Rayyyyy committed
113
114
115
116
117
118
119
120
121
122
123
124
```

Tips: 项目提供了tiny_datasets用于快速上手学习, 如果实用tiny_datasets, 需要对下面的代码内的地址进行替换, 当前默认完整数据集的处理地址。

2. 因为 DF2K 数据集是 2K 分辨率的 (比如: 2048x1080), 而我们在训练的时候往往并不要那么大 (常见的是 128x128 或者 192x192 的训练patch). 因此我们可以先把2K的图片裁剪成有overlap的 480x480 的子图像块. 然后再由 dataloader 从这个 480x480 的子图像块中随机crop出 128x128 或者 192x192 的训练patch.

```bash
python extract_subimages.py # 将图片进行sub
```

3. 生成 meta_info_file
```bash
Rayyyyy's avatar
Rayyyyy committed
125
python generate_meta_info.py
Rayyyyy's avatar
Rayyyyy committed
126
127
128
```

## 训练
Rayyyyy's avatar
Rayyyyy committed
129
130
预训练模型下载地址:[Google Drive](https://drive.google.com/drive/folders/1HpmReFfoUqUbnAOQ7rvOeNU3uf_m69w0?usp=sharing) or [百度网盘](https://pan.baidu.com/s/1u2r4Lc2_EEeQqra2-w85Xg) (access code: qyrl)。

Rayyyyy's avatar
Rayyyyy committed
131
132
133
134
135
训练日志及weights保存在./experiments文件中

### 单机多卡

```bash
Rayyyyy's avatar
Rayyyyy committed
136
# 默认 train_HAT_SRx4_finetune_from_ImageNet_pretrain.yml
Rayyyyy's avatar
Rayyyyy committed
137
138
139
140
bash train.sh
```

### 多机多卡
Rayyyyy's avatar
Rayyyyy committed
141
142
143
使用多节点的情况下,需要将使用节点写入hostfile文件, 多节点每个节点一行, 例如: c1xxxxxx slots=4。

1. run_train_multi.sh中18行所需虚拟环境变量地址;
Rayyyyy's avatar
Rayyyyy committed
144
145
146
147
148

2. 修改single_process.sh中22行所需训练的yaml文件地址,如与默认一致,可不修改。

执行命令如下, 训练日志保存在logs文件夹下
```bash
Rayyyyy's avatar
Rayyyyy committed
149
# 默认 train_HAT_SRx4_finetune_from_ImageNet_pretrain.yml
Rayyyyy's avatar
Rayyyyy committed
150
bash run_train_multi.sh
Rayyyyy's avatar
Rayyyyy committed
151
152
153
154
155
```

## 推理
预训练模型下载地址:[Google Drive](https://drive.google.com/drive/folders/1HpmReFfoUqUbnAOQ7rvOeNU3uf_m69w0?usp=sharing) or [百度网盘](https://pan.baidu.com/s/1u2r4Lc2_EEeQqra2-w85Xg) (access code: qyrl)。

Rayyyyy's avatar
Rayyyyy committed
156
测试结果将保存到 ./results 路径下。options/test/HAT_SRx4_ImageNet-LR.yml 适用于不使用 ground truth image 的推理过程。
Rayyyyy's avatar
Rayyyyy committed
157
```bash
Rayyyyy's avatar
Rayyyyy committed
158
# 默认 HAT_SRx4_ImageNet-pretrain.yml
Rayyyyy's avatar
Rayyyyy committed
159
160
161
162
163
164
165
166
167
168
169
170
bash val.sh
```

## result
基于 Real_HAT_GAN_SRx4_sharper.pth 的测试结果展示

<div align=center>
    <img src="./doc/Visual_Results.png"/>
</div>

### 精度

Rayyyyy's avatar
Rayyyyy committed
171
172
173
174
| Model | Params(M) | Multi-Adds(G) | Set5 | Set14 | BSD100 | Urban100 |
| :------: | :------: | :------: | :------: |:------: | :------: | :------: |
| HAT | 20.8 | 102.4 | 33.1486 | 29.3587 | 25.4074 | 21.2687 |
| HAT(our) | 20.8 | 102.4 | 33.1486 | 29.3587 | 25.4074 | 21.2687 |
Rayyyyy's avatar
Rayyyyy committed
175
176
177
178
179
180
181
182
183
184
185
186
187

## 应用场景
### 算法类别
图像重建

### 热点应用行业
交通,公安,制造

## 源码仓库及问题反馈
- https://developer.hpccube.com/codes/modelzoo/hat_pytorch

## 参考资料
- https://github.com/XPixelGroup/HAT?tab=readme-ov-file