Commit 8f9687f5 authored by mashun1's avatar mashun1
Browse files

ridcp

parents
Pipeline #617 canceled with stages
FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.13.1-centos7.6-dtk-23.04-py38-latest
COPY . /tmp/
WORKDIR /tmp
# RUN mkdir -p /root/.cache/torch/hub/checkpoints/
RUN pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple
# RUN cp -r ./extra_models/* /root/.cache/torch/hub/checkpoints/
## creative commons
# Attribution-NonCommercial 4.0 International
Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.
### Using Creative Commons Public Licenses
Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.
* __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors).
* __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees).
## Creative Commons Attribution-NonCommercial 4.0 International Public License
By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
### Section 1 – Definitions.
a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.
b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License.
c. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.
d. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.
e. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
f. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License.
g. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.
h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License.
i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.
j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.
k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.
l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.
### Section 2 – Scope.
a. ___License grant.___
1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and
B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only.
2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
3. __Term.__ The term of this Public License is specified in Section 6(a).
4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
5. __Downstream recipients.__
A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
B. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).
b. ___Other rights.___
1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.
2. Patent and trademark rights are not licensed under this Public License.
3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.
### Section 3 – License Conditions.
Your exercise of the Licensed Rights is expressly made subject to the following conditions.
a. ___Attribution.___
1. If You Share the Licensed Material (including in modified form), You must:
A. retain the following if it is supplied by the Licensor with the Licensed Material:
i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
ii. a copyright notice;
iii. a notice that refers to this Public License;
iv. a notice that refers to the disclaimer of warranties;
v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
4. If You Share Adapted Material You produce, the Adapter's License You apply must not prevent recipients of the Adapted Material from complying with this Public License.
### Section 4 – Sui Generis Database Rights.
Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:
a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only;
b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and
c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.
For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.
### Section 5 – Disclaimer of Warranties and Limitation of Liability.
a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__
b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__
c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
### Section 6 – Term and Termination.
a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.
b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:
1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
2. upon express reinstatement by the Licensor.
For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.
c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.
### Section 7 – Other Terms and Conditions.
a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.
b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.
### Section 8 – Interpretation.
a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.
b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.
c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.
d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.
> Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
>
> Creative Commons may be contacted at creativecommons.org
Copyright (c) 2022 MCG-NKU
\ No newline at end of file
# RIDCP
## 论文
**RIDCP: Revitalizing Real Image Dehazing via High-Quality Codebook Priors**
* https://arxiv.org/pdf/2304.03994.pdf
## 模型结构
ridcp先使用Encoder对输入进行特征提取得到Z,特征Z流向$`G_{vq}`$以及G,流向$`G_{vq}`$的Z被替换(CHM方法)为码本中的特征$`Z_{vq}`$,然后$`Z_{vq}`$通过$`G_{vq}`$获取各个层的特征$`F^i_{vq}`$,通过G获取到的特征$`F^i`$通过与$`F^i_{vq}`$的特征进行融合(NFA方法)得到新的特征$`F^i`$,直至G的最后一层,得到目标输出。
![img.png](imgs/img.png)
![img_2.png](imgs/img_2.png)
## 算法原理
用途:该算法通过有效利用先验知识(码本)来完成图像去雾功能。
原理:
1. 融合特征(NFA),消除仅使用码本特征导致的图像失真;
2. 使用CHM(距离计算方式)匹配更好的码本特征。
**NFA**
![Alt text](imgs/nfa1.png)
![Alt text](imgs/nfa2.png)
**CHM**
![Alt text](imgs/chm1.png)
![Alt text](imgs/chm2.png)
![Alt text](imgs/chm3.png)
这里 $`\hat{z}`$代表通过Encoder获取到的特征, $`z_k`$代表码本中的第k个特征,$`\alpha`$ 有最佳值21.25,$`f^k_h`$与$`f^k_c`$为统计值。
## 环境配置
### Docker (方法一)
docker pull FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.13.1-centos7.6-dtk-23.04-py38-latest
docker run --shm-size 10g --network=host --name=ridcp --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -it 2bb84d403fac bash <your IMAGE ID>
pip install -r requirements.txt
BASICSR_EXT=True python setup.py develop
### Dockerfile (方法二)
# 需要在对应的目录下
docker build -t <IMAGE_NAME>:<TAG> .
# <your IMAGE ID>用以上拉取的docker的镜像ID替换
docker run -it --shm-size 10g --network=host --name=ridcp --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined
<your IMAGE ID> bash
BASICSR_EXT=True python setup.py develop
### Anaconda (方法三)
1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装:
https://developer.hpccube.com/tool/
DTK驱动:dtk23.04
python:python3.8
torch:1.13.1
torchvision:0.14.1
torchaudio:0.13.1
deepspeed:0.9.2
apex:0.1
Tips:以上dtk驱动、python、torch等DCU相关工具版本需要严格一一对应
2、其它非特殊库参照requirements.txt安装
pip install -r requirements.txt
3、其他
BASICSR_EXT=True python setup.py develop
Tips: 使用阿里的pip源
## 数据集
下载地址(训练集):https://pan.baidu.com/s/1oX3AZkVlEa7S1sSO12r47Q (提取码:qqqo)
datasets
|- clear_images_no_haze_no_dark_500
|- xxx.jpg
|- ...
|- depth_500
|- xxx.npy
|- ...
下载地址(测试集):https://sites.google.com/view/reside-dehaze-datasets/reside-%CE%B2
RTTS
|- Annotations
|- xxx.xml
|- ...
|- ImageSets
|- Main
|- xxx.txt
|- ...
|- JPEGImages
|- xxx.png
|- ...
## 训练
模型下载:https://pan.baidu.com/s/1ps9dPmerWyXILxb6lkHihQ (HQPs,CHM,提取码:huea)
# 这里的X表示设备编号,如0,1,2,3
CUDA_VISIBLE_DEVICES=X,X,X,X python basicsr/train.py -opt options/RIDCP.yml
## 推理
模型下载:https://pan.baidu.com/s/1ps9dPmerWyXILxb6lkHihQ (RIDCP,提取码:huea)
python inference_ridcp.py -i examples -w pretrained_models/pretrained_RIDCP.pth -o results --use_weight --alpha -21.25
参数说明:
1. -i <路径>: 输入的图片或者文件夹
2. -w <路径>: 模型参数
3. -o <路径>: 处理完后的结果路径
## result
上图是输入图像,下图是输出图像
![img_6.png](imgs/img_6.png)
### 精度
|platform|FADE|BRISQUE|NIMA|
|:---|:---:|:---:|:---:|
|DCU|0.845|17.295|5.0|
|GPU|0.944|18.782|4.4267|
注意:FADE与BRISQUE的值越小越好,NIMA的值越大越好。
## 应用场景
### 算法类别
`图像去雾`
### 热点应用行业
`交通,医疗,环保,气象`
## 源码仓库及问题反馈
https://developer.hpccube.com/codes/modelzoo/ridcp_pytorch
## 参考资料
* https://github.com/RQ-Wu/RIDCP_dehazing
# https://github.com/xinntao/BasicSR
# flake8: noqa
from .archs import *
from .data import *
from .losses import *
from .models import *
from .test import *
from .train import *
from .utils import *
from basicsr.archs.femasr_arch import FeMaSRNet
from basicsr.archs.dehaze_vq_warp_arch import VQWarpDehazeNet
from basicsr.archs import build_network
from collections import Counter
import torch
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
haze_path = '../dataset/RTTS/JPEGImages/'
clear_path = '../dataset/OTS/clear_images_no_haze_no_dark/'
hq_opt = {
'gt_resolution': 256,
'norm_type': 'gn',
'act_type': 'silu',
'scale_factor': 1,
'codebook_params': [[32, 1024, 512]],
'LQ_stage': True,
'use_quantilize': True,
'use_semantic_loss': False
}
dict_total = []
if __name__ == '__main__':
ckpt_path = 'experiments/vq_warp_dehaze_residual/models/net_g_19000.pth'
# print(hq_opt['codebook_params'].size())
net_vq = VQWarpDehazeNet(**hq_opt).cuda()
net_vq.load_state_dict(torch.load(ckpt_path)['params'])
i = 0
for filename in os.listdir(haze_path):
filepath = os.path.join(haze_path, filename)
image = cv2.imread(filepath)[:, :, ::-1] / 255.0
image = torch.FloatTensor(image).unsqueeze(0).cuda().permute(0, 3, 1, 2)
_, index_list = net_vq.test(image)
dict_total += list(index_list[0].flatten(0).cpu().numpy())
i = i +1
print(i)
if i > 100:
break
result = Counter(dict_total)
result_image = np.zeros((32, 32))
print(len(result))
for k in result.keys():
k_index = int(k)
result_image[k // 32, k % 32] = result[k]
# result_image = (result_image - result_image.min()) / (result_image.max() - result_image.min())
# result_image = np.log(result_image + 1)
plt.imshow(result_image)
plt.savefig('visual_code_dehaze_exp.png')
\ No newline at end of file
import importlib
from copy import deepcopy
from os import path as osp
from basicsr.utils import get_root_logger, scandir
from basicsr.utils.registry import ARCH_REGISTRY
__all__ = ['build_network']
# automatically scan and import arch modules for registry
# scan all the files under the 'archs' folder and collect files ending with
# '_arch.py'
arch_folder = osp.dirname(osp.abspath(__file__))
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
# import all the arch modules
_arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames]
def build_network(opt):
opt = deepcopy(opt)
network_type = opt.pop('type')
net = ARCH_REGISTRY.get(network_type)(**opt)
logger = get_root_logger()
logger.info(f'Network [{net.__class__.__name__}] is created.')
return net
import collections.abc
import math
import torch
import torchvision
import warnings
from distutils.version import LooseVersion
from itertools import repeat
from torch import nn as nn
from torch.nn import functional as F
from torch.nn import init as init
from torch.nn.modules.batchnorm import _BatchNorm
from basicsr.utils import get_root_logger
@torch.no_grad()
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
"""Initialize network weights.
Args:
module_list (list[nn.Module] | nn.Module): Modules to be initialized.
scale (float): Scale initialized weights, especially for residual
blocks. Default: 1.
bias_fill (float): The value to fill bias. Default: 0
kwargs (dict): Other arguments for initialization function.
"""
if not isinstance(module_list, list):
module_list = [module_list]
for module in module_list:
for m in module.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, **kwargs)
m.weight.data *= scale
if m.bias is not None:
m.bias.data.fill_(bias_fill)
elif isinstance(m, nn.Linear):
init.kaiming_normal_(m.weight, **kwargs)
m.weight.data *= scale
if m.bias is not None:
m.bias.data.fill_(bias_fill)
elif isinstance(m, _BatchNorm):
init.constant_(m.weight, 1)
if m.bias is not None:
m.bias.data.fill_(bias_fill)
def make_layer(basic_block, num_basic_block, **kwarg):
"""Make layers by stacking the same blocks.
Args:
basic_block (nn.module): nn.module class for basic block.
num_basic_block (int): number of blocks.
Returns:
nn.Sequential: Stacked blocks in nn.Sequential.
"""
layers = []
for _ in range(num_basic_block):
layers.append(basic_block(**kwarg))
return nn.Sequential(*layers)
class ResidualBlockNoBN(nn.Module):
"""Residual block without BN.
It has a style of:
---Conv-ReLU-Conv-+-
|________________|
Args:
num_feat (int): Channel number of intermediate features.
Default: 64.
res_scale (float): Residual scale. Default: 1.
pytorch_init (bool): If set to True, use pytorch default init,
otherwise, use default_init_weights. Default: False.
"""
def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
super(ResidualBlockNoBN, self).__init__()
self.res_scale = res_scale
self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
self.relu = nn.ReLU(inplace=True)
if not pytorch_init:
default_init_weights([self.conv1, self.conv2], 0.1)
def forward(self, x):
identity = x
out = self.conv2(self.relu(self.conv1(x)))
return identity + out * self.res_scale
class Upsample(nn.Sequential):
"""Upsample module.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat):
m = []
if (scale & (scale - 1)) == 0: # scale = 2^n
for _ in range(int(math.log(scale, 2))):
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(2))
elif scale == 3:
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(3))
else:
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
super(Upsample, self).__init__(*m)
def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
"""Warp an image or feature map with optical flow.
Args:
x (Tensor): Tensor with size (n, c, h, w).
flow (Tensor): Tensor with size (n, h, w, 2), normal value.
interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
padding_mode (str): 'zeros' or 'border' or 'reflection'.
Default: 'zeros'.
align_corners (bool): Before pytorch 1.3, the default value is
align_corners=True. After pytorch 1.3, the default value is
align_corners=False. Here, we use the True as default.
Returns:
Tensor: Warped image or feature map.
"""
assert x.size()[-2:] == flow.size()[1:3]
_, _, h, w = x.size()
# create mesh grid
grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
grid.requires_grad = False
vgrid = grid + flow
# scale grid to [-1,1]
vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
# TODO, what if align_corners=False
return output
def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
"""Resize a flow according to ratio or shape.
Args:
flow (Tensor): Precomputed flow. shape [N, 2, H, W].
size_type (str): 'ratio' or 'shape'.
sizes (list[int | float]): the ratio for resizing or the final output
shape.
1) The order of ratio should be [ratio_h, ratio_w]. For
downsampling, the ratio should be smaller than 1.0 (i.e., ratio
< 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
ratio > 1.0).
2) The order of output_size should be [out_h, out_w].
interp_mode (str): The mode of interpolation for resizing.
Default: 'bilinear'.
align_corners (bool): Whether align corners. Default: False.
Returns:
Tensor: Resized flow.
"""
_, _, flow_h, flow_w = flow.size()
if size_type == 'ratio':
output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
elif size_type == 'shape':
output_h, output_w = sizes[0], sizes[1]
else:
raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
input_flow = flow.clone()
ratio_h = output_h / flow_h
ratio_w = output_w / flow_w
input_flow[:, 0, :, :] *= ratio_w
input_flow[:, 1, :, :] *= ratio_h
resized_flow = F.interpolate(
input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
return resized_flow
# TODO: may write a cpp file
def pixel_unshuffle(x, scale):
""" Pixel unshuffle.
Args:
x (Tensor): Input feature with shape (b, c, hh, hw).
scale (int): Downsample ratio.
Returns:
Tensor: the pixel unshuffled feature.
"""
b, c, hh, hw = x.size()
out_channel = c * (scale**2)
assert hh % scale == 0 and hw % scale == 0
h = hh // scale
w = hw // scale
x_view = x.view(b, c, h, scale, w, scale)
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
'The distribution of values may be incorrect.',
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
low = norm_cdf((a - mean) / std)
up = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [low, up], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * low - 1, 2 * up - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
r"""Fills the input Tensor with values drawn from a truncated
normal distribution.
From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
# From PyTorch
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple
import torch
import torch.nn.functional as F
from torch import nn as nn
import numpy as np
import math
from basicsr.utils.registry import ARCH_REGISTRY
from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
from .network_swinir import RSTB
from .ridcp_utils import ResBlock, CombineQuantBlock
from .vgg_arch import VGGFeatureExtractor
class DCNv2Pack(ModulatedDeformConvPack):
"""Modulated deformable conv for deformable alignment.
Different from the official DCNv2Pack, which generates offsets and masks
from the preceding features, this DCNv2Pack takes another different
features to generate offsets and masks.
Ref:
Delving Deep into Deformable Alignment in Video Super-Resolution.
"""
def forward(self, x, feat):
out = self.conv_offset(feat)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
offset_absmean = torch.mean(torch.abs(offset))
if offset_absmean > 50:
logger = get_root_logger()
logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
self.dilation, self.groups, self.deformable_groups)
class VectorQuantizer(nn.Module):
"""
see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
____________________________________________
Discretization bottleneck part of the VQ-VAE.
Inputs:
- n_e : number of embeddings
- e_dim : dimension of embedding
- beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
_____________________________________________
"""
def __init__(self, n_e, e_dim, weight_path='pretrained_models/weight_for_matching_dehazing_Flickr.pth', beta=0.25, LQ_stage=False, use_weight=True, weight_alpha=1.0):
super().__init__()
self.n_e = int(n_e)
self.e_dim = int(e_dim)
self.LQ_stage = LQ_stage
self.beta = beta
self.use_weight = use_weight
self.weight_alpha = weight_alpha
if self.use_weight:
self.weight = nn.Parameter(torch.load(weight_path))
self.weight.requires_grad = False
self.embedding = nn.Embedding(self.n_e, self.e_dim)
def dist(self, x, y):
if x.shape == y.shape:
return (x - y) ** 2
else:
return torch.sum(x ** 2, dim=1, keepdim=True) + \
torch.sum(y**2, dim=1) - 2 * \
torch.matmul(x, y.t())
def gram_loss(self, x, y):
b, h, w, c = x.shape
x = x.reshape(b, h*w, c)
y = y.reshape(b, h*w, c)
gmx = x.transpose(1, 2) @ x / (h*w)
gmy = y.transpose(1, 2) @ y / (h*w)
return (gmx - gmy).square().mean()
def forward(self, z, gt_indices=None, current_iter=None, weight_alpha=None):
"""
Args:
z: input features to be quantized, z (continuous) -> z_q (discrete)
z.shape = (batch, channel, height, width)
gt_indices: feature map of given indices, used for visualization.
"""
# reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.e_dim)
codebook = self.embedding.weight
d = self.dist(z_flattened, codebook)
if self.use_weight and self.LQ_stage:
if weight_alpha is not None:
self.weight_alpha = weight_alpha
d = d * torch.exp(self.weight_alpha * self.weight)
# find closest encodings
min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
min_encodings = torch.zeros(min_encoding_indices.shape[0], codebook.shape[0]).to(z)
min_encodings.scatter_(1, min_encoding_indices, 1)
if gt_indices is not None:
gt_indices = gt_indices.reshape(-1)
gt_min_indices = gt_indices.reshape_as(min_encoding_indices)
gt_min_onehot = torch.zeros(gt_min_indices.shape[0], codebook.shape[0]).to(z)
gt_min_onehot.scatter_(1, gt_min_indices, 1)
z_q_gt = torch.matmul(gt_min_onehot, codebook)
z_q_gt = z_q_gt.view(z.shape)
# get quantized latent vectors
z_q = torch.matmul(min_encodings, codebook)
z_q = z_q.view(z.shape)
e_latent_loss = torch.mean((z_q.detach() - z)**2)
q_latent_loss = torch.mean((z_q - z.detach())**2)
if self.LQ_stage and gt_indices is not None:
# codebook_loss = self.dist(z_q, z_q_gt.detach()).mean() \
# + self.beta * self.dist(z_q_gt.detach(), z)
codebook_loss = self.beta * self.dist(z_q_gt.detach(), z)
texture_loss = self.gram_loss(z, z_q_gt.detach())
# print("codebook loss:", codebook_loss.mean(), "\ntexture_loss: ", texture_loss.mean())
codebook_loss = codebook_loss + texture_loss
else:
codebook_loss = q_latent_loss + e_latent_loss * self.beta
# preserve gradients
z_q = z + (z_q - z).detach()
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q, codebook_loss, min_encoding_indices.reshape(z_q.shape[0], 1, z_q.shape[2], z_q.shape[3])
def get_codebook_entry(self, indices):
b, _, h, w = indices.shape
indices = indices.flatten().to(self.embedding.weight.device)
min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
min_encodings.scatter_(1, indices[:,None], 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
z_q = z_q.view(b, h, w, -1).permute(0, 3, 1, 2).contiguous()
return z_q
class SwinLayers(nn.Module):
def __init__(self, input_resolution=(32, 32), embed_dim=256,
blk_depth=6,
num_heads=8,
window_size=8,
**kwargs):
super().__init__()
self.swin_blks = nn.ModuleList()
for i in range(4):
layer = RSTB(embed_dim, input_resolution, blk_depth, num_heads, window_size, patch_size=1, **kwargs)
self.swin_blks.append(layer)
def forward(self, x):
b, c, h, w = x.shape
x = x.reshape(b, c, h*w).transpose(1, 2)
for m in self.swin_blks:
x = m(x, (h, w))
x = x.transpose(1, 2).reshape(b, c, h, w)
return x
class MultiScaleEncoder(nn.Module):
def __init__(self,
in_channel,
max_depth,
input_res=256,
channel_query_dict=None,
norm_type='gn',
act_type='leakyrelu',
LQ_stage=True,
**swin_opts,
):
super().__init__()
self.LQ_stage = LQ_stage
ksz = 3
self.in_conv = nn.Conv2d(in_channel, channel_query_dict[input_res], 4, padding=1)
self.blocks = nn.ModuleList()
self.up_blocks = nn.ModuleList()
self.max_depth = max_depth
res = input_res
for i in range(max_depth):
in_ch, out_ch = channel_query_dict[res], channel_query_dict[res // 2]
tmp_down_block = [
nn.Conv2d(in_ch, out_ch, ksz, stride=2, padding=1),
ResBlock(out_ch, out_ch, norm_type, act_type),
ResBlock(out_ch, out_ch, norm_type, act_type),
]
self.blocks.append(nn.Sequential(*tmp_down_block))
res = res // 2
if LQ_stage:
self.blocks.append(SwinLayers(**swin_opts))
def forward(self, input):
# input.requires_grad = True
x = self.in_conv(input)
# if self.LQ_stage:
# print('input: ', input.requires_grad)
# for p in self.in_conv.parameters():
# print('conv: ', p.requires_grad)
# print('first output:', x.requires_grad)
for idx, m in enumerate(self.blocks):
with torch.backends.cudnn.flags(enabled=False):
x = m(x)
return x
class DecoderBlock(nn.Module):
def __init__(self, in_channel, out_channel, norm_type='gn', act_type='leakyrelu'):
super().__init__()
self.block = []
self.block += [
nn.Upsample(scale_factor=2),
nn.Conv2d(in_channel, out_channel, 3, stride=1, padding=1),
ResBlock(out_channel, out_channel, norm_type, act_type),
ResBlock(out_channel, out_channel, norm_type, act_type),
]
self.block = nn.Sequential(*self.block)
def forward(self, input):
return self.block(input)
class WarpBlock(nn.Module):
def __init__(self, in_channel):
super().__init__()
self.offset = nn.Conv2d(in_channel * 2, in_channel, 3, stride=1, padding=1)
self.dcn = DCNv2Pack(in_channel, in_channel, 3, padding=1, deformable_groups=4)
def forward(self, x_vq, x_residual):
x_residual = self.offset(torch.cat([x_vq, x_residual], dim=1))
feat_after_warp = self.dcn(x_vq, x_residual)
return feat_after_warp
class MultiScaleDecoder(nn.Module):
def __init__(self,
in_channel,
max_depth,
input_res=256,
channel_query_dict=None,
norm_type='gn',
act_type='leakyrelu',
only_residual=False,
use_warp=True
):
super().__init__()
self.only_residual = only_residual
self.use_warp = use_warp
self.upsampler = nn.ModuleList()
self.warp = nn.ModuleList()
res = input_res // (2 ** max_depth)
for i in range(max_depth):
in_channel, out_channel = channel_query_dict[res], channel_query_dict[res * 2]
self.upsampler.append(nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(in_channel, out_channel, 3, stride=1, padding=1),
ResBlock(out_channel, out_channel, norm_type, act_type),
ResBlock(out_channel, out_channel, norm_type, act_type),
)
)
self.warp.append(WarpBlock(out_channel))
res = res * 2
def forward(self, input, code_decoder_output):
x = input
for idx, m in enumerate(self.upsampler):
with torch.backends.cudnn.flags(enabled=False):
if not self.only_residual:
x = m(x)
if self.use_warp:
x_vq = self.warp[idx](code_decoder_output[idx], x)
# print(idx, x.mean(), x_vq.mean())
x = x + x_vq * (x.mean() / x_vq.mean())
else:
x = x + code_decoder_output[idx]
else:
x = m(x)
# print()
return x
@ARCH_REGISTRY.register()
class VQWeightDehazeNet(nn.Module):
def __init__(self,
*,
in_channel=3,
codebook_params=None,
gt_resolution=256,
LQ_stage=False,
norm_type='gn',
act_type='silu',
use_quantize=True,
use_semantic_loss=False,
use_residual=True,
only_residual=False,
use_weight=False,
use_warp=True,
weight_alpha=1.0,
**ignore_kwargs):
super().__init__()
codebook_params = np.array(codebook_params)
self.codebook_scale = codebook_params[:, 0]
codebook_emb_num = codebook_params[:, 1].astype(int)
codebook_emb_dim = codebook_params[:, 2].astype(int)
self.use_quantize = use_quantize
self.in_channel = in_channel
self.gt_res = gt_resolution
self.LQ_stage = LQ_stage
self.use_residual = use_residual
self.only_residual = only_residual
self.use_weight = use_weight
self.use_warp = use_warp
self.weight_alpha = weight_alpha
channel_query_dict = {
8: 256,
16: 256,
32: 256,
64: 256,
128: 128,
256: 64,
512: 32,
}
# build encoder
self.max_depth = int(np.log2(gt_resolution // self.codebook_scale[0]))
self.multiscale_encoder = MultiScaleEncoder(
in_channel,
self.max_depth,
self.gt_res,
channel_query_dict,
norm_type, act_type, LQ_stage
)
if self.LQ_stage and self.use_residual:
self.multiscale_decoder = MultiScaleDecoder(
in_channel,
self.max_depth,
self.gt_res,
channel_query_dict,
norm_type, act_type, only_residual, use_warp=self.use_warp
)
# build decoder
self.decoder_group = nn.ModuleList()
for i in range(self.max_depth):
res = gt_resolution // 2**self.max_depth * 2**i
in_ch, out_ch = channel_query_dict[res], channel_query_dict[res * 2]
self.decoder_group.append(DecoderBlock(in_ch, out_ch, norm_type, act_type))
self.out_conv = nn.Conv2d(out_ch, 3, 3, 1, 1)
self.residual_conv = nn.Conv2d(out_ch, 3, 3, 1, 1)
# build multi-scale vector quantizers
self.quantize_group = nn.ModuleList()
self.before_quant_group = nn.ModuleList()
self.after_quant_group = nn.ModuleList()
for scale in range(0, codebook_params.shape[0]):
quantize = VectorQuantizer(
codebook_emb_num[scale],
codebook_emb_dim[scale],
LQ_stage=self.LQ_stage,
use_weight=self.use_weight,
weight_alpha=self.weight_alpha
)
self.quantize_group.append(quantize)
scale_in_ch = channel_query_dict[self.codebook_scale[scale]]
if scale == 0:
quant_conv_in_ch = scale_in_ch
comb_quant_in_ch1 = codebook_emb_dim[scale]
comb_quant_in_ch2 = 0
else:
quant_conv_in_ch = scale_in_ch * 2
comb_quant_in_ch1 = codebook_emb_dim[scale - 1]
comb_quant_in_ch2 = codebook_emb_dim[scale]
self.before_quant_group.append(nn.Conv2d(quant_conv_in_ch, codebook_emb_dim[scale], 1))
self.after_quant_group.append(CombineQuantBlock(comb_quant_in_ch1, comb_quant_in_ch2, scale_in_ch))
# semantic loss for HQ pretrain stage
self.use_semantic_loss = use_semantic_loss
if use_semantic_loss:
self.conv_semantic = nn.Sequential(
nn.Conv2d(512, 512, 1, 1, 0),
nn.ReLU(),
)
self.vgg_feat_layer = 'relu4_4'
self.vgg_feat_extractor = VGGFeatureExtractor([self.vgg_feat_layer])
def encode_and_decode(self, input, gt_indices=None, current_iter=None, weight_alpha=None):
# if self.training:
# for p in self.multiscale_encoder.parameters():
# p.requires_grad = True
enc_feats = self.multiscale_encoder(input)
if self.use_semantic_loss:
with torch.no_grad():
vgg_feat = self.vgg_feat_extractor(input)[self.vgg_feat_layer]
codebook_loss_list = []
indices_list = []
semantic_loss_list = []
code_decoder_output = []
quant_idx = 0
prev_dec_feat = None
prev_quant_feat = None
out_img = None
out_img_residual = None
x = enc_feats
for i in range(self.max_depth):
cur_res = self.gt_res // 2**self.max_depth * 2**i
if cur_res in self.codebook_scale: # needs to perform quantize
if prev_dec_feat is not None:
before_quant_feat = torch.cat((x, prev_dec_feat), dim=1)
else:
before_quant_feat = x
feat_to_quant = self.before_quant_group[quant_idx](before_quant_feat)
if weight_alpha is not None:
self.weight_alpha = weight_alpha
if gt_indices is not None:
z_quant, codebook_loss, indices = self.quantize_group[quant_idx](feat_to_quant, gt_indices[quant_idx], weight_alpha=self.weight_alpha)
else:
z_quant, codebook_loss, indices = self.quantize_group[quant_idx](feat_to_quant, weight_alpha=self.weight_alpha)
if self.use_semantic_loss:
semantic_z_quant = self.conv_semantic(z_quant)
semantic_loss = F.mse_loss(semantic_z_quant, vgg_feat)
semantic_loss_list.append(semantic_loss)
if not self.use_quantize:
z_quant = feat_to_quant
after_quant_feat = self.after_quant_group[quant_idx](z_quant, prev_quant_feat)
codebook_loss_list.append(codebook_loss)
indices_list.append(indices)
quant_idx += 1
prev_quant_feat = z_quant
x = after_quant_feat
x = self.decoder_group[i](x)
code_decoder_output.append(x)
prev_dec_feat = x
out_img = self.out_conv(x)
if self.LQ_stage and self.use_residual:
if self.only_residual:
residual_feature = self.multiscale_decoder(enc_feats, code_decoder_output)
else:
residual_feature = self.multiscale_decoder(enc_feats.detach(), code_decoder_output)
out_img_residual = self.residual_conv(residual_feature)
if len(codebook_loss_list) > 0:
codebook_loss = sum(codebook_loss_list)
else:
codebook_loss = 0
semantic_loss = sum(semantic_loss_list) if len(semantic_loss_list) else codebook_loss * 0
return out_img, out_img_residual, codebook_loss, semantic_loss, feat_to_quant, z_quant, indices_list
def decode_indices(self, indices):
assert len(indices.shape) == 4, f'shape of indices must be (b, 1, h, w), but got {indices.shape}'
z_quant = self.quantize_group[0].get_codebook_entry(indices)
x = self.after_quant_group[0](z_quant)
for m in self.decoder_group:
x = m(x)
out_img = self.out_conv(x)
return out_img
@torch.no_grad()
def test_tile(self, input, tile_size=240, tile_pad=16):
# return self.test(input)
"""It will first crop input images to tiles, and then process each tile.
Finally, all the processed tiles are merged into one images.
Modified from: https://github.com/xinntao/Real-ESRGAN/blob/master/realesrgan/utils.py
"""
batch, channel, height, width = input.shape
output_height = height
output_width = width
output_shape = (batch, channel, output_height, output_width)
# start with black image
output = input.new_zeros(output_shape)
tiles_x = math.ceil(width / tile_size)
tiles_y = math.ceil(height / tile_size)
# loop over all tiles
for y in range(tiles_y):
for x in range(tiles_x):
# extract tile from input image
ofs_x = x * tile_size
ofs_y = y * tile_size
# input tile area on total image
input_start_x = ofs_x
input_end_x = min(ofs_x + tile_size, width)
input_start_y = ofs_y
input_end_y = min(ofs_y + tile_size, height)
# input tile area on total image with padding
input_start_x_pad = max(input_start_x - tile_pad, 0)
input_end_x_pad = min(input_end_x + tile_pad, width)
input_start_y_pad = max(input_start_y - tile_pad, 0)
input_end_y_pad = min(input_end_y + tile_pad, height)
# input tile dimensions
input_tile_width = input_end_x - input_start_x
input_tile_height = input_end_y - input_start_y
tile_idx = y * tiles_x + x + 1
input_tile = input[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
# upscale tile
output_tile = self.test(input_tile)
# output tile area on total image
output_start_x = input_start_x
output_end_x = input_end_x
output_start_y = input_start_y
output_end_y = input_end_y
# output tile area without padding
output_start_x_tile = (input_start_x - input_start_x_pad)
output_end_x_tile = output_start_x_tile + input_tile_width
output_start_y_tile = (input_start_y - input_start_y_pad)
output_end_y_tile = output_start_y_tile + input_tile_height
# put tile into output image
output[:, :, output_start_y:output_end_y,
output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
output_start_x_tile:output_end_x_tile]
return output
@torch.no_grad()
def test(self, input, weight_alpha=None):
org_use_semantic_loss = self.use_semantic_loss
self.use_semantic_loss = False
# padding to multiple of window_size * 8
wsz = 32
_, _, h_old, w_old = input.shape
h_pad = (h_old // wsz + 1) * wsz - h_old
w_pad = (w_old // wsz + 1) * wsz - w_old
input = torch.cat([input, torch.flip(input, [2])], 2)[:, :, :h_old + h_pad, :]
input = torch.cat([input, torch.flip(input, [3])], 3)[:, :, :, :w_old + w_pad]
output_vq, output, _, _, _, after_quant, index = self.encode_and_decode(input, None, None, weight_alpha=weight_alpha)
if output is not None:
output = output[..., :h_old, :w_old]
if output_vq is not None:
output_vq = output_vq[..., :h_old, :w_old]
self.use_semantic_loss = org_use_semantic_loss
return output, index
def forward(self, input, gt_indices=None, weight_alpha=None):
if gt_indices is not None:
# in LQ training stage, need to pass GT indices for supervise.
dec, dec_residual, codebook_loss, semantic_loss, quant_before_feature, quant_after_feature, indices = self.encode_and_decode(input, gt_indices, weight_alpha=weight_alpha)
else:
# in HQ stage, or LQ test stage, no GT indices needed.
dec, dec_residual, codebook_loss, semantic_loss, quant_before_feature, quant_after_feature, indices = self.encode_and_decode(input, weight_alpha=weight_alpha)
return dec, dec_residual, codebook_loss, semantic_loss, quant_before_feature, quant_after_feature, indices
from basicsr.utils.registry import ARCH_REGISTRY
from torch import nn as nn
from torch.nn import functional as F
from torch.nn.utils import spectral_norm
@ARCH_REGISTRY.register()
class UNetDiscriminatorSN(nn.Module):
"""Defines a U-Net discriminator with spectral normalization (SN)
It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
Arg:
num_in_ch (int): Channel number of inputs. Default: 3.
num_feat (int): Channel number of base intermediate features. Default: 64.
skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
"""
def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
super(UNetDiscriminatorSN, self).__init__()
self.skip_connection = skip_connection
norm = spectral_norm
# the first convolution
self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
# downsample
self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
# upsample
self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
# extra convolutions
self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
def forward(self, x):
# downsample
x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
# upsample
x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
if self.skip_connection:
x4 = x4 + x2
x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
if self.skip_connection:
x5 = x5 + x1
x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
if self.skip_connection:
x6 = x6 + x0
# extra convolutions
out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
out = self.conv9(out)
return out
\ No newline at end of file
# -----------------------------------------------------------------------------------
# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
# Originally Written by Ze Liu, Modified by Jingyun Liang.
# -----------------------------------------------------------------------------------
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
def flops(self, N):
# calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# x = self.proj(x)
flops += N * self.dim * self.dim
return flops
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if self.shift_size > 0:
attn_mask = self.calculate_mask(self.input_resolution)
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def calculate_mask(self, x_size):
# calculate attention mask for SW-MSA
H, W = x_size
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
def forward(self, x, x_size):
H, W = x_size
B, L, C = x.shape
# assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
if self.input_resolution == x_size:
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
else:
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
def flops(self):
flops = 0
H, W = self.input_resolution
# norm1
flops += self.dim * H * W
# W-MSA/SW-MSA
nW = H * W / self.window_size / self.window_size
flops += nW * self.attn.flops(self.window_size * self.window_size)
# mlp
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
# norm2
flops += self.dim * H * W
return flops
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
def extra_repr(self) -> str:
return f"input_resolution={self.input_resolution}, dim={self.dim}"
def flops(self):
H, W = self.input_resolution
flops = H * W * self.dim
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
return flops
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x, x_size):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, x_size)
else:
x = blk(x, x_size)
if self.downsample is not None:
x = self.downsample(x)
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
def flops(self):
flops = 0
for blk in self.blocks:
flops += blk.flops()
if self.downsample is not None:
flops += self.downsample.flops()
return flops
class RSTB(nn.Module):
"""Residual Swin Transformer Block (RSTB).
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
img_size: Input image size.
patch_size: Patch size.
resi_connection: The convolutional block before residual connection.
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
img_size=224, patch_size=4, resi_connection='1conv'):
super(RSTB, self).__init__()
self.dim = dim
self.input_resolution = input_resolution
self.residual_group = BasicLayer(dim=dim,
input_resolution=input_resolution,
depth=depth,
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path,
norm_layer=norm_layer,
downsample=downsample,
use_checkpoint=use_checkpoint)
if resi_connection == '1conv':
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
elif resi_connection == '3conv':
# to save parameters and memory
self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(dim // 4, dim, 3, 1, 1))
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
norm_layer=None)
self.patch_unembed = PatchUnEmbed(
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
norm_layer=None)
def forward(self, x, x_size):
# with torch.backends.cudnn.flags(enabled=False):
return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
def flops(self):
flops = 0
flops += self.residual_group.flops()
H, W = self.input_resolution
flops += H * W * self.dim * self.dim * 9
flops += self.patch_embed.flops()
flops += self.patch_unembed.flops()
return flops
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
def flops(self):
flops = 0
H, W = self.img_size
if self.norm is not None:
flops += H * W * self.embed_dim
return flops
class PatchUnEmbed(nn.Module):
r""" Image to Patch Unembedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
def forward(self, x, x_size):
B, HW, C = x.shape
x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
return x
def flops(self):
flops = 0
return flops
class Upsample(nn.Sequential):
"""Upsample module.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat):
m = []
if (scale & (scale - 1)) == 0: # scale = 2^n
for _ in range(int(math.log(scale, 2))):
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(2))
elif scale == 3:
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(3))
else:
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
super(Upsample, self).__init__(*m)
class UpsampleOneStep(nn.Sequential):
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
Used in lightweight SR to save parameters.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
self.num_feat = num_feat
self.input_resolution = input_resolution
m = []
m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
m.append(nn.PixelShuffle(scale))
super(UpsampleOneStep, self).__init__(*m)
def flops(self):
H, W = self.input_resolution
flops = H * W * self.num_feat * 3 * 9
return flops
class SwinIR(nn.Module):
r""" SwinIR
A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
Args:
img_size (int | tuple(int)): Input image size. Default 64
patch_size (int | tuple(int)): Patch size. Default: 1
in_chans (int): Number of input image channels. Default: 3
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
img_range: Image range. 1. or 255.
upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
"""
def __init__(self, img_size=64, patch_size=1, in_chans=3,
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
**kwargs):
super(SwinIR, self).__init__()
num_in_ch = in_chans
num_out_ch = in_chans
num_feat = 64
self.img_range = img_range
if in_chans == 3:
rgb_mean = (0.4488, 0.4371, 0.4040)
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
else:
self.mean = torch.zeros(1, 1, 1, 1)
self.upscale = upscale
self.upsampler = upsampler
self.window_size = window_size
#####################################################################################################
################################### 1, shallow feature extraction ###################################
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
#####################################################################################################
################################### 2, deep feature extraction ######################################
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = embed_dim
self.mlp_ratio = mlp_ratio
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# merge non-overlapping patches into image
self.patch_unembed = PatchUnEmbed(
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build Residual Swin Transformer blocks (RSTB)
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = RSTB(dim=embed_dim,
input_resolution=(patches_resolution[0],
patches_resolution[1]),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
norm_layer=norm_layer,
downsample=None,
use_checkpoint=use_checkpoint,
img_size=img_size,
patch_size=patch_size,
resi_connection=resi_connection
)
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
# build the last conv layer in deep feature extraction
if resi_connection == '1conv':
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
elif resi_connection == '3conv':
# to save parameters and memory
self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
#####################################################################################################
################################ 3, high quality image reconstruction ################################
if self.upsampler == 'pixelshuffle':
# for classical SR
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.upsample = Upsample(upscale, num_feat)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
elif self.upsampler == 'pixelshuffledirect':
# for lightweight SR (to save parameters)
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
(patches_resolution[0], patches_resolution[1]))
elif self.upsampler == 'nearest+conv':
# for real-world SR (less artifacts)
assert self.upscale == 4, 'only support x4 now.'
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
else:
# for image denoising and JPEG compression artifact reduction
self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'absolute_pos_embed'}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'relative_position_bias_table'}
def check_image_size(self, x):
_, _, h, w = x.size()
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
return x
def forward_features(self, x):
x_size = (x.shape[2], x.shape[3])
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x, x_size)
x = self.norm(x) # B L C
x = self.patch_unembed(x, x_size)
return x
def forward(self, x):
H, W = x.shape[2:]
x = self.check_image_size(x)
self.mean = self.mean.type_as(x)
x = (x - self.mean) * self.img_range
if self.upsampler == 'pixelshuffle':
# for classical SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.conv_before_upsample(x)
x = self.conv_last(self.upsample(x))
elif self.upsampler == 'pixelshuffledirect':
# for lightweight SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.upsample(x)
elif self.upsampler == 'nearest+conv':
# for real-world SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.conv_before_upsample(x)
x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
x = self.conv_last(self.lrelu(self.conv_hr(x)))
else:
# for image denoising and JPEG compression artifact reduction
x_first = self.conv_first(x)
res = self.conv_after_body(self.forward_features(x_first)) + x_first
x = x + self.conv_last(res)
x = x / self.img_range + self.mean
return x[:, :, :H*self.upscale, :W*self.upscale]
def flops(self):
flops = 0
H, W = self.patches_resolution
flops += H * W * 3 * self.embed_dim * 9
flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers):
flops += layer.flops()
flops += H * W * 3 * self.embed_dim * self.embed_dim
flops += self.upsample.flops()
return flops
if __name__ == '__main__':
upscale = 4
window_size = 8
height = (1024 // upscale // window_size + 1) * window_size
width = (720 // upscale // window_size + 1) * window_size
model = SwinIR(upscale=2, img_size=(height, width),
window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
print(model)
print(height, width, model.flops() / 1e9)
x = torch.randn((1, 3, height, width))
x = model(x)
print(x.shape)
import torch
from torch.nn import functional as F
from torch import nn as nn
class NormLayer(nn.Module):
"""Normalization Layers.
------------
# Arguments
- channels: input channels, for batch norm and instance norm.
- input_size: input shape without batch size, for layer norm.
"""
def __init__(self, channels, norm_type='bn'):
super(NormLayer, self).__init__()
norm_type = norm_type.lower()
self.norm_type = norm_type
self.channels = channels
if norm_type == 'bn':
self.norm = nn.BatchNorm2d(channels, affine=True)
elif norm_type == 'in':
self.norm = nn.InstanceNorm2d(channels, affine=False)
elif norm_type == 'gn':
self.norm = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True)
elif norm_type == 'none':
self.norm = lambda x: x*1.0
else:
assert 1==0, 'Norm type {} not support.'.format(norm_type)
def forward(self, x):
return self.norm(x)
class ActLayer(nn.Module):
"""activation layer.
------------
# Arguments
- relu type: type of relu layer, candidates are
- ReLU
- LeakyReLU: default relu slope 0.2
- PRelu
- SELU
- none: direct pass
"""
def __init__(self, channels, relu_type='leakyrelu'):
super(ActLayer, self).__init__()
relu_type = relu_type.lower()
if relu_type == 'relu':
self.func = nn.ReLU(True)
elif relu_type == 'leakyrelu':
self.func = nn.LeakyReLU(0.2, inplace=True)
elif relu_type == 'prelu':
self.func = nn.PReLU(channels)
elif relu_type == 'none':
self.func = lambda x: x*1.0
elif relu_type == 'silu':
self.func = nn.SiLU(True)
elif relu_type == 'gelu':
self.func = nn.GELU()
else:
assert 1==0, 'activation type {} not support.'.format(relu_type)
def forward(self, x):
return self.func(x)
class ResBlock(nn.Module):
"""
Use preactivation version of residual block, the same as taming
"""
def __init__(self, in_channel, out_channel, norm_type='gn', act_type='leakyrelu'):
super(ResBlock, self).__init__()
self.conv = nn.Sequential(
NormLayer(in_channel, norm_type),
ActLayer(in_channel, act_type),
nn.Conv2d(in_channel, out_channel, 3, stride=1, padding=1),
NormLayer(out_channel, norm_type),
ActLayer(out_channel, act_type),
nn.Conv2d(out_channel, out_channel, 3, stride=1, padding=1),
)
def forward(self, input):
with torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False):
res = self.conv(input)
out = res + input
return out
class CombineQuantBlock(nn.Module):
def __init__(self, in_ch1, in_ch2, out_channel):
super().__init__()
self.conv = nn.Conv2d(in_ch1 + in_ch2, out_channel, 3, 1, 1)
def forward(self, input1, input2=None):
if input2 is not None:
input2 = F.interpolate(input2, input1.shape[2:])
input = torch.cat((input1, input2), dim=1)
else:
input = input1
out = self.conv(input)
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment