Commit a8ada82f authored by chenych's avatar chenych
Browse files

First commit

parent 537691da
720p_240fps_1 100 (720,1280,3) 00000
720p_240fps_3 100 (720,1280,3) 00000
720p_240fps_5 100 (720,1280,3) 00000
720p_240fps_6 100 (720,1280,3) 00000
C0036 63 (720,1280,3) 00000
C0041 229 (720,1280,3) 00000
GOPR9633 35 (720,1280,3) 00000
GOPR9636 48 (720,1280,3) 00000
GOPR9641 137 (720,1280,3) 00000
GOPR9642 83 (720,1280,3) 00000
GOPR9643 97 (720,1280,3) 00000
GOPR9645 65 (720,1280,3) 00000
GOPR9648 88 (720,1280,3) 00000
GOPR9652 71 (720,1280,3) 00000
IMG_0001 100 (720,1280,3) 00000
IMG_0004 100 (720,1280,3) 00000
IMG_0005 100 (720,1280,3) 00000
IMG_0008 100 (720,1280,3) 00000
IMG_0009 41 (720,1280,3) 00030
IMG_0010 100 (720,1280,3) 00000
IMG_0011 100 (720,1280,3) 00000
IMG_0012 100 (720,1280,3) 00000
IMG_0013 100 (720,1280,3) 00000
IMG_0014 100 (720,1280,3) 00000
IMG_0015 100 (720,1280,3) 00000
IMG_0017 100 (720,1280,3) 00000
IMG_0019 100 (720,1280,3) 00000
IMG_0022 100 (720,1280,3) 00000
IMG_0023 100 (720,1280,3) 00000
IMG_0024 100 (720,1280,3) 00000
IMG_0025 100 (720,1280,3) 00000
IMG_0026 100 (720,1280,3) 00000
IMG_0028 100 (720,1280,3) 00000
IMG_0029 100 (720,1280,3) 00000
IMG_0034 100 (720,1280,3) 00000
IMG_0035 100 (720,1280,3) 00000
IMG_0036 100 (720,1280,3) 00000
IMG_0038 100 (720,1280,3) 00000
IMG_0040 100 (720,1280,3) 00000
IMG_0041 100 (720,1280,3) 00000
IMG_0042 100 (720,1280,3) 00000
IMG_0043 100 (720,1280,3) 00000
IMG_0044 100 (720,1280,3) 00000
IMG_0045 61 (720,1280,3) 00000
IMG_0046 100 (720,1280,3) 00000
IMG_0047 100 (720,1280,3) 00000
IMG_0051 100 (720,1280,3) 00000
IMG_0053 100 (720,1280,3) 00000
IMG_0054 100 (720,1280,3) 00000
IMG_0055 100 (720,1280,3) 00000
IMG_0056 100 (720,1280,3) 00000
IMG_0058 100 (720,1280,3) 00000
IMG_0150 50 (720,1280,3) 00000
IMG_0151 60 (720,1280,3) 00000
IMG_0153 120 (720,1280,3) 00000
IMG_0155 78 (720,1280,3) 00000
IMG_0164 46 (720,1280,3) 00000
IMG_0173 70 (720,1280,3) 00000
IMG_0180 50 (720,1280,3) 00000
IMG_0183 113 (720,1280,3) 00000
IMG_0200 103 (720,1280,3) 00000
GOPR0372_07_00 100 (720,1280,3) 000047
GOPR0372_07_01 75 (720,1280,3) 000601
GOPR0374_11_00 150 (720,1280,3) 000001
GOPR0374_11_01 80 (720,1280,3) 000203
GOPR0374_11_02 100 (720,1280,3) 000541
GOPR0374_11_03 48 (720,1280,3) 002481
GOPR0378_13_00 110 (720,1280,3) 000041
GOPR0379_11_00 100 (720,1280,3) 000188
GOPR0380_11_00 60 (720,1280,3) 000134
GOPR0384_11_01 100 (720,1280,3) 000951
GOPR0384_11_02 100 (720,1280,3) 001301
GOPR0384_11_03 100 (720,1280,3) 002101
GOPR0384_11_04 100 (720,1280,3) 002801
GOPR0385_11_00 100 (720,1280,3) 000101
GOPR0386_11_00 100 (720,1280,3) 000247
GOPR0477_11_00 80 (720,1280,3) 000001
GOPR0857_11_00 100 (720,1280,3) 000001
GOPR0868_11_01 100 (720,1280,3) 000221
GOPR0868_11_02 100 (720,1280,3) 000681
GOPR0871_11_01 100 (720,1280,3) 000181
GOPR0881_11_00 100 (720,1280,3) 000001
GOPR0884_11_00 100 (720,1280,3) 000186
000 100 (720,1280,3) 00000000
001 100 (720,1280,3) 00000000
002 100 (720,1280,3) 00000000
003 100 (720,1280,3) 00000000
004 100 (720,1280,3) 00000000
005 100 (720,1280,3) 00000000
006 100 (720,1280,3) 00000000
007 100 (720,1280,3) 00000000
008 100 (720,1280,3) 00000000
009 100 (720,1280,3) 00000000
010 100 (720,1280,3) 00000000
011 100 (720,1280,3) 00000000
012 100 (720,1280,3) 00000000
013 100 (720,1280,3) 00000000
014 100 (720,1280,3) 00000000
015 100 (720,1280,3) 00000000
016 100 (720,1280,3) 00000000
017 100 (720,1280,3) 00000000
018 100 (720,1280,3) 00000000
019 100 (720,1280,3) 00000000
020 100 (720,1280,3) 00000000
021 100 (720,1280,3) 00000000
022 100 (720,1280,3) 00000000
023 100 (720,1280,3) 00000000
024 100 (720,1280,3) 00000000
025 100 (720,1280,3) 00000000
026 100 (720,1280,3) 00000000
027 100 (720,1280,3) 00000000
028 100 (720,1280,3) 00000000
029 100 (720,1280,3) 00000000
030 100 (720,1280,3) 00000000
031 100 (720,1280,3) 00000000
032 100 (720,1280,3) 00000000
033 100 (720,1280,3) 00000000
034 100 (720,1280,3) 00000000
035 100 (720,1280,3) 00000000
036 100 (720,1280,3) 00000000
037 100 (720,1280,3) 00000000
038 100 (720,1280,3) 00000000
039 100 (720,1280,3) 00000000
040 100 (720,1280,3) 00000000
041 100 (720,1280,3) 00000000
042 100 (720,1280,3) 00000000
043 100 (720,1280,3) 00000000
044 100 (720,1280,3) 00000000
045 100 (720,1280,3) 00000000
046 100 (720,1280,3) 00000000
047 100 (720,1280,3) 00000000
048 100 (720,1280,3) 00000000
049 100 (720,1280,3) 00000000
050 100 (720,1280,3) 00000000
051 100 (720,1280,3) 00000000
052 100 (720,1280,3) 00000000
053 100 (720,1280,3) 00000000
054 100 (720,1280,3) 00000000
055 100 (720,1280,3) 00000000
056 100 (720,1280,3) 00000000
057 100 (720,1280,3) 00000000
058 100 (720,1280,3) 00000000
059 100 (720,1280,3) 00000000
060 100 (720,1280,3) 00000000
061 100 (720,1280,3) 00000000
062 100 (720,1280,3) 00000000
063 100 (720,1280,3) 00000000
064 100 (720,1280,3) 00000000
065 100 (720,1280,3) 00000000
066 100 (720,1280,3) 00000000
067 100 (720,1280,3) 00000000
068 100 (720,1280,3) 00000000
069 100 (720,1280,3) 00000000
070 100 (720,1280,3) 00000000
071 100 (720,1280,3) 00000000
072 100 (720,1280,3) 00000000
073 100 (720,1280,3) 00000000
074 100 (720,1280,3) 00000000
075 100 (720,1280,3) 00000000
076 100 (720,1280,3) 00000000
077 100 (720,1280,3) 00000000
078 100 (720,1280,3) 00000000
079 100 (720,1280,3) 00000000
080 100 (720,1280,3) 00000000
081 100 (720,1280,3) 00000000
082 100 (720,1280,3) 00000000
083 100 (720,1280,3) 00000000
084 100 (720,1280,3) 00000000
085 100 (720,1280,3) 00000000
086 100 (720,1280,3) 00000000
087 100 (720,1280,3) 00000000
088 100 (720,1280,3) 00000000
089 100 (720,1280,3) 00000000
090 100 (720,1280,3) 00000000
091 100 (720,1280,3) 00000000
092 100 (720,1280,3) 00000000
093 100 (720,1280,3) 00000000
094 100 (720,1280,3) 00000000
095 100 (720,1280,3) 00000000
096 100 (720,1280,3) 00000000
097 100 (720,1280,3) 00000000
098 100 (720,1280,3) 00000000
099 100 (720,1280,3) 00000000
100 100 (720,1280,3) 00000000
101 100 (720,1280,3) 00000000
102 100 (720,1280,3) 00000000
103 100 (720,1280,3) 00000000
104 100 (720,1280,3) 00000000
105 100 (720,1280,3) 00000000
106 100 (720,1280,3) 00000000
107 100 (720,1280,3) 00000000
108 100 (720,1280,3) 00000000
109 100 (720,1280,3) 00000000
110 100 (720,1280,3) 00000000
111 100 (720,1280,3) 00000000
112 100 (720,1280,3) 00000000
113 100 (720,1280,3) 00000000
114 100 (720,1280,3) 00000000
115 100 (720,1280,3) 00000000
116 100 (720,1280,3) 00000000
117 100 (720,1280,3) 00000000
118 100 (720,1280,3) 00000000
119 100 (720,1280,3) 00000000
120 100 (720,1280,3) 00000000
121 100 (720,1280,3) 00000000
122 100 (720,1280,3) 00000000
123 100 (720,1280,3) 00000000
124 100 (720,1280,3) 00000000
125 100 (720,1280,3) 00000000
126 100 (720,1280,3) 00000000
127 100 (720,1280,3) 00000000
128 100 (720,1280,3) 00000000
129 100 (720,1280,3) 00000000
130 100 (720,1280,3) 00000000
131 100 (720,1280,3) 00000000
132 100 (720,1280,3) 00000000
133 100 (720,1280,3) 00000000
134 100 (720,1280,3) 00000000
135 100 (720,1280,3) 00000000
136 100 (720,1280,3) 00000000
137 100 (720,1280,3) 00000000
138 100 (720,1280,3) 00000000
139 100 (720,1280,3) 00000000
140 100 (720,1280,3) 00000000
141 100 (720,1280,3) 00000000
142 100 (720,1280,3) 00000000
143 100 (720,1280,3) 00000000
144 100 (720,1280,3) 00000000
145 100 (720,1280,3) 00000000
146 100 (720,1280,3) 00000000
147 100 (720,1280,3) 00000000
148 100 (720,1280,3) 00000000
149 100 (720,1280,3) 00000000
150 100 (720,1280,3) 00000000
151 100 (720,1280,3) 00000000
152 100 (720,1280,3) 00000000
153 100 (720,1280,3) 00000000
154 100 (720,1280,3) 00000000
155 100 (720,1280,3) 00000000
156 100 (720,1280,3) 00000000
157 100 (720,1280,3) 00000000
158 100 (720,1280,3) 00000000
159 100 (720,1280,3) 00000000
160 100 (720,1280,3) 00000000
161 100 (720,1280,3) 00000000
162 100 (720,1280,3) 00000000
163 100 (720,1280,3) 00000000
164 100 (720,1280,3) 00000000
165 100 (720,1280,3) 00000000
166 100 (720,1280,3) 00000000
167 100 (720,1280,3) 00000000
168 100 (720,1280,3) 00000000
169 100 (720,1280,3) 00000000
170 100 (720,1280,3) 00000000
171 100 (720,1280,3) 00000000
172 100 (720,1280,3) 00000000
173 100 (720,1280,3) 00000000
174 100 (720,1280,3) 00000000
175 100 (720,1280,3) 00000000
176 100 (720,1280,3) 00000000
177 100 (720,1280,3) 00000000
178 100 (720,1280,3) 00000000
179 100 (720,1280,3) 00000000
180 100 (720,1280,3) 00000000
181 100 (720,1280,3) 00000000
182 100 (720,1280,3) 00000000
183 100 (720,1280,3) 00000000
184 100 (720,1280,3) 00000000
185 100 (720,1280,3) 00000000
186 100 (720,1280,3) 00000000
187 100 (720,1280,3) 00000000
188 100 (720,1280,3) 00000000
189 100 (720,1280,3) 00000000
190 100 (720,1280,3) 00000000
191 100 (720,1280,3) 00000000
192 100 (720,1280,3) 00000000
193 100 (720,1280,3) 00000000
194 100 (720,1280,3) 00000000
195 100 (720,1280,3) 00000000
196 100 (720,1280,3) 00000000
197 100 (720,1280,3) 00000000
198 100 (720,1280,3) 00000000
199 100 (720,1280,3) 00000000
200 100 (720,1280,3) 00000000
201 100 (720,1280,3) 00000000
202 100 (720,1280,3) 00000000
203 100 (720,1280,3) 00000000
204 100 (720,1280,3) 00000000
205 100 (720,1280,3) 00000000
206 100 (720,1280,3) 00000000
207 100 (720,1280,3) 00000000
208 100 (720,1280,3) 00000000
209 100 (720,1280,3) 00000000
210 100 (720,1280,3) 00000000
211 100 (720,1280,3) 00000000
212 100 (720,1280,3) 00000000
213 100 (720,1280,3) 00000000
214 100 (720,1280,3) 00000000
215 100 (720,1280,3) 00000000
216 100 (720,1280,3) 00000000
217 100 (720,1280,3) 00000000
218 100 (720,1280,3) 00000000
219 100 (720,1280,3) 00000000
220 100 (720,1280,3) 00000000
221 100 (720,1280,3) 00000000
222 100 (720,1280,3) 00000000
223 100 (720,1280,3) 00000000
224 100 (720,1280,3) 00000000
225 100 (720,1280,3) 00000000
226 100 (720,1280,3) 00000000
227 100 (720,1280,3) 00000000
228 100 (720,1280,3) 00000000
229 100 (720,1280,3) 00000000
230 100 (720,1280,3) 00000000
231 100 (720,1280,3) 00000000
232 100 (720,1280,3) 00000000
233 100 (720,1280,3) 00000000
234 100 (720,1280,3) 00000000
235 100 (720,1280,3) 00000000
236 100 (720,1280,3) 00000000
237 100 (720,1280,3) 00000000
238 100 (720,1280,3) 00000000
239 100 (720,1280,3) 00000000
240 100 (720,1280,3) 00000000
241 100 (720,1280,3) 00000000
242 100 (720,1280,3) 00000000
243 100 (720,1280,3) 00000000
244 100 (720,1280,3) 00000000
245 100 (720,1280,3) 00000000
246 100 (720,1280,3) 00000000
247 100 (720,1280,3) 00000000
248 100 (720,1280,3) 00000000
249 100 (720,1280,3) 00000000
250 100 (720,1280,3) 00000000
251 100 (720,1280,3) 00000000
252 100 (720,1280,3) 00000000
253 100 (720,1280,3) 00000000
254 100 (720,1280,3) 00000000
255 100 (720,1280,3) 00000000
256 100 (720,1280,3) 00000000
257 100 (720,1280,3) 00000000
258 100 (720,1280,3) 00000000
259 100 (720,1280,3) 00000000
260 100 (720,1280,3) 00000000
261 100 (720,1280,3) 00000000
262 100 (720,1280,3) 00000000
263 100 (720,1280,3) 00000000
264 100 (720,1280,3) 00000000
265 100 (720,1280,3) 00000000
266 100 (720,1280,3) 00000000
267 100 (720,1280,3) 00000000
268 100 (720,1280,3) 00000000
269 100 (720,1280,3) 00000000
This diff is collapsed.
This diff is collapsed.
'''
# --------------------------------------------
# select dataset
# --------------------------------------------
# Kai Zhang (github: https://github.com/cszn)
# --------------------------------------------
'''
def define_Dataset(dataset_opt):
dataset_type = dataset_opt['dataset_type'].lower()
if dataset_type in ['l', 'low-quality', 'input-only']:
from data.dataset_l import DatasetL as D
# -----------------------------------------
# denoising
# -----------------------------------------
elif dataset_type in ['masked_denoising']:
from data.dataset_masked_denoising import DatasetMaskedDenoising as D
elif dataset_type in ['dncnn', 'denoising']:
from data.dataset_dncnn import DatasetDnCNN as D
elif dataset_type in ['dnpatch']:
from data.dataset_dnpatch import DatasetDnPatch as D
elif dataset_type in ['ffdnet', 'denoising-noiselevel']:
from data.dataset_ffdnet import DatasetFFDNet as D
elif dataset_type in ['fdncnn', 'denoising-noiselevelmap']:
from data.dataset_fdncnn import DatasetFDnCNN as D
# -----------------------------------------
# super-resolution
# -----------------------------------------
elif dataset_type in ['sr', 'super-resolution']:
from data.dataset_sr import DatasetSR as D
elif dataset_type in ['srmd']:
from data.dataset_srmd import DatasetSRMD as D
elif dataset_type in ['dpsr', 'dnsr']:
from data.dataset_dpsr import DatasetDPSR as D
elif dataset_type in ['usrnet', 'usrgan']:
from data.dataset_usrnet import DatasetUSRNet as D
elif dataset_type in ['bsrnet', 'bsrgan', 'blindsr']:
from data.dataset_blindsr import DatasetBlindSR as D
# -------------------------------------------------
# JPEG compression artifact reduction (deblocking)
# -------------------------------------------------
elif dataset_type in ['jpeg']:
from data.dataset_jpeg import DatasetJPEG as D
# -----------------------------------------
# video restoration
# -----------------------------------------
elif dataset_type in ['videorecurrenttraindataset']:
from data.dataset_video_train import VideoRecurrentTrainDataset as D
elif dataset_type in ['videorecurrenttrainnonblinddenoisingdataset']:
from data.dataset_video_train import VideoRecurrentTrainNonblindDenoisingDataset as D
elif dataset_type in ['videorecurrenttrainvimeodataset']:
from data.dataset_video_train import VideoRecurrentTrainVimeoDataset as D
elif dataset_type in ['videorecurrenttestdataset']:
from data.dataset_video_test import VideoRecurrentTestDataset as D
elif dataset_type in ['singlevideorecurrenttestdataset']:
from data.dataset_video_test import SingleVideoRecurrentTestDataset as D
elif dataset_type in ['videotestvimeo90kdataset']:
from data.dataset_video_test import VideoTestVimeo90KDataset as D
# -----------------------------------------
# common
# -----------------------------------------
elif dataset_type in ['plain']:
from data.dataset_plain import DatasetPlain as D
elif dataset_type in ['plainpatch']:
from data.dataset_plainpatch import DatasetPlainPatch as D
else:
raise NotImplementedError('Dataset [{:s}] is not found.'.format(dataset_type))
dataset = D(dataset_opt)
print('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__, dataset_opt['name']))
return dataset
FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.13.1-centos7.6-dtk-23.04.1-py38-latest
RUN source /opt/dtk/env.sh
COPY requirement.txt requirement.txt
RUN pip3 install -r requirement.txt
import os
from utils.utils_image import split_imageset
if __name__ == "__main__":
ROOT_PATH = os.getcwd()
print("The root is", ROOT_PATH)
ori_img_path = os.path.join(ROOT_PATH, 'trainset/')
if not os.path.exists(ori_img_path):
print("the ori_img_path {} is not exists.".format(ori_img_path))
save_img_path = os.path.join(ROOT_PATH, 'trainsets/trainH')
if not os.path.exists(save_img_path):
os.makedirs(save_img_path)
split_imageset(ori_img_path, save_img_path, n_channels=3, p_size=512, p_overlap=96, p_max=800)
print("split {} to {} finished.".format(ori_img_path, save_img_path))
import argparse
import cv2
import glob
import numpy as np
from collections import OrderedDict
import os
import torch
import requests
from models.network_swinir import SwinIR as net
from utils import utils_image as util
from utils import utils_option as option
import lpips
import torch
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='masked_denoising')
parser.add_argument('--scale', type=int, default=1, help='scale factor: 1, 2, 3, 4, 8') # 1 for dn and jpeg car
parser.add_argument('--noise', type=int, default=15, help='noise level: 15, 25, 50')
parser.add_argument('--jpeg', type=int, default=40, help='scale factor: 10, 20, 30, 40')
parser.add_argument('--training_patch_size', type=int, default=128, help='patch size used in training SwinIR. '
'Just used to differentiate two different settings in Table 2 of the paper. '
'Images are NOT tested patch by patch.')
parser.add_argument('--large_model', action='store_true', help='use large model, only provided for real image sr')
parser.add_argument('--model_path', type=str,
default='model_zoo/swinir/001_classicalSR_DIV2K_s48w8_SwinIR-M_x2.pth')
parser.add_argument('--folder_lq', type=str, default=None, help='input low-quality test image folder')
parser.add_argument('--folder_gt', type=str, default=None, help='input ground-truth test image folder')
parser.add_argument('--tile', type=int, default=None, help='Tile size, None for no tile during testing (testing as a whole)')
parser.add_argument('--tile_overlap', type=int, default=32, help='Overlapping of different tiles')
parser.add_argument('--opt', type=str, help='Path to option JSON file.')
parser.add_argument('--name', type=str, default="test", help='Path to option JSON file.')
opt = option.parse(parser.parse_args().opt, is_train=False)
global opt_net
opt_net = opt['netG']
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# set up model
if os.path.exists(args.model_path):
print(f'loading model from {args.model_path}')
else:
os.makedirs(os.path.dirname(args.model_path), exist_ok=True)
url = 'https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/{}'.format(os.path.basename(args.model_path))
r = requests.get(url, allow_redirects=True)
print(f'downloading model {args.model_path}')
open(args.model_path, 'wb').write(r.content)
model = define_model(args)
model.eval()
model = model.to(device)
# setup folder and path
folder, save_dir, border, window_size = setup(args)
print(folder)
os.makedirs(save_dir, exist_ok=True)
test_results = OrderedDict()
test_results['psnr'] = []
test_results['ssim'] = []
test_results['psnr_y'] = []
test_results['ssim_y'] = []
test_results['psnr_b'] = []
test_results['lpips'] = []
psnr, ssim, psnr_y, ssim_y, psnr_b, lpips_ = 0, 0, 0, 0, 0, 0
loss_fn_alex = lpips.LPIPS(net='alex').cuda() # best forward scores
for idx, path in enumerate(sorted(glob.glob(os.path.join(folder, '*')))):
# print(1)
# read image
imgname, img_lq, img_gt = get_image_pair(args, path) # image to HWC-BGR, float32
img_lq = np.transpose(img_lq if img_lq.shape[2] == 1 else img_lq[:, :, [2, 1, 0]], (2, 0, 1)) # HCW-BGR to CHW-RGB
img_lq = torch.from_numpy(img_lq).float().unsqueeze(0).to(device) # CHW-RGB to NCHW-RGB
# inference
with torch.no_grad():
# pad input image to be a multiple of window_size
_, _, h_old, w_old = img_lq.size()
h_pad = (h_old // window_size + 1) * window_size - h_old
w_pad = (w_old // window_size + 1) * window_size - w_old
img_lq = torch.cat([img_lq, torch.flip(img_lq, [2])], 2)[:, :, :h_old + h_pad, :]
img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad]
output = test(img_lq, model, args, window_size)
output = output[..., :h_old * args.scale, :w_old * args.scale]
# save image
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3:
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
cv2.imwrite(f'{save_dir}/{imgname}_SwinIR.png', output)
# evaluate psnr/ssim/psnr_b
if img_gt is not None:
img_gt = (img_gt * 255.0).round().astype(np.uint8) # float32 to uint8
img_gt = img_gt[:h_old * args.scale, :w_old * args.scale, ...] # crop gt
img_gt = np.squeeze(img_gt)
psnr = util.calculate_psnr(output, img_gt, border=border)
ssim = util.calculate_ssim(output, img_gt, border=border)
lpips_ = loss_fn_alex(im2tensor(output).cuda(), im2tensor(img_gt).cuda()).item()
test_results['psnr'].append(psnr)
test_results['ssim'].append(ssim)
test_results['lpips'].append(lpips_)
if img_gt.ndim == 3: # RGB image
output_y = util.bgr2ycbcr(output.astype(np.float32) / 255.) * 255.
img_gt_y = util.bgr2ycbcr(img_gt.astype(np.float32) / 255.) * 255.
psnr_y = util.calculate_psnr(output_y, img_gt_y, border=border)
ssim_y = util.calculate_ssim(output_y, img_gt_y, border=border)
test_results['psnr_y'].append(psnr_y)
test_results['ssim_y'].append(ssim_y)
print('Testing {:d} {:20s} - PSNR: {:.2f} dB; SSIM: {:.4f}; '
'PSNR_Y: {:.2f} dB; SSIM_Y: {:.4f}; '
'LPIPS: {:.4f}'.
format(idx, imgname, psnr, ssim, psnr_y, ssim_y, lpips_))
else:
print('Testing {:d} {:20s}'.format(idx, imgname))
# summarize psnr/ssim
if img_gt is not None:
ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
ave_lpips = sum(test_results['lpips']) / len(test_results['lpips'])
print('\n{} \n-- Average PSNR/SSIM(RGB): {:.2f} dB; {:.4f}'.format(save_dir, ave_psnr, ave_ssim))
if img_gt.ndim == 3:
ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y'])
ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y'])
print('-- Average PSNR_Y/SSIM_Y/LPIPS: {:.2f}/{:.4f}/{:.4f}'.format(ave_psnr_y, ave_ssim_y, ave_lpips))
def define_model(args):
if args.task == 'masked_denoising':
global opt_net
model = net(upscale=opt_net['upscale'],
in_chans=opt_net['in_chans'],
img_size=opt_net['img_size'],
window_size=opt_net['window_size'],
img_range=opt_net['img_range'],
depths=opt_net['depths'],
embed_dim=opt_net['embed_dim'],
num_heads=opt_net['num_heads'],
mlp_ratio=opt_net['mlp_ratio'],
upsampler=opt_net['upsampler'],
resi_connection=opt_net['resi_connection'],
talking_heads=opt_net['talking_heads'],
use_attn_fn=opt_net['attn_fn'],
head_scale=opt_net['head_scale'],
on_attn=opt_net['on_attn'],
use_mask=opt_net['use_mask'],
mask_ratio1=opt_net['mask_ratio1'],
mask_ratio2=opt_net['mask_ratio2'],
mask_is_diff=opt_net['mask_is_diff'],
type=opt_net['type'],
opt=opt_net,
)
param_key_g = 'params'
# real-world image sr
elif args.task == 'real_sr':
if not args.large_model:
# use 'nearest+conv' to avoid block artifacts
model = net(upscale=4, in_chans=3, img_size=64, window_size=8,
img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2, upsampler='nearest+conv', resi_connection='1conv')
else:
# larger model size; use '3conv' to save parameters and memory; use ema for GAN training
model = net(upscale=4, in_chans=3, img_size=64, window_size=8,
img_range=1., depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], embed_dim=240,
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
mlp_ratio=2, upsampler='nearest+conv', resi_connection='3conv')
param_key_g = 'params_ema'
# grayscale image denoising
elif args.task == 'gray_dn':
model = net(upscale=1, in_chans=1, img_size=128, window_size=8,
img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2, upsampler='', resi_connection='1conv')
param_key_g = 'params'
# color image denoising
elif args.task == 'color_dn':
model = net(upscale=1, in_chans=3, img_size=128, window_size=8,
img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2, upsampler='', resi_connection='1conv')
param_key_g = 'params'
# JPEG compression artifact reduction
# use window_size=7 because JPEG encoding uses 8x8; use img_range=255 because it's sligtly better than 1
elif args.task == 'jpeg_car':
model = net(upscale=1, in_chans=1, img_size=126, window_size=7,
img_range=255., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2, upsampler='', resi_connection='1conv')
param_key_g = 'params'
pretrained_model = torch.load(args.model_path)
model.load_state_dict(pretrained_model[param_key_g] if param_key_g in pretrained_model.keys() else pretrained_model, strict=True)
return model
def setup(args):
# 001 classical image sr/ 002 lightweight image sr
if args.task in ['masked_denoising', 'classical_sr', 'lightweight_sr']:
save_dir = f'results/{args.name}'
folder = args.folder_gt
border = args.scale
window_size = 8
# 003 real-world image sr
elif args.task in ['real_sr']:
save_dir = f'results/swinir_{args.task}_x{args.scale}'
if args.large_model:
save_dir += '_large'
folder = args.folder_lq
border = 0
window_size = 8
# 004 grayscale image denoising/ 005 color image denoising
elif args.task in ['gray_dn', 'color_dn']:
save_dir = f'results/swinir_{args.task}_noise{args.noise}'
folder = args.folder_gt
border = 0
window_size = 8
# 006 JPEG compression artifact reduction
elif args.task in ['jpeg_car']:
save_dir = f'results/swinir_{args.task}_jpeg{args.jpeg}'
folder = args.folder_gt
border = 0
window_size = 7
return folder, save_dir, border, window_size
def get_image_pair(args, path):
(imgname, imgext) = os.path.splitext(os.path.basename(path))
# 001 classical image sr/ 002 lightweight image sr (load lq-gt image pairs)
if args.task in ['masked_denoising', 'classical_sr', 'lightweight_sr']:
img_gt = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
# img_lq = cv2.imread(f'{args.folder_lq}/{imgname}_x{args.scale}{imgext}', cv2.IMREAD_COLOR).astype(np.float32) / 255.
# img_lq = cv2.imread(f'{args.folder_lq}/{imgname}{imgext}', cv2.IMREAD_COLOR).astype(np.float32) / 255.
try:
imgext = '.png'
img_lq = cv2.imread(f'{args.folder_lq}/{imgname}{imgext}', cv2.IMREAD_COLOR).astype(np.float32) / 255.
except:
imgext = '.tif'
img_lq = cv2.imread(f'{args.folder_lq}/{imgname}{imgext}', cv2.IMREAD_COLOR).astype(np.float32) / 255.
# 003 real-world image sr (load lq image only)
elif args.task in ['real_sr']:
img_gt = None
img_lq = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
# 004 grayscale image denoising (load gt image and generate lq image on-the-fly)
elif args.task in ['gray_dn']:
img_gt = cv2.imread(path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
np.random.seed(seed=0)
img_lq = img_gt + np.random.normal(0, args.noise / 255., img_gt.shape)
img_gt = np.expand_dims(img_gt, axis=2)
img_lq = np.expand_dims(img_lq, axis=2)
# 005 color image denoising (load gt image and generate lq image on-the-fly)
elif args.task in ['color_dn']:
img_gt = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
np.random.seed(seed=0)
img_lq = img_gt + np.random.normal(0, args.noise / 255., img_gt.shape)
# 006 JPEG compression artifact reduction (load gt image and generate lq image on-the-fly)
elif args.task in ['jpeg_car']:
img_gt = cv2.imread(path, 0)
result, encimg = cv2.imencode('.jpg', img_gt, [int(cv2.IMWRITE_JPEG_QUALITY), args.jpeg])
img_lq = cv2.imdecode(encimg, 0)
img_gt = np.expand_dims(img_gt, axis=2).astype(np.float32) / 255.
img_lq = np.expand_dims(img_lq, axis=2).astype(np.float32) / 255.
return imgname, img_lq, img_gt
def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
return torch.Tensor((image / factor - cent)
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
def test(img_lq, model, args, window_size):
if args.tile is None:
# test the image as a whole
output = model(img_lq)
else:
# test the image tile by tile
b, c, h, w = img_lq.size()
tile = min(args.tile, h, w)
assert tile % window_size == 0, "tile size should be a multiple of window_size"
tile_overlap = args.tile_overlap
sf = args.scale
stride = tile - tile_overlap
h_idx_list = list(range(0, h-tile, stride)) + [h-tile]
w_idx_list = list(range(0, w-tile, stride)) + [w-tile]
E = torch.zeros(b, c, h*sf, w*sf).type_as(img_lq)
W = torch.zeros_like(E)
for h_idx in h_idx_list:
for w_idx in w_idx_list:
in_patch = img_lq[..., h_idx:h_idx+tile, w_idx:w_idx+tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
E[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch)
W[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch_mask)
output = E.div_(W)
return output
if __name__ == '__main__':
main()
import argparse
import cv2
import glob
import numpy as np
from collections import OrderedDict
import os
import torch
import requests
from models.network_swinir import SwinIR as net
from utils import utils_image as util
from utils import utils_option as option
import lpips
import torch
def transform(v, op):
# if self.precision != 'single': v = v.float()
v2np = v.data.cpu().numpy()
if op == 'v':
tfnp = v2np[:, :, :, ::-1].copy()
elif op == 'h':
tfnp = v2np[:, :, ::-1, :].copy()
elif op == 't':
tfnp = v2np.transpose((0, 1, 3, 2)).copy()
ret = torch.Tensor(tfnp).cuda()
# if self.precision == 'half': ret = ret.half()
return ret
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='lightweight_sr', help='classical_sr, lightweight_sr, real_sr, '
'gray_dn, color_dn, jpeg_car')
parser.add_argument('--scale', type=int, default=1, help='scale factor: 1, 2, 3, 4, 8') # 1 for dn and jpeg car
parser.add_argument('--noise', type=int, default=15, help='noise level: 15, 25, 50')
parser.add_argument('--jpeg', type=int, default=40, help='scale factor: 10, 20, 30, 40')
parser.add_argument('--training_patch_size', type=int, default=128, help='patch size used in training SwinIR. '
'Just used to differentiate two different settings in Table 2 of the paper. '
'Images are NOT tested patch by patch.')
parser.add_argument('--large_model', action='store_true', help='use large model, only provided for real image sr')
parser.add_argument('--model_path', type=str,
default='model_zoo/swinir/001_classicalSR_DIV2K_s48w8_SwinIR-M_x2.pth')
parser.add_argument('--folder_lq', type=str, default=None, help='input low-quality test image folder')
parser.add_argument('--folder_gt', type=str, default=None, help='input ground-truth test image folder')
parser.add_argument('--tile', type=int, default=None, help='Tile size, None for no tile during testing (testing as a whole)')
parser.add_argument('--tile_overlap', type=int, default=32, help='Overlapping of different tiles')
parser.add_argument('--opt', type=str, help='Path to option JSON file.')
parser.add_argument('--name', type=str, default="test", help='Path to option JSON file.')
opt = option.parse(parser.parse_args().opt, is_train=False)
global opt_net
opt_net = opt['netG']
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# set up model
if os.path.exists(args.model_path):
print(f'loading model from {args.model_path}')
else:
os.makedirs(os.path.dirname(args.model_path), exist_ok=True)
url = 'https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/{}'.format(os.path.basename(args.model_path))
r = requests.get(url, allow_redirects=True)
print(f'downloading model {args.model_path}')
open(args.model_path, 'wb').write(r.content)
model = define_model(args)
model.eval()
model = model.to(device)
# setup folder and path
folder, save_dir, border, window_size = setup(args)
# print(folder)
print(args.folder_lq)
os.makedirs(save_dir, exist_ok=True)
test_results = OrderedDict()
test_results['psnr'] = []
test_results['ssim'] = []
test_results['psnr_y'] = []
test_results['ssim_y'] = []
test_results['psnr_b'] = []
test_results['lpips'] = []
psnr, ssim, psnr_y, ssim_y, psnr_b, lpips_ = 0, 0, 0, 0, 0, 0
loss_fn_alex = lpips.LPIPS(net='alex').cuda() # best forward scores
for idx, path in enumerate(sorted(glob.glob(os.path.join(folder, '*')))):
# print(1)
# read image
imgname, img_lq, img_gt = get_image_pair(args, path) # image to HWC-BGR, float32
img_lq = np.transpose(img_lq if img_lq.shape[2] == 1 else img_lq[:, :, [2, 1, 0]], (2, 0, 1)) # HCW-BGR to CHW-RGB
img_lq = torch.from_numpy(img_lq).float().unsqueeze(0).to(device) # CHW-RGB to NCHW-RGB
# inference
with torch.no_grad():
# pad input image to be a multiple of window_size
_, _, h_old, w_old = img_lq.size()
h_pad = (h_old // window_size + 1) * window_size - h_old
w_pad = (w_old // window_size + 1) * window_size - w_old
img_lq = torch.cat([img_lq, torch.flip(img_lq, [2])], 2)[:, :, :h_old + h_pad, :]
img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad]
list_x = []
x = [img_lq]
for tf in 'v', 'h', 't': x.extend([transform(_x, tf) for _x in x])
list_x.append(x)
list_y = []
for x in zip(*list_x):
# print(len(x))
# y = forward_function(*x)
y = test(x[0], model, args, window_size)
if not isinstance(y, list): y = [y]
if not list_y:
list_y = [[_y] for _y in y]
else:
for _list_y, _y in zip(list_y, y): _list_y.append(_y)
for _list_y in list_y:
for i in range(len(_list_y)):
if i > 3:
_list_y[i] = transform(_list_y[i], 't')
if i % 4 > 1:
_list_y[i] = transform(_list_y[i], 'h')
if (i % 4) % 2 == 1:
_list_y[i] = transform(_list_y[i], 'v')
y = [torch.cat(_y, dim=0).mean(dim=0, keepdim=True) for _y in list_y]
if len(y) == 1: y = y[0]
output = y
# output = test(img_lq, model, args, window_size)
output = output[..., :h_old * args.scale, :w_old * args.scale]
# save image
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3:
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
cv2.imwrite(f'{save_dir}/{imgname}_SwinIR.png', output)
# evaluate psnr/ssim/psnr_b
if img_gt is not None:
img_gt = (img_gt * 255.0).round().astype(np.uint8) # float32 to uint8
img_gt = img_gt[:h_old * args.scale, :w_old * args.scale, ...] # crop gt
img_gt = np.squeeze(img_gt)
psnr = util.calculate_psnr(output, img_gt, border=border)
ssim = util.calculate_ssim(output, img_gt, border=border)
lpips_ = loss_fn_alex(im2tensor(output).cuda(), im2tensor(img_gt).cuda()).item()
test_results['psnr'].append(psnr)
test_results['ssim'].append(ssim)
test_results['lpips'].append(lpips_)
if img_gt.ndim == 3: # RGB image
output_y = util.bgr2ycbcr(output.astype(np.float32) / 255.) * 255.
img_gt_y = util.bgr2ycbcr(img_gt.astype(np.float32) / 255.) * 255.
psnr_y = util.calculate_psnr(output_y, img_gt_y, border=border)
ssim_y = util.calculate_ssim(output_y, img_gt_y, border=border)
test_results['psnr_y'].append(psnr_y)
test_results['ssim_y'].append(ssim_y)
if args.task in ['jpeg_car']:
psnr_b = util.calculate_psnrb(output, img_gt, border=border)
test_results['psnr_b'].append(psnr_b)
# print('Testing {:d} {:20s} - PSNR: {:.2f} dB; SSIM: {:.4f}; '
# 'PSNR_Y: {:.2f} dB; SSIM_Y: {:.4f}; '
# 'PSNR_B: {:.2f} dB; LPIPS: {:.4f}'.
# format(idx, imgname, psnr, ssim, psnr_y, ssim_y, psnr_b, lpips_))
else:
print('Testing {:d} {:20s}'.format(idx, imgname))
# summarize psnr/ssim
if img_gt is not None:
ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
ave_lpips = sum(test_results['lpips']) / len(test_results['lpips'])
print('\n{} \n-- Average PSNR/SSIM(RGB): {:.2f} dB; {:.4f}'.format(save_dir, ave_psnr, ave_ssim))
if img_gt.ndim == 3:
ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y'])
ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y'])
print('-- Average PSNR_Y/SSIM_Y/LPIPS: {:.2f}/{:.4f}/{:.4f}'.format(ave_psnr_y, ave_ssim_y, ave_lpips))
if args.task in ['jpeg_car']:
ave_psnr_b = sum(test_results['psnr_b']) / len(test_results['psnr_b'])
print('-- Average PSNR_B: {:.2f} dB'.format(ave_psnr_b))
def define_model(args):
# 001 classical image sr
if args.task == 'classical_sr':
model = net(upscale=args.scale, in_chans=3, img_size=args.training_patch_size, window_size=8,
img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2, upsampler='pixelshuffle', resi_connection='1conv')
param_key_g = 'params'
# 002 lightweight image sr
# use 'pixelshuffledirect' to save parameters
elif args.task == 'lightweight_sr':
# model = net(upscale=args.scale, in_chans=3, img_size=64, window_size=8,
# img_range=1., depths=[6, 6, 6, 6], embed_dim=60, num_heads=[6, 6, 6, 6],
# mlp_ratio=2, upsampler='pixelshuffledirect', resi_connection='1conv')
global opt_net
model = net(upscale=opt_net['upscale'],
in_chans=opt_net['in_chans'],
img_size=opt_net['img_size'],
window_size=opt_net['window_size'],
img_range=opt_net['img_range'],
depths=opt_net['depths'],
embed_dim=opt_net['embed_dim'],
num_heads=opt_net['num_heads'],
mlp_ratio=opt_net['mlp_ratio'],
upsampler=opt_net['upsampler'],
resi_connection=opt_net['resi_connection'],
talking_heads=opt_net['talking_heads'],
use_attn_fn=opt_net['attn_fn'],
head_scale=opt_net['head_scale'],
on_attn=opt_net['on_attn'],
use_mask=opt_net['use_mask'],
mask_ratio1=opt_net['mask_ratio1'],
mask_ratio2=opt_net['mask_ratio2'],
mask_is_diff=opt_net['mask_is_diff'],
type=opt_net['type'],
opt=opt_net,
)
param_key_g = 'params'
# 003 real-world image sr
elif args.task == 'real_sr':
if not args.large_model:
# use 'nearest+conv' to avoid block artifacts
model = net(upscale=4, in_chans=3, img_size=64, window_size=8,
img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2, upsampler='nearest+conv', resi_connection='1conv')
else:
# larger model size; use '3conv' to save parameters and memory; use ema for GAN training
model = net(upscale=4, in_chans=3, img_size=64, window_size=8,
img_range=1., depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], embed_dim=240,
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
mlp_ratio=2, upsampler='nearest+conv', resi_connection='3conv')
param_key_g = 'params_ema'
# 004 grayscale image denoising
elif args.task == 'gray_dn':
model = net(upscale=1, in_chans=1, img_size=128, window_size=8,
img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2, upsampler='', resi_connection='1conv')
param_key_g = 'params'
# 005 color image denoising
elif args.task == 'color_dn':
model = net(upscale=1, in_chans=3, img_size=128, window_size=8,
img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2, upsampler='', resi_connection='1conv')
param_key_g = 'params'
# 006 JPEG compression artifact reduction
# use window_size=7 because JPEG encoding uses 8x8; use img_range=255 because it's sligtly better than 1
elif args.task == 'jpeg_car':
model = net(upscale=1, in_chans=1, img_size=126, window_size=7,
img_range=255., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2, upsampler='', resi_connection='1conv')
param_key_g = 'params'
pretrained_model = torch.load(args.model_path)
model.load_state_dict(pretrained_model[param_key_g] if param_key_g in pretrained_model.keys() else pretrained_model, strict=True)
return model
def setup(args):
# 001 classical image sr/ 002 lightweight image sr
if args.task in ['classical_sr', 'lightweight_sr']:
save_dir = f'results/{args.name}'
folder = args.folder_gt
border = args.scale
window_size = 8
# 003 real-world image sr
elif args.task in ['real_sr']:
save_dir = f'results/swinir_{args.task}_x{args.scale}'
if args.large_model:
save_dir += '_large'
folder = args.folder_lq
border = 0
window_size = 8
# 004 grayscale image denoising/ 005 color image denoising
elif args.task in ['gray_dn', 'color_dn']:
save_dir = f'results/swinir_{args.task}_noise{args.noise}'
folder = args.folder_gt
border = 0
window_size = 8
# 006 JPEG compression artifact reduction
elif args.task in ['jpeg_car']:
save_dir = f'results/swinir_{args.task}_jpeg{args.jpeg}'
folder = args.folder_gt
border = 0
window_size = 7
return folder, save_dir, border, window_size
def get_image_pair(args, path):
(imgname, imgext) = os.path.splitext(os.path.basename(path))
# 001 classical image sr/ 002 lightweight image sr (load lq-gt image pairs)
if args.task in ['classical_sr', 'lightweight_sr']:
img_gt = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
# img_lq = cv2.imread(f'{args.folder_lq}/{imgname}_x{args.scale}{imgext}', cv2.IMREAD_COLOR).astype(np.float32) / 255.
img_lq = cv2.imread(f'{args.folder_lq}/{imgname}{imgext}', cv2.IMREAD_COLOR).astype(np.float32) / 255.
# 003 real-world image sr (load lq image only)
elif args.task in ['real_sr']:
img_gt = None
img_lq = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
# 004 grayscale image denoising (load gt image and generate lq image on-the-fly)
elif args.task in ['gray_dn']:
img_gt = cv2.imread(path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
np.random.seed(seed=0)
img_lq = img_gt + np.random.normal(0, args.noise / 255., img_gt.shape)
img_gt = np.expand_dims(img_gt, axis=2)
img_lq = np.expand_dims(img_lq, axis=2)
# 005 color image denoising (load gt image and generate lq image on-the-fly)
elif args.task in ['color_dn']:
img_gt = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
np.random.seed(seed=0)
img_lq = img_gt + np.random.normal(0, args.noise / 255., img_gt.shape)
# 006 JPEG compression artifact reduction (load gt image and generate lq image on-the-fly)
elif args.task in ['jpeg_car']:
img_gt = cv2.imread(path, 0)
result, encimg = cv2.imencode('.jpg', img_gt, [int(cv2.IMWRITE_JPEG_QUALITY), args.jpeg])
img_lq = cv2.imdecode(encimg, 0)
img_gt = np.expand_dims(img_gt, axis=2).astype(np.float32) / 255.
img_lq = np.expand_dims(img_lq, axis=2).astype(np.float32) / 255.
return imgname, img_lq, img_gt
def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
return torch.Tensor((image / factor - cent)
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
def test(img_lq, model, args, window_size):
if args.tile is None:
# test the image as a whole
output = model(img_lq)
else:
# test the image tile by tile
b, c, h, w = img_lq.size()
tile = min(args.tile, h, w)
assert tile % window_size == 0, "tile size should be a multiple of window_size"
tile_overlap = args.tile_overlap
sf = args.scale
stride = tile - tile_overlap
h_idx_list = list(range(0, h-tile, stride)) + [h-tile]
w_idx_list = list(range(0, w-tile, stride)) + [w-tile]
E = torch.zeros(b, c, h*sf, w*sf).type_as(img_lq)
W = torch.zeros_like(E)
for h_idx in h_idx_list:
for w_idx in w_idx_list:
in_patch = img_lq[..., h_idx:h_idx+tile, w_idx:w_idx+tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
E[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch)
W[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch_mask)
output = E.div_(W)
return output
if __name__ == '__main__':
main()
import os.path
import math
import argparse
import time
import random
import numpy as np
from collections import OrderedDict
import logging
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import torch
from utils import utils_logger
from utils import utils_image as util
from utils import utils_option as option
from utils.utils_dist import get_dist_info, init_dist
from data.select_dataset import define_Dataset
from models.select_model import define_Model
import lpips
from tensorboardX import SummaryWriter
from torchvision.utils import make_grid
'''
# --------------------------------------------
# training code for MSRResNet
# --------------------------------------------
# Kai Zhang (cskaizhang@gmail.com)
# github: https://github.com/cszn/KAIR
# --------------------------------------------
# https://github.com/xinntao/BasicSR
# --------------------------------------------
'''
torch.backends.cudnn.enabled = False
def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
return torch.Tensor((image / factor - cent)
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
def main(json_path='options/masked_denoising/input_mask_80_90.json'):
'''
# ----------------------------------------
# Step--1 (prepare opt)
# ----------------------------------------
'''
parser = argparse.ArgumentParser()
parser.add_argument('--opt', type=str, default=json_path, help='Path to option JSON file.')
parser.add_argument('--launcher', default='pytorch', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument('--epochs', type=int, default=1000000)
parser.add_argument('--dist', default=False)
opt = option.parse(parser.parse_args().opt, is_train=True)
opt['dist'] = parser.parse_args().dist
args = parser.parse_args()
writer = SummaryWriter('./runs/' + opt['task'])
# ----------------------------------------
# distributed settings
# ----------------------------------------
if opt['dist']:
init_dist('pytorch')
opt['rank'], opt['world_size'] = get_dist_info()
if opt['rank'] == 0:
util.mkdirs((path for key, path in opt['path'].items() if 'pretrained' not in key))
# ----------------------------------------
# update opt
# ----------------------------------------
# -->-->-->-->-->-->-->-->-->-->-->-->-->-
init_iter_G, init_path_G = option.find_last_checkpoint(opt['path']['models'], net_type='G')
init_iter_E, init_path_E = option.find_last_checkpoint(opt['path']['models'], net_type='E')
opt['path']['pretrained_netG'] = init_path_G
opt['path']['pretrained_netE'] = init_path_E
init_iter_optimizerG, init_path_optimizerG = option.find_last_checkpoint(opt['path']['models'], net_type='optimizerG')
opt['path']['pretrained_optimizerG'] = init_path_optimizerG
current_step = max(init_iter_G, init_iter_E, init_iter_optimizerG)
# current_step = 0
border = opt['scale']
# --<--<--<--<--<--<--<--<--<--<--<--<--<-
# ----------------------------------------
# save opt to a '../option.json' file
# ----------------------------------------
if opt['rank'] == 0:
option.save(opt)
# ----------------------------------------
# return None for missing key
# ----------------------------------------
opt = option.dict_to_nonedict(opt)
# ----------------------------------------
# configure logger
# ----------------------------------------
if opt['rank'] == 0:
logger_name = 'train'
utils_logger.logger_info(logger_name, os.path.join(opt['path']['log'], logger_name+'.log'))
logger = logging.getLogger(logger_name)
logger.info(option.dict2str(opt))
# ----------------------------------------
# seed
# ----------------------------------------
seed = opt['train']['manual_seed']
if seed is None:
seed = random.randint(1, 10000)
print('Random seed: {}'.format(seed))
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
'''
# ----------------------------------------
# Step--2 (creat dataloader)
# ----------------------------------------
'''
# ----------------------------------------
# 1) create_dataset
# 2) creat_dataloader for train and test
# ----------------------------------------
for phase, dataset_opt in opt['datasets'].items():
if phase == 'train':
train_set = define_Dataset(dataset_opt)
train_size = int(math.ceil(len(train_set) / dataset_opt['dataloader_batch_size']))
if opt['rank'] == 0:
logger.info('Number of train images: {:,d}, iters: {:,d}'.format(len(train_set), train_size))
if opt['dist']:
# train_sampler = DistributedSampler(train_set, shuffle=dataset_opt['dataloader_shuffle'], drop_last=True, seed=seed)
train_sampler = DistributedSampler(train_set, shuffle=dataset_opt['dataloader_shuffle'])
train_loader = DataLoader(train_set,
batch_size=dataset_opt['dataloader_batch_size']//opt['num_gpu'],
shuffle=False,
num_workers=dataset_opt['dataloader_num_workers']//opt['num_gpu'],
drop_last=True,
pin_memory=True,
sampler=train_sampler)
else:
train_loader = DataLoader(train_set,
batch_size=dataset_opt['dataloader_batch_size'],
shuffle=dataset_opt['dataloader_shuffle'],
num_workers=dataset_opt['dataloader_num_workers'],
drop_last=True,
pin_memory=True)
elif phase == 'test':
test_set = define_Dataset(dataset_opt)
test_loader = DataLoader(test_set, batch_size=1,
shuffle=False, num_workers=1,
drop_last=False, pin_memory=True)
else:
raise NotImplementedError("Phase [%s] is not recognized." % phase)
'''
# ----------------------------------------
# Step--3 (initialize model)
# ----------------------------------------
'''
model = define_Model(opt)
model.init_train()
if opt['rank'] == 0:
logger.info(model.info_network())
logger.info(model.info_params())
# ==================================================================
loss_fn_alex = lpips.LPIPS(net='alex').cuda()
best_PSNRY = 0
best_step = 0
# ==================================================================
'''
# ----------------------------------------
# Step--4 (main training)
# ----------------------------------------
'''
for epoch in range(args.epochs): # keep running
if opt['dist']:
train_sampler.set_epoch(epoch)
for _, train_data in enumerate(train_loader):
current_step += 1
# -------------------------------
# 1) update learning rate
# -------------------------------
model.update_learning_rate(current_step)
# -------------------------------
# 2) feed patch pairs
# -------------------------------
model.feed_data(train_data)
# -------------------------------
# 3) optimize parameters
# -------------------------------
model.optimize_parameters(current_step)
# -------------------------------
# 4) training information
# -------------------------------
if current_step % opt['train']['checkpoint_print'] == 0 and opt['rank'] == 0:
logs = model.current_log() # such as loss
message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(epoch, current_step, model.current_learning_rate())
for k, v in logs.items(): # merge log information into message
message += '{:s}: {:.3e} '.format(k, v)
# ----------------------------------------
writer.add_scalar('loss', v, global_step=current_step)
# ----------------------------------------
logger.info(message)
# -------------------------------
# 5) save model
# -------------------------------
if current_step % opt['train']['checkpoint_save'] == 0 and opt['rank'] == 0:
logger.info('Saving the model.')
model.save(current_step)
# -------------------------------
# 6) testing
# -------------------------------
if current_step % opt['train']['checkpoint_test'] == 0 and opt['rank'] == 0:
avg_psnr = 0.0
avg_ssim = 0.0
avg_psnrY = 0.0
avg_ssimY = 0.0
avg_lpips = 0.0
idx = 0
save_list = []
for test_data in test_loader:
idx += 1
image_name_ext = os.path.basename(test_data['L_path'][0])
img_name, ext = os.path.splitext(image_name_ext)
img_dir = os.path.join(opt['path']['images'], img_name)
util.mkdir(img_dir)
model.feed_data(test_data)
model.test()
visuals = model.current_visuals()
E_img = util.tensor2uint(visuals['E'])
H_img = util.tensor2uint(visuals['H'])
# -----------------------
# save estimated image E
# -----------------------
save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(img_name, current_step))
util.imsave(E_img, save_img_path)
# -----------------------
# calculate PSNR
# -----------------------
current_psnr = util.calculate_psnr(E_img, H_img, border=border)
# ==================================================================
current_ssim = util.calculate_ssim(E_img, H_img, border=border)
current_lpips = loss_fn_alex(im2tensor(E_img).cuda(), im2tensor(H_img).cuda()).item()
output_y = util.bgr2ycbcr(E_img.astype(np.float32) / 255.) * 255.
img_gt_y = util.bgr2ycbcr(H_img.astype(np.float32) / 255.) * 255.
psnr_y = util.calculate_psnr(output_y, img_gt_y, border=border)
ssim_y = util.calculate_ssim(output_y, img_gt_y, border=border)
# ==================================================================
logger.info('{:->4d}--> {:>20s} | PSNR: {:<4.2f}, SSIM: {:<5.4f}, PSNRY: {:<4.2f}, SSIMY: {:<5.4f}, LPIPS: {:<5.4f},'.format(idx, image_name_ext, current_psnr, current_ssim, psnr_y, ssim_y, current_lpips))
# logger.info('{:->4d}--> {:>10s} | {:<4.2f}dB'.format(idx, image_name_ext, current_psnr))
avg_psnr += current_psnr
avg_ssim += current_ssim
avg_psnrY += psnr_y
avg_ssimY += ssim_y
avg_lpips += current_lpips
if img_name in opt['train']['save_image']:
print(img_name)
save_list.append(util.uint2tensor3(E_img)[:, :512, :512])
avg_psnr = avg_psnr / idx
avg_ssim = avg_ssim / idx
avg_psnrY = avg_psnrY / idx
avg_ssimY = avg_ssimY / idx
avg_lpips = avg_lpips / idx
if len(save_list) > 0 and current_step % opt['train']['checkpoint_save'] == 0 and opt['rank'] == 0:
save_images = make_grid(save_list, nrow=len(save_list))
writer.add_image("test", save_images, global_step=current_step)
# avg_psnr += current_psnr
# avg_psnr = avg_psnr / idx
if avg_psnrY >= best_PSNRY:
best_step = current_step
best_PSNRY = avg_psnrY
# testing log
# logger.info('<epoch:{:3d}, iter:{:8,d}, Average PSNR : {:<.2f}dB\n'.format(epoch, current_step, avg_psnr))
logger.info('<epoch:{:3d}, iter:{:8,d}, Average: PSNR: {:<.2f}, SSIM: {:<.4f}, PSNRY: {:<.2f}, SSIMY: {:<.4f}, LPIPS: {:<.4f}'.format(epoch, current_step, avg_psnr, avg_ssim, avg_psnrY, avg_ssimY, avg_lpips))
logger.info('--- best PSNRY ---> iter:{:8,d}, Average: PSNR: {:<.2f}\n'.format(best_step, best_PSNRY))
writer.add_scalar('PSNRY', avg_psnrY, global_step=current_step)
writer.add_scalar('SSIMY', avg_ssimY, global_step=current_step)
writer.add_scalar('PSNR', avg_psnr, global_step=current_step)
writer.add_scalar('SSIM', avg_ssim, global_step=current_step)
writer.add_scalar('LPIPS', avg_lpips, global_step=current_step)
if __name__ == '__main__':
main()
function [psnr_cur, ssim_cur] = Cal_PSNRSSIM(A,B,row,col)
[n,m,ch]=size(B);
A = A(row+1:n-row,col+1:m-col,:);
B = B(row+1:n-row,col+1:m-col,:);
A=double(A); % Ground-truth
B=double(B); %
e=A(:)-B(:);
mse=mean(e.^2);
psnr_cur=10*log10(255^2/mse);
if ch==1
[ssim_cur, ~] = ssim_index(A, B);
else
ssim_cur = (ssim_index(A(:,:,1), B(:,:,1)) + ssim_index(A(:,:,2), B(:,:,2)) + ssim_index(A(:,:,3), B(:,:,3)))/3;
end
function [mssim, ssim_map] = ssim_index(img1, img2, K, window, L)
%========================================================================
%SSIM Index, Version 1.0
%Copyright(c) 2003 Zhou Wang
%All Rights Reserved.
%
%The author is with Howard Hughes Medical Institute, and Laboratory
%for Computational Vision at Center for Neural Science and Courant
%Institute of Mathematical Sciences, New York University.
%
%----------------------------------------------------------------------
%Permission to use, copy, or modify this software and its documentation
%for educational and research purposes only and without fee is hereby
%granted, provided that this copyright notice and the original authors'
%names appear on all copies and supporting documentation. This program
%shall not be used, rewritten, or adapted as the basis of a commercial
%software or hardware product without first obtaining permission of the
%authors. The authors make no representations about the suitability of
%this software for any purpose. It is provided "as is" without express
%or implied warranty.
%----------------------------------------------------------------------
%
%This is an implementation of the algorithm for calculating the
%Structural SIMilarity (SSIM) index between two images. Please refer
%to the following paper:
%
%Z. Wang, A. C. Bovik, H. R. Sheikh, and E. P. Simoncelli, "Image
%quality assessment: From error measurement to structural similarity"
%IEEE Transactios on Image Processing, vol. 13, no. 1, Jan. 2004.
%
%Kindly report any suggestions or corrections to zhouwang@ieee.org
%
%----------------------------------------------------------------------
%
%Input : (1) img1: the first image being compared
% (2) img2: the second image being compared
% (3) K: constants in the SSIM index formula (see the above
% reference). defualt value: K = [0.01 0.03]
% (4) window: local window for statistics (see the above
% reference). default widnow is Gaussian given by
% window = fspecial('gaussian', 11, 1.5);
% (5) L: dynamic range of the images. default: L = 255
%
%Output: (1) mssim: the mean SSIM index value between 2 images.
% If one of the images being compared is regarded as
% perfect quality, then mssim can be considered as the
% quality measure of the other image.
% If img1 = img2, then mssim = 1.
% (2) ssim_map: the SSIM index map of the test image. The map
% has a smaller size than the input images. The actual size:
% size(img1) - size(window) + 1.
%
%Default Usage:
% Given 2 test images img1 and img2, whose dynamic range is 0-255
%
% [mssim ssim_map] = ssim_index(img1, img2);
%
%Advanced Usage:
% User defined parameters. For example
%
% K = [0.05 0.05];
% window = ones(8);
% L = 100;
% [mssim ssim_map] = ssim_index(img1, img2, K, window, L);
%
%See the results:
%
% mssim %Gives the mssim value
% imshow(max(0, ssim_map).^4) %Shows the SSIM index map
%
%========================================================================
if (nargin < 2 || nargin > 5)
ssim_index = -Inf;
ssim_map = -Inf;
return;
end
if (size(img1) ~= size(img2))
ssim_index = -Inf;
ssim_map = -Inf;
return;
end
[M N] = size(img1);
if (nargin == 2)
if ((M < 11) || (N < 11))
ssim_index = -Inf;
ssim_map = -Inf;
return
end
window = fspecial('gaussian', 11, 1.5); %
K(1) = 0.01; % default settings
K(2) = 0.03; %
L = 255; %
end
if (nargin == 3)
if ((M < 11) || (N < 11))
ssim_index = -Inf;
ssim_map = -Inf;
return
end
window = fspecial('gaussian', 11, 1.5);
L = 255;
if (length(K) == 2)
if (K(1) < 0 || K(2) < 0)
ssim_index = -Inf;
ssim_map = -Inf;
return;
end
else
ssim_index = -Inf;
ssim_map = -Inf;
return;
end
end
if (nargin == 4)
[H W] = size(window);
if ((H*W) < 4 || (H > M) || (W > N))
ssim_index = -Inf;
ssim_map = -Inf;
return
end
L = 255;
if (length(K) == 2)
if (K(1) < 0 || K(2) < 0)
ssim_index = -Inf;
ssim_map = -Inf;
return;
end
else
ssim_index = -Inf;
ssim_map = -Inf;
return;
end
end
if (nargin == 5)
[H W] = size(window);
if ((H*W) < 4 || (H > M) || (W > N))
ssim_index = -Inf;
ssim_map = -Inf;
return
end
if (length(K) == 2)
if (K(1) < 0 || K(2) < 0)
ssim_index = -Inf;
ssim_map = -Inf;
return;
end
else
ssim_index = -Inf;
ssim_map = -Inf;
return;
end
end
C1 = (K(1)*L)^2;
C2 = (K(2)*L)^2;
window = window/sum(sum(window));
img1 = double(img1);
img2 = double(img2);
mu1 = filter2(window, img1, 'valid');
mu2 = filter2(window, img2, 'valid');
mu1_sq = mu1.*mu1;
mu2_sq = mu2.*mu2;
mu1_mu2 = mu1.*mu2;
sigma1_sq = filter2(window, img1.*img1, 'valid') - mu1_sq;
sigma2_sq = filter2(window, img2.*img2, 'valid') - mu2_sq;
sigma12 = filter2(window, img1.*img2, 'valid') - mu1_mu2;
if (C1 > 0 & C2 > 0)
ssim_map = ((2*mu1_mu2 + C1).*(2*sigma12 + C2))./((mu1_sq + mu2_sq + C1).*(sigma1_sq + sigma2_sq + C2));
else
numerator1 = 2*mu1_mu2 + C1;
numerator2 = 2*sigma12 + C2;
denominator1 = mu1_sq + mu2_sq + C1;
denominator2 = sigma1_sq + sigma2_sq + C2;
ssim_map = ones(size(mu1));
index = (denominator1.*denominator2 > 0);
ssim_map(index) = (numerator1(index).*numerator2(index))./(denominator1(index).*denominator2(index));
index = (denominator1 ~= 0) & (denominator2 == 0);
ssim_map(index) = numerator1(index)./denominator1(index);
end
mssim = mean2(ssim_map);
return
Run matlab file [main_denoising_gray.m](https://github.com/cszn/KAIR/blob/master/matlab/main_denoising_gray.m) for local zoom.
```matlab
upperleft_pixel = [172, 218];
box = [35, 35];
zoomfactor = 3;
zoom_position = 'ur'; % 'ur' = 'upper-right'
nline = 2;
```
<img src="https://github.com/cszn/KAIR/blob/master/matlab/denoising_gray/05_drunet_2731.png" width="256px"/> <img src="https://github.com/cszn/KAIR/blob/master/matlab/denoising_gray_results/05_drunet_2731.png" width="256px"/>
function [im] = center_replace(im,im2)
[w,h,~] = size(im);
[a,b,~] = size(im2);
c1 = w-a-(w-a)/2;
c2 = h-b-(h-b)/2;
im(c1+1:c1+a,c2+1:c2+b,:) = im2;
end
function imgs = modcrop(imgs, modulo)
if size(imgs,3)==1
sz = size(imgs);
sz = sz - mod(sz, modulo);
imgs = imgs(1:sz(1), 1:sz(2));
else
tmpsz = size(imgs);
sz = tmpsz(1:2);
sz = sz - mod(sz, modulo);
imgs = imgs(1:sz(1), 1:sz(2),:);
end
function I = shave(I, border)
I = I(1+border(1):end-border(1), ...
1+border(2):end-border(2), :, :);
function [I]=zoom_function(I,upperleft_pixel,box,zoomfactor,zoom_position,nline)
y = upperleft_pixel(1);
x = upperleft_pixel(2);
box1 = box(1);
box2 = box(2); %4
s_color = [0 255 0];
l_color = [255 0 0];
[~, ~, hw] = size( I );
if hw == 1
I=repmat(I,[1,1,3]);
end
Imin = I(x:x+box1-1,y:y+box2-1,:);
I(x-nline:x+box1-1+nline,y-nline:y+box2-1+nline,1) = s_color(1);
I(x-nline:x+box1-1+nline,y-nline:y+box2-1+nline,2) = s_color(2);
I(x-nline:x+box1-1+nline,y-nline:y+box2-1+nline,3) = s_color(3);
I(x:x+box1-1,y:y+box2-1,:) = Imin;
Imax = imresize(Imin,zoomfactor,'nearest');
switch lower(zoom_position)
case {'uper_left','ul'}
I(1:2*nline+zoomfactor*box1,1:2*nline+zoomfactor*box2,1) = l_color(1);
I(1:2*nline+zoomfactor*box1,1:2*nline+zoomfactor*box2,2) = l_color(2);
I(1:2*nline+zoomfactor*box1,1:2*nline+zoomfactor*box2,3) = l_color(3);
I(1+nline:zoomfactor*box1+nline,1+nline:zoomfactor*box2+nline,:) = Imax;
case {'uper_right','ur'}
I(1:2*nline+zoomfactor*box1,end-2*nline-zoomfactor*box2+1:end,1) = l_color(1);
I(1:2*nline+zoomfactor*box1,end-2*nline-zoomfactor*box2+1:end,2) = l_color(2);
I(1:2*nline+zoomfactor*box1,end-2*nline-zoomfactor*box2+1:end,3) = l_color(3);
I(1+nline:zoomfactor*box1+nline,end-nline-zoomfactor*box2+1:end-nline,:) = Imax;
case {'lower_left','ll'}
I(end-2*nline-zoomfactor*box1+1:end,1:2*nline+zoomfactor*box2,1) = l_color(1);
I(end-2*nline-zoomfactor*box1+1:end,1:2*nline+zoomfactor*box2,2) = l_color(2);
I(end-2*nline-zoomfactor*box1+1:end,1:2*nline+zoomfactor*box2,3) = l_color(3);
I(end-nline-zoomfactor*box1+1:end-nline,1+nline:zoomfactor*box2+nline,:) = Imax;
case {'lower_right','lr'}
I(end-2*nline-zoomfactor*box1+1:end,end-2*nline-zoomfactor*box2+1:end,1) = l_color(1);
I(end-2*nline-zoomfactor*box1+1:end,end-2*nline-zoomfactor*box2+1:end,2) = l_color(2);
I(end-2*nline-zoomfactor*box1+1:end,end-2*nline-zoomfactor*box2+1:end,3) = l_color(3);
I(end-nline-zoomfactor*box1+1:end-nline,end-nline-zoomfactor*box2+1:end-nline,:) = Imax;
end
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment