"git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "bb121214c2daaff0f82483ee83ebc8ad636ef7ff"
Commit 9e459ea3 authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from .gdn import GDN
class MaskedConv2d(nn.Conv2d):
r"""Masked 2D convolution implementation, mask future "unseen" pixels.
Useful for building auto-regressive network components.
Introduced in `"Conditional Image Generation with PixelCNN Decoders"
<https://arxiv.org/abs/1606.05328>`_.
Inherits the same arguments as a `nn.Conv2d`. Use `mask_type='A'` for the
first layer (which also masks the "current pixel"), `mask_type='B'` for the
following layers.
"""
def __init__(self, *args, mask_type="A", **kwargs):
super().__init__(*args, **kwargs)
if mask_type not in ("A", "B"):
raise ValueError(f'Invalid "mask_type" value "{mask_type}"')
self.register_buffer("mask", torch.ones_like(self.weight.data))
_, _, h, w = self.mask.size()
self.mask[:, :, h // 2, w // 2 + (mask_type == "B") :] = 0
self.mask[:, :, h // 2 + 1 :] = 0
def forward(self, x):
# TODO(begaintj): weight assigment is not supported by torchscript
self.weight.data *= self.mask
return super().forward(x)
def conv3x3(in_ch, out_ch, stride=1):
"""3x3 convolution with padding."""
return nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1)
def subpel_conv3x3(in_ch, out_ch, r=1):
"""3x3 sub-pixel convolution for up-sampling."""
return nn.Sequential(
nn.Conv2d(in_ch, out_ch * r ** 2, kernel_size=3, padding=1), nn.PixelShuffle(r)
)
def conv1x1(in_ch, out_ch, stride=1):
"""1x1 convolution."""
return nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride)
class ResidualBlockWithStride(nn.Module):
"""Residual block with a stride on the first convolution.
Args:
in_ch (int): number of input channels
out_ch (int): number of output channels
stride (int): stride value (default: 2)
"""
def __init__(self, in_ch, out_ch, stride=2):
super().__init__()
self.conv1 = conv3x3(in_ch, out_ch, stride=stride)
self.leaky_relu = nn.LeakyReLU(inplace=True)
self.conv2 = conv3x3(out_ch, out_ch)
self.gdn = GDN(out_ch)
if stride != 1 or in_ch != out_ch:
self.skip = conv1x1(in_ch, out_ch, stride=stride)
else:
self.skip = None
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.leaky_relu(out)
out = self.conv2(out)
out = self.gdn(out)
if self.skip is not None:
identity = self.skip(x)
out += identity
return out
class ResidualBlockUpsample(nn.Module):
"""Residual block with sub-pixel upsampling on the last convolution.
Args:
in_ch (int): number of input channels
out_ch (int): number of output channels
upsample (int): upsampling factor (default: 2)
"""
def __init__(self, in_ch, out_ch, upsample=2):
super().__init__()
self.subpel_conv = subpel_conv3x3(in_ch, out_ch, upsample)
self.leaky_relu = nn.LeakyReLU(inplace=True)
self.conv = conv3x3(out_ch, out_ch)
self.igdn = GDN(out_ch, inverse=True)
self.upsample = subpel_conv3x3(in_ch, out_ch, upsample)
def forward(self, x):
identity = x
out = self.subpel_conv(x)
out = self.leaky_relu(out)
out = self.conv(out)
out = self.igdn(out)
identity = self.upsample(x)
out += identity
return out
class ResidualBlock(nn.Module):
"""Simple residual block with two 3x3 convolutions.
Args:
in_ch (int): number of input channels
out_ch (int): number of output channels
"""
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv1 = conv3x3(in_ch, out_ch)
self.leaky_relu = nn.LeakyReLU(inplace=True)
self.conv2 = conv3x3(out_ch, out_ch)
if in_ch != out_ch:
self.skip = conv1x1(in_ch, out_ch)
else:
self.skip = None
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.leaky_relu(out)
out = self.conv2(out)
out = self.leaky_relu(out)
if self.skip is not None:
identity = self.skip(x)
out = out + identity
return out
class AttentionBlock(nn.Module):
"""Self attention block.
Simplified variant from `"Learned Image Compression with
Discretized Gaussian Mixture Likelihoods and Attention Modules"
<https://arxiv.org/abs/2001.01568>`_, by Zhengxue Cheng, Heming Sun, Masaru
Takeuchi, Jiro Katto.
Args:
N (int): Number of channels)
"""
def __init__(self, N):
super().__init__()
class ResidualUnit(nn.Module):
"""Simple residual unit."""
def __init__(self):
super().__init__()
self.conv = nn.Sequential(
conv1x1(N, N // 2),
nn.ReLU(inplace=True),
conv3x3(N // 2, N // 2),
nn.ReLU(inplace=True),
conv1x1(N // 2, N),
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
identity = x
out = self.conv(x)
out += identity
out = self.relu(out)
return out
self.conv_a = nn.Sequential(ResidualUnit(), ResidualUnit(), ResidualUnit())
self.conv_b = nn.Sequential(
ResidualUnit(),
ResidualUnit(),
ResidualUnit(),
conv1x1(N, N),
)
def forward(self, x):
identity = x
a = self.conv_a(x)
b = self.conv_b(x)
out = a * torch.sigmoid(b)
out += identity
return out
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .priors import *
from .waseda import *
from .ours import *
import numpy as np
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import warnings
from compressai.layers import *
class AttModule(nn.Module):
def __init__(self, N):
super(AttModule, self).__init__()
self.forw_att = AttentionBlock(N)
self.back_att = AttentionBlock(N)
def forward(self, x, rev=False):
if not rev:
return self.forw_att(x)
else:
return self.back_att(x)
class EnhModule(nn.Module):
def __init__(self, nf):
super(EnhModule, self).__init__()
self.forw_enh = EnhBlock(nf)
self.back_enh = EnhBlock(nf)
def forward(self, x, rev=False):
if not rev:
return self.forw_enh(x)
else:
return self.back_enh(x)
class EnhBlock(nn.Module):
def __init__(self, nf):
super(EnhBlock, self).__init__()
self.layers = nn.Sequential(
DenseBlock(3, nf),
nn.Conv2d(nf, nf, kernel_size=1, stride=1, padding=0, bias=True),
nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=True),
nn.Conv2d(nf, nf, kernel_size=1, stride=1, padding=0, bias=True),
DenseBlock(nf, 3)
)
def forward(self, x):
return x + self.layers(x) * 0.2
class InvComp(nn.Module):
def __init__(self, M):
super(InvComp, self).__init__()
self.in_nc = 3
self.out_nc = M
self.operations = nn.ModuleList()
# 1st level
b = SqueezeLayer(2)
self.operations.append(b)
self.in_nc *= 4
b = InvertibleConv1x1(self.in_nc)
self.operations.append(b)
b = CouplingLayer(self.in_nc // 4, 3 * self.in_nc // 4, 5)
self.operations.append(b)
b = CouplingLayer(self.in_nc // 4, 3 * self.in_nc // 4, 5)
self.operations.append(b)
b = CouplingLayer(self.in_nc // 4, 3 * self.in_nc // 4, 5)
self.operations.append(b)
# 2nd level
b = SqueezeLayer(2)
self.operations.append(b)
self.in_nc *= 4
b = InvertibleConv1x1(self.in_nc)
self.operations.append(b)
b = CouplingLayer(self.in_nc // 4, 3 * self.in_nc // 4, 5)
self.operations.append(b)
b = CouplingLayer(self.in_nc // 4, 3 * self.in_nc // 4, 5)
self.operations.append(b)
b = CouplingLayer(self.in_nc // 4, 3 * self.in_nc // 4, 5)
self.operations.append(b)
# 3rd level
b = SqueezeLayer(2)
self.operations.append(b)
self.in_nc *= 4
b = InvertibleConv1x1(self.in_nc)
self.operations.append(b)
b = CouplingLayer(self.in_nc // 4, 3 * self.in_nc // 4, 3)
self.operations.append(b)
b = CouplingLayer(self.in_nc // 4, 3 * self.in_nc // 4, 3)
self.operations.append(b)
b = CouplingLayer(self.in_nc // 4, 3 * self.in_nc // 4, 3)
self.operations.append(b)
# 4th level
b = SqueezeLayer(2)
self.operations.append(b)
self.in_nc *= 4
b = InvertibleConv1x1(self.in_nc)
self.operations.append(b)
b = CouplingLayer(self.in_nc // 4, 3 * self.in_nc // 4, 3)
self.operations.append(b)
b = CouplingLayer(self.in_nc // 4, 3 * self.in_nc // 4, 3)
self.operations.append(b)
b = CouplingLayer(self.in_nc // 4, 3 * self.in_nc // 4, 3)
self.operations.append(b)
def forward(self, x, rev=False):
if not rev:
for op in self.operations:
x = op.forward(x, False)
b, c, h, w = x.size()
x = torch.mean(x.view(b, c//self.out_nc, self.out_nc, h, w), dim=1)
else:
times = self.in_nc // self.out_nc
x = x.repeat(1, times, 1, 1)
for op in reversed(self.operations):
x = op.forward(x, True)
return x
class CouplingLayer(nn.Module):
def __init__(self, split_len1, split_len2, kernal_size, clamp=1.0):
super(CouplingLayer, self).__init__()
self.split_len1 = split_len1
self.split_len2 = split_len2
self.clamp = clamp
self.G1 = Bottleneck(self.split_len1, self.split_len2, kernal_size)
self.G2 = Bottleneck(self.split_len2, self.split_len1, kernal_size)
self.H1 = Bottleneck(self.split_len1, self.split_len2, kernal_size)
self.H2 = Bottleneck(self.split_len2, self.split_len1, kernal_size)
def forward(self, x, rev=False):
x1, x2 = (x.narrow(1, 0, self.split_len1), x.narrow(1, self.split_len1, self.split_len2))
if not rev:
y1 = x1.mul(torch.exp( self.clamp * (torch.sigmoid(self.G2(x2)) * 2 - 1) )) + self.H2(x2)
y2 = x2.mul(torch.exp( self.clamp * (torch.sigmoid(self.G1(y1)) * 2 - 1) )) + self.H1(y1)
else:
y2 = (x2 - self.H1(x1)).div(torch.exp( self.clamp * (torch.sigmoid(self.G1(x1)) * 2 - 1) ))
y1 = (x1 - self.H2(y2)).div(torch.exp( self.clamp * (torch.sigmoid(self.G2(y2)) * 2 - 1) ))
return torch.cat((y1, y2), 1)
class Bottleneck(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size):
super(Bottleneck, self).__init__()
# P = ((S-1)*W-S+F)/2, with F = filter size, S = stride
padding = (kernel_size - 1) // 2
self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding)
self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=1)
self.conv3 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
initialize_weights_xavier([self.conv1, self.conv2], 0.1)
initialize_weights(self.conv3, 0)
def forward(self, x):
conv1 = self.lrelu(self.conv1(x))
conv2 = self.lrelu(self.conv2(conv1))
conv3 = self.conv3(conv2)
return conv3
class SqueezeLayer(nn.Module):
def __init__(self, factor):
super().__init__()
self.factor = factor
def forward(self, input, reverse=False):
if not reverse:
output = self.squeeze2d(input, self.factor) # Squeeze in forward
return output
else:
output = self.unsqueeze2d(input, self.factor)
return output
def jacobian(self, x, rev=False):
return 0
@staticmethod
def squeeze2d(input, factor=2):
assert factor >= 1 and isinstance(factor, int)
if factor == 1:
return input
size = input.size()
B = size[0]
C = size[1]
H = size[2]
W = size[3]
assert H % factor == 0 and W % factor == 0, "{}".format((H, W, factor))
x = input.view(B, C, H // factor, factor, W // factor, factor)
x = x.permute(0, 3, 5, 1, 2, 4).contiguous()
x = x.view(B, factor * factor * C, H // factor, W // factor)
return x
@staticmethod
def unsqueeze2d(input, factor=2):
assert factor >= 1 and isinstance(factor, int)
factor2 = factor ** 2
if factor == 1:
return input
size = input.size()
B = size[0]
C = size[1]
H = size[2]
W = size[3]
assert C % (factor2) == 0, "{}".format(C)
x = input.view(B, factor, factor, C // factor2, H, W)
x = x.permute(0, 3, 4, 1, 5, 2).contiguous()
x = x.view(B, C // (factor2), H * factor, W * factor)
return x
class InvertibleConv1x1(nn.Module):
def __init__(self, num_channels):
super().__init__()
w_shape = [num_channels, num_channels]
w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype(np.float32)
self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init)))
self.w_shape = w_shape
def get_weight(self, input, reverse):
w_shape = self.w_shape
if not reverse:
weight = self.weight.view(w_shape[0], w_shape[1], 1, 1)
else:
weight = torch.inverse(self.weight.double()).float() \
.view(w_shape[0], w_shape[1], 1, 1)
return weight
def forward(self, input, reverse=False):
weight = self.get_weight(input, reverse)
if not reverse:
z = F.conv2d(input, weight)
return z
else:
z = F.conv2d(input, weight)
return z
class DenseBlock(nn.Module):
def __init__(self, channel_in, channel_out, init='xavier', gc=32, bias=True):
super(DenseBlock, self).__init__()
self.conv1 = nn.Conv2d(channel_in, gc, 3, 1, 1, bias=bias)
self.conv2 = nn.Conv2d(channel_in + gc, gc, 3, 1, 1, bias=bias)
self.conv3 = nn.Conv2d(channel_in + 2 * gc, gc, 3, 1, 1, bias=bias)
self.conv4 = nn.Conv2d(channel_in + 3 * gc, gc, 3, 1, 1, bias=bias)
self.conv5 = nn.Conv2d(channel_in + 4 * gc, channel_out, 3, 1, 1, bias=bias)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
if init == 'xavier':
initialize_weights_xavier([self.conv1, self.conv2, self.conv3, self.conv4], 0.1)
else:
initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4], 0.1)
initialize_weights(self.conv5, 0)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5
def initialize_weights(net_l, scale=1):
if not isinstance(net_l, list):
net_l = [net_l]
for net in net_l:
for m in net.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale # for residual block
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias.data, 0.0)
def initialize_weights_xavier(net_l, scale=1):
if not isinstance(net_l, list):
net_l = [net_l]
for net in net_l:
for m in net.modules():
if isinstance(m, nn.Conv2d):
init.xavier_normal_(m.weight)
m.weight.data *= scale # for residual block
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
init.xavier_normal_(m.weight)
m.weight.data *= scale
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias.data, 0.0)
def __init__(self, in_shape, int_ch, numTraceSamples=0, numSeriesTerms=0,
stride=1, coeff=.97, input_nonlin=True,
actnorm=True, n_power_iter=5, nonlin="elu", train=False):
"""
buid invertible bottleneck block
:param in_shape: shape of the input (channels, height, width)
:param int_ch: dimension of intermediate layers
:param stride: 1 if no downsample 2 if downsample
:param coeff: desired lipschitz constant
:param input_nonlin: if true applies a nonlinearity on the input
:param actnorm: if true uses actnorm like GLOW
:param n_power_iter: number of iterations for spectral normalization
:param nonlin: the nonlinearity to use
"""
super(conv_iresnet_block_simplified, self).__init__()
assert stride in (1, 2)
self.stride = stride
self.squeeze = IRes_Squeeze(stride)
self.coeff = coeff
self.numTraceSamples = numTraceSamples
self.numSeriesTerms = numSeriesTerms
self.n_power_iter = n_power_iter
nonlin = {
"relu": nn.ReLU,
"elu": nn.ELU,
"softplus": nn.Softplus,
"sorting": lambda: MaxMinGroup(group_size=2, axis=1)
}[nonlin]
# set shapes for spectral norm conv
in_ch, h, w = in_shape
layers = []
if input_nonlin:
layers.append(nonlin())
in_ch = in_ch * stride**2
kernel_size1 = 1
layers.append(self._wrapper_spectral_norm(nn.Conv2d(in_ch, int_ch, kernel_size=kernel_size1, padding=0),
(in_ch, h, w), kernel_size1))
layers.append(nonlin())
kernel_size3 = 1
layers.append(self._wrapper_spectral_norm(nn.Conv2d(int_ch, in_ch, kernel_size=kernel_size3, padding=0),
(int_ch, h, w), kernel_size3))
self.bottleneck_block = nn.Sequential(*layers)
if actnorm:
self.actnorm = ActNorm2D(in_ch, train=train)
else:
self.actnorm = None
def forward(self, x, rev=False, ignore_logdet=False, maxIter=25):
if not rev:
""" bijective or injective block forward """
if self.stride == 2:
x = self.squeeze.forward(x)
if self.actnorm is not None:
x, an_logdet = self.actnorm(x)
else:
an_logdet = 0.0
Fx = self.bottleneck_block(x)
if (self.numTraceSamples == 0 and self.numSeriesTerms == 0) or ignore_logdet:
trace = torch.tensor(0.)
else:
trace = power_series_matrix_logarithm_trace(Fx, x, self.numSeriesTerms, self.numTraceSamples)
y = Fx + x
return y, trace + an_logdet
else:
y = x
for iter_index in range(maxIter):
summand = self.bottleneck_block(x)
x = y - summand
if self.actnorm is not None:
x = self.actnorm.inverse(x)
if self.stride == 2:
x = self.squeeze.inverse(x)
return x
def _wrapper_spectral_norm(self, layer, shapes, kernel_size):
if kernel_size == 1:
# use spectral norm fc, because bound are tight for 1x1 convolutions
return spectral_norm_fc(layer, self.coeff,
n_power_iterations=self.n_power_iter)
else:
# use spectral norm based on conv, because bound not tight
return spectral_norm_conv(layer, self.coeff, shapes,
n_power_iterations=self.n_power_iter)
\ No newline at end of file
import numpy as np
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import warnings
from .priors import JointAutoregressiveHierarchicalPriors
from .our_utils import *
from compressai.layers import *
from .waseda import Cheng2020Anchor
class InvCompress(Cheng2020Anchor):
def __init__(self, N=192, **kwargs):
super().__init__(N=N)
self.g_a = None
self.g_s = None
self.enh = EnhModule(64)
self.inv = InvComp(M=N)
self.attention = AttModule(N)
def g_a_func(self, x):
x = self.enh(x)
x = self.inv(x)
x = self.attention(x)
return x
def g_s_func(self, x):
x = self.attention(x, rev = True)
x = self.inv(x, rev=True)
x = self.enh(x, rev=True)
return x
def forward(self, x):
y = self.g_a_func(x)
z = self.h_a(y)
z_hat, z_likelihoods = self.entropy_bottleneck(z)
params = self.h_s(z_hat)
y_hat = self.gaussian_conditional.quantize(
y, "noise" if self.training else "dequantize"
)
ctx_params = self.context_prediction(y_hat)
gaussian_params = self.entropy_parameters(
torch.cat((params, ctx_params), dim=1)
)
scales_hat, means_hat = gaussian_params.chunk(2, 1)
_, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat)
x_hat = self.g_s_func(y_hat)
return {
"x_hat": x_hat,
"likelihoods": {"y": y_likelihoods, "z": z_likelihoods}
}
@classmethod
def from_state_dict(cls, state_dict):
"""Return a new model instance from `state_dict`."""
N = state_dict["h_a.0.weight"].size(0)
net = cls(N)
net.load_state_dict(state_dict)
return net
def compress(self, x):
if next(self.parameters()).device != torch.device("cpu"):
warnings.warn(
"Inference on GPU is not recommended for the autoregressive "
"models (the entropy coder is run sequentially on CPU)."
)
y = self.g_a_func(x)
z = self.h_a(y)
z_strings = self.entropy_bottleneck.compress(z)
z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:])
params = self.h_s(z_hat)
s = 4 # scaling factor between z and y
kernel_size = 5 # context prediction kernel size
padding = (kernel_size - 1) // 2
y_height = z_hat.size(2) * s
y_width = z_hat.size(3) * s
y_hat = F.pad(y, (padding, padding, padding, padding))
y_strings = []
for i in range(y.size(0)):
string = self._compress_ar(
y_hat[i : i + 1],
params[i : i + 1],
y_height,
y_width,
kernel_size,
padding,
)
y_strings.append(string)
return {"strings": [y_strings, z_strings], "shape": z.size()[-2:], "y": y}
def decompress(self, strings, shape):
assert isinstance(strings, list) and len(strings) == 2
if next(self.parameters()).device != torch.device("cpu"):
warnings.warn(
"Inference on GPU is not recommended for the autoregressive "
"models (the entropy coder is run sequentially on CPU)."
)
z_hat = self.entropy_bottleneck.decompress(strings[1], shape)
params = self.h_s(z_hat)
s = 4 # scaling factor between z and y
kernel_size = 5 # context prediction kernel size
padding = (kernel_size - 1) // 2
y_height = z_hat.size(2) * s
y_width = z_hat.size(3) * s
y_hat = torch.zeros(
(z_hat.size(0), self.M, y_height + 2 * padding, y_width + 2 * padding),
device=z_hat.device,
)
for i, y_string in enumerate(strings[0]):
self._decompress_ar(
y_string,
y_hat[i : i + 1],
params[i : i + 1],
y_height,
y_width,
kernel_size,
padding,
)
y_hat = F.pad(y_hat, (-padding, -padding, -padding, -padding))
x_hat = self.g_s_func(y_hat).clamp_(0, 1)
return {"x_hat": x_hat}
This diff is collapsed.
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
def find_named_module(module, query):
"""Helper function to find a named module. Returns a `nn.Module` or `None`
Args:
module (nn.Module): the root module
query (str): the module name to find
Returns:
nn.Module or None
"""
return next((m for n, m in module.named_modules() if n == query), None)
def find_named_buffer(module, query):
"""Helper function to find a named buffer. Returns a `torch.Tensor` or `None`
Args:
module (nn.Module): the root module
query (str): the buffer name to find
Returns:
torch.Tensor or None
"""
return next((b for n, b in module.named_buffers() if n == query), None)
def _update_registered_buffer(
module,
buffer_name,
state_dict_key,
state_dict,
policy="resize_if_empty",
dtype=torch.int,
):
new_size = state_dict[state_dict_key].size()
registered_buf = find_named_buffer(module, buffer_name)
if policy in ("resize_if_empty", "resize"):
if registered_buf is None:
raise RuntimeError(f'buffer "{buffer_name}" was not registered')
if policy == "resize" or registered_buf.numel() == 0:
registered_buf.resize_(new_size)
elif policy == "register":
if registered_buf is not None:
raise RuntimeError(f'buffer "{buffer_name}" was already registered')
module.register_buffer(buffer_name, torch.empty(new_size, dtype=dtype).fill_(0))
else:
raise ValueError(f'Invalid policy "{policy}"')
def update_registered_buffers(
module,
module_name,
buffer_names,
state_dict,
policy="resize_if_empty",
dtype=torch.int,
):
"""Update the registered buffers in a module according to the tensors sized
in a state_dict.
(There's no way in torch to directly load a buffer with a dynamic size)
Args:
module (nn.Module): the module
module_name (str): module name in the state dict
buffer_names (list(str)): list of the buffer names to resize in the module
state_dict (dict): the state dict
policy (str): Update policy, choose from
('resize_if_empty', 'resize', 'register')
dtype (dtype): Type of buffer to be registered (when policy is 'register')
"""
valid_buffer_names = [n for n, _ in module.named_buffers()]
for buffer_name in buffer_names:
if buffer_name not in valid_buffer_names:
raise ValueError(f'Invalid buffer name "{buffer_name}"')
for buffer_name in buffer_names:
_update_registered_buffer(
module,
buffer_name,
f"{module_name}.{buffer_name}",
state_dict,
policy,
dtype,
)
def conv(in_channels, out_channels, kernel_size=5, stride=2):
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=kernel_size // 2,
)
def deconv(in_channels, out_channels, kernel_size=5, stride=2):
return nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
output_padding=stride - 1,
padding=kernel_size // 2,
)
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch.nn as nn
from compressai.layers import (
AttentionBlock,
ResidualBlock,
ResidualBlockUpsample,
ResidualBlockWithStride,
conv3x3,
subpel_conv3x3,
)
from .priors import JointAutoregressiveHierarchicalPriors
class Cheng2020Anchor(JointAutoregressiveHierarchicalPriors):
"""Anchor model variant from `"Learned Image Compression with
Discretized Gaussian Mixture Likelihoods and Attention Modules"
<https://arxiv.org/abs/2001.01568>`_, by Zhengxue Cheng, Heming Sun, Masaru
Takeuchi, Jiro Katto.
Uses residual blocks with small convolutions (3x3 and 1x1), and sub-pixel
convolutions for up-sampling.
Args:
N (int): Number of channels
"""
def __init__(self, N=192, **kwargs):
super().__init__(N=N, M=N, **kwargs)
self.g_a = nn.Sequential(
ResidualBlockWithStride(3, N, stride=2),
ResidualBlock(N, N),
ResidualBlockWithStride(N, N, stride=2),
ResidualBlock(N, N),
ResidualBlockWithStride(N, N, stride=2),
ResidualBlock(N, N),
conv3x3(N, N, stride=2),
)
self.h_a = nn.Sequential(
conv3x3(N, N),
nn.LeakyReLU(inplace=True),
conv3x3(N, N),
nn.LeakyReLU(inplace=True),
conv3x3(N, N, stride=2),
nn.LeakyReLU(inplace=True),
conv3x3(N, N),
nn.LeakyReLU(inplace=True),
conv3x3(N, N, stride=2),
)
self.h_s = nn.Sequential(
conv3x3(N, N),
nn.LeakyReLU(inplace=True),
subpel_conv3x3(N, N, 2),
nn.LeakyReLU(inplace=True),
conv3x3(N, N * 3 // 2),
nn.LeakyReLU(inplace=True),
subpel_conv3x3(N * 3 // 2, N * 3 // 2, 2),
nn.LeakyReLU(inplace=True),
conv3x3(N * 3 // 2, N * 2),
)
self.g_s = nn.Sequential(
ResidualBlock(N, N),
ResidualBlockUpsample(N, N, 2),
ResidualBlock(N, N),
ResidualBlockUpsample(N, N, 2),
ResidualBlock(N, N),
ResidualBlockUpsample(N, N, 2),
ResidualBlock(N, N),
subpel_conv3x3(N, 3, 2),
)
@classmethod
def from_state_dict(cls, state_dict):
"""Return a new model instance from `state_dict`."""
N = state_dict["g_a.0.conv1.weight"].size(0)
net = cls(N)
net.load_state_dict(state_dict)
return net
class Cheng2020Attention(Cheng2020Anchor):
"""Self-attention model variant from `"Learned Image Compression with
Discretized Gaussian Mixture Likelihoods and Attention Modules"
<https://arxiv.org/abs/2001.01568>`_, by Zhengxue Cheng, Heming Sun, Masaru
Takeuchi, Jiro Katto.
Uses self-attention, residual blocks with small convolutions (3x3 and 1x1),
and sub-pixel convolutions for up-sampling.
Args:
N (int): Number of channels
"""
def __init__(self, N=192, **kwargs):
super().__init__(N=N, **kwargs)
self.g_a = nn.Sequential(
ResidualBlockWithStride(3, N, stride=2),
ResidualBlock(N, N),
ResidualBlockWithStride(N, N, stride=2),
AttentionBlock(N),
ResidualBlock(N, N),
ResidualBlockWithStride(N, N, stride=2),
ResidualBlock(N, N),
conv3x3(N, N, stride=2),
AttentionBlock(N),
)
self.g_s = nn.Sequential(
AttentionBlock(N),
ResidualBlock(N, N),
ResidualBlockUpsample(N, N, 2),
ResidualBlock(N, N),
ResidualBlockUpsample(N, N, 2),
AttentionBlock(N),
ResidualBlock(N, N),
ResidualBlockUpsample(N, N, 2),
ResidualBlock(N, N),
subpel_conv3x3(N, 3, 2),
)
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .bound_ops import LowerBound
from .ops import ste_round
from .parametrizers import NonNegativeParametrizer
__all__ = ["ste_round", "LowerBound", "NonNegativeParametrizer"]
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
class LowerBoundFunction(torch.autograd.Function):
"""Autograd function for the `LowerBound` operator."""
@staticmethod
def forward(ctx, input_, bound):
ctx.save_for_backward(input_, bound)
return torch.max(input_, bound)
@staticmethod
def backward(ctx, grad_output):
input_, bound = ctx.saved_tensors
pass_through_if = (input_ >= bound) | (grad_output < 0)
return pass_through_if.type(grad_output.dtype) * grad_output, None
class LowerBound(nn.Module):
"""Lower bound operator, computes `torch.max(x, bound)` with a custom
gradient.
The derivative is replaced by the identity function when `x` is moved
towards the `bound`, otherwise the gradient is kept to zero.
"""
def __init__(self, bound):
super().__init__()
self.register_buffer("bound", torch.Tensor([float(bound)]))
@torch.jit.unused
def lower_bound(self, x):
return LowerBoundFunction.apply(x, self.bound)
def forward(self, x):
if torch.jit.is_scripting():
return torch.max(x, self.bound)
return self.lower_bound(x)
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
def ste_round(x):
"""
Rounding with non-zero gradients. Gradients are approximated by replacing
the derivative by the identity function.
Used in `"Lossy Image Compression with Compressive Autoencoders"
<https://arxiv.org/abs/1703.00395>`_
.. note::
Implemented with the pytorch `detach()` reparametrization trick:
`x_round = x_round - x.detach() + x`
"""
return torch.round(x) - x.detach() + x
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from .bound_ops import LowerBound
class NonNegativeParametrizer(nn.Module):
"""
Non negative reparametrization.
Used for stability during training.
"""
def __init__(self, minimum=0, reparam_offset=2 ** -18):
super().__init__()
self.minimum = float(minimum)
self.reparam_offset = float(reparam_offset)
pedestal = self.reparam_offset ** 2
self.register_buffer("pedestal", torch.Tensor([pedestal]))
bound = (self.minimum + self.reparam_offset ** 2) ** 0.5
self.lower_bound = LowerBound(bound)
def init(self, x):
return torch.sqrt(torch.max(x + self.pedestal, self.pedestal))
def forward(self, x):
out = self.lower_bound(x)
out = out ** 2 - self.pedestal
return out
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .transforms import *
from typing import Tuple, Union
import torch
import torch.nn.functional as F
from torch import Tensor
YCBCR_WEIGHTS = {
# Spec: (K_r, K_g, K_b) with K_g = 1 - K_r - K_b
"ITU-R_BT.709": (0.2126, 0.7152, 0.0722)
}
def _check_input_tensor(tensor: Tensor) -> None:
if (
not isinstance(tensor, Tensor)
or not tensor.is_floating_point()
or not len(tensor.size()) in (3, 4)
or not tensor.size(-3) == 3
):
raise ValueError(
"Expected a 3D or 4D tensor with shape (Nx3xHxW) or (3xHxW) as input"
)
def rgb2ycbcr(rgb: Tensor) -> Tensor:
"""RGB to YCbCr conversion for torch Tensor.
Using ITU-R BT.709 coefficients.
Args:
rgb (torch.Tensor): 3D or 4D floating point RGB tensor
Returns:
ycbcr (torch.Tensor): converted tensor
"""
_check_input_tensor(rgb)
r, g, b = rgb.chunk(3, -3)
Kr, Kg, Kb = YCBCR_WEIGHTS["ITU-R_BT.709"]
y = Kr * r + Kg * g + Kb * b
cb = 0.5 * (b - y) / (1 - Kb) + 0.5
cr = 0.5 * (r - y) / (1 - Kr) + 0.5
ycbcr = torch.cat((y, cb, cr), dim=-3)
return ycbcr
def ycbcr2rgb(ycbcr: Tensor) -> Tensor:
"""YCbCr to RGB conversion for torch Tensor.
Using ITU-R BT.709 coefficients.
Args:
ycbcr (torch.Tensor): 3D or 4D floating point RGB tensor
Returns:
rgb (torch.Tensor): converted tensor
"""
_check_input_tensor(ycbcr)
y, cb, cr = ycbcr.chunk(3, -3)
Kr, Kg, Kb = YCBCR_WEIGHTS["ITU-R_BT.709"]
r = y + (2 - 2 * Kr) * (cr - 0.5)
b = y + (2 - 2 * Kb) * (cb - 0.5)
g = (y - Kr * r - Kb * b) / Kg
rgb = torch.cat((r, g, b), dim=-3)
return rgb
def yuv_444_to_420(
yuv: Union[Tensor, Tuple[Tensor, Tensor, Tensor]],
mode: str = "avg_pool",
) -> Tuple[Tensor, Tensor, Tensor]:
"""Convert a 444 tensor to a 420 representation.
Args:
yuv (torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)): 444
input to be downsampled. Takes either a (Nx3xHxW) tensor or a tuple
of 3 (Nx1xHxW) tensors.
mode (str): algorithm used for downsampling: ``'avg_pool'``. Default
``'avg_pool'``
Returns:
(torch.Tensor, torch.Tensor, torch.Tensor): Converted 420
"""
if mode not in ("avg_pool",):
raise ValueError(f'Invalid downsampling mode "{mode}".')
if mode == "avg_pool":
def _downsample(tensor):
return F.avg_pool2d(tensor, kernel_size=2, stride=2)
if isinstance(yuv, torch.Tensor):
y, u, v = yuv.chunk(3, 1)
else:
y, u, v = yuv
return (y, _downsample(u), _downsample(v))
def yuv_420_to_444(
yuv: Tuple[Tensor, Tensor, Tensor],
mode: str = "bilinear",
return_tuple: bool = False,
) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor]]:
"""Convert a 420 input to a 444 representation.
Args:
yuv (torch.Tensor, torch.Tensor, torch.Tensor): 420 input frames in
(Nx1xHxW) format
mode (str): algorithm used for upsampling: ``'bilinear'`` |
``'nearest'`` Default ``'bilinear'``
return_tuple (bool): return input as tuple of tensors instead of a
concatenated tensor, 3 (Nx1xHxW) tensors instead of one (Nx3xHxW)
tensor (default: False)
Returns:
(torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)): Converted
444
"""
if len(yuv) != 3 or any(not isinstance(c, torch.Tensor) for c in yuv):
raise ValueError("Expected a tuple of 3 torch tensors")
if mode not in ("bilinear", "nearest"):
raise ValueError(f'Invalid upsampling mode "{mode}".')
if mode in ("bilinear", "nearest"):
def _upsample(tensor):
return F.interpolate(tensor, scale_factor=2, mode=mode, align_corners=False)
y, u, v = yuv
u, v = _upsample(u), _upsample(v)
if return_tuple:
return y, u, v
return torch.cat((y, u, v), dim=1)
from . import functional as F_transforms
__all__ = [
"RGB2YCbCr",
"YCbCr2RGB",
"YUV444To420",
"YUV420To444",
]
class RGB2YCbCr:
"""Convert a RGB tensor to YCbCr.
The tensor is expected to be in the [0, 1] floating point range, with a
shape of (3xHxW) or (Nx3xHxW).
"""
def __call__(self, rgb):
"""
Args:
rgb (torch.Tensor): 3D or 4D floating point RGB tensor
Returns:
ycbcr(torch.Tensor): converted tensor
"""
return F_transforms.rgb2ycbcr(rgb)
def __repr__(self):
return f"{self.__class__.__name__}()"
class YCbCr2RGB:
"""Convert a YCbCr tensor to RGB.
The tensor is expected to be in the [0, 1] floating point range, with a
shape of (3xHxW) or (Nx3xHxW).
"""
def __call__(self, ycbcr):
"""
Args:
ycbcr(torch.Tensor): 3D or 4D floating point RGB tensor
Returns:
rgb(torch.Tensor): converted tensor
"""
return F_transforms.ycbcr2rgb(ycbcr)
def __repr__(self):
return f"{self.__class__.__name__}()"
class YUV444To420:
"""Convert a YUV 444 tensor to a 420 representation.
Args:
mode (str): algorithm used for downsampling: ``'avg_pool'``. Default
``'avg_pool'``
Example:
>>> x = torch.rand(1, 3, 32, 32)
>>> y, u, v = YUV444To420()(x)
>>> y.size() # 1, 1, 32, 32
>>> u.size() # 1, 1, 16, 16
"""
def __init__(self, mode: str = "avg_pool"):
self.mode = str(mode)
def __call__(self, yuv):
"""
Args:
yuv (torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)):
444 input to be downsampled. Takes either a (Nx3xHxW) tensor or
a tuple of 3 (Nx1xHxW) tensors.
Returns:
(torch.Tensor, torch.Tensor, torch.Tensor): Converted 420
"""
return F_transforms.yuv_444_to_420(yuv, mode=self.mode)
def __repr__(self):
return f"{self.__class__.__name__}()"
class YUV420To444:
"""Convert a YUV 420 input to a 444 representation.
Args:
mode (str): algorithm used for upsampling: ``'bilinear'`` | ``'nearest'``.
Default ``'bilinear'``
return_tuple (bool): return input as tuple of tensors instead of a
concatenated tensor, 3 (Nx1xHxW) tensors instead of one (Nx3xHxW)
tensor (default: False)
Example:
>>> y = torch.rand(1, 1, 32, 32)
>>> u, v = torch.rand(1, 1, 16, 16), torch.rand(1, 1, 16, 16)
>>> x = YUV420To444()((y, u, v))
>>> x.size() # 1, 3, 32, 32
"""
def __init__(self, mode: str = "bilinear", return_tuple: bool = False):
self.mode = str(mode)
self.return_tuple = bool(return_tuple)
def __call__(self, yuv):
"""
Args:
yuv (torch.Tensor, torch.Tensor, torch.Tensor): 420 input frames in
(Nx1xHxW) format
Returns:
(torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)): Converted
444
"""
return F_transforms.yuv_420_to_444(yuv, return_tuple=self.return_tuple)
def __repr__(self):
return f"{self.__class__.__name__}(return_tuple={self.return_tuple})"
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Collect performance metrics of published traditional or end-to-end image
codecs.
"""
import argparse
import json
import multiprocessing as mp
import os
import sys
from collections import defaultdict
from itertools import starmap
from typing import List
from .codecs import AV1, BPG, HM, JPEG, JPEG2000, TFCI, VTM, Codec, WebP
# from torchvision.datasets.folder
IMG_EXTENSIONS = (
".jpg",
".jpeg",
".png",
".ppm",
".bmp",
".pgm",
".tif",
".tiff",
".webp",
)
codecs = [JPEG, WebP, JPEG2000, BPG, TFCI, VTM, HM, AV1]
# we need the quality index (not value) to compute the stats later
def func(codec, i, *args):
rv = codec.run(*args)
return i, rv
def collect(codec: Codec, dataset: str, qualities: List[int], num_jobs: int = 1):
filepaths = [
os.path.join(dataset, f)
for f in os.listdir(dataset)
if os.path.splitext(f)[-1].lower() in IMG_EXTENSIONS
]
# print(filepaths)
pool = mp.Pool(num_jobs) if num_jobs > 1 else None
if len(filepaths) == 0:
print("No images found in the dataset directory")
sys.exit(1)
args = [(codec, i, f, q) for i, q in enumerate(qualities) for f in filepaths]
if pool:
rv = pool.starmap(func, args)
else:
rv = list(starmap(func, args))
results = [defaultdict(float) for _ in range(len(qualities))]
for i, metrics in rv:
for k, v in metrics.items():
results[i][k] += v
for i, _ in enumerate(results):
for k, v in results[i].items():
results[i][k] = v / len(filepaths)
# list of dict -> dict of list
out = defaultdict(list)
for r in results:
for k, v in r.items():
out[k].append(v)
return out
def setup_args():
description = "Collect codec metrics."
parser = argparse.ArgumentParser(description=description)
subparsers = parser.add_subparsers(dest="codec", help="Select codec")
subparsers.required = True
return parser, subparsers
def setup_common_args(parser):
parser.add_argument("dataset", type=str)
parser.add_argument(
"-j",
"--num-jobs",
type=int,
metavar="N",
default=1,
help="Number of parallel jobs (default: %(default)s)",
)
parser.add_argument(
"-q",
"--quality",
dest="qualities",
metavar="Q",
default=[5, 10, 20, 30, 40, 50, 60, 70,80,90],
nargs="*",
type=int,
help="quality parameter (default: %(default)s)",
)
# [3,5,10,15,17,20,22,25,27,30,32,35,37,40,42,45,47,50],
# new added
parser.add_argument(
"--name",
dest="name",
default="ans",
type=str,
help="name for json",
)
def main(argv):
import time
start = time.time()
parser, subparsers = setup_args()
for c in codecs:
cparser = subparsers.add_parser(c.__name__.lower(), help=f"{c.__name__}")
setup_common_args(cparser)
c.setup_args(cparser)
args = parser.parse_args(argv)
codec_cls = next(c for c in codecs if c.__name__.lower() == args.codec)
codec = codec_cls(args)
results = collect(codec, args.dataset, args.qualities, args.num_jobs)
output = {
"name": codec.name,
"description": codec.description,
"results": results,
}
print(json.dumps(output, indent=2))
end = time.time()
print('total time:', end - start)
output_dir = '/home/felix/disk2/compressai_v2/codes/results/log'
output_json_path = os.path.join(output_dir, args.name+'.json')
with open(output_json_path, 'w') as f:
json.dump(output, f, indent=2)
if __name__ == "__main__":
main(sys.argv[1:])
This diff is collapsed.
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Evaluate an end-to-end compression model on an image dataset.
"""
import argparse
import json
import math
import os
import sys
import time
from collections import defaultdict
from typing import List
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from pytorch_msssim import ms_ssim
from torchvision import transforms
import compressai
from compressai.zoo import models as pretrained_models
from compressai.zoo.image import model_architectures as architectures
# from torchvision.datasets.folder
IMG_EXTENSIONS = (
".jpg",
".jpeg",
".png",
".ppm",
".bmp",
".pgm",
".tif",
".tiff",
".webp",
)
def collect_images(rootpath: str) -> List[str]:
return [
os.path.join(rootpath, f)
for f in os.listdir(rootpath)
if os.path.splitext(f)[-1].lower() in IMG_EXTENSIONS
]
def psnr(a: torch.Tensor, b: torch.Tensor) -> float:
mse = F.mse_loss(a, b).item()
return -10 * math.log10(mse)
def read_image(filepath: str) -> torch.Tensor:
assert os.path.isfile(filepath)
img = Image.open(filepath).convert("RGB")
# test_transforms = transforms.Compose(
# [transforms.CenterCrop(256), transforms.ToTensor()]
# )
# return test_transforms(img)
return transforms.ToTensor()(img)
@torch.no_grad()
def inference(model, x, savedir = "", idx = 1):
x = x.unsqueeze(0)
h, w = x.size(2), x.size(3)
p = 64 # maximum 6 strides of 2
new_h = (h + p - 1) // p * p
new_w = (w + p - 1) // p * p
padding_left = (new_w - w) // 2
padding_right = new_w - w - padding_left
padding_top = (new_h - h) // 2
padding_bottom = new_h - h - padding_top
x_padded = F.pad(
x,
(padding_left, padding_right, padding_top, padding_bottom),
mode="constant",
value=0,
)
start = time.time()
out_enc = model.compress(x_padded)
enc_time = time.time() - start
start = time.time()
out_dec = model.decompress(out_enc["strings"], out_enc["shape"])
dec_time = time.time() - start
out_dec["x_hat"] = F.pad(
out_dec["x_hat"], (-padding_left, -padding_right, -padding_top, -padding_bottom)
)
num_pixels = x.size(0) * x.size(2) * x.size(3)
bpp = sum(len(s[0]) for s in out_enc["strings"]) * 8.0 / num_pixels
if savedir != "":
if not os.path.exists(savedir):
os.makedirs(savedir)
cur_psnr = psnr(x, out_dec["x_hat"])
cur_ssim = ms_ssim(x, out_dec["x_hat"], data_range=1.0).item()
tran1 = transforms.ToPILImage()
cur_img = tran1(out_dec["x_hat"][0])
cur_img.save(os.path.join(savedir,'{:02d}'.format(idx)+"_"+'{:.2f}'.format(cur_psnr)+"_"+'{:.3f}'.format(bpp)+"_"+'{:.3f}'.format(cur_ssim)+".png"))
return {
"psnr": psnr(x, out_dec["x_hat"]),
"ms-ssim": ms_ssim(x, out_dec["x_hat"], data_range=1.0).item(),
"bpp": bpp,
"encoding_time": enc_time,
"decoding_time": dec_time,
}
@torch.no_grad()
def inference_entropy_estimation(model, x):
x = x.unsqueeze(0)
start = time.time()
out_net = model.forward(x)
# print(out_net['x_hat'][0,0,:5,:5])
elapsed_time = time.time() - start
num_pixels = x.size(0) * x.size(2) * x.size(3)
bpp = sum(
(torch.log(likelihoods).sum() / (-math.log(2) * num_pixels))
for likelihoods in out_net["likelihoods"].values()
)
return {
"psnr": psnr(x, out_net["x_hat"]),
"bpp": bpp.item(),
"encoding_time": elapsed_time / 2.0, # broad estimation
"decoding_time": elapsed_time / 2.0,
}
def load_pretrained(model: str, metric: str, quality: int) -> nn.Module:
return pretrained_models[model](
quality=quality, metric=metric, pretrained=True
).eval()
def load_checkpoint(arch: str, checkpoint_path: str) -> nn.Module:
return architectures[arch].from_state_dict(torch.load(checkpoint_path)).eval()
def eval_model(model, filepaths, entropy_estimation=False, half=False, savedir = ""):
device = next(model.parameters()).device
metrics = defaultdict(float)
for idx, f in enumerate(sorted(filepaths)):
x = read_image(f).to(device)
if not entropy_estimation:
print('evaluating index', idx)
if half:
model = model.half()
x = x.half()
rv = inference(model, x, savedir, idx)
else:
rv = inference_entropy_estimation(model, x)
print('bpp', rv['bpp'])
print('psnr', rv['psnr'])
print('ms-ssim', rv['ms-ssim'])
print()
for k, v in rv.items():
metrics[k] += v
for k, v in metrics.items():
metrics[k] = v / len(filepaths)
return metrics
def setup_args():
parent_parser = argparse.ArgumentParser(
add_help=False,
)
# Common options.
parent_parser.add_argument("dataset", type=str, help="dataset path")
parent_parser.add_argument(
"-a",
"--arch",
type=str,
choices=pretrained_models.keys(),
help="model architecture",
required=True,
)
parent_parser.add_argument(
"-c",
"--entropy-coder",
choices=compressai.available_entropy_coders(),
default=compressai.available_entropy_coders()[0],
help="entropy coder (default: %(default)s)",
)
parent_parser.add_argument(
"--cuda",
action="store_true",
help="enable CUDA",
)
parent_parser.add_argument(
"--half",
action="store_true",
help="convert model to half floating point (fp16)",
)
parent_parser.add_argument(
"--entropy-estimation",
action="store_true",
help="use evaluated entropy estimation (no entropy coding)",
)
parent_parser.add_argument(
"-v",
"--verbose",
action="store_true",
help="verbose mode",
)
parent_parser.add_argument(
"-s",
"--savedir",
type=str,
default="",
)
parent_parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID")
parser = argparse.ArgumentParser(
description="Evaluate a model on an image dataset.", add_help=True
)
subparsers = parser.add_subparsers(
help="model source", dest="source")#, required=True
# )
# Options for pretrained models
pretrained_parser = subparsers.add_parser("pretrained", parents=[parent_parser])
pretrained_parser.add_argument(
"-m",
"--metric",
type=str,
choices=["mse", "ms-ssim"],
default="mse",
help="metric trained against (default: %(default)s)",
)
pretrained_parser.add_argument(
"-q",
"--quality",
dest="qualities",
nargs="+",
type=int,
default=(1,),
)
checkpoint_parser = subparsers.add_parser("checkpoint", parents=[parent_parser])
# checkpoint_parser.add_argument(
# "-p",
# "--path",
# dest="paths",
# type=str,
# nargs="*",
# required=True,
# help="checkpoint path",
# )
checkpoint_parser.add_argument("-exp", "--experiment", type=str, required=True, help="Experiment name")
return parser
def main(argv):
args = setup_args().parse_args(argv)
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu"
filepaths = collect_images(args.dataset)
if len(filepaths) == 0:
print("No images found in directory.")
sys.exit(1)
compressai.set_entropy_coder(args.entropy_coder)
if args.source == "pretrained":
runs = sorted(args.qualities)
opts = (args.arch, args.metric)
load_func = load_pretrained
log_fmt = "\rEvaluating {0} | {run:d}"
elif args.source == "checkpoint":
# runs = args.paths
checkpoint_updated_dir = os.path.join('../experiments', args.experiment, 'checkpoint_updated')
checkpoint_updated = os.path.join(checkpoint_updated_dir, os.listdir(checkpoint_updated_dir)[0])
runs = [checkpoint_updated]
opts = (args.arch,)
load_func = load_checkpoint
log_fmt = "\rEvaluating {run:s}"
results = defaultdict(list)
for run in runs:
if args.verbose:
sys.stderr.write(log_fmt.format(*opts, run=run))
sys.stderr.flush()
model = load_func(*opts, run)
if args.cuda and torch.cuda.is_available():
model = model.to("cuda")
metrics = eval_model(model, filepaths, args.entropy_estimation, args.half, args.savedir)
for k, v in metrics.items():
results[k].append(v)
if args.verbose:
sys.stderr.write("\n")
sys.stderr.flush()
description = (
"entropy estimation" if args.entropy_estimation else args.entropy_coder
)
output = {
"name": args.arch,
"description": f"Inference ({description})",
"results": results,
}
print(json.dumps(output, indent=2))
if __name__ == "__main__":
main(sys.argv[1:])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment