Unverified Commit 8cfd677c authored by Joey Ballentine's avatar Joey Ballentine Committed by GitHub
Browse files

Replace chainner_models with Spandrel package (#2146)

* Replace chainner_models with Spandrel

* Update to latest spandrel

* Use spandrel_foss instead

* update spandrel to new FOSS-compliant version
parent ffc4b7c3
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: layernorm.py
# Created Date: Tuesday April 28th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Thursday, 20th April 2023 9:28:20 am
# Modified By: Chen Xuanhong
# Copyright (c) 2020 Shanghai Jiao Tong University
#############################################################
import torch
import torch.nn as nn
class LayerNormFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, bias, eps):
ctx.eps = eps
N, C, H, W = x.size()
mu = x.mean(1, keepdim=True)
var = (x - mu).pow(2).mean(1, keepdim=True)
y = (x - mu) / (var + eps).sqrt()
ctx.save_for_backward(y, var, weight)
y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
return y
@staticmethod
def backward(ctx, grad_output):
eps = ctx.eps
N, C, H, W = grad_output.size()
y, var, weight = ctx.saved_variables
g = grad_output * weight.view(1, C, 1, 1)
mean_g = g.mean(dim=1, keepdim=True)
mean_gy = (g * y).mean(dim=1, keepdim=True)
gx = 1.0 / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
return (
gx,
(grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0),
grad_output.sum(dim=3).sum(dim=2).sum(dim=0),
None,
)
class LayerNorm2d(nn.Module):
def __init__(self, channels, eps=1e-6):
super(LayerNorm2d, self).__init__()
self.register_parameter("weight", nn.Parameter(torch.ones(channels)))
self.register_parameter("bias", nn.Parameter(torch.zeros(channels)))
self.eps = eps
def forward(self, x):
return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
class GRN(nn.Module):
"""GRN (Global Response Normalization) layer"""
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1))
self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=(2, 3), keepdim=True)
Nx = Gx / (Gx.mean(dim=1, keepdim=True) + 1e-6)
return self.gamma * (x * Nx) + self.beta + x
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: pixelshuffle.py
# Created Date: Friday July 1st 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Friday, 1st July 2022 10:18:39 am
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import torch.nn as nn
def pixelshuffle_block(
in_channels, out_channels, upscale_factor=2, kernel_size=3, bias=False
):
"""
Upsample features according to `upscale_factor`.
"""
padding = kernel_size // 2
conv = nn.Conv2d(
in_channels,
out_channels * (upscale_factor**2),
kernel_size,
padding=1,
bias=bias,
)
pixel_shuffle = nn.PixelShuffle(upscale_factor)
return nn.Sequential(*[conv, pixel_shuffle])
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import functools
import math
import re
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import block as B
# Borrowed from https://github.com/rlaphoenix/VSGAN/blob/master/vsgan/archs/ESRGAN.py
# Which enhanced stuff that was already here
class RRDBNet(nn.Module):
def __init__(
self,
state_dict,
norm=None,
act: str = "leakyrelu",
upsampler: str = "upconv",
mode: B.ConvMode = "CNA",
) -> None:
"""
ESRGAN - Enhanced Super-Resolution Generative Adversarial Networks.
By Xintao Wang, Ke Yu, Shixiang Wu, Jinjin Gu, Yihao Liu, Chao Dong, Yu Qiao,
and Chen Change Loy.
This is old-arch Residual in Residual Dense Block Network and is not
the newest revision that's available at github.com/xinntao/ESRGAN.
This is on purpose, the newest Network has severely limited the
potential use of the Network with no benefits.
This network supports model files from both new and old-arch.
Args:
norm: Normalization layer
act: Activation layer
upsampler: Upsample layer. upconv, pixel_shuffle
mode: Convolution mode
"""
super(RRDBNet, self).__init__()
self.model_arch = "ESRGAN"
self.sub_type = "SR"
self.state = state_dict
self.norm = norm
self.act = act
self.upsampler = upsampler
self.mode = mode
self.state_map = {
# currently supports old, new, and newer RRDBNet arch models
# ESRGAN, BSRGAN/RealSR, Real-ESRGAN
"model.0.weight": ("conv_first.weight",),
"model.0.bias": ("conv_first.bias",),
"model.1.sub./NB/.weight": ("trunk_conv.weight", "conv_body.weight"),
"model.1.sub./NB/.bias": ("trunk_conv.bias", "conv_body.bias"),
r"model.1.sub.\1.RDB\2.conv\3.0.\4": (
r"RRDB_trunk\.(\d+)\.RDB(\d)\.conv(\d+)\.(weight|bias)",
r"body\.(\d+)\.rdb(\d)\.conv(\d+)\.(weight|bias)",
),
}
if "params_ema" in self.state:
self.state = self.state["params_ema"]
# self.model_arch = "RealESRGAN"
self.num_blocks = self.get_num_blocks()
self.plus = any("conv1x1" in k for k in self.state.keys())
if self.plus:
self.model_arch = "ESRGAN+"
self.state = self.new_to_old_arch(self.state)
self.key_arr = list(self.state.keys())
self.in_nc: int = self.state[self.key_arr[0]].shape[1]
self.out_nc: int = self.state[self.key_arr[-1]].shape[0]
self.scale: int = self.get_scale()
self.num_filters: int = self.state[self.key_arr[0]].shape[0]
c2x2 = False
if self.state["model.0.weight"].shape[-2] == 2:
c2x2 = True
self.scale = round(math.sqrt(self.scale / 4))
self.model_arch = "ESRGAN-2c2"
self.supports_fp16 = True
self.supports_bfp16 = True
self.min_size_restriction = None
# Detect if pixelunshuffle was used (Real-ESRGAN)
if self.in_nc in (self.out_nc * 4, self.out_nc * 16) and self.out_nc in (
self.in_nc / 4,
self.in_nc / 16,
):
self.shuffle_factor = int(math.sqrt(self.in_nc / self.out_nc))
else:
self.shuffle_factor = None
upsample_block = {
"upconv": B.upconv_block,
"pixel_shuffle": B.pixelshuffle_block,
}.get(self.upsampler)
if upsample_block is None:
raise NotImplementedError(f"Upsample mode [{self.upsampler}] is not found")
if self.scale == 3:
upsample_blocks = upsample_block(
in_nc=self.num_filters,
out_nc=self.num_filters,
upscale_factor=3,
act_type=self.act,
c2x2=c2x2,
)
else:
upsample_blocks = [
upsample_block(
in_nc=self.num_filters,
out_nc=self.num_filters,
act_type=self.act,
c2x2=c2x2,
)
for _ in range(int(math.log(self.scale, 2)))
]
self.model = B.sequential(
# fea conv
B.conv_block(
in_nc=self.in_nc,
out_nc=self.num_filters,
kernel_size=3,
norm_type=None,
act_type=None,
c2x2=c2x2,
),
B.ShortcutBlock(
B.sequential(
# rrdb blocks
*[
B.RRDB(
nf=self.num_filters,
kernel_size=3,
gc=32,
stride=1,
bias=True,
pad_type="zero",
norm_type=self.norm,
act_type=self.act,
mode="CNA",
plus=self.plus,
c2x2=c2x2,
)
for _ in range(self.num_blocks)
],
# lr conv
B.conv_block(
in_nc=self.num_filters,
out_nc=self.num_filters,
kernel_size=3,
norm_type=self.norm,
act_type=None,
mode=self.mode,
c2x2=c2x2,
),
)
),
*upsample_blocks,
# hr_conv0
B.conv_block(
in_nc=self.num_filters,
out_nc=self.num_filters,
kernel_size=3,
norm_type=None,
act_type=self.act,
c2x2=c2x2,
),
# hr_conv1
B.conv_block(
in_nc=self.num_filters,
out_nc=self.out_nc,
kernel_size=3,
norm_type=None,
act_type=None,
c2x2=c2x2,
),
)
# Adjust these properties for calculations outside of the model
if self.shuffle_factor:
self.in_nc //= self.shuffle_factor**2
self.scale //= self.shuffle_factor
self.load_state_dict(self.state, strict=False)
def new_to_old_arch(self, state):
"""Convert a new-arch model state dictionary to an old-arch dictionary."""
if "params_ema" in state:
state = state["params_ema"]
if "conv_first.weight" not in state:
# model is already old arch, this is a loose check, but should be sufficient
return state
# add nb to state keys
for kind in ("weight", "bias"):
self.state_map[f"model.1.sub.{self.num_blocks}.{kind}"] = self.state_map[
f"model.1.sub./NB/.{kind}"
]
del self.state_map[f"model.1.sub./NB/.{kind}"]
old_state = OrderedDict()
for old_key, new_keys in self.state_map.items():
for new_key in new_keys:
if r"\1" in old_key:
for k, v in state.items():
sub = re.sub(new_key, old_key, k)
if sub != k:
old_state[sub] = v
else:
if new_key in state:
old_state[old_key] = state[new_key]
# upconv layers
max_upconv = 0
for key in state.keys():
match = re.match(r"(upconv|conv_up)(\d)\.(weight|bias)", key)
if match is not None:
_, key_num, key_type = match.groups()
old_state[f"model.{int(key_num) * 3}.{key_type}"] = state[key]
max_upconv = max(max_upconv, int(key_num) * 3)
# final layers
for key in state.keys():
if key in ("HRconv.weight", "conv_hr.weight"):
old_state[f"model.{max_upconv + 2}.weight"] = state[key]
elif key in ("HRconv.bias", "conv_hr.bias"):
old_state[f"model.{max_upconv + 2}.bias"] = state[key]
elif key in ("conv_last.weight",):
old_state[f"model.{max_upconv + 4}.weight"] = state[key]
elif key in ("conv_last.bias",):
old_state[f"model.{max_upconv + 4}.bias"] = state[key]
# Sort by first numeric value of each layer
def compare(item1, item2):
parts1 = item1.split(".")
parts2 = item2.split(".")
int1 = int(parts1[1])
int2 = int(parts2[1])
return int1 - int2
sorted_keys = sorted(old_state.keys(), key=functools.cmp_to_key(compare))
# Rebuild the output dict in the right order
out_dict = OrderedDict((k, old_state[k]) for k in sorted_keys)
return out_dict
def get_scale(self, min_part: int = 6) -> int:
n = 0
for part in list(self.state):
parts = part.split(".")[1:]
if len(parts) == 2:
part_num = int(parts[0])
if part_num > min_part and parts[1] == "weight":
n += 1
return 2**n
def get_num_blocks(self) -> int:
nbs = []
state_keys = self.state_map[r"model.1.sub.\1.RDB\2.conv\3.0.\4"] + (
r"model\.\d+\.sub\.(\d+)\.RDB(\d+)\.conv(\d+)\.0\.(weight|bias)",
)
for state_key in state_keys:
for k in self.state:
m = re.search(state_key, k)
if m:
nbs.append(int(m.group(1)))
if nbs:
break
return max(*nbs) + 1
def forward(self, x):
if self.shuffle_factor:
_, _, h, w = x.size()
mod_pad_h = (
self.shuffle_factor - h % self.shuffle_factor
) % self.shuffle_factor
mod_pad_w = (
self.shuffle_factor - w % self.shuffle_factor
) % self.shuffle_factor
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
x = torch.pixel_unshuffle(x, downscale_factor=self.shuffle_factor)
x = self.model(x)
return x[:, :, : h * self.scale, : w * self.scale]
return self.model(x)
# pylint: skip-file
# -----------------------------------------------------------------------------------
# SCUNet: Practical Blind Denoising via Swin-Conv-UNet and Data Synthesis, https://arxiv.org/abs/2203.13278
# Zhang, Kai and Li, Yawei and Liang, Jingyun and Cao, Jiezhang and Zhang, Yulun and Tang, Hao and Timofte, Radu and Van Gool, Luc
# -----------------------------------------------------------------------------------
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from einops.layers.torch import Rearrange
from .timm.drop import DropPath
from .timm.weight_init import trunc_normal_
# Borrowed from https://github.com/cszn/SCUNet/blob/main/models/network_scunet.py
class WMSA(nn.Module):
"""Self-attention module in Swin Transformer"""
def __init__(self, input_dim, output_dim, head_dim, window_size, type):
super(WMSA, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.head_dim = head_dim
self.scale = self.head_dim**-0.5
self.n_heads = input_dim // head_dim
self.window_size = window_size
self.type = type
self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True)
self.relative_position_params = nn.Parameter(
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads)
)
# TODO recover
# self.relative_position_params = nn.Parameter(torch.zeros(self.n_heads, 2 * window_size - 1, 2 * window_size -1))
self.relative_position_params = nn.Parameter(
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads)
)
self.linear = nn.Linear(self.input_dim, self.output_dim)
trunc_normal_(self.relative_position_params, std=0.02)
self.relative_position_params = torch.nn.Parameter(
self.relative_position_params.view(
2 * window_size - 1, 2 * window_size - 1, self.n_heads
)
.transpose(1, 2)
.transpose(0, 1)
)
def generate_mask(self, h, w, p, shift):
"""generating the mask of SW-MSA
Args:
shift: shift parameters in CyclicShift.
Returns:
attn_mask: should be (1 1 w p p),
"""
# supporting square.
attn_mask = torch.zeros(
h,
w,
p,
p,
p,
p,
dtype=torch.bool,
device=self.relative_position_params.device,
)
if self.type == "W":
return attn_mask
s = p - shift
attn_mask[-1, :, :s, :, s:, :] = True
attn_mask[-1, :, s:, :, :s, :] = True
attn_mask[:, -1, :, :s, :, s:] = True
attn_mask[:, -1, :, s:, :, :s] = True
attn_mask = rearrange(
attn_mask, "w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)"
)
return attn_mask
def forward(self, x):
"""Forward pass of Window Multi-head Self-attention module.
Args:
x: input tensor with shape of [b h w c];
attn_mask: attention mask, fill -inf where the value is True;
Returns:
output: tensor shape [b h w c]
"""
if self.type != "W":
x = torch.roll(
x,
shifts=(-(self.window_size // 2), -(self.window_size // 2)),
dims=(1, 2),
)
x = rearrange(
x,
"b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c",
p1=self.window_size,
p2=self.window_size,
)
h_windows = x.size(1)
w_windows = x.size(2)
# square validation
# assert h_windows == w_windows
x = rearrange(
x,
"b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c",
p1=self.window_size,
p2=self.window_size,
)
qkv = self.embedding_layer(x)
q, k, v = rearrange(
qkv, "b nw np (threeh c) -> threeh b nw np c", c=self.head_dim
).chunk(3, dim=0)
sim = torch.einsum("hbwpc,hbwqc->hbwpq", q, k) * self.scale
# Adding learnable relative embedding
sim = sim + rearrange(self.relative_embedding(), "h p q -> h 1 1 p q")
# Using Attn Mask to distinguish different subwindows.
if self.type != "W":
attn_mask = self.generate_mask(
h_windows, w_windows, self.window_size, shift=self.window_size // 2
)
sim = sim.masked_fill_(attn_mask, float("-inf"))
probs = nn.functional.softmax(sim, dim=-1)
output = torch.einsum("hbwij,hbwjc->hbwic", probs, v)
output = rearrange(output, "h b w p c -> b w p (h c)")
output = self.linear(output)
output = rearrange(
output,
"b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c",
w1=h_windows,
p1=self.window_size,
)
if self.type != "W":
output = torch.roll(
output,
shifts=(self.window_size // 2, self.window_size // 2),
dims=(1, 2),
)
return output
def relative_embedding(self):
cord = torch.tensor(
np.array(
[
[i, j]
for i in range(self.window_size)
for j in range(self.window_size)
]
)
)
relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1
# negative is allowed
return self.relative_position_params[
:, relation[:, :, 0].long(), relation[:, :, 1].long()
]
class Block(nn.Module):
def __init__(
self,
input_dim,
output_dim,
head_dim,
window_size,
drop_path,
type="W",
input_resolution=None,
):
"""SwinTransformer Block"""
super(Block, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
assert type in ["W", "SW"]
self.type = type
if input_resolution <= window_size:
self.type = "W"
self.ln1 = nn.LayerNorm(input_dim)
self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.ln2 = nn.LayerNorm(input_dim)
self.mlp = nn.Sequential(
nn.Linear(input_dim, 4 * input_dim),
nn.GELU(),
nn.Linear(4 * input_dim, output_dim),
)
def forward(self, x):
x = x + self.drop_path(self.msa(self.ln1(x)))
x = x + self.drop_path(self.mlp(self.ln2(x)))
return x
class ConvTransBlock(nn.Module):
def __init__(
self,
conv_dim,
trans_dim,
head_dim,
window_size,
drop_path,
type="W",
input_resolution=None,
):
"""SwinTransformer and Conv Block"""
super(ConvTransBlock, self).__init__()
self.conv_dim = conv_dim
self.trans_dim = trans_dim
self.head_dim = head_dim
self.window_size = window_size
self.drop_path = drop_path
self.type = type
self.input_resolution = input_resolution
assert self.type in ["W", "SW"]
if self.input_resolution <= self.window_size:
self.type = "W"
self.trans_block = Block(
self.trans_dim,
self.trans_dim,
self.head_dim,
self.window_size,
self.drop_path,
self.type,
self.input_resolution,
)
self.conv1_1 = nn.Conv2d(
self.conv_dim + self.trans_dim,
self.conv_dim + self.trans_dim,
1,
1,
0,
bias=True,
)
self.conv1_2 = nn.Conv2d(
self.conv_dim + self.trans_dim,
self.conv_dim + self.trans_dim,
1,
1,
0,
bias=True,
)
self.conv_block = nn.Sequential(
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
nn.ReLU(True),
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
)
def forward(self, x):
conv_x, trans_x = torch.split(
self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1
)
conv_x = self.conv_block(conv_x) + conv_x
trans_x = Rearrange("b c h w -> b h w c")(trans_x)
trans_x = self.trans_block(trans_x)
trans_x = Rearrange("b h w c -> b c h w")(trans_x)
res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
x = x + res
return x
class SCUNet(nn.Module):
def __init__(
self,
state_dict,
in_nc=3,
config=[4, 4, 4, 4, 4, 4, 4],
dim=64,
drop_path_rate=0.0,
input_resolution=256,
):
super(SCUNet, self).__init__()
self.model_arch = "SCUNet"
self.sub_type = "SR"
self.num_filters: int = 0
self.state = state_dict
self.config = config
self.dim = dim
self.head_dim = 32
self.window_size = 8
self.in_nc = in_nc
self.out_nc = self.in_nc
self.scale = 1
self.supports_fp16 = True
# drop path rate for each layer
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))]
self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)]
begin = 0
self.m_down1 = [
ConvTransBlock(
dim // 2,
dim // 2,
self.head_dim,
self.window_size,
dpr[i + begin],
"W" if not i % 2 else "SW",
input_resolution,
)
for i in range(config[0])
] + [nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)]
begin += config[0]
self.m_down2 = [
ConvTransBlock(
dim,
dim,
self.head_dim,
self.window_size,
dpr[i + begin],
"W" if not i % 2 else "SW",
input_resolution // 2,
)
for i in range(config[1])
] + [nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)]
begin += config[1]
self.m_down3 = [
ConvTransBlock(
2 * dim,
2 * dim,
self.head_dim,
self.window_size,
dpr[i + begin],
"W" if not i % 2 else "SW",
input_resolution // 4,
)
for i in range(config[2])
] + [nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)]
begin += config[2]
self.m_body = [
ConvTransBlock(
4 * dim,
4 * dim,
self.head_dim,
self.window_size,
dpr[i + begin],
"W" if not i % 2 else "SW",
input_resolution // 8,
)
for i in range(config[3])
]
begin += config[3]
self.m_up3 = [
nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False),
] + [
ConvTransBlock(
2 * dim,
2 * dim,
self.head_dim,
self.window_size,
dpr[i + begin],
"W" if not i % 2 else "SW",
input_resolution // 4,
)
for i in range(config[4])
]
begin += config[4]
self.m_up2 = [
nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False),
] + [
ConvTransBlock(
dim,
dim,
self.head_dim,
self.window_size,
dpr[i + begin],
"W" if not i % 2 else "SW",
input_resolution // 2,
)
for i in range(config[5])
]
begin += config[5]
self.m_up1 = [
nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False),
] + [
ConvTransBlock(
dim // 2,
dim // 2,
self.head_dim,
self.window_size,
dpr[i + begin],
"W" if not i % 2 else "SW",
input_resolution,
)
for i in range(config[6])
]
self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)]
self.m_head = nn.Sequential(*self.m_head)
self.m_down1 = nn.Sequential(*self.m_down1)
self.m_down2 = nn.Sequential(*self.m_down2)
self.m_down3 = nn.Sequential(*self.m_down3)
self.m_body = nn.Sequential(*self.m_body)
self.m_up3 = nn.Sequential(*self.m_up3)
self.m_up2 = nn.Sequential(*self.m_up2)
self.m_up1 = nn.Sequential(*self.m_up1)
self.m_tail = nn.Sequential(*self.m_tail)
# self.apply(self._init_weights)
self.load_state_dict(state_dict, strict=True)
def check_image_size(self, x):
_, _, h, w = x.size()
mod_pad_h = (64 - h % 64) % 64
mod_pad_w = (64 - w % 64) % 64
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
return x
def forward(self, x0):
h, w = x0.size()[-2:]
x0 = self.check_image_size(x0)
x1 = self.m_head(x0)
x2 = self.m_down1(x1)
x3 = self.m_down2(x2)
x4 = self.m_down3(x3)
x = self.m_body(x4)
x = self.m_up3(x + x4)
x = self.m_up2(x + x3)
x = self.m_up1(x + x2)
x = self.m_tail(x + x1)
x = x[:, :, :h, :w]
return x
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import block as B
class Get_gradient_nopadding(nn.Module):
def __init__(self):
super(Get_gradient_nopadding, self).__init__()
kernel_v = [[0, -1, 0], [0, 0, 0], [0, 1, 0]]
kernel_h = [[0, 0, 0], [-1, 0, 1], [0, 0, 0]]
kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
self.weight_h = nn.Parameter(data=kernel_h, requires_grad=False) # type: ignore
self.weight_v = nn.Parameter(data=kernel_v, requires_grad=False) # type: ignore
def forward(self, x):
x_list = []
for i in range(x.shape[1]):
x_i = x[:, i]
x_i_v = F.conv2d(x_i.unsqueeze(1), self.weight_v, padding=1)
x_i_h = F.conv2d(x_i.unsqueeze(1), self.weight_h, padding=1)
x_i = torch.sqrt(torch.pow(x_i_v, 2) + torch.pow(x_i_h, 2) + 1e-6)
x_list.append(x_i)
x = torch.cat(x_list, dim=1)
return x
class SPSRNet(nn.Module):
def __init__(
self,
state_dict,
norm=None,
act: str = "leakyrelu",
upsampler: str = "upconv",
mode: B.ConvMode = "CNA",
):
super(SPSRNet, self).__init__()
self.model_arch = "SPSR"
self.sub_type = "SR"
self.state = state_dict
self.norm = norm
self.act = act
self.upsampler = upsampler
self.mode = mode
self.num_blocks = self.get_num_blocks()
self.in_nc: int = self.state["model.0.weight"].shape[1]
self.out_nc: int = self.state["f_HR_conv1.0.bias"].shape[0]
self.scale = self.get_scale(4)
self.num_filters: int = self.state["model.0.weight"].shape[0]
self.supports_fp16 = True
self.supports_bfp16 = True
self.min_size_restriction = None
n_upscale = int(math.log(self.scale, 2))
if self.scale == 3:
n_upscale = 1
fea_conv = B.conv_block(
self.in_nc, self.num_filters, kernel_size=3, norm_type=None, act_type=None
)
rb_blocks = [
B.RRDB(
self.num_filters,
kernel_size=3,
gc=32,
stride=1,
bias=True,
pad_type="zero",
norm_type=norm,
act_type=act,
mode="CNA",
)
for _ in range(self.num_blocks)
]
LR_conv = B.conv_block(
self.num_filters,
self.num_filters,
kernel_size=3,
norm_type=norm,
act_type=None,
mode=mode,
)
if upsampler == "upconv":
upsample_block = B.upconv_block
elif upsampler == "pixelshuffle":
upsample_block = B.pixelshuffle_block
else:
raise NotImplementedError(f"upsample mode [{upsampler}] is not found")
if self.scale == 3:
a_upsampler = upsample_block(
self.num_filters, self.num_filters, 3, act_type=act
)
else:
a_upsampler = [
upsample_block(self.num_filters, self.num_filters, act_type=act)
for _ in range(n_upscale)
]
self.HR_conv0_new = B.conv_block(
self.num_filters,
self.num_filters,
kernel_size=3,
norm_type=None,
act_type=act,
)
self.HR_conv1_new = B.conv_block(
self.num_filters,
self.num_filters,
kernel_size=3,
norm_type=None,
act_type=None,
)
self.model = B.sequential(
fea_conv,
B.ShortcutBlockSPSR(B.sequential(*rb_blocks, LR_conv)),
*a_upsampler,
self.HR_conv0_new,
)
self.get_g_nopadding = Get_gradient_nopadding()
self.b_fea_conv = B.conv_block(
self.in_nc, self.num_filters, kernel_size=3, norm_type=None, act_type=None
)
self.b_concat_1 = B.conv_block(
2 * self.num_filters,
self.num_filters,
kernel_size=3,
norm_type=None,
act_type=None,
)
self.b_block_1 = B.RRDB(
self.num_filters * 2,
kernel_size=3,
gc=32,
stride=1,
bias=True,
pad_type="zero",
norm_type=norm,
act_type=act,
mode="CNA",
)
self.b_concat_2 = B.conv_block(
2 * self.num_filters,
self.num_filters,
kernel_size=3,
norm_type=None,
act_type=None,
)
self.b_block_2 = B.RRDB(
self.num_filters * 2,
kernel_size=3,
gc=32,
stride=1,
bias=True,
pad_type="zero",
norm_type=norm,
act_type=act,
mode="CNA",
)
self.b_concat_3 = B.conv_block(
2 * self.num_filters,
self.num_filters,
kernel_size=3,
norm_type=None,
act_type=None,
)
self.b_block_3 = B.RRDB(
self.num_filters * 2,
kernel_size=3,
gc=32,
stride=1,
bias=True,
pad_type="zero",
norm_type=norm,
act_type=act,
mode="CNA",
)
self.b_concat_4 = B.conv_block(
2 * self.num_filters,
self.num_filters,
kernel_size=3,
norm_type=None,
act_type=None,
)
self.b_block_4 = B.RRDB(
self.num_filters * 2,
kernel_size=3,
gc=32,
stride=1,
bias=True,
pad_type="zero",
norm_type=norm,
act_type=act,
mode="CNA",
)
self.b_LR_conv = B.conv_block(
self.num_filters,
self.num_filters,
kernel_size=3,
norm_type=norm,
act_type=None,
mode=mode,
)
if upsampler == "upconv":
upsample_block = B.upconv_block
elif upsampler == "pixelshuffle":
upsample_block = B.pixelshuffle_block
else:
raise NotImplementedError(f"upsample mode [{upsampler}] is not found")
if self.scale == 3:
b_upsampler = upsample_block(
self.num_filters, self.num_filters, 3, act_type=act
)
else:
b_upsampler = [
upsample_block(self.num_filters, self.num_filters, act_type=act)
for _ in range(n_upscale)
]
b_HR_conv0 = B.conv_block(
self.num_filters,
self.num_filters,
kernel_size=3,
norm_type=None,
act_type=act,
)
b_HR_conv1 = B.conv_block(
self.num_filters,
self.num_filters,
kernel_size=3,
norm_type=None,
act_type=None,
)
self.b_module = B.sequential(*b_upsampler, b_HR_conv0, b_HR_conv1)
self.conv_w = B.conv_block(
self.num_filters, self.out_nc, kernel_size=1, norm_type=None, act_type=None
)
self.f_concat = B.conv_block(
self.num_filters * 2,
self.num_filters,
kernel_size=3,
norm_type=None,
act_type=None,
)
self.f_block = B.RRDB(
self.num_filters * 2,
kernel_size=3,
gc=32,
stride=1,
bias=True,
pad_type="zero",
norm_type=norm,
act_type=act,
mode="CNA",
)
self.f_HR_conv0 = B.conv_block(
self.num_filters,
self.num_filters,
kernel_size=3,
norm_type=None,
act_type=act,
)
self.f_HR_conv1 = B.conv_block(
self.num_filters, self.out_nc, kernel_size=3, norm_type=None, act_type=None
)
self.load_state_dict(self.state, strict=False)
def get_scale(self, min_part: int = 4) -> int:
n = 0
for part in list(self.state):
parts = part.split(".")
if len(parts) == 3:
part_num = int(parts[1])
if part_num > min_part and parts[0] == "model" and parts[2] == "weight":
n += 1
return 2**n
def get_num_blocks(self) -> int:
nb = 0
for part in list(self.state):
parts = part.split(".")
n_parts = len(parts)
if n_parts == 5 and parts[2] == "sub":
nb = int(parts[3])
return nb
def forward(self, x):
x_grad = self.get_g_nopadding(x)
x = self.model[0](x)
x, block_list = self.model[1](x)
x_ori = x
for i in range(5):
x = block_list[i](x)
x_fea1 = x
for i in range(5):
x = block_list[i + 5](x)
x_fea2 = x
for i in range(5):
x = block_list[i + 10](x)
x_fea3 = x
for i in range(5):
x = block_list[i + 15](x)
x_fea4 = x
x = block_list[20:](x)
# short cut
x = x_ori + x
x = self.model[2:](x)
x = self.HR_conv1_new(x)
x_b_fea = self.b_fea_conv(x_grad)
x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1)
x_cat_1 = self.b_block_1(x_cat_1)
x_cat_1 = self.b_concat_1(x_cat_1)
x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1)
x_cat_2 = self.b_block_2(x_cat_2)
x_cat_2 = self.b_concat_2(x_cat_2)
x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1)
x_cat_3 = self.b_block_3(x_cat_3)
x_cat_3 = self.b_concat_3(x_cat_3)
x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1)
x_cat_4 = self.b_block_4(x_cat_4)
x_cat_4 = self.b_concat_4(x_cat_4)
x_cat_4 = self.b_LR_conv(x_cat_4)
# short cut
x_cat_4 = x_cat_4 + x_b_fea
x_branch = self.b_module(x_cat_4)
# x_out_branch = self.conv_w(x_branch)
########
x_branch_d = x_branch
x_f_cat = torch.cat([x_branch_d, x], dim=1)
x_f_cat = self.f_block(x_f_cat)
x_out = self.f_concat(x_f_cat)
x_out = self.f_HR_conv0(x_out)
x_out = self.f_HR_conv1(x_out)
#########
# return x_out_branch, x_out, x_grad
return x_out
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import math
import torch.nn as nn
import torch.nn.functional as F
class SRVGGNetCompact(nn.Module):
"""A compact VGG-style network structure for super-resolution.
It is a compact network structure, which performs upsampling in the last layer and no convolution is
conducted on the HR feature space.
Args:
num_in_ch (int): Channel number of inputs. Default: 3.
num_out_ch (int): Channel number of outputs. Default: 3.
num_feat (int): Channel number of intermediate features. Default: 64.
num_conv (int): Number of convolution layers in the body network. Default: 16.
upscale (int): Upsampling factor. Default: 4.
act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
"""
def __init__(
self,
state_dict,
act_type: str = "prelu",
):
super(SRVGGNetCompact, self).__init__()
self.model_arch = "SRVGG (RealESRGAN)"
self.sub_type = "SR"
self.act_type = act_type
self.state = state_dict
if "params" in self.state:
self.state = self.state["params"]
self.key_arr = list(self.state.keys())
self.in_nc = self.get_in_nc()
self.num_feat = self.get_num_feats()
self.num_conv = self.get_num_conv()
self.out_nc = self.in_nc # :(
self.pixelshuffle_shape = None # Defined in get_scale()
self.scale = self.get_scale()
self.supports_fp16 = True
self.supports_bfp16 = True
self.min_size_restriction = None
self.body = nn.ModuleList()
# the first conv
self.body.append(nn.Conv2d(self.in_nc, self.num_feat, 3, 1, 1))
# the first activation
if act_type == "relu":
activation = nn.ReLU(inplace=True)
elif act_type == "prelu":
activation = nn.PReLU(num_parameters=self.num_feat)
elif act_type == "leakyrelu":
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.body.append(activation) # type: ignore
# the body structure
for _ in range(self.num_conv):
self.body.append(nn.Conv2d(self.num_feat, self.num_feat, 3, 1, 1))
# activation
if act_type == "relu":
activation = nn.ReLU(inplace=True)
elif act_type == "prelu":
activation = nn.PReLU(num_parameters=self.num_feat)
elif act_type == "leakyrelu":
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.body.append(activation) # type: ignore
# the last conv
self.body.append(nn.Conv2d(self.num_feat, self.pixelshuffle_shape, 3, 1, 1)) # type: ignore
# upsample
self.upsampler = nn.PixelShuffle(self.scale)
self.load_state_dict(self.state, strict=False)
def get_num_conv(self) -> int:
return (int(self.key_arr[-1].split(".")[1]) - 2) // 2
def get_num_feats(self) -> int:
return self.state[self.key_arr[0]].shape[0]
def get_in_nc(self) -> int:
return self.state[self.key_arr[0]].shape[1]
def get_scale(self) -> int:
self.pixelshuffle_shape = self.state[self.key_arr[-1]].shape[0]
# Assume out_nc is the same as in_nc
# I cant think of a better way to do that
self.out_nc = self.in_nc
scale = math.sqrt(self.pixelshuffle_shape / self.out_nc)
if scale - int(scale) > 0:
print(
"out_nc is probably different than in_nc, scale calculation might be wrong"
)
scale = int(scale)
return scale
def forward(self, x):
out = x
for i in range(0, len(self.body)):
out = self.body[i](out)
out = self.upsampler(out)
# add the nearest upsampled image, so that the network learns the residual
base = F.interpolate(x, scale_factor=self.scale, mode="nearest")
out += base
return out
# From https://github.com/Koushik0901/Swift-SRGAN/blob/master/swift-srgan/models.py
import torch
from torch import nn
class SeperableConv2d(nn.Module):
def __init__(
self, in_channels, out_channels, kernel_size, stride=1, padding=1, bias=True
):
super(SeperableConv2d, self).__init__()
self.depthwise = nn.Conv2d(
in_channels,
in_channels,
kernel_size=kernel_size,
stride=stride,
groups=in_channels,
bias=bias,
padding=padding,
)
self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias)
def forward(self, x):
return self.pointwise(self.depthwise(x))
class ConvBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
use_act=True,
use_bn=True,
discriminator=False,
**kwargs,
):
super(ConvBlock, self).__init__()
self.use_act = use_act
self.cnn = SeperableConv2d(in_channels, out_channels, **kwargs, bias=not use_bn)
self.bn = nn.BatchNorm2d(out_channels) if use_bn else nn.Identity()
self.act = (
nn.LeakyReLU(0.2, inplace=True)
if discriminator
else nn.PReLU(num_parameters=out_channels)
)
def forward(self, x):
return self.act(self.bn(self.cnn(x))) if self.use_act else self.bn(self.cnn(x))
class UpsampleBlock(nn.Module):
def __init__(self, in_channels, scale_factor):
super(UpsampleBlock, self).__init__()
self.conv = SeperableConv2d(
in_channels,
in_channels * scale_factor**2,
kernel_size=3,
stride=1,
padding=1,
)
self.ps = nn.PixelShuffle(
scale_factor
) # (in_channels * 4, H, W) -> (in_channels, H*2, W*2)
self.act = nn.PReLU(num_parameters=in_channels)
def forward(self, x):
return self.act(self.ps(self.conv(x)))
class ResidualBlock(nn.Module):
def __init__(self, in_channels):
super(ResidualBlock, self).__init__()
self.block1 = ConvBlock(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
self.block2 = ConvBlock(
in_channels, in_channels, kernel_size=3, stride=1, padding=1, use_act=False
)
def forward(self, x):
out = self.block1(x)
out = self.block2(out)
return out + x
class Generator(nn.Module):
"""Swift-SRGAN Generator
Args:
in_channels (int): number of input image channels.
num_channels (int): number of hidden channels.
num_blocks (int): number of residual blocks.
upscale_factor (int): factor to upscale the image [2x, 4x, 8x].
Returns:
torch.Tensor: super resolution image
"""
def __init__(
self,
state_dict,
):
super(Generator, self).__init__()
self.model_arch = "Swift-SRGAN"
self.sub_type = "SR"
self.state = state_dict
if "model" in self.state:
self.state = self.state["model"]
self.in_nc: int = self.state["initial.cnn.depthwise.weight"].shape[0]
self.out_nc: int = self.state["final_conv.pointwise.weight"].shape[0]
self.num_filters: int = self.state["initial.cnn.pointwise.weight"].shape[0]
self.num_blocks = len(
set([x.split(".")[1] for x in self.state.keys() if "residual" in x])
)
self.scale: int = 2 ** len(
set([x.split(".")[1] for x in self.state.keys() if "upsampler" in x])
)
in_channels = self.in_nc
num_channels = self.num_filters
num_blocks = self.num_blocks
upscale_factor = self.scale
self.supports_fp16 = True
self.supports_bfp16 = True
self.min_size_restriction = None
self.initial = ConvBlock(
in_channels, num_channels, kernel_size=9, stride=1, padding=4, use_bn=False
)
self.residual = nn.Sequential(
*[ResidualBlock(num_channels) for _ in range(num_blocks)]
)
self.convblock = ConvBlock(
num_channels,
num_channels,
kernel_size=3,
stride=1,
padding=1,
use_act=False,
)
self.upsampler = nn.Sequential(
*[
UpsampleBlock(num_channels, scale_factor=2)
for _ in range(upscale_factor // 2)
]
)
self.final_conv = SeperableConv2d(
num_channels, in_channels, kernel_size=9, stride=1, padding=4
)
self.load_state_dict(self.state, strict=False)
def forward(self, x):
initial = self.initial(x)
x = self.residual(initial)
x = self.convblock(x) + initial
x = self.upsampler(x)
return (torch.tanh(self.final_conv(x)) + 1) / 2
# pylint: skip-file
# -----------------------------------------------------------------------------------
# Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/2209.11345
# Written by Conde and Choi et al.
# From: https://raw.githubusercontent.com/mv-lab/swin2sr/main/models/network_swin2sr.py
# -----------------------------------------------------------------------------------
import math
import re
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
# Originally from the timm package
from .timm.drop import DropPath
from .timm.helpers import to_2tuple
from .timm.weight_init import trunc_normal_
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = (
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(
B, H // window_size, W // window_size, window_size, window_size, -1
)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r"""Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
"""
def __init__(
self,
dim,
window_size,
num_heads,
qkv_bias=True,
attn_drop=0.0,
proj_drop=0.0,
pretrained_window_size=[0, 0],
):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.pretrained_window_size = pretrained_window_size
self.num_heads = num_heads
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True) # type: ignore
# mlp to generate continuous relative position bias
self.cpb_mlp = nn.Sequential(
nn.Linear(2, 512, bias=True),
nn.ReLU(inplace=True),
nn.Linear(512, num_heads, bias=False),
)
# get relative_coords_table
relative_coords_h = torch.arange(
-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32
)
relative_coords_w = torch.arange(
-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32
)
relative_coords_table = (
torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
.permute(1, 2, 0)
.contiguous()
.unsqueeze(0)
) # 1, 2*Wh-1, 2*Ww-1, 2
if pretrained_window_size[0] > 0:
relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1
relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1
else:
relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1
relative_coords_table *= 8 # normalize to -8, 8
relative_coords_table = (
torch.sign(relative_coords_table)
* torch.log2(torch.abs(relative_coords_table) + 1.0)
/ np.log2(8)
)
self.register_buffer("relative_coords_table", relative_coords_table)
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = (
coords_flatten[:, :, None] - coords_flatten[:, None, :]
) # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(
1, 2, 0
).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(dim)) # type: ignore
self.v_bias = nn.Parameter(torch.zeros(dim)) # type: ignore
else:
self.q_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) # type: ignore
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = (
qkv[0],
qkv[1],
qkv[2],
) # make torchscript happy (cannot use tensor as tuple)
# cosine attention
attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
logit_scale = torch.clamp(
self.logit_scale,
max=torch.log(torch.tensor(1.0 / 0.01)).to(self.logit_scale.device),
).exp()
attn = attn * logit_scale
relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(
-1, self.num_heads
)
relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( # type: ignore
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1,
) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1
).contiguous() # nH, Wh*Ww, Wh*Ww
relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
1
).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def extra_repr(self) -> str:
return (
f"dim={self.dim}, window_size={self.window_size}, "
f"pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}"
)
def flops(self, N):
# calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# x = self.proj(x)
flops += N * self.dim * self.dim
return flops
class SwinTransformerBlock(nn.Module):
r"""Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
pretrained_window_size (int): Window size in pre-training.
"""
def __init__(
self,
dim,
input_resolution,
num_heads,
window_size=7,
shift_size=0,
mlp_ratio=4.0,
qkv_bias=True,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
pretrained_window_size=0,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert (
0 <= self.shift_size < self.window_size
), "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim,
window_size=to_2tuple(self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
pretrained_window_size=to_2tuple(pretrained_window_size),
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
if self.shift_size > 0:
attn_mask = self.calculate_mask(self.input_resolution)
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def calculate_mask(self, x_size):
# calculate attention mask for SW-MSA
H, W = x_size
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
w_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(
img_mask, self.window_size
) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
attn_mask == 0, float(0.0)
)
return attn_mask
def forward(self, x, x_size):
H, W = x_size
B, L, C = x.shape
# assert L == H * W, "input feature has wrong size"
shortcut = x
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(
x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
)
else:
shifted_x = x
# partition windows
x_windows = window_partition(
shifted_x, self.window_size
) # nW*B, window_size, window_size, C
x_windows = x_windows.view(
-1, self.window_size * self.window_size, C
) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
if self.input_resolution == x_size:
attn_windows = self.attn(
x_windows, mask=self.attn_mask
) # nW*B, window_size*window_size, C
else:
attn_windows = self.attn(
x_windows, mask=self.calculate_mask(x_size).to(x.device)
)
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(
shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
)
else:
x = shifted_x
x = x.view(B, H * W, C)
x = shortcut + self.drop_path(self.norm1(x))
# FFN
x = x + self.drop_path(self.norm2(self.mlp(x)))
return x
def extra_repr(self) -> str:
return (
f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
)
def flops(self):
flops = 0
H, W = self.input_resolution
# norm1
flops += self.dim * H * W
# W-MSA/SW-MSA
nW = H * W / self.window_size / self.window_size
flops += nW * self.attn.flops(self.window_size * self.window_size)
# mlp
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
# norm2
flops += self.dim * H * W
return flops
class PatchMerging(nn.Module):
r"""Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(2 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.reduction(x)
x = self.norm(x)
return x
def extra_repr(self) -> str:
return f"input_resolution={self.input_resolution}, dim={self.dim}"
def flops(self):
H, W = self.input_resolution
flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
flops += H * W * self.dim // 2
return flops
class BasicLayer(nn.Module):
"""A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
pretrained_window_size (int): Local window size in pre-training.
"""
def __init__(
self,
dim,
input_resolution,
depth,
num_heads,
window_size,
mlp_ratio=4.0,
qkv_bias=True,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=nn.LayerNorm,
downsample=None,
use_checkpoint=False,
pretrained_window_size=0,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList(
[
SwinTransformerBlock(
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i]
if isinstance(drop_path, list)
else drop_path,
norm_layer=norm_layer,
pretrained_window_size=pretrained_window_size,
)
for i in range(depth)
]
)
# patch merging layer
if downsample is not None:
self.downsample = downsample(
input_resolution, dim=dim, norm_layer=norm_layer
)
else:
self.downsample = None
def forward(self, x, x_size):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, x_size)
else:
x = blk(x, x_size)
if self.downsample is not None:
x = self.downsample(x)
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
def flops(self):
flops = 0
for blk in self.blocks:
flops += blk.flops() # type: ignore
if self.downsample is not None:
flops += self.downsample.flops()
return flops
def _init_respostnorm(self):
for blk in self.blocks:
nn.init.constant_(blk.norm1.bias, 0) # type: ignore
nn.init.constant_(blk.norm1.weight, 0) # type: ignore
nn.init.constant_(blk.norm2.bias, 0) # type: ignore
nn.init.constant_(blk.norm2.weight, 0) # type: ignore
class PatchEmbed(nn.Module):
r"""Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(
self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] # type: ignore
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size # type: ignore
)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
# assert H == self.img_size[0] and W == self.img_size[1],
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
def flops(self):
Ho, Wo = self.patches_resolution
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) # type: ignore
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
return flops
class RSTB(nn.Module):
"""Residual Swin Transformer Block (RSTB).
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
img_size: Input image size.
patch_size: Patch size.
resi_connection: The convolutional block before residual connection.
"""
def __init__(
self,
dim,
input_resolution,
depth,
num_heads,
window_size,
mlp_ratio=4.0,
qkv_bias=True,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=nn.LayerNorm,
downsample=None,
use_checkpoint=False,
img_size=224,
patch_size=4,
resi_connection="1conv",
):
super(RSTB, self).__init__()
self.dim = dim
self.input_resolution = input_resolution
self.residual_group = BasicLayer(
dim=dim,
input_resolution=input_resolution,
depth=depth,
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path,
norm_layer=norm_layer,
downsample=downsample,
use_checkpoint=use_checkpoint,
)
if resi_connection == "1conv":
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
elif resi_connection == "3conv":
# to save parameters and memory
self.conv = nn.Sequential(
nn.Conv2d(dim, dim // 4, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(dim // 4, dim, 3, 1, 1),
)
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=dim,
embed_dim=dim,
norm_layer=None,
)
self.patch_unembed = PatchUnEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=dim,
embed_dim=dim,
norm_layer=None,
)
def forward(self, x, x_size):
return (
self.patch_embed(
self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))
)
+ x
)
def flops(self):
flops = 0
flops += self.residual_group.flops()
H, W = self.input_resolution
flops += H * W * self.dim * self.dim * 9
flops += self.patch_embed.flops()
flops += self.patch_unembed.flops()
return flops
class PatchUnEmbed(nn.Module):
r"""Image to Patch Unembedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(
self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] # type: ignore
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
def forward(self, x, x_size):
B, HW, C = x.shape
x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
return x
def flops(self):
flops = 0
return flops
class Upsample(nn.Sequential):
"""Upsample module.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat):
m = []
if (scale & (scale - 1)) == 0: # scale = 2^n
for _ in range(int(math.log(scale, 2))):
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(2))
elif scale == 3:
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(3))
else:
raise ValueError(
f"scale {scale} is not supported. " "Supported scales: 2^n and 3."
)
super(Upsample, self).__init__(*m)
class Upsample_hf(nn.Sequential):
"""Upsample module.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat):
m = []
if (scale & (scale - 1)) == 0: # scale = 2^n
for _ in range(int(math.log(scale, 2))):
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(2))
elif scale == 3:
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(3))
else:
raise ValueError(
f"scale {scale} is not supported. " "Supported scales: 2^n and 3."
)
super(Upsample_hf, self).__init__(*m)
class UpsampleOneStep(nn.Sequential):
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
Used in lightweight SR to save parameters.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
self.num_feat = num_feat
self.input_resolution = input_resolution
m = []
m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1))
m.append(nn.PixelShuffle(scale))
super(UpsampleOneStep, self).__init__(*m)
def flops(self):
H, W = self.input_resolution # type: ignore
flops = H * W * self.num_feat * 3 * 9
return flops
class Swin2SR(nn.Module):
r"""Swin2SR
A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`.
Args:
img_size (int | tuple(int)): Input image size. Default 64
patch_size (int | tuple(int)): Patch size. Default: 1
in_chans (int): Number of input image channels. Default: 3
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
img_range: Image range. 1. or 255.
upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
"""
def __init__(
self,
state_dict,
**kwargs,
):
super(Swin2SR, self).__init__()
# Defaults
img_size = 128
patch_size = 1
in_chans = 3
embed_dim = 96
depths = [6, 6, 6, 6]
num_heads = [6, 6, 6, 6]
window_size = 7
mlp_ratio = 4.0
qkv_bias = True
drop_rate = 0.0
attn_drop_rate = 0.0
drop_path_rate = 0.1
norm_layer = nn.LayerNorm
ape = False
patch_norm = True
use_checkpoint = False
upscale = 2
img_range = 1.0
upsampler = ""
resi_connection = "1conv"
num_in_ch = in_chans
num_out_ch = in_chans
num_feat = 64
self.model_arch = "Swin2SR"
self.sub_type = "SR"
self.state = state_dict
if "params_ema" in self.state:
self.state = self.state["params_ema"]
elif "params" in self.state:
self.state = self.state["params"]
state_keys = self.state.keys()
if "conv_before_upsample.0.weight" in state_keys:
if "conv_aux.weight" in state_keys:
upsampler = "pixelshuffle_aux"
elif "conv_up1.weight" in state_keys:
upsampler = "nearest+conv"
else:
upsampler = "pixelshuffle"
supports_fp16 = False
elif "upsample.0.weight" in state_keys:
upsampler = "pixelshuffledirect"
else:
upsampler = ""
num_feat = (
self.state.get("conv_before_upsample.0.weight", None).shape[1]
if self.state.get("conv_before_upsample.weight", None)
else 64
)
num_in_ch = self.state["conv_first.weight"].shape[1]
in_chans = num_in_ch
if "conv_last.weight" in state_keys:
num_out_ch = self.state["conv_last.weight"].shape[0]
else:
num_out_ch = num_in_ch
upscale = 1
if upsampler == "nearest+conv":
upsample_keys = [
x for x in state_keys if "conv_up" in x and "bias" not in x
]
for upsample_key in upsample_keys:
upscale *= 2
elif upsampler == "pixelshuffle" or upsampler == "pixelshuffle_aux":
upsample_keys = [
x
for x in state_keys
if "upsample" in x and "conv" not in x and "bias" not in x
]
for upsample_key in upsample_keys:
shape = self.state[upsample_key].shape[0]
upscale *= math.sqrt(shape // num_feat)
upscale = int(upscale)
elif upsampler == "pixelshuffledirect":
upscale = int(
math.sqrt(self.state["upsample.0.bias"].shape[0] // num_out_ch)
)
max_layer_num = 0
max_block_num = 0
for key in state_keys:
result = re.match(
r"layers.(\d*).residual_group.blocks.(\d*).norm1.weight", key
)
if result:
layer_num, block_num = result.groups()
max_layer_num = max(max_layer_num, int(layer_num))
max_block_num = max(max_block_num, int(block_num))
depths = [max_block_num + 1 for _ in range(max_layer_num + 1)]
if (
"layers.0.residual_group.blocks.0.attn.relative_position_bias_table"
in state_keys
):
num_heads_num = self.state[
"layers.0.residual_group.blocks.0.attn.relative_position_bias_table"
].shape[-1]
num_heads = [num_heads_num for _ in range(max_layer_num + 1)]
else:
num_heads = depths
embed_dim = self.state["conv_first.weight"].shape[0]
mlp_ratio = float(
self.state["layers.0.residual_group.blocks.0.mlp.fc1.bias"].shape[0]
/ embed_dim
)
# TODO: could actually count the layers, but this should do
if "layers.0.conv.4.weight" in state_keys:
resi_connection = "3conv"
else:
resi_connection = "1conv"
window_size = int(
math.sqrt(
self.state[
"layers.0.residual_group.blocks.0.attn.relative_position_index"
].shape[0]
)
)
if "layers.0.residual_group.blocks.1.attn_mask" in state_keys:
img_size = int(
math.sqrt(
self.state["layers.0.residual_group.blocks.1.attn_mask"].shape[0]
)
* window_size
)
# The JPEG models are the only ones with window-size 7, and they also use this range
img_range = 255.0 if window_size == 7 else 1.0
self.in_nc = num_in_ch
self.out_nc = num_out_ch
self.num_feat = num_feat
self.embed_dim = embed_dim
self.num_heads = num_heads
self.depths = depths
self.window_size = window_size
self.mlp_ratio = mlp_ratio
self.scale = upscale
self.upsampler = upsampler
self.img_size = img_size
self.img_range = img_range
self.resi_connection = resi_connection
self.supports_fp16 = False # Too much weirdness to support this at the moment
self.supports_bfp16 = True
self.min_size_restriction = 16
## END AUTO DETECTION
if in_chans == 3:
rgb_mean = (0.4488, 0.4371, 0.4040)
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
else:
self.mean = torch.zeros(1, 1, 1, 1)
self.upscale = upscale
self.upsampler = upsampler
self.window_size = window_size
#####################################################################################################
################################### 1, shallow feature extraction ###################################
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
#####################################################################################################
################################### 2, deep feature extraction ######################################
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = embed_dim
self.mlp_ratio = mlp_ratio
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=embed_dim,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None,
)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# merge non-overlapping patches into image
self.patch_unembed = PatchUnEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=embed_dim,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None,
)
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) # type: ignore
trunc_normal_(self.absolute_pos_embed, std=0.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
] # stochastic depth decay rule
# build Residual Swin Transformer blocks (RSTB)
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = RSTB(
dim=embed_dim,
input_resolution=(patches_resolution[0], patches_resolution[1]),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], # type: ignore # no impact on SR results
norm_layer=norm_layer,
downsample=None,
use_checkpoint=use_checkpoint,
img_size=img_size,
patch_size=patch_size,
resi_connection=resi_connection,
)
self.layers.append(layer)
if self.upsampler == "pixelshuffle_hf":
self.layers_hf = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = RSTB(
dim=embed_dim,
input_resolution=(patches_resolution[0], patches_resolution[1]),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], # type: ignore # no impact on SR results # type: ignore
norm_layer=norm_layer,
downsample=None,
use_checkpoint=use_checkpoint,
img_size=img_size,
patch_size=patch_size,
resi_connection=resi_connection,
)
self.layers_hf.append(layer)
self.norm = norm_layer(self.num_features)
# build the last conv layer in deep feature extraction
if resi_connection == "1conv":
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
elif resi_connection == "3conv":
# to save parameters and memory
self.conv_after_body = nn.Sequential(
nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1),
)
#####################################################################################################
################################ 3, high quality image reconstruction ################################
if self.upsampler == "pixelshuffle":
# for classical SR
self.conv_before_upsample = nn.Sequential(
nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
)
self.upsample = Upsample(upscale, num_feat)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
elif self.upsampler == "pixelshuffle_aux":
self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
self.conv_before_upsample = nn.Sequential(
nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
)
self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.conv_after_aux = nn.Sequential(
nn.Conv2d(3, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
)
self.upsample = Upsample(upscale, num_feat)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
elif self.upsampler == "pixelshuffle_hf":
self.conv_before_upsample = nn.Sequential(
nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
)
self.upsample = Upsample(upscale, num_feat)
self.upsample_hf = Upsample_hf(upscale, num_feat)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.conv_first_hf = nn.Sequential(
nn.Conv2d(num_feat, embed_dim, 3, 1, 1), nn.LeakyReLU(inplace=True)
)
self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
self.conv_before_upsample_hf = nn.Sequential(
nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
)
self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
elif self.upsampler == "pixelshuffledirect":
# for lightweight SR (to save parameters)
self.upsample = UpsampleOneStep(
upscale,
embed_dim,
num_out_ch,
(patches_resolution[0], patches_resolution[1]),
)
elif self.upsampler == "nearest+conv":
# for real-world SR (less artifacts)
assert self.upscale == 4, "only support x4 now."
self.conv_before_upsample = nn.Sequential(
nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
)
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
else:
# for image denoising and JPEG compression artifact reduction
self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
self.apply(self._init_weights)
self.load_state_dict(state_dict)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore # type: ignore
def no_weight_decay(self):
return {"absolute_pos_embed"}
@torch.jit.ignore # type: ignore
def no_weight_decay_keywords(self):
return {"relative_position_bias_table"}
def check_image_size(self, x):
_, _, h, w = x.size()
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
return x
def forward_features(self, x):
x_size = (x.shape[2], x.shape[3])
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x, x_size)
x = self.norm(x) # B L C
x = self.patch_unembed(x, x_size)
return x
def forward_features_hf(self, x):
x_size = (x.shape[2], x.shape[3])
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers_hf:
x = layer(x, x_size)
x = self.norm(x) # B L C
x = self.patch_unembed(x, x_size)
return x
def forward(self, x):
H, W = x.shape[2:]
x = self.check_image_size(x)
self.mean = self.mean.type_as(x)
x = (x - self.mean) * self.img_range
if self.upsampler == "pixelshuffle":
# for classical SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.conv_before_upsample(x)
x = self.conv_last(self.upsample(x))
elif self.upsampler == "pixelshuffle_aux":
bicubic = F.interpolate(
x,
size=(H * self.upscale, W * self.upscale),
mode="bicubic",
align_corners=False,
)
bicubic = self.conv_bicubic(bicubic)
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.conv_before_upsample(x)
aux = self.conv_aux(x) # b, 3, LR_H, LR_W
x = self.conv_after_aux(aux)
x = (
self.upsample(x)[:, :, : H * self.upscale, : W * self.upscale]
+ bicubic[:, :, : H * self.upscale, : W * self.upscale]
)
x = self.conv_last(x)
aux = aux / self.img_range + self.mean
elif self.upsampler == "pixelshuffle_hf":
# for classical SR with HF
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x_before = self.conv_before_upsample(x)
x_out = self.conv_last(self.upsample(x_before))
x_hf = self.conv_first_hf(x_before)
x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
x_hf = self.conv_before_upsample_hf(x_hf)
x_hf = self.conv_last_hf(self.upsample_hf(x_hf))
x = x_out + x_hf
x_hf = x_hf / self.img_range + self.mean
elif self.upsampler == "pixelshuffledirect":
# for lightweight SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.upsample(x)
elif self.upsampler == "nearest+conv":
# for real-world SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.conv_before_upsample(x)
x = self.lrelu(
self.conv_up1(
torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest")
)
)
x = self.lrelu(
self.conv_up2(
torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest")
)
)
x = self.conv_last(self.lrelu(self.conv_hr(x)))
else:
# for image denoising and JPEG compression artifact reduction
x_first = self.conv_first(x)
res = self.conv_after_body(self.forward_features(x_first)) + x_first
x = x + self.conv_last(res)
x = x / self.img_range + self.mean
if self.upsampler == "pixelshuffle_aux":
# NOTE: I removed an "aux" output here. not sure what that was for
return x[:, :, : H * self.upscale, : W * self.upscale] # type: ignore
elif self.upsampler == "pixelshuffle_hf":
x_out = x_out / self.img_range + self.mean # type: ignore
return x_out[:, :, : H * self.upscale, : W * self.upscale], x[:, :, : H * self.upscale, : W * self.upscale], x_hf[:, :, : H * self.upscale, : W * self.upscale] # type: ignore
else:
return x[:, :, : H * self.upscale, : W * self.upscale]
def flops(self):
flops = 0
H, W = self.patches_resolution
flops += H * W * 3 * self.embed_dim * 9
flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers):
flops += layer.flops() # type: ignore
flops += H * W * 3 * self.embed_dim * self.embed_dim
flops += self.upsample.flops() # type: ignore
return flops
# pylint: skip-file
# -----------------------------------------------------------------------------------
# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
# Originally Written by Ze Liu, Modified by Jingyun Liang.
# -----------------------------------------------------------------------------------
import math
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
# Originally from the timm package
from .timm.drop import DropPath
from .timm.helpers import to_2tuple
from .timm.weight_init import trunc_normal_
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = (
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(
B, H // window_size, W // window_size, window_size, window_size, -1
)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r"""Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(
self,
dim,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter( # type: ignore
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = (
coords_flatten[:, :, None] - coords_flatten[:, None, :]
) # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(
1, 2, 0
).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=0.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B_, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = (
qkv[0],
qkv[1],
qkv[2],
) # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = q @ k.transpose(-2, -1)
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1) # type: ignore
].view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1,
) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1
).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
1
).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}"
def flops(self, N):
# calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# x = self.proj(x)
flops += N * self.dim * self.dim
return flops
class SwinTransformerBlock(nn.Module):
r"""Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(
self,
dim,
input_resolution,
num_heads,
window_size=7,
shift_size=0,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert (
0 <= self.shift_size < self.window_size
), "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim,
window_size=to_2tuple(self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
if self.shift_size > 0:
attn_mask = self.calculate_mask(self.input_resolution)
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def calculate_mask(self, x_size):
# calculate attention mask for SW-MSA
H, W = x_size
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
w_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(
img_mask, self.window_size
) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
attn_mask == 0, float(0.0)
)
return attn_mask
def forward(self, x, x_size):
H, W = x_size
B, L, C = x.shape
# assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(
x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
)
else:
shifted_x = x
# partition windows
x_windows = window_partition(
shifted_x, self.window_size
) # nW*B, window_size, window_size, C
x_windows = x_windows.view(
-1, self.window_size * self.window_size, C
) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
if self.input_resolution == x_size:
attn_windows = self.attn(
x_windows, mask=self.attn_mask
) # nW*B, window_size*window_size, C
else:
attn_windows = self.attn(
x_windows, mask=self.calculate_mask(x_size).to(x.device)
)
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(
shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
)
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def extra_repr(self) -> str:
return (
f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
)
def flops(self):
flops = 0
H, W = self.input_resolution
# norm1
flops += self.dim * H * W
# W-MSA/SW-MSA
nW = H * W / self.window_size / self.window_size
flops += nW * self.attn.flops(self.window_size * self.window_size)
# mlp
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
# norm2
flops += self.dim * H * W
return flops
class PatchMerging(nn.Module):
r"""Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
def extra_repr(self) -> str:
return f"input_resolution={self.input_resolution}, dim={self.dim}"
def flops(self):
H, W = self.input_resolution
flops = H * W * self.dim
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
return flops
class BasicLayer(nn.Module):
"""A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(
self,
dim,
input_resolution,
depth,
num_heads,
window_size,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=nn.LayerNorm,
downsample=None,
use_checkpoint=False,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList(
[
SwinTransformerBlock(
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i]
if isinstance(drop_path, list)
else drop_path,
norm_layer=norm_layer,
)
for i in range(depth)
]
)
# patch merging layer
if downsample is not None:
self.downsample = downsample(
input_resolution, dim=dim, norm_layer=norm_layer
)
else:
self.downsample = None
def forward(self, x, x_size):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, x_size)
else:
x = blk(x, x_size)
if self.downsample is not None:
x = self.downsample(x)
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
def flops(self):
flops = 0
for blk in self.blocks:
flops += blk.flops() # type: ignore
if self.downsample is not None:
flops += self.downsample.flops()
return flops
class RSTB(nn.Module):
"""Residual Swin Transformer Block (RSTB).
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
img_size: Input image size.
patch_size: Patch size.
resi_connection: The convolutional block before residual connection.
"""
def __init__(
self,
dim,
input_resolution,
depth,
num_heads,
window_size,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=nn.LayerNorm,
downsample=None,
use_checkpoint=False,
img_size=224,
patch_size=4,
resi_connection="1conv",
):
super(RSTB, self).__init__()
self.dim = dim
self.input_resolution = input_resolution
self.residual_group = BasicLayer(
dim=dim,
input_resolution=input_resolution,
depth=depth,
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path,
norm_layer=norm_layer,
downsample=downsample,
use_checkpoint=use_checkpoint,
)
if resi_connection == "1conv":
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
elif resi_connection == "3conv":
# to save parameters and memory
self.conv = nn.Sequential(
nn.Conv2d(dim, dim // 4, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(dim // 4, dim, 3, 1, 1),
)
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=0,
embed_dim=dim,
norm_layer=None,
)
self.patch_unembed = PatchUnEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=0,
embed_dim=dim,
norm_layer=None,
)
def forward(self, x, x_size):
return (
self.patch_embed(
self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))
)
+ x
)
def flops(self):
flops = 0
flops += self.residual_group.flops()
H, W = self.input_resolution
flops += H * W * self.dim * self.dim * 9
flops += self.patch_embed.flops()
flops += self.patch_unembed.flops()
return flops
class PatchEmbed(nn.Module):
r"""Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(
self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [
img_size[0] // patch_size[0], # type: ignore
img_size[1] // patch_size[1], # type: ignore
]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
def flops(self):
flops = 0
H, W = self.img_size
if self.norm is not None:
flops += H * W * self.embed_dim # type: ignore
return flops
class PatchUnEmbed(nn.Module):
r"""Image to Patch Unembedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(
self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [
img_size[0] // patch_size[0], # type: ignore
img_size[1] // patch_size[1], # type: ignore
]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
def forward(self, x, x_size):
B, HW, C = x.shape
x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
return x
def flops(self):
flops = 0
return flops
class Upsample(nn.Sequential):
"""Upsample module.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat):
m = []
if (scale & (scale - 1)) == 0: # scale = 2^n
for _ in range(int(math.log(scale, 2))):
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(2))
elif scale == 3:
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(3))
else:
raise ValueError(
f"scale {scale} is not supported. " "Supported scales: 2^n and 3."
)
super(Upsample, self).__init__(*m)
class UpsampleOneStep(nn.Sequential):
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
Used in lightweight SR to save parameters.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
self.num_feat = num_feat
self.input_resolution = input_resolution
m = []
m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1))
m.append(nn.PixelShuffle(scale))
super(UpsampleOneStep, self).__init__(*m)
def flops(self):
H, W = self.input_resolution # type: ignore
flops = H * W * self.num_feat * 3 * 9
return flops
class SwinIR(nn.Module):
r"""SwinIR
A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
Args:
img_size (int | tuple(int)): Input image size. Default 64
patch_size (int | tuple(int)): Patch size. Default: 1
in_chans (int): Number of input image channels. Default: 3
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
img_range: Image range. 1. or 255.
upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
"""
def __init__(
self,
state_dict,
**kwargs,
):
super(SwinIR, self).__init__()
# Defaults
img_size = 64
patch_size = 1
in_chans = 3
embed_dim = 96
depths = [6, 6, 6, 6]
num_heads = [6, 6, 6, 6]
window_size = 7
mlp_ratio = 4.0
qkv_bias = True
qk_scale = None
drop_rate = 0.0
attn_drop_rate = 0.0
drop_path_rate = 0.1
norm_layer = nn.LayerNorm
ape = False
patch_norm = True
use_checkpoint = False
upscale = 2
img_range = 1.0
upsampler = ""
resi_connection = "1conv"
num_feat = 64
num_in_ch = in_chans
num_out_ch = in_chans
supports_fp16 = True
self.start_unshuffle = 1
self.model_arch = "SwinIR"
self.sub_type = "SR"
self.state = state_dict
if "params_ema" in self.state:
self.state = self.state["params_ema"]
elif "params" in self.state:
self.state = self.state["params"]
state_keys = self.state.keys()
if "conv_before_upsample.0.weight" in state_keys:
if "conv_up1.weight" in state_keys:
upsampler = "nearest+conv"
else:
upsampler = "pixelshuffle"
supports_fp16 = False
elif "upsample.0.weight" in state_keys:
upsampler = "pixelshuffledirect"
else:
upsampler = ""
num_feat = (
self.state.get("conv_before_upsample.0.weight", None).shape[1]
if self.state.get("conv_before_upsample.weight", None)
else 64
)
if "conv_first.1.weight" in self.state:
self.state["conv_first.weight"] = self.state.pop("conv_first.1.weight")
self.state["conv_first.bias"] = self.state.pop("conv_first.1.bias")
self.start_unshuffle = round(math.sqrt(self.state["conv_first.weight"].shape[1] // 3))
num_in_ch = self.state["conv_first.weight"].shape[1]
in_chans = num_in_ch
if "conv_last.weight" in state_keys:
num_out_ch = self.state["conv_last.weight"].shape[0]
else:
num_out_ch = num_in_ch
upscale = 1
if upsampler == "nearest+conv":
upsample_keys = [
x for x in state_keys if "conv_up" in x and "bias" not in x
]
for upsample_key in upsample_keys:
upscale *= 2
elif upsampler == "pixelshuffle":
upsample_keys = [
x
for x in state_keys
if "upsample" in x and "conv" not in x and "bias" not in x
]
for upsample_key in upsample_keys:
shape = self.state[upsample_key].shape[0]
upscale *= math.sqrt(shape // num_feat)
upscale = int(upscale)
elif upsampler == "pixelshuffledirect":
upscale = int(
math.sqrt(self.state["upsample.0.bias"].shape[0] // num_out_ch)
)
max_layer_num = 0
max_block_num = 0
for key in state_keys:
result = re.match(
r"layers.(\d*).residual_group.blocks.(\d*).norm1.weight", key
)
if result:
layer_num, block_num = result.groups()
max_layer_num = max(max_layer_num, int(layer_num))
max_block_num = max(max_block_num, int(block_num))
depths = [max_block_num + 1 for _ in range(max_layer_num + 1)]
if (
"layers.0.residual_group.blocks.0.attn.relative_position_bias_table"
in state_keys
):
num_heads_num = self.state[
"layers.0.residual_group.blocks.0.attn.relative_position_bias_table"
].shape[-1]
num_heads = [num_heads_num for _ in range(max_layer_num + 1)]
else:
num_heads = depths
embed_dim = self.state["conv_first.weight"].shape[0]
mlp_ratio = float(
self.state["layers.0.residual_group.blocks.0.mlp.fc1.bias"].shape[0]
/ embed_dim
)
# TODO: could actually count the layers, but this should do
if "layers.0.conv.4.weight" in state_keys:
resi_connection = "3conv"
else:
resi_connection = "1conv"
window_size = int(
math.sqrt(
self.state[
"layers.0.residual_group.blocks.0.attn.relative_position_index"
].shape[0]
)
)
if "layers.0.residual_group.blocks.1.attn_mask" in state_keys:
img_size = int(
math.sqrt(
self.state["layers.0.residual_group.blocks.1.attn_mask"].shape[0]
)
* window_size
)
# The JPEG models are the only ones with window-size 7, and they also use this range
img_range = 255.0 if window_size == 7 else 1.0
self.in_nc = num_in_ch
self.out_nc = num_out_ch
self.num_feat = num_feat
self.embed_dim = embed_dim
self.num_heads = num_heads
self.depths = depths
self.window_size = window_size
self.mlp_ratio = mlp_ratio
self.scale = upscale / self.start_unshuffle
self.upsampler = upsampler
self.img_size = img_size
self.img_range = img_range
self.resi_connection = resi_connection
self.supports_fp16 = False # Too much weirdness to support this at the moment
self.supports_bfp16 = True
self.min_size_restriction = 16
self.img_range = img_range
if in_chans == 3:
rgb_mean = (0.4488, 0.4371, 0.4040)
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
else:
self.mean = torch.zeros(1, 1, 1, 1)
self.upscale = upscale
self.upsampler = upsampler
self.window_size = window_size
#####################################################################################################
################################### 1, shallow feature extraction ###################################
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
#####################################################################################################
################################### 2, deep feature extraction ######################################
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = embed_dim
self.mlp_ratio = mlp_ratio
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=embed_dim,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None,
)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# merge non-overlapping patches into image
self.patch_unembed = PatchUnEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=embed_dim,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None,
)
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter( # type: ignore
torch.zeros(1, num_patches, embed_dim)
)
trunc_normal_(self.absolute_pos_embed, std=0.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
] # stochastic depth decay rule
# build Residual Swin Transformer blocks (RSTB)
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = RSTB(
dim=embed_dim,
input_resolution=(patches_resolution[0], patches_resolution[1]),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[
sum(depths[:i_layer]) : sum(depths[: i_layer + 1]) # type: ignore
], # no impact on SR results
norm_layer=norm_layer,
downsample=None,
use_checkpoint=use_checkpoint,
img_size=img_size,
patch_size=patch_size,
resi_connection=resi_connection,
)
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
# build the last conv layer in deep feature extraction
if resi_connection == "1conv":
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
elif resi_connection == "3conv":
# to save parameters and memory
self.conv_after_body = nn.Sequential(
nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1),
)
#####################################################################################################
################################ 3, high quality image reconstruction ################################
if self.upsampler == "pixelshuffle":
# for classical SR
self.conv_before_upsample = nn.Sequential(
nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
)
self.upsample = Upsample(upscale, num_feat)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
elif self.upsampler == "pixelshuffledirect":
# for lightweight SR (to save parameters)
self.upsample = UpsampleOneStep(
upscale,
embed_dim,
num_out_ch,
(patches_resolution[0], patches_resolution[1]),
)
elif self.upsampler == "nearest+conv":
# for real-world SR (less artifacts)
self.conv_before_upsample = nn.Sequential(
nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
)
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
if self.upscale == 4:
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
elif self.upscale == 8:
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_up3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
else:
# for image denoising and JPEG compression artifact reduction
self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
self.apply(self._init_weights)
self.load_state_dict(self.state, strict=False)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore # type: ignore
def no_weight_decay(self):
return {"absolute_pos_embed"}
@torch.jit.ignore # type: ignore
def no_weight_decay_keywords(self):
return {"relative_position_bias_table"}
def check_image_size(self, x):
_, _, h, w = x.size()
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
return x
def forward_features(self, x):
x_size = (x.shape[2], x.shape[3])
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x, x_size)
x = self.norm(x) # B L C
x = self.patch_unembed(x, x_size)
return x
def forward(self, x):
H, W = x.shape[2:]
x = self.check_image_size(x)
self.mean = self.mean.type_as(x)
x = (x - self.mean) * self.img_range
if self.start_unshuffle > 1:
x = torch.nn.functional.pixel_unshuffle(x, self.start_unshuffle)
if self.upsampler == "pixelshuffle":
# for classical SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.conv_before_upsample(x)
x = self.conv_last(self.upsample(x))
elif self.upsampler == "pixelshuffledirect":
# for lightweight SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.upsample(x)
elif self.upsampler == "nearest+conv":
# for real-world SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.conv_before_upsample(x)
x = self.lrelu(
self.conv_up1(
torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest") # type: ignore
)
)
if self.upscale == 4:
x = self.lrelu(
self.conv_up2(
torch.nn.functional.interpolate( # type: ignore
x, scale_factor=2, mode="nearest"
)
)
)
elif self.upscale == 8:
x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
x = self.lrelu(self.conv_up3(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
x = self.conv_last(self.lrelu(self.conv_hr(x)))
else:
# for image denoising and JPEG compression artifact reduction
x_first = self.conv_first(x)
res = self.conv_after_body(self.forward_features(x_first)) + x_first
x = x + self.conv_last(res)
x = x / self.img_range + self.mean
return x[:, :, : H * self.upscale, : W * self.upscale]
def flops(self):
flops = 0
H, W = self.patches_resolution
flops += H * W * 3 * self.embed_dim * 9
flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers):
flops += layer.flops() # type: ignore
flops += H * W * 3 * self.embed_dim * self.embed_dim
flops += self.upsample.flops() # type: ignore
return flops
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
from collections import OrderedDict
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
import torch
import torch.nn as nn
####################
# Basic blocks
####################
def act(act_type: str, inplace=True, neg_slope=0.2, n_prelu=1):
# helper selecting activation
# neg_slope: for leakyrelu and init of prelu
# n_prelu: for p_relu num_parameters
act_type = act_type.lower()
if act_type == "relu":
layer = nn.ReLU(inplace)
elif act_type == "leakyrelu":
layer = nn.LeakyReLU(neg_slope, inplace)
elif act_type == "prelu":
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
else:
raise NotImplementedError(
"activation layer [{:s}] is not found".format(act_type)
)
return layer
def norm(norm_type: str, nc: int):
# helper selecting normalization layer
norm_type = norm_type.lower()
if norm_type == "batch":
layer = nn.BatchNorm2d(nc, affine=True)
elif norm_type == "instance":
layer = nn.InstanceNorm2d(nc, affine=False)
else:
raise NotImplementedError(
"normalization layer [{:s}] is not found".format(norm_type)
)
return layer
def pad(pad_type: str, padding):
# helper selecting padding layer
# if padding is 'zero', do by conv layers
pad_type = pad_type.lower()
if padding == 0:
return None
if pad_type == "reflect":
layer = nn.ReflectionPad2d(padding)
elif pad_type == "replicate":
layer = nn.ReplicationPad2d(padding)
else:
raise NotImplementedError(
"padding layer [{:s}] is not implemented".format(pad_type)
)
return layer
def get_valid_padding(kernel_size, dilation):
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
padding = (kernel_size - 1) // 2
return padding
class ConcatBlock(nn.Module):
# Concat the output of a submodule to its input
def __init__(self, submodule):
super(ConcatBlock, self).__init__()
self.sub = submodule
def forward(self, x):
output = torch.cat((x, self.sub(x)), dim=1)
return output
def __repr__(self):
tmpstr = "Identity .. \n|"
modstr = self.sub.__repr__().replace("\n", "\n|")
tmpstr = tmpstr + modstr
return tmpstr
class ShortcutBlock(nn.Module):
# Elementwise sum the output of a submodule to its input
def __init__(self, submodule):
super(ShortcutBlock, self).__init__()
self.sub = submodule
def forward(self, x):
output = x + self.sub(x)
return output
def __repr__(self):
tmpstr = "Identity + \n|"
modstr = self.sub.__repr__().replace("\n", "\n|")
tmpstr = tmpstr + modstr
return tmpstr
class ShortcutBlockSPSR(nn.Module):
# Elementwise sum the output of a submodule to its input
def __init__(self, submodule):
super(ShortcutBlockSPSR, self).__init__()
self.sub = submodule
def forward(self, x):
return x, self.sub
def __repr__(self):
tmpstr = "Identity + \n|"
modstr = self.sub.__repr__().replace("\n", "\n|")
tmpstr = tmpstr + modstr
return tmpstr
def sequential(*args):
# Flatten Sequential. It unwraps nn.Sequential.
if len(args) == 1:
if isinstance(args[0], OrderedDict):
raise NotImplementedError("sequential does not support OrderedDict input.")
return args[0] # No sequential is needed.
modules = []
for module in args:
if isinstance(module, nn.Sequential):
for submodule in module.children():
modules.append(submodule)
elif isinstance(module, nn.Module):
modules.append(module)
return nn.Sequential(*modules)
ConvMode = Literal["CNA", "NAC", "CNAC"]
# 2x2x2 Conv Block
def conv_block_2c2(
in_nc,
out_nc,
act_type="relu",
):
return sequential(
nn.Conv2d(in_nc, out_nc, kernel_size=2, padding=1),
nn.Conv2d(out_nc, out_nc, kernel_size=2, padding=0),
act(act_type) if act_type else None,
)
def conv_block(
in_nc: int,
out_nc: int,
kernel_size,
stride=1,
dilation=1,
groups=1,
bias=True,
pad_type="zero",
norm_type: str | None = None,
act_type: str | None = "relu",
mode: ConvMode = "CNA",
c2x2=False,
):
"""
Conv layer with padding, normalization, activation
mode: CNA --> Conv -> Norm -> Act
NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
"""
if c2x2:
return conv_block_2c2(in_nc, out_nc, act_type=act_type)
assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode)
padding = get_valid_padding(kernel_size, dilation)
p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None
padding = padding if pad_type == "zero" else 0
c = nn.Conv2d(
in_nc,
out_nc,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
groups=groups,
)
a = act(act_type) if act_type else None
if mode in ("CNA", "CNAC"):
n = norm(norm_type, out_nc) if norm_type else None
return sequential(p, c, n, a)
elif mode == "NAC":
if norm_type is None and act_type is not None:
a = act(act_type, inplace=False)
# Important!
# input----ReLU(inplace)----Conv--+----output
# |________________________|
# inplace ReLU will modify the input, therefore wrong output
n = norm(norm_type, in_nc) if norm_type else None
return sequential(n, a, p, c)
else:
assert False, f"Invalid conv mode {mode}"
####################
# Useful blocks
####################
class ResNetBlock(nn.Module):
"""
ResNet Block, 3-3 style
with extra residual scaling used in EDSR
(Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)
"""
def __init__(
self,
in_nc,
mid_nc,
out_nc,
kernel_size=3,
stride=1,
dilation=1,
groups=1,
bias=True,
pad_type="zero",
norm_type=None,
act_type="relu",
mode: ConvMode = "CNA",
res_scale=1,
):
super(ResNetBlock, self).__init__()
conv0 = conv_block(
in_nc,
mid_nc,
kernel_size,
stride,
dilation,
groups,
bias,
pad_type,
norm_type,
act_type,
mode,
)
if mode == "CNA":
act_type = None
if mode == "CNAC": # Residual path: |-CNAC-|
act_type = None
norm_type = None
conv1 = conv_block(
mid_nc,
out_nc,
kernel_size,
stride,
dilation,
groups,
bias,
pad_type,
norm_type,
act_type,
mode,
)
# if in_nc != out_nc:
# self.project = conv_block(in_nc, out_nc, 1, stride, dilation, 1, bias, pad_type, \
# None, None)
# print('Need a projecter in ResNetBlock.')
# else:
# self.project = lambda x:x
self.res = sequential(conv0, conv1)
self.res_scale = res_scale
def forward(self, x):
res = self.res(x).mul(self.res_scale)
return x + res
class RRDB(nn.Module):
"""
Residual in Residual Dense Block
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
"""
def __init__(
self,
nf,
kernel_size=3,
gc=32,
stride=1,
bias: bool = True,
pad_type="zero",
norm_type=None,
act_type="leakyrelu",
mode: ConvMode = "CNA",
_convtype="Conv2D",
_spectral_norm=False,
plus=False,
c2x2=False,
):
super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(
nf,
kernel_size,
gc,
stride,
bias,
pad_type,
norm_type,
act_type,
mode,
plus=plus,
c2x2=c2x2,
)
self.RDB2 = ResidualDenseBlock_5C(
nf,
kernel_size,
gc,
stride,
bias,
pad_type,
norm_type,
act_type,
mode,
plus=plus,
c2x2=c2x2,
)
self.RDB3 = ResidualDenseBlock_5C(
nf,
kernel_size,
gc,
stride,
bias,
pad_type,
norm_type,
act_type,
mode,
plus=plus,
c2x2=c2x2,
)
def forward(self, x):
out = self.RDB1(x)
out = self.RDB2(out)
out = self.RDB3(out)
return out * 0.2 + x
class ResidualDenseBlock_5C(nn.Module):
"""
Residual Dense Block
style: 5 convs
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
Modified options that can be used:
- "Partial Convolution based Padding" arXiv:1811.11718
- "Spectral normalization" arXiv:1802.05957
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
{Rakotonirina} and A. {Rasoanaivo}
Args:
nf (int): Channel number of intermediate features (num_feat).
gc (int): Channels for each growth (num_grow_ch: growth channel,
i.e. intermediate channels).
convtype (str): the type of convolution to use. Default: 'Conv2D'
gaussian_noise (bool): enable the ESRGAN+ gaussian noise (no new
trainable parameters)
plus (bool): enable the additional residual paths from ESRGAN+
(adds trainable parameters)
"""
def __init__(
self,
nf=64,
kernel_size=3,
gc=32,
stride=1,
bias: bool = True,
pad_type="zero",
norm_type=None,
act_type="leakyrelu",
mode: ConvMode = "CNA",
plus=False,
c2x2=False,
):
super(ResidualDenseBlock_5C, self).__init__()
## +
self.conv1x1 = conv1x1(nf, gc) if plus else None
## +
self.conv1 = conv_block(
nf,
gc,
kernel_size,
stride,
bias=bias,
pad_type=pad_type,
norm_type=norm_type,
act_type=act_type,
mode=mode,
c2x2=c2x2,
)
self.conv2 = conv_block(
nf + gc,
gc,
kernel_size,
stride,
bias=bias,
pad_type=pad_type,
norm_type=norm_type,
act_type=act_type,
mode=mode,
c2x2=c2x2,
)
self.conv3 = conv_block(
nf + 2 * gc,
gc,
kernel_size,
stride,
bias=bias,
pad_type=pad_type,
norm_type=norm_type,
act_type=act_type,
mode=mode,
c2x2=c2x2,
)
self.conv4 = conv_block(
nf + 3 * gc,
gc,
kernel_size,
stride,
bias=bias,
pad_type=pad_type,
norm_type=norm_type,
act_type=act_type,
mode=mode,
c2x2=c2x2,
)
if mode == "CNA":
last_act = None
else:
last_act = act_type
self.conv5 = conv_block(
nf + 4 * gc,
nf,
3,
stride,
bias=bias,
pad_type=pad_type,
norm_type=norm_type,
act_type=last_act,
mode=mode,
c2x2=c2x2,
)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(torch.cat((x, x1), 1))
if self.conv1x1:
# pylint: disable=not-callable
x2 = x2 + self.conv1x1(x) # +
x3 = self.conv3(torch.cat((x, x1, x2), 1))
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
if self.conv1x1:
x4 = x4 + x2 # +
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5 * 0.2 + x
def conv1x1(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
####################
# Upsampler
####################
def pixelshuffle_block(
in_nc: int,
out_nc: int,
upscale_factor=2,
kernel_size=3,
stride=1,
bias=True,
pad_type="zero",
norm_type: str | None = None,
act_type="relu",
):
"""
Pixel shuffle layer
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
Neural Network, CVPR17)
"""
conv = conv_block(
in_nc,
out_nc * (upscale_factor**2),
kernel_size,
stride,
bias=bias,
pad_type=pad_type,
norm_type=None,
act_type=None,
)
pixel_shuffle = nn.PixelShuffle(upscale_factor)
n = norm(norm_type, out_nc) if norm_type else None
a = act(act_type) if act_type else None
return sequential(conv, pixel_shuffle, n, a)
def upconv_block(
in_nc: int,
out_nc: int,
upscale_factor=2,
kernel_size=3,
stride=1,
bias=True,
pad_type="zero",
norm_type: str | None = None,
act_type="relu",
mode="nearest",
c2x2=False,
):
# Up conv
# described in https://distill.pub/2016/deconv-checkerboard/
upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
conv = conv_block(
in_nc,
out_nc,
kernel_size,
stride,
bias=bias,
pad_type=pad_type,
norm_type=norm_type,
act_type=act_type,
c2x2=c2x2,
)
return sequential(upsample, conv)
Tencent is pleased to support the open source community by making GFPGAN available.
Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
GFPGAN is licensed under the Apache License Version 2.0 except for the third-party components listed below.
Terms of the Apache License Version 2.0:
---------------------------------------------
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
“License” shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
“Licensor” shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
“Legal Entity” shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, “control” means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
“You” (or “Your”) shall mean an individual or Legal Entity exercising permissions granted by this License.
“Source” form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
“Object” form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
“Work” shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
“Derivative Works” shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
“Contribution” shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, “submitted” means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as “Not a Contribution.”
“Contributor” shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
You must give any other recipients of the Work or Derivative Works a copy of this License; and
You must cause any modified files to carry prominent notices stating that You changed the files; and
You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
If the Work includes a “NOTICE” text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
Other dependencies and licenses:
Open Source Software licensed under the Apache 2.0 license and Other Licenses of the Third-Party Components therein:
---------------------------------------------
1. basicsr
Copyright 2018-2020 BasicSR Authors
This BasicSR project is released under the Apache 2.0 license.
A copy of Apache 2.0 is included in this file.
StyleGAN2
The codes are modified from the repository stylegan2-pytorch. Many thanks to the author - Kim Seonghyeon 😊 for translating from the official TensorFlow codes to PyTorch ones. Here is the license of stylegan2-pytorch.
The official repository is https://github.com/NVlabs/stylegan2, and here is the NVIDIA license.
DFDNet
The codes are largely modified from the repository DFDNet. Their license is Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.
Terms of the Nvidia License:
---------------------------------------------
1. Definitions
"Licensor" means any person or entity that distributes its Work.
"Software" means the original work of authorship made available under
this License.
"Work" means the Software and any additions to or derivative works of
the Software that are made available under this License.
"Nvidia Processors" means any central processing unit (CPU), graphics
processing unit (GPU), field-programmable gate array (FPGA),
application-specific integrated circuit (ASIC) or any combination
thereof designed, made, sold, or provided by Nvidia or its affiliates.
The terms "reproduce," "reproduction," "derivative works," and
"distribution" have the meaning as provided under U.S. copyright law;
provided, however, that for the purposes of this License, derivative
works shall not include works that remain separable from, or merely
link (or bind by name) to the interfaces of, the Work.
Works, including the Software, are "made available" under this License
by including in or with the Work either (a) a copyright notice
referencing the applicability of this License to the Work, or (b) a
copy of this License.
2. License Grants
2.1 Copyright Grant. Subject to the terms and conditions of this
License, each Licensor grants to you a perpetual, worldwide,
non-exclusive, royalty-free, copyright license to reproduce,
prepare derivative works of, publicly display, publicly perform,
sublicense and distribute its Work and any resulting derivative
works in any form.
3. Limitations
3.1 Redistribution. You may reproduce or distribute the Work only
if (a) you do so under this License, (b) you include a complete
copy of this License with your distribution, and (c) you retain
without modification any copyright, patent, trademark, or
attribution notices that are present in the Work.
3.2 Derivative Works. You may specify that additional or different
terms apply to the use, reproduction, and distribution of your
derivative works of the Work ("Your Terms") only if (a) Your Terms
provide that the use limitation in Section 3.3 applies to your
derivative works, and (b) you identify the specific derivative
works that are subject to Your Terms. Notwithstanding Your Terms,
this License (including the redistribution requirements in Section
3.1) will continue to apply to the Work itself.
3.3 Use Limitation. The Work and any derivative works thereof only
may be used or intended for use non-commercially. The Work or
derivative works thereof may be used or intended for use by Nvidia
or its affiliates commercially or non-commercially. As used herein,
"non-commercially" means for research or evaluation purposes only.
3.4 Patent Claims. If you bring or threaten to bring a patent claim
against any Licensor (including any claim, cross-claim or
counterclaim in a lawsuit) to enforce any patents that you allege
are infringed by any Work, then your rights under this License from
such Licensor (including the grants in Sections 2.1 and 2.2) will
terminate immediately.
3.5 Trademarks. This License does not grant any rights to use any
Licensor's or its affiliates' names, logos, or trademarks, except
as necessary to reproduce the notices described in this License.
3.6 Termination. If you violate any term of this License, then your
rights under this License (including the grants in Sections 2.1 and
2.2) will terminate immediately.
4. Disclaimer of Warranty.
THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
THIS LICENSE.
5. Limitation of Liability.
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
THE POSSIBILITY OF SUCH DAMAGES.
MIT License
Copyright (c) 2019 Kim Seonghyeon
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Open Source Software licensed under the BSD 3-Clause license:
---------------------------------------------
1. torchvision
Copyright (c) Soumith Chintala 2016,
All rights reserved.
2. torch
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
Copyright (c) 2011-2013 NYU (Clement Farabet)
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
Terms of the BSD 3-Clause License:
---------------------------------------------
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Open Source Software licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
---------------------------------------------
1. numpy
Copyright (c) 2005-2020, NumPy Developers.
All rights reserved.
A copy of BSD 3-Clause License is included in this file.
The NumPy repository and source distributions bundle several libraries that are
compatibly licensed. We list these here.
Name: Numpydoc
Files: doc/sphinxext/numpydoc/*
License: BSD-2-Clause
For details, see doc/sphinxext/LICENSE.txt
Name: scipy-sphinx-theme
Files: doc/scipy-sphinx-theme/*
License: BSD-3-Clause AND PSF-2.0 AND Apache-2.0
For details, see doc/scipy-sphinx-theme/LICENSE.txt
Name: lapack-lite
Files: numpy/linalg/lapack_lite/*
License: BSD-3-Clause
For details, see numpy/linalg/lapack_lite/LICENSE.txt
Name: tempita
Files: tools/npy_tempita/*
License: MIT
For details, see tools/npy_tempita/license.txt
Name: dragon4
Files: numpy/core/src/multiarray/dragon4.c
License: MIT
For license text, see numpy/core/src/multiarray/dragon4.c
Open Source Software licensed under the MIT license:
---------------------------------------------
1. facexlib
Copyright (c) 2020 Xintao Wang
2. opencv-python
Copyright (c) Olli-Pekka Heinisuo
Please note that only files in cv2 package are used.
Terms of the MIT License:
---------------------------------------------
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
Open Source Software licensed under the MIT license and Other Licenses of the Third-Party Components therein:
---------------------------------------------
1. tqdm
Copyright (c) 2013 noamraph
`tqdm` is a product of collaborative work.
Unless otherwise stated, all authors (see commit logs) retain copyright
for their respective work, and release the work under the MIT licence
(text below).
Exceptions or notable authors are listed below
in reverse chronological order:
* files: *
MPLv2.0 2015-2020 (c) Casper da Costa-Luis
[casperdcl](https://github.com/casperdcl).
* files: tqdm/_tqdm.py
MIT 2016 (c) [PR #96] on behalf of Google Inc.
* files: tqdm/_tqdm.py setup.py README.rst MANIFEST.in .gitignore
MIT 2013 (c) Noam Yorav-Raphael, original author.
[PR #96]: https://github.com/tqdm/tqdm/pull/96
Mozilla Public Licence (MPL) v. 2.0 - Exhibit A
-----------------------------------------------
This Source Code Form is subject to the terms of the
Mozilla Public License, v. 2.0.
If a copy of the MPL was not distributed with this file,
You can obtain one at https://mozilla.org/MPL/2.0/.
MIT License (MIT)
-----------------
Copyright (c) 2013 noamraph
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
\ No newline at end of file
Tencent is pleased to support the open source community by making GFPGAN available.
Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
GFPGAN is licensed under the Apache License Version 2.0 except for the third-party components listed below.
Terms of the Apache License Version 2.0:
---------------------------------------------
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
“License” shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
“Licensor” shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
“Legal Entity” shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, “control” means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
“You” (or “Your”) shall mean an individual or Legal Entity exercising permissions granted by this License.
“Source” form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
“Object” form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
“Work” shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
“Derivative Works” shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
“Contribution” shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, “submitted” means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as “Not a Contribution.”
“Contributor” shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
You must give any other recipients of the Work or Derivative Works a copy of this License; and
You must cause any modified files to carry prominent notices stating that You changed the files; and
You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
If the Work includes a “NOTICE” text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
Other dependencies and licenses:
Open Source Software licensed under the Apache 2.0 license and Other Licenses of the Third-Party Components therein:
---------------------------------------------
1. basicsr
Copyright 2018-2020 BasicSR Authors
This BasicSR project is released under the Apache 2.0 license.
A copy of Apache 2.0 is included in this file.
StyleGAN2
The codes are modified from the repository stylegan2-pytorch. Many thanks to the author - Kim Seonghyeon 😊 for translating from the official TensorFlow codes to PyTorch ones. Here is the license of stylegan2-pytorch.
The official repository is https://github.com/NVlabs/stylegan2, and here is the NVIDIA license.
DFDNet
The codes are largely modified from the repository DFDNet. Their license is Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.
Terms of the Nvidia License:
---------------------------------------------
1. Definitions
"Licensor" means any person or entity that distributes its Work.
"Software" means the original work of authorship made available under
this License.
"Work" means the Software and any additions to or derivative works of
the Software that are made available under this License.
"Nvidia Processors" means any central processing unit (CPU), graphics
processing unit (GPU), field-programmable gate array (FPGA),
application-specific integrated circuit (ASIC) or any combination
thereof designed, made, sold, or provided by Nvidia or its affiliates.
The terms "reproduce," "reproduction," "derivative works," and
"distribution" have the meaning as provided under U.S. copyright law;
provided, however, that for the purposes of this License, derivative
works shall not include works that remain separable from, or merely
link (or bind by name) to the interfaces of, the Work.
Works, including the Software, are "made available" under this License
by including in or with the Work either (a) a copyright notice
referencing the applicability of this License to the Work, or (b) a
copy of this License.
2. License Grants
2.1 Copyright Grant. Subject to the terms and conditions of this
License, each Licensor grants to you a perpetual, worldwide,
non-exclusive, royalty-free, copyright license to reproduce,
prepare derivative works of, publicly display, publicly perform,
sublicense and distribute its Work and any resulting derivative
works in any form.
3. Limitations
3.1 Redistribution. You may reproduce or distribute the Work only
if (a) you do so under this License, (b) you include a complete
copy of this License with your distribution, and (c) you retain
without modification any copyright, patent, trademark, or
attribution notices that are present in the Work.
3.2 Derivative Works. You may specify that additional or different
terms apply to the use, reproduction, and distribution of your
derivative works of the Work ("Your Terms") only if (a) Your Terms
provide that the use limitation in Section 3.3 applies to your
derivative works, and (b) you identify the specific derivative
works that are subject to Your Terms. Notwithstanding Your Terms,
this License (including the redistribution requirements in Section
3.1) will continue to apply to the Work itself.
3.3 Use Limitation. The Work and any derivative works thereof only
may be used or intended for use non-commercially. The Work or
derivative works thereof may be used or intended for use by Nvidia
or its affiliates commercially or non-commercially. As used herein,
"non-commercially" means for research or evaluation purposes only.
3.4 Patent Claims. If you bring or threaten to bring a patent claim
against any Licensor (including any claim, cross-claim or
counterclaim in a lawsuit) to enforce any patents that you allege
are infringed by any Work, then your rights under this License from
such Licensor (including the grants in Sections 2.1 and 2.2) will
terminate immediately.
3.5 Trademarks. This License does not grant any rights to use any
Licensor's or its affiliates' names, logos, or trademarks, except
as necessary to reproduce the notices described in this License.
3.6 Termination. If you violate any term of this License, then your
rights under this License (including the grants in Sections 2.1 and
2.2) will terminate immediately.
4. Disclaimer of Warranty.
THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
THIS LICENSE.
5. Limitation of Liability.
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
THE POSSIBILITY OF SUCH DAMAGES.
MIT License
Copyright (c) 2019 Kim Seonghyeon
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Open Source Software licensed under the BSD 3-Clause license:
---------------------------------------------
1. torchvision
Copyright (c) Soumith Chintala 2016,
All rights reserved.
2. torch
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
Copyright (c) 2011-2013 NYU (Clement Farabet)
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
Terms of the BSD 3-Clause License:
---------------------------------------------
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Open Source Software licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
---------------------------------------------
1. numpy
Copyright (c) 2005-2020, NumPy Developers.
All rights reserved.
A copy of BSD 3-Clause License is included in this file.
The NumPy repository and source distributions bundle several libraries that are
compatibly licensed. We list these here.
Name: Numpydoc
Files: doc/sphinxext/numpydoc/*
License: BSD-2-Clause
For details, see doc/sphinxext/LICENSE.txt
Name: scipy-sphinx-theme
Files: doc/scipy-sphinx-theme/*
License: BSD-3-Clause AND PSF-2.0 AND Apache-2.0
For details, see doc/scipy-sphinx-theme/LICENSE.txt
Name: lapack-lite
Files: numpy/linalg/lapack_lite/*
License: BSD-3-Clause
For details, see numpy/linalg/lapack_lite/LICENSE.txt
Name: tempita
Files: tools/npy_tempita/*
License: MIT
For details, see tools/npy_tempita/license.txt
Name: dragon4
Files: numpy/core/src/multiarray/dragon4.c
License: MIT
For license text, see numpy/core/src/multiarray/dragon4.c
Open Source Software licensed under the MIT license:
---------------------------------------------
1. facexlib
Copyright (c) 2020 Xintao Wang
2. opencv-python
Copyright (c) Olli-Pekka Heinisuo
Please note that only files in cv2 package are used.
Terms of the MIT License:
---------------------------------------------
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
Open Source Software licensed under the MIT license and Other Licenses of the Third-Party Components therein:
---------------------------------------------
1. tqdm
Copyright (c) 2013 noamraph
`tqdm` is a product of collaborative work.
Unless otherwise stated, all authors (see commit logs) retain copyright
for their respective work, and release the work under the MIT licence
(text below).
Exceptions or notable authors are listed below
in reverse chronological order:
* files: *
MPLv2.0 2015-2020 (c) Casper da Costa-Luis
[casperdcl](https://github.com/casperdcl).
* files: tqdm/_tqdm.py
MIT 2016 (c) [PR #96] on behalf of Google Inc.
* files: tqdm/_tqdm.py setup.py README.rst MANIFEST.in .gitignore
MIT 2013 (c) Noam Yorav-Raphael, original author.
[PR #96]: https://github.com/tqdm/tqdm/pull/96
Mozilla Public Licence (MPL) v. 2.0 - Exhibit A
-----------------------------------------------
This Source Code Form is subject to the terms of the
Mozilla Public License, v. 2.0.
If a copy of the MPL was not distributed with this file,
You can obtain one at https://mozilla.org/MPL/2.0/.
MIT License (MIT)
-----------------
Copyright (c) 2013 noamraph
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
\ No newline at end of file
S-Lab License 1.0
Copyright 2022 S-Lab
Redistribution and use for non-commercial purpose in source and
binary forms, with or without modification, are permitted provided
that the following conditions are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
In the event that redistribution and/or use for commercial purpose in
source or binary forms, with or without modification is required,
please contact the contributor(s) of the work.
import torch.nn as nn
def conv3x3(inplanes, outplanes, stride=1):
"""A simple wrapper for 3x3 convolution with padding.
Args:
inplanes (int): Channel number of inputs.
outplanes (int): Channel number of outputs.
stride (int): Stride in convolution. Default: 1.
"""
return nn.Conv2d(
inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False
)
class BasicBlock(nn.Module):
"""Basic residual block used in the ResNetArcFace architecture.
Args:
inplanes (int): Channel number of inputs.
planes (int): Channel number of outputs.
stride (int): Stride in convolution. Default: 1.
downsample (nn.Module): The downsample module. Default: None.
"""
expansion = 1 # output channel expansion ratio
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class IRBlock(nn.Module):
"""Improved residual block (IR Block) used in the ResNetArcFace architecture.
Args:
inplanes (int): Channel number of inputs.
planes (int): Channel number of outputs.
stride (int): Stride in convolution. Default: 1.
downsample (nn.Module): The downsample module. Default: None.
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
"""
expansion = 1 # output channel expansion ratio
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
super(IRBlock, self).__init__()
self.bn0 = nn.BatchNorm2d(inplanes)
self.conv1 = conv3x3(inplanes, inplanes)
self.bn1 = nn.BatchNorm2d(inplanes)
self.prelu = nn.PReLU()
self.conv2 = conv3x3(inplanes, planes, stride)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
self.use_se = use_se
if self.use_se:
self.se = SEBlock(planes)
def forward(self, x):
residual = x
out = self.bn0(x)
out = self.conv1(out)
out = self.bn1(out)
out = self.prelu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.use_se:
out = self.se(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.prelu(out)
return out
class Bottleneck(nn.Module):
"""Bottleneck block used in the ResNetArcFace architecture.
Args:
inplanes (int): Channel number of inputs.
planes (int): Channel number of outputs.
stride (int): Stride in convolution. Default: 1.
downsample (nn.Module): The downsample module. Default: None.
"""
expansion = 4 # output channel expansion ratio
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(
planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(
planes, planes * self.expansion, kernel_size=1, bias=False
)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class SEBlock(nn.Module):
"""The squeeze-and-excitation block (SEBlock) used in the IRBlock.
Args:
channel (int): Channel number of inputs.
reduction (int): Channel reduction ration. Default: 16.
"""
def __init__(self, channel, reduction=16):
super(SEBlock, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(
1
) # pool to 1x1 without spatial information
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction),
nn.PReLU(),
nn.Linear(channel // reduction, channel),
nn.Sigmoid(),
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y
class ResNetArcFace(nn.Module):
"""ArcFace with ResNet architectures.
Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
Args:
block (str): Block used in the ArcFace architecture.
layers (tuple(int)): Block numbers in each layer.
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
"""
def __init__(self, block, layers, use_se=True):
if block == "IRBlock":
block = IRBlock
self.inplanes = 64
self.use_se = use_se
super(ResNetArcFace, self).__init__()
self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.prelu = nn.PReLU()
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.bn4 = nn.BatchNorm2d(512)
self.dropout = nn.Dropout()
self.fc5 = nn.Linear(512 * 8 * 8, 512)
self.bn5 = nn.BatchNorm1d(512)
# initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
nn.init.constant_(m.bias, 0)
def _make_layer(self, block, planes, num_blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False,
),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(
block(self.inplanes, planes, stride, downsample, use_se=self.use_se)
)
self.inplanes = planes
for _ in range(1, num_blocks):
layers.append(block(self.inplanes, planes, use_se=self.use_se))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.prelu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.bn4(x)
x = self.dropout(x)
x = x.view(x.size(0), -1)
x = self.fc5(x)
x = self.bn5(x)
return x
"""
Modified from https://github.com/sczhou/CodeFormer
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
This verison of the arch specifically was gathered from an old version of GFPGAN. If this is a problem, please contact me.
"""
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import logging as logger
from torch import Tensor
class VectorQuantizer(nn.Module):
def __init__(self, codebook_size, emb_dim, beta):
super(VectorQuantizer, self).__init__()
self.codebook_size = codebook_size # number of embeddings
self.emb_dim = emb_dim # dimension of embedding
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
self.embedding.weight.data.uniform_(
-1.0 / self.codebook_size, 1.0 / self.codebook_size
)
def forward(self, z):
# reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.emb_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = (
(z_flattened**2).sum(dim=1, keepdim=True)
+ (self.embedding.weight**2).sum(1)
- 2 * torch.matmul(z_flattened, self.embedding.weight.t())
)
mean_distance = torch.mean(d)
# find closest encodings
# min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
min_encoding_scores, min_encoding_indices = torch.topk(
d, 1, dim=1, largest=False
)
# [0-1], higher score, higher confidence
min_encoding_scores = torch.exp(-min_encoding_scores / 10)
min_encodings = torch.zeros(
min_encoding_indices.shape[0], self.codebook_size
).to(z)
min_encodings.scatter_(1, min_encoding_indices, 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
# compute loss for embedding
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
(z_q - z.detach()) ** 2
)
# preserve gradients
z_q = z + (z_q - z).detach()
# perplexity
e_mean = torch.mean(min_encodings, dim=0)
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return (
z_q,
loss,
{
"perplexity": perplexity,
"min_encodings": min_encodings,
"min_encoding_indices": min_encoding_indices,
"min_encoding_scores": min_encoding_scores,
"mean_distance": mean_distance,
},
)
def get_codebook_feat(self, indices, shape):
# input indices: batch*token_num -> (batch*token_num)*1
# shape: batch, height, width, channel
indices = indices.view(-1, 1)
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
min_encodings.scatter_(1, indices, 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
if shape is not None: # reshape back to match original input shape
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
return z_q
class GumbelQuantizer(nn.Module):
def __init__(
self,
codebook_size,
emb_dim,
num_hiddens,
straight_through=False,
kl_weight=5e-4,
temp_init=1.0,
):
super().__init__()
self.codebook_size = codebook_size # number of embeddings
self.emb_dim = emb_dim # dimension of embedding
self.straight_through = straight_through
self.temperature = temp_init
self.kl_weight = kl_weight
self.proj = nn.Conv2d(
num_hiddens, codebook_size, 1
) # projects last encoder layer to quantized logits
self.embed = nn.Embedding(codebook_size, emb_dim)
def forward(self, z):
hard = self.straight_through if self.training else True
logits = self.proj(z)
soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
# + kl divergence to the prior loss
qy = F.softmax(logits, dim=1)
diff = (
self.kl_weight
* torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
)
min_encoding_indices = soft_one_hot.argmax(dim=1)
return z_q, diff, {"min_encoding_indices": min_encoding_indices}
class Downsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)
def forward(self, x):
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x
class Upsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
def forward(self, x):
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
x = self.conv(x)
return x
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1)
k = k.reshape(b, c, h * w)
w_ = torch.bmm(q, k)
w_ = w_ * (int(c) ** (-0.5))
w_ = F.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1)
h_ = torch.bmm(v, w_)
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
return x + h_
class Encoder(nn.Module):
def __init__(
self,
in_channels,
nf,
out_channels,
ch_mult,
num_res_blocks,
resolution,
attn_resolutions,
):
super().__init__()
self.nf = nf
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.attn_resolutions = attn_resolutions
curr_res = self.resolution
in_ch_mult = (1,) + tuple(ch_mult)
blocks = []
# initial convultion
blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
# residual and downsampling blocks, with attention on smaller res (16x16)
for i in range(self.num_resolutions):
block_in_ch = nf * in_ch_mult[i]
block_out_ch = nf * ch_mult[i]
for _ in range(self.num_res_blocks):
blocks.append(ResBlock(block_in_ch, block_out_ch))
block_in_ch = block_out_ch
if curr_res in attn_resolutions:
blocks.append(AttnBlock(block_in_ch))
if i != self.num_resolutions - 1:
blocks.append(Downsample(block_in_ch))
curr_res = curr_res // 2
# non-local attention block
blocks.append(ResBlock(block_in_ch, block_in_ch)) # type: ignore
blocks.append(AttnBlock(block_in_ch)) # type: ignore
blocks.append(ResBlock(block_in_ch, block_in_ch)) # type: ignore
# normalise and convert to latent size
blocks.append(normalize(block_in_ch)) # type: ignore
blocks.append(
nn.Conv2d(block_in_ch, out_channels, kernel_size=3, stride=1, padding=1) # type: ignore
)
self.blocks = nn.ModuleList(blocks)
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
class Generator(nn.Module):
def __init__(self, nf, ch_mult, res_blocks, img_size, attn_resolutions, emb_dim):
super().__init__()
self.nf = nf
self.ch_mult = ch_mult
self.num_resolutions = len(self.ch_mult)
self.num_res_blocks = res_blocks
self.resolution = img_size
self.attn_resolutions = attn_resolutions
self.in_channels = emb_dim
self.out_channels = 3
block_in_ch = self.nf * self.ch_mult[-1]
curr_res = self.resolution // 2 ** (self.num_resolutions - 1)
blocks = []
# initial conv
blocks.append(
nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)
)
# non-local attention block
blocks.append(ResBlock(block_in_ch, block_in_ch))
blocks.append(AttnBlock(block_in_ch))
blocks.append(ResBlock(block_in_ch, block_in_ch))
for i in reversed(range(self.num_resolutions)):
block_out_ch = self.nf * self.ch_mult[i]
for _ in range(self.num_res_blocks):
blocks.append(ResBlock(block_in_ch, block_out_ch))
block_in_ch = block_out_ch
if curr_res in self.attn_resolutions:
blocks.append(AttnBlock(block_in_ch))
if i != 0:
blocks.append(Upsample(block_in_ch))
curr_res = curr_res * 2
blocks.append(normalize(block_in_ch))
blocks.append(
nn.Conv2d(
block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1
)
)
self.blocks = nn.ModuleList(blocks)
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
class VQAutoEncoder(nn.Module):
def __init__(
self,
img_size,
nf,
ch_mult,
quantizer="nearest",
res_blocks=2,
attn_resolutions=[16],
codebook_size=1024,
emb_dim=256,
beta=0.25,
gumbel_straight_through=False,
gumbel_kl_weight=1e-8,
model_path=None,
):
super().__init__()
self.in_channels = 3
self.nf = nf
self.n_blocks = res_blocks
self.codebook_size = codebook_size
self.embed_dim = emb_dim
self.ch_mult = ch_mult
self.resolution = img_size
self.attn_resolutions = attn_resolutions
self.quantizer_type = quantizer
self.encoder = Encoder(
self.in_channels,
self.nf,
self.embed_dim,
self.ch_mult,
self.n_blocks,
self.resolution,
self.attn_resolutions,
)
if self.quantizer_type == "nearest":
self.beta = beta # 0.25
self.quantize = VectorQuantizer(
self.codebook_size, self.embed_dim, self.beta
)
elif self.quantizer_type == "gumbel":
self.gumbel_num_hiddens = emb_dim
self.straight_through = gumbel_straight_through
self.kl_weight = gumbel_kl_weight
self.quantize = GumbelQuantizer(
self.codebook_size,
self.embed_dim,
self.gumbel_num_hiddens,
self.straight_through,
self.kl_weight,
)
self.generator = Generator(
nf, ch_mult, res_blocks, img_size, attn_resolutions, emb_dim
)
if model_path is not None:
chkpt = torch.load(model_path, map_location="cpu")
if "params_ema" in chkpt:
self.load_state_dict(
torch.load(model_path, map_location="cpu")["params_ema"]
)
logger.info(f"vqgan is loaded from: {model_path} [params_ema]")
elif "params" in chkpt:
self.load_state_dict(
torch.load(model_path, map_location="cpu")["params"]
)
logger.info(f"vqgan is loaded from: {model_path} [params]")
else:
raise ValueError("Wrong params!")
def forward(self, x):
x = self.encoder(x)
quant, codebook_loss, quant_stats = self.quantize(x)
x = self.generator(quant)
return x, codebook_loss, quant_stats
def calc_mean_std(feat, eps=1e-5):
"""Calculate mean and std for adaptive_instance_normalization.
Args:
feat (Tensor): 4D tensor.
eps (float): A small value added to the variance to avoid
divide-by-zero. Default: 1e-5.
"""
size = feat.size()
assert len(size) == 4, "The input feature should be 4D tensor."
b, c = size[:2]
feat_var = feat.view(b, c, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().view(b, c, 1, 1)
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
return feat_mean, feat_std
def adaptive_instance_normalization(content_feat, style_feat):
"""Adaptive instance normalization.
Adjust the reference features to have the similar color and illuminations
as those in the degradate features.
Args:
content_feat (Tensor): The reference feature.
style_feat (Tensor): The degradate features.
"""
size = content_feat.size()
style_mean, style_std = calc_mean_std(style_feat)
content_mean, content_std = calc_mean_std(content_feat)
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(
size
)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(
self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, x, mask=None):
if mask is None:
mask = torch.zeros(
(x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool
)
not_mask = ~mask # pylint: disable=invalid-unary-operand-type
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
class TransformerSALayer(nn.Module):
def __init__(
self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"
):
super().__init__()
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
# Implementation of Feedforward model - MLP
self.linear1 = nn.Linear(embed_dim, dim_mlp)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_mlp, embed_dim)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward(
self,
tgt,
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
# self attention
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(
q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
)[0]
tgt = tgt + self.dropout1(tgt2)
# ffn
tgt2 = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout2(tgt2)
return tgt
def normalize(in_channels):
return torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
@torch.jit.script # type: ignore
def swish(x):
return x * torch.sigmoid(x)
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels=None):
super(ResBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.norm1 = normalize(in_channels)
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1 # type: ignore
)
self.norm2 = normalize(out_channels)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1 # type: ignore
)
if self.in_channels != self.out_channels:
self.conv_out = nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0 # type: ignore
)
def forward(self, x_in):
x = x_in
x = self.norm1(x)
x = swish(x)
x = self.conv1(x)
x = self.norm2(x)
x = swish(x)
x = self.conv2(x)
if self.in_channels != self.out_channels:
x_in = self.conv_out(x_in)
return x + x_in
class Fuse_sft_block(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.encode_enc = ResBlock(2 * in_ch, out_ch)
self.scale = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
)
self.shift = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
)
def forward(self, enc_feat, dec_feat, w=1):
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
scale = self.scale(enc_feat)
shift = self.shift(enc_feat)
residual = w * (dec_feat * scale + shift)
out = dec_feat + residual
return out
class CodeFormer(VQAutoEncoder):
def __init__(self, state_dict):
dim_embd = 512
n_head = 8
n_layers = 9
codebook_size = 1024
latent_size = 256
connect_list = ["32", "64", "128", "256"]
fix_modules = ["quantize", "generator"]
# This is just a guess as I only have one model to look at
position_emb = state_dict["position_emb"]
dim_embd = position_emb.shape[1]
latent_size = position_emb.shape[0]
try:
n_layers = len(
set([x.split(".")[1] for x in state_dict.keys() if "ft_layers" in x])
)
except:
pass
codebook_size = state_dict["quantize.embedding.weight"].shape[0]
# This is also just another guess
n_head_exp = (
state_dict["ft_layers.0.self_attn.in_proj_weight"].shape[0] // dim_embd
)
n_head = 2**n_head_exp
in_nc = state_dict["encoder.blocks.0.weight"].shape[1]
self.model_arch = "CodeFormer"
self.sub_type = "Face SR"
self.scale = 8
self.in_nc = in_nc
self.out_nc = in_nc
self.state = state_dict
self.supports_fp16 = False
self.supports_bf16 = True
self.min_size_restriction = 16
super(CodeFormer, self).__init__(
512, 64, [1, 2, 2, 4, 4, 8], "nearest", 2, [16], codebook_size
)
if fix_modules is not None:
for module in fix_modules:
for param in getattr(self, module).parameters():
param.requires_grad = False
self.connect_list = connect_list
self.n_layers = n_layers
self.dim_embd = dim_embd
self.dim_mlp = dim_embd * 2
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd)) # type: ignore
self.feat_emb = nn.Linear(256, self.dim_embd)
# transformer
self.ft_layers = nn.Sequential(
*[
TransformerSALayer(
embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0
)
for _ in range(self.n_layers)
]
)
# logits_predict head
self.idx_pred_layer = nn.Sequential(
nn.LayerNorm(dim_embd), nn.Linear(dim_embd, codebook_size, bias=False)
)
self.channels = {
"16": 512,
"32": 256,
"64": 256,
"128": 128,
"256": 128,
"512": 64,
}
# after second residual block for > 16, before attn layer for ==16
self.fuse_encoder_block = {
"512": 2,
"256": 5,
"128": 8,
"64": 11,
"32": 14,
"16": 18,
}
# after first residual block for > 16, before attn layer for ==16
self.fuse_generator_block = {
"16": 6,
"32": 9,
"64": 12,
"128": 15,
"256": 18,
"512": 21,
}
# fuse_convs_dict
self.fuse_convs_dict = nn.ModuleDict()
for f_size in self.connect_list:
in_ch = self.channels[f_size]
self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
self.load_state_dict(state_dict)
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, x, weight=0.5, **kwargs):
detach_16 = True
code_only = False
adain = True
# ################### Encoder #####################
enc_feat_dict = {}
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.encoder.blocks):
x = block(x)
if i in out_list:
enc_feat_dict[str(x.shape[-1])] = x.clone()
lq_feat = x
# ################# Transformer ###################
# quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
pos_emb = self.position_emb.unsqueeze(1).repeat(1, x.shape[0], 1)
# BCHW -> BC(HW) -> (HW)BC
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2, 0, 1))
query_emb = feat_emb
# Transformer encoder
for layer in self.ft_layers:
query_emb = layer(query_emb, query_pos=pos_emb)
# output logits
logits = self.idx_pred_layer(query_emb) # (hw)bn
logits = logits.permute(1, 0, 2) # (hw)bn -> b(hw)n
if code_only: # for training stage II
# logits doesn't need softmax before cross_entropy loss
return logits, lq_feat
# ################# Quantization ###################
# if self.training:
# quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
# # b(hw)c -> bc(hw) -> bchw
# quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
# ------------
soft_one_hot = F.softmax(logits, dim=2)
_, top_idx = torch.topk(soft_one_hot, 1, dim=2)
quant_feat = self.quantize.get_codebook_feat(
top_idx, shape=[x.shape[0], 16, 16, 256] # type: ignore
)
# preserve gradients
# quant_feat = lq_feat + (quant_feat - lq_feat).detach()
if detach_16:
quant_feat = quant_feat.detach() # for training stage III
if adain:
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
# ################## Generator ####################
x = quant_feat
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.generator.blocks):
x = block(x)
if i in fuse_list: # fuse after i-th block
f_size = str(x.shape[-1])
if weight > 0:
x = self.fuse_convs_dict[f_size](
enc_feat_dict[f_size].detach(), x, weight
)
out = x
# logits doesn't need softmax before cross_entropy loss
# return out, logits, lq_feat
return out, logits
# pylint: skip-file
# type: ignore
# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501
import torch
from torch import nn
from torch.autograd import Function
fused_act_ext = None
class FusedLeakyReLUFunctionBackward(Function):
@staticmethod
def forward(ctx, grad_output, out, negative_slope, scale):
ctx.save_for_backward(out)
ctx.negative_slope = negative_slope
ctx.scale = scale
empty = grad_output.new_empty(0)
grad_input = fused_act_ext.fused_bias_act(
grad_output, empty, out, 3, 1, negative_slope, scale
)
dim = [0]
if grad_input.ndim > 2:
dim += list(range(2, grad_input.ndim))
grad_bias = grad_input.sum(dim).detach()
return grad_input, grad_bias
@staticmethod
def backward(ctx, gradgrad_input, gradgrad_bias):
(out,) = ctx.saved_tensors
gradgrad_out = fused_act_ext.fused_bias_act(
gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
)
return gradgrad_out, None, None, None
class FusedLeakyReLUFunction(Function):
@staticmethod
def forward(ctx, input, bias, negative_slope, scale):
empty = input.new_empty(0)
out = fused_act_ext.fused_bias_act(
input, bias, empty, 3, 0, negative_slope, scale
)
ctx.save_for_backward(out)
ctx.negative_slope = negative_slope
ctx.scale = scale
return out
@staticmethod
def backward(ctx, grad_output):
(out,) = ctx.saved_tensors
grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
grad_output, out, ctx.negative_slope, ctx.scale
)
return grad_input, grad_bias, None, None
class FusedLeakyReLU(nn.Module):
def __init__(self, channel, negative_slope=0.2, scale=2**0.5):
super().__init__()
self.bias = nn.Parameter(torch.zeros(channel))
self.negative_slope = negative_slope
self.scale = scale
def forward(self, input):
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5):
return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
# pylint: skip-file
# type: ignore
import math
import random
import torch
from torch import nn
from .gfpganv1_arch import ResUpBlock
from .stylegan2_bilinear_arch import (
ConvLayer,
EqualConv2d,
EqualLinear,
ResBlock,
ScaledLeakyReLU,
StyleGAN2GeneratorBilinear,
)
class StyleGAN2GeneratorBilinearSFT(StyleGAN2GeneratorBilinear):
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
It is the bilinear version. It does not use the complicated UpFirDnSmooth function that is not friendly for
deployment. It can be easily converted to the clean version: StyleGAN2GeneratorCSFT.
Args:
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
num_mlp (int): Layer number of MLP style layers. Default: 8.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
narrow (float): The narrow ratio for channels. Default: 1.
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
"""
def __init__(
self,
out_size,
num_style_feat=512,
num_mlp=8,
channel_multiplier=2,
lr_mlp=0.01,
narrow=1,
sft_half=False,
):
super(StyleGAN2GeneratorBilinearSFT, self).__init__(
out_size,
num_style_feat=num_style_feat,
num_mlp=num_mlp,
channel_multiplier=channel_multiplier,
lr_mlp=lr_mlp,
narrow=narrow,
)
self.sft_half = sft_half
def forward(
self,
styles,
conditions,
input_is_latent=False,
noise=None,
randomize_noise=True,
truncation=1,
truncation_latent=None,
inject_index=None,
return_latents=False,
):
"""Forward function for StyleGAN2GeneratorBilinearSFT.
Args:
styles (list[Tensor]): Sample codes of styles.
conditions (list[Tensor]): SFT conditions to generators.
input_is_latent (bool): Whether input is latent style. Default: False.
noise (Tensor | None): Input noise or None. Default: None.
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
truncation (float): The truncation ratio. Default: 1.
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
inject_index (int | None): The injection index for mixing noise. Default: None.
return_latents (bool): Whether to return style latents. Default: False.
"""
# style codes -> latents with Style MLP layer
if not input_is_latent:
styles = [self.style_mlp(s) for s in styles]
# noises
if noise is None:
if randomize_noise:
noise = [None] * self.num_layers # for each style conv layer
else: # use the stored noise
noise = [
getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
]
# style truncation
if truncation < 1:
style_truncation = []
for style in styles:
style_truncation.append(
truncation_latent + truncation * (style - truncation_latent)
)
styles = style_truncation
# get style latents with injection
if len(styles) == 1:
inject_index = self.num_latent
if styles[0].ndim < 3:
# repeat latent code for all the layers
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else: # used for encoder with different latent code for each layer
latent = styles[0]
elif len(styles) == 2: # mixing noises
if inject_index is None:
inject_index = random.randint(1, self.num_latent - 1)
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = (
styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
)
latent = torch.cat([latent1, latent2], 1)
# main generation
out = self.constant_input(latent.shape[0])
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
skip = self.to_rgb1(out, latent[:, 1])
i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(
self.style_convs[::2],
self.style_convs[1::2],
noise[1::2],
noise[2::2],
self.to_rgbs,
):
out = conv1(out, latent[:, i], noise=noise1)
# the conditions may have fewer levels
if i < len(conditions):
# SFT part to combine the conditions
if self.sft_half: # only apply SFT to half of the channels
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
out_sft = out_sft * conditions[i - 1] + conditions[i]
out = torch.cat([out_same, out_sft], dim=1)
else: # apply SFT to all the channels
out = out * conditions[i - 1] + conditions[i]
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
i += 2
image = skip
if return_latents:
return image, latent
else:
return image, None
class GFPGANBilinear(nn.Module):
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
It is the bilinear version and it does not use the complicated UpFirDnSmooth function that is not friendly for
deployment. It can be easily converted to the clean version: GFPGANv1Clean.
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
Args:
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
fix_decoder (bool): Whether to fix the decoder. Default: True.
num_mlp (int): Layer number of MLP style layers. Default: 8.
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
input_is_latent (bool): Whether input is latent style. Default: False.
different_w (bool): Whether to use different latent w for different layers. Default: False.
narrow (float): The narrow ratio for channels. Default: 1.
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
"""
def __init__(
self,
out_size,
num_style_feat=512,
channel_multiplier=1,
decoder_load_path=None,
fix_decoder=True,
# for stylegan decoder
num_mlp=8,
lr_mlp=0.01,
input_is_latent=False,
different_w=False,
narrow=1,
sft_half=False,
):
super(GFPGANBilinear, self).__init__()
self.input_is_latent = input_is_latent
self.different_w = different_w
self.num_style_feat = num_style_feat
self.min_size_restriction = 512
unet_narrow = narrow * 0.5 # by default, use a half of input channels
channels = {
"4": int(512 * unet_narrow),
"8": int(512 * unet_narrow),
"16": int(512 * unet_narrow),
"32": int(512 * unet_narrow),
"64": int(256 * channel_multiplier * unet_narrow),
"128": int(128 * channel_multiplier * unet_narrow),
"256": int(64 * channel_multiplier * unet_narrow),
"512": int(32 * channel_multiplier * unet_narrow),
"1024": int(16 * channel_multiplier * unet_narrow),
}
self.log_size = int(math.log(out_size, 2))
first_out_size = 2 ** (int(math.log(out_size, 2)))
self.conv_body_first = ConvLayer(
3, channels[f"{first_out_size}"], 1, bias=True, activate=True
)
# downsample
in_channels = channels[f"{first_out_size}"]
self.conv_body_down = nn.ModuleList()
for i in range(self.log_size, 2, -1):
out_channels = channels[f"{2**(i - 1)}"]
self.conv_body_down.append(ResBlock(in_channels, out_channels))
in_channels = out_channels
self.final_conv = ConvLayer(
in_channels, channels["4"], 3, bias=True, activate=True
)
# upsample
in_channels = channels["4"]
self.conv_body_up = nn.ModuleList()
for i in range(3, self.log_size + 1):
out_channels = channels[f"{2**i}"]
self.conv_body_up.append(ResUpBlock(in_channels, out_channels))
in_channels = out_channels
# to RGB
self.toRGB = nn.ModuleList()
for i in range(3, self.log_size + 1):
self.toRGB.append(
EqualConv2d(
channels[f"{2**i}"],
3,
1,
stride=1,
padding=0,
bias=True,
bias_init_val=0,
)
)
if different_w:
linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
else:
linear_out_channel = num_style_feat
self.final_linear = EqualLinear(
channels["4"] * 4 * 4,
linear_out_channel,
bias=True,
bias_init_val=0,
lr_mul=1,
activation=None,
)
# the decoder: stylegan2 generator with SFT modulations
self.stylegan_decoder = StyleGAN2GeneratorBilinearSFT(
out_size=out_size,
num_style_feat=num_style_feat,
num_mlp=num_mlp,
channel_multiplier=channel_multiplier,
lr_mlp=lr_mlp,
narrow=narrow,
sft_half=sft_half,
)
# load pre-trained stylegan2 model if necessary
if decoder_load_path:
self.stylegan_decoder.load_state_dict(
torch.load(
decoder_load_path, map_location=lambda storage, loc: storage
)["params_ema"]
)
# fix decoder without updating params
if fix_decoder:
for _, param in self.stylegan_decoder.named_parameters():
param.requires_grad = False
# for SFT modulations (scale and shift)
self.condition_scale = nn.ModuleList()
self.condition_shift = nn.ModuleList()
for i in range(3, self.log_size + 1):
out_channels = channels[f"{2**i}"]
if sft_half:
sft_out_channels = out_channels
else:
sft_out_channels = out_channels * 2
self.condition_scale.append(
nn.Sequential(
EqualConv2d(
out_channels,
out_channels,
3,
stride=1,
padding=1,
bias=True,
bias_init_val=0,
),
ScaledLeakyReLU(0.2),
EqualConv2d(
out_channels,
sft_out_channels,
3,
stride=1,
padding=1,
bias=True,
bias_init_val=1,
),
)
)
self.condition_shift.append(
nn.Sequential(
EqualConv2d(
out_channels,
out_channels,
3,
stride=1,
padding=1,
bias=True,
bias_init_val=0,
),
ScaledLeakyReLU(0.2),
EqualConv2d(
out_channels,
sft_out_channels,
3,
stride=1,
padding=1,
bias=True,
bias_init_val=0,
),
)
)
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
"""Forward function for GFPGANBilinear.
Args:
x (Tensor): Input images.
return_latents (bool): Whether to return style latents. Default: False.
return_rgb (bool): Whether return intermediate rgb images. Default: True.
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
"""
conditions = []
unet_skips = []
out_rgbs = []
# encoder
feat = self.conv_body_first(x)
for i in range(self.log_size - 2):
feat = self.conv_body_down[i](feat)
unet_skips.insert(0, feat)
feat = self.final_conv(feat)
# style code
style_code = self.final_linear(feat.view(feat.size(0), -1))
if self.different_w:
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
# decode
for i in range(self.log_size - 2):
# add unet skip
feat = feat + unet_skips[i]
# ResUpLayer
feat = self.conv_body_up[i](feat)
# generate scale and shift for SFT layers
scale = self.condition_scale[i](feat)
conditions.append(scale.clone())
shift = self.condition_shift[i](feat)
conditions.append(shift.clone())
# generate rgb images
if return_rgb:
out_rgbs.append(self.toRGB[i](feat))
# decoder
image, _ = self.stylegan_decoder(
[style_code],
conditions,
return_latents=return_latents,
input_is_latent=self.input_is_latent,
randomize_noise=randomize_noise,
)
return image, out_rgbs
# pylint: skip-file
# type: ignore
import math
import random
import torch
from torch import nn
from torch.nn import functional as F
from .fused_act import FusedLeakyReLU
from .stylegan2_arch import (
ConvLayer,
EqualConv2d,
EqualLinear,
ResBlock,
ScaledLeakyReLU,
StyleGAN2Generator,
)
class StyleGAN2GeneratorSFT(StyleGAN2Generator):
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
Args:
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
num_mlp (int): Layer number of MLP style layers. Default: 8.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
narrow (float): The narrow ratio for channels. Default: 1.
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
"""
def __init__(
self,
out_size,
num_style_feat=512,
num_mlp=8,
channel_multiplier=2,
resample_kernel=(1, 3, 3, 1),
lr_mlp=0.01,
narrow=1,
sft_half=False,
):
super(StyleGAN2GeneratorSFT, self).__init__(
out_size,
num_style_feat=num_style_feat,
num_mlp=num_mlp,
channel_multiplier=channel_multiplier,
resample_kernel=resample_kernel,
lr_mlp=lr_mlp,
narrow=narrow,
)
self.sft_half = sft_half
def forward(
self,
styles,
conditions,
input_is_latent=False,
noise=None,
randomize_noise=True,
truncation=1,
truncation_latent=None,
inject_index=None,
return_latents=False,
):
"""Forward function for StyleGAN2GeneratorSFT.
Args:
styles (list[Tensor]): Sample codes of styles.
conditions (list[Tensor]): SFT conditions to generators.
input_is_latent (bool): Whether input is latent style. Default: False.
noise (Tensor | None): Input noise or None. Default: None.
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
truncation (float): The truncation ratio. Default: 1.
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
inject_index (int | None): The injection index for mixing noise. Default: None.
return_latents (bool): Whether to return style latents. Default: False.
"""
# style codes -> latents with Style MLP layer
if not input_is_latent:
styles = [self.style_mlp(s) for s in styles]
# noises
if noise is None:
if randomize_noise:
noise = [None] * self.num_layers # for each style conv layer
else: # use the stored noise
noise = [
getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
]
# style truncation
if truncation < 1:
style_truncation = []
for style in styles:
style_truncation.append(
truncation_latent + truncation * (style - truncation_latent)
)
styles = style_truncation
# get style latents with injection
if len(styles) == 1:
inject_index = self.num_latent
if styles[0].ndim < 3:
# repeat latent code for all the layers
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else: # used for encoder with different latent code for each layer
latent = styles[0]
elif len(styles) == 2: # mixing noises
if inject_index is None:
inject_index = random.randint(1, self.num_latent - 1)
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = (
styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
)
latent = torch.cat([latent1, latent2], 1)
# main generation
out = self.constant_input(latent.shape[0])
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
skip = self.to_rgb1(out, latent[:, 1])
i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(
self.style_convs[::2],
self.style_convs[1::2],
noise[1::2],
noise[2::2],
self.to_rgbs,
):
out = conv1(out, latent[:, i], noise=noise1)
# the conditions may have fewer levels
if i < len(conditions):
# SFT part to combine the conditions
if self.sft_half: # only apply SFT to half of the channels
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
out_sft = out_sft * conditions[i - 1] + conditions[i]
out = torch.cat([out_same, out_sft], dim=1)
else: # apply SFT to all the channels
out = out * conditions[i - 1] + conditions[i]
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
i += 2
image = skip
if return_latents:
return image, latent
else:
return image, None
class ConvUpLayer(nn.Module):
"""Convolutional upsampling layer. It uses bilinear upsampler + Conv.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
kernel_size (int): Size of the convolving kernel.
stride (int): Stride of the convolution. Default: 1
padding (int): Zero-padding added to both sides of the input. Default: 0.
bias (bool): If ``True``, adds a learnable bias to the output. Default: ``True``.
bias_init_val (float): Bias initialized value. Default: 0.
activate (bool): Whether use activateion. Default: True.
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
bias=True,
bias_init_val=0,
activate=True,
):
super(ConvUpLayer, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
# self.scale is used to scale the convolution weights, which is related to the common initializations.
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
self.weight = nn.Parameter(
torch.randn(out_channels, in_channels, kernel_size, kernel_size)
)
if bias and not activate:
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
else:
self.register_parameter("bias", None)
# activation
if activate:
if bias:
self.activation = FusedLeakyReLU(out_channels)
else:
self.activation = ScaledLeakyReLU(0.2)
else:
self.activation = None
def forward(self, x):
# bilinear upsample
out = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)
# conv
out = F.conv2d(
out,
self.weight * self.scale,
bias=self.bias,
stride=self.stride,
padding=self.padding,
)
# activation
if self.activation is not None:
out = self.activation(out)
return out
class ResUpBlock(nn.Module):
"""Residual block with upsampling.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
"""
def __init__(self, in_channels, out_channels):
super(ResUpBlock, self).__init__()
self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
self.conv2 = ConvUpLayer(
in_channels, out_channels, 3, stride=1, padding=1, bias=True, activate=True
)
self.skip = ConvUpLayer(
in_channels, out_channels, 1, bias=False, activate=False
)
def forward(self, x):
out = self.conv1(x)
out = self.conv2(out)
skip = self.skip(x)
out = (out + skip) / math.sqrt(2)
return out
class GFPGANv1(nn.Module):
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
Args:
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
fix_decoder (bool): Whether to fix the decoder. Default: True.
num_mlp (int): Layer number of MLP style layers. Default: 8.
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
input_is_latent (bool): Whether input is latent style. Default: False.
different_w (bool): Whether to use different latent w for different layers. Default: False.
narrow (float): The narrow ratio for channels. Default: 1.
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
"""
def __init__(
self,
out_size,
num_style_feat=512,
channel_multiplier=1,
resample_kernel=(1, 3, 3, 1),
decoder_load_path=None,
fix_decoder=True,
# for stylegan decoder
num_mlp=8,
lr_mlp=0.01,
input_is_latent=False,
different_w=False,
narrow=1,
sft_half=False,
):
super(GFPGANv1, self).__init__()
self.input_is_latent = input_is_latent
self.different_w = different_w
self.num_style_feat = num_style_feat
unet_narrow = narrow * 0.5 # by default, use a half of input channels
channels = {
"4": int(512 * unet_narrow),
"8": int(512 * unet_narrow),
"16": int(512 * unet_narrow),
"32": int(512 * unet_narrow),
"64": int(256 * channel_multiplier * unet_narrow),
"128": int(128 * channel_multiplier * unet_narrow),
"256": int(64 * channel_multiplier * unet_narrow),
"512": int(32 * channel_multiplier * unet_narrow),
"1024": int(16 * channel_multiplier * unet_narrow),
}
self.log_size = int(math.log(out_size, 2))
first_out_size = 2 ** (int(math.log(out_size, 2)))
self.conv_body_first = ConvLayer(
3, channels[f"{first_out_size}"], 1, bias=True, activate=True
)
# downsample
in_channels = channels[f"{first_out_size}"]
self.conv_body_down = nn.ModuleList()
for i in range(self.log_size, 2, -1):
out_channels = channels[f"{2**(i - 1)}"]
self.conv_body_down.append(
ResBlock(in_channels, out_channels, resample_kernel)
)
in_channels = out_channels
self.final_conv = ConvLayer(
in_channels, channels["4"], 3, bias=True, activate=True
)
# upsample
in_channels = channels["4"]
self.conv_body_up = nn.ModuleList()
for i in range(3, self.log_size + 1):
out_channels = channels[f"{2**i}"]
self.conv_body_up.append(ResUpBlock(in_channels, out_channels))
in_channels = out_channels
# to RGB
self.toRGB = nn.ModuleList()
for i in range(3, self.log_size + 1):
self.toRGB.append(
EqualConv2d(
channels[f"{2**i}"],
3,
1,
stride=1,
padding=0,
bias=True,
bias_init_val=0,
)
)
if different_w:
linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
else:
linear_out_channel = num_style_feat
self.final_linear = EqualLinear(
channels["4"] * 4 * 4,
linear_out_channel,
bias=True,
bias_init_val=0,
lr_mul=1,
activation=None,
)
# the decoder: stylegan2 generator with SFT modulations
self.stylegan_decoder = StyleGAN2GeneratorSFT(
out_size=out_size,
num_style_feat=num_style_feat,
num_mlp=num_mlp,
channel_multiplier=channel_multiplier,
resample_kernel=resample_kernel,
lr_mlp=lr_mlp,
narrow=narrow,
sft_half=sft_half,
)
# load pre-trained stylegan2 model if necessary
if decoder_load_path:
self.stylegan_decoder.load_state_dict(
torch.load(
decoder_load_path, map_location=lambda storage, loc: storage
)["params_ema"]
)
# fix decoder without updating params
if fix_decoder:
for _, param in self.stylegan_decoder.named_parameters():
param.requires_grad = False
# for SFT modulations (scale and shift)
self.condition_scale = nn.ModuleList()
self.condition_shift = nn.ModuleList()
for i in range(3, self.log_size + 1):
out_channels = channels[f"{2**i}"]
if sft_half:
sft_out_channels = out_channels
else:
sft_out_channels = out_channels * 2
self.condition_scale.append(
nn.Sequential(
EqualConv2d(
out_channels,
out_channels,
3,
stride=1,
padding=1,
bias=True,
bias_init_val=0,
),
ScaledLeakyReLU(0.2),
EqualConv2d(
out_channels,
sft_out_channels,
3,
stride=1,
padding=1,
bias=True,
bias_init_val=1,
),
)
)
self.condition_shift.append(
nn.Sequential(
EqualConv2d(
out_channels,
out_channels,
3,
stride=1,
padding=1,
bias=True,
bias_init_val=0,
),
ScaledLeakyReLU(0.2),
EqualConv2d(
out_channels,
sft_out_channels,
3,
stride=1,
padding=1,
bias=True,
bias_init_val=0,
),
)
)
def forward(
self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs
):
"""Forward function for GFPGANv1.
Args:
x (Tensor): Input images.
return_latents (bool): Whether to return style latents. Default: False.
return_rgb (bool): Whether return intermediate rgb images. Default: True.
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
"""
conditions = []
unet_skips = []
out_rgbs = []
# encoder
feat = self.conv_body_first(x)
for i in range(self.log_size - 2):
feat = self.conv_body_down[i](feat)
unet_skips.insert(0, feat)
feat = self.final_conv(feat)
# style code
style_code = self.final_linear(feat.view(feat.size(0), -1))
if self.different_w:
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
# decode
for i in range(self.log_size - 2):
# add unet skip
feat = feat + unet_skips[i]
# ResUpLayer
feat = self.conv_body_up[i](feat)
# generate scale and shift for SFT layers
scale = self.condition_scale[i](feat)
conditions.append(scale.clone())
shift = self.condition_shift[i](feat)
conditions.append(shift.clone())
# generate rgb images
if return_rgb:
out_rgbs.append(self.toRGB[i](feat))
# decoder
image, _ = self.stylegan_decoder(
[style_code],
conditions,
return_latents=return_latents,
input_is_latent=self.input_is_latent,
randomize_noise=randomize_noise,
)
return image, out_rgbs
class FacialComponentDiscriminator(nn.Module):
"""Facial component (eyes, mouth, noise) discriminator used in GFPGAN."""
def __init__(self):
super(FacialComponentDiscriminator, self).__init__()
# It now uses a VGG-style architectrue with fixed model size
self.conv1 = ConvLayer(
3,
64,
3,
downsample=False,
resample_kernel=(1, 3, 3, 1),
bias=True,
activate=True,
)
self.conv2 = ConvLayer(
64,
128,
3,
downsample=True,
resample_kernel=(1, 3, 3, 1),
bias=True,
activate=True,
)
self.conv3 = ConvLayer(
128,
128,
3,
downsample=False,
resample_kernel=(1, 3, 3, 1),
bias=True,
activate=True,
)
self.conv4 = ConvLayer(
128,
256,
3,
downsample=True,
resample_kernel=(1, 3, 3, 1),
bias=True,
activate=True,
)
self.conv5 = ConvLayer(
256,
256,
3,
downsample=False,
resample_kernel=(1, 3, 3, 1),
bias=True,
activate=True,
)
self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False)
def forward(self, x, return_feats=False, **kwargs):
"""Forward function for FacialComponentDiscriminator.
Args:
x (Tensor): Input images.
return_feats (bool): Whether to return intermediate features. Default: False.
"""
feat = self.conv1(x)
feat = self.conv3(self.conv2(feat))
rlt_feats = []
if return_feats:
rlt_feats.append(feat.clone())
feat = self.conv5(self.conv4(feat))
if return_feats:
rlt_feats.append(feat.clone())
out = self.final_conv(feat)
if return_feats:
return out, rlt_feats
else:
return out, None
# pylint: skip-file
# type: ignore
import math
import random
import torch
from torch import nn
from torch.nn import functional as F
from .stylegan2_clean_arch import StyleGAN2GeneratorClean
class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
Args:
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
num_mlp (int): Layer number of MLP style layers. Default: 8.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
narrow (float): The narrow ratio for channels. Default: 1.
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
"""
def __init__(
self,
out_size,
num_style_feat=512,
num_mlp=8,
channel_multiplier=2,
narrow=1,
sft_half=False,
):
super(StyleGAN2GeneratorCSFT, self).__init__(
out_size,
num_style_feat=num_style_feat,
num_mlp=num_mlp,
channel_multiplier=channel_multiplier,
narrow=narrow,
)
self.sft_half = sft_half
def forward(
self,
styles,
conditions,
input_is_latent=False,
noise=None,
randomize_noise=True,
truncation=1,
truncation_latent=None,
inject_index=None,
return_latents=False,
):
"""Forward function for StyleGAN2GeneratorCSFT.
Args:
styles (list[Tensor]): Sample codes of styles.
conditions (list[Tensor]): SFT conditions to generators.
input_is_latent (bool): Whether input is latent style. Default: False.
noise (Tensor | None): Input noise or None. Default: None.
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
truncation (float): The truncation ratio. Default: 1.
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
inject_index (int | None): The injection index for mixing noise. Default: None.
return_latents (bool): Whether to return style latents. Default: False.
"""
# style codes -> latents with Style MLP layer
if not input_is_latent:
styles = [self.style_mlp(s) for s in styles]
# noises
if noise is None:
if randomize_noise:
noise = [None] * self.num_layers # for each style conv layer
else: # use the stored noise
noise = [
getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
]
# style truncation
if truncation < 1:
style_truncation = []
for style in styles:
style_truncation.append(
truncation_latent + truncation * (style - truncation_latent)
)
styles = style_truncation
# get style latents with injection
if len(styles) == 1:
inject_index = self.num_latent
if styles[0].ndim < 3:
# repeat latent code for all the layers
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else: # used for encoder with different latent code for each layer
latent = styles[0]
elif len(styles) == 2: # mixing noises
if inject_index is None:
inject_index = random.randint(1, self.num_latent - 1)
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = (
styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
)
latent = torch.cat([latent1, latent2], 1)
# main generation
out = self.constant_input(latent.shape[0])
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
skip = self.to_rgb1(out, latent[:, 1])
i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(
self.style_convs[::2],
self.style_convs[1::2],
noise[1::2],
noise[2::2],
self.to_rgbs,
):
out = conv1(out, latent[:, i], noise=noise1)
# the conditions may have fewer levels
if i < len(conditions):
# SFT part to combine the conditions
if self.sft_half: # only apply SFT to half of the channels
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
out_sft = out_sft * conditions[i - 1] + conditions[i]
out = torch.cat([out_same, out_sft], dim=1)
else: # apply SFT to all the channels
out = out * conditions[i - 1] + conditions[i]
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
i += 2
image = skip
if return_latents:
return image, latent
else:
return image, None
class ResBlock(nn.Module):
"""Residual block with bilinear upsampling/downsampling.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
mode (str): Upsampling/downsampling mode. Options: down | up. Default: down.
"""
def __init__(self, in_channels, out_channels, mode="down"):
super(ResBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False)
if mode == "down":
self.scale_factor = 0.5
elif mode == "up":
self.scale_factor = 2
def forward(self, x):
out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
# upsample/downsample
out = F.interpolate(
out, scale_factor=self.scale_factor, mode="bilinear", align_corners=False
)
out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
# skip
x = F.interpolate(
x, scale_factor=self.scale_factor, mode="bilinear", align_corners=False
)
skip = self.skip(x)
out = out + skip
return out
class GFPGANv1Clean(nn.Module):
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
Args:
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
fix_decoder (bool): Whether to fix the decoder. Default: True.
num_mlp (int): Layer number of MLP style layers. Default: 8.
input_is_latent (bool): Whether input is latent style. Default: False.
different_w (bool): Whether to use different latent w for different layers. Default: False.
narrow (float): The narrow ratio for channels. Default: 1.
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
"""
def __init__(
self,
state_dict,
):
super(GFPGANv1Clean, self).__init__()
out_size = 512
num_style_feat = 512
channel_multiplier = 2
decoder_load_path = None
fix_decoder = False
num_mlp = 8
input_is_latent = True
different_w = True
narrow = 1
sft_half = True
self.model_arch = "GFPGAN"
self.sub_type = "Face SR"
self.scale = 8
self.in_nc = 3
self.out_nc = 3
self.state = state_dict
self.supports_fp16 = False
self.supports_bf16 = True
self.min_size_restriction = 512
self.input_is_latent = input_is_latent
self.different_w = different_w
self.num_style_feat = num_style_feat
unet_narrow = narrow * 0.5 # by default, use a half of input channels
channels = {
"4": int(512 * unet_narrow),
"8": int(512 * unet_narrow),
"16": int(512 * unet_narrow),
"32": int(512 * unet_narrow),
"64": int(256 * channel_multiplier * unet_narrow),
"128": int(128 * channel_multiplier * unet_narrow),
"256": int(64 * channel_multiplier * unet_narrow),
"512": int(32 * channel_multiplier * unet_narrow),
"1024": int(16 * channel_multiplier * unet_narrow),
}
self.log_size = int(math.log(out_size, 2))
first_out_size = 2 ** (int(math.log(out_size, 2)))
self.conv_body_first = nn.Conv2d(3, channels[f"{first_out_size}"], 1)
# downsample
in_channels = channels[f"{first_out_size}"]
self.conv_body_down = nn.ModuleList()
for i in range(self.log_size, 2, -1):
out_channels = channels[f"{2**(i - 1)}"]
self.conv_body_down.append(ResBlock(in_channels, out_channels, mode="down"))
in_channels = out_channels
self.final_conv = nn.Conv2d(in_channels, channels["4"], 3, 1, 1)
# upsample
in_channels = channels["4"]
self.conv_body_up = nn.ModuleList()
for i in range(3, self.log_size + 1):
out_channels = channels[f"{2**i}"]
self.conv_body_up.append(ResBlock(in_channels, out_channels, mode="up"))
in_channels = out_channels
# to RGB
self.toRGB = nn.ModuleList()
for i in range(3, self.log_size + 1):
self.toRGB.append(nn.Conv2d(channels[f"{2**i}"], 3, 1))
if different_w:
linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
else:
linear_out_channel = num_style_feat
self.final_linear = nn.Linear(channels["4"] * 4 * 4, linear_out_channel)
# the decoder: stylegan2 generator with SFT modulations
self.stylegan_decoder = StyleGAN2GeneratorCSFT(
out_size=out_size,
num_style_feat=num_style_feat,
num_mlp=num_mlp,
channel_multiplier=channel_multiplier,
narrow=narrow,
sft_half=sft_half,
)
# load pre-trained stylegan2 model if necessary
if decoder_load_path:
self.stylegan_decoder.load_state_dict(
torch.load(
decoder_load_path, map_location=lambda storage, loc: storage
)["params_ema"]
)
# fix decoder without updating params
if fix_decoder:
for _, param in self.stylegan_decoder.named_parameters():
param.requires_grad = False
# for SFT modulations (scale and shift)
self.condition_scale = nn.ModuleList()
self.condition_shift = nn.ModuleList()
for i in range(3, self.log_size + 1):
out_channels = channels[f"{2**i}"]
if sft_half:
sft_out_channels = out_channels
else:
sft_out_channels = out_channels * 2
self.condition_scale.append(
nn.Sequential(
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1),
)
)
self.condition_shift.append(
nn.Sequential(
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1),
)
)
self.load_state_dict(state_dict)
def forward(
self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs
):
"""Forward function for GFPGANv1Clean.
Args:
x (Tensor): Input images.
return_latents (bool): Whether to return style latents. Default: False.
return_rgb (bool): Whether return intermediate rgb images. Default: True.
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
"""
conditions = []
unet_skips = []
out_rgbs = []
# encoder
feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2)
for i in range(self.log_size - 2):
feat = self.conv_body_down[i](feat)
unet_skips.insert(0, feat)
feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2)
# style code
style_code = self.final_linear(feat.view(feat.size(0), -1))
if self.different_w:
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
# decode
for i in range(self.log_size - 2):
# add unet skip
feat = feat + unet_skips[i]
# ResUpLayer
feat = self.conv_body_up[i](feat)
# generate scale and shift for SFT layers
scale = self.condition_scale[i](feat)
conditions.append(scale.clone())
shift = self.condition_shift[i](feat)
conditions.append(shift.clone())
# generate rgb images
if return_rgb:
out_rgbs.append(self.toRGB[i](feat))
# decoder
image, _ = self.stylegan_decoder(
[style_code],
conditions,
return_latents=return_latents,
input_is_latent=self.input_is_latent,
randomize_noise=randomize_noise,
)
return image, out_rgbs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment