Commit 97451368 authored by chenych's avatar chenych
Browse files

Update to dtk24.04.1

parent a7d973fa
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
- https://arxiv.org/abs/2309.05239 - https://arxiv.org/abs/2309.05239
## 模型结构 ## 模型结构
HAT包括三个部分,包括浅层特征提取、深层特征提取和图像重建。 HAT包括三个部分,包括浅层特征提取、深层特征提取和图像重建。
<div align=center> <div align=center>
...@@ -11,6 +12,7 @@ HAT包括三个部分,包括浅层特征提取、深层特征提取和图像 ...@@ -11,6 +12,7 @@ HAT包括三个部分,包括浅层特征提取、深层特征提取和图像
</div> </div>
## 算法原理 ## 算法原理
HAT方法结合了通道注意力和基于窗口的自注意力方案,利用两者的互补优势。此外,引入了重叠的跨注意力模块来增强相邻窗口特征之间的交互,更好地聚合跨窗口信息。在训练阶段,HAT还采用了相同的任务预训练策略,以进一步挖掘模型的潜力进行进一步改进。得益于这些设计,HAT可以激活更多的像素进行重建,从而显著提高性能。 HAT方法结合了通道注意力和基于窗口的自注意力方案,利用两者的互补优势。此外,引入了重叠的跨注意力模块来增强相邻窗口特征之间的交互,更好地聚合跨窗口信息。在训练阶段,HAT还采用了相同的任务预训练策略,以进一步挖掘模型的潜力进行进一步改进。得益于这些设计,HAT可以激活更多的像素进行重建,从而显著提高性能。
<div align=center> <div align=center>
...@@ -21,16 +23,19 @@ HAT方法结合了通道注意力和基于窗口的自注意力方案,利用 ...@@ -21,16 +23,19 @@ HAT方法结合了通道注意力和基于窗口的自注意力方案,利用
-v 路径、docker_name和imageID根据实际情况修改 -v 路径、docker_name和imageID根据实际情况修改
### Docker(方法一) ### Docker(方法一)
```bash ```bash
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.13.1-centos7.6-dtk23.10-py38 docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10
docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro --shm-size=32G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro --shm-size=32G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash
cd /your_code_path/hat_pytorch cd /your_code_path/hat_pytorch
pip install wheel cython
pip install -r requirements.txt pip install -r requirements.txt
python setup.py develop python setup.py develop
``` ```
### Dockerfile(方法二) ### Dockerfile(方法二)
```bash ```bash
cd ./docker cd ./docker
...@@ -38,28 +43,33 @@ docker build --no-cache -t hat:latest . ...@@ -38,28 +43,33 @@ docker build --no-cache -t hat:latest .
docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro --shm-size=32G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro --shm-size=32G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash
cd /your_code_path/hat_pytorch cd /your_code_path/hat_pytorch
pip install wheel cython
pip install -r requirements.txt
python setup.py develop python setup.py develop
``` ```
### Anaconda(方法三) ### Anaconda(方法三)
1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装: https://developer.hpccube.com/tool/
1. 关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装: https://developer.hpccube.com/tool/
```bash ```bash
DTK软件栈:dtk23.10 DTK软件栈:dtk24.04.1
python:python3.8 python:3.10
torch:1.13.1 torch:2.1
torchvision:0.14.1 torchvision:0.16.0
``` ```
`Tips:以上dtk软件栈、python、torch等DCU相关工具版本需要严格一一对应` `Tips:以上dtk软件栈、python、torch等DCU相关工具版本需要严格一一对应`
2其他非特殊库直接按照requirements.txt安装 2. 其他非特殊库直接按照requirements.txt安装
```bash ```bash
pip install wheel cython
pip install -r requirements.txt pip install -r requirements.txt
python setup.py develop python setup.py develop
``` ```
## 数据集 ## 数据集
- 训练: - 训练:
[ImageNet dataset](https://image-net.org/challenges/LSVRC/2012/2012-downloads.php) [ImageNet dataset](https://image-net.org/challenges/LSVRC/2012/2012-downloads.php)
...@@ -67,21 +77,23 @@ python setup.py develop ...@@ -67,21 +77,23 @@ python setup.py develop
[Flickr2K](http://113.200.138.88:18080/aidatasets/project-dependency/flickr2k/-/blob/master/Flickr2K.tar) [Flickr2K](http://113.200.138.88:18080/aidatasets/project-dependency/flickr2k/-/blob/master/Flickr2K.tar)
训练数据处理请参考[BasicSR](https://github.com/XPixelGroup/BasicSR/blob/master/docs/DatasetPreparation.md) > 训练数据处理请参考[BasicSR](https://github.com/XPixelGroup/BasicSR/blob/master/docs/DatasetPreparation.md)
- 测试: - 测试:
[Classical SR Testing](http://113.200.138.88:18080/aidatasets/project-dependency/classical-sr) [Classical SR Testing](http://113.200.138.88:18080/aidatasets/project-dependency/classical-sr)
数据准备具体步骤如下: 数据准备具体步骤如下:
1. 将数据存放在`datasets`目录下; 1. 将数据存放在`datasets`目录下;
2. `BSD100``urban100`需要再各自目录下新`GTmod4``LRbicx4`两个新目录,并把原始图片存放进`GTmod4`目录下,然后`datasets`目录下分别执行下面两条命令: 2. `BSD100``urban100` 需要再各自目录下新`GTmod4``LRbicx4` 两个新目录,并把原始图片存放进 `GTmod4` 目录下,然后`datasets` 目录下分别执行下面两条命令:
```bash ```bash
python gen_LRbicx4.py --file_name ./BSD100 python gen_LRbicx4.py --file_name ./BSD100
python gen_LRbicx4.py --file_name ./urban100 python gen_LRbicx4.py --file_name ./urban100
``` ```
3. 建数据集的目录结构如下: 3. 建数据集的目录结构如下:
`DF2K`:DIV2K和Flickr2的HR数据的整合 `DF2K`:DIV2K和Flickr2的HR数据的整合
``` ```
├── DF2K ├── DF2K
...@@ -107,29 +119,34 @@ python gen_LRbicx4.py --file_name ./urban100 ...@@ -107,29 +119,34 @@ python gen_LRbicx4.py --file_name ./urban100
│ ├── LRbicx4 │ ├── LRbicx4
``` ```
Tips: 项目提供了`tiny_datasets`用于快速上手学习,如果使用`tiny_datasets`,需要对下面的代码内的地址进行替换,当前默认完整数据集的处理地址。 > 项目提供了`tiny_datasets`用于快速上手学习,如果使用`tiny_datasets`,需要对下面的代码内的地址进行替换,当前默认完整数据集的处理地址。
2.`DF2K`数据集是2K分辨率,而我们在训练的时候往往并不要那么大(常见的是128x128或者192x192的训练patch)。因此我们可以先把2K的图片裁剪成有overlap的480x480的子图像块,然后再由`dataloader`从这个480x480的子图像块中随机crop出128x128或192x192的训练patch。 2.`DF2K`数据集是2K分辨率,而我们在训练的时候往往并不要那么大(常见的是128x128或者192x192的训练patch)。
因此我们可以先把2K的图片裁剪成有overlap的480x480的子图像块,然后再由`dataloader`从这个480x480的子图像块中随机crop出128x128或192x192的训练patch。
```bash ```bash
python extract_subimages.py # 将图片进行sub python extract_subimages.py # 将图片进行sub
``` ```
3. 生成`meta_info_file` 3. 生成`meta_info_file`
```bash ```bash
python generate_meta_info.py python generate_meta_info.py
``` ```
## 训练 ## 训练
训练日志及权重保存在`/experiments`文件中,预训练模型下载地址[预训练权重](#预训练权重)
训练日志及权重保存在`experiments`文件中,预训练模型下载地址[预训练权重](#预训练权重)
### 单机多卡 ### 单机多卡
```bash ```bash
# 默认 train_HAT_SRx4_finetune_from_ImageNet_pretrain.yml # 默认 train_HAT_SRx4_finetune_from_ImageNet_pretrain.yml
bash train.sh bash train.sh
``` ```
### 多机多卡 ### 多机多卡
使用多节点的情况下,需要将使用节点写入hostfile文件,多节点每个节点一行,例如: c1xxxxxx slots=4。 使用多节点的情况下,需要将使用节点写入hostfile文件,多节点每个节点一行,例如: c1xxxxxx slots=4。
1. [run_train_multi.sh](/run_train_multi.sh)`18行`所需虚拟环境变量地址; 1. [run_train_multi.sh](/run_train_multi.sh)`18行`所需虚拟环境变量地址;
...@@ -137,12 +154,14 @@ bash train.sh ...@@ -137,12 +154,14 @@ bash train.sh
2. 修改[single_process.sh](./single_process.sh)中22行所需训练的`yaml文件`地址,如与默认一致,可不修改。 2. 修改[single_process.sh](./single_process.sh)中22行所需训练的`yaml文件`地址,如与默认一致,可不修改。
执行命令如下,训练日志保存在logs文件夹下 执行命令如下,训练日志保存在logs文件夹下
```bash ```bash
# 默认 train_HAT_SRx4_finetune_from_ImageNet_pretrain.yml # 默认 train_HAT_SRx4_finetune_from_ImageNet_pretrain.yml
bash run_train_multi.sh bash run_train_multi.sh
``` ```
## 推理 ## 推理
预训练模型下载地址[预训练权重](#预训练权重),测试结果将保存到`./results`路径下。[HAT_SRx4_ImageNet-LR.yml](options/test/HAT_SRx4_ImageNet-LR.yml)适用于不使用`ground truth image`的推理过程。 预训练模型下载地址[预训练权重](#预训练权重),测试结果将保存到`./results`路径下。[HAT_SRx4_ImageNet-LR.yml](options/test/HAT_SRx4_ImageNet-LR.yml)适用于不使用`ground truth image`的推理过程。
```bash ```bash
# 默认 HAT_SRx4_ImageNet-pretrain.yml # 默认 HAT_SRx4_ImageNet-pretrain.yml
...@@ -150,6 +169,7 @@ bash val.sh ...@@ -150,6 +169,7 @@ bash val.sh
``` ```
## result ## result
基于`Real_HAT_GAN_SRx4_sharper.pth`的测试结果展示 基于`Real_HAT_GAN_SRx4_sharper.pth`的测试结果展示
<div align=center> <div align=center>
...@@ -157,6 +177,7 @@ bash val.sh ...@@ -157,6 +177,7 @@ bash val.sh
</div> </div>
### 精度 ### 精度
测试数据如下表中所示,使用的加速卡:Z100L。 测试数据如下表中所示,使用的加速卡:Z100L。
| DEVICE | Params(M) | Multi-Adds(G) | Set5 | Set14 | BSD100 | Urban100 | | DEVICE | Params(M) | Multi-Adds(G) | Set5 | Set14 | BSD100 | Urban100 |
......
FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.13.1-centos7.6-dtk23.10-py38 FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10
\ No newline at end of file \ No newline at end of file
...@@ -12,4 +12,4 @@ echo "Training start ..." ...@@ -12,4 +12,4 @@ echo "Training start ..."
# 9行 datasets: 请确认数据地址正确 # 9行 datasets: 请确认数据地址正确
# 76行 pretrain_network_g: 请确认预训练模型地址正确 # 76行 pretrain_network_g: 请确认预训练模型地址正确
python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 hat/train.py -opt options/train/train_HAT_SRx4_finetune_from_ImageNet_pretrain.yml --launcher pytorch torchrun --nproc_per_node=8 hat/train.py -opt options/train/train_HAT_SRx4_finetune_from_ImageNet_pretrain.yml --launcher pytorch
\ No newline at end of file \ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment