Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
HAT_pytorch
Commits
a7d973fa
Commit
a7d973fa
authored
Jul 09, 2024
by
Rayyyyy
Browse files
Add icon and SCNet, update dtk
parent
dd68669d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
35 additions
and
44 deletions
+35
-44
README.md
README.md
+34
-40
docker/Dockerfile
docker/Dockerfile
+1
-4
icon.png
icon.png
+0
-0
No files found.
README.md
View file @
a7d973fa
# HAT
# HAT
## 论文
## 论文
[
HAT: Hybrid Attention Transformer for Image Restoration
](
https://arxiv.org/abs/2309.05239
)
`HAT: Hybrid Attention Transformer for Image Restoration`
-
https://arxiv.org/abs/2309.05239
## 模型结构
## 模型结构
HAT包括三个部分,包括浅层特征提取、深层特征提取和图像重建。
HAT包括三个部分,包括浅层特征提取、深层特征提取和图像重建。
...
@@ -10,7 +11,7 @@ HAT包括三个部分,包括浅层特征提取、深层特征提取和图像
...
@@ -10,7 +11,7 @@ HAT包括三个部分,包括浅层特征提取、深层特征提取和图像
</div>
</div>
## 算法原理
## 算法原理
HAT方法结合了通道注意力和基于窗口的自注意力方案,利用两者的互补优势。此外,引入了重叠的跨注意力模块来增强相邻窗口特征之间的交互
,
更好地聚合跨窗口信息。在训练阶段,HAT还采用了相同的任务预训练策略,以进一步挖掘模型的潜力进行进一步改进。得益于这些设计,HAT可以激活更多的像素进行重建,从而显著提高性能。
HAT方法结合了通道注意力和基于窗口的自注意力方案,利用两者的互补优势。此外,引入了重叠的跨注意力模块来增强相邻窗口特征之间的交互
,
更好地聚合跨窗口信息。在训练阶段,HAT还采用了相同的任务预训练策略,以进一步挖掘模型的潜力进行进一步改进。得益于这些设计,HAT可以激活更多的像素进行重建,从而显著提高性能。
<div
align=
center
>
<div
align=
center
>
<img
src=
"./doc/method.png"
/>
<img
src=
"./doc/method.png"
/>
...
@@ -20,11 +21,9 @@ HAT方法结合了通道注意力和基于窗口的自注意力方案,利用
...
@@ -20,11 +21,9 @@ HAT方法结合了通道注意力和基于窗口的自注意力方案,利用
-v 路径、docker_name和imageID根据实际情况修改
-v 路径、docker_name和imageID根据实际情况修改
### Docker(方法一)
### Docker(方法一)
```
bash
```
bash
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.13.1-centos7.6-dtk23.10-py38
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.13.1-centos7.6-dtk23.10-py38
docker run
-it
-v
/path/your_code_data/:/path/your_code_data/
-v
/opt/hyhal/:/opt/hyhal/:ro
--shm-size
=
32G
--privileged
=
true
--device
=
/dev/kfd
--device
=
/dev/dri/
--group-add
video
--name
docker_name imageID bash
docker run
-it
-v
/path/your_code_data/:/path/your_code_data/
-v
/opt/hyhal/:/opt/hyhal/
--shm-size
=
32G
--privileged
=
true
--device
=
/dev/kfd
--device
=
/dev/dri/
--group-add
video
--name
docker_name imageID bash
cd
/your_code_path/hat_pytorch
cd
/your_code_path/hat_pytorch
pip
install
-r
requirements.txt
pip
install
-r
requirements.txt
...
@@ -32,20 +31,17 @@ python setup.py develop
...
@@ -32,20 +31,17 @@ python setup.py develop
```
```
### Dockerfile(方法二)
### Dockerfile(方法二)
```
bash
```
bash
cd
./docker
cd
./docker
cp
../requirements.txt requirements.txt
docker build
--no-cache
-t
hat:latest
.
docker build
--no-cache
-t
hat:latest
.
docker run
-it
-v
/path/your_code_data/:/path/your_code_data/
-v
/opt/hyhal/:/opt/hyhal/
--shm-size
=
32G
--privileged
=
true
--device
=
/dev/kfd
--device
=
/dev/dri/
--group-add
video
--name
docker_name imageID bash
docker run
-it
-v
/path/your_code_data/:/path/your_code_data/
-v
/opt/hyhal/:/opt/hyhal/
:ro
--shm-size
=
32G
--privileged
=
true
--device
=
/dev/kfd
--device
=
/dev/dri/
--group-add
video
--name
docker_name imageID bash
cd
/your_code_path/hat_pytorch
cd
/your_code_path/hat_pytorch
python setup.py develop
python setup.py develop
```
```
### Anaconda(方法三)
### Anaconda(方法三)
1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装: https://developer.hpccube.com/tool/
1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装: https://developer.hpccube.com/tool/
```
bash
```
bash
...
@@ -55,39 +51,38 @@ torch:1.13.1
...
@@ -55,39 +51,38 @@ torch:1.13.1
torchvision:0.14.1
torchvision:0.14.1
```
```
Tips:以上dtk软件栈、python、torch等DCU相关工具版本需要严格一一对应
`
Tips:以上dtk软件栈、python、torch等DCU相关工具版本需要严格一一对应
`
2、其他非特殊库直接按照requirements.txt安装
2、其他非特殊库直接按照requirements.txt安装
```
bash
```
pip
install
-r
requirements.txt
pip
install
-r
requirements.txt
python setup.py develop
python setup.py develop
```
```
## 数据集
## 数据集
训练:
-
训练:
[
ImageNet dataset
](
https://image-net.org/challenges/LSVRC/2012/2012-downloads.php
)
[
ImageNet dataset
](
https://image-net.org/challenges/LSVRC/2012/2012-downloads.php
)
[
DIV2K
](
https://data.vision.ee.ethz.ch/cvl/DIV2K/
)
[
Flickr2K
](
https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar
)
Tips: DF2K: DIV2K 和 Flickr2 数据的整合
[
DIV2K Train Data (HR images)
](
http://113.200.138.88:18080/aidatasets/project-dependency/div2k/-/blob/master/DIV2K_train_HR.zip
)
训练数据处理请参考
[
BasicSR
](
https://github.com/XPixelGroup/BasicSR/blob/master/docs/DatasetPreparation.md
)
测试:
[
Flickr2K
](
http://113.200.138.88:18080/aidatasets/project-dependency/flickr2k/-/blob/master/Flickr2K.tar
)
[
Classical SR Testing
](
https://drive.google.com/drive/folders/1gt5eT293esqY0yr1Anbm36EdnxWW_5oH?usp=sharing
)
训练数据处理请参考
[
BasicSR
](
https://github.com/XPixelGroup/BasicSR/blob/master/docs/DatasetPreparation.md
)
-
测试:
[
Classical SR Testing
](
http://113.200.138.88:18080/aidatasets/project-dependency/classical-sr
)
数据准备具体步骤如下:
数据准备具体步骤如下:
1.
将数据存放在datasets目录下;
1.
将数据存放在
`
datasets
`
目录下;
2.
BSD100
和
urban100需要再各自目录下新GTmod4
、
LRbicx4两个新目录,并把原始图片存放进GTmod4目录下,然后再datasets目录下分别执行下面两条命令:
2.
`
BSD100
`
和
`
urban100
`
需要再各自目录下新
`
GTmod4
`
、
`
LRbicx4
`
两个新目录,并把原始图片存放进
`
GTmod4
`
目录下,然后再
`
datasets
`
目录下分别执行下面两条命令:
```
bash
```
bash
python gen_LRbicx4.py
--file_name
./BSD100
python gen_LRbicx4.py
--file_name
./BSD100
python gen_LRbicx4.py
--file_name
./urban100
python gen_LRbicx4.py
--file_name
./urban100
```
```
3.
建数据集的目录结构如下:
3.
建数据集的目录结构如下:
`DF2K`
:DIV2K和Flickr2的HR数据的整合
```
```
├── DF2K
├── DF2K
│ ├── DF2K_HR # HR 数据
│ ├── DF2K_HR # HR 数据
...
@@ -112,63 +107,59 @@ python gen_LRbicx4.py --file_name ./urban100
...
@@ -112,63 +107,59 @@ python gen_LRbicx4.py --file_name ./urban100
│ ├── LRbicx4
│ ├── LRbicx4
```
```
Tips: 项目提供了tiny_datasets用于快速上手学习
, 如果实用
tiny_datasets
,
需要对下面的代码内的地址进行替换
,
当前默认完整数据集的处理地址。
Tips: 项目提供了
`
tiny_datasets
`
用于快速上手学习
,如果使用
`
tiny_datasets
`
,
需要对下面的代码内的地址进行替换
,
当前默认完整数据集的处理地址。
2.
因
为
DF2K
数据集是
2K
分辨率
的 (比如: 2048x1080),
而我们在训练的时候往往并不要那么大
(常见的是
128x128
或者
192x192
的训练patch)
.
因此我们可以先把2K的图片裁剪成有overlap的
480x480
的子图像块
.
然后再由
dataloader
从这个
480x480
的子图像块中随机crop出
128x128
或者
192x192
的训练patch
.
2.
因
`
DF2K
`
数据集是2K分辨率
,
而我们在训练的时候往往并不要那么大(常见的是128x128或者192x192的训练patch)
。
因此我们可以先把2K的图片裁剪成有overlap的480x480的子图像块
,
然后再由
`
dataloader
`
从这个480x480的子图像块中随机crop出128x128
或
192x192的训练patch
。
```
bash
```
bash
python extract_subimages.py
# 将图片进行sub
python extract_subimages.py
# 将图片进行sub
```
```
3.
生成
meta_info_file
3.
生成
`
meta_info_file
`
```
bash
```
bash
python generate_meta_info.py
python generate_meta_info.py
```
```
## 训练
## 训练
预训练模型下载地址:
[
Google Drive
](
https://drive.google.com/drive/folders/1HpmReFfoUqUbnAOQ7rvOeNU3uf_m69w0?usp=sharing
)
or
[
百度网盘
](
https://pan.baidu.com/s/1u2r4Lc2_EEeQqra2-w85Xg
)
(
access
code: qyrl)。
训练日志及权重保存在
`/experiments`
文件中,预训练模型下载地址
[
预训练权重
](
#预训练权重
)
。
训练日志及weights保存在./experiments文件中
### 单机多卡
### 单机多卡
```
bash
```
bash
# 默认 train_HAT_SRx4_finetune_from_ImageNet_pretrain.yml
# 默认 train_HAT_SRx4_finetune_from_ImageNet_pretrain.yml
bash train.sh
bash train.sh
```
```
### 多机多卡
### 多机多卡
使用多节点的情况下,需要将使用节点写入hostfile文件
,
多节点每个节点一行
,
例如: c1xxxxxx slots=4。
使用多节点的情况下,需要将使用节点写入hostfile文件
,
多节点每个节点一行
,
例如: c1xxxxxx slots=4。
1.
run_train_multi.sh
中
18行所需虚拟环境变量地址;
1.
[
run_train_multi.sh
](
/run_train_multi.sh
)
中
`
18行
`
所需虚拟环境变量地址;
2.
修改single_process.sh中22行所需训练的yaml文件地址,如与默认一致,可不修改。
2.
修改
[
single_process.sh
](
./single_process.sh
)
中22行所需训练的
`
yaml文件
`
地址,如与默认一致,可不修改。
执行命令如下
,
训练日志保存在logs文件夹下
执行命令如下
,
训练日志保存在logs文件夹下
```
bash
```
bash
# 默认 train_HAT_SRx4_finetune_from_ImageNet_pretrain.yml
# 默认 train_HAT_SRx4_finetune_from_ImageNet_pretrain.yml
bash run_train_multi.sh
bash run_train_multi.sh
```
```
## 推理
## 推理
预训练模型下载地址:
[
Google Drive
](
https://drive.google.com/drive/folders/1HpmReFfoUqUbnAOQ7rvOeNU3uf_m69w0?usp=sharing
)
or
[
百度网盘
](
https://pan.baidu.com/s/1u2r4Lc2_EEeQqra2-w85Xg
)
(
access
code: qyrl)。
预训练模型下载地址
[
预训练权重
](
#预训练权重
)
,测试结果将保存到
`./results`
路径下。
[
HAT_SRx4_ImageNet-LR.yml
](
options/test/HAT_SRx4_ImageNet-LR.yml
)
适用于不使用
`ground truth image`
的推理过程。
测试结果将保存到 ./results 路径下。options/test/HAT_SRx4_ImageNet-LR.yml 适用于不使用 ground truth image 的推理过程。
```
bash
```
bash
# 默认 HAT_SRx4_ImageNet-pretrain.yml
# 默认 HAT_SRx4_ImageNet-pretrain.yml
bash val.sh
bash val.sh
```
```
## result
## result
基于
Real_HAT_GAN_SRx4_sharper.pth
的测试结果展示
基于
`
Real_HAT_GAN_SRx4_sharper.pth
`
的测试结果展示
<div
align=
center
>
<div
align=
center
>
<img
src=
"./doc/Visual_Results.png"
/>
<img
src=
"./doc/Visual_Results.png"
/>
</div>
</div>
### 精度
### 精度
HAT
测试数据如下表中所示,使用的加速卡:Z100L。
| Model | Params(M) | Multi-Adds(G) | Set5 | Set14 | BSD100 | Urban100 |
| DEVICE | Params(M) | Multi-Adds(G) | Set5 | Set14 | BSD100 | Urban100 |
| :------: | :------: | :------: | :------: |:------: | :------: | :------: |
| :------: | :------: | :------: | :------: |:------: | :------: | :------: |
| Z100L | 20.8 | 102.4 | 33.1486 | 29.3587 | 25.4074 | 21.2687 |
| Z100L | 20.8 | 102.4 | 33.1486 | 29.3587 | 25.4074 | 21.2687 |
...
@@ -179,8 +170,11 @@ HAT
...
@@ -179,8 +170,11 @@ HAT
### 热点应用行业
### 热点应用行业
交通,公安,制造
交通,公安,制造
## 预训练权重
预训练模型下载地址:
[
SCNet-AIModels
](
http://113.200.138.88:18080/aimodels/findsource-dependency/hat_pytorch
)
## 源码仓库及问题反馈
## 源码仓库及问题反馈
-
https://developer.hpccube.com/codes/modelzoo/hat_pytorch
-
https://developer.hpccube.com/codes/modelzoo/hat_pytorch
## 参考资料
## 参考资料
-
https://github.com/XPixelGroup/HAT
?tab=readme-ov-file
-
https://github.com/XPixelGroup/HAT
docker/Dockerfile
View file @
a7d973fa
FROM
image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.13.1-centos7.6-dtk23.10-py38
FROM
image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.13.1-centos7.6-dtk23.10-py38
RUN
source
/opt/dtk/env.sh
\ No newline at end of file
COPY
requirements.txt requirements.txt
RUN
pip3
install
-r
requirements.txt
\ No newline at end of file
icon.png
0 → 100644
View file @
a7d973fa
59.3 KB
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment