Commit 2136e796 authored by mashun1's avatar mashun1
Browse files

codeformer

parents
Pipeline #699 canceled with stages
from .upfirdn2d import upfirdn2d
__all__ = ['upfirdn2d']
// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp
#include <torch/extension.h>
torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
int up_x, int up_y, int down_x, int down_y,
int pad_x0, int pad_x1, int pad_y0, int pad_y1);
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
int up_x, int up_y, int down_x, int down_y,
int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
CHECK_CUDA(input);
CHECK_CUDA(kernel);
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
}
// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d_kernel.cu
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html
#include <torch/types.h>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
int c = a / b;
if (c * b > a) {
c--;
}
return c;
}
struct UpFirDn2DKernelParams {
int up_x;
int up_y;
int down_x;
int down_y;
int pad_x0;
int pad_x1;
int pad_y0;
int pad_y1;
int major_dim;
int in_h;
int in_w;
int minor_dim;
int kernel_h;
int kernel_w;
int out_h;
int out_w;
int loop_major;
int loop_x;
};
template <typename scalar_t>
__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
const scalar_t *kernel,
const UpFirDn2DKernelParams p) {
int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
int out_y = minor_idx / p.minor_dim;
minor_idx -= out_y * p.minor_dim;
int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
int major_idx_base = blockIdx.z * p.loop_major;
if (out_x_base >= p.out_w || out_y >= p.out_h ||
major_idx_base >= p.major_dim) {
return;
}
int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
for (int loop_major = 0, major_idx = major_idx_base;
loop_major < p.loop_major && major_idx < p.major_dim;
loop_major++, major_idx++) {
for (int loop_x = 0, out_x = out_x_base;
loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
const scalar_t *x_p =
&input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
minor_idx];
const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
int x_px = p.minor_dim;
int k_px = -p.up_x;
int x_py = p.in_w * p.minor_dim;
int k_py = -p.up_y * p.kernel_w;
scalar_t v = 0.0f;
for (int y = 0; y < h; y++) {
for (int x = 0; x < w; x++) {
v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
x_p += x_px;
k_p += k_px;
}
x_p += x_py - w * x_px;
k_p += k_py - w * k_px;
}
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
minor_idx] = v;
}
}
}
template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
const scalar_t *kernel,
const UpFirDn2DKernelParams p) {
const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
__shared__ volatile float sk[kernel_h][kernel_w];
__shared__ volatile float sx[tile_in_h][tile_in_w];
int minor_idx = blockIdx.x;
int tile_out_y = minor_idx / p.minor_dim;
minor_idx -= tile_out_y * p.minor_dim;
tile_out_y *= tile_out_h;
int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
int major_idx_base = blockIdx.z * p.loop_major;
if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
major_idx_base >= p.major_dim) {
return;
}
for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
tap_idx += blockDim.x) {
int ky = tap_idx / kernel_w;
int kx = tap_idx - ky * kernel_w;
scalar_t v = 0.0;
if (kx < p.kernel_w & ky < p.kernel_h) {
v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
}
sk[ky][kx] = v;
}
for (int loop_major = 0, major_idx = major_idx_base;
loop_major < p.loop_major & major_idx < p.major_dim;
loop_major++, major_idx++) {
for (int loop_x = 0, tile_out_x = tile_out_x_base;
loop_x < p.loop_x & tile_out_x < p.out_w;
loop_x++, tile_out_x += tile_out_w) {
int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
int tile_in_x = floor_div(tile_mid_x, up_x);
int tile_in_y = floor_div(tile_mid_y, up_y);
__syncthreads();
for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
in_idx += blockDim.x) {
int rel_in_y = in_idx / tile_in_w;
int rel_in_x = in_idx - rel_in_y * tile_in_w;
int in_x = rel_in_x + tile_in_x;
int in_y = rel_in_y + tile_in_y;
scalar_t v = 0.0;
if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
p.minor_dim +
minor_idx];
}
sx[rel_in_y][rel_in_x] = v;
}
__syncthreads();
for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
out_idx += blockDim.x) {
int rel_out_y = out_idx / tile_out_w;
int rel_out_x = out_idx - rel_out_y * tile_out_w;
int out_x = rel_out_x + tile_out_x;
int out_y = rel_out_y + tile_out_y;
int mid_x = tile_mid_x + rel_out_x * down_x;
int mid_y = tile_mid_y + rel_out_y * down_y;
int in_x = floor_div(mid_x, up_x);
int in_y = floor_div(mid_y, up_y);
int rel_in_x = in_x - tile_in_x;
int rel_in_y = in_y - tile_in_y;
int kernel_x = (in_x + 1) * up_x - mid_x - 1;
int kernel_y = (in_y + 1) * up_y - mid_y - 1;
scalar_t v = 0.0;
#pragma unroll
for (int y = 0; y < kernel_h / up_y; y++)
#pragma unroll
for (int x = 0; x < kernel_w / up_x; x++)
v += sx[rel_in_y + y][rel_in_x + x] *
sk[kernel_y + y * up_y][kernel_x + x * up_x];
if (out_x < p.out_w & out_y < p.out_h) {
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
minor_idx] = v;
}
}
}
}
}
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
const torch::Tensor &kernel, int up_x, int up_y,
int down_x, int down_y, int pad_x0, int pad_x1,
int pad_y0, int pad_y1) {
int curDevice = -1;
cudaGetDevice(&curDevice);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
UpFirDn2DKernelParams p;
auto x = input.contiguous();
auto k = kernel.contiguous();
p.major_dim = x.size(0);
p.in_h = x.size(1);
p.in_w = x.size(2);
p.minor_dim = x.size(3);
p.kernel_h = k.size(0);
p.kernel_w = k.size(1);
p.up_x = up_x;
p.up_y = up_y;
p.down_x = down_x;
p.down_y = down_y;
p.pad_x0 = pad_x0;
p.pad_x1 = pad_x1;
p.pad_y0 = pad_y0;
p.pad_y1 = pad_y1;
p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
p.down_y;
p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
p.down_x;
auto out =
at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
int mode = -1;
int tile_out_h = -1;
int tile_out_w = -1;
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 1;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 3 && p.kernel_w <= 3) {
mode = 2;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 3;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 2 && p.kernel_w <= 2) {
mode = 4;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 5;
tile_out_h = 8;
tile_out_w = 32;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
p.kernel_h <= 2 && p.kernel_w <= 2) {
mode = 6;
tile_out_h = 8;
tile_out_w = 32;
}
dim3 block_size;
dim3 grid_size;
if (tile_out_h > 0 && tile_out_w > 0) {
p.loop_major = (p.major_dim - 1) / 16384 + 1;
p.loop_x = 1;
block_size = dim3(32 * 8, 1, 1);
grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
(p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
(p.major_dim - 1) / p.loop_major + 1);
} else {
p.loop_major = (p.major_dim - 1) / 16384 + 1;
p.loop_x = 4;
block_size = dim3(4, 32, 1);
grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
(p.out_w - 1) / (p.loop_x * block_size.y) + 1,
(p.major_dim - 1) / p.loop_major + 1);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
switch (mode) {
case 1:
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 2:
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 3:
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 4:
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 5:
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 6:
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
default:
upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
}
});
return out;
}
# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501
import torch
from torch.autograd import Function
from torch.nn import functional as F
try:
from . import upfirdn2d_ext
except ImportError:
import os
BASICSR_JIT = os.getenv('BASICSR_JIT')
if BASICSR_JIT == 'True':
from torch.utils.cpp_extension import load
module_path = os.path.dirname(__file__)
upfirdn2d_ext = load(
'upfirdn2d',
sources=[
os.path.join(module_path, 'src', 'upfirdn2d.cpp'),
os.path.join(module_path, 'src', 'upfirdn2d_kernel.cu'),
],
)
class UpFirDn2dBackward(Function):
@staticmethod
def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size):
up_x, up_y = up
down_x, down_y = down
g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
grad_input = upfirdn2d_ext.upfirdn2d(
grad_output,
grad_kernel,
down_x,
down_y,
up_x,
up_y,
g_pad_x0,
g_pad_x1,
g_pad_y0,
g_pad_y1,
)
grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
ctx.save_for_backward(kernel)
pad_x0, pad_x1, pad_y0, pad_y1 = pad
ctx.up_x = up_x
ctx.up_y = up_y
ctx.down_x = down_x
ctx.down_y = down_y
ctx.pad_x0 = pad_x0
ctx.pad_x1 = pad_x1
ctx.pad_y0 = pad_y0
ctx.pad_y1 = pad_y1
ctx.in_size = in_size
ctx.out_size = out_size
return grad_input
@staticmethod
def backward(ctx, gradgrad_input):
kernel, = ctx.saved_tensors
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
gradgrad_out = upfirdn2d_ext.upfirdn2d(
gradgrad_input,
kernel,
ctx.up_x,
ctx.up_y,
ctx.down_x,
ctx.down_y,
ctx.pad_x0,
ctx.pad_x1,
ctx.pad_y0,
ctx.pad_y1,
)
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0],
# ctx.out_size[1], ctx.in_size[3])
gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1])
return gradgrad_out, None, None, None, None, None, None, None, None
class UpFirDn2d(Function):
@staticmethod
def forward(ctx, input, kernel, up, down, pad):
up_x, up_y = up
down_x, down_y = down
pad_x0, pad_x1, pad_y0, pad_y1 = pad
kernel_h, kernel_w = kernel.shape
batch, channel, in_h, in_w = input.shape
ctx.in_size = input.shape
input = input.reshape(-1, in_h, in_w, 1)
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
ctx.out_size = (out_h, out_w)
ctx.up = (up_x, up_y)
ctx.down = (down_x, down_y)
ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
g_pad_x0 = kernel_w - pad_x0 - 1
g_pad_y0 = kernel_h - pad_y0 - 1
g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1)
# out = out.view(major, out_h, out_w, minor)
out = out.view(-1, channel, out_h, out_w)
return out
@staticmethod
def backward(ctx, grad_output):
kernel, grad_kernel = ctx.saved_tensors
grad_input = UpFirDn2dBackward.apply(
grad_output,
kernel,
grad_kernel,
ctx.up,
ctx.down,
ctx.pad,
ctx.g_pad,
ctx.in_size,
ctx.out_size,
)
return grad_input, None, None, None, None
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
if input.device.type == 'cpu':
out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
else:
out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]))
return out
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
_, channel, in_h, in_w = input.shape
input = input.reshape(-1, in_h, in_w, 1)
_, in_h, in_w, minor = input.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, in_h, 1, in_w, 1, minor)
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
out = out[:, max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ]
out = out.permute(0, 3, 1, 2)
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(
-1,
minor,
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
)
out = out.permute(0, 2, 3, 1)
out = out[:, ::down_y, ::down_x, :]
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
return out.view(-1, channel, out_h, out_w)
#!/usr/bin/env python
from setuptools import find_packages, setup
import os
import subprocess
import sys
import time
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
from utils.misc import gpu_is_available
version_file = './basicsr/version.py'
def readme():
with open('README.md', encoding='utf-8') as f:
content = f.read()
return content
def get_git_hash():
def _minimal_ext_cmd(cmd):
# construct minimal environment
env = {}
for k in ['SYSTEMROOT', 'PATH', 'HOME']:
v = os.environ.get(k)
if v is not None:
env[k] = v
# LANGUAGE is used on win32
env['LANGUAGE'] = 'C'
env['LANG'] = 'C'
env['LC_ALL'] = 'C'
out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
return out
try:
out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
sha = out.strip().decode('ascii')
except OSError:
sha = 'unknown'
return sha
def get_hash():
if os.path.exists('.git'):
sha = get_git_hash()[:7]
elif os.path.exists(version_file):
try:
from version import __version__
sha = __version__.split('+')[-1]
except ImportError:
raise ImportError('Unable to get git version')
else:
sha = 'unknown'
return sha
def write_version_py():
content = """# GENERATED VERSION FILE
# TIME: {}
__version__ = '{}'
__gitsha__ = '{}'
version_info = ({})
"""
sha = get_hash()
with open('./basicsr/VERSION', 'r') as f:
SHORT_VERSION = f.read().strip()
VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)
with open(version_file, 'w') as f:
f.write(version_file_str)
def get_version():
with open(version_file, 'r') as f:
exec(compile(f.read(), version_file, 'exec'))
return locals()['__version__']
def make_cuda_ext(name, module, sources, sources_cuda=None):
if sources_cuda is None:
sources_cuda = []
define_macros = []
extra_compile_args = {'cxx': []}
# if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
if gpu_is_available or os.getenv('FORCE_CUDA', '0') == '1':
define_macros += [('WITH_CUDA', None)]
extension = CUDAExtension
extra_compile_args['nvcc'] = [
'-D__CUDA_NO_HALF_OPERATORS__',
'-D__CUDA_NO_HALF_CONVERSIONS__',
'-D__CUDA_NO_HALF2_OPERATORS__',
]
sources += sources_cuda
else:
print(f'Compiling {name} without CUDA')
extension = CppExtension
return extension(
name=f'{module}.{name}',
sources=[os.path.join(*module.split('.'), p) for p in sources],
define_macros=define_macros,
extra_compile_args=extra_compile_args)
def get_requirements(filename='requirements.txt'):
with open(os.path.join('.', filename), 'r') as f:
requires = [line.replace('\n', '') for line in f.readlines()]
return requires
if __name__ == '__main__':
if '--cuda_ext' in sys.argv:
ext_modules = [
make_cuda_ext(
name='deform_conv_ext',
module='ops.dcn',
sources=['src/deform_conv_ext.cpp'],
sources_cuda=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu']),
make_cuda_ext(
name='fused_act_ext',
module='ops.fused_act',
sources=['src/fused_bias_act.cpp'],
sources_cuda=['src/fused_bias_act_kernel.cu']),
make_cuda_ext(
name='upfirdn2d_ext',
module='ops.upfirdn2d',
sources=['src/upfirdn2d.cpp'],
sources_cuda=['src/upfirdn2d_kernel.cu']),
]
sys.argv.remove('--cuda_ext')
else:
ext_modules = []
write_version_py()
setup(
name='basicsr',
version=get_version(),
description='Open Source Image and Video Super-Resolution Toolbox',
long_description=readme(),
long_description_content_type='text/markdown',
author='Xintao Wang',
author_email='xintao.wang@outlook.com',
keywords='computer vision, restoration, super resolution',
url='https://github.com/xinntao/BasicSR',
include_package_data=True,
packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')),
classifiers=[
'Development Status :: 4 - Beta',
'License :: OSI Approved :: Apache Software License',
'Operating System :: OS Independent',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
],
license='Apache License 2.0',
setup_requires=['cython', 'numpy'],
install_requires=get_requirements(),
ext_modules=ext_modules,
cmdclass={'build_ext': BuildExtension},
zip_safe=False)
import argparse
import datetime
import logging
import math
import copy
import random
import time
import torch
from os import path as osp
from basicsr.data import build_dataloader, build_dataset
from basicsr.data.data_sampler import EnlargedSampler
from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
from basicsr.models import build_model
from basicsr.utils import (MessageLogger, check_resume, get_env_info, get_root_logger, init_tb_logger,
init_wandb_logger, make_exp_dirs, mkdir_and_rename, set_random_seed)
from basicsr.utils.dist_util import get_dist_info, init_dist
from basicsr.utils.options import dict2str, parse
import warnings
# ignore UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`.
warnings.filterwarnings("ignore", category=UserWarning)
def parse_options(root_path, is_train=True):
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.')
parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
opt = parse(args.opt, root_path, is_train=is_train)
# distributed settings
if args.launcher == 'none':
opt['dist'] = False
print('Disable distributed.', flush=True)
else:
opt['dist'] = True
if args.launcher == 'slurm' and 'dist_params' in opt:
init_dist(args.launcher, **opt['dist_params'])
else:
init_dist(args.launcher)
opt['rank'], opt['world_size'] = get_dist_info()
# random seed
seed = opt.get('manual_seed')
if seed is None:
seed = random.randint(1, 10000)
opt['manual_seed'] = seed
set_random_seed(seed + opt['rank'])
return opt
def init_loggers(opt):
log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
logger.info(get_env_info())
logger.info(dict2str(opt))
# initialize wandb logger before tensorboard logger to allow proper sync:
if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project') is not None):
assert opt['logger'].get('use_tb_logger') is True, ('should turn on tensorboard when using wandb')
init_wandb_logger(opt)
tb_logger = None
if opt['logger'].get('use_tb_logger'):
tb_logger = init_tb_logger(log_dir=osp.join('tb_logger', opt['name']))
return logger, tb_logger
def create_train_val_dataloader(opt, logger):
# create train and val dataloaders
train_loader, val_loader = None, None
for phase, dataset_opt in opt['datasets'].items():
if phase == 'train':
dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1)
train_set = build_dataset(dataset_opt)
train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio)
train_loader = build_dataloader(
train_set,
dataset_opt,
num_gpu=opt['num_gpu'],
dist=opt['dist'],
sampler=train_sampler,
seed=opt['manual_seed'])
num_iter_per_epoch = math.ceil(
len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size']))
total_iters = int(opt['train']['total_iter'])
total_epochs = math.ceil(total_iters / (num_iter_per_epoch))
logger.info('Training statistics:'
f'\n\tNumber of train images: {len(train_set)}'
f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}'
f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}'
f'\n\tWorld size (gpu number): {opt["world_size"]}'
f'\n\tRequire iter number per epoch: {num_iter_per_epoch}'
f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.')
elif phase == 'val':
val_set = build_dataset(dataset_opt)
val_loader = build_dataloader(
val_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed'])
logger.info(f'Number of val images/folders in {dataset_opt["name"]}: ' f'{len(val_set)}')
else:
raise ValueError(f'Dataset phase {phase} is not recognized.')
return train_loader, train_sampler, val_loader, total_epochs, total_iters
def train_pipeline(root_path):
# parse options, set distributed setting, set ramdom seed
opt = parse_options(root_path, is_train=True)
torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = True
# load resume states if necessary
if opt['path'].get('resume_state'):
device_id = torch.cuda.current_device()
resume_state = torch.load(
opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id))
else:
resume_state = None
# mkdir for experiments and logger
if resume_state is None:
make_exp_dirs(opt)
if opt['logger'].get('use_tb_logger') and opt['rank'] == 0:
mkdir_and_rename(osp.join('tb_logger', opt['name']))
# initialize loggers
logger, tb_logger = init_loggers(opt)
# create train and validation dataloaders
result = create_train_val_dataloader(opt, logger)
train_loader, train_sampler, val_loader, total_epochs, total_iters = result
# create model
if resume_state: # resume training
check_resume(opt, resume_state['iter'])
model = build_model(opt)
model.resume_training(resume_state) # handle optimizers and schedulers
logger.info(f"Resuming training from epoch: {resume_state['epoch']}, " f"iter: {resume_state['iter']}.")
start_epoch = resume_state['epoch']
current_iter = resume_state['iter']
else:
model = build_model(opt)
start_epoch = 0
current_iter = 0
# create message logger (formatted outputs)
msg_logger = MessageLogger(opt, current_iter, tb_logger)
# dataloader prefetcher
prefetch_mode = opt['datasets']['train'].get('prefetch_mode')
if prefetch_mode is None or prefetch_mode == 'cpu':
prefetcher = CPUPrefetcher(train_loader)
elif prefetch_mode == 'cuda':
prefetcher = CUDAPrefetcher(train_loader, opt)
logger.info(f'Use {prefetch_mode} prefetch dataloader')
if opt['datasets']['train'].get('pin_memory') is not True:
raise ValueError('Please set pin_memory=True for CUDAPrefetcher.')
else:
raise ValueError(f'Wrong prefetch_mode {prefetch_mode}.' "Supported ones are: None, 'cuda', 'cpu'.")
# training
logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter+1}')
data_time, iter_time = time.time(), time.time()
start_time = time.time()
for epoch in range(start_epoch, total_epochs + 1):
train_sampler.set_epoch(epoch)
prefetcher.reset()
train_data = prefetcher.next()
while train_data is not None:
data_time = time.time() - data_time
current_iter += 1
if current_iter > total_iters:
break
# update learning rate
model.update_learning_rate(current_iter, warmup_iter=opt['train'].get('warmup_iter', -1))
# training
model.feed_data(train_data)
model.optimize_parameters(current_iter)
iter_time = time.time() - iter_time
# log
if current_iter % opt['logger']['print_freq'] == 0:
log_vars = {'epoch': epoch, 'iter': current_iter}
log_vars.update({'lrs': model.get_current_learning_rate()})
log_vars.update({'time': iter_time, 'data_time': data_time})
log_vars.update(model.get_current_log())
msg_logger(log_vars)
# save models and training states
if current_iter % opt['logger']['save_checkpoint_freq'] == 0:
logger.info('Saving models and training states.')
model.save(epoch, current_iter)
# validation
if opt.get('val') is not None and opt['datasets'].get('val') is not None \
and (current_iter % opt['val']['val_freq'] == 0):
model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
data_time = time.time()
iter_time = time.time()
train_data = prefetcher.next()
# end of iter
# end of epoch
consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time)))
logger.info(f'End of training. Time consumed: {consumed_time}')
logger.info('Save the latest model.')
model.save(epoch=-1, current_iter=-1) # -1 stands for the latest
if opt.get('val') is not None and opt['datasets'].get('val'):
model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
if tb_logger:
tb_logger.close()
if __name__ == '__main__':
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
train_pipeline(root_path)
from .file_client import FileClient
from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img
from .logger import MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger
from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt
__all__ = [
# file_client.py
'FileClient',
# img_util.py
'img2tensor',
'tensor2img',
'imfrombytes',
'imwrite',
'crop_border',
# logger.py
'MessageLogger',
'init_tb_logger',
'init_wandb_logger',
'get_root_logger',
'get_env_info',
# misc.py
'set_random_seed',
'get_time_str',
'mkdir_and_rename',
'make_exp_dirs',
'scandir',
'check_resume',
'sizeof_fmt'
]
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
import functools
import os
import subprocess
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
def init_dist(launcher, backend='nccl', **kwargs):
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
if launcher == 'pytorch':
_init_dist_pytorch(backend, **kwargs)
elif launcher == 'slurm':
_init_dist_slurm(backend, **kwargs)
else:
raise ValueError(f'Invalid launcher type: {launcher}')
def _init_dist_pytorch(backend, **kwargs):
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
def _init_dist_slurm(backend, port=None):
"""Initialize slurm distributed training environment.
If argument ``port`` is not specified, then the master port will be system
environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
environment variable, then a default port ``29500`` will be used.
Args:
backend (str): Backend of torch.distributed.
port (int, optional): Master port. Defaults to None.
"""
proc_id = int(os.environ['SLURM_PROCID'])
ntasks = int(os.environ['SLURM_NTASKS'])
node_list = os.environ['SLURM_NODELIST']
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(proc_id % num_gpus)
addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1')
# specify master port
if port is not None:
os.environ['MASTER_PORT'] = str(port)
elif 'MASTER_PORT' in os.environ:
pass # use MASTER_PORT in the environment variable
else:
# 29500 is torch.distributed default port
os.environ['MASTER_PORT'] = '29500'
os.environ['MASTER_ADDR'] = addr
os.environ['WORLD_SIZE'] = str(ntasks)
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
os.environ['RANK'] = str(proc_id)
dist.init_process_group(backend=backend)
def get_dist_info():
if dist.is_available():
initialized = dist.is_initialized()
else:
initialized = False
if initialized:
rank = dist.get_rank()
world_size = dist.get_world_size()
else:
rank = 0
world_size = 1
return rank, world_size
def master_only(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
rank, _ = get_dist_info()
if rank == 0:
return func(*args, **kwargs)
return wrapper
import math
import os
import requests
from torch.hub import download_url_to_file, get_dir
from tqdm import tqdm
from urllib.parse import urlparse
from .misc import sizeof_fmt
def download_file_from_google_drive(file_id, save_path):
"""Download files from google drive.
Ref:
https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
Args:
file_id (str): File id.
save_path (str): Save path.
"""
session = requests.Session()
URL = 'https://docs.google.com/uc?export=download'
params = {'id': file_id}
response = session.get(URL, params=params, stream=True)
token = get_confirm_token(response)
if token:
params['confirm'] = token
response = session.get(URL, params=params, stream=True)
# get file size
response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
print(response_file_size)
if 'Content-Range' in response_file_size.headers:
file_size = int(response_file_size.headers['Content-Range'].split('/')[1])
else:
file_size = None
save_response_content(response, save_path, file_size)
def get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None
def save_response_content(response, destination, file_size=None, chunk_size=32768):
if file_size is not None:
pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
readable_file_size = sizeof_fmt(file_size)
else:
pbar = None
with open(destination, 'wb') as f:
downloaded_size = 0
for chunk in response.iter_content(chunk_size):
downloaded_size += chunk_size
if pbar is not None:
pbar.update(1)
pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}')
if chunk: # filter out keep-alive new chunks
f.write(chunk)
if pbar is not None:
pbar.close()
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
"""Load file form http url, will download models if necessary.
Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
Args:
url (str): URL to be downloaded.
model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
Default: None.
progress (bool): Whether to show the download progress. Default: True.
file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
Returns:
str: The path to the downloaded file.
"""
if model_dir is None: # use the pytorch hub_dir
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
os.makedirs(model_dir, exist_ok=True)
parts = urlparse(url)
filename = os.path.basename(parts.path)
if file_name is not None:
filename = file_name
cached_file = os.path.abspath(os.path.join(model_dir, filename))
if not os.path.exists(cached_file):
print(f'Downloading: "{url}" to {cached_file}\n')
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
return cached_file
\ No newline at end of file
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
from abc import ABCMeta, abstractmethod
class BaseStorageBackend(metaclass=ABCMeta):
"""Abstract class of storage backends.
All backends need to implement two apis: ``get()`` and ``get_text()``.
``get()`` reads the file as a byte stream and ``get_text()`` reads the file
as texts.
"""
@abstractmethod
def get(self, filepath):
pass
@abstractmethod
def get_text(self, filepath):
pass
class MemcachedBackend(BaseStorageBackend):
"""Memcached storage backend.
Attributes:
server_list_cfg (str): Config file for memcached server list.
client_cfg (str): Config file for memcached client.
sys_path (str | None): Additional path to be appended to `sys.path`.
Default: None.
"""
def __init__(self, server_list_cfg, client_cfg, sys_path=None):
if sys_path is not None:
import sys
sys.path.append(sys_path)
try:
import mc
except ImportError:
raise ImportError('Please install memcached to enable MemcachedBackend.')
self.server_list_cfg = server_list_cfg
self.client_cfg = client_cfg
self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
# mc.pyvector servers as a point which points to a memory cache
self._mc_buffer = mc.pyvector()
def get(self, filepath):
filepath = str(filepath)
import mc
self._client.Get(filepath, self._mc_buffer)
value_buf = mc.ConvertBuffer(self._mc_buffer)
return value_buf
def get_text(self, filepath):
raise NotImplementedError
class HardDiskBackend(BaseStorageBackend):
"""Raw hard disks storage backend."""
def get(self, filepath):
filepath = str(filepath)
with open(filepath, 'rb') as f:
value_buf = f.read()
return value_buf
def get_text(self, filepath):
filepath = str(filepath)
with open(filepath, 'r') as f:
value_buf = f.read()
return value_buf
class LmdbBackend(BaseStorageBackend):
"""Lmdb storage backend.
Args:
db_paths (str | list[str]): Lmdb database paths.
client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
readonly (bool, optional): Lmdb environment parameter. If True,
disallow any write operations. Default: True.
lock (bool, optional): Lmdb environment parameter. If False, when
concurrent access occurs, do not lock the database. Default: False.
readahead (bool, optional): Lmdb environment parameter. If False,
disable the OS filesystem readahead mechanism, which may improve
random read performance when a database is larger than RAM.
Default: False.
Attributes:
db_paths (list): Lmdb database path.
_client (list): A list of several lmdb envs.
"""
def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs):
try:
import lmdb
except ImportError:
raise ImportError('Please install lmdb to enable LmdbBackend.')
if isinstance(client_keys, str):
client_keys = [client_keys]
if isinstance(db_paths, list):
self.db_paths = [str(v) for v in db_paths]
elif isinstance(db_paths, str):
self.db_paths = [str(db_paths)]
assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
f'but received {len(client_keys)} and {len(self.db_paths)}.')
self._client = {}
for client, path in zip(client_keys, self.db_paths):
self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
def get(self, filepath, client_key):
"""Get values according to the filepath from one lmdb named client_key.
Args:
filepath (str | obj:`Path`): Here, filepath is the lmdb key.
client_key (str): Used for distinguishing differnet lmdb envs.
"""
filepath = str(filepath)
assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.')
client = self._client[client_key]
with client.begin(write=False) as txn:
value_buf = txn.get(filepath.encode('ascii'))
return value_buf
def get_text(self, filepath):
raise NotImplementedError
class FileClient(object):
"""A general file client to access files in different backend.
The client loads a file or text in a specified backend from its path
and return it as a binary file. it can also register other backend
accessor with a given name and backend class.
Attributes:
backend (str): The storage backend type. Options are "disk",
"memcached" and "lmdb".
client (:obj:`BaseStorageBackend`): The backend object.
"""
_backends = {
'disk': HardDiskBackend,
'memcached': MemcachedBackend,
'lmdb': LmdbBackend,
}
def __init__(self, backend='disk', **kwargs):
if backend not in self._backends:
raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
f' are {list(self._backends.keys())}')
self.backend = backend
self.client = self._backends[backend](**kwargs)
def get(self, filepath, client_key='default'):
# client_key is used only for lmdb, where different fileclients have
# different lmdb environments.
if self.backend == 'lmdb':
return self.client.get(filepath, client_key)
else:
return self.client.get(filepath)
def get_text(self, filepath):
return self.client.get_text(filepath)
import cv2
import math
import numpy as np
import os
import torch
from torchvision.utils import make_grid
def img2tensor(imgs, bgr2rgb=True, float32=True):
"""Numpy array to tensor.
Args:
imgs (list[ndarray] | ndarray): Input images.
bgr2rgb (bool): Whether to change bgr to rgb.
float32 (bool): Whether to change to float32.
Returns:
list[tensor] | tensor: Tensor images. If returned results only have
one element, just return tensor.
"""
def _totensor(img, bgr2rgb, float32):
if img.shape[2] == 3 and bgr2rgb:
if img.dtype == 'float64':
img = img.astype('float32')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = torch.from_numpy(img.transpose(2, 0, 1))
if float32:
img = img.float()
return img
if isinstance(imgs, list):
return [_totensor(img, bgr2rgb, float32) for img in imgs]
else:
return _totensor(imgs, bgr2rgb, float32)
def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
"""Convert torch Tensors into image numpy arrays.
After clamping to [min, max], values will be normalized to [0, 1].
Args:
tensor (Tensor or list[Tensor]): Accept shapes:
1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
2) 3D Tensor of shape (3/1 x H x W);
3) 2D Tensor of shape (H x W).
Tensor channel should be in RGB order.
rgb2bgr (bool): Whether to change rgb to bgr.
out_type (numpy type): output types. If ``np.uint8``, transform outputs
to uint8 type with range [0, 255]; otherwise, float type with
range [0, 1]. Default: ``np.uint8``.
min_max (tuple[int]): min and max values for clamp.
Returns:
(Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
shape (H x W). The channel order is BGR.
"""
if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
if torch.is_tensor(tensor):
tensor = [tensor]
result = []
for _tensor in tensor:
_tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
_tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
n_dim = _tensor.dim()
if n_dim == 4:
img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
img_np = img_np.transpose(1, 2, 0)
if rgb2bgr:
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
elif n_dim == 3:
img_np = _tensor.numpy()
img_np = img_np.transpose(1, 2, 0)
if img_np.shape[2] == 1: # gray image
img_np = np.squeeze(img_np, axis=2)
else:
if rgb2bgr:
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
elif n_dim == 2:
img_np = _tensor.numpy()
else:
raise TypeError('Only support 4D, 3D or 2D tensor. ' f'But received with dimension: {n_dim}')
if out_type == np.uint8:
# Unlike MATLAB, numpy.unit8() WILL NOT round by default.
img_np = (img_np * 255.0).round()
img_np = img_np.astype(out_type)
result.append(img_np)
if len(result) == 1:
result = result[0]
return result
def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
"""This implementation is slightly faster than tensor2img.
It now only supports torch tensor with shape (1, c, h, w).
Args:
tensor (Tensor): Now only support torch tensor with (1, c, h, w).
rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
min_max (tuple[int]): min and max values for clamp.
"""
output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0)
output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255
output = output.type(torch.uint8).cpu().numpy()
if rgb2bgr:
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output
def imfrombytes(content, flag='color', float32=False):
"""Read an image from bytes.
Args:
content (bytes): Image bytes got from files or other streams.
flag (str): Flags specifying the color type of a loaded image,
candidates are `color`, `grayscale` and `unchanged`.
float32 (bool): Whether to change to float32., If True, will also norm
to [0, 1]. Default: False.
Returns:
ndarray: Loaded image array.
"""
img_np = np.frombuffer(content, np.uint8)
imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
img = cv2.imdecode(img_np, imread_flags[flag])
if float32:
img = img.astype(np.float32) / 255.
return img
def imwrite(img, file_path, params=None, auto_mkdir=True):
"""Write image to file.
Args:
img (ndarray): Image array to be written.
file_path (str): Image file path.
params (None or list): Same as opencv's :func:`imwrite` interface.
auto_mkdir (bool): If the parent folder of `file_path` does not exist,
whether to create it automatically.
Returns:
bool: Successful or not.
"""
if auto_mkdir:
dir_name = os.path.abspath(os.path.dirname(file_path))
os.makedirs(dir_name, exist_ok=True)
return cv2.imwrite(file_path, img, params)
def crop_border(imgs, crop_border):
"""Crop borders of images.
Args:
imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
crop_border (int): Crop border for each end of height and weight.
Returns:
list[ndarray]: Cropped images.
"""
if crop_border == 0:
return imgs
else:
if isinstance(imgs, list):
return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
else:
return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]
\ No newline at end of file
import cv2
import lmdb
import sys
from multiprocessing import Pool
from os import path as osp
from tqdm import tqdm
def make_lmdb_from_imgs(data_path,
lmdb_path,
img_path_list,
keys,
batch=5000,
compress_level=1,
multiprocessing_read=False,
n_thread=40,
map_size=None):
"""Make lmdb from images.
Contents of lmdb. The file structure is:
example.lmdb
├── data.mdb
├── lock.mdb
├── meta_info.txt
The data.mdb and lock.mdb are standard lmdb files and you can refer to
https://lmdb.readthedocs.io/en/release/ for more details.
The meta_info.txt is a specified txt file to record the meta information
of our datasets. It will be automatically created when preparing
datasets by our provided dataset tools.
Each line in the txt file records 1)image name (with extension),
2)image shape, and 3)compression level, separated by a white space.
For example, the meta information could be:
`000_00000000.png (720,1280,3) 1`, which means:
1) image name (with extension): 000_00000000.png;
2) image shape: (720,1280,3);
3) compression level: 1
We use the image name without extension as the lmdb key.
If `multiprocessing_read` is True, it will read all the images to memory
using multiprocessing. Thus, your server needs to have enough memory.
Args:
data_path (str): Data path for reading images.
lmdb_path (str): Lmdb save path.
img_path_list (str): Image path list.
keys (str): Used for lmdb keys.
batch (int): After processing batch images, lmdb commits.
Default: 5000.
compress_level (int): Compress level when encoding images. Default: 1.
multiprocessing_read (bool): Whether use multiprocessing to read all
the images to memory. Default: False.
n_thread (int): For multiprocessing.
map_size (int | None): Map size for lmdb env. If None, use the
estimated size from images. Default: None
"""
assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, '
f'but got {len(img_path_list)} and {len(keys)}')
print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
print(f'Totoal images: {len(img_path_list)}')
if not lmdb_path.endswith('.lmdb'):
raise ValueError("lmdb_path must end with '.lmdb'.")
if osp.exists(lmdb_path):
print(f'Folder {lmdb_path} already exists. Exit.')
sys.exit(1)
if multiprocessing_read:
# read all the images to memory (multiprocessing)
dataset = {} # use dict to keep the order for multiprocessing
shapes = {}
print(f'Read images with multiprocessing, #thread: {n_thread} ...')
pbar = tqdm(total=len(img_path_list), unit='image')
def callback(arg):
"""get the image data and update pbar."""
key, dataset[key], shapes[key] = arg
pbar.update(1)
pbar.set_description(f'Read {key}')
pool = Pool(n_thread)
for path, key in zip(img_path_list, keys):
pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback)
pool.close()
pool.join()
pbar.close()
print(f'Finish reading {len(img_path_list)} images.')
# create lmdb environment
if map_size is None:
# obtain data size for one image
img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
_, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
data_size_per_img = img_byte.nbytes
print('Data size per image is: ', data_size_per_img)
data_size = data_size_per_img * len(img_path_list)
map_size = data_size * 10
env = lmdb.open(lmdb_path, map_size=map_size)
# write data to lmdb
pbar = tqdm(total=len(img_path_list), unit='chunk')
txn = env.begin(write=True)
txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
for idx, (path, key) in enumerate(zip(img_path_list, keys)):
pbar.update(1)
pbar.set_description(f'Write {key}')
key_byte = key.encode('ascii')
if multiprocessing_read:
img_byte = dataset[key]
h, w, c = shapes[key]
else:
_, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level)
h, w, c = img_shape
txn.put(key_byte, img_byte)
# write meta information
txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
if idx % batch == 0:
txn.commit()
txn = env.begin(write=True)
pbar.close()
txn.commit()
env.close()
txt_file.close()
print('\nFinish writing lmdb.')
def read_img_worker(path, key, compress_level):
"""Read image worker.
Args:
path (str): Image path.
key (str): Image key.
compress_level (int): Compress level when encoding images.
Returns:
str: Image key.
byte: Image byte.
tuple[int]: Image shape.
"""
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if img.ndim == 2:
h, w = img.shape
c = 1
else:
h, w, c = img.shape
_, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
return (key, img_byte, (h, w, c))
class LmdbMaker():
"""LMDB Maker.
Args:
lmdb_path (str): Lmdb save path.
map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
batch (int): After processing batch images, lmdb commits.
Default: 5000.
compress_level (int): Compress level when encoding images. Default: 1.
"""
def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1):
if not lmdb_path.endswith('.lmdb'):
raise ValueError("lmdb_path must end with '.lmdb'.")
if osp.exists(lmdb_path):
print(f'Folder {lmdb_path} already exists. Exit.')
sys.exit(1)
self.lmdb_path = lmdb_path
self.batch = batch
self.compress_level = compress_level
self.env = lmdb.open(lmdb_path, map_size=map_size)
self.txn = self.env.begin(write=True)
self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
self.counter = 0
def put(self, img_byte, key, img_shape):
self.counter += 1
key_byte = key.encode('ascii')
self.txn.put(key_byte, img_byte)
# write meta information
h, w, c = img_shape
self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n')
if self.counter % self.batch == 0:
self.txn.commit()
self.txn = self.env.begin(write=True)
def close(self):
self.txn.commit()
self.env.close()
self.txt_file.close()
import datetime
import logging
import time
from .dist_util import get_dist_info, master_only
initialized_logger = {}
class MessageLogger():
"""Message logger for printing.
Args:
opt (dict): Config. It contains the following keys:
name (str): Exp name.
logger (dict): Contains 'print_freq' (str) for logger interval.
train (dict): Contains 'total_iter' (int) for total iters.
use_tb_logger (bool): Use tensorboard logger.
start_iter (int): Start iter. Default: 1.
tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
"""
def __init__(self, opt, start_iter=1, tb_logger=None):
self.exp_name = opt['name']
self.interval = opt['logger']['print_freq']
self.start_iter = start_iter
self.max_iters = opt['train']['total_iter']
self.use_tb_logger = opt['logger']['use_tb_logger']
self.tb_logger = tb_logger
self.start_time = time.time()
self.logger = get_root_logger()
@master_only
def __call__(self, log_vars):
"""Format logging message.
Args:
log_vars (dict): It contains the following keys:
epoch (int): Epoch number.
iter (int): Current iter.
lrs (list): List for learning rates.
time (float): Iter time.
data_time (float): Data time for each iter.
"""
# epoch, iter, learning rates
epoch = log_vars.pop('epoch')
current_iter = log_vars.pop('iter')
lrs = log_vars.pop('lrs')
message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' f'iter:{current_iter:8,d}, lr:(')
for v in lrs:
message += f'{v:.3e},'
message += ')] '
# time and estimated time
if 'time' in log_vars.keys():
iter_time = log_vars.pop('time')
data_time = log_vars.pop('data_time')
total_time = time.time() - self.start_time
time_sec_avg = total_time / (current_iter - self.start_iter + 1)
eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
message += f'[eta: {eta_str}, '
message += f'time (data): {iter_time:.3f} ({data_time:.3f})] '
# other items, especially losses
for k, v in log_vars.items():
message += f'{k}: {v:.4e} '
# tensorboard logger
if self.use_tb_logger:
# if k.startswith('l_'):
# self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
# else:
self.tb_logger.add_scalar(k, v, current_iter)
self.logger.info(message)
@master_only
def init_tb_logger(log_dir):
from torch.utils.tensorboard import SummaryWriter
tb_logger = SummaryWriter(log_dir=log_dir)
return tb_logger
@master_only
def init_wandb_logger(opt):
"""We now only use wandb to sync tensorboard log."""
import wandb
logger = logging.getLogger('basicsr')
project = opt['logger']['wandb']['project']
resume_id = opt['logger']['wandb'].get('resume_id')
if resume_id:
wandb_id = resume_id
resume = 'allow'
logger.warning(f'Resume wandb logger with id={wandb_id}.')
else:
wandb_id = wandb.util.generate_id()
resume = 'never'
wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True)
logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')
def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
"""Get the root logger.
The logger will be initialized if it has not been initialized. By default a
StreamHandler will be added. If `log_file` is specified, a FileHandler will
also be added.
Args:
logger_name (str): root logger name. Default: 'basicsr'.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the root logger.
log_level (int): The root logger level. Note that only the process of
rank 0 is affected, while other processes will set the level to
"Error" and be silent most of the time.
Returns:
logging.Logger: The root logger.
"""
logger = logging.getLogger(logger_name)
# if the logger has been initialized, just return it
if logger_name in initialized_logger:
return logger
format_str = '%(asctime)s %(levelname)s: %(message)s'
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter(format_str))
logger.addHandler(stream_handler)
logger.propagate = False
rank, _ = get_dist_info()
if rank != 0:
logger.setLevel('ERROR')
elif log_file is not None:
logger.setLevel(log_level)
# add file handler
# file_handler = logging.FileHandler(log_file, 'w')
file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log
file_handler.setFormatter(logging.Formatter(format_str))
file_handler.setLevel(log_level)
logger.addHandler(file_handler)
initialized_logger[logger_name] = True
return logger
def get_env_info():
"""Get environment information.
Currently, only log the software version.
"""
import torch
import torchvision
from basicsr.version import __version__
msg = r"""
____ _ _____ ____
/ __ ) ____ _ _____ (_)_____/ ___/ / __ \
/ __ |/ __ `// ___// // ___/\__ \ / /_/ /
/ /_/ // /_/ /(__ )/ // /__ ___/ // _, _/
/_____/ \__,_//____//_/ \___//____//_/ |_|
______ __ __ __ __
/ ____/____ ____ ____/ / / / __ __ _____ / /__ / /
/ / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / /
/ /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/
\____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_)
"""
msg += ('\nVersion Information: '
f'\n\tBasicSR: {__version__}'
f'\n\tPyTorch: {torch.__version__}'
f'\n\tTorchVision: {torchvision.__version__}')
return msg
\ No newline at end of file
import math
import numpy as np
import torch
def cubic(x):
"""cubic function used for calculate_weights_indices."""
absx = torch.abs(x)
absx2 = absx**2
absx3 = absx**3
return (1.5 * absx3 - 2.5 * absx2 + 1) * (
(absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) *
(absx <= 2)).type_as(absx))
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
"""Calculate weights and indices, used for imresize function.
Args:
in_length (int): Input length.
out_length (int): Output length.
scale (float): Scale factor.
kernel_width (int): Kernel width.
antialisaing (bool): Whether to apply anti-aliasing when downsampling.
"""
if (scale < 1) and antialiasing:
# Use a modified kernel (larger kernel width) to simultaneously
# interpolate and antialias
kernel_width = kernel_width / scale
# Output-space coordinates
x = torch.linspace(1, out_length, out_length)
# Input-space coordinates. Calculate the inverse mapping such that 0.5
# in output space maps to 0.5 in input space, and 0.5 + scale in output
# space maps to 1.5 in input space.
u = x / scale + 0.5 * (1 - 1 / scale)
# What is the left-most pixel that can be involved in the computation?
left = torch.floor(u - kernel_width / 2)
# What is the maximum number of pixels that can be involved in the
# computation? Note: it's OK to use an extra pixel here; if the
# corresponding weights are all zero, it will be eliminated at the end
# of this function.
p = math.ceil(kernel_width) + 2
# The indices of the input pixels involved in computing the k-th output
# pixel are in row k of the indices matrix.
indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
out_length, p)
# The weights used to compute the k-th output pixel are in row k of the
# weights matrix.
distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
# apply cubic kernel
if (scale < 1) and antialiasing:
weights = scale * cubic(distance_to_center * scale)
else:
weights = cubic(distance_to_center)
# Normalize the weights matrix so that each row sums to 1.
weights_sum = torch.sum(weights, 1).view(out_length, 1)
weights = weights / weights_sum.expand(out_length, p)
# If a column in weights is all zero, get rid of it. only consider the
# first and last column.
weights_zero_tmp = torch.sum((weights == 0), 0)
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
indices = indices.narrow(1, 1, p - 2)
weights = weights.narrow(1, 1, p - 2)
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
indices = indices.narrow(1, 0, p - 2)
weights = weights.narrow(1, 0, p - 2)
weights = weights.contiguous()
indices = indices.contiguous()
sym_len_s = -indices.min() + 1
sym_len_e = indices.max() - in_length
indices = indices + sym_len_s - 1
return weights, indices, int(sym_len_s), int(sym_len_e)
@torch.no_grad()
def imresize(img, scale, antialiasing=True):
"""imresize function same as MATLAB.
It now only supports bicubic.
The same scale applies for both height and width.
Args:
img (Tensor | Numpy array):
Tensor: Input image with shape (c, h, w), [0, 1] range.
Numpy: Input image with shape (h, w, c), [0, 1] range.
scale (float): Scale factor. The same scale applies for both height
and width.
antialisaing (bool): Whether to apply anti-aliasing when downsampling.
Default: True.
Returns:
Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round.
"""
if type(img).__module__ == np.__name__: # numpy type
numpy_type = True
img = torch.from_numpy(img.transpose(2, 0, 1)).float()
else:
numpy_type = False
in_c, in_h, in_w = img.size()
out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale)
kernel_width = 4
kernel = 'cubic'
# get weights and indices
weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width,
antialiasing)
weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width,
antialiasing)
# process H dimension
# symmetric copying
img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
img_aug.narrow(1, sym_len_hs, in_h).copy_(img)
sym_patch = img[:, :sym_len_hs, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
sym_patch = img[:, -sym_len_he:, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
out_1 = torch.FloatTensor(in_c, out_h, in_w)
kernel_width = weights_h.size(1)
for i in range(out_h):
idx = int(indices_h[i][0])
for j in range(in_c):
out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
# process W dimension
# symmetric copying
out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
sym_patch = out_1[:, :, :sym_len_ws]
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(2, inv_idx)
out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
sym_patch = out_1[:, :, -sym_len_we:]
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(2, inv_idx)
out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
out_2 = torch.FloatTensor(in_c, out_h, out_w)
kernel_width = weights_w.size(1)
for i in range(out_w):
idx = int(indices_w[i][0])
for j in range(in_c):
out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
if numpy_type:
out_2 = out_2.numpy().transpose(1, 2, 0)
return out_2
def rgb2ycbcr(img, y_only=False):
"""Convert a RGB image to YCbCr image.
This function produces the same results as Matlab's `rgb2ycbcr` function.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
y_only (bool): Whether to only return Y channel. Default: False.
Returns:
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img)
if y_only:
out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
else:
out_img = np.matmul(
img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128]
out_img = _convert_output_type_range(out_img, img_type)
return out_img
def bgr2ycbcr(img, y_only=False):
"""Convert a BGR image to YCbCr image.
The bgr version of rgb2ycbcr.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
y_only (bool): Whether to only return Y channel. Default: False.
Returns:
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img)
if y_only:
out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
else:
out_img = np.matmul(
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
out_img = _convert_output_type_range(out_img, img_type)
return out_img
def ycbcr2rgb(img):
"""Convert a YCbCr image to RGB image.
This function produces the same results as Matlab's ycbcr2rgb function.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
Returns:
ndarray: The converted RGB image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img) * 255
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126
out_img = _convert_output_type_range(out_img, img_type)
return out_img
def ycbcr2bgr(img):
"""Convert a YCbCr image to BGR image.
The bgr version of ycbcr2rgb.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
Returns:
ndarray: The converted BGR image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img) * 255
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0],
[0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126
out_img = _convert_output_type_range(out_img, img_type)
return out_img
def _convert_input_type_range(img):
"""Convert the type and range of the input image.
It converts the input image to np.float32 type and range of [0, 1].
It is mainly used for pre-processing the input image in colorspace
convertion functions such as rgb2ycbcr and ycbcr2rgb.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
Returns:
(ndarray): The converted image with type of np.float32 and range of
[0, 1].
"""
img_type = img.dtype
img = img.astype(np.float32)
if img_type == np.float32:
pass
elif img_type == np.uint8:
img /= 255.
else:
raise TypeError('The img type should be np.float32 or np.uint8, ' f'but got {img_type}')
return img
def _convert_output_type_range(img, dst_type):
"""Convert the type and range of the image according to dst_type.
It converts the image to desired type and range. If `dst_type` is np.uint8,
images will be converted to np.uint8 type with range [0, 255]. If
`dst_type` is np.float32, it converts the image to np.float32 type with
range [0, 1].
It is mainly used for post-processing images in colorspace convertion
functions such as rgb2ycbcr and ycbcr2rgb.
Args:
img (ndarray): The image to be converted with np.float32 type and
range [0, 255].
dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
converts the image to np.uint8 type with range [0, 255]. If
dst_type is np.float32, it converts the image to np.float32 type
with range [0, 1].
Returns:
(ndarray): The converted image with desired type and range.
"""
if dst_type not in (np.uint8, np.float32):
raise TypeError('The dst_type should be np.float32 or np.uint8, ' f'but got {dst_type}')
if dst_type == np.uint8:
img = img.round()
else:
img /= 255.
return img.astype(dst_type)
import os
import re
import random
import time
import torch
import numpy as np
from os import path as osp
from .dist_util import master_only
from .logger import get_root_logger
IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
torch.__version__)[0][:3])] >= [1, 12, 0]
def gpu_is_available():
if IS_HIGH_VERSION:
if torch.backends.mps.is_available():
return True
return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False
def get_device(gpu_id=None):
if gpu_id is None:
gpu_str = ''
elif isinstance(gpu_id, int):
gpu_str = f':{gpu_id}'
else:
raise TypeError('Input should be int value.')
if IS_HIGH_VERSION:
if torch.backends.mps.is_available():
return torch.device('mps'+gpu_str)
return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu')
def set_random_seed(seed):
"""Set random seeds."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def get_time_str():
return time.strftime('%Y%m%d_%H%M%S', time.localtime())
def mkdir_and_rename(path):
"""mkdirs. If path exists, rename it with timestamp and create a new one.
Args:
path (str): Folder path.
"""
if osp.exists(path):
new_name = path + '_archived_' + get_time_str()
print(f'Path already exists. Rename it to {new_name}', flush=True)
os.rename(path, new_name)
os.makedirs(path, exist_ok=True)
@master_only
def make_exp_dirs(opt):
"""Make dirs for experiments."""
path_opt = opt['path'].copy()
if opt['is_train']:
mkdir_and_rename(path_opt.pop('experiments_root'))
else:
mkdir_and_rename(path_opt.pop('results_root'))
for key, path in path_opt.items():
if ('strict_load' not in key) and ('pretrain_network' not in key) and ('resume' not in key):
os.makedirs(path, exist_ok=True)
def scandir(dir_path, suffix=None, recursive=False, full_path=False):
"""Scan a directory to find the interested files.
Args:
dir_path (str): Path of the directory.
suffix (str | tuple(str), optional): File suffix that we are
interested in. Default: None.
recursive (bool, optional): If set to True, recursively scan the
directory. Default: False.
full_path (bool, optional): If set to True, include the dir_path.
Default: False.
Returns:
A generator for all the interested files with relative pathes.
"""
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
raise TypeError('"suffix" must be a string or tuple of strings')
root = dir_path
def _scandir(dir_path, suffix, recursive):
for entry in os.scandir(dir_path):
if not entry.name.startswith('.') and entry.is_file():
if full_path:
return_path = entry.path
else:
return_path = osp.relpath(entry.path, root)
if suffix is None:
yield return_path
elif return_path.endswith(suffix):
yield return_path
else:
if recursive:
yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
else:
continue
return _scandir(dir_path, suffix=suffix, recursive=recursive)
def check_resume(opt, resume_iter):
"""Check resume states and pretrain_network paths.
Args:
opt (dict): Options.
resume_iter (int): Resume iteration.
"""
logger = get_root_logger()
if opt['path']['resume_state']:
# get all the networks
networks = [key for key in opt.keys() if key.startswith('network_')]
flag_pretrain = False
for network in networks:
if opt['path'].get(f'pretrain_{network}') is not None:
flag_pretrain = True
if flag_pretrain:
logger.warning('pretrain_network path will be ignored during resuming.')
# set pretrained model paths
for network in networks:
name = f'pretrain_{network}'
basename = network.replace('network_', '')
if opt['path'].get('ignore_resume_networks') is None or (basename
not in opt['path']['ignore_resume_networks']):
opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
logger.info(f"Set {name} to {opt['path'][name]}")
def sizeof_fmt(size, suffix='B'):
"""Get human readable file size.
Args:
size (int): File size.
suffix (str): Suffix. Default: 'B'.
Return:
str: Formated file siz.
"""
for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
if abs(size) < 1024.0:
return f'{size:3.1f} {unit}{suffix}'
size /= 1024.0
return f'{size:3.1f} Y{suffix}'
import yaml
import time
from collections import OrderedDict
from os import path as osp
from basicsr.utils.misc import get_time_str
def ordered_yaml():
"""Support OrderedDict for yaml.
Returns:
yaml Loader and Dumper.
"""
try:
from yaml import CDumper as Dumper
from yaml import CLoader as Loader
except ImportError:
from yaml import Dumper, Loader
_mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
def dict_representer(dumper, data):
return dumper.represent_dict(data.items())
def dict_constructor(loader, node):
return OrderedDict(loader.construct_pairs(node))
Dumper.add_representer(OrderedDict, dict_representer)
Loader.add_constructor(_mapping_tag, dict_constructor)
return Loader, Dumper
def parse(opt_path, root_path, is_train=True):
"""Parse option file.
Args:
opt_path (str): Option file path.
is_train (str): Indicate whether in training or not. Default: True.
Returns:
(dict): Options.
"""
with open(opt_path, mode='r') as f:
Loader, _ = ordered_yaml()
opt = yaml.load(f, Loader=Loader)
opt['is_train'] = is_train
# opt['name'] = f"{get_time_str()}_{opt['name']}"
if opt['path'].get('resume_state', None): # Shangchen added
resume_state_path = opt['path'].get('resume_state')
opt['name'] = resume_state_path.split("/")[-3]
else:
opt['name'] = f"{get_time_str()}_{opt['name']}"
# datasets
for phase, dataset in opt['datasets'].items():
# for several datasets, e.g., test_1, test_2
phase = phase.split('_')[0]
dataset['phase'] = phase
if 'scale' in opt:
dataset['scale'] = opt['scale']
if dataset.get('dataroot_gt') is not None:
dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
if dataset.get('dataroot_lq') is not None:
dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
# paths
for key, val in opt['path'].items():
if (val is not None) and ('resume_state' in key or 'pretrain_network' in key):
opt['path'][key] = osp.expanduser(val)
if is_train:
experiments_root = osp.join(root_path, 'experiments', opt['name'])
opt['path']['experiments_root'] = experiments_root
opt['path']['models'] = osp.join(experiments_root, 'models')
opt['path']['training_states'] = osp.join(experiments_root, 'training_states')
opt['path']['log'] = experiments_root
opt['path']['visualization'] = osp.join(experiments_root, 'visualization')
else: # test
results_root = osp.join(root_path, 'results', opt['name'])
opt['path']['results_root'] = results_root
opt['path']['log'] = results_root
opt['path']['visualization'] = osp.join(results_root, 'visualization')
return opt
def dict2str(opt, indent_level=1):
"""dict to string for printing options.
Args:
opt (dict): Option dict.
indent_level (int): Indent level. Default: 1.
Return:
(str): Option string for printing.
"""
msg = '\n'
for k, v in opt.items():
if isinstance(v, dict):
msg += ' ' * (indent_level * 2) + k + ':['
msg += dict2str(v, indent_level + 1)
msg += ' ' * (indent_level * 2) + ']\n'
else:
msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
return msg
import cv2
import math
import numpy as np
import os
import queue
import threading
import torch
from torch.nn import functional as F
from basicsr.utils.download_util import load_file_from_url
from basicsr.utils.misc import get_device
# ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
class RealESRGANer():
"""A helper class for upsampling images with RealESRGAN.
Args:
scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
model (nn.Module): The defined network. Default: None.
tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
input images into tiles, and then process each of them. Finally, they will be merged into one image.
0 denotes for do not use tile. Default: 0.
tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
half (float): Whether to use half precision during inference. Default: False.
"""
def __init__(self,
scale,
model_path,
model=None,
tile=0,
tile_pad=10,
pre_pad=10,
half=False,
device=None,
gpu_id=None):
self.scale = scale
self.tile_size = tile
self.tile_pad = tile_pad
self.pre_pad = pre_pad
self.mod_scale = None
self.half = half
# initialize model
# if gpu_id:
# self.device = torch.device(
# f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
# else:
# self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
self.device = get_device(gpu_id) if device is None else device
# if the model_path starts with https, it will first download models to the folder: realesrgan/weights
if model_path.startswith('https://'):
model_path = load_file_from_url(
url=model_path, model_dir=os.path.join('weights/realesrgan'), progress=True, file_name=None)
loadnet = torch.load(model_path, map_location=torch.device('cpu'))
# prefer to use params_ema
if 'params_ema' in loadnet:
keyname = 'params_ema'
else:
keyname = 'params'
model.load_state_dict(loadnet[keyname], strict=True)
model.eval()
self.model = model.to(self.device)
if self.half:
self.model = self.model.half()
def pre_process(self, img):
"""Pre-process, such as pre-pad and mod pad, so that the images can be divisible
"""
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
self.img = img.unsqueeze(0).to(self.device)
if self.half:
self.img = self.img.half()
# pre_pad
if self.pre_pad != 0:
self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
# mod pad for divisible borders
if self.scale == 2:
self.mod_scale = 2
elif self.scale == 1:
self.mod_scale = 4
if self.mod_scale is not None:
self.mod_pad_h, self.mod_pad_w = 0, 0
_, _, h, w = self.img.size()
if (h % self.mod_scale != 0):
self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
if (w % self.mod_scale != 0):
self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
def process(self):
# model inference
self.output = self.model(self.img)
def tile_process(self):
"""It will first crop input images to tiles, and then process each tile.
Finally, all the processed tiles are merged into one images.
Modified from: https://github.com/ata4/esrgan-launcher
"""
batch, channel, height, width = self.img.shape
output_height = height * self.scale
output_width = width * self.scale
output_shape = (batch, channel, output_height, output_width)
# start with black image
self.output = self.img.new_zeros(output_shape)
tiles_x = math.ceil(width / self.tile_size)
tiles_y = math.ceil(height / self.tile_size)
# loop over all tiles
for y in range(tiles_y):
for x in range(tiles_x):
# extract tile from input image
ofs_x = x * self.tile_size
ofs_y = y * self.tile_size
# input tile area on total image
input_start_x = ofs_x
input_end_x = min(ofs_x + self.tile_size, width)
input_start_y = ofs_y
input_end_y = min(ofs_y + self.tile_size, height)
# input tile area on total image with padding
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
input_end_x_pad = min(input_end_x + self.tile_pad, width)
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
input_end_y_pad = min(input_end_y + self.tile_pad, height)
# input tile dimensions
input_tile_width = input_end_x - input_start_x
input_tile_height = input_end_y - input_start_y
tile_idx = y * tiles_x + x + 1
input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
# upscale tile
try:
with torch.no_grad():
output_tile = self.model(input_tile)
except RuntimeError as error:
print('Error', error)
# print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
# output tile area on total image
output_start_x = input_start_x * self.scale
output_end_x = input_end_x * self.scale
output_start_y = input_start_y * self.scale
output_end_y = input_end_y * self.scale
# output tile area without padding
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
# put tile into output image
self.output[:, :, output_start_y:output_end_y,
output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
output_start_x_tile:output_end_x_tile]
def post_process(self):
# remove extra pad
if self.mod_scale is not None:
_, _, h, w = self.output.size()
self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
# remove prepad
if self.pre_pad != 0:
_, _, h, w = self.output.size()
self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
return self.output
@torch.no_grad()
def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
h_input, w_input = img.shape[0:2]
# img: numpy
img = img.astype(np.float32)
if np.max(img) > 256: # 16-bit image
max_range = 65535
print('\tInput is a 16-bit image')
else:
max_range = 255
img = img / max_range
if len(img.shape) == 2: # gray image
img_mode = 'L'
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
elif img.shape[2] == 4: # RGBA image with alpha channel
img_mode = 'RGBA'
alpha = img[:, :, 3]
img = img[:, :, 0:3]
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if alpha_upsampler == 'realesrgan':
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
else:
img_mode = 'RGB'
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# ------------------- process image (without the alpha channel) ------------------- #
try:
with torch.no_grad():
self.pre_process(img)
if self.tile_size > 0:
self.tile_process()
else:
self.process()
output_img_t = self.post_process()
output_img = output_img_t.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
if img_mode == 'L':
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
del output_img_t
torch.cuda.empty_cache()
except RuntimeError as error:
print(f"Failed inference for RealESRGAN: {error}")
# ------------------- process the alpha channel if necessary ------------------- #
if img_mode == 'RGBA':
if alpha_upsampler == 'realesrgan':
self.pre_process(alpha)
if self.tile_size > 0:
self.tile_process()
else:
self.process()
output_alpha = self.post_process()
output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
else: # use the cv2 resize for alpha channel
h, w = alpha.shape[0:2]
output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
# merge the alpha channel
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
output_img[:, :, 3] = output_alpha
# ------------------------------ return ------------------------------ #
if max_range == 65535: # 16-bit image
output = (output_img * 65535.0).round().astype(np.uint16)
else:
output = (output_img * 255.0).round().astype(np.uint8)
if outscale is not None and outscale != float(self.scale):
output = cv2.resize(
output, (
int(w_input * outscale),
int(h_input * outscale),
), interpolation=cv2.INTER_LANCZOS4)
return output, img_mode
class PrefetchReader(threading.Thread):
"""Prefetch images.
Args:
img_list (list[str]): A image list of image paths to be read.
num_prefetch_queue (int): Number of prefetch queue.
"""
def __init__(self, img_list, num_prefetch_queue):
super().__init__()
self.que = queue.Queue(num_prefetch_queue)
self.img_list = img_list
def run(self):
for img_path in self.img_list:
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
self.que.put(img)
self.que.put(None)
def __next__(self):
next_item = self.que.get()
if next_item is None:
raise StopIteration
return next_item
def __iter__(self):
return self
class IOConsumer(threading.Thread):
def __init__(self, opt, que, qid):
super().__init__()
self._queue = que
self.qid = qid
self.opt = opt
def run(self):
while True:
msg = self._queue.get()
if isinstance(msg, str) and msg == 'quit':
break
output = msg['output']
save_path = msg['save_path']
cv2.imwrite(save_path, output)
print(f'IO worker {self.qid} is done.')
\ No newline at end of file
# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501
class Registry():
"""
The registry that provides name -> object mapping, to support third-party
users' custom modules.
To create a registry (e.g. a backbone registry):
.. code-block:: python
BACKBONE_REGISTRY = Registry('BACKBONE')
To register an object:
.. code-block:: python
@BACKBONE_REGISTRY.register()
class MyBackbone():
...
Or:
.. code-block:: python
BACKBONE_REGISTRY.register(MyBackbone)
"""
def __init__(self, name):
"""
Args:
name (str): the name of this registry
"""
self._name = name
self._obj_map = {}
def _do_register(self, name, obj):
assert (name not in self._obj_map), (f"An object named '{name}' was already registered "
f"in '{self._name}' registry!")
self._obj_map[name] = obj
def register(self, obj=None):
"""
Register the given object under the the name `obj.__name__`.
Can be used as either a decorator or not.
See docstring of this class for usage.
"""
if obj is None:
# used as a decorator
def deco(func_or_class):
name = func_or_class.__name__
self._do_register(name, func_or_class)
return func_or_class
return deco
# used as a function call
name = obj.__name__
self._do_register(name, obj)
def get(self, name):
ret = self._obj_map.get(name)
if ret is None:
raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
return ret
def __contains__(self, name):
return name in self._obj_map
def __iter__(self):
return iter(self._obj_map.items())
def keys(self):
return self._obj_map.keys()
DATASET_REGISTRY = Registry('dataset')
ARCH_REGISTRY = Registry('arch')
MODEL_REGISTRY = Registry('model')
LOSS_REGISTRY = Registry('loss')
METRIC_REGISTRY = Registry('metric')
'''
The code is modified from the Real-ESRGAN:
https://github.com/xinntao/Real-ESRGAN/blob/master/inference_realesrgan_video.py
'''
import cv2
import sys
import numpy as np
try:
import ffmpeg
except ImportError:
import pip
pip.main(['install', '--user', 'ffmpeg-python'])
import ffmpeg
def get_video_meta_info(video_path):
ret = {}
probe = ffmpeg.probe(video_path)
video_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'video']
has_audio = any(stream['codec_type'] == 'audio' for stream in probe['streams'])
ret['width'] = video_streams[0]['width']
ret['height'] = video_streams[0]['height']
ret['fps'] = eval(video_streams[0]['avg_frame_rate'])
ret['audio'] = ffmpeg.input(video_path).audio if has_audio else None
ret['nb_frames'] = int(video_streams[0]['nb_frames'])
return ret
class VideoReader:
def __init__(self, video_path):
self.paths = [] # for image&folder type
self.audio = None
try:
self.stream_reader = (
ffmpeg.input(video_path).output('pipe:', format='rawvideo', pix_fmt='bgr24',
loglevel='error').run_async(
pipe_stdin=True, pipe_stdout=True, cmd='ffmpeg'))
except FileNotFoundError:
print('Please install ffmpeg (not ffmpeg-python) by running\n',
'\t$ conda install -c conda-forge ffmpeg')
sys.exit(0)
meta = get_video_meta_info(video_path)
self.width = meta['width']
self.height = meta['height']
self.input_fps = meta['fps']
self.audio = meta['audio']
self.nb_frames = meta['nb_frames']
self.idx = 0
def get_resolution(self):
return self.height, self.width
def get_fps(self):
if self.input_fps is not None:
return self.input_fps
return 24
def get_audio(self):
return self.audio
def __len__(self):
return self.nb_frames
def get_frame_from_stream(self):
img_bytes = self.stream_reader.stdout.read(self.width * self.height * 3) # 3 bytes for one pixel
if not img_bytes:
return None
img = np.frombuffer(img_bytes, np.uint8).reshape([self.height, self.width, 3])
return img
def get_frame_from_list(self):
if self.idx >= self.nb_frames:
return None
img = cv2.imread(self.paths[self.idx])
self.idx += 1
return img
def get_frame(self):
return self.get_frame_from_stream()
def close(self):
self.stream_reader.stdin.close()
self.stream_reader.wait()
class VideoWriter:
def __init__(self, video_save_path, height, width, fps, audio):
if height > 2160:
print('You are generating video that is larger than 4K, which will be very slow due to IO speed.',
'We highly recommend to decrease the outscale(aka, -s).')
if audio is not None:
self.stream_writer = (
ffmpeg.input('pipe:', format='rawvideo', pix_fmt='bgr24', s=f'{width}x{height}',
framerate=fps).output(
audio,
video_save_path,
pix_fmt='yuv420p',
vcodec='libx264',
loglevel='error',
acodec='copy').overwrite_output().run_async(
pipe_stdin=True, pipe_stdout=True, cmd='ffmpeg'))
else:
self.stream_writer = (
ffmpeg.input('pipe:', format='rawvideo', pix_fmt='bgr24', s=f'{width}x{height}',
framerate=fps).output(
video_save_path, pix_fmt='yuv420p', vcodec='libx264',
loglevel='error').overwrite_output().run_async(
pipe_stdin=True, pipe_stdout=True, cmd='ffmpeg'))
def write_frame(self, frame):
try:
frame = frame.astype(np.uint8).tobytes()
self.stream_writer.stdin.write(frame)
except BrokenPipeError:
print('Please re-install ffmpeg and libx264 by running\n',
'\t$ conda install -c conda-forge ffmpeg\n',
'\t$ conda install -c conda-forge x264')
sys.exit(0)
def close(self):
self.stream_writer.stdin.close()
self.stream_writer.wait()
\ No newline at end of file
"""
FFHQ 1024 x 1024 -> 512 x 512
"""
import os
from zipfile import ZipFile
from PIL import Image
from glob import glob
from tqdm import tqdm
# 解压缩并转换图片大小
def process_in_zip(zip_dir: str,
output_dir: str):
for zfile in glob(os.path.join(zip_dir, "*.zip")):
with ZipFile(zfile, "r") as zip_ref:
print("extract from ", zfile)
for file_info in tqdm(zip_ref.infolist()):
file_name = file_info.filename
if file_name.lower().endswith(('.png', 'jpg', 'jpeg')):
with zip_ref.open(file_name) as file:
ori_image = Image.open(file)
resized_image = ori_image.resize((512, 512), Image.LANCZOS)
output_path = os.path.join(output_dir, file_name)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
resized_image.save(output_path)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--zip_dir", type=str)
parser.add_argument("--output_dir", type=str)
args = parser.parse_args()
process_in_zip(args.zip_dir, args.output_dir)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment