Commit 76b9024b authored by yangzhong's avatar yangzhong
Browse files

git init

parents
Pipeline #3145 failed with stages
in 0 seconds
import torch
from torch import nn as nn
from torch.nn import functional as F
from basicsr.archs.vgg_arch import VGGFeatureExtractor
from basicsr.utils.registry import LOSS_REGISTRY
from .loss_util import weighted_loss
_reduction_modes = ['none', 'mean', 'sum']
@weighted_loss
def l1_loss(pred, target):
return F.l1_loss(pred, target, reduction='none')
@weighted_loss
def mse_loss(pred, target):
return F.mse_loss(pred, target, reduction='none')
@weighted_loss
def charbonnier_loss(pred, target, eps=1e-12):
return torch.sqrt((pred - target)**2 + eps)
@LOSS_REGISTRY.register()
class L1Loss(nn.Module):
"""L1 (mean absolute error, MAE) loss.
Args:
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
"""
def __init__(self, loss_weight=1.0, reduction='mean'):
super(L1Loss, self).__init__()
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
self.loss_weight = loss_weight
self.reduction = reduction
def forward(self, pred, target, weight=None, **kwargs):
"""
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
"""
return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)
@LOSS_REGISTRY.register()
class MSELoss(nn.Module):
"""MSE (L2) loss.
Args:
loss_weight (float): Loss weight for MSE loss. Default: 1.0.
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
"""
def __init__(self, loss_weight=1.0, reduction='mean'):
super(MSELoss, self).__init__()
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
self.loss_weight = loss_weight
self.reduction = reduction
def forward(self, pred, target, weight=None, **kwargs):
"""
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
"""
return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction)
@LOSS_REGISTRY.register()
class CharbonnierLoss(nn.Module):
"""Charbonnier loss (one variant of Robust L1Loss, a differentiable
variant of L1Loss).
Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
Super-Resolution".
Args:
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
eps (float): A value used to control the curvature near zero. Default: 1e-12.
"""
def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
super(CharbonnierLoss, self).__init__()
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
self.loss_weight = loss_weight
self.reduction = reduction
self.eps = eps
def forward(self, pred, target, weight=None, **kwargs):
"""
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
"""
return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction)
@LOSS_REGISTRY.register()
class WeightedTVLoss(L1Loss):
"""Weighted TV loss.
Args:
loss_weight (float): Loss weight. Default: 1.0.
"""
def __init__(self, loss_weight=1.0, reduction='mean'):
if reduction not in ['mean', 'sum']:
raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: mean | sum')
super(WeightedTVLoss, self).__init__(loss_weight=loss_weight, reduction=reduction)
def forward(self, pred, weight=None):
if weight is None:
y_weight = None
x_weight = None
else:
y_weight = weight[:, :, :-1, :]
x_weight = weight[:, :, :, :-1]
y_diff = super().forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=y_weight)
x_diff = super().forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=x_weight)
loss = x_diff + y_diff
return loss
@LOSS_REGISTRY.register()
class PerceptualLoss(nn.Module):
"""Perceptual loss with commonly used style loss.
Args:
layer_weights (dict): The weight for each layer of vgg feature.
Here is an example: {'conv5_4': 1.}, which means the conv5_4
feature layer (before relu5_4) will be extracted with weight
1.0 in calculating losses.
vgg_type (str): The type of vgg network used as feature extractor.
Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image in vgg.
Default: True.
range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
Default: False.
perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
loss will be calculated and the loss will multiplied by the
weight. Default: 1.0.
style_weight (float): If `style_weight > 0`, the style loss will be
calculated and the loss will multiplied by the weight.
Default: 0.
criterion (str): Criterion used for perceptual loss. Default: 'l1'.
"""
def __init__(self,
layer_weights,
vgg_type='vgg19',
use_input_norm=True,
range_norm=False,
perceptual_weight=1.0,
style_weight=0.,
criterion='l1'):
super(PerceptualLoss, self).__init__()
self.perceptual_weight = perceptual_weight
self.style_weight = style_weight
self.layer_weights = layer_weights
self.vgg = VGGFeatureExtractor(
layer_name_list=list(layer_weights.keys()),
vgg_type=vgg_type,
use_input_norm=use_input_norm,
range_norm=range_norm)
self.criterion_type = criterion
if self.criterion_type == 'l1':
self.criterion = torch.nn.L1Loss()
elif self.criterion_type == 'l2':
self.criterion = torch.nn.L2loss()
elif self.criterion_type == 'fro':
self.criterion = None
else:
raise NotImplementedError(f'{criterion} criterion has not been supported.')
def forward(self, x, gt):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
# extract vgg features
x_features = self.vgg(x)
gt_features = self.vgg(gt.detach())
# calculate perceptual loss
if self.perceptual_weight > 0:
percep_loss = 0
for k in x_features.keys():
if self.criterion_type == 'fro':
percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
else:
percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
percep_loss *= self.perceptual_weight
else:
percep_loss = None
# calculate style loss
if self.style_weight > 0:
style_loss = 0
for k in x_features.keys():
if self.criterion_type == 'fro':
style_loss += torch.norm(
self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
else:
style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
gt_features[k])) * self.layer_weights[k]
style_loss *= self.style_weight
else:
style_loss = None
return percep_loss, style_loss
def _gram_mat(self, x):
"""Calculate Gram matrix.
Args:
x (torch.Tensor): Tensor with shape of (n, c, h, w).
Returns:
torch.Tensor: Gram matrix.
"""
n, c, h, w = x.size()
features = x.view(n, c, w * h)
features_t = features.transpose(1, 2)
gram = features.bmm(features_t) / (c * h * w)
return gram
import math
import torch
from torch import autograd as autograd
from torch import nn as nn
from torch.nn import functional as F
from basicsr.utils.registry import LOSS_REGISTRY
@LOSS_REGISTRY.register()
class GANLoss(nn.Module):
"""Define GAN loss.
Args:
gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
real_label_val (float): The value for real label. Default: 1.0.
fake_label_val (float): The value for fake label. Default: 0.0.
loss_weight (float): Loss weight. Default: 1.0.
Note that loss_weight is only for generators; and it is always 1.0
for discriminators.
"""
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
super(GANLoss, self).__init__()
self.gan_type = gan_type
self.loss_weight = loss_weight
self.real_label_val = real_label_val
self.fake_label_val = fake_label_val
if self.gan_type == 'vanilla':
self.loss = nn.BCEWithLogitsLoss()
elif self.gan_type == 'lsgan':
self.loss = nn.MSELoss()
elif self.gan_type == 'wgan':
self.loss = self._wgan_loss
elif self.gan_type == 'wgan_softplus':
self.loss = self._wgan_softplus_loss
elif self.gan_type == 'hinge':
self.loss = nn.ReLU()
else:
raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')
def _wgan_loss(self, input, target):
"""wgan loss.
Args:
input (Tensor): Input tensor.
target (bool): Target label.
Returns:
Tensor: wgan loss.
"""
return -input.mean() if target else input.mean()
def _wgan_softplus_loss(self, input, target):
"""wgan loss with soft plus. softplus is a smooth approximation to the
ReLU function.
In StyleGAN2, it is called:
Logistic loss for discriminator;
Non-saturating loss for generator.
Args:
input (Tensor): Input tensor.
target (bool): Target label.
Returns:
Tensor: wgan loss.
"""
return F.softplus(-input).mean() if target else F.softplus(input).mean()
def get_target_label(self, input, target_is_real):
"""Get target label.
Args:
input (Tensor): Input tensor.
target_is_real (bool): Whether the target is real or fake.
Returns:
(bool | Tensor): Target tensor. Return bool for wgan, otherwise,
return Tensor.
"""
if self.gan_type in ['wgan', 'wgan_softplus']:
return target_is_real
target_val = (self.real_label_val if target_is_real else self.fake_label_val)
return input.new_ones(input.size()) * target_val
def forward(self, input, target_is_real, is_disc=False):
"""
Args:
input (Tensor): The input for the loss module, i.e., the network
prediction.
target_is_real (bool): Whether the targe is real or fake.
is_disc (bool): Whether the loss for discriminators or not.
Default: False.
Returns:
Tensor: GAN loss value.
"""
target_label = self.get_target_label(input, target_is_real)
if self.gan_type == 'hinge':
if is_disc: # for discriminators in hinge-gan
input = -input if target_is_real else input
loss = self.loss(1 + input).mean()
else: # for generators in hinge-gan
loss = -input.mean()
else: # other gan types
loss = self.loss(input, target_label)
# loss_weight is always 1.0 for discriminators
return loss if is_disc else loss * self.loss_weight
@LOSS_REGISTRY.register()
class MultiScaleGANLoss(GANLoss):
"""
MultiScaleGANLoss accepts a list of predictions
"""
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
super(MultiScaleGANLoss, self).__init__(gan_type, real_label_val, fake_label_val, loss_weight)
def forward(self, input, target_is_real, is_disc=False):
"""
The input is a list of tensors, or a list of (a list of tensors)
"""
if isinstance(input, list):
loss = 0
for pred_i in input:
if isinstance(pred_i, list):
# Only compute GAN loss for the last layer
# in case of multiscale feature matching
pred_i = pred_i[-1]
# Safe operation: 0-dim tensor calling self.mean() does nothing
loss_tensor = super().forward(pred_i, target_is_real, is_disc).mean()
loss += loss_tensor
return loss / len(input)
else:
return super().forward(input, target_is_real, is_disc)
def r1_penalty(real_pred, real_img):
"""R1 regularization for discriminator. The core idea is to
penalize the gradient on real data alone: when the
generator distribution produces the true data distribution
and the discriminator is equal to 0 on the data manifold, the
gradient penalty ensures that the discriminator cannot create
a non-zero gradient orthogonal to the data manifold without
suffering a loss in the GAN game.
Reference: Eq. 9 in Which training methods for GANs do actually converge.
"""
grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
return grad_penalty
def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0]
path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
path_penalty = (path_lengths - path_mean).pow(2).mean()
return path_penalty, path_lengths.detach().mean(), path_mean.detach()
def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
"""Calculate gradient penalty for wgan-gp.
Args:
discriminator (nn.Module): Network for the discriminator.
real_data (Tensor): Real input data.
fake_data (Tensor): Fake input data.
weight (Tensor): Weight tensor. Default: None.
Returns:
Tensor: A tensor for gradient penalty.
"""
batch_size = real_data.size(0)
alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
# interpolate between real_data and fake_data
interpolates = alpha * real_data + (1. - alpha) * fake_data
interpolates = autograd.Variable(interpolates, requires_grad=True)
disc_interpolates = discriminator(interpolates)
gradients = autograd.grad(
outputs=disc_interpolates,
inputs=interpolates,
grad_outputs=torch.ones_like(disc_interpolates),
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
if weight is not None:
gradients = gradients * weight
gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
if weight is not None:
gradients_penalty /= torch.mean(weight)
return gradients_penalty
import functools
import torch
from torch.nn import functional as F
def reduce_loss(loss, reduction):
"""Reduce loss as specified.
Args:
loss (Tensor): Elementwise loss tensor.
reduction (str): Options are 'none', 'mean' and 'sum'.
Returns:
Tensor: Reduced loss tensor.
"""
reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, elementwise_mean:1, sum: 2
if reduction_enum == 0:
return loss
elif reduction_enum == 1:
return loss.mean()
else:
return loss.sum()
def weight_reduce_loss(loss, weight=None, reduction='mean'):
"""Apply element-wise weight and reduce loss.
Args:
loss (Tensor): Element-wise loss.
weight (Tensor): Element-wise weights. Default: None.
reduction (str): Same as built-in losses of PyTorch. Options are
'none', 'mean' and 'sum'. Default: 'mean'.
Returns:
Tensor: Loss values.
"""
# if weight is specified, apply element-wise weight
if weight is not None:
assert weight.dim() == loss.dim()
assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
loss = loss * weight
# if weight is not specified or reduction is sum, just reduce the loss
if weight is None or reduction == 'sum':
loss = reduce_loss(loss, reduction)
# if reduction is mean, then compute mean over weight region
elif reduction == 'mean':
if weight.size(1) > 1:
weight = weight.sum()
else:
weight = weight.sum() * loss.size(1)
loss = loss.sum() / weight
return loss
def weighted_loss(loss_func):
"""Create a weighted version of a given loss function.
To use this decorator, the loss function must have the signature like
`loss_func(pred, target, **kwargs)`. The function only needs to compute
element-wise loss without any reduction. This decorator will add weight
and reduction arguments to the function. The decorated function will have
the signature like `loss_func(pred, target, weight=None, reduction='mean',
**kwargs)`.
:Example:
>>> import torch
>>> @weighted_loss
>>> def l1_loss(pred, target):
>>> return (pred - target).abs()
>>> pred = torch.Tensor([0, 2, 3])
>>> target = torch.Tensor([1, 1, 1])
>>> weight = torch.Tensor([1, 0, 1])
>>> l1_loss(pred, target)
tensor(1.3333)
>>> l1_loss(pred, target, weight)
tensor(1.5000)
>>> l1_loss(pred, target, reduction='none')
tensor([1., 1., 2.])
>>> l1_loss(pred, target, weight, reduction='sum')
tensor(3.)
"""
@functools.wraps(loss_func)
def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
# get element-wise loss
loss = loss_func(pred, target, **kwargs)
loss = weight_reduce_loss(loss, weight, reduction)
return loss
return wrapper
def get_local_weights(residual, ksize):
"""Get local weights for generating the artifact map of LDL.
It is only called by the `get_refined_artifact_map` function.
Args:
residual (Tensor): Residual between predicted and ground truth images.
ksize (Int): size of the local window.
Returns:
Tensor: weight for each pixel to be discriminated as an artifact pixel
"""
pad = (ksize - 1) // 2
residual_pad = F.pad(residual, pad=[pad, pad, pad, pad], mode='reflect')
unfolded_residual = residual_pad.unfold(2, ksize, 1).unfold(3, ksize, 1)
pixel_level_weight = torch.var(unfolded_residual, dim=(-1, -2), unbiased=True, keepdim=True).squeeze(-1).squeeze(-1)
return pixel_level_weight
def get_refined_artifact_map(img_gt, img_output, img_ema, ksize):
"""Calculate the artifact map of LDL
(Details or Artifacts: A Locally Discriminative Learning Approach to Realistic Image Super-Resolution. In CVPR 2022)
Args:
img_gt (Tensor): ground truth images.
img_output (Tensor): output images given by the optimizing model.
img_ema (Tensor): output images given by the ema model.
ksize (Int): size of the local window.
Returns:
overall_weight: weight for each pixel to be discriminated as an artifact pixel
(calculated based on both local and global observations).
"""
residual_ema = torch.sum(torch.abs(img_gt - img_ema), 1, keepdim=True)
residual_sr = torch.sum(torch.abs(img_gt - img_output), 1, keepdim=True)
patch_level_weight = torch.var(residual_sr.clone(), dim=(-1, -2, -3), keepdim=True)**(1 / 5)
pixel_level_weight = get_local_weights(residual_sr.clone(), ksize)
overall_weight = patch_level_weight * pixel_level_weight
overall_weight[residual_sr < residual_ema] = 0
return overall_weight
# Metrics
[English](README.md) **|** [简体中文](README_CN.md)
- [约定](#约定)
- [PSNR 和 SSIM](#psnr-和-ssim)
## 约定
因为不同的输入类型会导致结果的不同,因此我们对输入做如下约定:
- Numpy 类型 (一般是 cv2 的结果)
- UINT8: BGR, [0, 255], (h, w, c)
- float: BGR, [0, 1], (h, w, c). 一般作为中间结果
- Tensor 类型
- float: RGB, [0, 1], (n, c, h, w)
其他约定:
-`_pt` 结尾的是 PyTorch 结果
- PyTorch version 支持 batch 计算
- 颜色转换在 float32 上做;metric计算在 float64 上做
## PSNR 和 SSIM
PSNR 和 SSIM 的结果趋势是一致的,即一般 PSNR 高,则 SSIM 也高。
在实现上, PSNR 的各种实现都很一致。SSIM 有各种各样的实现,我们这里和 MATLAB 最原始版本保持 (参考 [NTIRE17比赛](https://competitions.codalab.org/competitions/16306#participate)[evaluation代码](https://competitions.codalab.org/my/datasets/download/ebe960d8-0ec8-4846-a1a2-7c4a586a7378))
下面列了各个实现的结果比对.
总结:PyTorch 实现和 MATLAB 实现基本一致,在 GPU 运行上会有稍许差异
- PSNR 比对
|Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU |
|:---| :---: | :---: | :---: | :---: | :---: |
|baboon| RGB | 20.419710 | 20.419710 | 20.419710 |20.419710 |
|baboon| Y | - |22.441898 | 22.441899 | 22.444916|
|comic | RGB | 20.239912 | 20.239912 | 20.239912 | 20.239912 |
|comic | Y | - | 21.720398 | 21.720398 | 21.721663|
- SSIM 比对
|Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU |
|:---| :---: | :---: | :---: | :---: | :---: |
|baboon| RGB | 0.391853 | 0.391853 | 0.391853|0.391853 |
|baboon| Y | - |0.453097| 0.453097 | 0.453171|
|comic | RGB | 0.567738 | 0.567738 | 0.567738 | 0.567738|
|comic | Y | - | 0.585511 | 0.585511 | 0.585522 |
# Metrics
[English](README.md) **|** [简体中文](README_CN.md)
- [约定](#约定)
- [PSNR 和 SSIM](#psnr-和-ssim)
## 约定
因为不同的输入类型会导致结果的不同,因此我们对输入做如下约定:
- Numpy 类型 (一般是 cv2 的结果)
- UINT8: BGR, [0, 255], (h, w, c)
- float: BGR, [0, 1], (h, w, c). 一般作为中间结果
- Tensor 类型
- float: RGB, [0, 1], (n, c, h, w)
其他约定:
-`_pt` 结尾的是 PyTorch 结果
- PyTorch version 支持 batch 计算
- 颜色转换在 float32 上做;metric计算在 float64 上做
## PSNR 和 SSIM
PSNR 和 SSIM 的结果趋势是一致的,即一般 PSNR 高,则 SSIM 也高。
在实现上, PSNR 的各种实现都很一致。SSIM 有各种各样的实现,我们这里和 MATLAB 最原始版本保持 (参考 [NTIRE17比赛](https://competitions.codalab.org/competitions/16306#participate)[evaluation代码](https://competitions.codalab.org/my/datasets/download/ebe960d8-0ec8-4846-a1a2-7c4a586a7378))
下面列了各个实现的结果比对.
总结:PyTorch 实现和 MATLAB 实现基本一致,在 GPU 运行上会有稍许差异
- PSNR 比对
|Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU |
|:---| :---: | :---: | :---: | :---: | :---: |
|baboon| RGB | 20.419710 | 20.419710 | 20.419710 |20.419710 |
|baboon| Y | - |22.441898 | 22.441899 | 22.444916|
|comic | RGB | 20.239912 | 20.239912 | 20.239912 | 20.239912 |
|comic | Y | - | 21.720398 | 21.720398 | 21.721663|
- SSIM 比对
|Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU |
|:---| :---: | :---: | :---: | :---: | :---: |
|baboon| RGB | 0.391853 | 0.391853 | 0.391853|0.391853 |
|baboon| Y | - |0.453097| 0.453097 | 0.453171|
|comic | RGB | 0.567738 | 0.567738 | 0.567738 | 0.567738|
|comic | Y | - | 0.585511 | 0.585511 | 0.585522 |
from copy import deepcopy
from basicsr.utils.registry import METRIC_REGISTRY
from .niqe import calculate_niqe
from .psnr_ssim import calculate_psnr, calculate_ssim, calculate_ssim_pt, calculate_psnr_pt
__all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe']
def calculate_metric(data, opt):
"""Calculate metric from data and options.
Args:
opt (dict): Configuration. It must contain:
type (str): Model type.
"""
opt = deepcopy(opt)
metric_type = opt.pop('type')
metric = METRIC_REGISTRY.get(metric_type)(**data, **opt)
return metric
import numpy as np
import torch
import torch.nn as nn
from scipy import linalg
from tqdm import tqdm
from basicsr.archs.inception import InceptionV3
def load_patched_inception_v3(device='cuda', resize_input=True, normalize_input=False):
# we may not resize the input, but in [rosinality/stylegan2-pytorch] it
# does resize the input.
inception = InceptionV3([3], resize_input=resize_input, normalize_input=normalize_input)
inception = nn.DataParallel(inception).eval().to(device)
return inception
@torch.no_grad()
def extract_inception_features(data_generator, inception, len_generator=None, device='cuda'):
"""Extract inception features.
Args:
data_generator (generator): A data generator.
inception (nn.Module): Inception model.
len_generator (int): Length of the data_generator to show the
progressbar. Default: None.
device (str): Device. Default: cuda.
Returns:
Tensor: Extracted features.
"""
if len_generator is not None:
pbar = tqdm(total=len_generator, unit='batch', desc='Extract')
else:
pbar = None
features = []
for data in data_generator:
if pbar:
pbar.update(1)
data = data.to(device)
feature = inception(data)[0].view(data.shape[0], -1)
features.append(feature.to('cpu'))
if pbar:
pbar.close()
features = torch.cat(features, 0)
return features
def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6):
"""Numpy implementation of the Frechet Distance.
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) and X_2 ~ N(mu_2, C_2) is:
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
Stable version by Dougal J. Sutherland.
Args:
mu1 (np.array): The sample mean over activations.
sigma1 (np.array): The covariance matrix over activations for generated samples.
mu2 (np.array): The sample mean over activations, precalculated on an representative data set.
sigma2 (np.array): The covariance matrix over activations, precalculated on an representative data set.
Returns:
float: The Frechet Distance.
"""
assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths'
assert sigma1.shape == sigma2.shape, ('Two covariances have different dimensions')
cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False)
# Product might be almost singular
if not np.isfinite(cov_sqrt).all():
print('Product of cov matrices is singular. Adding {eps} to diagonal of cov estimates')
offset = np.eye(sigma1.shape[0]) * eps
cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset))
# Numerical error might give slight imaginary component
if np.iscomplexobj(cov_sqrt):
if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3):
m = np.max(np.abs(cov_sqrt.imag))
raise ValueError(f'Imaginary component {m}')
cov_sqrt = cov_sqrt.real
mean_diff = mu1 - mu2
mean_norm = mean_diff @ mean_diff
trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt)
fid = mean_norm + trace
return fid
import numpy as np
from basicsr.utils import bgr2ycbcr
def reorder_image(img, input_order='HWC'):
"""Reorder images to 'HWC' order.
If the input_order is (h, w), return (h, w, 1);
If the input_order is (c, h, w), return (h, w, c);
If the input_order is (h, w, c), return as it is.
Args:
img (ndarray): Input image.
input_order (str): Whether the input order is 'HWC' or 'CHW'.
If the input image shape is (h, w), input_order will not have
effects. Default: 'HWC'.
Returns:
ndarray: reordered image.
"""
if input_order not in ['HWC', 'CHW']:
raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'")
if len(img.shape) == 2:
img = img[..., None]
if input_order == 'CHW':
img = img.transpose(1, 2, 0)
return img
def to_y_channel(img):
"""Change to Y channel of YCbCr.
Args:
img (ndarray): Images with range [0, 255].
Returns:
(ndarray): Images with range [0, 255] (float type) without round.
"""
img = img.astype(np.float32) / 255.
if img.ndim == 3 and img.shape[2] == 3:
img = bgr2ycbcr(img, y_only=True)
img = img[..., None]
return img * 255.
import cv2
import math
import numpy as np
import os
from scipy.ndimage import convolve
from scipy.special import gamma
from basicsr.metrics.metric_util import reorder_image, to_y_channel
from basicsr.utils.matlab_functions import imresize
from basicsr.utils.registry import METRIC_REGISTRY
def estimate_aggd_param(block):
"""Estimate AGGD (Asymmetric Generalized Gaussian Distribution) parameters.
Args:
block (ndarray): 2D Image block.
Returns:
tuple: alpha (float), beta_l (float) and beta_r (float) for the AGGD
distribution (Estimating the parames in Equation 7 in the paper).
"""
block = block.flatten()
gam = np.arange(0.2, 10.001, 0.001) # len = 9801
gam_reciprocal = np.reciprocal(gam)
r_gam = np.square(gamma(gam_reciprocal * 2)) / (gamma(gam_reciprocal) * gamma(gam_reciprocal * 3))
left_std = np.sqrt(np.mean(block[block < 0]**2))
right_std = np.sqrt(np.mean(block[block > 0]**2))
gammahat = left_std / right_std
rhat = (np.mean(np.abs(block)))**2 / np.mean(block**2)
rhatnorm = (rhat * (gammahat**3 + 1) * (gammahat + 1)) / ((gammahat**2 + 1)**2)
array_position = np.argmin((r_gam - rhatnorm)**2)
alpha = gam[array_position]
beta_l = left_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
beta_r = right_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
return (alpha, beta_l, beta_r)
def compute_feature(block):
"""Compute features.
Args:
block (ndarray): 2D Image block.
Returns:
list: Features with length of 18.
"""
feat = []
alpha, beta_l, beta_r = estimate_aggd_param(block)
feat.extend([alpha, (beta_l + beta_r) / 2])
# distortions disturb the fairly regular structure of natural images.
# This deviation can be captured by analyzing the sample distribution of
# the products of pairs of adjacent coefficients computed along
# horizontal, vertical and diagonal orientations.
shifts = [[0, 1], [1, 0], [1, 1], [1, -1]]
for i in range(len(shifts)):
shifted_block = np.roll(block, shifts[i], axis=(0, 1))
alpha, beta_l, beta_r = estimate_aggd_param(block * shifted_block)
# Eq. 8
mean = (beta_r - beta_l) * (gamma(2 / alpha) / gamma(1 / alpha))
feat.extend([alpha, mean, beta_l, beta_r])
return feat
def niqe(img, mu_pris_param, cov_pris_param, gaussian_window, block_size_h=96, block_size_w=96):
"""Calculate NIQE (Natural Image Quality Evaluator) metric.
``Paper: Making a "Completely Blind" Image Quality Analyzer``
This implementation could produce almost the same results as the official
MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip
Note that we do not include block overlap height and width, since they are
always 0 in the official implementation.
For good performance, it is advisable by the official implementation to
divide the distorted image in to the same size patched as used for the
construction of multivariate Gaussian model.
Args:
img (ndarray): Input image whose quality needs to be computed. The
image must be a gray or Y (of YCbCr) image with shape (h, w).
Range [0, 255] with float type.
mu_pris_param (ndarray): Mean of a pre-defined multivariate Gaussian
model calculated on the pristine dataset.
cov_pris_param (ndarray): Covariance of a pre-defined multivariate
Gaussian model calculated on the pristine dataset.
gaussian_window (ndarray): A 7x7 Gaussian window used for smoothing the
image.
block_size_h (int): Height of the blocks in to which image is divided.
Default: 96 (the official recommended value).
block_size_w (int): Width of the blocks in to which image is divided.
Default: 96 (the official recommended value).
"""
assert img.ndim == 2, ('Input image must be a gray or Y (of YCbCr) image with shape (h, w).')
# crop image
h, w = img.shape
num_block_h = math.floor(h / block_size_h)
num_block_w = math.floor(w / block_size_w)
img = img[0:num_block_h * block_size_h, 0:num_block_w * block_size_w]
distparam = [] # dist param is actually the multiscale features
for scale in (1, 2): # perform on two scales (1, 2)
mu = convolve(img, gaussian_window, mode='nearest')
sigma = np.sqrt(np.abs(convolve(np.square(img), gaussian_window, mode='nearest') - np.square(mu)))
# normalize, as in Eq. 1 in the paper
img_nomalized = (img - mu) / (sigma + 1)
feat = []
for idx_w in range(num_block_w):
for idx_h in range(num_block_h):
# process ecah block
block = img_nomalized[idx_h * block_size_h // scale:(idx_h + 1) * block_size_h // scale,
idx_w * block_size_w // scale:(idx_w + 1) * block_size_w // scale]
feat.append(compute_feature(block))
distparam.append(np.array(feat))
if scale == 1:
img = imresize(img / 255., scale=0.5, antialiasing=True)
img = img * 255.
distparam = np.concatenate(distparam, axis=1)
# fit a MVG (multivariate Gaussian) model to distorted patch features
mu_distparam = np.nanmean(distparam, axis=0)
# use nancov. ref: https://ww2.mathworks.cn/help/stats/nancov.html
distparam_no_nan = distparam[~np.isnan(distparam).any(axis=1)]
cov_distparam = np.cov(distparam_no_nan, rowvar=False)
# compute niqe quality, Eq. 10 in the paper
invcov_param = np.linalg.pinv((cov_pris_param + cov_distparam) / 2)
quality = np.matmul(
np.matmul((mu_pris_param - mu_distparam), invcov_param), np.transpose((mu_pris_param - mu_distparam)))
quality = np.sqrt(quality)
quality = float(np.squeeze(quality))
return quality
@METRIC_REGISTRY.register()
def calculate_niqe(img, crop_border, input_order='HWC', convert_to='y', **kwargs):
"""Calculate NIQE (Natural Image Quality Evaluator) metric.
``Paper: Making a "Completely Blind" Image Quality Analyzer``
This implementation could produce almost the same results as the official
MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip
> MATLAB R2021a result for tests/data/baboon.png: 5.72957338 (5.7296)
> Our re-implementation result for tests/data/baboon.png: 5.7295763 (5.7296)
We use the official params estimated from the pristine dataset.
We use the recommended block size (96, 96) without overlaps.
Args:
img (ndarray): Input image whose quality needs to be computed.
The input image must be in range [0, 255] with float/int type.
The input_order of image can be 'HW' or 'HWC' or 'CHW'. (BGR order)
If the input order is 'HWC' or 'CHW', it will be converted to gray
or Y (of YCbCr) image according to the ``convert_to`` argument.
crop_border (int): Cropped pixels in each edge of an image. These
pixels are not involved in the metric calculation.
input_order (str): Whether the input order is 'HW', 'HWC' or 'CHW'.
Default: 'HWC'.
convert_to (str): Whether converted to 'y' (of MATLAB YCbCr) or 'gray'.
Default: 'y'.
Returns:
float: NIQE result.
"""
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
# we use the official params estimated from the pristine dataset.
niqe_pris_params = np.load(os.path.join(ROOT_DIR, 'niqe_pris_params.npz'))
mu_pris_param = niqe_pris_params['mu_pris_param']
cov_pris_param = niqe_pris_params['cov_pris_param']
gaussian_window = niqe_pris_params['gaussian_window']
img = img.astype(np.float32)
if input_order != 'HW':
img = reorder_image(img, input_order=input_order)
if convert_to == 'y':
img = to_y_channel(img)
elif convert_to == 'gray':
img = cv2.cvtColor(img / 255., cv2.COLOR_BGR2GRAY) * 255.
img = np.squeeze(img)
if crop_border != 0:
img = img[crop_border:-crop_border, crop_border:-crop_border]
# round is necessary for being consistent with MATLAB's result
img = img.round()
niqe_result = niqe(img, mu_pris_param, cov_pris_param, gaussian_window)
return niqe_result
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from basicsr.metrics.metric_util import reorder_image, to_y_channel
from basicsr.utils.color_util import rgb2ycbcr_pt
from basicsr.utils.registry import METRIC_REGISTRY
@METRIC_REGISTRY.register()
def calculate_psnr(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs):
"""Calculate PSNR (Peak Signal-to-Noise Ratio).
Reference: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
Args:
img (ndarray): Images with range [0, 255].
img2 (ndarray): Images with range [0, 255].
crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
input_order (str): Whether the input order is 'HWC' or 'CHW'. Default: 'HWC'.
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
Returns:
float: PSNR result.
"""
assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
if input_order not in ['HWC', 'CHW']:
raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"')
img = reorder_image(img, input_order=input_order)
img2 = reorder_image(img2, input_order=input_order)
if crop_border != 0:
img = img[crop_border:-crop_border, crop_border:-crop_border, ...]
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
if test_y_channel:
img = to_y_channel(img)
img2 = to_y_channel(img2)
img = img.astype(np.float64)
img2 = img2.astype(np.float64)
mse = np.mean((img - img2)**2)
if mse == 0:
return float('inf')
return 10. * np.log10(255. * 255. / mse)
@METRIC_REGISTRY.register()
def calculate_psnr_pt(img, img2, crop_border, test_y_channel=False, **kwargs):
"""Calculate PSNR (Peak Signal-to-Noise Ratio) (PyTorch version).
Reference: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
Args:
img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
Returns:
float: PSNR result.
"""
assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
if crop_border != 0:
img = img[:, :, crop_border:-crop_border, crop_border:-crop_border]
img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border]
if test_y_channel:
img = rgb2ycbcr_pt(img, y_only=True)
img2 = rgb2ycbcr_pt(img2, y_only=True)
img = img.to(torch.float64)
img2 = img2.to(torch.float64)
mse = torch.mean((img - img2)**2, dim=[1, 2, 3])
return 10. * torch.log10(1. / (mse + 1e-8))
@METRIC_REGISTRY.register()
def calculate_ssim(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs):
"""Calculate SSIM (structural similarity).
``Paper: Image quality assessment: From error visibility to structural similarity``
The results are the same as that of the official released MATLAB code in
https://ece.uwaterloo.ca/~z70wang/research/ssim/.
For three-channel images, SSIM is calculated for each channel and then
averaged.
Args:
img (ndarray): Images with range [0, 255].
img2 (ndarray): Images with range [0, 255].
crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
input_order (str): Whether the input order is 'HWC' or 'CHW'.
Default: 'HWC'.
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
Returns:
float: SSIM result.
"""
assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
if input_order not in ['HWC', 'CHW']:
raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"')
img = reorder_image(img, input_order=input_order)
img2 = reorder_image(img2, input_order=input_order)
if crop_border != 0:
img = img[crop_border:-crop_border, crop_border:-crop_border, ...]
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
if test_y_channel:
img = to_y_channel(img)
img2 = to_y_channel(img2)
img = img.astype(np.float64)
img2 = img2.astype(np.float64)
ssims = []
for i in range(img.shape[2]):
ssims.append(_ssim(img[..., i], img2[..., i]))
return np.array(ssims).mean()
@METRIC_REGISTRY.register()
def calculate_ssim_pt(img, img2, crop_border, test_y_channel=False, **kwargs):
"""Calculate SSIM (structural similarity) (PyTorch version).
``Paper: Image quality assessment: From error visibility to structural similarity``
The results are the same as that of the official released MATLAB code in
https://ece.uwaterloo.ca/~z70wang/research/ssim/.
For three-channel images, SSIM is calculated for each channel and then
averaged.
Args:
img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
Returns:
float: SSIM result.
"""
assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
if crop_border != 0:
img = img[:, :, crop_border:-crop_border, crop_border:-crop_border]
img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border]
if test_y_channel:
img = rgb2ycbcr_pt(img, y_only=True)
img2 = rgb2ycbcr_pt(img2, y_only=True)
img = img.to(torch.float64)
img2 = img2.to(torch.float64)
ssim = _ssim_pth(img * 255., img2 * 255.)
return ssim
def _ssim(img, img2):
"""Calculate SSIM (structural similarity) for one channel images.
It is called by func:`calculate_ssim`.
Args:
img (ndarray): Images with range [0, 255] with order 'HWC'.
img2 (ndarray): Images with range [0, 255] with order 'HWC'.
Returns:
float: SSIM result.
"""
c1 = (0.01 * 255)**2
c2 = (0.03 * 255)**2
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
mu1 = cv2.filter2D(img, -1, window)[5:-5, 5:-5] # valid mode for window size 11
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1**2
mu2_sq = mu2**2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2))
return ssim_map.mean()
def _ssim_pth(img, img2):
"""Calculate SSIM (structural similarity) (PyTorch version).
It is called by func:`calculate_ssim_pt`.
Args:
img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
Returns:
float: SSIM result.
"""
c1 = (0.01 * 255)**2
c2 = (0.03 * 255)**2
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
window = torch.from_numpy(window).view(1, 1, 11, 11).expand(img.size(1), 1, 11, 11).to(img.dtype).to(img.device)
mu1 = F.conv2d(img, window, stride=1, padding=0, groups=img.shape[1]) # valid mode
mu2 = F.conv2d(img2, window, stride=1, padding=0, groups=img2.shape[1]) # valid mode
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(img * img, window, stride=1, padding=0, groups=img.shape[1]) - mu1_sq
sigma2_sq = F.conv2d(img2 * img2, window, stride=1, padding=0, groups=img.shape[1]) - mu2_sq
sigma12 = F.conv2d(img * img2, window, stride=1, padding=0, groups=img.shape[1]) - mu1_mu2
cs_map = (2 * sigma12 + c2) / (sigma1_sq + sigma2_sq + c2)
ssim_map = ((2 * mu1_mu2 + c1) / (mu1_sq + mu2_sq + c1)) * cs_map
return ssim_map.mean([1, 2, 3])
import cv2
import torch
from basicsr.metrics import calculate_psnr, calculate_ssim
from basicsr.metrics.psnr_ssim import calculate_psnr_pt, calculate_ssim_pt
from basicsr.utils import img2tensor
def test(img_path, img_path2, crop_border, test_y_channel=False):
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
img2 = cv2.imread(img_path2, cv2.IMREAD_UNCHANGED)
# --------------------- Numpy ---------------------
psnr = calculate_psnr(img, img2, crop_border=crop_border, input_order='HWC', test_y_channel=test_y_channel)
ssim = calculate_ssim(img, img2, crop_border=crop_border, input_order='HWC', test_y_channel=test_y_channel)
print(f'\tNumpy\tPSNR: {psnr:.6f} dB, \tSSIM: {ssim:.6f}')
# --------------------- PyTorch (CPU) ---------------------
img = img2tensor(img / 255., bgr2rgb=True, float32=True).unsqueeze_(0)
img2 = img2tensor(img2 / 255., bgr2rgb=True, float32=True).unsqueeze_(0)
psnr_pth = calculate_psnr_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel)
ssim_pth = calculate_ssim_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel)
print(f'\tTensor (CPU) \tPSNR: {psnr_pth[0]:.6f} dB, \tSSIM: {ssim_pth[0]:.6f}')
# --------------------- PyTorch (GPU) ---------------------
img = img.cuda()
img2 = img2.cuda()
psnr_pth = calculate_psnr_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel)
ssim_pth = calculate_ssim_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel)
print(f'\tTensor (GPU) \tPSNR: {psnr_pth[0]:.6f} dB, \tSSIM: {ssim_pth[0]:.6f}')
psnr_pth = calculate_psnr_pt(
torch.repeat_interleave(img, 2, dim=0),
torch.repeat_interleave(img2, 2, dim=0),
crop_border=crop_border,
test_y_channel=test_y_channel)
ssim_pth = calculate_ssim_pt(
torch.repeat_interleave(img, 2, dim=0),
torch.repeat_interleave(img2, 2, dim=0),
crop_border=crop_border,
test_y_channel=test_y_channel)
print(f'\tTensor (GPU batch) \tPSNR: {psnr_pth[0]:.6f}, {psnr_pth[1]:.6f} dB,'
f'\tSSIM: {ssim_pth[0]:.6f}, {ssim_pth[1]:.6f}')
if __name__ == '__main__':
test('tests/data/bic/baboon.png', 'tests/data/gt/baboon.png', crop_border=4, test_y_channel=False)
test('tests/data/bic/baboon.png', 'tests/data/gt/baboon.png', crop_border=4, test_y_channel=True)
test('tests/data/bic/comic.png', 'tests/data/gt/comic.png', crop_border=4, test_y_channel=False)
test('tests/data/bic/comic.png', 'tests/data/gt/comic.png', crop_border=4, test_y_channel=True)
import importlib
from copy import deepcopy
from os import path as osp
from basicsr.utils import get_root_logger, scandir
from basicsr.utils.registry import MODEL_REGISTRY
__all__ = ['build_model']
# automatically scan and import model modules for registry
# scan all the files under the 'models' folder and collect files ending with '_model.py'
model_folder = osp.dirname(osp.abspath(__file__))
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
# import all the model modules
_model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames]
def build_model(opt):
"""Build model from options.
Args:
opt (dict): Configuration. It must contain:
model_type (str): Model type.
"""
opt = deepcopy(opt)
model = MODEL_REGISTRY.get(opt['model_type'])(opt)
logger = get_root_logger()
logger.info(f'Model [{model.__class__.__name__}] is created.')
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment