Commit a75d2bda authored by mashun1's avatar mashun1
Browse files

evtexture

parents
Pipeline #1325 canceled with stages
000.h5 100
011.h5 100
015.h5 100
020.h5 100
001.h5 100
002.h5 100
003.h5 100
004.h5 100
005.h5 100
006.h5 100
007.h5 100
008.h5 100
009.h5 100
010.h5 100
012.h5 100
013.h5 100
014.h5 100
016.h5 100
017.h5 100
018.h5 100
019.h5 100
021.h5 100
022.h5 100
023.h5 100
024.h5 100
025.h5 100
026.h5 100
027.h5 100
028.h5 100
029.h5 100
030.h5 100
031.h5 100
032.h5 100
033.h5 100
034.h5 100
035.h5 100
036.h5 100
037.h5 100
038.h5 100
039.h5 100
040.h5 100
041.h5 100
042.h5 100
043.h5 100
044.h5 100
045.h5 100
046.h5 100
047.h5 100
048.h5 100
049.h5 100
050.h5 100
051.h5 100
052.h5 100
053.h5 100
054.h5 100
055.h5 100
056.h5 100
057.h5 100
058.h5 100
059.h5 100
060.h5 100
061.h5 100
062.h5 100
063.h5 100
064.h5 100
065.h5 100
066.h5 100
067.h5 100
068.h5 100
069.h5 100
070.h5 100
071.h5 100
072.h5 100
073.h5 100
074.h5 100
075.h5 100
076.h5 100
077.h5 100
078.h5 100
079.h5 100
080.h5 100
081.h5 100
082.h5 100
083.h5 100
084.h5 100
085.h5 100
086.h5 100
087.h5 100
088.h5 100
089.h5 100
090.h5 100
091.h5 100
092.h5 100
093.h5 100
094.h5 100
095.h5 100
096.h5 100
097.h5 100
098.h5 100
099.h5 100
100.h5 100
101.h5 100
102.h5 100
103.h5 100
104.h5 100
105.h5 100
106.h5 100
107.h5 100
108.h5 100
109.h5 100
110.h5 100
111.h5 100
112.h5 100
113.h5 100
114.h5 100
115.h5 100
116.h5 100
117.h5 100
118.h5 100
119.h5 100
120.h5 100
121.h5 100
122.h5 100
123.h5 100
124.h5 100
125.h5 100
126.h5 100
127.h5 100
128.h5 100
129.h5 100
130.h5 100
131.h5 100
132.h5 100
133.h5 100
134.h5 100
135.h5 100
136.h5 100
137.h5 100
138.h5 100
139.h5 100
140.h5 100
141.h5 100
142.h5 100
143.h5 100
144.h5 100
145.h5 100
146.h5 100
147.h5 100
148.h5 100
149.h5 100
150.h5 100
151.h5 100
152.h5 100
153.h5 100
154.h5 100
155.h5 100
156.h5 100
157.h5 100
158.h5 100
159.h5 100
160.h5 100
161.h5 100
162.h5 100
163.h5 100
164.h5 100
165.h5 100
166.h5 100
167.h5 100
168.h5 100
169.h5 100
170.h5 100
171.h5 100
172.h5 100
173.h5 100
174.h5 100
175.h5 100
176.h5 100
177.h5 100
178.h5 100
179.h5 100
180.h5 100
181.h5 100
182.h5 100
183.h5 100
184.h5 100
185.h5 100
186.h5 100
187.h5 100
188.h5 100
189.h5 100
190.h5 100
191.h5 100
192.h5 100
193.h5 100
194.h5 100
195.h5 100
196.h5 100
197.h5 100
198.h5 100
199.h5 100
200.h5 100
201.h5 100
202.h5 100
203.h5 100
204.h5 100
205.h5 100
206.h5 100
207.h5 100
208.h5 100
209.h5 100
210.h5 100
211.h5 100
212.h5 100
213.h5 100
214.h5 100
215.h5 100
216.h5 100
217.h5 100
218.h5 100
219.h5 100
220.h5 100
221.h5 100
222.h5 100
223.h5 100
224.h5 100
225.h5 100
226.h5 100
227.h5 100
228.h5 100
229.h5 100
230.h5 100
231.h5 100
232.h5 100
233.h5 100
234.h5 100
235.h5 100
236.h5 100
237.h5 100
238.h5 100
239.h5 100
240.h5 100
241.h5 100
242.h5 100
243.h5 100
244.h5 100
245.h5 100
246.h5 100
247.h5 100
248.h5 100
249.h5 100
250.h5 100
251.h5 100
252.h5 100
253.h5 100
254.h5 100
255.h5 100
256.h5 100
257.h5 100
258.h5 100
259.h5 100
260.h5 100
261.h5 100
262.h5 100
263.h5 100
264.h5 100
265.h5 100
266.h5 100
267.h5 100
268.h5 100
269.h5 100
foliage.h5 49
city.h5 34
walk.h5 47
calendar.h5 41
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import queue as Queue
import threading
import torch
from torch.utils.data import DataLoader
class PrefetchGenerator(threading.Thread):
"""A general prefetch generator.
Reference: https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
Args:
generator: Python generator.
num_prefetch_queue (int): Number of prefetch queue.
"""
def __init__(self, generator, num_prefetch_queue):
threading.Thread.__init__(self)
self.queue = Queue.Queue(num_prefetch_queue)
self.generator = generator
self.daemon = True
self.start()
def run(self):
for item in self.generator:
self.queue.put(item)
self.queue.put(None)
def __next__(self):
next_item = self.queue.get()
if next_item is None:
raise StopIteration
return next_item
def __iter__(self):
return self
class PrefetchDataLoader(DataLoader):
"""Prefetch version of dataloader.
Reference: https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
TODO:
Need to test on single gpu and ddp (multi-gpu). There is a known issue in
ddp.
Args:
num_prefetch_queue (int): Number of prefetch queue.
kwargs (dict): Other arguments for dataloader.
"""
def __init__(self, num_prefetch_queue, **kwargs):
self.num_prefetch_queue = num_prefetch_queue
super(PrefetchDataLoader, self).__init__(**kwargs)
def __iter__(self):
return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
class CPUPrefetcher():
"""CPU prefetcher.
Args:
loader: Dataloader.
"""
def __init__(self, loader):
self.ori_loader = loader
self.loader = iter(loader)
def next(self):
try:
return next(self.loader)
except StopIteration:
return None
def reset(self):
self.loader = iter(self.ori_loader)
class CUDAPrefetcher():
"""CUDA prefetcher.
Reference: https://github.com/NVIDIA/apex/issues/304#
It may consume more GPU memory.
Args:
loader: Dataloader.
opt (dict): Options.
"""
def __init__(self, loader, opt):
self.ori_loader = loader
self.loader = iter(loader)
self.opt = opt
self.stream = torch.cuda.Stream()
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
self.preload()
def preload(self):
try:
self.batch = next(self.loader) # self.batch is a dict
except StopIteration:
self.batch = None
return None
# put tensors to gpu
with torch.cuda.stream(self.stream):
for k, v in self.batch.items():
if torch.is_tensor(v):
self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
def next(self):
torch.cuda.current_stream().wait_stream(self.stream)
batch = self.batch
self.preload()
return batch
def reset(self):
self.loader = iter(self.ori_loader)
self.preload()
import cv2
import random
import torch
def mod_crop(img, scale):
"""Mod crop images, used during testing.
Args:
img (ndarray): Input image.
scale (int): Scale factor.
Returns:
ndarray: Result image.
"""
img = img.copy()
if img.ndim in (2, 3):
h, w = img.shape[0], img.shape[1]
h_remainder, w_remainder = h % scale, w % scale
img = img[:h - h_remainder, :w - w_remainder, ...]
else:
raise ValueError(f'Wrong img ndim: {img.ndim}.')
return img
def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
"""Paired random crop. Support Numpy array and Tensor inputs.
It crops lists of lq and gt images with corresponding locations.
Args:
img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
gt_patch_size (int): GT patch size.
scale (int): Scale factor.
gt_path (str): Path to ground-truth. Default: None.
Returns:
list[ndarray] | ndarray: GT images and LQ images. If returned results
only have one element, just return ndarray.
"""
if not isinstance(img_gts, list):
img_gts = [img_gts]
if not isinstance(img_lqs, list):
img_lqs = [img_lqs]
# determine input type: Numpy array or Tensor
input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
if input_type == 'Tensor':
h_lq, w_lq = img_lqs[0].size()[-2:]
h_gt, w_gt = img_gts[0].size()[-2:]
else:
h_lq, w_lq = img_lqs[0].shape[0:2]
h_gt, w_gt = img_gts[0].shape[0:2]
lq_patch_size = gt_patch_size // scale
if h_gt != h_lq * scale or w_gt != w_lq * scale:
raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
f'multiplication of LQ ({h_lq}, {w_lq}).')
if h_lq < lq_patch_size or w_lq < lq_patch_size:
raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
f'({lq_patch_size}, {lq_patch_size}). '
f'Please remove {gt_path}.')
# randomly choose top and left coordinates for lq patch
top = random.randint(0, h_lq - lq_patch_size)
left = random.randint(0, w_lq - lq_patch_size)
# crop lq patch
if input_type == 'Tensor':
img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
else:
img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
# crop corresponding gt patch
top_gt, left_gt = int(top * scale), int(left * scale)
if input_type == 'Tensor':
img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
else:
img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
if len(img_gts) == 1:
img_gts = img_gts[0]
if len(img_lqs) == 1:
img_lqs = img_lqs[0]
return img_gts, img_lqs
def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
"""Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
We use vertical flip and transpose for rotation implementation.
All the images in the list use the same augmentation.
Args:
imgs (list[ndarray] | ndarray): Images to be augmented. If the input
is an ndarray, it will be transformed to a list.
hflip (bool): Horizontal flip. Default: True.
rotation (bool): Ratotation. Default: True.
flows (list[ndarray]: Flows to be augmented. If the input is an
ndarray, it will be transformed to a list.
Dimension is (h, w, 2). Default: None.
return_status (bool): Return the status of flip and rotation.
Default: False.
Returns:
list[ndarray] | ndarray: Augmented images and flows. If returned
results only have one element, just return ndarray.
"""
hflip = hflip and random.random() < 0.5
vflip = rotation and random.random() < 0.5
rot90 = rotation and random.random() < 0.5
def _augment(img):
if hflip: # horizontal
cv2.flip(img, 1, img)
if vflip: # vertical
cv2.flip(img, 0, img)
if rot90:
img = img.transpose(1, 0, 2)
return img
def _augment_flow(flow):
if hflip: # horizontal
cv2.flip(flow, 1, flow)
flow[:, :, 0] *= -1
if vflip: # vertical
cv2.flip(flow, 0, flow)
flow[:, :, 1] *= -1
if rot90:
flow = flow.transpose(1, 0, 2)
flow = flow[:, :, [1, 0]]
return flow
if not isinstance(imgs, list):
imgs = [imgs]
imgs = [_augment(img) for img in imgs]
if len(imgs) == 1:
imgs = imgs[0]
if flows is not None:
if not isinstance(flows, list):
flows = [flows]
flows = [_augment_flow(flow) for flow in flows]
if len(flows) == 1:
flows = flows[0]
return imgs, flows
else:
if return_status:
return imgs, (hflip, vflip, rot90)
else:
return imgs
def img_rotate(img, angle, center=None, scale=1.0):
"""Rotate image.
Args:
img (ndarray): Image to be rotated.
angle (float): Rotation angle in degrees. Positive values mean
counter-clockwise rotation.
center (tuple[int]): Rotation center. If the center is None,
initialize it as the center of the image. Default: None.
scale (float): Isotropic scale factor. Default: 1.0.
"""
(h, w) = img.shape[:2]
if center is None:
center = (w // 2, h // 2)
matrix = cv2.getRotationMatrix2D(center, angle, scale)
rotated_img = cv2.warpAffine(img, matrix, (w, h))
return rotated_img
import numpy as np
import torch
from os import path as osp
from torch.utils import data as data
from basicsr.utils import get_root_logger, FileClient
from basicsr.utils.registry import DATASET_REGISTRY
from basicsr.data.transforms import mod_crop
from basicsr.utils.img_util import img2tensor
@DATASET_REGISTRY.register()
class VideoWithEventsTestDataset(data.Dataset):
def __init__(self, opt):
super().__init__()
self.opt = opt
self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
self.data_info = {'folder': []}
self.scale = opt['scale']
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
self.name = opt['name']
if self.io_backend_opt['type'] == 'hdf5':
self.io_backend_opt['h5_paths'] = [self.lq_root, self.gt_root]
self.io_backend_opt['client_keys'] = ['LR', 'HR']
else:
raise ValueError(f"We don't realize {self.io_backend_opt['type']} backend")
logger = get_root_logger()
logger.info(f'Generate data info for VideoWithEventsTestDataset - {opt["name"]}')
if 'meta_info_file' in opt:
with open(opt['meta_info_file'], 'r') as fin:
clips = []
clips_num = []
for line in fin:
clips.append(line.split(' ')[0])
clips_num.append(line.split(' ')[1])
else:
raise NotImplementedError
self.imgs_lq, self.imgs_gt, self.event_lqs = {}, {}, {}
self.folders = []
self.lq_paths = []
for clip, num in zip(clips, clips_num):
self.io_backend_opt['h5_clip'] = clip
self.file_client = FileClient(self.io_backend_opt['type'], **self.io_backend_opt)
img_lqs, img_gts, event_lqs = self.file_client.get(list(range(int(num))))
# mod_crop gt image for scale
img_gts = [mod_crop(img, self.scale) for img in img_gts]
self.imgs_lq[clip] = torch.stack(img2tensor(img_lqs), dim=0)
self.imgs_gt[clip] = torch.stack(img2tensor(img_gts), dim=0)
self.event_lqs[clip] = torch.from_numpy(np.stack(event_lqs, axis=0))
self.folders.append(clip)
self.lq_paths.append(osp.join('vid4', osp.splitext(clip)[0]))
self.data_info['folder'].extend([clip] * int(num))
def __getitem__(self, index):
folder = self.folders[index]
lq_path = self.lq_paths[index]
img_lq = self.imgs_lq[folder]
img_gt = self.imgs_gt[folder]
event_lq = self.event_lqs[folder]
voxel_f = event_lq[:len(event_lq) // 2]
voxel_b = event_lq[len(event_lq) // 2:]
return {
'lq': img_lq,
'gt': img_gt,
'voxels_f': voxel_f,
'voxels_b': voxel_b,
'folder': folder,
'lq_path': lq_path
}
def __len__(self):
return len(self.folders)
\ No newline at end of file
import importlib
from copy import deepcopy
from os import path as osp
from basicsr.utils import get_root_logger, scandir
from basicsr.utils.registry import LOSS_REGISTRY
from .gan_loss import g_path_regularize, gradient_penalty_loss, r1_penalty
__all__ = ['build_loss', 'gradient_penalty_loss', 'r1_penalty', 'g_path_regularize']
# automatically scan and import loss modules for registry
# scan all the files under the 'losses' folder and collect files ending with '_loss.py'
loss_folder = osp.dirname(osp.abspath(__file__))
loss_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(loss_folder) if v.endswith('_loss.py')]
# import all the loss modules
_model_modules = [importlib.import_module(f'basicsr.losses.{file_name}') for file_name in loss_filenames]
def build_loss(opt):
"""Build loss from options.
Args:
opt (dict): Configuration. It must contain:
type (str): Model type.
"""
opt = deepcopy(opt)
loss_type = opt.pop('type')
loss = LOSS_REGISTRY.get(loss_type)(**opt)
logger = get_root_logger()
logger.info(f'Loss [{loss.__class__.__name__}] is created.')
return loss
import torch
from torch import nn as nn
from torch.nn import functional as F
from basicsr.archs.vgg_arch import VGGFeatureExtractor
from basicsr.utils.registry import LOSS_REGISTRY
from .loss_util import weighted_loss
_reduction_modes = ['none', 'mean', 'sum']
@weighted_loss
def l1_loss(pred, target):
return F.l1_loss(pred, target, reduction='none')
@weighted_loss
def mse_loss(pred, target):
return F.mse_loss(pred, target, reduction='none')
@weighted_loss
def charbonnier_loss(pred, target, eps=1e-12):
return torch.sqrt((pred - target)**2 + eps)
@LOSS_REGISTRY.register()
class L1Loss(nn.Module):
"""L1 (mean absolute error, MAE) loss.
Args:
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
"""
def __init__(self, loss_weight=1.0, reduction='mean'):
super(L1Loss, self).__init__()
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
self.loss_weight = loss_weight
self.reduction = reduction
def forward(self, pred, target, weight=None, **kwargs):
"""
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
"""
return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)
@LOSS_REGISTRY.register()
class MSELoss(nn.Module):
"""MSE (L2) loss.
Args:
loss_weight (float): Loss weight for MSE loss. Default: 1.0.
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
"""
def __init__(self, loss_weight=1.0, reduction='mean'):
super(MSELoss, self).__init__()
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
self.loss_weight = loss_weight
self.reduction = reduction
def forward(self, pred, target, weight=None, **kwargs):
"""
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
"""
return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction)
@LOSS_REGISTRY.register()
class CharbonnierLoss(nn.Module):
"""Charbonnier loss (one variant of Robust L1Loss, a differentiable
variant of L1Loss).
Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
Super-Resolution".
Args:
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
eps (float): A value used to control the curvature near zero. Default: 1e-12.
"""
def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
super(CharbonnierLoss, self).__init__()
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
self.loss_weight = loss_weight
self.reduction = reduction
self.eps = eps
def forward(self, pred, target, weight=None, **kwargs):
"""
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
"""
return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction)
@LOSS_REGISTRY.register()
class WeightedTVLoss(L1Loss):
"""Weighted TV loss.
Args:
loss_weight (float): Loss weight. Default: 1.0.
"""
def __init__(self, loss_weight=1.0, reduction='mean'):
if reduction not in ['mean', 'sum']:
raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: mean | sum')
super(WeightedTVLoss, self).__init__(loss_weight=loss_weight, reduction=reduction)
def forward(self, pred, weight=None):
if weight is None:
y_weight = None
x_weight = None
else:
y_weight = weight[:, :, :-1, :]
x_weight = weight[:, :, :, :-1]
y_diff = super().forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=y_weight)
x_diff = super().forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=x_weight)
loss = x_diff + y_diff
return loss
@LOSS_REGISTRY.register()
class PerceptualLoss(nn.Module):
"""Perceptual loss with commonly used style loss.
Args:
layer_weights (dict): The weight for each layer of vgg feature.
Here is an example: {'conv5_4': 1.}, which means the conv5_4
feature layer (before relu5_4) will be extracted with weight
1.0 in calculating losses.
vgg_type (str): The type of vgg network used as feature extractor.
Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image in vgg.
Default: True.
range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
Default: False.
perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
loss will be calculated and the loss will multiplied by the
weight. Default: 1.0.
style_weight (float): If `style_weight > 0`, the style loss will be
calculated and the loss will multiplied by the weight.
Default: 0.
criterion (str): Criterion used for perceptual loss. Default: 'l1'.
"""
def __init__(self,
layer_weights,
vgg_type='vgg19',
use_input_norm=True,
range_norm=False,
perceptual_weight=1.0,
style_weight=0.,
criterion='l1'):
super(PerceptualLoss, self).__init__()
self.perceptual_weight = perceptual_weight
self.style_weight = style_weight
self.layer_weights = layer_weights
self.vgg = VGGFeatureExtractor(
layer_name_list=list(layer_weights.keys()),
vgg_type=vgg_type,
use_input_norm=use_input_norm,
range_norm=range_norm)
self.criterion_type = criterion
if self.criterion_type == 'l1':
self.criterion = torch.nn.L1Loss()
elif self.criterion_type == 'l2':
self.criterion = torch.nn.MSELoss()
elif self.criterion_type == 'fro':
self.criterion = None
else:
raise NotImplementedError(f'{criterion} criterion has not been supported.')
def forward(self, x, gt):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
# extract vgg features
x_features = self.vgg(x)
gt_features = self.vgg(gt.detach())
# calculate perceptual loss
if self.perceptual_weight > 0:
percep_loss = 0
for k in x_features.keys():
if self.criterion_type == 'fro':
percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
else:
percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
percep_loss *= self.perceptual_weight
else:
percep_loss = None
# calculate style loss
if self.style_weight > 0:
style_loss = 0
for k in x_features.keys():
if self.criterion_type == 'fro':
style_loss += torch.norm(
self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
else:
style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
gt_features[k])) * self.layer_weights[k]
style_loss *= self.style_weight
else:
style_loss = None
return percep_loss, style_loss
def _gram_mat(self, x):
"""Calculate Gram matrix.
Args:
x (torch.Tensor): Tensor with shape of (n, c, h, w).
Returns:
torch.Tensor: Gram matrix.
"""
n, c, h, w = x.size()
features = x.view(n, c, w * h)
features_t = features.transpose(1, 2)
gram = features.bmm(features_t) / (c * h * w)
return gram
This diff is collapsed.
import functools
import torch
from torch.nn import functional as F
def reduce_loss(loss, reduction):
"""Reduce loss as specified.
Args:
loss (Tensor): Elementwise loss tensor.
reduction (str): Options are 'none', 'mean' and 'sum'.
Returns:
Tensor: Reduced loss tensor.
"""
reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, elementwise_mean:1, sum: 2
if reduction_enum == 0:
return loss
elif reduction_enum == 1:
return loss.mean()
else:
return loss.sum()
def weight_reduce_loss(loss, weight=None, reduction='mean'):
"""Apply element-wise weight and reduce loss.
Args:
loss (Tensor): Element-wise loss.
weight (Tensor): Element-wise weights. Default: None.
reduction (str): Same as built-in losses of PyTorch. Options are
'none', 'mean' and 'sum'. Default: 'mean'.
Returns:
Tensor: Loss values.
"""
# if weight is specified, apply element-wise weight
if weight is not None:
assert weight.dim() == loss.dim()
assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
loss = loss * weight
# if weight is not specified or reduction is sum, just reduce the loss
if weight is None or reduction == 'sum':
loss = reduce_loss(loss, reduction)
# if reduction is mean, then compute mean over weight region
elif reduction == 'mean':
if weight.size(1) > 1:
weight = weight.sum()
else:
weight = weight.sum() * loss.size(1)
loss = loss.sum() / weight
return loss
def weighted_loss(loss_func):
"""Create a weighted version of a given loss function.
To use this decorator, the loss function must have the signature like
`loss_func(pred, target, **kwargs)`. The function only needs to compute
element-wise loss without any reduction. This decorator will add weight
and reduction arguments to the function. The decorated function will have
the signature like `loss_func(pred, target, weight=None, reduction='mean',
**kwargs)`.
:Example:
>>> import torch
>>> @weighted_loss
>>> def l1_loss(pred, target):
>>> return (pred - target).abs()
>>> pred = torch.Tensor([0, 2, 3])
>>> target = torch.Tensor([1, 1, 1])
>>> weight = torch.Tensor([1, 0, 1])
>>> l1_loss(pred, target)
tensor(1.3333)
>>> l1_loss(pred, target, weight)
tensor(1.5000)
>>> l1_loss(pred, target, reduction='none')
tensor([1., 1., 2.])
>>> l1_loss(pred, target, weight, reduction='sum')
tensor(3.)
"""
@functools.wraps(loss_func)
def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
# get element-wise loss
loss = loss_func(pred, target, **kwargs)
loss = weight_reduce_loss(loss, weight, reduction)
return loss
return wrapper
def get_local_weights(residual, ksize):
"""Get local weights for generating the artifact map of LDL.
It is only called by the `get_refined_artifact_map` function.
Args:
residual (Tensor): Residual between predicted and ground truth images.
ksize (Int): size of the local window.
Returns:
Tensor: weight for each pixel to be discriminated as an artifact pixel
"""
pad = (ksize - 1) // 2
residual_pad = F.pad(residual, pad=[pad, pad, pad, pad], mode='reflect')
unfolded_residual = residual_pad.unfold(2, ksize, 1).unfold(3, ksize, 1)
pixel_level_weight = torch.var(unfolded_residual, dim=(-1, -2), unbiased=True, keepdim=True).squeeze(-1).squeeze(-1)
return pixel_level_weight
def get_refined_artifact_map(img_gt, img_output, img_ema, ksize):
"""Calculate the artifact map of LDL
(Details or Artifacts: A Locally Discriminative Learning Approach to Realistic Image Super-Resolution. In CVPR 2022)
Args:
img_gt (Tensor): ground truth images.
img_output (Tensor): output images given by the optimizing model.
img_ema (Tensor): output images given by the ema model.
ksize (Int): size of the local window.
Returns:
overall_weight: weight for each pixel to be discriminated as an artifact pixel
(calculated based on both local and global observations).
"""
residual_ema = torch.sum(torch.abs(img_gt - img_ema), 1, keepdim=True)
residual_sr = torch.sum(torch.abs(img_gt - img_output), 1, keepdim=True)
patch_level_weight = torch.var(residual_sr.clone(), dim=(-1, -2, -3), keepdim=True)**(1 / 5)
pixel_level_weight = get_local_weights(residual_sr.clone(), ksize)
overall_weight = patch_level_weight * pixel_level_weight
overall_weight[residual_sr < residual_ema] = 0
return overall_weight
from copy import deepcopy
from basicsr.utils.registry import METRIC_REGISTRY
from .niqe import calculate_niqe
from .psnr_ssim import calculate_psnr, calculate_ssim
from .lpips import calculate_lpips
__all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe']
def calculate_metric(data, opt):
"""Calculate metric from data and options.
Args:
opt (dict): Configuration. It must contain:
type (str): Model type.
"""
opt = deepcopy(opt)
metric_type = opt.pop('type')
metric = METRIC_REGISTRY.get(metric_type)(**data, **opt)
return metric
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment