README.md 6.12 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
# HAT
## 论文
Rayyyyy's avatar
Rayyyyy committed
3
4
`HAT: Hybrid Attention Transformer for Image Restoration`
- https://arxiv.org/abs/2309.05239
Rayyyyy's avatar
Rayyyyy committed
5
6

## 模型结构
chenych's avatar
chenych committed
7

Rayyyyy's avatar
Rayyyyy committed
8
9
10
11
12
13
14
HAT包括三个部分,包括浅层特征提取、深层特征提取和图像重建。

<div align=center>
    <img src="./doc/model.png"/>
</div>

## 算法原理
chenych's avatar
chenych committed
15

Rayyyyy's avatar
Rayyyyy committed
16
HAT方法结合了通道注意力和基于窗口的自注意力方案,利用两者的互补优势。此外,引入了重叠的跨注意力模块来增强相邻窗口特征之间的交互,更好地聚合跨窗口信息。在训练阶段,HAT还采用了相同的任务预训练策略,以进一步挖掘模型的潜力进行进一步改进。得益于这些设计,HAT可以激活更多的像素进行重建,从而显著提高性能。
Rayyyyy's avatar
Rayyyyy committed
17
18
19
20
21
22
23
24
25

<div align=center>
    <img src="./doc/method.png"/>
</div>

## 环境配置
-v 路径、docker_name和imageID根据实际情况修改

### Docker(方法一)
chenych's avatar
chenych committed
26

Rayyyyy's avatar
Rayyyyy committed
27
```bash
chenych's avatar
chenych committed
28
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10
Rayyyyy's avatar
Rayyyyy committed
29
docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro --shm-size=32G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash
Rayyyyy's avatar
Rayyyyy committed
30
31

cd /your_code_path/hat_pytorch
chenych's avatar
chenych committed
32
pip install wheel cython
Rayyyyy's avatar
Rayyyyy committed
33
34
35
36
37
pip install -r requirements.txt
python setup.py develop
```

### Dockerfile(方法二)
chenych's avatar
chenych committed
38

Rayyyyy's avatar
Rayyyyy committed
39
40
41
42
```bash
cd ./docker

docker build --no-cache -t hat:latest .
Rayyyyy's avatar
Rayyyyy committed
43
docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro --shm-size=32G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash
Rayyyyy's avatar
Rayyyyy committed
44
45

cd /your_code_path/hat_pytorch
chenych's avatar
chenych committed
46
47
pip install wheel cython
pip install -r requirements.txt
Rayyyyy's avatar
Rayyyyy committed
48
49
50
51
python setup.py develop
```

### Anaconda(方法三)
chenych's avatar
chenych committed
52

chenzk's avatar
chenzk committed
53
1. 关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装: https://developer.sourcefind.cn/tool/
Rayyyyy's avatar
Rayyyyy committed
54
55

```bash
chenych's avatar
chenych committed
56
57
58
59
DTK软件栈:dtk24.04.1
python:3.10
torch:2.1
torchvision:0.16.0
Rayyyyy's avatar
Rayyyyy committed
60
61
```

Rayyyyy's avatar
Rayyyyy committed
62
`Tips:以上dtk软件栈、python、torch等DCU相关工具版本需要严格一一对应`
Rayyyyy's avatar
Rayyyyy committed
63

chenych's avatar
chenych committed
64
2. 其他非特殊库直接按照requirements.txt安装
Rayyyyy's avatar
Rayyyyy committed
65
```bash
chenych's avatar
chenych committed
66
pip install wheel cython
Rayyyyy's avatar
Rayyyyy committed
67
68
69
70
71
pip install -r requirements.txt
python setup.py develop
```

## 数据集
chenych's avatar
chenych committed
72

Rayyyyy's avatar
Rayyyyy committed
73
- 训练:
Rayyyyy's avatar
Rayyyyy committed
74
75
[ImageNet dataset](https://image-net.org/challenges/LSVRC/2012/2012-downloads.php)

chenzk's avatar
chenzk committed
76
[DIV2K Train Data (HR images)](http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip)
Rayyyyy's avatar
Rayyyyy committed
77

chenzk's avatar
chenzk committed
78
[Flickr2K](http://cv.snu.ac.kr/research/EDSR/Flickr2K.tar)
Rayyyyy's avatar
Rayyyyy committed
79

chenych's avatar
chenych committed
80
> 训练数据处理请参考[BasicSR](https://github.com/XPixelGroup/BasicSR/blob/master/docs/DatasetPreparation.md)
Rayyyyy's avatar
Rayyyyy committed
81
82

- 测试:
chenzk's avatar
chenzk committed
83
`Classical SR Testing`
Rayyyyy's avatar
Rayyyyy committed
84
85

数据准备具体步骤如下:
chenych's avatar
chenych committed
86

Rayyyyy's avatar
Rayyyyy committed
87
1. 将数据存放在`datasets`目录下;
Rayyyyy's avatar
Rayyyyy committed
88

chenych's avatar
chenych committed
89
2. `BSD100``urban100` 需要再各自目录下新建 `GTmod4``LRbicx4` 两个新目录,并把原始图片存放进 `GTmod4` 目录下,然后在 `datasets` 目录下分别执行下面两条命令:
Rayyyyy's avatar
Rayyyyy committed
90
91
92
93
94
95
```bash
python gen_LRbicx4.py --file_name ./BSD100
python gen_LRbicx4.py --file_name ./urban100
```

3. 建数据集的目录结构如下:
chenych's avatar
chenych committed
96

Rayyyyy's avatar
Rayyyyy committed
97
`DF2K`:DIV2K和Flickr2的HR数据的整合
Rayyyyy's avatar
Rayyyyy committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
```
├── DF2K
│   ├── DF2K_HR # HR 数据
│   ├── DF2K_HR_sub # 生成的
│   ├── DF2K_bicx4 # train_LR_bicubic_X4 数据
│   ├── DF2K_bicx4_sub # 生成的
├── Set5
│   ├── GTmod12
│   ├── LRbicx2
│   ├── LRbicx3
│   ├── LRbicx4
├── Set14
│   ├── GTmod12
│   ├── LRbicx2
│   ├── LRbicx3
│   ├── LRbicx4
Rayyyyy's avatar
Rayyyyy committed
114
115
116
117
118
119
├── BSDS100
│   ├── GTmod4 # 原始图像
│   ├── LRbicx4
├── urban100
│   ├── GTmod4 # 原始图像
│   ├── LRbicx4
Rayyyyy's avatar
Rayyyyy committed
120
121
```

chenych's avatar
chenych committed
122
> 项目提供了`tiny_datasets`用于快速上手学习,如果使用`tiny_datasets`,需要对下面的代码内的地址进行替换,当前默认完整数据集的处理地址。
Rayyyyy's avatar
Rayyyyy committed
123

chenych's avatar
chenych committed
124
125
2.`DF2K`数据集是2K分辨率,而我们在训练的时候往往并不要那么大(常见的是128x128或者192x192的训练patch)。
因此我们可以先把2K的图片裁剪成有overlap的480x480的子图像块,然后再由`dataloader`从这个480x480的子图像块中随机crop出128x128或192x192的训练patch。
Rayyyyy's avatar
Rayyyyy committed
126
127
128
129
130

```bash
python extract_subimages.py # 将图片进行sub
```

Rayyyyy's avatar
Rayyyyy committed
131
3. 生成`meta_info_file`
chenych's avatar
chenych committed
132

Rayyyyy's avatar
Rayyyyy committed
133
```bash
Rayyyyy's avatar
Rayyyyy committed
134
python generate_meta_info.py
Rayyyyy's avatar
Rayyyyy committed
135
136
137
```

## 训练
chenych's avatar
chenych committed
138
139

训练日志及权重保存在`experiments`文件中,预训练模型下载地址[预训练权重](#预训练权重)
Rayyyyy's avatar
Rayyyyy committed
140
141

### 单机多卡
chenych's avatar
chenych committed
142

Rayyyyy's avatar
Rayyyyy committed
143
```bash
Rayyyyy's avatar
Rayyyyy committed
144
# 默认 train_HAT_SRx4_finetune_from_ImageNet_pretrain.yml
Rayyyyy's avatar
Rayyyyy committed
145
146
147
148
bash train.sh
```

### 多机多卡
chenych's avatar
chenych committed
149

Rayyyyy's avatar
Rayyyyy committed
150
使用多节点的情况下,需要将使用节点写入hostfile文件,多节点每个节点一行,例如: c1xxxxxx slots=4。
Rayyyyy's avatar
Rayyyyy committed
151

Rayyyyy's avatar
Rayyyyy committed
152
1. [run_train_multi.sh](/run_train_multi.sh)`18行`所需虚拟环境变量地址;
Rayyyyy's avatar
Rayyyyy committed
153

Rayyyyy's avatar
Rayyyyy committed
154
2. 修改[single_process.sh](./single_process.sh)中22行所需训练的`yaml文件`地址,如与默认一致,可不修改。
Rayyyyy's avatar
Rayyyyy committed
155

Rayyyyy's avatar
Rayyyyy committed
156
执行命令如下,训练日志保存在logs文件夹下
chenych's avatar
chenych committed
157

Rayyyyy's avatar
Rayyyyy committed
158
```bash
Rayyyyy's avatar
Rayyyyy committed
159
# 默认 train_HAT_SRx4_finetune_from_ImageNet_pretrain.yml
Rayyyyy's avatar
Rayyyyy committed
160
bash run_train_multi.sh
Rayyyyy's avatar
Rayyyyy committed
161
162
163
```

## 推理
chenych's avatar
chenych committed
164

Rayyyyy's avatar
Rayyyyy committed
165
预训练模型下载地址[预训练权重](#预训练权重),测试结果将保存到`./results`路径下。[HAT_SRx4_ImageNet-LR.yml](options/test/HAT_SRx4_ImageNet-LR.yml)适用于不使用`ground truth image`的推理过程。
Rayyyyy's avatar
Rayyyyy committed
166
```bash
Rayyyyy's avatar
Rayyyyy committed
167
# 默认 HAT_SRx4_ImageNet-pretrain.yml
Rayyyyy's avatar
Rayyyyy committed
168
169
170
171
bash val.sh
```

## result
chenych's avatar
chenych committed
172

Rayyyyy's avatar
Rayyyyy committed
173
基于`Real_HAT_GAN_SRx4_sharper.pth`的测试结果展示
Rayyyyy's avatar
Rayyyyy committed
174
175
176
177
178
179

<div align=center>
    <img src="./doc/Visual_Results.png"/>
</div>

### 精度
chenych's avatar
chenych committed
180

Rayyyyy's avatar
Rayyyyy committed
181
182
183
测试数据如下表中所示,使用的加速卡:Z100L。

| DEVICE | Params(M) | Multi-Adds(G) | Set5 | Set14 | BSD100 | Urban100 |
Rayyyyy's avatar
Rayyyyy committed
184
| :------: | :------: | :------: | :------: |:------: | :------: | :------: |
Rayyyyy's avatar
Rayyyyy committed
185
| Z100L | 20.8 | 102.4 | 33.1486 | 29.3587 | 25.4074 | 21.2687 |
Rayyyyy's avatar
Rayyyyy committed
186
187
188
189
190
191
192
193

## 应用场景
### 算法类别
图像重建

### 热点应用行业
交通,公安,制造

Rayyyyy's avatar
Rayyyyy committed
194
## 预训练权重
chenzk's avatar
chenzk committed
195
预训练模型下载地址:[pretrained models](https://drive.google.com/drive/folders/1HpmReFfoUqUbnAOQ7rvOeNU3uf_m69w0)
Rayyyyy's avatar
Rayyyyy committed
196

Rayyyyy's avatar
Rayyyyy committed
197
## 源码仓库及问题反馈
chenzk's avatar
chenzk committed
198
- https://developer.sourcefind.cn/codes/modelzoo/hat_pytorch
Rayyyyy's avatar
Rayyyyy committed
199
200

## 参考资料
Rayyyyy's avatar
Rayyyyy committed
201
- https://github.com/XPixelGroup/HAT