Commit 9e459ea3 authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from .gdn import GDN
class MaskedConv2d(nn.Conv2d):
r"""Masked 2D convolution implementation, mask future "unseen" pixels.
Useful for building auto-regressive network components.
Introduced in `"Conditional Image Generation with PixelCNN Decoders"
<https://arxiv.org/abs/1606.05328>`_.
Inherits the same arguments as a `nn.Conv2d`. Use `mask_type='A'` for the
first layer (which also masks the "current pixel"), `mask_type='B'` for the
following layers.
"""
def __init__(self, *args, mask_type="A", **kwargs):
super().__init__(*args, **kwargs)
if mask_type not in ("A", "B"):
raise ValueError(f'Invalid "mask_type" value "{mask_type}"')
self.register_buffer("mask", torch.ones_like(self.weight.data))
_, _, h, w = self.mask.size()
self.mask[:, :, h // 2, w // 2 + (mask_type == "B") :] = 0
self.mask[:, :, h // 2 + 1 :] = 0
def forward(self, x):
# TODO(begaintj): weight assigment is not supported by torchscript
self.weight.data *= self.mask
return super().forward(x)
def conv3x3(in_ch, out_ch, stride=1):
"""3x3 convolution with padding."""
return nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1)
def subpel_conv3x3(in_ch, out_ch, r=1):
"""3x3 sub-pixel convolution for up-sampling."""
return nn.Sequential(
nn.Conv2d(in_ch, out_ch * r ** 2, kernel_size=3, padding=1), nn.PixelShuffle(r)
)
def conv1x1(in_ch, out_ch, stride=1):
"""1x1 convolution."""
return nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride)
class ResidualBlockWithStride(nn.Module):
"""Residual block with a stride on the first convolution.
Args:
in_ch (int): number of input channels
out_ch (int): number of output channels
stride (int): stride value (default: 2)
"""
def __init__(self, in_ch, out_ch, stride=2):
super().__init__()
self.conv1 = conv3x3(in_ch, out_ch, stride=stride)
self.leaky_relu = nn.LeakyReLU(inplace=True)
self.conv2 = conv3x3(out_ch, out_ch)
self.gdn = GDN(out_ch)
if stride != 1 or in_ch != out_ch:
self.skip = conv1x1(in_ch, out_ch, stride=stride)
else:
self.skip = None
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.leaky_relu(out)
out = self.conv2(out)
out = self.gdn(out)
if self.skip is not None:
identity = self.skip(x)
out += identity
return out
class ResidualBlockUpsample(nn.Module):
"""Residual block with sub-pixel upsampling on the last convolution.
Args:
in_ch (int): number of input channels
out_ch (int): number of output channels
upsample (int): upsampling factor (default: 2)
"""
def __init__(self, in_ch, out_ch, upsample=2):
super().__init__()
self.subpel_conv = subpel_conv3x3(in_ch, out_ch, upsample)
self.leaky_relu = nn.LeakyReLU(inplace=True)
self.conv = conv3x3(out_ch, out_ch)
self.igdn = GDN(out_ch, inverse=True)
self.upsample = subpel_conv3x3(in_ch, out_ch, upsample)
def forward(self, x):
identity = x
out = self.subpel_conv(x)
out = self.leaky_relu(out)
out = self.conv(out)
out = self.igdn(out)
identity = self.upsample(x)
out += identity
return out
class ResidualBlock(nn.Module):
"""Simple residual block with two 3x3 convolutions.
Args:
in_ch (int): number of input channels
out_ch (int): number of output channels
"""
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv1 = conv3x3(in_ch, out_ch)
self.leaky_relu = nn.LeakyReLU(inplace=True)
self.conv2 = conv3x3(out_ch, out_ch)
if in_ch != out_ch:
self.skip = conv1x1(in_ch, out_ch)
else:
self.skip = None
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.leaky_relu(out)
out = self.conv2(out)
out = self.leaky_relu(out)
if self.skip is not None:
identity = self.skip(x)
out = out + identity
return out
class AttentionBlock(nn.Module):
"""Self attention block.
Simplified variant from `"Learned Image Compression with
Discretized Gaussian Mixture Likelihoods and Attention Modules"
<https://arxiv.org/abs/2001.01568>`_, by Zhengxue Cheng, Heming Sun, Masaru
Takeuchi, Jiro Katto.
Args:
N (int): Number of channels)
"""
def __init__(self, N):
super().__init__()
class ResidualUnit(nn.Module):
"""Simple residual unit."""
def __init__(self):
super().__init__()
self.conv = nn.Sequential(
conv1x1(N, N // 2),
nn.ReLU(inplace=True),
conv3x3(N // 2, N // 2),
nn.ReLU(inplace=True),
conv1x1(N // 2, N),
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
identity = x
out = self.conv(x)
out += identity
out = self.relu(out)
return out
self.conv_a = nn.Sequential(ResidualUnit(), ResidualUnit(), ResidualUnit())
self.conv_b = nn.Sequential(
ResidualUnit(),
ResidualUnit(),
ResidualUnit(),
conv1x1(N, N),
)
def forward(self, x):
identity = x
a = self.conv_a(x)
b = self.conv_b(x)
out = a * torch.sigmoid(b)
out += identity
return out
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .priors import *
from .waseda import *
from .ours import *
import numpy as np
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import warnings
from compressai.layers import *
class AttModule(nn.Module):
def __init__(self, N):
super(AttModule, self).__init__()
self.forw_att = AttentionBlock(N)
self.back_att = AttentionBlock(N)
def forward(self, x, rev=False):
if not rev:
return self.forw_att(x)
else:
return self.back_att(x)
class EnhModule(nn.Module):
def __init__(self, nf):
super(EnhModule, self).__init__()
self.forw_enh = EnhBlock(nf)
self.back_enh = EnhBlock(nf)
def forward(self, x, rev=False):
if not rev:
return self.forw_enh(x)
else:
return self.back_enh(x)
class EnhBlock(nn.Module):
def __init__(self, nf):
super(EnhBlock, self).__init__()
self.layers = nn.Sequential(
DenseBlock(3, nf),
nn.Conv2d(nf, nf, kernel_size=1, stride=1, padding=0, bias=True),
nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=True),
nn.Conv2d(nf, nf, kernel_size=1, stride=1, padding=0, bias=True),
DenseBlock(nf, 3)
)
def forward(self, x):
return x + self.layers(x) * 0.2
class InvComp(nn.Module):
def __init__(self, M):
super(InvComp, self).__init__()
self.in_nc = 3
self.out_nc = M
self.operations = nn.ModuleList()
# 1st level
b = SqueezeLayer(2)
self.operations.append(b)
self.in_nc *= 4
b = InvertibleConv1x1(self.in_nc)
self.operations.append(b)
b = CouplingLayer(self.in_nc // 4, 3 * self.in_nc // 4, 5)
self.operations.append(b)
b = CouplingLayer(self.in_nc // 4, 3 * self.in_nc // 4, 5)
self.operations.append(b)
b = CouplingLayer(self.in_nc // 4, 3 * self.in_nc // 4, 5)
self.operations.append(b)
# 2nd level
b = SqueezeLayer(2)
self.operations.append(b)
self.in_nc *= 4
b = InvertibleConv1x1(self.in_nc)
self.operations.append(b)
b = CouplingLayer(self.in_nc // 4, 3 * self.in_nc // 4, 5)
self.operations.append(b)
b = CouplingLayer(self.in_nc // 4, 3 * self.in_nc // 4, 5)
self.operations.append(b)
b = CouplingLayer(self.in_nc // 4, 3 * self.in_nc // 4, 5)
self.operations.append(b)
# 3rd level
b = SqueezeLayer(2)
self.operations.append(b)
self.in_nc *= 4
b = InvertibleConv1x1(self.in_nc)
self.operations.append(b)
b = CouplingLayer(self.in_nc // 4, 3 * self.in_nc // 4, 3)
self.operations.append(b)
b = CouplingLayer(self.in_nc // 4, 3 * self.in_nc // 4, 3)
self.operations.append(b)
b = CouplingLayer(self.in_nc // 4, 3 * self.in_nc // 4, 3)
self.operations.append(b)
# 4th level
b = SqueezeLayer(2)
self.operations.append(b)
self.in_nc *= 4
b = InvertibleConv1x1(self.in_nc)
self.operations.append(b)
b = CouplingLayer(self.in_nc // 4, 3 * self.in_nc // 4, 3)
self.operations.append(b)
b = CouplingLayer(self.in_nc // 4, 3 * self.in_nc // 4, 3)
self.operations.append(b)
b = CouplingLayer(self.in_nc // 4, 3 * self.in_nc // 4, 3)
self.operations.append(b)
def forward(self, x, rev=False):
if not rev:
for op in self.operations:
x = op.forward(x, False)
b, c, h, w = x.size()
x = torch.mean(x.view(b, c//self.out_nc, self.out_nc, h, w), dim=1)
else:
times = self.in_nc // self.out_nc
x = x.repeat(1, times, 1, 1)
for op in reversed(self.operations):
x = op.forward(x, True)
return x
class CouplingLayer(nn.Module):
def __init__(self, split_len1, split_len2, kernal_size, clamp=1.0):
super(CouplingLayer, self).__init__()
self.split_len1 = split_len1
self.split_len2 = split_len2
self.clamp = clamp
self.G1 = Bottleneck(self.split_len1, self.split_len2, kernal_size)
self.G2 = Bottleneck(self.split_len2, self.split_len1, kernal_size)
self.H1 = Bottleneck(self.split_len1, self.split_len2, kernal_size)
self.H2 = Bottleneck(self.split_len2, self.split_len1, kernal_size)
def forward(self, x, rev=False):
x1, x2 = (x.narrow(1, 0, self.split_len1), x.narrow(1, self.split_len1, self.split_len2))
if not rev:
y1 = x1.mul(torch.exp( self.clamp * (torch.sigmoid(self.G2(x2)) * 2 - 1) )) + self.H2(x2)
y2 = x2.mul(torch.exp( self.clamp * (torch.sigmoid(self.G1(y1)) * 2 - 1) )) + self.H1(y1)
else:
y2 = (x2 - self.H1(x1)).div(torch.exp( self.clamp * (torch.sigmoid(self.G1(x1)) * 2 - 1) ))
y1 = (x1 - self.H2(y2)).div(torch.exp( self.clamp * (torch.sigmoid(self.G2(y2)) * 2 - 1) ))
return torch.cat((y1, y2), 1)
class Bottleneck(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size):
super(Bottleneck, self).__init__()
# P = ((S-1)*W-S+F)/2, with F = filter size, S = stride
padding = (kernel_size - 1) // 2
self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding)
self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=1)
self.conv3 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
initialize_weights_xavier([self.conv1, self.conv2], 0.1)
initialize_weights(self.conv3, 0)
def forward(self, x):
conv1 = self.lrelu(self.conv1(x))
conv2 = self.lrelu(self.conv2(conv1))
conv3 = self.conv3(conv2)
return conv3
class SqueezeLayer(nn.Module):
def __init__(self, factor):
super().__init__()
self.factor = factor
def forward(self, input, reverse=False):
if not reverse:
output = self.squeeze2d(input, self.factor) # Squeeze in forward
return output
else:
output = self.unsqueeze2d(input, self.factor)
return output
def jacobian(self, x, rev=False):
return 0
@staticmethod
def squeeze2d(input, factor=2):
assert factor >= 1 and isinstance(factor, int)
if factor == 1:
return input
size = input.size()
B = size[0]
C = size[1]
H = size[2]
W = size[3]
assert H % factor == 0 and W % factor == 0, "{}".format((H, W, factor))
x = input.view(B, C, H // factor, factor, W // factor, factor)
x = x.permute(0, 3, 5, 1, 2, 4).contiguous()
x = x.view(B, factor * factor * C, H // factor, W // factor)
return x
@staticmethod
def unsqueeze2d(input, factor=2):
assert factor >= 1 and isinstance(factor, int)
factor2 = factor ** 2
if factor == 1:
return input
size = input.size()
B = size[0]
C = size[1]
H = size[2]
W = size[3]
assert C % (factor2) == 0, "{}".format(C)
x = input.view(B, factor, factor, C // factor2, H, W)
x = x.permute(0, 3, 4, 1, 5, 2).contiguous()
x = x.view(B, C // (factor2), H * factor, W * factor)
return x
class InvertibleConv1x1(nn.Module):
def __init__(self, num_channels):
super().__init__()
w_shape = [num_channels, num_channels]
w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype(np.float32)
self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init)))
self.w_shape = w_shape
def get_weight(self, input, reverse):
w_shape = self.w_shape
if not reverse:
weight = self.weight.view(w_shape[0], w_shape[1], 1, 1)
else:
weight = torch.inverse(self.weight.double()).float() \
.view(w_shape[0], w_shape[1], 1, 1)
return weight
def forward(self, input, reverse=False):
weight = self.get_weight(input, reverse)
if not reverse:
z = F.conv2d(input, weight)
return z
else:
z = F.conv2d(input, weight)
return z
class DenseBlock(nn.Module):
def __init__(self, channel_in, channel_out, init='xavier', gc=32, bias=True):
super(DenseBlock, self).__init__()
self.conv1 = nn.Conv2d(channel_in, gc, 3, 1, 1, bias=bias)
self.conv2 = nn.Conv2d(channel_in + gc, gc, 3, 1, 1, bias=bias)
self.conv3 = nn.Conv2d(channel_in + 2 * gc, gc, 3, 1, 1, bias=bias)
self.conv4 = nn.Conv2d(channel_in + 3 * gc, gc, 3, 1, 1, bias=bias)
self.conv5 = nn.Conv2d(channel_in + 4 * gc, channel_out, 3, 1, 1, bias=bias)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
if init == 'xavier':
initialize_weights_xavier([self.conv1, self.conv2, self.conv3, self.conv4], 0.1)
else:
initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4], 0.1)
initialize_weights(self.conv5, 0)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5
def initialize_weights(net_l, scale=1):
if not isinstance(net_l, list):
net_l = [net_l]
for net in net_l:
for m in net.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale # for residual block
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias.data, 0.0)
def initialize_weights_xavier(net_l, scale=1):
if not isinstance(net_l, list):
net_l = [net_l]
for net in net_l:
for m in net.modules():
if isinstance(m, nn.Conv2d):
init.xavier_normal_(m.weight)
m.weight.data *= scale # for residual block
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
init.xavier_normal_(m.weight)
m.weight.data *= scale
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias.data, 0.0)
def __init__(self, in_shape, int_ch, numTraceSamples=0, numSeriesTerms=0,
stride=1, coeff=.97, input_nonlin=True,
actnorm=True, n_power_iter=5, nonlin="elu", train=False):
"""
buid invertible bottleneck block
:param in_shape: shape of the input (channels, height, width)
:param int_ch: dimension of intermediate layers
:param stride: 1 if no downsample 2 if downsample
:param coeff: desired lipschitz constant
:param input_nonlin: if true applies a nonlinearity on the input
:param actnorm: if true uses actnorm like GLOW
:param n_power_iter: number of iterations for spectral normalization
:param nonlin: the nonlinearity to use
"""
super(conv_iresnet_block_simplified, self).__init__()
assert stride in (1, 2)
self.stride = stride
self.squeeze = IRes_Squeeze(stride)
self.coeff = coeff
self.numTraceSamples = numTraceSamples
self.numSeriesTerms = numSeriesTerms
self.n_power_iter = n_power_iter
nonlin = {
"relu": nn.ReLU,
"elu": nn.ELU,
"softplus": nn.Softplus,
"sorting": lambda: MaxMinGroup(group_size=2, axis=1)
}[nonlin]
# set shapes for spectral norm conv
in_ch, h, w = in_shape
layers = []
if input_nonlin:
layers.append(nonlin())
in_ch = in_ch * stride**2
kernel_size1 = 1
layers.append(self._wrapper_spectral_norm(nn.Conv2d(in_ch, int_ch, kernel_size=kernel_size1, padding=0),
(in_ch, h, w), kernel_size1))
layers.append(nonlin())
kernel_size3 = 1
layers.append(self._wrapper_spectral_norm(nn.Conv2d(int_ch, in_ch, kernel_size=kernel_size3, padding=0),
(int_ch, h, w), kernel_size3))
self.bottleneck_block = nn.Sequential(*layers)
if actnorm:
self.actnorm = ActNorm2D(in_ch, train=train)
else:
self.actnorm = None
def forward(self, x, rev=False, ignore_logdet=False, maxIter=25):
if not rev:
""" bijective or injective block forward """
if self.stride == 2:
x = self.squeeze.forward(x)
if self.actnorm is not None:
x, an_logdet = self.actnorm(x)
else:
an_logdet = 0.0
Fx = self.bottleneck_block(x)
if (self.numTraceSamples == 0 and self.numSeriesTerms == 0) or ignore_logdet:
trace = torch.tensor(0.)
else:
trace = power_series_matrix_logarithm_trace(Fx, x, self.numSeriesTerms, self.numTraceSamples)
y = Fx + x
return y, trace + an_logdet
else:
y = x
for iter_index in range(maxIter):
summand = self.bottleneck_block(x)
x = y - summand
if self.actnorm is not None:
x = self.actnorm.inverse(x)
if self.stride == 2:
x = self.squeeze.inverse(x)
return x
def _wrapper_spectral_norm(self, layer, shapes, kernel_size):
if kernel_size == 1:
# use spectral norm fc, because bound are tight for 1x1 convolutions
return spectral_norm_fc(layer, self.coeff,
n_power_iterations=self.n_power_iter)
else:
# use spectral norm based on conv, because bound not tight
return spectral_norm_conv(layer, self.coeff, shapes,
n_power_iterations=self.n_power_iter)
\ No newline at end of file
import numpy as np
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import warnings
from .priors import JointAutoregressiveHierarchicalPriors
from .our_utils import *
from compressai.layers import *
from .waseda import Cheng2020Anchor
class InvCompress(Cheng2020Anchor):
def __init__(self, N=192, **kwargs):
super().__init__(N=N)
self.g_a = None
self.g_s = None
self.enh = EnhModule(64)
self.inv = InvComp(M=N)
self.attention = AttModule(N)
def g_a_func(self, x):
x = self.enh(x)
x = self.inv(x)
x = self.attention(x)
return x
def g_s_func(self, x):
x = self.attention(x, rev = True)
x = self.inv(x, rev=True)
x = self.enh(x, rev=True)
return x
def forward(self, x):
y = self.g_a_func(x)
z = self.h_a(y)
z_hat, z_likelihoods = self.entropy_bottleneck(z)
params = self.h_s(z_hat)
y_hat = self.gaussian_conditional.quantize(
y, "noise" if self.training else "dequantize"
)
ctx_params = self.context_prediction(y_hat)
gaussian_params = self.entropy_parameters(
torch.cat((params, ctx_params), dim=1)
)
scales_hat, means_hat = gaussian_params.chunk(2, 1)
_, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat)
x_hat = self.g_s_func(y_hat)
return {
"x_hat": x_hat,
"likelihoods": {"y": y_likelihoods, "z": z_likelihoods}
}
@classmethod
def from_state_dict(cls, state_dict):
"""Return a new model instance from `state_dict`."""
N = state_dict["h_a.0.weight"].size(0)
net = cls(N)
net.load_state_dict(state_dict)
return net
def compress(self, x):
if next(self.parameters()).device != torch.device("cpu"):
warnings.warn(
"Inference on GPU is not recommended for the autoregressive "
"models (the entropy coder is run sequentially on CPU)."
)
y = self.g_a_func(x)
z = self.h_a(y)
z_strings = self.entropy_bottleneck.compress(z)
z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:])
params = self.h_s(z_hat)
s = 4 # scaling factor between z and y
kernel_size = 5 # context prediction kernel size
padding = (kernel_size - 1) // 2
y_height = z_hat.size(2) * s
y_width = z_hat.size(3) * s
y_hat = F.pad(y, (padding, padding, padding, padding))
y_strings = []
for i in range(y.size(0)):
string = self._compress_ar(
y_hat[i : i + 1],
params[i : i + 1],
y_height,
y_width,
kernel_size,
padding,
)
y_strings.append(string)
return {"strings": [y_strings, z_strings], "shape": z.size()[-2:], "y": y}
def decompress(self, strings, shape):
assert isinstance(strings, list) and len(strings) == 2
if next(self.parameters()).device != torch.device("cpu"):
warnings.warn(
"Inference on GPU is not recommended for the autoregressive "
"models (the entropy coder is run sequentially on CPU)."
)
z_hat = self.entropy_bottleneck.decompress(strings[1], shape)
params = self.h_s(z_hat)
s = 4 # scaling factor between z and y
kernel_size = 5 # context prediction kernel size
padding = (kernel_size - 1) // 2
y_height = z_hat.size(2) * s
y_width = z_hat.size(3) * s
y_hat = torch.zeros(
(z_hat.size(0), self.M, y_height + 2 * padding, y_width + 2 * padding),
device=z_hat.device,
)
for i, y_string in enumerate(strings[0]):
self._decompress_ar(
y_string,
y_hat[i : i + 1],
params[i : i + 1],
y_height,
y_width,
kernel_size,
padding,
)
y_hat = F.pad(y_hat, (-padding, -padding, -padding, -padding))
x_hat = self.g_s_func(y_hat).clamp_(0, 1)
return {"x_hat": x_hat}
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
# pylint: disable=E0611,E0401
from compressai.ans import BufferedRansEncoder, RansDecoder
from compressai.entropy_models import EntropyBottleneck, GaussianConditional
from compressai.layers import GDN, MaskedConv2d
from .utils import conv, deconv, update_registered_buffers
# pylint: enable=E0611,E0401
__all__ = [
"CompressionModel",
"FactorizedPrior",
"ScaleHyperprior",
"MeanScaleHyperprior",
"JointAutoregressiveHierarchicalPriors",
]
class CompressionModel(nn.Module):
"""Base class for constructing an auto-encoder with at least one entropy
bottleneck module.
Args:
entropy_bottleneck_channels (int): Number of channels of the entropy
bottleneck
"""
def __init__(self, entropy_bottleneck_channels, init_weights=True):
super().__init__()
self.entropy_bottleneck = EntropyBottleneck(entropy_bottleneck_channels)
if init_weights:
self._initialize_weights()
def aux_loss(self):
"""Return the aggregated loss over the auxiliary entropy bottleneck
module(s).
"""
aux_loss = sum(
m.loss() for m in self.modules() if isinstance(m, EntropyBottleneck)
)
return aux_loss
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, *args):
raise NotImplementedError()
def update(self, force=False):
"""Updates the entropy bottleneck(s) CDF values.
Needs to be called once after training to be able to later perform the
evaluation with an actual entropy coder.
Args:
force (bool): overwrite previous values (default: False)
Returns:
updated (bool): True if one of the EntropyBottlenecks was updated.
"""
updated = False
for m in self.children():
if not isinstance(m, EntropyBottleneck):
continue
rv = m.update(force=force)
updated |= rv
return updated
def load_state_dict(self, state_dict):
# Dynamically update the entropy bottleneck buffers related to the CDFs
update_registered_buffers(
self.entropy_bottleneck,
"entropy_bottleneck",
["_quantized_cdf", "_offset", "_cdf_length"],
state_dict,
)
super().load_state_dict(state_dict)
class FactorizedPrior(CompressionModel):
r"""Factorized Prior model from J. Balle, D. Minnen, S. Singh, S.J. Hwang,
N. Johnston: `"Variational Image Compression with a Scale Hyperprior"
<https://arxiv.org/abs/1802.01436>`_, Int Conf. on Learning Representations
(ICLR), 2018.
Args:
N (int): Number of channels
M (int): Number of channels in the expansion layers (last layer of the
encoder and last layer of the hyperprior decoder)
"""
def __init__(self, N, M, **kwargs):
super().__init__(entropy_bottleneck_channels=M, **kwargs)
self.g_a = nn.Sequential(
conv(3, N),
GDN(N),
conv(N, N),
GDN(N),
conv(N, N),
GDN(N),
conv(N, M),
)
self.g_s = nn.Sequential(
deconv(M, N),
GDN(N, inverse=True),
deconv(N, N),
GDN(N, inverse=True),
deconv(N, N),
GDN(N, inverse=True),
deconv(N, 3),
)
self.N = N
self.M = M
@property
def downsampling_factor(self) -> int:
return 2 ** 4
def forward(self, x):
y = self.g_a(x)
y_hat, y_likelihoods = self.entropy_bottleneck(y)
x_hat = self.g_s(y_hat)
return {
"x_hat": x_hat,
"likelihoods": {
"y": y_likelihoods,
},
}
@classmethod
def from_state_dict(cls, state_dict):
"""Return a new model instance from `state_dict`."""
N = state_dict["g_a.0.weight"].size(0)
M = state_dict["g_a.6.weight"].size(0)
net = cls(N, M)
net.load_state_dict(state_dict)
return net
def compress(self, x):
y = self.g_a(x)
y_strings = self.entropy_bottleneck.compress(y)
return {"strings": [y_strings], "shape": y.size()[-2:]}
def decompress(self, strings, shape):
assert isinstance(strings, list) and len(strings) == 1
y_hat = self.entropy_bottleneck.decompress(strings[0], shape)
x_hat = self.g_s(y_hat).clamp_(0, 1)
return {"x_hat": x_hat}
# From Balle's tensorflow compression examples
SCALES_MIN = 0.11
SCALES_MAX = 256
SCALES_LEVELS = 64
def get_scale_table(
min=SCALES_MIN, max=SCALES_MAX, levels=SCALES_LEVELS
): # pylint: disable=W0622
return torch.exp(torch.linspace(math.log(min), math.log(max), levels))
class ScaleHyperprior(CompressionModel):
r"""Scale Hyperprior model from J. Balle, D. Minnen, S. Singh, S.J. Hwang,
N. Johnston: `"Variational Image Compression with a Scale Hyperprior"
<https://arxiv.org/abs/1802.01436>`_ Int. Conf. on Learning Representations
(ICLR), 2018.
Args:
N (int): Number of channels
M (int): Number of channels in the expansion layers (last layer of the
encoder and last layer of the hyperprior decoder)
"""
def __init__(self, N, M, **kwargs):
super().__init__(entropy_bottleneck_channels=N, **kwargs)
self.g_a = nn.Sequential(
conv(3, N),
GDN(N),
conv(N, N),
GDN(N),
conv(N, N),
GDN(N),
conv(N, M),
)
self.g_s = nn.Sequential(
deconv(M, N),
GDN(N, inverse=True),
deconv(N, N),
GDN(N, inverse=True),
deconv(N, N),
GDN(N, inverse=True),
deconv(N, 3),
)
self.h_a = nn.Sequential(
conv(M, N, stride=1, kernel_size=3),
nn.ReLU(inplace=True),
conv(N, N),
nn.ReLU(inplace=True),
conv(N, N),
)
self.h_s = nn.Sequential(
deconv(N, N),
nn.ReLU(inplace=True),
deconv(N, N),
nn.ReLU(inplace=True),
conv(N, M, stride=1, kernel_size=3),
nn.ReLU(inplace=True),
)
self.gaussian_conditional = GaussianConditional(None)
self.N = int(N)
self.M = int(M)
@property
def downsampling_factor(self) -> int:
return 2 ** (4 + 2)
def forward(self, x):
y = self.g_a(x)
z = self.h_a(torch.abs(y))
z_hat, z_likelihoods = self.entropy_bottleneck(z)
scales_hat = self.h_s(z_hat)
y_hat, y_likelihoods = self.gaussian_conditional(y, scales_hat)
x_hat = self.g_s(y_hat)
return {
"x_hat": x_hat,
"likelihoods": {"y": y_likelihoods, "z": z_likelihoods},
}
def load_state_dict(self, state_dict):
update_registered_buffers(
self.gaussian_conditional,
"gaussian_conditional",
["_quantized_cdf", "_offset", "_cdf_length", "scale_table"],
state_dict,
)
super().load_state_dict(state_dict)
@classmethod
def from_state_dict(cls, state_dict):
"""Return a new model instance from `state_dict`."""
N = state_dict["g_a.0.weight"].size(0)
M = state_dict["g_a.6.weight"].size(0)
net = cls(N, M)
net.load_state_dict(state_dict)
return net
def update(self, scale_table=None, force=False):
if scale_table is None:
scale_table = get_scale_table()
updated = self.gaussian_conditional.update_scale_table(scale_table, force=force)
updated |= super().update(force=force)
return updated
def compress(self, x):
y = self.g_a(x)
z = self.h_a(torch.abs(y))
z_strings = self.entropy_bottleneck.compress(z)
z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:])
scales_hat = self.h_s(z_hat)
indexes = self.gaussian_conditional.build_indexes(scales_hat)
y_strings = self.gaussian_conditional.compress(y, indexes)
return {"strings": [y_strings, z_strings], "shape": z.size()[-2:]}
def decompress(self, strings, shape):
assert isinstance(strings, list) and len(strings) == 2
z_hat = self.entropy_bottleneck.decompress(strings[1], shape)
scales_hat = self.h_s(z_hat)
indexes = self.gaussian_conditional.build_indexes(scales_hat)
y_hat = self.gaussian_conditional.decompress(strings[0], indexes)
x_hat = self.g_s(y_hat).clamp_(0, 1)
return {"x_hat": x_hat}
class MeanScaleHyperprior(ScaleHyperprior):
r"""Scale Hyperprior with non zero-mean Gaussian conditionals from D.
Minnen, J. Balle, G.D. Toderici: `"Joint Autoregressive and Hierarchical
Priors for Learned Image Compression" <https://arxiv.org/abs/1809.02736>`_,
Adv. in Neural Information Processing Systems 31 (NeurIPS 2018).
Args:
N (int): Number of channels
M (int): Number of channels in the expansion layers (last layer of the
encoder and last layer of the hyperprior decoder)
"""
def __init__(self, N, M, **kwargs):
super().__init__(N, M, **kwargs)
self.h_a = nn.Sequential(
conv(M, N, stride=1, kernel_size=3),
nn.LeakyReLU(inplace=True),
conv(N, N),
nn.LeakyReLU(inplace=True),
conv(N, N),
)
self.h_s = nn.Sequential(
deconv(N, M),
nn.LeakyReLU(inplace=True),
deconv(M, M * 3 // 2),
nn.LeakyReLU(inplace=True),
conv(M * 3 // 2, M * 2, stride=1, kernel_size=3),
)
def forward(self, x):
y = self.g_a(x)
z = self.h_a(y)
z_hat, z_likelihoods = self.entropy_bottleneck(z)
gaussian_params = self.h_s(z_hat)
scales_hat, means_hat = gaussian_params.chunk(2, 1)
y_hat, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat)
x_hat = self.g_s(y_hat)
return {
"x_hat": x_hat,
"likelihoods": {"y": y_likelihoods, "z": z_likelihoods},
}
def compress(self, x):
y = self.g_a(x)
z = self.h_a(y)
z_strings = self.entropy_bottleneck.compress(z)
z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:])
gaussian_params = self.h_s(z_hat)
scales_hat, means_hat = gaussian_params.chunk(2, 1)
indexes = self.gaussian_conditional.build_indexes(scales_hat)
y_strings = self.gaussian_conditional.compress(y, indexes, means=means_hat)
return {"strings": [y_strings, z_strings], "shape": z.size()[-2:]}
def decompress(self, strings, shape):
assert isinstance(strings, list) and len(strings) == 2
z_hat = self.entropy_bottleneck.decompress(strings[1], shape)
gaussian_params = self.h_s(z_hat)
scales_hat, means_hat = gaussian_params.chunk(2, 1)
indexes = self.gaussian_conditional.build_indexes(scales_hat)
y_hat = self.gaussian_conditional.decompress(
strings[0], indexes, means=means_hat
)
x_hat = self.g_s(y_hat).clamp_(0, 1)
return {"x_hat": x_hat}
class JointAutoregressiveHierarchicalPriors(MeanScaleHyperprior):
r"""Joint Autoregressive Hierarchical Priors model from D.
Minnen, J. Balle, G.D. Toderici: `"Joint Autoregressive and Hierarchical
Priors for Learned Image Compression" <https://arxiv.org/abs/1809.02736>`_,
Adv. in Neural Information Processing Systems 31 (NeurIPS 2018).
Args:
N (int): Number of channels
M (int): Number of channels in the expansion layers (last layer of the
encoder and last layer of the hyperprior decoder)
"""
def __init__(self, N=192, M=192, **kwargs):
super().__init__(N=N, M=M, **kwargs)
self.g_a = nn.Sequential(
conv(3, N, kernel_size=5, stride=2),
GDN(N),
conv(N, N, kernel_size=5, stride=2),
GDN(N),
conv(N, N, kernel_size=5, stride=2),
GDN(N),
conv(N, M, kernel_size=5, stride=2),
)
self.g_s = nn.Sequential(
deconv(M, N, kernel_size=5, stride=2),
GDN(N, inverse=True),
deconv(N, N, kernel_size=5, stride=2),
GDN(N, inverse=True),
deconv(N, N, kernel_size=5, stride=2),
GDN(N, inverse=True),
deconv(N, 3, kernel_size=5, stride=2),
)
self.h_a = nn.Sequential(
conv(M, N, stride=1, kernel_size=3),
nn.LeakyReLU(inplace=True),
conv(N, N, stride=2, kernel_size=5),
nn.LeakyReLU(inplace=True),
conv(N, N, stride=2, kernel_size=5),
)
self.h_s = nn.Sequential(
deconv(N, M, stride=2, kernel_size=5),
nn.LeakyReLU(inplace=True),
deconv(M, M * 3 // 2, stride=2, kernel_size=5),
nn.LeakyReLU(inplace=True),
conv(M * 3 // 2, M * 2, stride=1, kernel_size=3),
)
self.entropy_parameters = nn.Sequential(
nn.Conv2d(M * 12 // 3, M * 10 // 3, 1),
nn.LeakyReLU(inplace=True),
nn.Conv2d(M * 10 // 3, M * 8 // 3, 1),
nn.LeakyReLU(inplace=True),
nn.Conv2d(M * 8 // 3, M * 6 // 3, 1),
)
self.context_prediction = MaskedConv2d(
M, 2 * M, kernel_size=5, padding=2, stride=1
)
self.gaussian_conditional = GaussianConditional(None)
self.N = int(N)
self.M = int(M)
@property
def downsampling_factor(self) -> int:
return 2 ** (4 + 2)
def forward(self, x):
y = self.g_a(x)
z = self.h_a(y)
z_hat, z_likelihoods = self.entropy_bottleneck(z)
params = self.h_s(z_hat)
y_hat = self.gaussian_conditional.quantize(
y, "noise" if self.training else "dequantize"
)
ctx_params = self.context_prediction(y_hat)
gaussian_params = self.entropy_parameters(
torch.cat((params, ctx_params), dim=1)
)
scales_hat, means_hat = gaussian_params.chunk(2, 1)
_, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat)
x_hat = self.g_s(y_hat)
return {
"x_hat": x_hat,
"likelihoods": {"y": y_likelihoods, "z": z_likelihoods},
}
@classmethod
def from_state_dict(cls, state_dict):
"""Return a new model instance from `state_dict`."""
N = state_dict["g_a.0.weight"].size(0)
M = state_dict["g_a.6.weight"].size(0)
net = cls(N, M)
net.load_state_dict(state_dict)
return net
def compress(self, x):
if next(self.parameters()).device != torch.device("cpu"):
warnings.warn(
"Inference on GPU is not recommended for the autoregressive "
"models (the entropy coder is run sequentially on CPU)."
)
y = self.g_a(x)
z = self.h_a(y)
z_strings = self.entropy_bottleneck.compress(z)
z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:])
params = self.h_s(z_hat)
s = 4 # scaling factor between z and y
kernel_size = 5 # context prediction kernel size
padding = (kernel_size - 1) // 2
y_height = z_hat.size(2) * s
y_width = z_hat.size(3) * s
y_hat = F.pad(y, (padding, padding, padding, padding))
y_strings = []
for i in range(y.size(0)):
string = self._compress_ar(
y_hat[i : i + 1],
params[i : i + 1],
y_height,
y_width,
kernel_size,
padding,
)
y_strings.append(string)
return {"strings": [y_strings, z_strings], "shape": z.size()[-2:]}
def _compress_ar(self, y_hat, params, height, width, kernel_size, padding):
cdf = self.gaussian_conditional.quantized_cdf.tolist()
cdf_lengths = self.gaussian_conditional.cdf_length.tolist()
offsets = self.gaussian_conditional.offset.tolist()
encoder = BufferedRansEncoder()
symbols_list = []
indexes_list = []
# Warning, this is slow...
# TODO: profile the calls to the bindings...
masked_weight = self.context_prediction.weight * self.context_prediction.mask
for h in range(height):
for w in range(width):
y_crop = y_hat[:, :, h : h + kernel_size, w : w + kernel_size]
ctx_p = F.conv2d(
y_crop,
masked_weight,
bias=self.context_prediction.bias,
)
# 1x1 conv for the entropy parameters prediction network, so
# we only keep the elements in the "center"
p = params[:, :, h : h + 1, w : w + 1]
gaussian_params = self.entropy_parameters(torch.cat((p, ctx_p), dim=1))
gaussian_params = gaussian_params.squeeze(3).squeeze(2)
scales_hat, means_hat = gaussian_params.chunk(2, 1)
indexes = self.gaussian_conditional.build_indexes(scales_hat)
y_crop = y_crop[:, :, padding, padding]
y_q = self.gaussian_conditional.quantize(y_crop, "symbols", means_hat)
y_hat[:, :, h + padding, w + padding] = y_q + means_hat
symbols_list.extend(y_q.squeeze().tolist())
indexes_list.extend(indexes.squeeze().tolist())
encoder.encode_with_indexes(
symbols_list, indexes_list, cdf, cdf_lengths, offsets
)
string = encoder.flush()
return string
def decompress(self, strings, shape):
assert isinstance(strings, list) and len(strings) == 2
if next(self.parameters()).device != torch.device("cpu"):
warnings.warn(
"Inference on GPU is not recommended for the autoregressive "
"models (the entropy coder is run sequentially on CPU)."
)
# FIXME: we don't respect the default entropy coder and directly call the
# range ANS decoder
z_hat = self.entropy_bottleneck.decompress(strings[1], shape)
params = self.h_s(z_hat)
s = 4 # scaling factor between z and y
kernel_size = 5 # context prediction kernel size
padding = (kernel_size - 1) // 2
y_height = z_hat.size(2) * s
y_width = z_hat.size(3) * s
# initialize y_hat to zeros, and pad it so we can directly work with
# sub-tensors of size (N, C, kernel size, kernel_size)
y_hat = torch.zeros(
(z_hat.size(0), self.M, y_height + 2 * padding, y_width + 2 * padding),
device=z_hat.device,
)
for i, y_string in enumerate(strings[0]):
self._decompress_ar(
y_string,
y_hat[i : i + 1],
params[i : i + 1],
y_height,
y_width,
kernel_size,
padding,
)
y_hat = F.pad(y_hat, (-padding, -padding, -padding, -padding))
x_hat = self.g_s(y_hat).clamp_(0, 1)
return {"x_hat": x_hat}
def _decompress_ar(
self, y_string, y_hat, params, height, width, kernel_size, padding
):
cdf = self.gaussian_conditional.quantized_cdf.tolist()
cdf_lengths = self.gaussian_conditional.cdf_length.tolist()
offsets = self.gaussian_conditional.offset.tolist()
decoder = RansDecoder()
decoder.set_stream(y_string)
# Warning: this is slow due to the auto-regressive nature of the
# decoding... See more recent publication where they use an
# auto-regressive module on chunks of channels for faster decoding...
for h in range(height):
for w in range(width):
# only perform the 5x5 convolution on a cropped tensor
# centered in (h, w)
y_crop = y_hat[:, :, h : h + kernel_size, w : w + kernel_size]
ctx_p = F.conv2d(
y_crop,
self.context_prediction.weight,
bias=self.context_prediction.bias,
)
# 1x1 conv for the entropy parameters prediction network, so
# we only keep the elements in the "center"
p = params[:, :, h : h + 1, w : w + 1]
gaussian_params = self.entropy_parameters(torch.cat((p, ctx_p), dim=1))
scales_hat, means_hat = gaussian_params.chunk(2, 1)
indexes = self.gaussian_conditional.build_indexes(scales_hat)
rv = decoder.decode_stream(
indexes.squeeze().tolist(), cdf, cdf_lengths, offsets
)
rv = torch.Tensor(rv).reshape(1, -1, 1, 1)
rv = self.gaussian_conditional.dequantize(rv, means_hat)
hp = h + padding
wp = w + padding
y_hat[:, :, hp : hp + 1, wp : wp + 1] = rv
\ No newline at end of file
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
def find_named_module(module, query):
"""Helper function to find a named module. Returns a `nn.Module` or `None`
Args:
module (nn.Module): the root module
query (str): the module name to find
Returns:
nn.Module or None
"""
return next((m for n, m in module.named_modules() if n == query), None)
def find_named_buffer(module, query):
"""Helper function to find a named buffer. Returns a `torch.Tensor` or `None`
Args:
module (nn.Module): the root module
query (str): the buffer name to find
Returns:
torch.Tensor or None
"""
return next((b for n, b in module.named_buffers() if n == query), None)
def _update_registered_buffer(
module,
buffer_name,
state_dict_key,
state_dict,
policy="resize_if_empty",
dtype=torch.int,
):
new_size = state_dict[state_dict_key].size()
registered_buf = find_named_buffer(module, buffer_name)
if policy in ("resize_if_empty", "resize"):
if registered_buf is None:
raise RuntimeError(f'buffer "{buffer_name}" was not registered')
if policy == "resize" or registered_buf.numel() == 0:
registered_buf.resize_(new_size)
elif policy == "register":
if registered_buf is not None:
raise RuntimeError(f'buffer "{buffer_name}" was already registered')
module.register_buffer(buffer_name, torch.empty(new_size, dtype=dtype).fill_(0))
else:
raise ValueError(f'Invalid policy "{policy}"')
def update_registered_buffers(
module,
module_name,
buffer_names,
state_dict,
policy="resize_if_empty",
dtype=torch.int,
):
"""Update the registered buffers in a module according to the tensors sized
in a state_dict.
(There's no way in torch to directly load a buffer with a dynamic size)
Args:
module (nn.Module): the module
module_name (str): module name in the state dict
buffer_names (list(str)): list of the buffer names to resize in the module
state_dict (dict): the state dict
policy (str): Update policy, choose from
('resize_if_empty', 'resize', 'register')
dtype (dtype): Type of buffer to be registered (when policy is 'register')
"""
valid_buffer_names = [n for n, _ in module.named_buffers()]
for buffer_name in buffer_names:
if buffer_name not in valid_buffer_names:
raise ValueError(f'Invalid buffer name "{buffer_name}"')
for buffer_name in buffer_names:
_update_registered_buffer(
module,
buffer_name,
f"{module_name}.{buffer_name}",
state_dict,
policy,
dtype,
)
def conv(in_channels, out_channels, kernel_size=5, stride=2):
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=kernel_size // 2,
)
def deconv(in_channels, out_channels, kernel_size=5, stride=2):
return nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
output_padding=stride - 1,
padding=kernel_size // 2,
)
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch.nn as nn
from compressai.layers import (
AttentionBlock,
ResidualBlock,
ResidualBlockUpsample,
ResidualBlockWithStride,
conv3x3,
subpel_conv3x3,
)
from .priors import JointAutoregressiveHierarchicalPriors
class Cheng2020Anchor(JointAutoregressiveHierarchicalPriors):
"""Anchor model variant from `"Learned Image Compression with
Discretized Gaussian Mixture Likelihoods and Attention Modules"
<https://arxiv.org/abs/2001.01568>`_, by Zhengxue Cheng, Heming Sun, Masaru
Takeuchi, Jiro Katto.
Uses residual blocks with small convolutions (3x3 and 1x1), and sub-pixel
convolutions for up-sampling.
Args:
N (int): Number of channels
"""
def __init__(self, N=192, **kwargs):
super().__init__(N=N, M=N, **kwargs)
self.g_a = nn.Sequential(
ResidualBlockWithStride(3, N, stride=2),
ResidualBlock(N, N),
ResidualBlockWithStride(N, N, stride=2),
ResidualBlock(N, N),
ResidualBlockWithStride(N, N, stride=2),
ResidualBlock(N, N),
conv3x3(N, N, stride=2),
)
self.h_a = nn.Sequential(
conv3x3(N, N),
nn.LeakyReLU(inplace=True),
conv3x3(N, N),
nn.LeakyReLU(inplace=True),
conv3x3(N, N, stride=2),
nn.LeakyReLU(inplace=True),
conv3x3(N, N),
nn.LeakyReLU(inplace=True),
conv3x3(N, N, stride=2),
)
self.h_s = nn.Sequential(
conv3x3(N, N),
nn.LeakyReLU(inplace=True),
subpel_conv3x3(N, N, 2),
nn.LeakyReLU(inplace=True),
conv3x3(N, N * 3 // 2),
nn.LeakyReLU(inplace=True),
subpel_conv3x3(N * 3 // 2, N * 3 // 2, 2),
nn.LeakyReLU(inplace=True),
conv3x3(N * 3 // 2, N * 2),
)
self.g_s = nn.Sequential(
ResidualBlock(N, N),
ResidualBlockUpsample(N, N, 2),
ResidualBlock(N, N),
ResidualBlockUpsample(N, N, 2),
ResidualBlock(N, N),
ResidualBlockUpsample(N, N, 2),
ResidualBlock(N, N),
subpel_conv3x3(N, 3, 2),
)
@classmethod
def from_state_dict(cls, state_dict):
"""Return a new model instance from `state_dict`."""
N = state_dict["g_a.0.conv1.weight"].size(0)
net = cls(N)
net.load_state_dict(state_dict)
return net
class Cheng2020Attention(Cheng2020Anchor):
"""Self-attention model variant from `"Learned Image Compression with
Discretized Gaussian Mixture Likelihoods and Attention Modules"
<https://arxiv.org/abs/2001.01568>`_, by Zhengxue Cheng, Heming Sun, Masaru
Takeuchi, Jiro Katto.
Uses self-attention, residual blocks with small convolutions (3x3 and 1x1),
and sub-pixel convolutions for up-sampling.
Args:
N (int): Number of channels
"""
def __init__(self, N=192, **kwargs):
super().__init__(N=N, **kwargs)
self.g_a = nn.Sequential(
ResidualBlockWithStride(3, N, stride=2),
ResidualBlock(N, N),
ResidualBlockWithStride(N, N, stride=2),
AttentionBlock(N),
ResidualBlock(N, N),
ResidualBlockWithStride(N, N, stride=2),
ResidualBlock(N, N),
conv3x3(N, N, stride=2),
AttentionBlock(N),
)
self.g_s = nn.Sequential(
AttentionBlock(N),
ResidualBlock(N, N),
ResidualBlockUpsample(N, N, 2),
ResidualBlock(N, N),
ResidualBlockUpsample(N, N, 2),
AttentionBlock(N),
ResidualBlock(N, N),
ResidualBlockUpsample(N, N, 2),
ResidualBlock(N, N),
subpel_conv3x3(N, 3, 2),
)
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .bound_ops import LowerBound
from .ops import ste_round
from .parametrizers import NonNegativeParametrizer
__all__ = ["ste_round", "LowerBound", "NonNegativeParametrizer"]
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
class LowerBoundFunction(torch.autograd.Function):
"""Autograd function for the `LowerBound` operator."""
@staticmethod
def forward(ctx, input_, bound):
ctx.save_for_backward(input_, bound)
return torch.max(input_, bound)
@staticmethod
def backward(ctx, grad_output):
input_, bound = ctx.saved_tensors
pass_through_if = (input_ >= bound) | (grad_output < 0)
return pass_through_if.type(grad_output.dtype) * grad_output, None
class LowerBound(nn.Module):
"""Lower bound operator, computes `torch.max(x, bound)` with a custom
gradient.
The derivative is replaced by the identity function when `x` is moved
towards the `bound`, otherwise the gradient is kept to zero.
"""
def __init__(self, bound):
super().__init__()
self.register_buffer("bound", torch.Tensor([float(bound)]))
@torch.jit.unused
def lower_bound(self, x):
return LowerBoundFunction.apply(x, self.bound)
def forward(self, x):
if torch.jit.is_scripting():
return torch.max(x, self.bound)
return self.lower_bound(x)
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
def ste_round(x):
"""
Rounding with non-zero gradients. Gradients are approximated by replacing
the derivative by the identity function.
Used in `"Lossy Image Compression with Compressive Autoencoders"
<https://arxiv.org/abs/1703.00395>`_
.. note::
Implemented with the pytorch `detach()` reparametrization trick:
`x_round = x_round - x.detach() + x`
"""
return torch.round(x) - x.detach() + x
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from .bound_ops import LowerBound
class NonNegativeParametrizer(nn.Module):
"""
Non negative reparametrization.
Used for stability during training.
"""
def __init__(self, minimum=0, reparam_offset=2 ** -18):
super().__init__()
self.minimum = float(minimum)
self.reparam_offset = float(reparam_offset)
pedestal = self.reparam_offset ** 2
self.register_buffer("pedestal", torch.Tensor([pedestal]))
bound = (self.minimum + self.reparam_offset ** 2) ** 0.5
self.lower_bound = LowerBound(bound)
def init(self, x):
return torch.sqrt(torch.max(x + self.pedestal, self.pedestal))
def forward(self, x):
out = self.lower_bound(x)
out = out ** 2 - self.pedestal
return out
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .transforms import *
from typing import Tuple, Union
import torch
import torch.nn.functional as F
from torch import Tensor
YCBCR_WEIGHTS = {
# Spec: (K_r, K_g, K_b) with K_g = 1 - K_r - K_b
"ITU-R_BT.709": (0.2126, 0.7152, 0.0722)
}
def _check_input_tensor(tensor: Tensor) -> None:
if (
not isinstance(tensor, Tensor)
or not tensor.is_floating_point()
or not len(tensor.size()) in (3, 4)
or not tensor.size(-3) == 3
):
raise ValueError(
"Expected a 3D or 4D tensor with shape (Nx3xHxW) or (3xHxW) as input"
)
def rgb2ycbcr(rgb: Tensor) -> Tensor:
"""RGB to YCbCr conversion for torch Tensor.
Using ITU-R BT.709 coefficients.
Args:
rgb (torch.Tensor): 3D or 4D floating point RGB tensor
Returns:
ycbcr (torch.Tensor): converted tensor
"""
_check_input_tensor(rgb)
r, g, b = rgb.chunk(3, -3)
Kr, Kg, Kb = YCBCR_WEIGHTS["ITU-R_BT.709"]
y = Kr * r + Kg * g + Kb * b
cb = 0.5 * (b - y) / (1 - Kb) + 0.5
cr = 0.5 * (r - y) / (1 - Kr) + 0.5
ycbcr = torch.cat((y, cb, cr), dim=-3)
return ycbcr
def ycbcr2rgb(ycbcr: Tensor) -> Tensor:
"""YCbCr to RGB conversion for torch Tensor.
Using ITU-R BT.709 coefficients.
Args:
ycbcr (torch.Tensor): 3D or 4D floating point RGB tensor
Returns:
rgb (torch.Tensor): converted tensor
"""
_check_input_tensor(ycbcr)
y, cb, cr = ycbcr.chunk(3, -3)
Kr, Kg, Kb = YCBCR_WEIGHTS["ITU-R_BT.709"]
r = y + (2 - 2 * Kr) * (cr - 0.5)
b = y + (2 - 2 * Kb) * (cb - 0.5)
g = (y - Kr * r - Kb * b) / Kg
rgb = torch.cat((r, g, b), dim=-3)
return rgb
def yuv_444_to_420(
yuv: Union[Tensor, Tuple[Tensor, Tensor, Tensor]],
mode: str = "avg_pool",
) -> Tuple[Tensor, Tensor, Tensor]:
"""Convert a 444 tensor to a 420 representation.
Args:
yuv (torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)): 444
input to be downsampled. Takes either a (Nx3xHxW) tensor or a tuple
of 3 (Nx1xHxW) tensors.
mode (str): algorithm used for downsampling: ``'avg_pool'``. Default
``'avg_pool'``
Returns:
(torch.Tensor, torch.Tensor, torch.Tensor): Converted 420
"""
if mode not in ("avg_pool",):
raise ValueError(f'Invalid downsampling mode "{mode}".')
if mode == "avg_pool":
def _downsample(tensor):
return F.avg_pool2d(tensor, kernel_size=2, stride=2)
if isinstance(yuv, torch.Tensor):
y, u, v = yuv.chunk(3, 1)
else:
y, u, v = yuv
return (y, _downsample(u), _downsample(v))
def yuv_420_to_444(
yuv: Tuple[Tensor, Tensor, Tensor],
mode: str = "bilinear",
return_tuple: bool = False,
) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor]]:
"""Convert a 420 input to a 444 representation.
Args:
yuv (torch.Tensor, torch.Tensor, torch.Tensor): 420 input frames in
(Nx1xHxW) format
mode (str): algorithm used for upsampling: ``'bilinear'`` |
``'nearest'`` Default ``'bilinear'``
return_tuple (bool): return input as tuple of tensors instead of a
concatenated tensor, 3 (Nx1xHxW) tensors instead of one (Nx3xHxW)
tensor (default: False)
Returns:
(torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)): Converted
444
"""
if len(yuv) != 3 or any(not isinstance(c, torch.Tensor) for c in yuv):
raise ValueError("Expected a tuple of 3 torch tensors")
if mode not in ("bilinear", "nearest"):
raise ValueError(f'Invalid upsampling mode "{mode}".')
if mode in ("bilinear", "nearest"):
def _upsample(tensor):
return F.interpolate(tensor, scale_factor=2, mode=mode, align_corners=False)
y, u, v = yuv
u, v = _upsample(u), _upsample(v)
if return_tuple:
return y, u, v
return torch.cat((y, u, v), dim=1)
from . import functional as F_transforms
__all__ = [
"RGB2YCbCr",
"YCbCr2RGB",
"YUV444To420",
"YUV420To444",
]
class RGB2YCbCr:
"""Convert a RGB tensor to YCbCr.
The tensor is expected to be in the [0, 1] floating point range, with a
shape of (3xHxW) or (Nx3xHxW).
"""
def __call__(self, rgb):
"""
Args:
rgb (torch.Tensor): 3D or 4D floating point RGB tensor
Returns:
ycbcr(torch.Tensor): converted tensor
"""
return F_transforms.rgb2ycbcr(rgb)
def __repr__(self):
return f"{self.__class__.__name__}()"
class YCbCr2RGB:
"""Convert a YCbCr tensor to RGB.
The tensor is expected to be in the [0, 1] floating point range, with a
shape of (3xHxW) or (Nx3xHxW).
"""
def __call__(self, ycbcr):
"""
Args:
ycbcr(torch.Tensor): 3D or 4D floating point RGB tensor
Returns:
rgb(torch.Tensor): converted tensor
"""
return F_transforms.ycbcr2rgb(ycbcr)
def __repr__(self):
return f"{self.__class__.__name__}()"
class YUV444To420:
"""Convert a YUV 444 tensor to a 420 representation.
Args:
mode (str): algorithm used for downsampling: ``'avg_pool'``. Default
``'avg_pool'``
Example:
>>> x = torch.rand(1, 3, 32, 32)
>>> y, u, v = YUV444To420()(x)
>>> y.size() # 1, 1, 32, 32
>>> u.size() # 1, 1, 16, 16
"""
def __init__(self, mode: str = "avg_pool"):
self.mode = str(mode)
def __call__(self, yuv):
"""
Args:
yuv (torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)):
444 input to be downsampled. Takes either a (Nx3xHxW) tensor or
a tuple of 3 (Nx1xHxW) tensors.
Returns:
(torch.Tensor, torch.Tensor, torch.Tensor): Converted 420
"""
return F_transforms.yuv_444_to_420(yuv, mode=self.mode)
def __repr__(self):
return f"{self.__class__.__name__}()"
class YUV420To444:
"""Convert a YUV 420 input to a 444 representation.
Args:
mode (str): algorithm used for upsampling: ``'bilinear'`` | ``'nearest'``.
Default ``'bilinear'``
return_tuple (bool): return input as tuple of tensors instead of a
concatenated tensor, 3 (Nx1xHxW) tensors instead of one (Nx3xHxW)
tensor (default: False)
Example:
>>> y = torch.rand(1, 1, 32, 32)
>>> u, v = torch.rand(1, 1, 16, 16), torch.rand(1, 1, 16, 16)
>>> x = YUV420To444()((y, u, v))
>>> x.size() # 1, 3, 32, 32
"""
def __init__(self, mode: str = "bilinear", return_tuple: bool = False):
self.mode = str(mode)
self.return_tuple = bool(return_tuple)
def __call__(self, yuv):
"""
Args:
yuv (torch.Tensor, torch.Tensor, torch.Tensor): 420 input frames in
(Nx1xHxW) format
Returns:
(torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)): Converted
444
"""
return F_transforms.yuv_420_to_444(yuv, return_tuple=self.return_tuple)
def __repr__(self):
return f"{self.__class__.__name__}(return_tuple={self.return_tuple})"
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Collect performance metrics of published traditional or end-to-end image
codecs.
"""
import argparse
import json
import multiprocessing as mp
import os
import sys
from collections import defaultdict
from itertools import starmap
from typing import List
from .codecs import AV1, BPG, HM, JPEG, JPEG2000, TFCI, VTM, Codec, WebP
# from torchvision.datasets.folder
IMG_EXTENSIONS = (
".jpg",
".jpeg",
".png",
".ppm",
".bmp",
".pgm",
".tif",
".tiff",
".webp",
)
codecs = [JPEG, WebP, JPEG2000, BPG, TFCI, VTM, HM, AV1]
# we need the quality index (not value) to compute the stats later
def func(codec, i, *args):
rv = codec.run(*args)
return i, rv
def collect(codec: Codec, dataset: str, qualities: List[int], num_jobs: int = 1):
filepaths = [
os.path.join(dataset, f)
for f in os.listdir(dataset)
if os.path.splitext(f)[-1].lower() in IMG_EXTENSIONS
]
# print(filepaths)
pool = mp.Pool(num_jobs) if num_jobs > 1 else None
if len(filepaths) == 0:
print("No images found in the dataset directory")
sys.exit(1)
args = [(codec, i, f, q) for i, q in enumerate(qualities) for f in filepaths]
if pool:
rv = pool.starmap(func, args)
else:
rv = list(starmap(func, args))
results = [defaultdict(float) for _ in range(len(qualities))]
for i, metrics in rv:
for k, v in metrics.items():
results[i][k] += v
for i, _ in enumerate(results):
for k, v in results[i].items():
results[i][k] = v / len(filepaths)
# list of dict -> dict of list
out = defaultdict(list)
for r in results:
for k, v in r.items():
out[k].append(v)
return out
def setup_args():
description = "Collect codec metrics."
parser = argparse.ArgumentParser(description=description)
subparsers = parser.add_subparsers(dest="codec", help="Select codec")
subparsers.required = True
return parser, subparsers
def setup_common_args(parser):
parser.add_argument("dataset", type=str)
parser.add_argument(
"-j",
"--num-jobs",
type=int,
metavar="N",
default=1,
help="Number of parallel jobs (default: %(default)s)",
)
parser.add_argument(
"-q",
"--quality",
dest="qualities",
metavar="Q",
default=[5, 10, 20, 30, 40, 50, 60, 70,80,90],
nargs="*",
type=int,
help="quality parameter (default: %(default)s)",
)
# [3,5,10,15,17,20,22,25,27,30,32,35,37,40,42,45,47,50],
# new added
parser.add_argument(
"--name",
dest="name",
default="ans",
type=str,
help="name for json",
)
def main(argv):
import time
start = time.time()
parser, subparsers = setup_args()
for c in codecs:
cparser = subparsers.add_parser(c.__name__.lower(), help=f"{c.__name__}")
setup_common_args(cparser)
c.setup_args(cparser)
args = parser.parse_args(argv)
codec_cls = next(c for c in codecs if c.__name__.lower() == args.codec)
codec = codec_cls(args)
results = collect(codec, args.dataset, args.qualities, args.num_jobs)
output = {
"name": codec.name,
"description": codec.description,
"results": results,
}
print(json.dumps(output, indent=2))
end = time.time()
print('total time:', end - start)
output_dir = '/home/felix/disk2/compressai_v2/codes/results/log'
output_json_path = os.path.join(output_dir, args.name+'.json')
with open(output_json_path, 'w') as f:
json.dump(output, f, indent=2)
if __name__ == "__main__":
main(sys.argv[1:])
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
import platform
import subprocess
import sys
import time
from tempfile import mkstemp
from typing import Tuple, Union
import numpy as np
import PIL
import PIL.Image as Image
import torch
from pytorch_msssim import ms_ssim
from compressai.transforms.functional import rgb2ycbcr, ycbcr2rgb
# from torchvision.datasets.folder
IMG_EXTENSIONS = (
".jpg",
".jpeg",
".png",
".ppm",
".bmp",
".pgm",
".tif",
".tiff",
".webp",
)
def filesize(filepath: str) -> int:
"""Return file size in bits of `filepath`."""
if not os.path.isfile(filepath):
raise ValueError(f'Invalid file "{filepath}".')
return os.stat(filepath).st_size
def read_image(filepath: str, mode: str = "RGB") -> np.array:
"""Return PIL image in the specified `mode` format. """
if not os.path.isfile(filepath):
raise ValueError(f'Invalid file "{filepath}".')
return Image.open(filepath).convert(mode)
def compute_metrics(
a: Union[np.array, Image.Image],
b: Union[np.array, Image.Image],
max_val: float = 255.0,
) -> Tuple[float, float]:
"""Returns PSNR and MS-SSIM between images `a` and `b`. """
if isinstance(a, Image.Image):
a = np.asarray(a)
if isinstance(b, Image.Image):
b = np.asarray(b)
a = torch.from_numpy(a.copy()).float().unsqueeze(0)
if a.size(3) == 3:
a = a.permute(0, 3, 1, 2)
b = torch.from_numpy(b.copy()).float().unsqueeze(0)
if b.size(3) == 3:
b = b.permute(0, 3, 1, 2)
mse = torch.mean((a - b) ** 2).item()
p = 20 * np.log10(max_val) - 10 * np.log10(mse)
m = ms_ssim(a, b, data_range=max_val).item()
return p, m
def run_command(cmd, ignore_returncodes=None):
cmd = [str(c) for c in cmd]
try:
rv = subprocess.check_output(cmd)
return rv.decode("ascii")
except subprocess.CalledProcessError as err:
if ignore_returncodes is not None and err.returncode in ignore_returncodes:
return err.output
print(err.output.decode("utf-8"))
sys.exit(1)
def _get_ffmpeg_version():
rv = run_command(["ffmpeg", "-version"])
return rv.split()[2]
def _get_bpg_version(encoder_path):
rv = run_command([encoder_path, "-h"], ignore_returncodes=[1])
return rv.split()[4]
class Codec:
"""Abstract base class"""
_description = None
def __init__(self, args):
self._set_args(args)
def _set_args(self, args):
return args
@classmethod
def setup_args(cls, parser):
pass
@property
def description(self):
return self._description
@property
def name(self):
raise NotImplementedError()
def _load_img(self, img):
return os.path.abspath(img)
def _run(self, img, quality, *args, **kwargs):
raise NotImplementedError()
def run(self, img, quality, *args, **kwargs):
img = self._load_img(img)
return self._run(img, quality, *args, **kwargs)
class PillowCodec(Codec):
"""Abastract codec based on Pillow bindings."""
fmt = None
@property
def name(self):
raise NotImplementedError()
def _load_img(self, img):
return read_image(img)
def _run(self, img, quality, return_rec=False, return_metrics=True):
start = time.time()
tmp = io.BytesIO()
img.save(tmp, format=self.fmt, quality=int(quality))
enc_time = time.time() - start
tmp.seek(0)
size = tmp.getbuffer().nbytes
start = time.time()
rec = Image.open(tmp)
rec.load()
dec_time = time.time() - start
bpp_val = float(size) * 8 / (img.size[0] * img.size[1])
out = {
"bpp": bpp_val,
"encoding_time": enc_time,
"decoding_time": dec_time,
}
if return_metrics:
psnr_val, msssim_val = compute_metrics(rec, img)
out["psnr"] = psnr_val
out["ms-ssim"] = msssim_val
if return_rec:
return out, rec
return out
class JPEG(PillowCodec):
"""Use libjpeg linked in Pillow"""
fmt = "jpeg"
_description = f"JPEG. Pillow version {PIL.__version__}"
@property
def name(self):
return "JPEG"
class WebP(PillowCodec):
"""Use libwebp linked in Pillow"""
fmt = "webp"
_description = f"WebP. Pillow version {PIL.__version__}"
@property
def name(self):
return "WebP"
class BinaryCodec(Codec):
"""Call a external binary."""
fmt = None
def _run(self, img, quality, return_rec=False, return_metrics=True):
fd0, png_filepath = mkstemp(suffix=".png")
fd1, out_filepath = mkstemp(suffix=self.fmt)
# Encode
start = time.time()
run_command(self._get_encode_cmd(img, quality, out_filepath))
enc_time = time.time() - start
size = filesize(out_filepath)
# Decode
start = time.time()
run_command(self._get_decode_cmd(out_filepath, png_filepath))
dec_time = time.time() - start
# Read image
img = read_image(img)
rec = read_image(png_filepath)
os.close(fd0)
os.remove(png_filepath)
os.close(fd1)
os.remove(out_filepath)
bpp_val = float(size) * 8 / (img.size[0] * img.size[1])
out = {
"bpp": bpp_val,
"encoding_time": enc_time,
"decoding_time": dec_time,
}
if return_metrics:
psnr_val, msssim_val = compute_metrics(rec, img)
out["psnr"] = psnr_val
out["ms-ssim"] = msssim_val
if return_rec:
return out, rec
return out
def _get_encode_cmd(self, img, quality, out_filepath):
raise NotImplementedError()
def _get_decode_cmd(self, out_filepath, rec_filepath):
raise NotImplementedError()
class JPEG2000(BinaryCodec):
"""Use ffmpeg version.
(Not built-in support in default Pillow builds)
"""
fmt = ".jp2"
@property
def name(self):
return "JPEG2000"
@property
def description(self):
return f"JPEG2000. ffmpeg version {_get_ffmpeg_version()}"
def _get_encode_cmd(self, img, quality, out_filepath):
cmd = [
"ffmpeg",
"-loglevel",
"panic",
"-y",
"-i",
img,
"-vcodec",
"jpeg2000",
"-pix_fmt",
"yuv444p",
"-c:v",
"libopenjpeg",
"-compression_level",
quality,
out_filepath,
]
return cmd
# jpeg2000
def _get_decode_cmd(self, out_filepath, rec_filepath):
cmd = ["ffmpeg", "-loglevel", "panic", "-y", "-i", out_filepath, rec_filepath]
return cmd
class BPG(BinaryCodec):
"""BPG from Fabrice Bellard."""
fmt = ".bpg"
@property
def name(self):
return (
f"BPG {self.bitdepth}b {self.subsampling_mode} {self.encoder} "
f"{self.color_mode}"
)
@property
def description(self):
return f"BPG. BPG version {_get_bpg_version(self.encoder_path)}"
@classmethod
def setup_args(cls, parser):
super().setup_args(parser)
parser.add_argument(
"-m",
choices=["420", "444"],
default="444",
help="subsampling mode (default: %(default)s)",
)
parser.add_argument(
"-b",
choices=["8", "10"],
default="8",
help="bitdepth (default: %(default)s)",
)
parser.add_argument(
"-c",
choices=["rgb", "ycbcr"],
default="ycbcr",
help="colorspace (default: %(default)s)",
)
parser.add_argument(
"-e",
choices=["jctvc", "x265"],
default="x265",
help="HEVC implementation (default: %(default)s)",
)
parser.add_argument("--encoder-path", default="bpgenc", help="BPG encoder path")
parser.add_argument("--decoder-path", default="bpgdec", help="BPG decoder path")
def _set_args(self, args):
args = super()._set_args(args)
self.color_mode = args.c
self.encoder = args.e
self.subsampling_mode = args.m
self.bitdepth = args.b
self.encoder_path = "/home/felix/disk2/libbpg/bpgenc" #args.encoder_path
self.decoder_path = "/home/felix/disk2/libbpg/bpgdec"
return args
def _get_encode_cmd(self, img, quality, out_filepath):
if not 0 <= quality <= 51:
raise ValueError(f"Invalid quality value: {quality} (0,51)")
cmd = [
self.encoder_path,
"-o",
out_filepath,
"-q",
str(quality),
"-f",
self.subsampling_mode,
"-e",
self.encoder,
"-c",
self.color_mode,
"-b",
self.bitdepth,
img,
]
return cmd
def _get_decode_cmd(self, out_filepath, rec_filepath):
cmd = [self.decoder_path, "-o", rec_filepath, out_filepath]
return cmd
class TFCI(BinaryCodec):
"""Tensorflow image compression format from tensorflow/compression"""
fmt = ".tfci"
_models = [
"bmshj2018-factorized-mse",
"bmshj2018-hyperprior-mse",
"mbt2018-mean-mse",
]
@property
def description(self):
return "TFCI"
@property
def name(self):
return f"{self.model}"
@classmethod
def setup_args(cls, parser):
super().setup_args(parser)
parser.add_argument(
"-m",
"--model",
choices=cls._models,
default=cls._models[0],
help="model architecture (default: %(default)s)",
)
parser.add_argument(
"-p",
"--path",
required=True,
help="tfci python script path (default: %(default)s)",
)
def _set_args(self, args):
args = super()._set_args(args)
self.model = args.model
self.tfci_path = args.path
return args
def _get_encode_cmd(self, img, quality, out_filepath):
if not 1 <= quality <= 8:
raise ValueError(f"Invalid quality value: {quality} (1, 8)")
cmd = [
sys.executable,
self.tfci_path,
"compress",
f"{self.model}-{quality:d}",
img,
out_filepath,
]
return cmd
def _get_decode_cmd(self, out_filepath, rec_filepath):
cmd = [sys.executable, self.tfci_path, "decompress", out_filepath, rec_filepath]
return cmd
def get_vtm_encoder_path(build_dir):
system = platform.system()
try:
elfnames = {"Darwin": "EncoderApp", "Linux": "EncoderAppStatic"}
return os.path.join(build_dir, elfnames[system])
except KeyError as err:
raise RuntimeError(f'Unsupported platform "{system}"') from err
def get_vtm_decoder_path(build_dir):
system = platform.system()
try:
elfnames = {"Darwin": "DecoderApp", "Linux": "DecoderAppStatic"}
return os.path.join(build_dir, elfnames[system])
except KeyError as err:
raise RuntimeError(f'Unsupported platform "{system}"') from err
class VTM(Codec):
"""VTM: VVC reference software"""
fmt = ".bin"
@property
def description(self):
return "VTM"
@property
def name(self):
return "VTM"
@classmethod
def setup_args(cls, parser):
super().setup_args(parser)
parser.add_argument(
"-b",
"--build-dir",
type=str,
default = "/home/felix/disk2/VVCSoftware_VTM/bin",
help="VTM build dir",
)
parser.add_argument(
"-c",
"--config",
type=str,
default = "/home/felix/disk2/VVCSoftware_VTM/cfg/encoder_intra_vtm.cfg",
help="VTM config file",
)
parser.add_argument(
"--rgb", action="store_true", help="Use RGB color space (over YCbCr)"
)
def _set_args(self, args):
args = super()._set_args(args)
self.encoder_path = get_vtm_encoder_path(args.build_dir)
self.decoder_path = get_vtm_decoder_path(args.build_dir)
self.config_path = args.config
self.rgb = args.rgb
return args
def _run(self, img, quality, return_rec=False, return_metrics=True):
if not 0 <= quality <= 63:
raise ValueError(f"Invalid quality value: {quality} (0,63)")
# Taking 8bit input for now
bitdepth = 8
# Convert input image to yuv 444 file
arr = np.asarray(read_image(img))
fd, yuv_path = mkstemp(suffix=".yuv")
out_filepath = os.path.splitext(yuv_path)[0] + ".bin"
arr = arr.transpose((2, 0, 1)) # color channel first
if not self.rgb:
# convert rgb content to YCbCr
rgb = torch.from_numpy(arr.copy()).float() / (2 ** bitdepth - 1)
arr = np.clip(rgb2ycbcr(rgb).numpy(), 0, 1)
arr = (arr * (2 ** bitdepth - 1)).astype(np.uint8)
with open(yuv_path, "wb") as f:
f.write(arr.tobytes())
# Encode
height, width = arr.shape[1:]
cmd = [
self.encoder_path,
"-i",
yuv_path,
"-c",
self.config_path,
"-q",
quality,
"-o",
"/dev/null",
"-b",
out_filepath,
"-wdt",
width,
"-hgt",
height,
"-fr",
"1",
"-f",
"1",
"--InputChromaFormat=444",
"--InputBitDepth=8",
"--ConformanceMode=1",
]
if self.rgb:
cmd += [
"--InputColourSpaceConvert=RGBtoGBR",
"--SNRInternalColourSpace=1",
"--OutputInternalColourSpace=0",
]
start = time.time()
run_command(cmd)
enc_time = time.time() - start
# cleanup encoder input
os.close(fd)
os.unlink(yuv_path)
# Decode
cmd = [self.decoder_path, "-b", out_filepath, "-o", yuv_path, "-d", 8]
if self.rgb:
cmd.append("--OutputInternalColourSpace=GBRtoRGB")
start = time.time()
run_command(cmd)
dec_time = time.time() - start
# Compute PSNR
rec_arr = np.fromfile(yuv_path, dtype=np.uint8)
rec_arr = rec_arr.reshape(arr.shape)
arr = arr.astype(np.float32) / (2 ** bitdepth - 1)
rec_arr = rec_arr.astype(np.float32) / (2 ** bitdepth - 1)
if not self.rgb:
arr = ycbcr2rgb(torch.from_numpy(arr.copy())).numpy()
rec_arr = ycbcr2rgb(torch.from_numpy(rec_arr.copy())).numpy()
bpp = filesize(out_filepath) * 8.0 / (height * width)
# Cleanup
os.unlink(yuv_path)
os.unlink(out_filepath)
out = {
"bpp": bpp,
"encoding_time": enc_time,
"decoding_time": dec_time,
}
if return_metrics:
psnr_val, msssim_val = compute_metrics(arr, rec_arr, max_val=1.0)
out["psnr"] = psnr_val
out["ms-ssim"] = msssim_val
if return_rec:
rec = Image.fromarray(
(rec_arr.clip(0, 1).transpose(1, 2, 0) * 255.0).astype(np.uint8)
)
return out, rec
return out
class HM(Codec):
"""HM: H.265/HEVC reference software"""
fmt = ".bin"
@property
def description(self):
return "HM"
@property
def name(self):
return "HM"
@classmethod
def setup_args(cls, parser):
super().setup_args(parser)
parser.add_argument(
"-b",
"--build-dir",
type=str,
required=True,
help="HM build dir",
)
parser.add_argument(
"-c", "--config", type=str, required=True, help="HM config file"
)
parser.add_argument(
"--rgb", action="store_true", help="Use RGB color space (over YCbCr)"
)
def _set_args(self, args):
args = super()._set_args(args)
self.encoder_path = os.path.join(args.build_dir, "TAppEncoderStatic")
self.decoder_path = os.path.join(args.build_dir, "TAppDecoderStatic")
self.config_path = args.config
self.rgb = args.rgb
return args
def _run(self, img, quality, return_rec=False, return_metrics=True):
if not 0 <= quality <= 51:
raise ValueError(f"Invalid quality value: {quality} (0,51)")
# Convert input image to yuv 444 file
arr = np.asarray(read_image(img))
fd, yuv_path = mkstemp(suffix=".yuv")
out_filepath = os.path.splitext(yuv_path)[0] + ".bin"
bitdepth = 8
arr = arr.transpose((2, 0, 1)) # color channel first
if not self.rgb:
# convert rgb content to YCbCr
rgb = torch.from_numpy(arr.copy()).float() / (2 ** bitdepth - 1)
arr = np.clip(rgb2ycbcr(rgb).numpy(), 0, 1)
arr = (arr * (2 ** bitdepth - 1)).astype(np.uint8)
with open(yuv_path, "wb") as f:
f.write(arr.tobytes())
# Encode
height, width = arr.shape[1:]
cmd = [
self.encoder_path,
"-i",
yuv_path,
"-c",
self.config_path,
"-q",
quality,
"-o",
"/dev/null",
"-b",
out_filepath,
"-wdt",
width,
"-hgt",
height,
"-fr",
"1",
"-f",
"1",
"--InputChromaFormat=444",
"--InputBitDepth=8",
"--SEIDecodedPictureHash",
"--Level=5.1",
"--CUNoSplitIntraACT=0",
"--ConformanceMode=1",
]
if self.rgb:
cmd += [
"--InputColourSpaceConvert=RGBtoGBR",
"--SNRInternalColourSpace=1",
"--OutputInternalColourSpace=0",
]
start = time.time()
run_command(cmd)
enc_time = time.time() - start
# cleanup encoder input
os.close(fd)
os.unlink(yuv_path)
# Decode
cmd = [self.decoder_path, "-b", out_filepath, "-o", yuv_path, "-d", 8]
if self.rgb:
cmd.append("--OutputInternalColourSpace=GBRtoRGB")
start = time.time()
run_command(cmd)
dec_time = time.time() - start
# Compute PSNR
rec_arr = np.fromfile(yuv_path, dtype=np.uint8)
rec_arr = rec_arr.reshape(arr.shape)
arr = arr.astype(np.float32) / (2 ** bitdepth - 1)
rec_arr = rec_arr.astype(np.float32) / (2 ** bitdepth - 1)
if not self.rgb:
arr = ycbcr2rgb(torch.from_numpy(arr.copy())).numpy()
rec_arr = ycbcr2rgb(torch.from_numpy(rec_arr.copy())).numpy()
bpp = filesize(out_filepath) * 8.0 / (height * width)
# Cleanup
os.unlink(yuv_path)
os.unlink(out_filepath)
out = {
"bpp": bpp,
"encoding_time": enc_time,
"decoding_time": dec_time,
}
if return_metrics:
psnr_val, msssim_val = compute_metrics(arr, rec_arr, max_val=1.0)
out["psnr"] = psnr_val
out["ms-ssim"] = msssim_val
if return_rec:
rec = Image.fromarray(
(rec_arr.clip(0, 1).transpose(1, 2, 0) * 255.0).astype(np.uint8)
)
return out, rec
return out
class AV1(Codec):
"""AV1: AOM reference software"""
fmt = ".webm"
@property
def description(self):
return "AV1"
@property
def name(self):
return "AV1"
@classmethod
def setup_args(cls, parser):
super().setup_args(parser)
parser.add_argument(
"-b",
"--build-dir",
type=str,
required=True,
help="AOM binaries dir",
)
def _set_args(self, args):
args = super()._set_args(args)
self.encoder_path = os.path.join(args.build_dir, "aomenc")
self.decoder_path = os.path.join(args.build_dir, "aomdec")
return args
def _run(self, img, quality, return_rec=False, return_metrics=True):
if not 0 <= quality <= 63:
raise ValueError(f"Invalid quality value: {quality} (0,63)")
# Convert input image to yuv 444 file
arr = np.asarray(read_image(img))
fd, yuv_path = mkstemp(suffix=".yuv")
out_filepath = os.path.splitext(yuv_path)[0] + ".webm"
bitdepth = 8
arr = arr.transpose((2, 0, 1)) # color channel first
# convert rgb content to YCbCr
rgb = torch.from_numpy(arr.copy()).float() / (2 ** bitdepth - 1)
arr = np.clip(rgb2ycbcr(rgb).numpy(), 0, 1)
arr = (arr * (2 ** bitdepth - 1)).astype(np.uint8)
with open(yuv_path, "wb") as f:
f.write(arr.tobytes())
# Encode
height, width = arr.shape[1:]
cmd = [
self.encoder_path,
"-w",
width,
"-h",
height,
"--fps=1/1",
"--limit=1",
"--input-bit-depth=8",
"--cpu-used=0",
"--threads=1",
"--passes=2",
"--end-usage=q",
"--cq-level=" + str(quality),
"--i444",
"--skip=0",
"--tune=psnr",
"--psnr",
"--bit-depth=8",
"-o",
out_filepath,
yuv_path,
]
start = time.time()
run_command(cmd)
enc_time = time.time() - start
# cleanup encoder input
os.close(fd)
os.unlink(yuv_path)
# Decode
cmd = [
self.decoder_path,
out_filepath,
"-o",
yuv_path,
"--rawvideo",
"--output-bit-depth=8",
]
start = time.time()
run_command(cmd)
dec_time = time.time() - start
# Compute PSNR
rec_arr = np.fromfile(yuv_path, dtype=np.uint8)
rec_arr = rec_arr.reshape(arr.shape)
arr = arr.astype(np.float32) / (2 ** bitdepth - 1)
rec_arr = rec_arr.astype(np.float32) / (2 ** bitdepth - 1)
arr = ycbcr2rgb(torch.from_numpy(arr.copy())).numpy()
rec_arr = ycbcr2rgb(torch.from_numpy(rec_arr.copy())).numpy()
bpp = filesize(out_filepath) * 8.0 / (height * width)
# Cleanup
os.unlink(yuv_path)
os.unlink(out_filepath)
out = {
"bpp": bpp,
"encoding_time": enc_time,
"decoding_time": dec_time,
}
if return_metrics:
psnr_val, msssim_val = compute_metrics(arr, rec_arr, max_val=1.0)
out["psnr"] = psnr_val
out["ms-ssim"] = msssim_val
if return_rec:
rec = Image.fromarray(
(rec_arr.clip(0, 1).transpose(1, 2, 0) * 255.0).astype(np.uint8)
)
return out, rec
return out
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Evaluate an end-to-end compression model on an image dataset.
"""
import argparse
import json
import math
import os
import sys
import time
from collections import defaultdict
from typing import List
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from pytorch_msssim import ms_ssim
from torchvision import transforms
import compressai
from compressai.zoo import models as pretrained_models
from compressai.zoo.image import model_architectures as architectures
# from torchvision.datasets.folder
IMG_EXTENSIONS = (
".jpg",
".jpeg",
".png",
".ppm",
".bmp",
".pgm",
".tif",
".tiff",
".webp",
)
def collect_images(rootpath: str) -> List[str]:
return [
os.path.join(rootpath, f)
for f in os.listdir(rootpath)
if os.path.splitext(f)[-1].lower() in IMG_EXTENSIONS
]
def psnr(a: torch.Tensor, b: torch.Tensor) -> float:
mse = F.mse_loss(a, b).item()
return -10 * math.log10(mse)
def read_image(filepath: str) -> torch.Tensor:
assert os.path.isfile(filepath)
img = Image.open(filepath).convert("RGB")
# test_transforms = transforms.Compose(
# [transforms.CenterCrop(256), transforms.ToTensor()]
# )
# return test_transforms(img)
return transforms.ToTensor()(img)
@torch.no_grad()
def inference(model, x, savedir = "", idx = 1):
x = x.unsqueeze(0)
h, w = x.size(2), x.size(3)
p = 64 # maximum 6 strides of 2
new_h = (h + p - 1) // p * p
new_w = (w + p - 1) // p * p
padding_left = (new_w - w) // 2
padding_right = new_w - w - padding_left
padding_top = (new_h - h) // 2
padding_bottom = new_h - h - padding_top
x_padded = F.pad(
x,
(padding_left, padding_right, padding_top, padding_bottom),
mode="constant",
value=0,
)
start = time.time()
out_enc = model.compress(x_padded)
enc_time = time.time() - start
start = time.time()
out_dec = model.decompress(out_enc["strings"], out_enc["shape"])
dec_time = time.time() - start
out_dec["x_hat"] = F.pad(
out_dec["x_hat"], (-padding_left, -padding_right, -padding_top, -padding_bottom)
)
num_pixels = x.size(0) * x.size(2) * x.size(3)
bpp = sum(len(s[0]) for s in out_enc["strings"]) * 8.0 / num_pixels
if savedir != "":
if not os.path.exists(savedir):
os.makedirs(savedir)
cur_psnr = psnr(x, out_dec["x_hat"])
cur_ssim = ms_ssim(x, out_dec["x_hat"], data_range=1.0).item()
tran1 = transforms.ToPILImage()
cur_img = tran1(out_dec["x_hat"][0])
cur_img.save(os.path.join(savedir,'{:02d}'.format(idx)+"_"+'{:.2f}'.format(cur_psnr)+"_"+'{:.3f}'.format(bpp)+"_"+'{:.3f}'.format(cur_ssim)+".png"))
return {
"psnr": psnr(x, out_dec["x_hat"]),
"ms-ssim": ms_ssim(x, out_dec["x_hat"], data_range=1.0).item(),
"bpp": bpp,
"encoding_time": enc_time,
"decoding_time": dec_time,
}
@torch.no_grad()
def inference_entropy_estimation(model, x):
x = x.unsqueeze(0)
start = time.time()
out_net = model.forward(x)
# print(out_net['x_hat'][0,0,:5,:5])
elapsed_time = time.time() - start
num_pixels = x.size(0) * x.size(2) * x.size(3)
bpp = sum(
(torch.log(likelihoods).sum() / (-math.log(2) * num_pixels))
for likelihoods in out_net["likelihoods"].values()
)
return {
"psnr": psnr(x, out_net["x_hat"]),
"bpp": bpp.item(),
"encoding_time": elapsed_time / 2.0, # broad estimation
"decoding_time": elapsed_time / 2.0,
}
def load_pretrained(model: str, metric: str, quality: int) -> nn.Module:
return pretrained_models[model](
quality=quality, metric=metric, pretrained=True
).eval()
def load_checkpoint(arch: str, checkpoint_path: str) -> nn.Module:
return architectures[arch].from_state_dict(torch.load(checkpoint_path)).eval()
def eval_model(model, filepaths, entropy_estimation=False, half=False, savedir = ""):
device = next(model.parameters()).device
metrics = defaultdict(float)
for idx, f in enumerate(sorted(filepaths)):
x = read_image(f).to(device)
if not entropy_estimation:
print('evaluating index', idx)
if half:
model = model.half()
x = x.half()
rv = inference(model, x, savedir, idx)
else:
rv = inference_entropy_estimation(model, x)
print('bpp', rv['bpp'])
print('psnr', rv['psnr'])
print('ms-ssim', rv['ms-ssim'])
print()
for k, v in rv.items():
metrics[k] += v
for k, v in metrics.items():
metrics[k] = v / len(filepaths)
return metrics
def setup_args():
parent_parser = argparse.ArgumentParser(
add_help=False,
)
# Common options.
parent_parser.add_argument("dataset", type=str, help="dataset path")
parent_parser.add_argument(
"-a",
"--arch",
type=str,
choices=pretrained_models.keys(),
help="model architecture",
required=True,
)
parent_parser.add_argument(
"-c",
"--entropy-coder",
choices=compressai.available_entropy_coders(),
default=compressai.available_entropy_coders()[0],
help="entropy coder (default: %(default)s)",
)
parent_parser.add_argument(
"--cuda",
action="store_true",
help="enable CUDA",
)
parent_parser.add_argument(
"--half",
action="store_true",
help="convert model to half floating point (fp16)",
)
parent_parser.add_argument(
"--entropy-estimation",
action="store_true",
help="use evaluated entropy estimation (no entropy coding)",
)
parent_parser.add_argument(
"-v",
"--verbose",
action="store_true",
help="verbose mode",
)
parent_parser.add_argument(
"-s",
"--savedir",
type=str,
default="",
)
parent_parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID")
parser = argparse.ArgumentParser(
description="Evaluate a model on an image dataset.", add_help=True
)
subparsers = parser.add_subparsers(
help="model source", dest="source")#, required=True
# )
# Options for pretrained models
pretrained_parser = subparsers.add_parser("pretrained", parents=[parent_parser])
pretrained_parser.add_argument(
"-m",
"--metric",
type=str,
choices=["mse", "ms-ssim"],
default="mse",
help="metric trained against (default: %(default)s)",
)
pretrained_parser.add_argument(
"-q",
"--quality",
dest="qualities",
nargs="+",
type=int,
default=(1,),
)
checkpoint_parser = subparsers.add_parser("checkpoint", parents=[parent_parser])
# checkpoint_parser.add_argument(
# "-p",
# "--path",
# dest="paths",
# type=str,
# nargs="*",
# required=True,
# help="checkpoint path",
# )
checkpoint_parser.add_argument("-exp", "--experiment", type=str, required=True, help="Experiment name")
return parser
def main(argv):
args = setup_args().parse_args(argv)
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu"
filepaths = collect_images(args.dataset)
if len(filepaths) == 0:
print("No images found in directory.")
sys.exit(1)
compressai.set_entropy_coder(args.entropy_coder)
if args.source == "pretrained":
runs = sorted(args.qualities)
opts = (args.arch, args.metric)
load_func = load_pretrained
log_fmt = "\rEvaluating {0} | {run:d}"
elif args.source == "checkpoint":
# runs = args.paths
checkpoint_updated_dir = os.path.join('../experiments', args.experiment, 'checkpoint_updated')
checkpoint_updated = os.path.join(checkpoint_updated_dir, os.listdir(checkpoint_updated_dir)[0])
runs = [checkpoint_updated]
opts = (args.arch,)
load_func = load_checkpoint
log_fmt = "\rEvaluating {run:s}"
results = defaultdict(list)
for run in runs:
if args.verbose:
sys.stderr.write(log_fmt.format(*opts, run=run))
sys.stderr.flush()
model = load_func(*opts, run)
if args.cuda and torch.cuda.is_available():
model = model.to("cuda")
metrics = eval_model(model, filepaths, args.entropy_estimation, args.half, args.savedir)
for k, v in metrics.items():
results[k].append(v)
if args.verbose:
sys.stderr.write("\n")
sys.stderr.flush()
description = (
"entropy estimation" if args.entropy_estimation else args.entropy_coder
)
output = {
"name": args.arch,
"description": f"Inference ({description})",
"results": results,
}
print(json.dumps(output, indent=2))
if __name__ == "__main__":
main(sys.argv[1:])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment