# HAT ## 论文 `HAT: Hybrid Attention Transformer for Image Restoration` - https://arxiv.org/abs/2309.05239 ## 模型结构 HAT包括三个部分,包括浅层特征提取、深层特征提取和图像重建。
## 算法原理 HAT方法结合了通道注意力和基于窗口的自注意力方案,利用两者的互补优势。此外,引入了重叠的跨注意力模块来增强相邻窗口特征之间的交互,更好地聚合跨窗口信息。在训练阶段,HAT还采用了相同的任务预训练策略,以进一步挖掘模型的潜力进行进一步改进。得益于这些设计,HAT可以激活更多的像素进行重建,从而显著提高性能。
## 环境配置 -v 路径、docker_name和imageID根据实际情况修改 ### Docker(方法一) ```bash docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10 docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro --shm-size=32G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash cd /your_code_path/hat_pytorch pip install wheel cython pip install -r requirements.txt python setup.py develop ``` ### Dockerfile(方法二) ```bash cd ./docker docker build --no-cache -t hat:latest . docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro --shm-size=32G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash cd /your_code_path/hat_pytorch pip install wheel cython pip install -r requirements.txt python setup.py develop ``` ### Anaconda(方法三) 1. 关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装: https://developer.sourcefind.cn/tool/ ```bash DTK软件栈:dtk24.04.1 python:3.10 torch:2.1 torchvision:0.16.0 ``` `Tips:以上dtk软件栈、python、torch等DCU相关工具版本需要严格一一对应` 2. 其他非特殊库直接按照requirements.txt安装 ```bash pip install wheel cython pip install -r requirements.txt python setup.py develop ``` ## 数据集 - 训练: [ImageNet dataset](https://image-net.org/challenges/LSVRC/2012/2012-downloads.php) [DIV2K Train Data (HR images)](http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip) [Flickr2K](http://cv.snu.ac.kr/research/EDSR/Flickr2K.tar) > 训练数据处理请参考[BasicSR](https://github.com/XPixelGroup/BasicSR/blob/master/docs/DatasetPreparation.md) - 测试: `Classical SR Testing` 数据准备具体步骤如下: 1. 将数据存放在`datasets`目录下; 2. `BSD100` 和 `urban100` 需要再各自目录下新建 `GTmod4`、`LRbicx4` 两个新目录,并把原始图片存放进 `GTmod4` 目录下,然后在 `datasets` 目录下分别执行下面两条命令: ```bash python gen_LRbicx4.py --file_name ./BSD100 python gen_LRbicx4.py --file_name ./urban100 ``` 3. 建数据集的目录结构如下: `DF2K`:DIV2K和Flickr2的HR数据的整合 ``` ├── DF2K │ ├── DF2K_HR # HR 数据 │ ├── DF2K_HR_sub # 生成的 │ ├── DF2K_bicx4 # train_LR_bicubic_X4 数据 │ ├── DF2K_bicx4_sub # 生成的 ├── Set5 │ ├── GTmod12 │ ├── LRbicx2 │ ├── LRbicx3 │ ├── LRbicx4 ├── Set14 │ ├── GTmod12 │ ├── LRbicx2 │ ├── LRbicx3 │ ├── LRbicx4 ├── BSDS100 │ ├── GTmod4 # 原始图像 │ ├── LRbicx4 ├── urban100 │ ├── GTmod4 # 原始图像 │ ├── LRbicx4 ``` > 项目提供了`tiny_datasets`用于快速上手学习,如果使用`tiny_datasets`,需要对下面的代码内的地址进行替换,当前默认完整数据集的处理地址。 2. 因`DF2K`数据集是2K分辨率,而我们在训练的时候往往并不要那么大(常见的是128x128或者192x192的训练patch)。 因此我们可以先把2K的图片裁剪成有overlap的480x480的子图像块,然后再由`dataloader`从这个480x480的子图像块中随机crop出128x128或192x192的训练patch。 ```bash python extract_subimages.py # 将图片进行sub ``` 3. 生成`meta_info_file` ```bash python generate_meta_info.py ``` ## 训练 训练日志及权重保存在`experiments`文件中,预训练模型下载地址[预训练权重](#预训练权重)。 ### 单机多卡 ```bash # 默认 train_HAT_SRx4_finetune_from_ImageNet_pretrain.yml bash train.sh ``` ### 多机多卡 使用多节点的情况下,需要将使用节点写入hostfile文件,多节点每个节点一行,例如: c1xxxxxx slots=4。 1. [run_train_multi.sh](/run_train_multi.sh)中`18行`所需虚拟环境变量地址; 2. 修改[single_process.sh](./single_process.sh)中22行所需训练的`yaml文件`地址,如与默认一致,可不修改。 执行命令如下,训练日志保存在logs文件夹下 ```bash # 默认 train_HAT_SRx4_finetune_from_ImageNet_pretrain.yml bash run_train_multi.sh ``` ## 推理 预训练模型下载地址[预训练权重](#预训练权重),测试结果将保存到`./results`路径下。[HAT_SRx4_ImageNet-LR.yml](options/test/HAT_SRx4_ImageNet-LR.yml)适用于不使用`ground truth image`的推理过程。 ```bash # 默认 HAT_SRx4_ImageNet-pretrain.yml bash val.sh ``` ## result 基于`Real_HAT_GAN_SRx4_sharper.pth`的测试结果展示
### 精度 测试数据如下表中所示,使用的加速卡:Z100L。 | DEVICE | Params(M) | Multi-Adds(G) | Set5 | Set14 | BSD100 | Urban100 | | :------: | :------: | :------: | :------: |:------: | :------: | :------: | | Z100L | 20.8 | 102.4 | 33.1486 | 29.3587 | 25.4074 | 21.2687 | ## 应用场景 ### 算法类别 图像重建 ### 热点应用行业 交通,公安,制造 ## 预训练权重 预训练模型下载地址:[pretrained models](https://drive.google.com/drive/folders/1HpmReFfoUqUbnAOQ7rvOeNU3uf_m69w0) ## 源码仓库及问题反馈 - https://developer.sourcefind.cn/codes/modelzoo/hat_pytorch ## 参考资料 - https://github.com/XPixelGroup/HAT