Commit 5efcc6ff authored by mashun1's avatar mashun1
Browse files

metaportrait

parents
Pipeline #584 canceled with stages
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
from abc import ABCMeta, abstractmethod
class BaseStorageBackend(metaclass=ABCMeta):
"""Abstract class of storage backends.
All backends need to implement two apis: ``get()`` and ``get_text()``.
``get()`` reads the file as a byte stream and ``get_text()`` reads the file
as texts.
"""
@abstractmethod
def get(self, filepath):
pass
@abstractmethod
def get_text(self, filepath):
pass
class MemcachedBackend(BaseStorageBackend):
"""Memcached storage backend.
Attributes:
server_list_cfg (str): Config file for memcached server list.
client_cfg (str): Config file for memcached client.
sys_path (str | None): Additional path to be appended to `sys.path`.
Default: None.
"""
def __init__(self, server_list_cfg, client_cfg, sys_path=None):
if sys_path is not None:
import sys
sys.path.append(sys_path)
try:
import mc
except ImportError:
raise ImportError('Please install memcached to enable MemcachedBackend.')
self.server_list_cfg = server_list_cfg
self.client_cfg = client_cfg
self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
# mc.pyvector servers as a point which points to a memory cache
self._mc_buffer = mc.pyvector()
def get(self, filepath):
filepath = str(filepath)
import mc
self._client.Get(filepath, self._mc_buffer)
value_buf = mc.ConvertBuffer(self._mc_buffer)
return value_buf
def get_text(self, filepath):
raise NotImplementedError
class HardDiskBackend(BaseStorageBackend):
"""Raw hard disks storage backend."""
def get(self, filepath):
filepath = str(filepath)
with open(filepath, 'rb') as f:
value_buf = f.read()
return value_buf
def get_text(self, filepath):
filepath = str(filepath)
with open(filepath, 'r') as f:
value_buf = f.read()
return value_buf
class LmdbBackend(BaseStorageBackend):
"""Lmdb storage backend.
Args:
db_paths (str | list[str]): Lmdb database paths.
client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
readonly (bool, optional): Lmdb environment parameter. If True,
disallow any write operations. Default: True.
lock (bool, optional): Lmdb environment parameter. If False, when
concurrent access occurs, do not lock the database. Default: False.
readahead (bool, optional): Lmdb environment parameter. If False,
disable the OS filesystem readahead mechanism, which may improve
random read performance when a database is larger than RAM.
Default: False.
Attributes:
db_paths (list): Lmdb database path.
_client (list): A list of several lmdb envs.
"""
def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs):
try:
import lmdb
except ImportError:
raise ImportError('Please install lmdb to enable LmdbBackend.')
if isinstance(client_keys, str):
client_keys = [client_keys]
if isinstance(db_paths, list):
self.db_paths = [str(v) for v in db_paths]
elif isinstance(db_paths, str):
self.db_paths = [str(db_paths)]
assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
f'but received {len(client_keys)} and {len(self.db_paths)}.')
self._client = {}
for client, path in zip(client_keys, self.db_paths):
self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
def get(self, filepath, client_key):
"""Get values according to the filepath from one lmdb named client_key.
Args:
filepath (str | obj:`Path`): Here, filepath is the lmdb key.
client_key (str): Used for distinguishing different lmdb envs.
"""
filepath = str(filepath)
assert client_key in self._client, (f'client_key {client_key} is not in lmdb clients.')
client = self._client[client_key]
with client.begin(write=False) as txn:
value_buf = txn.get(filepath.encode('ascii'))
return value_buf
def get_text(self, filepath):
raise NotImplementedError
class FileClient(object):
"""A general file client to access files in different backend.
The client loads a file or text in a specified backend from its path
and return it as a binary file. it can also register other backend
accessor with a given name and backend class.
Attributes:
backend (str): The storage backend type. Options are "disk",
"memcached" and "lmdb".
client (:obj:`BaseStorageBackend`): The backend object.
"""
_backends = {
'disk': HardDiskBackend,
'memcached': MemcachedBackend,
'lmdb': LmdbBackend,
}
def __init__(self, backend='disk', **kwargs):
if backend not in self._backends:
raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
f' are {list(self._backends.keys())}')
self.backend = backend
self.client = self._backends[backend](**kwargs)
def get(self, filepath, client_key='default'):
# client_key is used only for lmdb, where different fileclients have
# different lmdb environments.
if self.backend == 'lmdb':
return self.client.get(filepath, client_key)
else:
return self.client.get(filepath)
def get_text(self, filepath):
return self.client.get_text(filepath)
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501
import cv2
import numpy as np
import os
def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs):
"""Read an optical flow map.
Args:
flow_path (ndarray or str): Flow path.
quantize (bool): whether to read quantized pair, if set to True,
remaining args will be passed to :func:`dequantize_flow`.
concat_axis (int): The axis that dx and dy are concatenated,
can be either 0 or 1. Ignored if quantize is False.
Returns:
ndarray: Optical flow represented as a (h, w, 2) numpy array
"""
if quantize:
assert concat_axis in [0, 1]
cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED)
if cat_flow.ndim != 2:
raise IOError(f'{flow_path} is not a valid quantized flow file, its dimension is {cat_flow.ndim}.')
assert cat_flow.shape[concat_axis] % 2 == 0
dx, dy = np.split(cat_flow, 2, axis=concat_axis)
flow = dequantize_flow(dx, dy, *args, **kwargs)
else:
with open(flow_path, 'rb') as f:
try:
header = f.read(4).decode('utf-8')
except Exception:
raise IOError(f'Invalid flow file: {flow_path}')
else:
if header != 'PIEH':
raise IOError(f'Invalid flow file: {flow_path}, header does not contain PIEH')
w = np.fromfile(f, np.int32, 1).squeeze()
h = np.fromfile(f, np.int32, 1).squeeze()
flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2))
return flow.astype(np.float32)
def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
"""Write optical flow to file.
If the flow is not quantized, it will be saved as a .flo file losslessly,
otherwise a jpeg image which is lossy but of much smaller size. (dx and dy
will be concatenated horizontally into a single image if quantize is True.)
Args:
flow (ndarray): (h, w, 2) array of optical flow.
filename (str): Output filepath.
quantize (bool): Whether to quantize the flow and save it to 2 jpeg
images. If set to True, remaining args will be passed to
:func:`quantize_flow`.
concat_axis (int): The axis that dx and dy are concatenated,
can be either 0 or 1. Ignored if quantize is False.
"""
if not quantize:
with open(filename, 'wb') as f:
f.write('PIEH'.encode('utf-8'))
np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
flow = flow.astype(np.float32)
flow.tofile(f)
f.flush()
else:
assert concat_axis in [0, 1]
dx, dy = quantize_flow(flow, *args, **kwargs)
dxdy = np.concatenate((dx, dy), axis=concat_axis)
os.makedirs(os.path.dirname(filename), exist_ok=True)
cv2.imwrite(filename, dxdy)
def quantize_flow(flow, max_val=0.02, norm=True):
"""Quantize flow to [0, 255].
After this step, the size of flow will be much smaller, and can be
dumped as jpeg images.
Args:
flow (ndarray): (h, w, 2) array of optical flow.
max_val (float): Maximum value of flow, values beyond
[-max_val, max_val] will be truncated.
norm (bool): Whether to divide flow values by image width/height.
Returns:
tuple[ndarray]: Quantized dx and dy.
"""
h, w, _ = flow.shape
dx = flow[..., 0]
dy = flow[..., 1]
if norm:
dx = dx / w # avoid inplace operations
dy = dy / h
# use 255 levels instead of 256 to make sure 0 is 0 after dequantization.
flow_comps = [quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]]
return tuple(flow_comps)
def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
"""Recover from quantized flow.
Args:
dx (ndarray): Quantized dx.
dy (ndarray): Quantized dy.
max_val (float): Maximum value used when quantizing.
denorm (bool): Whether to multiply flow values with width/height.
Returns:
ndarray: Dequantized flow.
"""
assert dx.shape == dy.shape
assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
if denorm:
dx *= dx.shape[1]
dy *= dx.shape[0]
flow = np.dstack((dx, dy))
return flow
def quantize(arr, min_val, max_val, levels, dtype=np.int64):
"""Quantize an array of (-inf, inf) to [0, levels-1].
Args:
arr (ndarray): Input array.
min_val (scalar): Minimum value to be clipped.
max_val (scalar): Maximum value to be clipped.
levels (int): Quantization levels.
dtype (np.type): The type of the quantized array.
Returns:
tuple: Quantized array.
"""
if not (isinstance(levels, int) and levels > 1):
raise ValueError(f'levels must be a positive integer, but got {levels}')
if min_val >= max_val:
raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})')
arr = np.clip(arr, min_val, max_val) - min_val
quantized_arr = np.minimum(np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
return quantized_arr
def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
"""Dequantize an array.
Args:
arr (ndarray): Input array.
min_val (scalar): Minimum value to be clipped.
max_val (scalar): Maximum value to be clipped.
levels (int): Quantization levels.
dtype (np.type): The type of the dequantized array.
Returns:
tuple: Dequantized array.
"""
if not (isinstance(levels, int) and levels > 1):
raise ValueError(f'levels must be a positive integer, but got {levels}')
if min_val >= max_val:
raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})')
dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - min_val) / levels + min_val
return dequantized_arr
import cv2
import numpy as np
import torch
from torch.nn import functional as F
def filter2D(img, kernel):
"""PyTorch version of cv2.filter2D
Args:
img (Tensor): (b, c, h, w)
kernel (Tensor): (b, k, k)
"""
k = kernel.size(-1)
b, c, h, w = img.size()
if k % 2 == 1:
img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect')
else:
raise ValueError('Wrong kernel size')
ph, pw = img.size()[-2:]
if kernel.size(0) == 1:
# apply the same kernel to all batch images
img = img.view(b * c, 1, ph, pw)
kernel = kernel.view(1, 1, k, k)
return F.conv2d(img, kernel, padding=0).view(b, c, h, w)
else:
img = img.view(1, b * c, ph, pw)
kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k)
return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w)
def usm_sharp(img, weight=0.5, radius=50, threshold=10):
"""USM sharpening.
Input image: I; Blurry image: B.
1. sharp = I + weight * (I - B)
2. Mask = 1 if abs(I - B) > threshold, else: 0
3. Blur mask:
4. Out = Mask * sharp + (1 - Mask) * I
Args:
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
weight (float): Sharp weight. Default: 1.
radius (float): Kernel size of Gaussian blur. Default: 50.
threshold (int):
"""
if radius % 2 == 0:
radius += 1
blur = cv2.GaussianBlur(img, (radius, radius), 0)
residual = img - blur
mask = np.abs(residual) * 255 > threshold
mask = mask.astype('float32')
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
sharp = img + weight * residual
sharp = np.clip(sharp, 0, 1)
return soft_mask * sharp + (1 - soft_mask) * img
class USMSharp(torch.nn.Module):
def __init__(self, radius=50, sigma=0):
super(USMSharp, self).__init__()
if radius % 2 == 0:
radius += 1
self.radius = radius
kernel = cv2.getGaussianKernel(radius, sigma)
kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0)
self.register_buffer('kernel', kernel)
def forward(self, img, weight=0.5, threshold=10):
blur = filter2D(img, self.kernel)
residual = img - blur
mask = torch.abs(residual) * 255 > threshold
mask = mask.float()
soft_mask = filter2D(mask, self.kernel)
sharp = img + weight * residual
sharp = torch.clip(sharp, 0, 1)
return soft_mask * sharp + (1 - soft_mask) * img
import cv2
import math
import numpy as np
import os
import torch
from torchvision.utils import make_grid
def img2tensor(imgs, bgr2rgb=True, float32=True):
"""Numpy array to tensor.
Args:
imgs (list[ndarray] | ndarray): Input images.
bgr2rgb (bool): Whether to change bgr to rgb.
float32 (bool): Whether to change to float32.
Returns:
list[tensor] | tensor: Tensor images. If returned results only have
one element, just return tensor.
"""
def _totensor(img, bgr2rgb, float32):
if img.shape[2] == 3 and bgr2rgb:
if img.dtype == 'float64':
img = img.astype('float32')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = torch.from_numpy(img.transpose(2, 0, 1))
if float32:
img = img.float()
return img
if isinstance(imgs, list):
return [_totensor(img, bgr2rgb, float32) for img in imgs]
else:
return _totensor(imgs, bgr2rgb, float32)
def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
"""Convert torch Tensors into image numpy arrays.
After clamping to [min, max], values will be normalized to [0, 1].
Args:
tensor (Tensor or list[Tensor]): Accept shapes:
1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
2) 3D Tensor of shape (3/1 x H x W);
3) 2D Tensor of shape (H x W).
Tensor channel should be in RGB order.
rgb2bgr (bool): Whether to change rgb to bgr.
out_type (numpy type): output types. If ``np.uint8``, transform outputs
to uint8 type with range [0, 255]; otherwise, float type with
range [0, 1]. Default: ``np.uint8``.
min_max (tuple[int]): min and max values for clamp.
Returns:
(Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
shape (H x W). The channel order is BGR.
"""
if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
if torch.is_tensor(tensor):
tensor = [tensor]
result = []
for _tensor in tensor:
_tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
_tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
n_dim = _tensor.dim()
if n_dim == 4:
img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
img_np = img_np.transpose(1, 2, 0)
if rgb2bgr:
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
elif n_dim == 3:
img_np = _tensor.numpy()
img_np = img_np.transpose(1, 2, 0)
if img_np.shape[2] == 1: # gray image
img_np = np.squeeze(img_np, axis=2)
else:
if rgb2bgr:
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
elif n_dim == 2:
img_np = _tensor.numpy()
else:
raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
if out_type == np.uint8:
# Unlike MATLAB, numpy.unit8() WILL NOT round by default.
img_np = (img_np * 255.0).round()
img_np = img_np.astype(out_type)
result.append(img_np)
if len(result) == 1:
result = result[0]
return result
def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
"""This implementation is slightly faster than tensor2img.
It now only supports torch tensor with shape (1, c, h, w).
Args:
tensor (Tensor): Now only support torch tensor with (1, c, h, w).
rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
min_max (tuple[int]): min and max values for clamp.
"""
output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0)
output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255
output = output.type(torch.uint8).cpu().numpy()
if rgb2bgr:
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output
def imfrombytes(content, flag='color', float32=False):
"""Read an image from bytes.
Args:
content (bytes): Image bytes got from files or other streams.
flag (str): Flags specifying the color type of a loaded image,
candidates are `color`, `grayscale` and `unchanged`.
float32 (bool): Whether to change to float32., If True, will also norm
to [0, 1]. Default: False.
Returns:
ndarray: Loaded image array.
"""
img_np = np.frombuffer(content, np.uint8)
imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
img = cv2.imdecode(img_np, imread_flags[flag])
if float32:
img = img.astype(np.float32) / 255.
return img
import os, sys
import cv2
import numpy as np
from PIL import Image
from glob import glob
import yaml
# from dataset import DatasetRepeater, PersonalDataset, FramePairDataset
from tqdm import tqdm
def folder_to_concat_folder(folder_list, existing_dict=None, order_image=None, res=None, frame_num=-1, concat_dim=1):
image_name_dict = {}
image_list = []
for f in folder_list:
image_list += sorted( glob(f+"/*.png"))
if frame_num >0:
image_list = image_list[0:frame_num]
image_name_dict[f] = image_list
print(f+str(len(image_list)))
for f in folder_list:
image_list += sorted( glob(f+"/*/*.png"))
if frame_num >0:
image_list = image_list[0:frame_num]
image_name_dict[f] = image_list
print(f+str(len(image_list)))
if existing_dict:
image_name_dict.update(existing_dict)
if order_image is None:
order_image = sorted(image_name_dict.keys())
print("order of video: ", order_image)
assert len(image_name_dict[order_image[0]])>0, f"number of frames at {f} should be large than zero"
first_image = cv2.cvtColor(cv2.imread(image_name_dict[order_image[0]][0]), cv2.COLOR_BGR2RGB)
concat_frame_list = []
if frame_num < 0:
frame_num = len(image_name_dict[order_image[0]])
for i in tqdm(np.arange(frame_num)):
image_list_i = []
for f in order_image:
img_if = cv2.cvtColor(cv2.imread(image_name_dict[f][i]), cv2.COLOR_BGR2RGB)
# final_img = cv2.cvtColor(cv2.imread(final_img_path), cv2.COLOR_BGR2RGB)
# gt_img = cv2.cvtColor(cv2.imread(gt_img_path), cv2.COLOR_BGR2RGB)
if res is not None:
if img_if.shape[0]!=res:
# print(image_name_dict[f][i])
img_if = cv2.resize(img_if, (res, res))
elif img_if.shape[0] != first_image.shape[0]:
# print(image_name_dict[f][i])
img_if = cv2.resize(img_if, (first_image.shape[0], first_image.shape[1]))
image_list_i.append(img_if)
# concat_frame_list.append(np.concatenate(image_list_i, 1))
concat_frame_list.append(np.concatenate(image_list_i, concat_dim))
return concat_frame_list
def folder_to_video(img_list, output_path):
from moviepy.editor import ImageSequenceClip
imgseqclip = ImageSequenceClip(img_list, 23.98)
imgseqclip.write_videofile((output_path), logger=None)
def imwrite(img, file_path, params=None, auto_mkdir=True):
"""Write image to file.
Args:
img (ndarray): Image array to be written.
file_path (str): Image file path.
params (None or list): Same as opencv's :func:`imwrite` interface.
auto_mkdir (bool): If the parent folder of `file_path` does not exist,
whether to create it automatically.
Returns:
bool: Successful or not.
"""
if auto_mkdir:
dir_name = os.path.abspath(os.path.dirname(file_path))
os.makedirs(dir_name, exist_ok=True)
ok = cv2.imwrite(file_path, img, params)
if not ok:
raise IOError('Failed in writing images.')
def crop_border(imgs, crop_border):
"""Crop borders of images.
Args:
imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
crop_border (int): Crop border for each end of height and weight.
Returns:
list[ndarray]: Cropped images.
"""
if crop_border == 0:
return imgs
else:
if isinstance(imgs, list):
return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
else:
return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]
import cv2
import lmdb
import sys
from multiprocessing import Pool
from os import path as osp
from tqdm import tqdm
def make_lmdb_from_imgs(data_path,
lmdb_path,
img_path_list,
keys,
batch=5000,
compress_level=1,
multiprocessing_read=False,
n_thread=40,
map_size=None):
"""Make lmdb from images.
Contents of lmdb. The file structure is:
example.lmdb
├── data.mdb
├── lock.mdb
├── meta_info.txt
The data.mdb and lock.mdb are standard lmdb files and you can refer to
https://lmdb.readthedocs.io/en/release/ for more details.
The meta_info.txt is a specified txt file to record the meta information
of our datasets. It will be automatically created when preparing
datasets by our provided dataset tools.
Each line in the txt file records 1)image name (with extension),
2)image shape, and 3)compression level, separated by a white space.
For example, the meta information could be:
`000_00000000.png (720,1280,3) 1`, which means:
1) image name (with extension): 000_00000000.png;
2) image shape: (720,1280,3);
3) compression level: 1
We use the image name without extension as the lmdb key.
If `multiprocessing_read` is True, it will read all the images to memory
using multiprocessing. Thus, your server needs to have enough memory.
Args:
data_path (str): Data path for reading images.
lmdb_path (str): Lmdb save path.
img_path_list (str): Image path list.
keys (str): Used for lmdb keys.
batch (int): After processing batch images, lmdb commits.
Default: 5000.
compress_level (int): Compress level when encoding images. Default: 1.
multiprocessing_read (bool): Whether use multiprocessing to read all
the images to memory. Default: False.
n_thread (int): For multiprocessing.
map_size (int | None): Map size for lmdb env. If None, use the
estimated size from images. Default: None
"""
assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, '
f'but got {len(img_path_list)} and {len(keys)}')
print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
print(f'Totoal images: {len(img_path_list)}')
if not lmdb_path.endswith('.lmdb'):
raise ValueError("lmdb_path must end with '.lmdb'.")
if osp.exists(lmdb_path):
print(f'Folder {lmdb_path} already exists. Exit.')
sys.exit(1)
if multiprocessing_read:
# read all the images to memory (multiprocessing)
dataset = {} # use dict to keep the order for multiprocessing
shapes = {}
print(f'Read images with multiprocessing, #thread: {n_thread} ...')
pbar = tqdm(total=len(img_path_list), unit='image')
def callback(arg):
"""get the image data and update pbar."""
key, dataset[key], shapes[key] = arg
pbar.update(1)
pbar.set_description(f'Read {key}')
pool = Pool(n_thread)
for path, key in zip(img_path_list, keys):
pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback)
pool.close()
pool.join()
pbar.close()
print(f'Finish reading {len(img_path_list)} images.')
# create lmdb environment
if map_size is None:
# obtain data size for one image
img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
_, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
data_size_per_img = img_byte.nbytes
print('Data size per image is: ', data_size_per_img)
data_size = data_size_per_img * len(img_path_list)
map_size = data_size * 10
env = lmdb.open(lmdb_path, map_size=map_size)
# write data to lmdb
pbar = tqdm(total=len(img_path_list), unit='chunk')
txn = env.begin(write=True)
txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
for idx, (path, key) in enumerate(zip(img_path_list, keys)):
pbar.update(1)
pbar.set_description(f'Write {key}')
key_byte = key.encode('ascii')
if multiprocessing_read:
img_byte = dataset[key]
h, w, c = shapes[key]
else:
_, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level)
h, w, c = img_shape
txn.put(key_byte, img_byte)
# write meta information
txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
if idx % batch == 0:
txn.commit()
txn = env.begin(write=True)
pbar.close()
txn.commit()
env.close()
txt_file.close()
print('\nFinish writing lmdb.')
def read_img_worker(path, key, compress_level):
"""Read image worker.
Args:
path (str): Image path.
key (str): Image key.
compress_level (int): Compress level when encoding images.
Returns:
str: Image key.
byte: Image byte.
tuple[int]: Image shape.
"""
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if img.ndim == 2:
h, w = img.shape
c = 1
else:
h, w, c = img.shape
_, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
return (key, img_byte, (h, w, c))
class LmdbMaker():
"""LMDB Maker.
Args:
lmdb_path (str): Lmdb save path.
map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
batch (int): After processing batch images, lmdb commits.
Default: 5000.
compress_level (int): Compress level when encoding images. Default: 1.
"""
def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1):
if not lmdb_path.endswith('.lmdb'):
raise ValueError("lmdb_path must end with '.lmdb'.")
if osp.exists(lmdb_path):
print(f'Folder {lmdb_path} already exists. Exit.')
sys.exit(1)
self.lmdb_path = lmdb_path
self.batch = batch
self.compress_level = compress_level
self.env = lmdb.open(lmdb_path, map_size=map_size)
self.txn = self.env.begin(write=True)
self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
self.counter = 0
def put(self, img_byte, key, img_shape):
self.counter += 1
key_byte = key.encode('ascii')
self.txn.put(key_byte, img_byte)
# write meta information
h, w, c = img_shape
self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n')
if self.counter % self.batch == 0:
self.txn.commit()
self.txn = self.env.begin(write=True)
def close(self):
self.txn.commit()
self.env.close()
self.txt_file.close()
import datetime
import logging
import time
from .dist_util import get_dist_info, master_only
initialized_logger = {}
class AvgTimer():
def __init__(self, window=200):
self.window = window # average window
self.current_time = 0
self.total_time = 0
self.count = 0
self.avg_time = 0
self.start()
def start(self):
self.start_time = self.tic = time.time()
def record(self):
self.count += 1
self.toc = time.time()
self.current_time = self.toc - self.tic
self.total_time += self.current_time
# calculate average time
self.avg_time = self.total_time / self.count
# reset
if self.count > self.window:
self.count = 0
self.total_time = 0
self.tic = time.time()
def get_current_time(self):
return self.current_time
def get_avg_time(self):
return self.avg_time
class MessageLogger():
"""Message logger for printing.
Args:
opt (dict): Config. It contains the following keys:
name (str): Exp name.
logger (dict): Contains 'print_freq' (str) for logger interval.
train (dict): Contains 'total_iter' (int) for total iters.
use_tb_logger (bool): Use tensorboard logger.
start_iter (int): Start iter. Default: 1.
tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
"""
def __init__(self, opt, start_iter=1, tb_logger=None):
self.exp_name = opt['name']
self.interval = opt['logger']['print_freq']
self.start_iter = start_iter
self.max_iters = opt['train']['total_iter']
self.use_tb_logger = opt['logger']['use_tb_logger']
self.tb_logger = tb_logger
self.start_time = time.time()
self.logger = get_root_logger()
def reset_start_time(self):
self.start_time = time.time()
@master_only
def __call__(self, log_vars):
"""Format logging message.
Args:
log_vars (dict): It contains the following keys:
epoch (int): Epoch number.
iter (int): Current iter.
lrs (list): List for learning rates.
time (float): Iter time.
data_time (float): Data time for each iter.
"""
# epoch, iter, learning rates
epoch = log_vars.pop('epoch')
current_iter = log_vars.pop('iter')
lrs = log_vars.pop('lrs')
message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, iter:{current_iter:8,d}, lr:(')
for v in lrs:
message += f'{v:.3e},'
message += ')] '
# time and estimated time
if 'time' in log_vars.keys():
iter_time = log_vars.pop('time')
data_time = log_vars.pop('data_time')
total_time = time.time() - self.start_time
time_sec_avg = total_time / (current_iter - self.start_iter + 1)
eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
message += f'[eta: {eta_str}, '
message += f'time (data): {iter_time:.3f} ({data_time:.3f})] '
# other items, especially losses
for k, v in log_vars.items():
message += f'{k}: {v:.4e} '
# tensorboard logger
if self.use_tb_logger and 'debug' not in self.exp_name:
if k.startswith('l_'):
self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
else:
self.tb_logger.add_scalar(k, v, current_iter)
self.logger.info(message)
@master_only
def init_tb_logger(log_dir):
# from torch.utils.tensorboard import SummaryWriter # face-aml-b cluster use torch.utils.tensorboard
from tensorboardX import SummaryWriter # rr1 cluster use tensorboardX
tb_logger = SummaryWriter(log_dir=log_dir)
return tb_logger
@master_only
def init_wandb_logger(opt):
"""We now only use wandb to sync tensorboard log."""
import wandb
logger = get_root_logger()
project = opt['logger']['wandb']['project']
resume_id = opt['logger']['wandb'].get('resume_id')
if resume_id:
wandb_id = resume_id
resume = 'allow'
logger.warning(f'Resume wandb logger with id={wandb_id}.')
else:
wandb_id = wandb.util.generate_id()
resume = 'never'
wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True)
logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')
def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
"""Get the root logger.
The logger will be initialized if it has not been initialized. By default a
StreamHandler will be added. If `log_file` is specified, a FileHandler will
also be added.
Args:
logger_name (str): root logger name. Default: 'basicsr'.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the root logger.
log_level (int): The root logger level. Note that only the process of
rank 0 is affected, while other processes will set the level to
"Error" and be silent most of the time.
Returns:
logging.Logger: The root logger.
"""
logger = logging.getLogger(logger_name)
# if the logger has been initialized, just return it
if logger_name in initialized_logger:
return logger
format_str = '%(asctime)s %(levelname)s: %(message)s'
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter(format_str))
logger.addHandler(stream_handler)
logger.propagate = False
rank, _ = get_dist_info()
if rank != 0:
logger.setLevel('ERROR')
elif log_file is not None:
logger.setLevel(log_level)
# add file handler
file_handler = logging.FileHandler(log_file, 'w')
file_handler.setFormatter(logging.Formatter(format_str))
file_handler.setLevel(log_level)
logger.addHandler(file_handler)
initialized_logger[logger_name] = True
return logger
def get_env_info():
"""Get environment information.
Currently, only log the software version.
"""
import torch
import torchvision
from basicsr.version import __version__
msg = r"""
"""
msg += ('\nVersion Information: '
f'\n\tBasicSR: {__version__}'
f'\n\tPyTorch: {torch.__version__}'
f'\n\tTorchVision: {torchvision.__version__}')
return msg
import math
import numpy as np
import torch
def cubic(x):
"""cubic function used for calculate_weights_indices."""
absx = torch.abs(x)
absx2 = absx**2
absx3 = absx**3
return (1.5 * absx3 - 2.5 * absx2 + 1) * (
(absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) *
(absx <= 2)).type_as(absx))
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
"""Calculate weights and indices, used for imresize function.
Args:
in_length (int): Input length.
out_length (int): Output length.
scale (float): Scale factor.
kernel_width (int): Kernel width.
antialisaing (bool): Whether to apply anti-aliasing when downsampling.
"""
if (scale < 1) and antialiasing:
# Use a modified kernel (larger kernel width) to simultaneously
# interpolate and antialias
kernel_width = kernel_width / scale
# Output-space coordinates
x = torch.linspace(1, out_length, out_length)
# Input-space coordinates. Calculate the inverse mapping such that 0.5
# in output space maps to 0.5 in input space, and 0.5 + scale in output
# space maps to 1.5 in input space.
u = x / scale + 0.5 * (1 - 1 / scale)
# What is the left-most pixel that can be involved in the computation?
left = torch.floor(u - kernel_width / 2)
# What is the maximum number of pixels that can be involved in the
# computation? Note: it's OK to use an extra pixel here; if the
# corresponding weights are all zero, it will be eliminated at the end
# of this function.
p = math.ceil(kernel_width) + 2
# The indices of the input pixels involved in computing the k-th output
# pixel are in row k of the indices matrix.
indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
out_length, p)
# The weights used to compute the k-th output pixel are in row k of the
# weights matrix.
distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
# apply cubic kernel
if (scale < 1) and antialiasing:
weights = scale * cubic(distance_to_center * scale)
else:
weights = cubic(distance_to_center)
# Normalize the weights matrix so that each row sums to 1.
weights_sum = torch.sum(weights, 1).view(out_length, 1)
weights = weights / weights_sum.expand(out_length, p)
# If a column in weights is all zero, get rid of it. only consider the
# first and last column.
weights_zero_tmp = torch.sum((weights == 0), 0)
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
indices = indices.narrow(1, 1, p - 2)
weights = weights.narrow(1, 1, p - 2)
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
indices = indices.narrow(1, 0, p - 2)
weights = weights.narrow(1, 0, p - 2)
weights = weights.contiguous()
indices = indices.contiguous()
sym_len_s = -indices.min() + 1
sym_len_e = indices.max() - in_length
indices = indices + sym_len_s - 1
return weights, indices, int(sym_len_s), int(sym_len_e)
@torch.no_grad()
def imresize(img, scale, antialiasing=True):
"""imresize function same as MATLAB.
It now only supports bicubic.
The same scale applies for both height and width.
Args:
img (Tensor | Numpy array):
Tensor: Input image with shape (c, h, w), [0, 1] range.
Numpy: Input image with shape (h, w, c), [0, 1] range.
scale (float): Scale factor. The same scale applies for both height
and width.
antialisaing (bool): Whether to apply anti-aliasing when downsampling.
Default: True.
Returns:
Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round.
"""
squeeze_flag = False
if type(img).__module__ == np.__name__: # numpy type
numpy_type = True
if img.ndim == 2:
img = img[:, :, None]
squeeze_flag = True
img = torch.from_numpy(img.transpose(2, 0, 1)).float()
else:
numpy_type = False
if img.ndim == 2:
img = img.unsqueeze(0)
squeeze_flag = True
in_c, in_h, in_w = img.size()
out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale)
kernel_width = 4
kernel = 'cubic'
# get weights and indices
weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width,
antialiasing)
weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width,
antialiasing)
# process H dimension
# symmetric copying
img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
img_aug.narrow(1, sym_len_hs, in_h).copy_(img)
sym_patch = img[:, :sym_len_hs, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
sym_patch = img[:, -sym_len_he:, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
out_1 = torch.FloatTensor(in_c, out_h, in_w)
kernel_width = weights_h.size(1)
for i in range(out_h):
idx = int(indices_h[i][0])
for j in range(in_c):
out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
# process W dimension
# symmetric copying
out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
sym_patch = out_1[:, :, :sym_len_ws]
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(2, inv_idx)
out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
sym_patch = out_1[:, :, -sym_len_we:]
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(2, inv_idx)
out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
out_2 = torch.FloatTensor(in_c, out_h, out_w)
kernel_width = weights_w.size(1)
for i in range(out_w):
idx = int(indices_w[i][0])
for j in range(in_c):
out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
if squeeze_flag:
out_2 = out_2.squeeze(0)
if numpy_type:
out_2 = out_2.numpy()
if not squeeze_flag:
out_2 = out_2.transpose(1, 2, 0)
return out_2
import numpy as np
import os
import random
import time
import torch
from os import path as osp
import shutil
from .dist_util import master_only
def set_random_seed(seed):
"""Set random seeds."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def get_time_str():
return time.strftime('%Y%m%d_%H%M%S', time.localtime())
def mkdir_and_rename(path):
"""mkdirs. If path exists, rename it with timestamp and create a new one.
Args:
path (str): Folder path.
"""
if osp.exists(path):
new_name = path + '_archived_' + get_time_str()
print(f'Path already exists. Rename it to {new_name}', flush=True)
# os.rename(path, new_name)
shutil.move(path, new_name)
os.makedirs(path, exist_ok=True)
@master_only
def make_exp_dirs(opt):
"""Make dirs for experiments."""
path_opt = opt['path'].copy()
if opt['is_train']:
mkdir_and_rename(path_opt.pop('experiments_root'))
else:
mkdir_and_rename(path_opt.pop('results_root'))
for key, path in path_opt.items():
if ('strict_load' in key) or ('pretrain_network' in key) or ('resume' in key) or ('param_key' in key):
continue
else:
os.makedirs(path, exist_ok=True)
def scandir(dir_path, suffix=None, recursive=False, full_path=False):
"""Scan a directory to find the interested files.
Args:
dir_path (str): Path of the directory.
suffix (str | tuple(str), optional): File suffix that we are
interested in. Default: None.
recursive (bool, optional): If set to True, recursively scan the
directory. Default: False.
full_path (bool, optional): If set to True, include the dir_path.
Default: False.
Returns:
A generator for all the interested files with relative paths.
"""
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
raise TypeError('"suffix" must be a string or tuple of strings')
root = dir_path
def _scandir(dir_path, suffix, recursive):
for entry in os.scandir(dir_path):
if not entry.name.startswith('.') and entry.is_file():
if full_path:
return_path = entry.path
else:
return_path = osp.relpath(entry.path, root)
if suffix is None:
yield return_path
elif return_path.endswith(suffix):
yield return_path
else:
if recursive:
yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
else:
continue
return _scandir(dir_path, suffix=suffix, recursive=recursive)
def check_resume(opt, resume_iter):
"""Check resume states and pretrain_network paths.
Args:
opt (dict): Options.
resume_iter (int): Resume iteration.
"""
if opt['path']['resume_state']:
# get all the networks
networks = [key for key in opt.keys() if key.startswith('network_')]
flag_pretrain = False
for network in networks:
if opt['path'].get(f'pretrain_{network}') is not None:
flag_pretrain = True
if flag_pretrain:
print('pretrain_network path will be ignored during resuming.')
# set pretrained model paths
for network in networks:
name = f'pretrain_{network}'
basename = network.replace('network_', '')
if opt['path'].get('ignore_resume_networks') is None or (network
not in opt['path']['ignore_resume_networks']):
opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
print(f"Set {name} to {opt['path'][name]}")
# change param_key to params in resume
param_keys = [key for key in opt['path'].keys() if key.startswith('param_key')]
for param_key in param_keys:
if opt['path'][param_key] == 'params_ema':
opt['path'][param_key] = 'params'
print(f'Set {param_key} to params')
def sizeof_fmt(size, suffix='B'):
"""Get human readable file size.
Args:
size (int): File size.
suffix (str): Suffix. Default: 'B'.
Return:
str: Formatted file siz.
"""
for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
if abs(size) < 1024.0:
return f'{size:3.1f} {unit}{suffix}'
size /= 1024.0
return f'{size:3.1f} Y{suffix}'
import argparse
import random
import torch
import yaml
from collections import OrderedDict
from os import path as osp
from basicsr.utils import set_random_seed
from basicsr.utils.dist_util import get_dist_info, init_dist, master_only
def ordered_yaml():
"""Support OrderedDict for yaml.
Returns:
yaml Loader and Dumper.
"""
try:
from yaml import CDumper as Dumper
from yaml import CLoader as Loader
except ImportError:
from yaml import Dumper, Loader
_mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
def dict_representer(dumper, data):
return dumper.represent_dict(data.items())
def dict_constructor(loader, node):
return OrderedDict(loader.construct_pairs(node))
Dumper.add_representer(OrderedDict, dict_representer)
Loader.add_constructor(_mapping_tag, dict_constructor)
return Loader, Dumper
def dict2str(opt, indent_level=1):
"""dict to string for printing options.
Args:
opt (dict): Option dict.
indent_level (int): Indent level. Default: 1.
Return:
(str): Option string for printing.
"""
msg = '\n'
for k, v in opt.items():
if isinstance(v, dict):
msg += ' ' * (indent_level * 2) + k + ':['
msg += dict2str(v, indent_level + 1)
msg += ' ' * (indent_level * 2) + ']\n'
else:
msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
return msg
def _postprocess_yml_value(value):
# None
if value == '~' or value.lower() == 'none':
return None
# bool
if value.lower() == 'true':
return True
elif value.lower() == 'false':
return False
# !!float number
if value.startswith('!!float'):
return float(value.replace('!!float', ''))
# number
if value.isdigit():
return int(value)
elif value.replace('.', '', 1).isdigit() and value.count('.') < 2:
return float(value)
# list
if value.startswith('['):
return eval(value)
# str
return value
def parse_options(root_path, is_train=True):
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.')
parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher')
parser.add_argument('--auto_resume', action='store_true')
parser.add_argument('--debug', action='store_true')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument(
'--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999')
args = parser.parse_args()
# parse yml to dict
with open(args.opt, mode='r') as f:
opt = yaml.load(f, Loader=ordered_yaml()[0])
# distributed settings
if args.launcher == 'none':
opt['dist'] = False
print('Disable distributed.', flush=True)
else:
opt['dist'] = True
if args.launcher == 'slurm' and 'dist_params' in opt:
init_dist(args.launcher, **opt['dist_params'])
else:
init_dist(args.launcher)
opt['rank'], opt['world_size'] = get_dist_info()
# random seed
seed = opt.get('manual_seed')
if seed is None:
seed = random.randint(1, 10000)
opt['manual_seed'] = seed
set_random_seed(seed + opt['rank'])
# force to update yml options
if args.force_yml is not None:
for entry in args.force_yml:
# now do not support creating new keys
keys, value = entry.split('=')
keys, value = keys.strip(), value.strip()
value = _postprocess_yml_value(value)
eval_str = 'opt'
for key in keys.split(':'):
eval_str += f'["{key}"]'
eval_str += '=value'
# using exec function
exec(eval_str)
opt['auto_resume'] = args.auto_resume
opt['is_train'] = is_train
# debug setting
if args.debug and not opt['name'].startswith('debug'):
opt['name'] = 'debug_' + opt['name']
if opt['num_gpu'] == 'auto':
opt['num_gpu'] = torch.cuda.device_count()
# datasets
for phase, dataset in opt['datasets'].items():
# for multiple datasets, e.g., val_1, val_2; test_1, test_2
phase = phase.split('_')[0]
dataset['phase'] = phase
if 'scale' in opt:
dataset['scale'] = opt['scale']
if dataset.get('dataroot_gt') is not None:
dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
if dataset.get('dataroot_lq') is not None:
dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
# paths
for key, val in opt['path'].items():
if (val is not None) and ('resume_state' in key or 'pretrain_network' in key):
opt['path'][key] = osp.expanduser(val)
if is_train:
experiments_root = osp.join(root_path, 'experiments', opt['name'])
opt['path']['experiments_root'] = experiments_root
opt['path']['models'] = osp.join(experiments_root, 'models')
opt['path']['training_states'] = osp.join(experiments_root, 'training_states')
opt['path']['log'] = experiments_root
opt['path']['visualization'] = osp.join(experiments_root, 'visualization')
# change some options for debug mode
if 'debug' in opt['name']:
if 'val' in opt:
opt['val']['val_freq'] = 8
opt['logger']['print_freq'] = 1
opt['logger']['save_checkpoint_freq'] = 8
else: # test
results_root = osp.join(root_path, 'results', opt['name'])
opt['path']['results_root'] = results_root
opt['path']['log'] = results_root
opt['path']['visualization'] = osp.join(results_root, 'visualization')
return opt, args
@master_only
def copy_opt_file(opt_file, experiments_root):
# copy the yml file to the experiment root
import sys
import time
from shutil import copyfile
cmd = ' '.join(sys.argv)
filename = osp.join(experiments_root, osp.basename(opt_file))
copyfile(opt_file, filename)
with open(filename, 'r+') as f:
lines = f.readlines()
lines.insert(0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n')
f.seek(0)
f.writelines(lines)
# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501
class Registry():
"""
The registry that provides name -> object mapping, to support third-party
users' custom modules.
To create a registry (e.g. a backbone registry):
.. code-block:: python
BACKBONE_REGISTRY = Registry('BACKBONE')
To register an object:
.. code-block:: python
@BACKBONE_REGISTRY.register()
class MyBackbone():
...
Or:
.. code-block:: python
BACKBONE_REGISTRY.register(MyBackbone)
"""
def __init__(self, name):
"""
Args:
name (str): the name of this registry
"""
self._name = name
self._obj_map = {}
def _do_register(self, name, obj, suffix=None):
if isinstance(suffix, str):
name = name + '_' + suffix
assert (name not in self._obj_map), (f"An object named '{name}' was already registered "
f"in '{self._name}' registry!")
self._obj_map[name] = obj
def register(self, obj=None, suffix=None):
"""
Register the given object under the the name `obj.__name__`.
Can be used as either a decorator or not.
See docstring of this class for usage.
"""
if obj is None:
# used as a decorator
def deco(func_or_class):
name = func_or_class.__name__
self._do_register(name, func_or_class, suffix)
return func_or_class
return deco
# used as a function call
name = obj.__name__
self._do_register(name, obj, suffix)
def get(self, name, suffix='basicsr'):
ret = self._obj_map.get(name)
if ret is None:
ret = self._obj_map.get(name + '_' + suffix)
print(f'Name {name} is not found, use name: {name}_{suffix}!')
if ret is None:
raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
return ret
def __contains__(self, name):
return name in self._obj_map
def __iter__(self):
return iter(self._obj_map.items())
def keys(self):
return self._obj_map.keys()
DATASET_REGISTRY = Registry('dataset')
ARCH_REGISTRY = Registry('arch')
MODEL_REGISTRY = Registry('model')
LOSS_REGISTRY = Registry('loss')
METRIC_REGISTRY = Registry('metric')
# Colab
<a href="https://drive.google.com/drive/folders/1G_qcpvkT5ixmw5XoN6MupkOzcK1km625?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>
To maintain a small size of BasicSR repo, we do not include the original colab notebooks in this repo, but provide links to the google colab.
| Face Restoration| |
| :--- | :---: |
|DFDNet | [BasicSR_inference_DFDNet.ipynb](https://colab.research.google.com/drive/1RoNDeipp9yPjI3EbpEbUhn66k5Uzg4n8?usp=sharing)|
| **Super-Resolution**| |
|ESRGAN |[BasicSR_inference_ESRGAN.ipynb](https://colab.research.google.com/drive/1JQScYICvEC3VqaabLu-lxvq9h7kSV1ML?usp=sharing)|
| **Deblurring**| |
| **Denoise**| |
# Configuration
[English](Config.md) **|** [简体中文](Config_CN.md)
#### Contents
1. [Experiment Name Convention](#Experiment-Name-Convention)
1. [Configuration Explanation](#Configuration-Explanation)
1. [Training Configuration](#Training-Configuration)
1. [Testing Configuration](#Testing-Configuration)
## Experiment Name Convention
Taking `001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb` as an example:
- `001`: We usually use index for managing experiments
- `MSRResNet`: Model name, here is Modified SRResNet
- `x4_f64b16`: Import configuration parameters. It means the upsampling ratio is 4; the channel number of middle features is 64; and it uses 16 residual block
- `DIV2K`: Training data is DIV2K
- `1000k`: Total training iteration is 1000k
- `B16G1`: Batch size is 16; one GPU is used for training
- `wandb`: Use wandb logger; the training process has beed uploaded to wandb server
**Note**: If `debug` is in the experiment name, it will enter the debug mode. That is, the program will log and validate more intensively and will not use `tensorboard logger` and `wandb logger`.
## Configuration Explanation
We use yaml files for configuration.
### Training Configuration
Taking [train_MSRResNet_x4.yml](../options/train/SRResNet_SRGAN/train_MSRResNet_x4.yml) as an example:
```yml
####################################
# The following are general settings
####################################
# Experiment name, more details are in [Experiment Name Convention]. If debug in the experiment name, it will enter debug mode
name: 001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb
# Model type. Usually the class name defined in the `models` folder
model_type: SRModel
# The scale of the output over the input. In SR, it is the upsampling ratio. If not defined, use 1
scale: 4
# The number of GPUs for training
num_gpu: 1 # set num_gpu: 0 for cpu mode
# Random seed
manual_seed: 0
########################################################
# The following are the dataset and data loader settings
########################################################
datasets:
# Training dataset settings
train:
# Dataset name
name: DIV2K
# Dataset type. Usually the class name defined in the `data` folder
type: PairedImageDataset
#### The following arguments are flexible and can be obtained in the corresponding doc
# GT (Ground-Truth) folder path
dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub
# LQ (Low-Quality) folder path
dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic/X4_sub
# template for file name. Usually, LQ files have suffix like `_x4`. It is used for file name mismatching
filename_tmpl: '{}'
# IO backend, more details are in [docs/DatasetPreparation.md]
io_backend:
# directly read from disk
type: disk
# Ground-Truth training patch size
gt_size: 128
# Whether to use horizontal flip. Here, flip is for horizontal flip
use_hflip: true
# Whether to rotate. Here for rotations with every 90 degree
use_rot: true
#### The following are data loader settings
# Whether to shuffle
use_shuffle: true
# Number of workers of reading data for each GPU
num_worker_per_gpu: 6
# Total training batch size
batch_size_per_gpu: 16
# THe ratio of enlarging dataset. For example, it will repeat 100 times for a dataset with 15 images
# So that after one epoch, it will read 1500 times. It is used for accelerating data loader
# since it costs too much time at the start of a new epoch
dataset_enlarge_ratio: 100
# validation dataset settings
val:
# Dataset name
name: Set5
# Dataset type. Usually the class name defined in the `data` folder
type: PairedImageDataset
#### The following arguments are flexible and can be obtained in the corresponding doc
# GT (Ground-Truth) folder path
dataroot_gt: datasets/Set5/GTmod12
# LQ (Low-Quality) folder path
dataroot_lq: datasets/Set5/LRbicx4
# IO backend, more details are in [docs/DatasetPreparation.md]
io_backend:
# directly read from disk
type: disk
##################################################
# The following are the network structure settings
##################################################
# network g settings
network_g:
# Architecture type. Usually the class name defined in the `basicsr/archs` folder
type: MSRResNet
#### The following arguments are flexible and can be obtained in the corresponding doc
# Channel number of inputs
num_in_ch: 3
# Channel number of outputs
num_out_ch: 3
# Channel number of middle features
num_feat: 64
# block number
num_block: 16
# upsampling ratio
upscale: 4
#########################################################
# The following are path, pretraining and resume settings
#########################################################
path:
# Path for pretrained models, usually end with pth
pretrain_network_g: ~
# Whether to load pretrained models strictly, that is the corresponding parameter names should be the same
strict_load_g: true
# Path for resume state. Usually in the `experiments/exp_name/training_states` folder
# This argument will over-write the `pretrain_network_g`
resume_state: ~
#####################################
# The following are training settings
#####################################
train:
# Optimizer settings
optim_g:
# Optimizer type
type: Adam
#### The following arguments are flexible and can be obtained in the corresponding doc
# Learning rate
lr: !!float 2e-4
weight_decay: 0
# beta1 and beta2 for the Adam
betas: [0.9, 0.99]
# Learning rate scheduler settings
scheduler:
# Scheduler type
type: CosineAnnealingRestartLR
#### The following arguments are flexible and can be obtained in the corresponding doc
# Cosine Annealing periods
periods: [250000, 250000, 250000, 250000]
# Cosine Annealing restart weights
restart_weights: [1, 1, 1, 1]
# Cosine Annealing minimum learning rate
eta_min: !!float 1e-7
# Total iterations for training
total_iter: 1000000
# Warm up iterations. -1 indicates no warm up
warmup_iter: -1
#### The following are loss settings
# Pixel-wise loss options
pixel_opt:
# Loss type. Usually the class name defined in the `basicsr/models/losses` folder
type: L1Loss
# Loss weight
loss_weight: 1.0
# Loss reduction mode
reduction: mean
#######################################
# The following are validation settings
#######################################
val:
# validation frequency. Validate every 5000 iterations
val_freq: !!float 5e3
# Whether to save images during validation
save_img: false
# Metrics in validation
metrics:
# Metric name. It can be arbitrary
psnr:
# Metric type. Usually the function name defined in the`basicsr/metrics` folder
type: calculate_psnr
#### The following arguments are flexible and can be obtained in the corresponding doc
# Whether to crop border during validation
crop_border: 4
# Whether to convert to Y(CbCr) for validation
test_y_channel: false
########################################
# The following are the logging settings
########################################
logger:
# Logger frequency
print_freq: 100
# The frequency for saving checkpoints
save_checkpoint_freq: !!float 5e3
# Whether to tensorboard logger
use_tb_logger: true
# Whether to use wandb logger. Currently, wandb only sync the tensorboard log. So we should also turn on tensorboard when using wandb
wandb:
# wandb project name. Default is None, that is not using wandb.
# Here, we use the basicsr wandb project: https://app.wandb.ai/xintao/basicsr
project: basicsr
# If resuming, wandb id could automatically link previous logs
resume_id: ~
################################################
# The following are distributed training setting
# Only require for slurm training
################################################
dist_params:
backend: nccl
port: 29500
```
### Testing Configuration
Taking [test_MSRResNet_x4.yml](../options/test/SRResNet_SRGAN/test_MSRResNet_x4.yml) as an example:
```yml
# Experiment name
name: 001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb
# Model type. Usually the class name defined in the `models` folder
model_type: SRModel
# The scale of the output over the input. In SR, it is the upsampling ratio. If not defined, use 1
scale: 4
# The number of GPUs for testing
num_gpu: 1 # set num_gpu: 0 for cpu mode
########################################################
# The following are the dataset and data loader settings
########################################################
datasets:
# Testing dataset settings. The first testing dataset
test_1:
# Dataset name
name: Set5
# Dataset type. Usually the class name defined in the `data` folder
type: PairedImageDataset
#### The following arguments are flexible and can be obtained in the corresponding doc
# GT (Ground-Truth) folder path
dataroot_gt: datasets/Set5/GTmod12
# LQ (Low-Quality) folder path
dataroot_lq: datasets/Set5/LRbicx4
# IO backend, more details are in [docs/DatasetPreparation.md]
io_backend:
# directly read from disk
type: disk
# Testing dataset settings. The second testing dataset
test_2:
name: Set14
type: PairedImageDataset
dataroot_gt: datasets/Set14/GTmod12
dataroot_lq: datasets/Set14/LRbicx4
io_backend:
type: disk
# Testing dataset settings. The third testing dataset
test_3:
name: DIV2K100
type: PairedImageDataset
dataroot_gt: datasets/DIV2K/DIV2K_valid_HR
dataroot_lq: datasets/DIV2K/DIV2K_valid_LR_bicubic/X4
filename_tmpl: '{}x4'
io_backend:
type: disk
##################################################
# The following are the network structure settings
##################################################
# network g settings
network_g:
# Architecture type. Usually the class name defined in the `basicsr/archs` folder
type: MSRResNet
#### The following arguments are flexible and can be obtained in the corresponding doc
# Channel number of inputs
num_in_ch: 3
# Channel number of outputs
num_out_ch: 3
# Channel number of middle features
num_feat: 64
# block number
num_block: 16
# upsampling ratio
upscale: 4
upscale: 4
#################################################
# The following are path and pretraining settings
#################################################
path:
## Path for pretrained models, usually end with pth
pretrain_network_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth
# Whether to load pretrained models strictly, that is the corresponding parameter names should be the same
strict_load_g: true
##########################################################
# The following are validation settings (Also for testing)
##########################################################
val:
# Whether to save images during validation
save_img: true
# Suffix for saved images. If None, use exp name
suffix: ~
# Metrics in validation
metrics:
# Metric name. It can be arbitrary
psnr:
# Metric type. Usually the function name defined in the`basicsr/metrics` folder
type: calculate_psnr
#### The following arguments are flexible and can be obtained in the corresponding doc
# Whether to crop border during validation
crop_border: 4
# Whether to convert to Y(CbCr) for validation
test_y_channel: false
# Another metric
ssim:
type: calculate_ssim
crop_border: 4
test_y_channel: false
```
# 配置文件
[English](Config.md) **|** [简体中文](Config_CN.md)
#### 目录
1. [实验命名](#实验命名)
1. [配置文件说明](#配置文件说明)
1. [训练配置文件](#训练配置文件)
1. [测试配置文件](#测试配置文件)
## 实验命名
`001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb`为例:
- `001`: 我们一般给实验标号, 方便实验管理
- `MSRResNet`: 模型名称, 这里是Modified SRResNet
- `x4_f64b16`: 重要配置参数, 这里表示放大4倍; 中间feature通道数是64, 使用了16个Residual Block
- `DIV2K`: 训练数据集是DIV2K
- `1000k`: 训练了1000k iterations
- `B16G1`: Batch size为16, 使用一个GPU训练
- `wandb`: 使用了wandb, 训练过程上传到了wandb云服务器
**注意**: 如果在实验名字中有`debug`字样, 则会进入debug模式, 即程序会更密集地log和validate, 并且不会使用`tensorboard logger``wandb logger`.
## 配置文件说明
我们使用了 yaml 格式来做配置文件.
### 训练配置文件
我们以[train_MSRResNet_x4.yml](../options/train/SRResNet_SRGAN/train_MSRResNet_x4.yml)为例, 说明训练配置文件的含义:
```yml
#################
# 以下为通用的设置
#################
# 实验名称, 具体可参见 [实验名称命名], 若实验名字中有debug字样, 则会进入debug模式
name: 001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb
# 使用的model类型, 一般为在`models`目录下定义的模型的类名
model_type: SRModel
# 输出相比输入的放大比率, 在SR中是放大倍数; 若有些任务没有这个配置, 则写1
scale: 4
# 训练卡数
num_gpu: 1 # set num_gpu: 0 for cpu mode
# 随机种子设定
manual_seed: 0
#################################
# 以下为dataset和data loader的设置
#################################
datasets:
# 训练数据集的设置
train:
# 数据集的名称
name: DIV2K
# 数据集的类型, 一般为在`data`目录下定义的dataset的类名
type: PairedImageDataset
#### 以下属性是灵活的, 可以在相应类的说明文档中获得; 若新加数据集, 则可以根据需要添加
# GT (Ground-Truth) 图像的文件夹路径
dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub
# LQ (Low-Quality) 图像的文件夹路径
dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic/X4_sub
# 文件名字模板, 一般LQ文件会有类似`_x4`这样的文件后缀, 这个就是来处理GT和LQ文件后缀不匹配的问题的
filename_tmpl: '{}'
# IO 读取的backend, 详细可以参见 [docs/DatasetPreparation_CN.md]
io_backend:
# disk 表示直接从硬盘读取
type: disk
# 训练中Ground-Truth的Training patch的大小
gt_size: 128
# 是否使用horizontal flip, 这里的flip特指 horizontal flip
use_hflip: true
# 是否使用rotation, 这里指的是每隔90°旋转
use_rot: true
#### 下面是data loader的设置
# data loader是否使用shuffle
use_shuffle: true
# 每一个GPU的data loader读取进程数目
num_worker_per_gpu: 6
# 总共的训练batch size
batch_size_per_gpu: 16
# 扩大dataset的倍率. 比如数据集有15张图, 则会重复这些图片100次, 这样一个epoch下来, 能够读取1500张图
# (事实上是重复读的). 它经常用来加速data loader, 因为在有的机器上, 一个epoch结束, 会重启进程, 往往会很慢
dataset_enlarge_ratio: 100
# validation 数据集的设置
val:
# 数据集名称
name: Set5
# 数据集的类型, 一般为在`data`目录下定义的dataset的类名
type: PairedImageDataset
#### 以下属性是灵活的, 可以在相应类的说明文档中获得; 若新加数据集, 则可以根据需要添加
# GT (Ground-Truth) 图像的文件夹路径
dataroot_gt: datasets/Set5/GTmod12
# LQ (Low-Quality) 图像的文件夹路径
dataroot_lq: datasets/Set5/LRbicx4
# IO 读取的backend, 详细可以参见 [docs/DatasetPreparation_CN.md]
io_backend:
# disk 表示直接从硬盘读取
type: disk
#####################
# 以下为网络结构的设置
#####################
# 网络g的设置
network_g:
# 网络结构 (Architecture)的类型, 一般为在`basicsr/archs`目录下定义的dataset的类名
type: MSRResNet
#### 以下属性是灵活的, 可以在相应类的说明文档中获得
# 输入通道数目
num_in_ch: 3
# 输出通道数目
num_out_ch: 3
# 中间特征通道数目
num_feat: 64
# 使用block的数目
num_block: 16
# SR的放大倍数
upscale: 4
######################################
# 以下为路径和与训练模型、重启训练的设置
######################################
path:
# 预训练模型的路径, 需要以pth结尾的模型
pretrain_network_g: ~
# 加载预训练模型的时候, 是否需要网络参数的名称严格对应
strict_load_g: true
# 重启训练的状态路径, 一般在`experiments/exp_name/training_states`目录下
# 这个设置了, 会覆盖 pretrain_network_g 的设定
resume_state: ~
#################
# 以下为训练的设置
#################
train:
# 优化器设置
optim_g:
# 优化器类型
type: Adam
##### 以下属性是灵活的, 根据不同优化器有不同的设置
# 学习率
lr: !!float 2e-4
weight_decay: 0
# Adam优化器的 beta1 和 beta2
betas: [0.9, 0.99]
# 学习率的设定
scheduler:
# 学习率Scheduler的类型
type: CosineAnnealingRestartLR
#### 以下属性是灵活的, 根据学习率Scheduler有不同的设置
# Cosine Annealing的周期
periods: [250000, 250000, 250000, 250000]
# Cosine Annealing每次Restart的权重
restart_weights: [1, 1, 1, 1]
# Cosine Annealing的学习率最小值
eta_min: !!float 1e-7
# 总共的训练迭代次数
total_iter: 1000000
# warm up的iteration数目, 如是-1, 表示没有warm up
warmup_iter: -1 # no warm up
#### 以下是loss的设置
# pixel-wise loss的options
pixel_opt:
# loss类型, 一般为在`basicsr/models/losses`目录下定义的loss类名
type: L1Loss
# loss 权重
loss_weight: 1.0
# loss reduction方式
reduction: mean
#######################
# 以下为Validation的设置
#######################
val:
# validation的频率, 每隔 5000 iterations 做一次validation
val_freq: !!float 5e3
# 是否需要在validation的时候保存图片
save_img: false
# Validation时候使用的metric
metrics:
# metric的名字, 这个名字可以是任意的
psnr:
# metric的类型, 一般为在`basicsr/metrics`目录下定义的metric函数名
type: calculate_psnr
#### 以下属性是灵活的, 根据metric有不同的设置
# 计算metric时, 是否需要crop border
crop_border: 4
# 是否转成在Y(CbCr)空间上计算metric
test_y_channel: false
####################
# 以下为Logging的设置
####################
logger:
# 屏幕上打印的logger频率
print_freq: 100
# 保存checkpoint的频率
save_checkpoint_freq: !!float 5e3
# 是否使用tensorboard logger
use_tb_logger: true
# 是否使用wandb logger, 目前wandb只是同步tensorboard的内容, 因此要使用wandb, 必须也同时使用tensorboard
wandb:
# wandb的project. 默认是 None, 即不使用wandb.
# 这里使用了 basicsr wandb project: https://app.wandb.ai/xintao/basicsr
project: basicsr
# 如果是resume, 可以输入上次的wandb id, 则log可以接起来
resume_id: ~
#############################################################
# 以下为distributed training的设置, 目前只有在Slurm训练下才需要
#############################################################
dist_params:
backend: nccl
port: 29500
```
### 测试配置文件
我们以[test_MSRResNet_x4.yml](../options/test/SRResNet_SRGAN/test_MSRResNet_x4.yml)为例, 说明测试配置文件的含义:
```yml
# 实验名称
name: 001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb
# 使用的model类型, 一般为在`models`目录下定义的模型的类名
model_type: SRModel
# 输出相比输入的放大比率, 在SR中是放大倍数; 若有些任务没有这个配置, 则写1
scale: 4
# 测试卡数
num_gpu: 1 # set num_gpu: 0 for cpu mode
#################################
# 以下为dataset和data loader的设置
#################################
datasets:
# 测试数据集的设置, 后缀1表示第一个测试集
test_1:
# 数据集的名称
name: Set5
# 数据集的类型, 一般为在`data`目录下定义的dataset的类名
type: PairedImageDataset
#### 以下属性是灵活的, 可以在相应类的说明文档中获得; 若新加数据集, 则可以根据需要添加
# GT (Ground-Truth) 图像的文件夹路径
dataroot_gt: datasets/Set5/GTmod12
# LQ (Low-Quality) 图像的文件夹路径
dataroot_lq: datasets/Set5/LRbicx4
# IO 读取的backend, 详细可以参见 [docs/DatasetPreparation_CN.md]
io_backend:
# disk 表示直接从硬盘读取
type: disk
# 测试数据集的设置, 后缀2表示第二个测试集
test_2:
name: Set14
type: PairedImageDataset
dataroot_gt: datasets/Set14/GTmod12
dataroot_lq: datasets/Set14/LRbicx4
io_backend:
type: disk
# 测试数据集的设置, 后缀3表示第三个测试集
test_3:
name: DIV2K100
type: PairedImageDataset
dataroot_gt: datasets/DIV2K/DIV2K_valid_HR
dataroot_lq: datasets/DIV2K/DIV2K_valid_LR_bicubic/X4
filename_tmpl: '{}x4'
io_backend:
type: disk
#####################
# 以下为网络结构的设置
#####################
# 网络g的设置
network_g:
# 网络结构 (Architecture)的类型, 一般为在`basicsr/archs`目录下定义的dataset的类名
type: MSRResNet
#### 以下属性是灵活的, 可以在相应类的说明文档中获得
# 输入通道数目
num_in_ch: 3
# 输出通道数目
num_out_ch: 3
# 中间特征通道数目
num_feat: 64
# 使用block的数目
num_block: 16
# SR的放大倍数
upscale: 4
#############################
# 以下为路径和与训练模型的设置
#############################
path:
# 预训练模型的路径, 需要以pth结尾的模型
pretrain_network_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth
# 加载预训练模型的时候, 是否需要网络参数的名称严格对应
strict_load_g: true
##################################
# 以下为Validation (也是测试)的设置
##################################
val:
# 是否需要在测试的时候保存图片
save_img: true
# 对保存的图片添加后缀,如果是None, 则使用exp name
suffix: ~
# 测试时候使用的metric
metrics:
# metric的名字, 这个名字可以是任意的
psnr:
# metric的类型, 一般为在`basicsr/metrics`目录下定义的metric函数名
type: calculate_psnr
#### 以下属性是灵活的, 根据metric有不同的设置
# 计算metric时, 是否需要crop border
crop_border: 4
# 是否转成在Y(CbCr)空间上计算metric
test_y_channel: false
# 另外一个metric
ssim:
type: calculate_ssim
crop_border: 4
test_y_channel: false
```
# Dataset Preparation
[English](DatasetPreparation.md) **|** [简体中文](DatasetPreparation_CN.md)
#### Contents
1. [Data Storage Format](#Data-Storage-Format)
1. [How to Use](#How-to-Use)
1. [How to Implement](#How-to-Implement)
1. [LMDB Description](#LMDB-Description)
1. [Data Pre-fetcher](#Data-Pre-fetcher)
1. [Image Super-Resolution](#Image-Super-Resolution)
1. [DIV2K](#DIV2K)
1. [Common Image SR Datasets](#Common-Image-SR-Datasets)
1. [Video Super-Resolution](#Video-Super-Resolution)
1. [REDS](#REDS)
1. [Vimeo90K](#Vimeo90K)
1. [StylgeGAN2](#StyleGAN2)
1. [FFHQ](#FFHQ)
## Data Storage Format
At present, there are three types of data storage formats supported:
1. Store in `hard disk` directly in the format of images / video frames.
1. Make [LMDB](https://lmdb.readthedocs.io/en/release/), which could accelerate the IO and decompression speed during training.
1. [memcached](https://memcached.org/) is also supported, if they are installed (usually on clusters).
#### How to Use
At present, we can modify the configuration yaml file to support different data storage formats. Taking [PairedImageDataset](../basicsr/data/paired_image_dataset.py) as an example, we can modify the yaml file according to different requirements.
1. Directly read disk data.
```yaml
type: PairedImageDataset
dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub
dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic/X4_sub
io_backend:
type: disk
```
1. Use LMDB.
We need to make LMDB before using it. Please refer to [LMDB description](#LMDB-Description). Note that we add meta information to the original LMDB, and the specific binary contents are also different. Therefore, LMDB from other sources can not be used directly.
```yaml
type: PairedImageDataset
dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub.lmdb
dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb
io_backend:
type: lmdb
```
1. Use Memcached
Your machine/clusters mush support memcached before using it. The configuration file should be modified accordingly.
```yaml
type: PairedImageDataset
dataroot_gt: datasets/DIV2K_train_HR_sub
dataroot_lq: datasets/DIV2K_train_LR_bicubicX4_sub
io_backend:
type: memcached
server_list_cfg: /mnt/lustre/share/memcached_client/server_list.conf
client_cfg: /mnt/lustre/share/memcached_client/client.conf
sys_path: /mnt/lustre/share/pymc/py3
```
#### How to Implement
The implementation is to call the elegant fileclient design in [mmcv](https://github.com/open-mmlab/mmcv). In order to be compatible with BasicSR, we have made some changes to the interface (mainly to adapt to LMDB). See [file_client.py](../basicsr/utils/file_client.py) for details.
When we implement our own dataloader, we can easily call the interfaces to support different data storage forms. Please refer to [PairedImageDataset](../basicsr/data/paired_image_dataset.py) for more details.
#### LMDB Description
During training, we use LMDB to speed up the IO and CPU decompression. (During testing, usually the data is limited and it is generally not necessary to use LMDB). The acceleration depends on the configurations of the machine, and the following factors will affect the speed:
1. Some machines will clean cache regularly, and LMDB depends on the cache mechanism. Therefore, if the data fails to be cached, you need to check it. After the command `free -h`, the cache occupied by LMDB will be recorded under the `buff/cache` entry.
1. Whether the memory of the machine is large enough to put the whole LMDB data in. If not, it will affect the speed due to the need to constantly update the cache.
1. If you cache the LMDB dataset for the first time, it may affect the training speed. So before training, you can enter the LMDB dataset directory and cache the data by: ` cat data.mdb > /dev/nul`.
In addition to the standard LMDB file (data.mdb and lock.mdb), we also add `meta_info.txt` to record additional information.
Here is an example:
**Folder Structure**
```txt
DIV2K_train_HR_sub.lmdb
├── data.mdb
├── lock.mdb
├── meta_info.txt
```
**meta information**
`meta_info.txt`, We use txt file to record for readability. The contents are:
```txt
0001_s001.png (480,480,3) 1
0001_s002.png (480,480,3) 1
0001_s003.png (480,480,3) 1
0001_s004.png (480,480,3) 1
...
```
Each line records an image with three fields, which indicate:
- Image name (with suffix): 0001_s001.png
- Image size: (480, 480,3) represents a 480x480x3 image
- Other parameters (BasicSR uses cv2 compression level for PNG): In restoration tasks, we usually use PNG format, so `1` represents the PNG compression level `CV_IMWRITE_PNG_COMPRESSION` is 1. It can be an integer in [0, 9]. A larger value indicates stronger compression, that is, smaller storage space and longer compression time.
**Binary Content**
For convenience, the binary content stored in LMDB dataset is encoded image by cv2: `cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]`. You can control the compression level by `compress_level`, balancing storage space and the speed of reading (including decompression).
**How to Make LMDB**
We provide a script to make LMDB. Before running the script, we need to modify the corresponding parameters accordingly. At present, we support DIV2K, REDS and Vimeo90K datasets; other datasets can also be made in a similar way.<br>
`python scripts/data_preparation/create_lmdb.py`
#### Data Pre-fetcher
Apar from using LMDB for speed up, we could use data per-fetcher. Please refer to [prefetch_dataloader](../basicsr/data/prefetch_dataloader.py) for implementation.<br>
It can be achieved by setting `prefetch_mode` in the configuration file. Currently, it provided three modes:
1. None. It does not use data pre-fetcher by default. If you have already use LMDB or the IO is OK, you can set it to None.
```yml
prefetch_mode: ~
```
1. `prefetch_mode: cuda`. Use CUDA prefetcher. Please see [NVIDIA/apex](https://github.com/NVIDIA/apex/issues/304#) for more details. It will occupy more GPU memory. Note that in the mode. you must also set `pin_memory=True`.
```yml
prefetch_mode: cuda
pin_memory: true
```
1. `prefetch_mode: cpu`. Use CPU prefetcher, please see [IgorSusmelj/pytorch-styleguide](https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#) for more details. (In my tests, this mode does not accelerate)
```yml
prefetch_mode: cpu
num_prefetch_queue: 1 # 1 by default
```
## Image Super-Resolution
It is recommended to symlink the dataset root to `datasets` with the command `ln -s xxx yyy`. If your folder structure is different, you may need to change the corresponding paths in config files.
### DIV2K
[DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) is a widely-used dataset in image super-resolution. In many research works, a MATLAB bicubic downsampling kernel is assumed. It may not be practical because the MATLAB bicubic downsampling kernel is not a good approximation for the implicit degradation kernels in real-world scenarios. And there is another topic named *blind restoration* that deals with this gap.
**Preparation Steps**
1. Download the datasets from the [official DIV2K website](https://data.vision.ee.ethz.ch/cvl/DIV2K/).<br>
1. Crop to sub-images: DIV2K has 2K resolution (e.g., 2048 × 1080) images but the training patches are usually small (e.g., 128x128 or 192x192). So there is a waste if reading the whole image but only using a very small part of it. In order to accelerate the IO speed during training, we crop the 2K resolution images to sub-images (here, we crop to 480x480 sub-images). <br>
Note that the size of sub-images is different from the training patch size (`gt_size`) defined in the config file. Specifically, the cropped sub-images with 480x480 are stored. The dataloader will further randomly crop the sub-images to `GT_size x GT_size` patches for training. <br/>
Run the script [extract_subimages.py](../scripts/data_preparation/extract_subimages.py):
```python
python scripts/data_preparation/extract_subimages.py
```
Remember to modify the paths and configurations if you have different settings.
1. [Optional] Create LMDB files. Please refer to [LMDB Description](#LMDB-Description). `python scripts/data_preparation/create_lmdb.py`. Use the `create_lmdb_for_div2k` function and remember to modify the paths and configurations accordingly.
1. Test the dataloader with the script `tests/test_paired_image_dataset.py`.
Remember to modify the paths and configurations accordingly.
1. [Optional] If you want to use meta_info_file, you may need to run `python scripts/data_preparation/generate_meta_info.py` to generate the meta_info_file.
### Common Image SR Datasets
We provide a list of common image super-resolution datasets.
<table>
<tr>
<th>Name</th>
<th>Datasets</th>
<th>Short Description</th>
<th>Download</th>
</tr>
<tr>
<td rowspan="3">Classical SR Training</td>
<td>T91</td>
<td><sub>91 images for training</sub></td>
<td rowspan="9"><a href="https://drive.google.com/drive/folders/1gt5eT293esqY0yr1Anbm36EdnxWW_5oH?usp=sharing">Google Drive</a> / <a href="https://pan.baidu.com/s/1q_1ERCMqALH0xFwjLM0pTg">Baidu Drive</a></td>
</tr>
<tr>
<td><a href="https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/">BSDS200</a></td>
<td><sub>A subset (train) of BSD500 for training</sub></td>
</tr>
<tr>
<td><a href="http://mmlab.ie.cuhk.edu.hk/projects/FSRCNN.html">General100</a></td>
<td><sub>100 images for training</sub></td>
</tr>
<tr>
<td rowspan="6">Classical SR Testing</td>
<td>Set5</td>
<td><sub>Set5 test dataset</sub></td>
</tr>
<tr>
<td>Set14</td>
<td><sub>Set14 test dataset</sub></td>
</tr>
<tr>
<td><a href="https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/">BSDS100</a></td>
<td><sub>A subset (test) of BSD500 for testing</sub></td>
</tr>
<tr>
<td><a href="https://sites.google.com/site/jbhuang0604/publications/struct_sr">urban100</a></td>
<td><sub>100 building images for testing (regular structures)</sub></td>
</tr>
<tr>
<td><a href="http://www.manga109.org/en/">manga109</a></td>
<td><sub>109 images of Japanese manga for testing</sub></td>
</tr>
<tr>
<td>historical</td>
<td><sub>10 gray low-resolution images without the ground-truth</sub></td>
</tr>
<tr>
<td rowspan="3">2K Resolution</td>
<td><a href="https://data.vision.ee.ethz.ch/cvl/DIV2K/">DIV2K</a></td>
<td><sub>proposed in <a href="http://www.vision.ee.ethz.ch/ntire17/">NTIRE17</a> (800 train and 100 validation)</sub></td>
<td><a href="https://data.vision.ee.ethz.ch/cvl/DIV2K/">official website</a></td>
</tr>
<tr>
<td><a href="https://github.com/LimBee/NTIRE2017">Flickr2K</a></td>
<td><sub>2650 2K images from Flickr for training</sub></td>
<td><a href="https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar">official website</a></td>
</tr>
<tr>
<td>DF2K</td>
<td><sub>A merged training dataset of DIV2K and Flickr2K</sub></td>
<td>-</a></td>
</tr>
<tr>
<td rowspan="2">OST (Outdoor Scenes)</td>
<td>OST Training</td>
<td><sub>7 categories images with rich textures</sub></td>
<td rowspan="2"><a href="https://drive.google.com/drive/u/1/folders/1iZfzAxAwOpeutz27HC56_y5RNqnsPPKr">Google Drive</a> / <a href="https://pan.baidu.com/s/1neUq5tZ4yTnOEAntZpK_rQ#list/path=%2Fpublic%2FSFTGAN&parentPath=%2Fpublic">Baidu Drive</a></td>
</tr>
<tr>
<td>OST300</td>
<td><sub>300 test images of outdoor scenes</sub></td>
</tr>
<tr>
<td >PIRM</td>
<td>PIRM</td>
<td><sub>PIRM self-val, val, test datasets</sub></td>
<td rowspan="2"><a href="https://drive.google.com/drive/folders/17FmdXu5t8wlKwt8extb_nQAdjxUOrb1O?usp=sharing">Google Drive</a> / <a href="https://pan.baidu.com/s/1gYv4tSJk_RVCbCq4B6UxNQ">Baidu Drive</a></td>
</tr>
</table>
## Video Super-Resolution
It is recommended to symlink the dataset root to `datasets` with the command `ln -s xxx yyy`. If your folder structure is different, you may need to change the corresponding paths in config files.
### REDS
[Official website](https://seungjunnah.github.io/Datasets/reds.html).<br>
We regroup the training and validation dataset into one folder. The original training dataset has 240 clips from 000 to 239. And we rename the validation clips from 240 to 269.
**Validation Partition**
The official validation partition and that used in EDVR for competition are different:
| name | clips | total number |
|:----------:|:----------:|:----------:|
| REDSOfficial | [240, 269] | 30 clips |
| REDS4 | 000, 011, 015, 020 clips from the *original training set* | 4 clips |
All the left clips are used for training. Note that it it not required to explicitly separate the training and validation datasets; and the dataloader does that.
**Preparation Steps**
1. Download the datasets from the [official website](https://seungjunnah.github.io/Datasets/reds.html).
1. Regroup the training and validation datasets: `python scripts/data_preparation/regroup_reds_dataset.py`
1. [Optional] Make LMDB files when necessary. Please refer to [LMDB Description](#LMDB-Description). `python scripts/data_preparation/create_lmdb.py`. Use the `create_lmdb_for_reds` function and remember to modify the paths and configurations accordingly.
1. Test the dataloader with the script `tests/test_reds_dataset.py`.
Remember to modify the paths and configurations accordingly.
### Vimeo90K
[Official webpage](http://toflow.csail.mit.edu/)
1. Download the dataset: [`Septuplets dataset --> The original training + test set (82GB)`](http://data.csail.mit.edu/tofu/dataset/vimeo_septuplet.zip).This is the Ground-Truth (GT). There is a `sep_trainlist.txt` file listing the training samples in the download zip file.
1. Generate the low-resolution images (TODO)
The low-resolution images in the Vimeo90K test dataset are generated with the MATLAB bicubic downsampling kernel. Use the script `data_scripts/generate_LR_Vimeo90K.m` (run in MATLAB) to generate the low-resolution images.
1. [Optional] Make LMDB files when necessary. Please refer to [LMDB Description](#LMDB-Description). `python scripts/data_preparation/create_lmdb.py`. Use the `create_lmdb_for_vimeo90k` function and remember to modify the paths and configurations accordingly.
1. Test the dataloader with the script `tests/test_vimeo90k_dataset.py`.
Remember to modify the paths and configurations accordingly.
## StyleGAN2
### FFHQ
Training dataset: [FFHQ](https://github.com/NVlabs/ffhq-dataset).
1. Download FFHQ dataset. Recommend to download the tfrecords files from [NVlabs/ffhq-dataset](https://github.com/NVlabs/ffhq-dataset).
1. Extract tfrecords to images or LMDBs. (TensorFlow is required to read tfrecords). For each resolution, we will create images folder or LMDB files separately.
```bash
python scripts/data_preparation/extract_images_from_tfrecords.py
```
# 数据准备
[English](DatasetPreparation.md) **|** [简体中文](DatasetPreparation_CN.md)
#### 目录
1. [数据存储形式](#数据存储形式)
1. [如何使用](#如何使用)
1. [如何实现](#如何实现)
1. [LMDB具体说明](#LMDB具体说明)
1. [预读取数据](#预读取数据)
1. [图像数据](#图像数据)
1. [DIV2K](#DIV2K)
1. [其他常见图像超分数据集](#其他常见图像超分数据集)
1. [视频帧数据](#视频帧数据)
1. [REDS](#REDS)
1. [Vimeo90K](#Vimeo90K)
1. [StylgeGAN2](#StyleGAN2)
1. [FFHQ](#FFHQ)
## 数据存储形式
目前支持的数据存储形式有以下三种:
1. 直接以图像/视频帧的格式存放在硬盘
2. 制作成 [LMDB](https://lmdb.readthedocs.io/en/release/). 训练数据使用这种形式, 一般会加快读取速度.
3. 若是支持 [Memcached](https://memcached.org/), 则可以使用. 它们一般应用在集群上.
#### 如何使用
目前, 我们可以通过 configuration yaml 文件方便的修改. 以支持DIV2K的 [PairedImageDataset](../basicsr/data/paired_image_dataset.py) 为例, 根据不同的要求修改yaml文件:
1. 直接读取硬盘数据
```yaml
type: PairedImageDataset
dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub
dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic/X4_sub
io_backend:
type: disk
```
1. 使用LMDB.
在使用前需要先制作LMDB, 参见 [LMDB具体说明](#LMDB具体说明), 注意我们在原有的 LMDB 上, 新增加了 meta 信息, 而且具体保存二进制内容也不同, 因此其他来源的LMDB并不能直接拿过来使用.
```yaml
type: PairedImageDataset
dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub.lmdb
dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb
io_backend:
type: lmdb
```
1. 使用Memcached
机器/集群需要支持 Memcached. 具体的配置文件根据实际的 Memcached 需要进行修改:
```yaml
type: PairedImageDataset
dataroot_gt: datasets/DIV2K_train_HR_sub
dataroot_lq: datasets/DIV2K_train_LR_bicubicX4_sub
io_backend:
type: memcached
server_list_cfg: /mnt/lustre/share/memcached_client/server_list.conf
client_cfg: /mnt/lustre/share/memcached_client/client.conf
sys_path: /mnt/lustre/share/pymc/py3
```
#### 如何实现
实现是调用了[MMCV](https://github.com/open-mmlab/mmcv) 优雅的 FileClient 设计. 为了兼容 BasicSR, 我们对接口做了一些改动 (主要是为了适应LMDB), 参见 [file_client.py](../basicsr/utils/file_client.py).
在实现我们自己的 dataloader 的时候, 可以方便地调用接口, 以实现对不同数据存储形式的支持, 具体可以参考 [PairedImageDataset](../basicsr/data/paired_image_dataset.py) 的写法.
#### LMDB具体说明
我们在训练的时候使用 LMDB 存储形式可以加快IO和CPU解压缩的速度 (测试的时候数据较少, 一般就没有太必要使用 LMDB). 其具体的加速要根据机器的配置来, 以下几个因素会影响:
1. 有的机器设置了定时清理缓存, 而 LMDB 依赖于缓存. 因此若一直缓存不进去, 则需要检查一下. 一般 `free -h` 命令下, LMDB 占用的缓存会记录在 `buff/cache` 条目下面
1. 机器的内存是否足够大, 能够把整个 LMDB 数据都放进去. 如果不是, 则它由于需要不断更换缓存, 会影响速度
1. 若是第一次缓存 LMDB 数据集, 可能会影响训练速度. 可以在训练前, 进入 LMDB 数据集目录, 把数据先缓存进去: `cat data.mdb > /dev/nul`
除了标准的 LMDB 文件 (data.mdb 和 lock.mdb) 外, 我们还增加了 `meta_info.txt` 来记录额外的信息.
下面用一个例子来说明:
**文件结构**
```txt
DIV2K_train_HR_sub.lmdb
├── data.mdb
├── lock.mdb
├── meta_info.txt
```
**meta信息**
`meta_info.txt`, 我们采用txt来记录, 是为了可读性. 其里面的内容为:
```txt
0001_s001.png (480,480,3) 1
0001_s002.png (480,480,3) 1
0001_s003.png (480,480,3) 1
0001_s004.png (480,480,3) 1
...
```
每一行记录了一张图片, 有三个字段, 分别表示:
- 图像名称 (带后缀): 0001_s001.png
- 图像大小: (480,480,3) 表示是480x480x3的图像
- 其他参数 (BasicSR里面使用了 cv2 压缩 png 程度): 因为在复原任务中, 我们通常使用 png 来存储, 所以这个 1 表示 png 的压缩程度 `CV_IMWRITE_PNG_COMPRESSION` 是 1. 它可以取值为[0, 9]的整数, 更大的值表示更强的压缩, 即更小的储存空间和更长的压缩时间.
**二进制内容**
为了方便, 我们在 LMDB 数据集中存储的二进制内容是 cv2 encode过的 image: `cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]`. 可以通过 `compress_level` 控制压缩程度, 平衡存储空间和读取(包括解压缩)的速度.
**如何制作**
我们提供了脚本来制作. 在运行脚本前, 需要根据需求修改相应的参数. 目前支持 DIV2K, REDS 和 Vimeo90K 数据集; 其他数据集可仿照进行制作. <br>
`python scripts/data_preparation/create_lmdb.py`
#### 预读取数据
除了使用LMDB来加速外, 还可以采用预读取数据来加速, 实现参见 [prefetch_dataloader](../basicsr/data/prefetch_dataloader.py).<br>
这个可以通过配置文件中的 `prefetch_mode` 来指定. 目前提供了三种模式:
1. None. 默认不使用. 如果使用了 LMDB 或者 IO 不成问题, 则可不使用
```yml
prefetch_mode: ~
```
1. `prefetch_mode: cuda`. 使用 CUDA prefetcher, 具体介绍参见 [NVIDIA/apex](https://github.com/NVIDIA/apex/issues/304#). 它会多占用一些GPU显存. 注意: 这个模式下, 一定要设置 `pin_memory=True`
```yml
prefetch_mode: cuda
pin_memory: true
```
1. `prefetch_mode: cpu`. 使用 CPU prefetcher, 具体介绍参见 [IgorSusmelj/pytorch-styleguide](https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#). (目前测试,这个加速不明显)
```yml
prefetch_mode: cpu
num_prefetch_queue: 1 # 1 by default
```
## 图像数据
推荐把数据通过 `ln -s xxx yyy` 软链到`BasicSR/datasets`下. 如果你的文件结构不同, 需要相应地修改configuration yaml文件的路径.
### DIV2K
DIV2K 数据集被广泛使用在图像复原的任务中.
**数据准备步骤**
1.[官网](https://data.vision.ee.ethz.ch/cvl/DIV2K)下载数据.
1. Crop to sub-images: 因为 DIV2K 数据集是 2K 分辨率的 (比如: 2048x1080), 而我们在训练的时候往往并不要那么大 (常见的是 128x128 或者 192x192 的训练patch). 因此我们可以先把2K的图片裁剪成有overlap的 480x480 的子图像块. 然后再由 dataloader 从这个 480x480 的子图像块中随机crop出 128x128 或者 192x192 的训练patch.<br>
运行脚本 [extract_subimages.py](../scripts/data_preparation/extract_subimages.py):
```python
python scripts/data_preparation/extract_subimages.py
```
使用之前可能需要修改文件里面的路径和配置参数.
**注意**: sub-image 的尺寸和训练patch的尺寸 (`gt_size`) 是不同的. 我们先把2K分辨率的图像 crop 成 sub-images (往往是 480x480), 然后存储起来. 在训练的时候, dataloader会读取这些sub-images, 然后进一步随机裁剪成 `gt_size` x `gt_size`的大小.
1. [可选] 若需要使用 LMDB, 则需要制作 LMDB, 参考 [LMDB具体说明](#LMDB具体说明). `python scripts/data_preparation/create_lmdb.py`, 注意选择`create_lmdb_for_div2k`函数, 并需要修改函数相应的配置和路径.
1. 测试: `tests/test_paired_image_dataset.py`, 注意修改函数相应的配置和路径.
1. [可选] 若需要使用 meta_info_file, 运行 `python scripts/data_preparation/generate_meta_info.py` 来生成 meta_info_file.
### 其他常见图像超分数据集
我们提供了常见图像超分数据集的列表.
<table>
<tr>
<th>Name</th>
<th>Datasets</th>
<th>Short Description</th>
<th>Download</th>
</tr>
<tr>
<td rowspan="3">Classical SR Training</td>
<td>T91</td>
<td><sub>91 images for training</sub></td>
<td rowspan="9"><a href="https://drive.google.com/drive/folders/1gt5eT293esqY0yr1Anbm36EdnxWW_5oH?usp=sharing">Google Drive</a> / <a href="https://pan.baidu.com/s/1q_1ERCMqALH0xFwjLM0pTg">Baidu Drive</a></td>
</tr>
<tr>
<td><a href="https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/">BSDS200</a></td>
<td><sub>A subset (train) of BSD500 for training</sub></td>
</tr>
<tr>
<td><a href="http://mmlab.ie.cuhk.edu.hk/projects/FSRCNN.html">General100</a></td>
<td><sub>100 images for training</sub></td>
</tr>
<tr>
<td rowspan="6">Classical SR Testing</td>
<td>Set5</td>
<td><sub>Set5 test dataset</sub></td>
</tr>
<tr>
<td>Set14</td>
<td><sub>Set14 test dataset</sub></td>
</tr>
<tr>
<td><a href="https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/">BSDS100</a></td>
<td><sub>A subset (test) of BSD500 for testing</sub></td>
</tr>
<tr>
<td><a href="https://sites.google.com/site/jbhuang0604/publications/struct_sr">urban100</a></td>
<td><sub>100 building images for testing (regular structures)</sub></td>
</tr>
<tr>
<td><a href="http://www.manga109.org/en/">manga109</a></td>
<td><sub>109 images of Japanese manga for testing</sub></td>
</tr>
<tr>
<td>historical</td>
<td><sub>10 gray low-resolution images without the ground-truth</sub></td>
</tr>
<tr>
<td rowspan="3">2K Resolution</td>
<td><a href="https://data.vision.ee.ethz.ch/cvl/DIV2K/">DIV2K</a></td>
<td><sub>proposed in <a href="http://www.vision.ee.ethz.ch/ntire17/">NTIRE17</a> (800 train and 100 validation)</sub></td>
<td><a href="https://data.vision.ee.ethz.ch/cvl/DIV2K/">official website</a></td>
</tr>
<tr>
<td><a href="https://github.com/LimBee/NTIRE2017">Flickr2K</a></td>
<td><sub>2650 2K images from Flickr for training</sub></td>
<td><a href="https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar">official website</a></td>
</tr>
<tr>
<td>DF2K</td>
<td><sub>A merged training dataset of DIV2K and Flickr2K</sub></td>
<td>-</a></td>
</tr>
<tr>
<td rowspan="2">OST (Outdoor Scenes)</td>
<td>OST Training</td>
<td><sub>7 categories images with rich textures</sub></td>
<td rowspan="2"><a href="https://drive.google.com/drive/u/1/folders/1iZfzAxAwOpeutz27HC56_y5RNqnsPPKr">Google Drive</a> / <a href="https://pan.baidu.com/s/1neUq5tZ4yTnOEAntZpK_rQ#list/path=%2Fpublic%2FSFTGAN&parentPath=%2Fpublic">Baidu Drive</a></td>
</tr>
<tr>
<td>OST300</td>
<td><sub>300 test images of outdoor scenes</sub></td>
</tr>
<tr>
<td >PIRM</td>
<td>PIRM</td>
<td><sub>PIRM self-val, val, test datasets</sub></td>
<td rowspan="2"><a href="https://drive.google.com/drive/folders/17FmdXu5t8wlKwt8extb_nQAdjxUOrb1O?usp=sharing">Google Drive</a> / <a href="https://pan.baidu.com/s/1gYv4tSJk_RVCbCq4B6UxNQ">Baidu Drive</a></td>
</tr>
</table>
## 视频帧数据
推荐把数据通过 `ln -s xxx yyy` 软链到`BasicSR/datasets`下. 如果你的文件结构不同, 需要相应地修改configuration yaml文件的路径.
### REDS
[官网](https://seungjunnah.github.io/Datasets/reds.html) <br>
我们重新整合了 training 和 validation 数据到一个文件夹中: 训练集合原来有240个clip (序号从000到239), 我们把validation clips重命名, 从240到269.
**Validation的划分**
官方的validation划分和EDVR的划分不同 (当时为了比赛的设置):
| name | clips | total number |
|:----------:|:----------:|:----------:|
| REDSOfficial | [240, 269] | 30 clips |
| REDS4 | 000, 011, 015, 020 clips from the *original training set* | 4 clips |
余下的clips拿来做训练集合. 注意: 我们不需要显式地分开训练和验证集合, dataloader会做这件事.
**数据准备步骤**
1.[官网](https://seungjunnah.github.io/Datasets/reds.html)下载数据
1. 整合 training 和 validation 数据: `python scripts/data_preparation/regroup_reds_dataset.py`
1. [可选] 若需要使用 LMDB, 则需要制作 LMDB, 参考 [LMDB具体说明](#LMDB具体说明). `python scripts/data_preparation/create_lmdb.py`, 注意选择`create_lmdb_for_reds`函数, 并需要修改函数相应的配置和路径.
1. 测试: `python tests/test_reds_dataset.py`, 注意修改函数相应的配置和路径.
### Vimeo90K
[官网](http://toflow.csail.mit.edu/)
**数据准备步骤**
1. 下载数据: [`Septuplets dataset --> The original training + test set (82GB)`](http://data.csail.mit.edu/tofu/dataset/vimeo_septuplet.zip). 这些是Ground-Truth. 里面有`sep_trainlist.txt`文件来区分训练数据.
1. 生成低分辨率图片. (TODO)
The low-resolution images in the Vimeo90K test dataset are generated with the MATLAB bicubic downsampling kernel. Use the script `data_scripts/generate_LR_Vimeo90K.m` (run in MATLAB) to generate the low-resolution images.
1. [可选] 若需要使用 LMDB, 则需要制作 LMDB, 参考 [LMDB具体说明](#LMDB具体说明). `python scripts/data_preparation/create_lmdb.py`, 注意选择`create_lmdb_for_vimeo90k`函数, 并需要修改函数相应的配置和路径.
1. 测试: `python tests/test_vimeo90k_dataset.py`, 注意修改函数相应的配置和路径.
## StyleGAN2
### FFHQ
训练数据集: [FFHQ](https://github.com/NVlabs/ffhq-dataset).
1. 下载 FFHQ 数据集. 推荐从 [NVlabs/ffhq-dataset](https://github.com/NVlabs/ffhq-dataset) 下载 tfrecords 文件.
1. 从 tfrecords 提取到*图片*或者*LMDB*. (需要安装 TensorFlow 来读取 tfrecords). 我们对每一个分辨率的人脸都单独创建文件夹或者LMDB文件.
```bash
python scripts/data_preparation/extract_images_from_tfrecords.py
```
# Datasets
[English](Datasets.md) **|** [简体中文](Datasets_CN.md)
## Supported Datasets
| Class | Task |Train/Test | Description |
| :------------- | :----------:| :----------: | :----------: |
| [PairedImageDataset](../basicsr/data/paired_image_dataset.py) | Image Super-Resolution | Train|Support paired data |
| [SingleImageDataset](../basicsr/data/single_image_dataset.py) | Image Super-Resolution | Test|Only read low quality images, used in tests without Ground-Truth|
| [REDSDataset](../basicsr/data/reds_dataset.py) | Video Super-Resolution | Train|REDS training dataset |
| [Vimeo90KDataset](../basicsr/data/vimeo90k_dataset.py) | Video Super-Resolution |Train| Vimeo90K training dataset|
| [VideoTestDataset](../basicsr/data/video_test_dataset.py) | Video Super-Resolution | Test|Base video test dataset, supporting Vid4, REDS testing datasets|
| [VideoTestVimeo90KDataset](../basicsr/data/video_test_dataset.py) | Video Super-Resolution |Test| Inherit `VideoTestDataset`, Vimeo90K testing dataset|
| [VideoTestDUFDataset](../basicsr/data/video_test_dataset.py) | Video Super-Resolution |Test| Inherit `VideoTestDataset`, testing dataset for method DUF, supporting Vid4 dataset|
| [FFHQDataset](../basicsr/data/ffhq_dataset.py) | Face Generation |Train| FFHQ training dataset|
1. Common transformations and functions are in [transforms.py](../basicsr/data/transforms.py) and [util.py](../basicsr/data/util.py), respectively
# 数据处理
[English](Datasets.md) **|** [简体中文](Datasets_CN.md)
## 支持的数据处理
| 类 | 任务 |训练/测试 | 描述 |
| :------------- | :----------:| :----------: | :----------: |
| [PairedImageDataset](../basicsr/data/paired_image_dataset.py) | 图像超分 | 训练|支持读取成对的训练数据 |
| [SingleImageDataset](../basicsr/data/single_image_dataset.py) | 图像超分 | 测试|只读取low quality的图像, 用在没有Ground-Truth的测试中 |
| [REDSDataset](../basicsr/data/reds_dataset.py) | 视频超分 | 训练|REDS的训练数据集 |
| [Vimeo90KDataset](../basicsr/data/vimeo90k_dataset.py) | 视频超分 |训练| Vimeo90K的训练数据集|
| [VideoTestDataset](../basicsr/data/video_test_dataset.py) | 视频超分 | 测试|基础的视频超分测试集, 支持Vid4, REDS测试集|
| [VideoTestVimeo90KDataset](../basicsr/data/video_test_dataset.py) | 视频超分 |测试| 继承`VideoTestDataset`, Vimeo90K的测试数据集|
| [VideoTestDUFDataset](../basicsr/data/video_test_dataset.py) | 视频超分 |测试| 继承`VideoTestDataset`, 方法DUF的测试数据集, 支持Vid4|
| [FFHQDataset](../basicsr/data/ffhq_dataset.py) | 人脸生成 |训练| FFHQ的训练数据集|
1. 共用的变换和函数分别在 [transforms.py](../basicsr/data/transforms.py)[util.py](../basicsr/data/util.py)
# Codebase Designs and Conventions
[English](DesignConvention.md) **|** [简体中文](DesignConvention_CN.md)
#### Contents
1. [Overall Framework](#Overall-Framework)
1. [Features](#Features)
1. [Dynamic Instantiation](#Dynamic-Instantiation)
1. [Conventions](#Conventions)
## Overall Framework
The `BasicSR` framework can be divided into the following parts: data, model, options/configs and training process. <br>
When we modify or add a new method, we often modify/add it from the above aspects. <br>
The figure below shows the overall framework.
![overall_structure](../assets/overall_structure.png)
## Features
### Dynamic Instantiation
When we add a new class or function, it can be used directly in the configuration file. The program will automatically scan, find and instantiate according to the class name or function name in the configuration file. This process is called dynamic instantiation.
Specifically, we implement it through `importlib` and `getattr`. Taking the data module as example, we follow the below steps in [`data/__init__.py`](../basicsr/data/__init__.py):
1. Scan all the files under the data folder with '_dataset' in file names
1. Import the classes or functions in these files through `importlib`
1. Instantiate through `getattr` according to the name in the configuration file
```python
# automatically scan and import dataset modules
# scan all the files under the data folder with '_dataset' in file names
data_folder = osp.dirname(osp.abspath(__file__))
dataset_filenames = [
osp.splitext(osp.basename(v))[0] for v in scandir(data_folder)
if v.endswith('_dataset.py')
]
# import all the dataset modules
_dataset_modules = [
importlib.import_module(f'basicsr.data.{file_name}')
for file_name in dataset_filenames
]
...
# dynamic instantiation
for module in _dataset_modules:
dataset_cls = getattr(module, dataset_type, None)
if dataset_cls is not None:
break
```
We use the similar techniques for the following modules. Pay attention to the conventions of file suffix when using them:
| Module | File Suffix | Example |
| :------------- | :----------: | :----------: |
| Data | `_dataset.py` | `data/paired_image_dataset.py` |
| Model | `_model.py` | `basicsr/models/sr_model.py` |
| Archs | `_arch.py` | `basicsr/archs/srresnet_arch.py`|
Note:
1. The above file suffixes are only used when necessary. Other file names should avoid using the above suffixes.
1. Note that the class name or function name cannot be repeated.
In addition, we also use `importlib` and `getattr` for `losses` and `metrics`. However, for losses and metrics, the number of files is smaller and the changes are less. So, we do not use the strategy of scanning files.
For these two modules, after adding new classes or functions, we need to add the corresponding class or function names to `__init__.py`.
| Module | Path | Modify `__init__.py` |
| :------------- | :----------: | :----------: |
| Losses | `basicsr/models/losses` | [`basicsr/models/losses/__init__.py`](../basicsr/models/losses/__init__.py) |
| Metrics | `basicsr/metrics` | [`basicsr/metrics/__init__.py`](../basicsr/metrics/__init__.py)|
## Conventions
1. In dynamic instantiation, there are requirements to the file suffix in the following module. Otherwise, automatic instantiation cannot be achieved.
| Module | File Suffix | Example |
| :------------- | :----------: | :----------: |
| Data | `_dataset.py` | `data/paired_image_dataset.py` |
| Model | `_model.py` | `basicsr/models/sr_model.py` |
| Archs | `_arch.py` | `basicsr/archs/srresnet_arch.py`|
1. When logging, the loss items are recommended to start with `l_`, so that all these loss items will be grouped together in tensorboard. For example, in [basicsr/models/srgan_model.py](../basicsr/models/srgan_model.py), we use `l_g_pix`, `l_g_percep`, `l_g_gan`, etc for loss items. In [basicsr/utils/logger.py](../basicsr/utils/logger.py), these items will be grouped together:
```python
if k.startswith('l_'):
self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
else:
self.tb_logger.add_scalar(k, v, current_iter)
```
# 代码库的设计和约定
[English](DesignConvention.md) **|** [简体中文](DesignConvention_CN.md)
#### 目录
1. [整体框架](#整体框架)
1. [特性](#特性)
1. [动态实例化](#动态实例化)
1. [约定](#约定)
## 整体框架
整个 `BasicSR` 框架可以分为以下几个部分 —— 数据 (Data), 模型 (Model), 配置文件 (Options/Configs) 和训练过程.<br>
当我们修改或定义新的方法时, 也往往是从以上几个方面进行修改/添加的.<br>
下图概括了整体的框架.
![overall_structure](../assets/overall_structure.png)
## 特性
### 动态实例化
(Dynamic Instantiation)<br>
当我们新写了类 (Class) 或 函数 时, 可直接在配置文件中使用. 程序会根据配置文件的类名 或 函数名, 自动查找并实例化. 这个过程称为 动态实例化 (Dynamic Instantiation).
具体而言, 我们是通过 `importlib``getattr` 来实现的. 以data为例, 我们在[`data/__init__.py`](../basicsr/data/__init__.py) 中是如下做的:
1. 扫描所有以`_dataset.py`为结尾的文件 (这是约定)
1. 把这些文件中的 类 或 函数 通过 importlib 都 import 进来
1. 根据配置文件中的名称, 通过`getattr`实例化
```python
# automatically scan and import dataset modules
# scan all the files under the data folder with '_dataset' in file names
data_folder = osp.dirname(osp.abspath(__file__))
dataset_filenames = [
osp.splitext(osp.basename(v))[0] for v in scandir(data_folder)
if v.endswith('_dataset.py')
]
# import all the dataset modules
_dataset_modules = [
importlib.import_module(f'basicsr.data.{file_name}')
for file_name in dataset_filenames
]
...
# dynamic instantiation
for module in _dataset_modules:
dataset_cls = getattr(module, dataset_type, None)
if dataset_cls is not None:
break
```
我们对以下模块使用了类似的技巧, 在使用的时候需要注意文件后缀名称的约定:
| Module | File Suffix | Example |
| :------------- | :----------: | :----------: |
| Data | `_dataset.py` | `data/paired_image_dataset.py` |
| Model | `_model.py` | `basicsr/models/sr_model.py` |
| Archs | `_arch.py` | `basicsr/archs/srresnet_arch.py`|
注意:
1. 上面的文件后缀只用在需要的文件中, 其他文件命名尽量避免使用以上的后缀
1. 注意 类名 或 函数名 不能重复
另外对 `losses``metrics`, 我们也使用了 `importlib``getattr`, 但是和上面不一样的是, 对于losses和metrics, 由于文件数量比较少, 改动也少, 因此我们不采用扫描文件的方式, 而是在新增加类/函数后, 需要在相应的 `__init__.py` 中增加类/函数名称.
| Module | Path | Modify `__init__.py` |
| :------------- | :----------: | :----------: |
| Losses | `basicsr/models/losses` | [`basicsr/models/losses/__init__.py`](../basicsr/models/losses/__init__.py) |
| Metrics | `basicsr/metrics` | [`basicsr/metrics/__init__.py`](../basicsr/metrics/__init__.py)|
## 约定
1. 动态实例化, 以下模块文件后缀名有要求, 否则不能做到自动实例化.
| Module | File Suffix | Example |
| :------------- | :----------: | :----------: |
| Data | `_dataset.py` | `data/paired_image_dataset.py` |
| Model | `_model.py` | `basicsr/models/sr_model.py` |
| Archs | `_arch.py` | `basicsr/archs/srresnet_arch.py`|
1. 在Log的时候, loss项使用`l_`开头, 这样在 tensorboard 显示的时候, 所有loss会被组织到一起. 比如在 [basicsr/models/srgan_model.py](../basicsr/models/srgan_model.py)中, 使用了`l_g_pix`, `l_g_percep`, `l_g_gan`等. 在[basicsr/utils/logger.py](../basicsr/utils/logger.py), 他们会被组织到一起:
```python
if k.startswith('l_'):
self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
else:
self.tb_logger.add_scalar(k, v, current_iter)
```
# HOWTOs
[English](HOWTOs.md) **|** [简体中文](HOWTOs_CN.md)
## How to train StyleGAN2
1. Prepare training dataset: [FFHQ](https://github.com/NVlabs/ffhq-dataset). More details are in [DatasetPreparation.md](DatasetPreparation.md#StyleGAN2)
1. Download FFHQ dataset. Recommend to download the tfrecords files from [NVlabs/ffhq-dataset](https://github.com/NVlabs/ffhq-dataset).
1. Extract tfrecords to images or LMDBs (TensorFlow is required to read tfrecords):
> python scripts/data_preparation/extract_images_from_tfrecords.py
1. Modify the config file in `options/train/StyleGAN/train_StyleGAN2_256_Cmul2_FFHQ.yml`
1. Train with distributed training. More training commands are in [TrainTest.md](TrainTest.md).
> python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/StyleGAN/train_StyleGAN2_256_Cmul2_FFHQ_800k.yml --launcher pytorch
## How to inference StyleGAN2
1. Download pre-trained models from **ModelZoo** ([Google Drive](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing), [百度网盘](https://pan.baidu.com/s/1R6Nc4v3cl79XPAiK0Toe7g)) to the `experiments/pretrained_models` folder.
1. Test.
> python inference/inference_stylegan2.py
1. The results are in the `samples` folder.
## How to inference DFDNet
1. Install [dlib](http://dlib.net/), because DFDNet uses dlib to do face recognition and landmark detection. [Installation reference](https://github.com/davisking/dlib).
1. Clone dlib repo: `git clone git@github.com:davisking/dlib.git`
1. `cd dlib`
1. Install: `python setup.py install`
2. Download the dlib pretrained models from **ModelZoo** ([Google Drive](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing), [百度网盘](https://pan.baidu.com/s/1R6Nc4v3cl79XPAiK0Toe7g)) to the `experiments/pretrained_models/dlib` folder.<br>
You can download by run the following command OR manually download the pretrained models.
> python scripts/download_pretrained_models.py dlib
3. Download pretrained DFDNet models, dictionary and face template from **ModelZoo** ([Google Drive](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing), [百度网盘](https://pan.baidu.com/s/1R6Nc4v3cl79XPAiK0Toe7g)) to the `experiments/pretrained_models/DFDNet` folder.<br>
You can download by run the the following command OR manually download the pretrained models.
> python scripts/download_pretrained_models.py DFDNet
4. Prepare the testing dataset in the `datasets`, for example, we put images in the `datasets/TestWhole` folder.
5. Test.
> python inference/inference_dfdnet.py --upscale_factor=2 --test_path datasets/TestWhole
6. The results are in the `results/DFDNet` folder.
## How to train SwinIR (SR)
We take the classical SR X4 with DIV2K for example.
1. Prepare the training dataset: [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/). More details are in [DatasetPreparation.md](DatasetPreparation.md#image-super-resolution)
1. Prepare the validation dataset: Set5. You can download with [this guidance](DatasetPreparation.md#common-image-sr-datasets)
1. Modify the config file in [`options/train/SwinIR/train_SwinIR_SRx4_scratch.yml`](../options/train/SwinIR/train_SwinIR_SRx4_scratch.yml) accordingly.
1. Train with distributed training. More training commands are in [TrainTest.md](TrainTest.md).
> python -m torch.distributed.launch --nproc_per_node=8 --master_port=4331 basicsr/train.py -opt options/train/SwinIR/train_SwinIR_SRx4_scratch.yml --launcher pytorch --auto_resume
Note that:
1. Different from the original setting in the paper where the X4 model is finetuned from the X2 model, we directly train it from scratch.
1. We also use `EMA (Exponential Moving Average)`. Note that all model trainings in BasicSR supports EMA.
1. In the **250K iteration** of training X4 model, it can achieve comparable performance to the official model.
| ClassicalSR DIV2KX4 | PSNR (RGB) | PSNR (Y) | SSIM (RGB) | SSIM (Y) |
| :--- | :---: | :---: | :---: | :---: |
| Official | 30.803 | 32.728 | 0.8738|0.9028 |
| Reproduce |30.832 | 32.756 | 0.8739| 0.9025 |
## How to inference SwinIR (SR)
1. Download pre-trained models from the [**official SwinIR repo**](https://github.com/JingyunLiang/SwinIR/releases/tag/v0.0) to the `experiments/pretrained_models/SwinIR` folder.
1. Inference.
> python inference/inference_swinir.py --input datasets/Set5/LRbicx4 --patch_size 48 --model_path experiments/pretrained_models/SwinIR/001_classicalSR_DIV2K_s48w8_SwinIR-M_x4.pth --output results/SwinIR_SRX4_DIV2K/Set5
1. The results are in the `results/SwinIR_SRX4_DIV2K/Set5` folder.
1. You may want to calculate the PSNR/SSIM values.
> python scripts/metrics/calculate_psnr_ssim.py --gt datasets/Set5/GTmod12/ --restored results/SwinIR_SRX4_DIV2K/Set5 --crop_border 4
or test with the Y channel with the `--test_y_channel` argument.
> python scripts/metrics/calculate_psnr_ssim.py --gt datasets/Set5/GTmod12/ --restored results/SwinIR_SRX4_DIV2K/Set5 --crop_border 4 --test_y_channel
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment