Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
compressai
Commits
9e459ea3
Commit
9e459ea3
authored
Nov 05, 2024
by
jerrrrry
Browse files
Initial commit
parents
Changes
120
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3518 additions
and
0 deletions
+3518
-0
codes/compressai/layers/layers.py
codes/compressai/layers/layers.py
+213
-0
codes/compressai/models/__init__.py
codes/compressai/models/__init__.py
+17
-0
codes/compressai/models/our_utils.py
codes/compressai/models/our_utils.py
+385
-0
codes/compressai/models/ours.py
codes/compressai/models/ours.py
+140
-0
codes/compressai/models/priors.py
codes/compressai/models/priors.py
+653
-0
codes/compressai/models/utils.py
codes/compressai/models/utils.py
+130
-0
codes/compressai/models/waseda.py
codes/compressai/models/waseda.py
+138
-0
codes/compressai/ops/__init__.py
codes/compressai/ops/__init__.py
+19
-0
codes/compressai/ops/bound_ops.py
codes/compressai/ops/bound_ops.py
+53
-0
codes/compressai/ops/ops.py
codes/compressai/ops/ops.py
+32
-0
codes/compressai/ops/parametrizers.py
codes/compressai/ops/parametrizers.py
+45
-0
codes/compressai/transforms/__init__.py
codes/compressai/transforms/__init__.py
+15
-0
codes/compressai/transforms/functional.py
codes/compressai/transforms/functional.py
+135
-0
codes/compressai/transforms/transforms.py
codes/compressai/transforms/transforms.py
+118
-0
codes/compressai/utils/__init__.py
codes/compressai/utils/__init__.py
+13
-0
codes/compressai/utils/bench/__init__.py
codes/compressai/utils/bench/__init__.py
+13
-0
codes/compressai/utils/bench/__main__.py
codes/compressai/utils/bench/__main__.py
+166
-0
codes/compressai/utils/bench/codecs.py
codes/compressai/utils/bench/codecs.py
+884
-0
codes/compressai/utils/eval_model/__init__.py
codes/compressai/utils/eval_model/__init__.py
+13
-0
codes/compressai/utils/eval_model/__main__.py
codes/compressai/utils/eval_model/__main__.py
+336
-0
No files found.
codes/compressai/layers/layers.py
0 → 100644
View file @
9e459ea3
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.nn
as
nn
from
.gdn
import
GDN
class
MaskedConv2d
(
nn
.
Conv2d
):
r
"""Masked 2D convolution implementation, mask future "unseen" pixels.
Useful for building auto-regressive network components.
Introduced in `"Conditional Image Generation with PixelCNN Decoders"
<https://arxiv.org/abs/1606.05328>`_.
Inherits the same arguments as a `nn.Conv2d`. Use `mask_type='A'` for the
first layer (which also masks the "current pixel"), `mask_type='B'` for the
following layers.
"""
def
__init__
(
self
,
*
args
,
mask_type
=
"A"
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
if
mask_type
not
in
(
"A"
,
"B"
):
raise
ValueError
(
f
'Invalid "mask_type" value "
{
mask_type
}
"'
)
self
.
register_buffer
(
"mask"
,
torch
.
ones_like
(
self
.
weight
.
data
))
_
,
_
,
h
,
w
=
self
.
mask
.
size
()
self
.
mask
[:,
:,
h
//
2
,
w
//
2
+
(
mask_type
==
"B"
)
:]
=
0
self
.
mask
[:,
:,
h
//
2
+
1
:]
=
0
def
forward
(
self
,
x
):
# TODO(begaintj): weight assigment is not supported by torchscript
self
.
weight
.
data
*=
self
.
mask
return
super
().
forward
(
x
)
def
conv3x3
(
in_ch
,
out_ch
,
stride
=
1
):
"""3x3 convolution with padding."""
return
nn
.
Conv2d
(
in_ch
,
out_ch
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
1
)
def
subpel_conv3x3
(
in_ch
,
out_ch
,
r
=
1
):
"""3x3 sub-pixel convolution for up-sampling."""
return
nn
.
Sequential
(
nn
.
Conv2d
(
in_ch
,
out_ch
*
r
**
2
,
kernel_size
=
3
,
padding
=
1
),
nn
.
PixelShuffle
(
r
)
)
def
conv1x1
(
in_ch
,
out_ch
,
stride
=
1
):
"""1x1 convolution."""
return
nn
.
Conv2d
(
in_ch
,
out_ch
,
kernel_size
=
1
,
stride
=
stride
)
class
ResidualBlockWithStride
(
nn
.
Module
):
"""Residual block with a stride on the first convolution.
Args:
in_ch (int): number of input channels
out_ch (int): number of output channels
stride (int): stride value (default: 2)
"""
def
__init__
(
self
,
in_ch
,
out_ch
,
stride
=
2
):
super
().
__init__
()
self
.
conv1
=
conv3x3
(
in_ch
,
out_ch
,
stride
=
stride
)
self
.
leaky_relu
=
nn
.
LeakyReLU
(
inplace
=
True
)
self
.
conv2
=
conv3x3
(
out_ch
,
out_ch
)
self
.
gdn
=
GDN
(
out_ch
)
if
stride
!=
1
or
in_ch
!=
out_ch
:
self
.
skip
=
conv1x1
(
in_ch
,
out_ch
,
stride
=
stride
)
else
:
self
.
skip
=
None
def
forward
(
self
,
x
):
identity
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
leaky_relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
gdn
(
out
)
if
self
.
skip
is
not
None
:
identity
=
self
.
skip
(
x
)
out
+=
identity
return
out
class
ResidualBlockUpsample
(
nn
.
Module
):
"""Residual block with sub-pixel upsampling on the last convolution.
Args:
in_ch (int): number of input channels
out_ch (int): number of output channels
upsample (int): upsampling factor (default: 2)
"""
def
__init__
(
self
,
in_ch
,
out_ch
,
upsample
=
2
):
super
().
__init__
()
self
.
subpel_conv
=
subpel_conv3x3
(
in_ch
,
out_ch
,
upsample
)
self
.
leaky_relu
=
nn
.
LeakyReLU
(
inplace
=
True
)
self
.
conv
=
conv3x3
(
out_ch
,
out_ch
)
self
.
igdn
=
GDN
(
out_ch
,
inverse
=
True
)
self
.
upsample
=
subpel_conv3x3
(
in_ch
,
out_ch
,
upsample
)
def
forward
(
self
,
x
):
identity
=
x
out
=
self
.
subpel_conv
(
x
)
out
=
self
.
leaky_relu
(
out
)
out
=
self
.
conv
(
out
)
out
=
self
.
igdn
(
out
)
identity
=
self
.
upsample
(
x
)
out
+=
identity
return
out
class
ResidualBlock
(
nn
.
Module
):
"""Simple residual block with two 3x3 convolutions.
Args:
in_ch (int): number of input channels
out_ch (int): number of output channels
"""
def
__init__
(
self
,
in_ch
,
out_ch
):
super
().
__init__
()
self
.
conv1
=
conv3x3
(
in_ch
,
out_ch
)
self
.
leaky_relu
=
nn
.
LeakyReLU
(
inplace
=
True
)
self
.
conv2
=
conv3x3
(
out_ch
,
out_ch
)
if
in_ch
!=
out_ch
:
self
.
skip
=
conv1x1
(
in_ch
,
out_ch
)
else
:
self
.
skip
=
None
def
forward
(
self
,
x
):
identity
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
leaky_relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
leaky_relu
(
out
)
if
self
.
skip
is
not
None
:
identity
=
self
.
skip
(
x
)
out
=
out
+
identity
return
out
class
AttentionBlock
(
nn
.
Module
):
"""Self attention block.
Simplified variant from `"Learned Image Compression with
Discretized Gaussian Mixture Likelihoods and Attention Modules"
<https://arxiv.org/abs/2001.01568>`_, by Zhengxue Cheng, Heming Sun, Masaru
Takeuchi, Jiro Katto.
Args:
N (int): Number of channels)
"""
def
__init__
(
self
,
N
):
super
().
__init__
()
class
ResidualUnit
(
nn
.
Module
):
"""Simple residual unit."""
def
__init__
(
self
):
super
().
__init__
()
self
.
conv
=
nn
.
Sequential
(
conv1x1
(
N
,
N
//
2
),
nn
.
ReLU
(
inplace
=
True
),
conv3x3
(
N
//
2
,
N
//
2
),
nn
.
ReLU
(
inplace
=
True
),
conv1x1
(
N
//
2
,
N
),
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
def
forward
(
self
,
x
):
identity
=
x
out
=
self
.
conv
(
x
)
out
+=
identity
out
=
self
.
relu
(
out
)
return
out
self
.
conv_a
=
nn
.
Sequential
(
ResidualUnit
(),
ResidualUnit
(),
ResidualUnit
())
self
.
conv_b
=
nn
.
Sequential
(
ResidualUnit
(),
ResidualUnit
(),
ResidualUnit
(),
conv1x1
(
N
,
N
),
)
def
forward
(
self
,
x
):
identity
=
x
a
=
self
.
conv_a
(
x
)
b
=
self
.
conv_b
(
x
)
out
=
a
*
torch
.
sigmoid
(
b
)
out
+=
identity
return
out
codes/compressai/models/__init__.py
0 → 100644
View file @
9e459ea3
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.priors
import
*
from
.waseda
import
*
from
.ours
import
*
codes/compressai/models/our_utils.py
0 → 100644
View file @
9e459ea3
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.init
as
init
import
torch.nn.functional
as
F
import
warnings
from
compressai.layers
import
*
class
AttModule
(
nn
.
Module
):
def
__init__
(
self
,
N
):
super
(
AttModule
,
self
).
__init__
()
self
.
forw_att
=
AttentionBlock
(
N
)
self
.
back_att
=
AttentionBlock
(
N
)
def
forward
(
self
,
x
,
rev
=
False
):
if
not
rev
:
return
self
.
forw_att
(
x
)
else
:
return
self
.
back_att
(
x
)
class
EnhModule
(
nn
.
Module
):
def
__init__
(
self
,
nf
):
super
(
EnhModule
,
self
).
__init__
()
self
.
forw_enh
=
EnhBlock
(
nf
)
self
.
back_enh
=
EnhBlock
(
nf
)
def
forward
(
self
,
x
,
rev
=
False
):
if
not
rev
:
return
self
.
forw_enh
(
x
)
else
:
return
self
.
back_enh
(
x
)
class
EnhBlock
(
nn
.
Module
):
def
__init__
(
self
,
nf
):
super
(
EnhBlock
,
self
).
__init__
()
self
.
layers
=
nn
.
Sequential
(
DenseBlock
(
3
,
nf
),
nn
.
Conv2d
(
nf
,
nf
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
True
),
nn
.
Conv2d
(
nf
,
nf
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
True
),
nn
.
Conv2d
(
nf
,
nf
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
True
),
DenseBlock
(
nf
,
3
)
)
def
forward
(
self
,
x
):
return
x
+
self
.
layers
(
x
)
*
0.2
class
InvComp
(
nn
.
Module
):
def
__init__
(
self
,
M
):
super
(
InvComp
,
self
).
__init__
()
self
.
in_nc
=
3
self
.
out_nc
=
M
self
.
operations
=
nn
.
ModuleList
()
# 1st level
b
=
SqueezeLayer
(
2
)
self
.
operations
.
append
(
b
)
self
.
in_nc
*=
4
b
=
InvertibleConv1x1
(
self
.
in_nc
)
self
.
operations
.
append
(
b
)
b
=
CouplingLayer
(
self
.
in_nc
//
4
,
3
*
self
.
in_nc
//
4
,
5
)
self
.
operations
.
append
(
b
)
b
=
CouplingLayer
(
self
.
in_nc
//
4
,
3
*
self
.
in_nc
//
4
,
5
)
self
.
operations
.
append
(
b
)
b
=
CouplingLayer
(
self
.
in_nc
//
4
,
3
*
self
.
in_nc
//
4
,
5
)
self
.
operations
.
append
(
b
)
# 2nd level
b
=
SqueezeLayer
(
2
)
self
.
operations
.
append
(
b
)
self
.
in_nc
*=
4
b
=
InvertibleConv1x1
(
self
.
in_nc
)
self
.
operations
.
append
(
b
)
b
=
CouplingLayer
(
self
.
in_nc
//
4
,
3
*
self
.
in_nc
//
4
,
5
)
self
.
operations
.
append
(
b
)
b
=
CouplingLayer
(
self
.
in_nc
//
4
,
3
*
self
.
in_nc
//
4
,
5
)
self
.
operations
.
append
(
b
)
b
=
CouplingLayer
(
self
.
in_nc
//
4
,
3
*
self
.
in_nc
//
4
,
5
)
self
.
operations
.
append
(
b
)
# 3rd level
b
=
SqueezeLayer
(
2
)
self
.
operations
.
append
(
b
)
self
.
in_nc
*=
4
b
=
InvertibleConv1x1
(
self
.
in_nc
)
self
.
operations
.
append
(
b
)
b
=
CouplingLayer
(
self
.
in_nc
//
4
,
3
*
self
.
in_nc
//
4
,
3
)
self
.
operations
.
append
(
b
)
b
=
CouplingLayer
(
self
.
in_nc
//
4
,
3
*
self
.
in_nc
//
4
,
3
)
self
.
operations
.
append
(
b
)
b
=
CouplingLayer
(
self
.
in_nc
//
4
,
3
*
self
.
in_nc
//
4
,
3
)
self
.
operations
.
append
(
b
)
# 4th level
b
=
SqueezeLayer
(
2
)
self
.
operations
.
append
(
b
)
self
.
in_nc
*=
4
b
=
InvertibleConv1x1
(
self
.
in_nc
)
self
.
operations
.
append
(
b
)
b
=
CouplingLayer
(
self
.
in_nc
//
4
,
3
*
self
.
in_nc
//
4
,
3
)
self
.
operations
.
append
(
b
)
b
=
CouplingLayer
(
self
.
in_nc
//
4
,
3
*
self
.
in_nc
//
4
,
3
)
self
.
operations
.
append
(
b
)
b
=
CouplingLayer
(
self
.
in_nc
//
4
,
3
*
self
.
in_nc
//
4
,
3
)
self
.
operations
.
append
(
b
)
def
forward
(
self
,
x
,
rev
=
False
):
if
not
rev
:
for
op
in
self
.
operations
:
x
=
op
.
forward
(
x
,
False
)
b
,
c
,
h
,
w
=
x
.
size
()
x
=
torch
.
mean
(
x
.
view
(
b
,
c
//
self
.
out_nc
,
self
.
out_nc
,
h
,
w
),
dim
=
1
)
else
:
times
=
self
.
in_nc
//
self
.
out_nc
x
=
x
.
repeat
(
1
,
times
,
1
,
1
)
for
op
in
reversed
(
self
.
operations
):
x
=
op
.
forward
(
x
,
True
)
return
x
class
CouplingLayer
(
nn
.
Module
):
def
__init__
(
self
,
split_len1
,
split_len2
,
kernal_size
,
clamp
=
1.0
):
super
(
CouplingLayer
,
self
).
__init__
()
self
.
split_len1
=
split_len1
self
.
split_len2
=
split_len2
self
.
clamp
=
clamp
self
.
G1
=
Bottleneck
(
self
.
split_len1
,
self
.
split_len2
,
kernal_size
)
self
.
G2
=
Bottleneck
(
self
.
split_len2
,
self
.
split_len1
,
kernal_size
)
self
.
H1
=
Bottleneck
(
self
.
split_len1
,
self
.
split_len2
,
kernal_size
)
self
.
H2
=
Bottleneck
(
self
.
split_len2
,
self
.
split_len1
,
kernal_size
)
def
forward
(
self
,
x
,
rev
=
False
):
x1
,
x2
=
(
x
.
narrow
(
1
,
0
,
self
.
split_len1
),
x
.
narrow
(
1
,
self
.
split_len1
,
self
.
split_len2
))
if
not
rev
:
y1
=
x1
.
mul
(
torch
.
exp
(
self
.
clamp
*
(
torch
.
sigmoid
(
self
.
G2
(
x2
))
*
2
-
1
)
))
+
self
.
H2
(
x2
)
y2
=
x2
.
mul
(
torch
.
exp
(
self
.
clamp
*
(
torch
.
sigmoid
(
self
.
G1
(
y1
))
*
2
-
1
)
))
+
self
.
H1
(
y1
)
else
:
y2
=
(
x2
-
self
.
H1
(
x1
)).
div
(
torch
.
exp
(
self
.
clamp
*
(
torch
.
sigmoid
(
self
.
G1
(
x1
))
*
2
-
1
)
))
y1
=
(
x1
-
self
.
H2
(
y2
)).
div
(
torch
.
exp
(
self
.
clamp
*
(
torch
.
sigmoid
(
self
.
G2
(
y2
))
*
2
-
1
)
))
return
torch
.
cat
((
y1
,
y2
),
1
)
class
Bottleneck
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
):
super
(
Bottleneck
,
self
).
__init__
()
# P = ((S-1)*W-S+F)/2, with F = filter size, S = stride
padding
=
(
kernel_size
-
1
)
//
2
self
.
conv1
=
nn
.
Conv2d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
padding
=
padding
)
self
.
conv2
=
nn
.
Conv2d
(
in_channels
=
out_channels
,
out_channels
=
out_channels
,
kernel_size
=
1
)
self
.
conv3
=
nn
.
Conv2d
(
in_channels
=
out_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
padding
=
padding
)
self
.
lrelu
=
nn
.
LeakyReLU
(
negative_slope
=
0.2
,
inplace
=
True
)
initialize_weights_xavier
([
self
.
conv1
,
self
.
conv2
],
0.1
)
initialize_weights
(
self
.
conv3
,
0
)
def
forward
(
self
,
x
):
conv1
=
self
.
lrelu
(
self
.
conv1
(
x
))
conv2
=
self
.
lrelu
(
self
.
conv2
(
conv1
))
conv3
=
self
.
conv3
(
conv2
)
return
conv3
class
SqueezeLayer
(
nn
.
Module
):
def
__init__
(
self
,
factor
):
super
().
__init__
()
self
.
factor
=
factor
def
forward
(
self
,
input
,
reverse
=
False
):
if
not
reverse
:
output
=
self
.
squeeze2d
(
input
,
self
.
factor
)
# Squeeze in forward
return
output
else
:
output
=
self
.
unsqueeze2d
(
input
,
self
.
factor
)
return
output
def
jacobian
(
self
,
x
,
rev
=
False
):
return
0
@
staticmethod
def
squeeze2d
(
input
,
factor
=
2
):
assert
factor
>=
1
and
isinstance
(
factor
,
int
)
if
factor
==
1
:
return
input
size
=
input
.
size
()
B
=
size
[
0
]
C
=
size
[
1
]
H
=
size
[
2
]
W
=
size
[
3
]
assert
H
%
factor
==
0
and
W
%
factor
==
0
,
"{}"
.
format
((
H
,
W
,
factor
))
x
=
input
.
view
(
B
,
C
,
H
//
factor
,
factor
,
W
//
factor
,
factor
)
x
=
x
.
permute
(
0
,
3
,
5
,
1
,
2
,
4
).
contiguous
()
x
=
x
.
view
(
B
,
factor
*
factor
*
C
,
H
//
factor
,
W
//
factor
)
return
x
@
staticmethod
def
unsqueeze2d
(
input
,
factor
=
2
):
assert
factor
>=
1
and
isinstance
(
factor
,
int
)
factor2
=
factor
**
2
if
factor
==
1
:
return
input
size
=
input
.
size
()
B
=
size
[
0
]
C
=
size
[
1
]
H
=
size
[
2
]
W
=
size
[
3
]
assert
C
%
(
factor2
)
==
0
,
"{}"
.
format
(
C
)
x
=
input
.
view
(
B
,
factor
,
factor
,
C
//
factor2
,
H
,
W
)
x
=
x
.
permute
(
0
,
3
,
4
,
1
,
5
,
2
).
contiguous
()
x
=
x
.
view
(
B
,
C
//
(
factor2
),
H
*
factor
,
W
*
factor
)
return
x
class
InvertibleConv1x1
(
nn
.
Module
):
def
__init__
(
self
,
num_channels
):
super
().
__init__
()
w_shape
=
[
num_channels
,
num_channels
]
w_init
=
np
.
linalg
.
qr
(
np
.
random
.
randn
(
*
w_shape
))[
0
].
astype
(
np
.
float32
)
self
.
register_parameter
(
"weight"
,
nn
.
Parameter
(
torch
.
Tensor
(
w_init
)))
self
.
w_shape
=
w_shape
def
get_weight
(
self
,
input
,
reverse
):
w_shape
=
self
.
w_shape
if
not
reverse
:
weight
=
self
.
weight
.
view
(
w_shape
[
0
],
w_shape
[
1
],
1
,
1
)
else
:
weight
=
torch
.
inverse
(
self
.
weight
.
double
()).
float
()
\
.
view
(
w_shape
[
0
],
w_shape
[
1
],
1
,
1
)
return
weight
def
forward
(
self
,
input
,
reverse
=
False
):
weight
=
self
.
get_weight
(
input
,
reverse
)
if
not
reverse
:
z
=
F
.
conv2d
(
input
,
weight
)
return
z
else
:
z
=
F
.
conv2d
(
input
,
weight
)
return
z
class
DenseBlock
(
nn
.
Module
):
def
__init__
(
self
,
channel_in
,
channel_out
,
init
=
'xavier'
,
gc
=
32
,
bias
=
True
):
super
(
DenseBlock
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
channel_in
,
gc
,
3
,
1
,
1
,
bias
=
bias
)
self
.
conv2
=
nn
.
Conv2d
(
channel_in
+
gc
,
gc
,
3
,
1
,
1
,
bias
=
bias
)
self
.
conv3
=
nn
.
Conv2d
(
channel_in
+
2
*
gc
,
gc
,
3
,
1
,
1
,
bias
=
bias
)
self
.
conv4
=
nn
.
Conv2d
(
channel_in
+
3
*
gc
,
gc
,
3
,
1
,
1
,
bias
=
bias
)
self
.
conv5
=
nn
.
Conv2d
(
channel_in
+
4
*
gc
,
channel_out
,
3
,
1
,
1
,
bias
=
bias
)
self
.
lrelu
=
nn
.
LeakyReLU
(
negative_slope
=
0.2
,
inplace
=
True
)
if
init
==
'xavier'
:
initialize_weights_xavier
([
self
.
conv1
,
self
.
conv2
,
self
.
conv3
,
self
.
conv4
],
0.1
)
else
:
initialize_weights
([
self
.
conv1
,
self
.
conv2
,
self
.
conv3
,
self
.
conv4
],
0.1
)
initialize_weights
(
self
.
conv5
,
0
)
def
forward
(
self
,
x
):
x1
=
self
.
lrelu
(
self
.
conv1
(
x
))
x2
=
self
.
lrelu
(
self
.
conv2
(
torch
.
cat
((
x
,
x1
),
1
)))
x3
=
self
.
lrelu
(
self
.
conv3
(
torch
.
cat
((
x
,
x1
,
x2
),
1
)))
x4
=
self
.
lrelu
(
self
.
conv4
(
torch
.
cat
((
x
,
x1
,
x2
,
x3
),
1
)))
x5
=
self
.
conv5
(
torch
.
cat
((
x
,
x1
,
x2
,
x3
,
x4
),
1
))
return
x5
def
initialize_weights
(
net_l
,
scale
=
1
):
if
not
isinstance
(
net_l
,
list
):
net_l
=
[
net_l
]
for
net
in
net_l
:
for
m
in
net
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
init
.
kaiming_normal_
(
m
.
weight
,
a
=
0
,
mode
=
'fan_in'
)
m
.
weight
.
data
*=
scale
# for residual block
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
elif
isinstance
(
m
,
nn
.
Linear
):
init
.
kaiming_normal_
(
m
.
weight
,
a
=
0
,
mode
=
'fan_in'
)
m
.
weight
.
data
*=
scale
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
init
.
constant_
(
m
.
weight
,
1
)
init
.
constant_
(
m
.
bias
.
data
,
0.0
)
def
initialize_weights_xavier
(
net_l
,
scale
=
1
):
if
not
isinstance
(
net_l
,
list
):
net_l
=
[
net_l
]
for
net
in
net_l
:
for
m
in
net
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
init
.
xavier_normal_
(
m
.
weight
)
m
.
weight
.
data
*=
scale
# for residual block
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
elif
isinstance
(
m
,
nn
.
Linear
):
init
.
xavier_normal_
(
m
.
weight
)
m
.
weight
.
data
*=
scale
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
init
.
constant_
(
m
.
weight
,
1
)
init
.
constant_
(
m
.
bias
.
data
,
0.0
)
def
__init__
(
self
,
in_shape
,
int_ch
,
numTraceSamples
=
0
,
numSeriesTerms
=
0
,
stride
=
1
,
coeff
=
.
97
,
input_nonlin
=
True
,
actnorm
=
True
,
n_power_iter
=
5
,
nonlin
=
"elu"
,
train
=
False
):
"""
buid invertible bottleneck block
:param in_shape: shape of the input (channels, height, width)
:param int_ch: dimension of intermediate layers
:param stride: 1 if no downsample 2 if downsample
:param coeff: desired lipschitz constant
:param input_nonlin: if true applies a nonlinearity on the input
:param actnorm: if true uses actnorm like GLOW
:param n_power_iter: number of iterations for spectral normalization
:param nonlin: the nonlinearity to use
"""
super
(
conv_iresnet_block_simplified
,
self
).
__init__
()
assert
stride
in
(
1
,
2
)
self
.
stride
=
stride
self
.
squeeze
=
IRes_Squeeze
(
stride
)
self
.
coeff
=
coeff
self
.
numTraceSamples
=
numTraceSamples
self
.
numSeriesTerms
=
numSeriesTerms
self
.
n_power_iter
=
n_power_iter
nonlin
=
{
"relu"
:
nn
.
ReLU
,
"elu"
:
nn
.
ELU
,
"softplus"
:
nn
.
Softplus
,
"sorting"
:
lambda
:
MaxMinGroup
(
group_size
=
2
,
axis
=
1
)
}[
nonlin
]
# set shapes for spectral norm conv
in_ch
,
h
,
w
=
in_shape
layers
=
[]
if
input_nonlin
:
layers
.
append
(
nonlin
())
in_ch
=
in_ch
*
stride
**
2
kernel_size1
=
1
layers
.
append
(
self
.
_wrapper_spectral_norm
(
nn
.
Conv2d
(
in_ch
,
int_ch
,
kernel_size
=
kernel_size1
,
padding
=
0
),
(
in_ch
,
h
,
w
),
kernel_size1
))
layers
.
append
(
nonlin
())
kernel_size3
=
1
layers
.
append
(
self
.
_wrapper_spectral_norm
(
nn
.
Conv2d
(
int_ch
,
in_ch
,
kernel_size
=
kernel_size3
,
padding
=
0
),
(
int_ch
,
h
,
w
),
kernel_size3
))
self
.
bottleneck_block
=
nn
.
Sequential
(
*
layers
)
if
actnorm
:
self
.
actnorm
=
ActNorm2D
(
in_ch
,
train
=
train
)
else
:
self
.
actnorm
=
None
def
forward
(
self
,
x
,
rev
=
False
,
ignore_logdet
=
False
,
maxIter
=
25
):
if
not
rev
:
""" bijective or injective block forward """
if
self
.
stride
==
2
:
x
=
self
.
squeeze
.
forward
(
x
)
if
self
.
actnorm
is
not
None
:
x
,
an_logdet
=
self
.
actnorm
(
x
)
else
:
an_logdet
=
0.0
Fx
=
self
.
bottleneck_block
(
x
)
if
(
self
.
numTraceSamples
==
0
and
self
.
numSeriesTerms
==
0
)
or
ignore_logdet
:
trace
=
torch
.
tensor
(
0.
)
else
:
trace
=
power_series_matrix_logarithm_trace
(
Fx
,
x
,
self
.
numSeriesTerms
,
self
.
numTraceSamples
)
y
=
Fx
+
x
return
y
,
trace
+
an_logdet
else
:
y
=
x
for
iter_index
in
range
(
maxIter
):
summand
=
self
.
bottleneck_block
(
x
)
x
=
y
-
summand
if
self
.
actnorm
is
not
None
:
x
=
self
.
actnorm
.
inverse
(
x
)
if
self
.
stride
==
2
:
x
=
self
.
squeeze
.
inverse
(
x
)
return
x
def
_wrapper_spectral_norm
(
self
,
layer
,
shapes
,
kernel_size
):
if
kernel_size
==
1
:
# use spectral norm fc, because bound are tight for 1x1 convolutions
return
spectral_norm_fc
(
layer
,
self
.
coeff
,
n_power_iterations
=
self
.
n_power_iter
)
else
:
# use spectral norm based on conv, because bound not tight
return
spectral_norm_conv
(
layer
,
self
.
coeff
,
shapes
,
n_power_iterations
=
self
.
n_power_iter
)
\ No newline at end of file
codes/compressai/models/ours.py
0 → 100644
View file @
9e459ea3
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.init
as
init
import
torch.nn.functional
as
F
import
warnings
from
.priors
import
JointAutoregressiveHierarchicalPriors
from
.our_utils
import
*
from
compressai.layers
import
*
from
.waseda
import
Cheng2020Anchor
class
InvCompress
(
Cheng2020Anchor
):
def
__init__
(
self
,
N
=
192
,
**
kwargs
):
super
().
__init__
(
N
=
N
)
self
.
g_a
=
None
self
.
g_s
=
None
self
.
enh
=
EnhModule
(
64
)
self
.
inv
=
InvComp
(
M
=
N
)
self
.
attention
=
AttModule
(
N
)
def
g_a_func
(
self
,
x
):
x
=
self
.
enh
(
x
)
x
=
self
.
inv
(
x
)
x
=
self
.
attention
(
x
)
return
x
def
g_s_func
(
self
,
x
):
x
=
self
.
attention
(
x
,
rev
=
True
)
x
=
self
.
inv
(
x
,
rev
=
True
)
x
=
self
.
enh
(
x
,
rev
=
True
)
return
x
def
forward
(
self
,
x
):
y
=
self
.
g_a_func
(
x
)
z
=
self
.
h_a
(
y
)
z_hat
,
z_likelihoods
=
self
.
entropy_bottleneck
(
z
)
params
=
self
.
h_s
(
z_hat
)
y_hat
=
self
.
gaussian_conditional
.
quantize
(
y
,
"noise"
if
self
.
training
else
"dequantize"
)
ctx_params
=
self
.
context_prediction
(
y_hat
)
gaussian_params
=
self
.
entropy_parameters
(
torch
.
cat
((
params
,
ctx_params
),
dim
=
1
)
)
scales_hat
,
means_hat
=
gaussian_params
.
chunk
(
2
,
1
)
_
,
y_likelihoods
=
self
.
gaussian_conditional
(
y
,
scales_hat
,
means
=
means_hat
)
x_hat
=
self
.
g_s_func
(
y_hat
)
return
{
"x_hat"
:
x_hat
,
"likelihoods"
:
{
"y"
:
y_likelihoods
,
"z"
:
z_likelihoods
}
}
@
classmethod
def
from_state_dict
(
cls
,
state_dict
):
"""Return a new model instance from `state_dict`."""
N
=
state_dict
[
"h_a.0.weight"
].
size
(
0
)
net
=
cls
(
N
)
net
.
load_state_dict
(
state_dict
)
return
net
def
compress
(
self
,
x
):
if
next
(
self
.
parameters
()).
device
!=
torch
.
device
(
"cpu"
):
warnings
.
warn
(
"Inference on GPU is not recommended for the autoregressive "
"models (the entropy coder is run sequentially on CPU)."
)
y
=
self
.
g_a_func
(
x
)
z
=
self
.
h_a
(
y
)
z_strings
=
self
.
entropy_bottleneck
.
compress
(
z
)
z_hat
=
self
.
entropy_bottleneck
.
decompress
(
z_strings
,
z
.
size
()[
-
2
:])
params
=
self
.
h_s
(
z_hat
)
s
=
4
# scaling factor between z and y
kernel_size
=
5
# context prediction kernel size
padding
=
(
kernel_size
-
1
)
//
2
y_height
=
z_hat
.
size
(
2
)
*
s
y_width
=
z_hat
.
size
(
3
)
*
s
y_hat
=
F
.
pad
(
y
,
(
padding
,
padding
,
padding
,
padding
))
y_strings
=
[]
for
i
in
range
(
y
.
size
(
0
)):
string
=
self
.
_compress_ar
(
y_hat
[
i
:
i
+
1
],
params
[
i
:
i
+
1
],
y_height
,
y_width
,
kernel_size
,
padding
,
)
y_strings
.
append
(
string
)
return
{
"strings"
:
[
y_strings
,
z_strings
],
"shape"
:
z
.
size
()[
-
2
:],
"y"
:
y
}
def
decompress
(
self
,
strings
,
shape
):
assert
isinstance
(
strings
,
list
)
and
len
(
strings
)
==
2
if
next
(
self
.
parameters
()).
device
!=
torch
.
device
(
"cpu"
):
warnings
.
warn
(
"Inference on GPU is not recommended for the autoregressive "
"models (the entropy coder is run sequentially on CPU)."
)
z_hat
=
self
.
entropy_bottleneck
.
decompress
(
strings
[
1
],
shape
)
params
=
self
.
h_s
(
z_hat
)
s
=
4
# scaling factor between z and y
kernel_size
=
5
# context prediction kernel size
padding
=
(
kernel_size
-
1
)
//
2
y_height
=
z_hat
.
size
(
2
)
*
s
y_width
=
z_hat
.
size
(
3
)
*
s
y_hat
=
torch
.
zeros
(
(
z_hat
.
size
(
0
),
self
.
M
,
y_height
+
2
*
padding
,
y_width
+
2
*
padding
),
device
=
z_hat
.
device
,
)
for
i
,
y_string
in
enumerate
(
strings
[
0
]):
self
.
_decompress_ar
(
y_string
,
y_hat
[
i
:
i
+
1
],
params
[
i
:
i
+
1
],
y_height
,
y_width
,
kernel_size
,
padding
,
)
y_hat
=
F
.
pad
(
y_hat
,
(
-
padding
,
-
padding
,
-
padding
,
-
padding
))
x_hat
=
self
.
g_s_func
(
y_hat
).
clamp_
(
0
,
1
)
return
{
"x_hat"
:
x_hat
}
codes/compressai/models/priors.py
0 → 100644
View file @
9e459ea3
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
warnings
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
# pylint: disable=E0611,E0401
from
compressai.ans
import
BufferedRansEncoder
,
RansDecoder
from
compressai.entropy_models
import
EntropyBottleneck
,
GaussianConditional
from
compressai.layers
import
GDN
,
MaskedConv2d
from
.utils
import
conv
,
deconv
,
update_registered_buffers
# pylint: enable=E0611,E0401
__all__
=
[
"CompressionModel"
,
"FactorizedPrior"
,
"ScaleHyperprior"
,
"MeanScaleHyperprior"
,
"JointAutoregressiveHierarchicalPriors"
,
]
class
CompressionModel
(
nn
.
Module
):
"""Base class for constructing an auto-encoder with at least one entropy
bottleneck module.
Args:
entropy_bottleneck_channels (int): Number of channels of the entropy
bottleneck
"""
def
__init__
(
self
,
entropy_bottleneck_channels
,
init_weights
=
True
):
super
().
__init__
()
self
.
entropy_bottleneck
=
EntropyBottleneck
(
entropy_bottleneck_channels
)
if
init_weights
:
self
.
_initialize_weights
()
def
aux_loss
(
self
):
"""Return the aggregated loss over the auxiliary entropy bottleneck
module(s).
"""
aux_loss
=
sum
(
m
.
loss
()
for
m
in
self
.
modules
()
if
isinstance
(
m
,
EntropyBottleneck
)
)
return
aux_loss
def
_initialize_weights
(
self
):
for
m
in
self
.
modules
():
if
isinstance
(
m
,
(
nn
.
Conv2d
,
nn
.
ConvTranspose2d
)):
nn
.
init
.
kaiming_normal_
(
m
.
weight
)
if
m
.
bias
is
not
None
:
nn
.
init
.
zeros_
(
m
.
bias
)
def
forward
(
self
,
*
args
):
raise
NotImplementedError
()
def
update
(
self
,
force
=
False
):
"""Updates the entropy bottleneck(s) CDF values.
Needs to be called once after training to be able to later perform the
evaluation with an actual entropy coder.
Args:
force (bool): overwrite previous values (default: False)
Returns:
updated (bool): True if one of the EntropyBottlenecks was updated.
"""
updated
=
False
for
m
in
self
.
children
():
if
not
isinstance
(
m
,
EntropyBottleneck
):
continue
rv
=
m
.
update
(
force
=
force
)
updated
|=
rv
return
updated
def
load_state_dict
(
self
,
state_dict
):
# Dynamically update the entropy bottleneck buffers related to the CDFs
update_registered_buffers
(
self
.
entropy_bottleneck
,
"entropy_bottleneck"
,
[
"_quantized_cdf"
,
"_offset"
,
"_cdf_length"
],
state_dict
,
)
super
().
load_state_dict
(
state_dict
)
class
FactorizedPrior
(
CompressionModel
):
r
"""Factorized Prior model from J. Balle, D. Minnen, S. Singh, S.J. Hwang,
N. Johnston: `"Variational Image Compression with a Scale Hyperprior"
<https://arxiv.org/abs/1802.01436>`_, Int Conf. on Learning Representations
(ICLR), 2018.
Args:
N (int): Number of channels
M (int): Number of channels in the expansion layers (last layer of the
encoder and last layer of the hyperprior decoder)
"""
def
__init__
(
self
,
N
,
M
,
**
kwargs
):
super
().
__init__
(
entropy_bottleneck_channels
=
M
,
**
kwargs
)
self
.
g_a
=
nn
.
Sequential
(
conv
(
3
,
N
),
GDN
(
N
),
conv
(
N
,
N
),
GDN
(
N
),
conv
(
N
,
N
),
GDN
(
N
),
conv
(
N
,
M
),
)
self
.
g_s
=
nn
.
Sequential
(
deconv
(
M
,
N
),
GDN
(
N
,
inverse
=
True
),
deconv
(
N
,
N
),
GDN
(
N
,
inverse
=
True
),
deconv
(
N
,
N
),
GDN
(
N
,
inverse
=
True
),
deconv
(
N
,
3
),
)
self
.
N
=
N
self
.
M
=
M
@
property
def
downsampling_factor
(
self
)
->
int
:
return
2
**
4
def
forward
(
self
,
x
):
y
=
self
.
g_a
(
x
)
y_hat
,
y_likelihoods
=
self
.
entropy_bottleneck
(
y
)
x_hat
=
self
.
g_s
(
y_hat
)
return
{
"x_hat"
:
x_hat
,
"likelihoods"
:
{
"y"
:
y_likelihoods
,
},
}
@
classmethod
def
from_state_dict
(
cls
,
state_dict
):
"""Return a new model instance from `state_dict`."""
N
=
state_dict
[
"g_a.0.weight"
].
size
(
0
)
M
=
state_dict
[
"g_a.6.weight"
].
size
(
0
)
net
=
cls
(
N
,
M
)
net
.
load_state_dict
(
state_dict
)
return
net
def
compress
(
self
,
x
):
y
=
self
.
g_a
(
x
)
y_strings
=
self
.
entropy_bottleneck
.
compress
(
y
)
return
{
"strings"
:
[
y_strings
],
"shape"
:
y
.
size
()[
-
2
:]}
def
decompress
(
self
,
strings
,
shape
):
assert
isinstance
(
strings
,
list
)
and
len
(
strings
)
==
1
y_hat
=
self
.
entropy_bottleneck
.
decompress
(
strings
[
0
],
shape
)
x_hat
=
self
.
g_s
(
y_hat
).
clamp_
(
0
,
1
)
return
{
"x_hat"
:
x_hat
}
# From Balle's tensorflow compression examples
SCALES_MIN
=
0.11
SCALES_MAX
=
256
SCALES_LEVELS
=
64
def
get_scale_table
(
min
=
SCALES_MIN
,
max
=
SCALES_MAX
,
levels
=
SCALES_LEVELS
):
# pylint: disable=W0622
return
torch
.
exp
(
torch
.
linspace
(
math
.
log
(
min
),
math
.
log
(
max
),
levels
))
class
ScaleHyperprior
(
CompressionModel
):
r
"""Scale Hyperprior model from J. Balle, D. Minnen, S. Singh, S.J. Hwang,
N. Johnston: `"Variational Image Compression with a Scale Hyperprior"
<https://arxiv.org/abs/1802.01436>`_ Int. Conf. on Learning Representations
(ICLR), 2018.
Args:
N (int): Number of channels
M (int): Number of channels in the expansion layers (last layer of the
encoder and last layer of the hyperprior decoder)
"""
def
__init__
(
self
,
N
,
M
,
**
kwargs
):
super
().
__init__
(
entropy_bottleneck_channels
=
N
,
**
kwargs
)
self
.
g_a
=
nn
.
Sequential
(
conv
(
3
,
N
),
GDN
(
N
),
conv
(
N
,
N
),
GDN
(
N
),
conv
(
N
,
N
),
GDN
(
N
),
conv
(
N
,
M
),
)
self
.
g_s
=
nn
.
Sequential
(
deconv
(
M
,
N
),
GDN
(
N
,
inverse
=
True
),
deconv
(
N
,
N
),
GDN
(
N
,
inverse
=
True
),
deconv
(
N
,
N
),
GDN
(
N
,
inverse
=
True
),
deconv
(
N
,
3
),
)
self
.
h_a
=
nn
.
Sequential
(
conv
(
M
,
N
,
stride
=
1
,
kernel_size
=
3
),
nn
.
ReLU
(
inplace
=
True
),
conv
(
N
,
N
),
nn
.
ReLU
(
inplace
=
True
),
conv
(
N
,
N
),
)
self
.
h_s
=
nn
.
Sequential
(
deconv
(
N
,
N
),
nn
.
ReLU
(
inplace
=
True
),
deconv
(
N
,
N
),
nn
.
ReLU
(
inplace
=
True
),
conv
(
N
,
M
,
stride
=
1
,
kernel_size
=
3
),
nn
.
ReLU
(
inplace
=
True
),
)
self
.
gaussian_conditional
=
GaussianConditional
(
None
)
self
.
N
=
int
(
N
)
self
.
M
=
int
(
M
)
@
property
def
downsampling_factor
(
self
)
->
int
:
return
2
**
(
4
+
2
)
def
forward
(
self
,
x
):
y
=
self
.
g_a
(
x
)
z
=
self
.
h_a
(
torch
.
abs
(
y
))
z_hat
,
z_likelihoods
=
self
.
entropy_bottleneck
(
z
)
scales_hat
=
self
.
h_s
(
z_hat
)
y_hat
,
y_likelihoods
=
self
.
gaussian_conditional
(
y
,
scales_hat
)
x_hat
=
self
.
g_s
(
y_hat
)
return
{
"x_hat"
:
x_hat
,
"likelihoods"
:
{
"y"
:
y_likelihoods
,
"z"
:
z_likelihoods
},
}
def
load_state_dict
(
self
,
state_dict
):
update_registered_buffers
(
self
.
gaussian_conditional
,
"gaussian_conditional"
,
[
"_quantized_cdf"
,
"_offset"
,
"_cdf_length"
,
"scale_table"
],
state_dict
,
)
super
().
load_state_dict
(
state_dict
)
@
classmethod
def
from_state_dict
(
cls
,
state_dict
):
"""Return a new model instance from `state_dict`."""
N
=
state_dict
[
"g_a.0.weight"
].
size
(
0
)
M
=
state_dict
[
"g_a.6.weight"
].
size
(
0
)
net
=
cls
(
N
,
M
)
net
.
load_state_dict
(
state_dict
)
return
net
def
update
(
self
,
scale_table
=
None
,
force
=
False
):
if
scale_table
is
None
:
scale_table
=
get_scale_table
()
updated
=
self
.
gaussian_conditional
.
update_scale_table
(
scale_table
,
force
=
force
)
updated
|=
super
().
update
(
force
=
force
)
return
updated
def
compress
(
self
,
x
):
y
=
self
.
g_a
(
x
)
z
=
self
.
h_a
(
torch
.
abs
(
y
))
z_strings
=
self
.
entropy_bottleneck
.
compress
(
z
)
z_hat
=
self
.
entropy_bottleneck
.
decompress
(
z_strings
,
z
.
size
()[
-
2
:])
scales_hat
=
self
.
h_s
(
z_hat
)
indexes
=
self
.
gaussian_conditional
.
build_indexes
(
scales_hat
)
y_strings
=
self
.
gaussian_conditional
.
compress
(
y
,
indexes
)
return
{
"strings"
:
[
y_strings
,
z_strings
],
"shape"
:
z
.
size
()[
-
2
:]}
def
decompress
(
self
,
strings
,
shape
):
assert
isinstance
(
strings
,
list
)
and
len
(
strings
)
==
2
z_hat
=
self
.
entropy_bottleneck
.
decompress
(
strings
[
1
],
shape
)
scales_hat
=
self
.
h_s
(
z_hat
)
indexes
=
self
.
gaussian_conditional
.
build_indexes
(
scales_hat
)
y_hat
=
self
.
gaussian_conditional
.
decompress
(
strings
[
0
],
indexes
)
x_hat
=
self
.
g_s
(
y_hat
).
clamp_
(
0
,
1
)
return
{
"x_hat"
:
x_hat
}
class
MeanScaleHyperprior
(
ScaleHyperprior
):
r
"""Scale Hyperprior with non zero-mean Gaussian conditionals from D.
Minnen, J. Balle, G.D. Toderici: `"Joint Autoregressive and Hierarchical
Priors for Learned Image Compression" <https://arxiv.org/abs/1809.02736>`_,
Adv. in Neural Information Processing Systems 31 (NeurIPS 2018).
Args:
N (int): Number of channels
M (int): Number of channels in the expansion layers (last layer of the
encoder and last layer of the hyperprior decoder)
"""
def
__init__
(
self
,
N
,
M
,
**
kwargs
):
super
().
__init__
(
N
,
M
,
**
kwargs
)
self
.
h_a
=
nn
.
Sequential
(
conv
(
M
,
N
,
stride
=
1
,
kernel_size
=
3
),
nn
.
LeakyReLU
(
inplace
=
True
),
conv
(
N
,
N
),
nn
.
LeakyReLU
(
inplace
=
True
),
conv
(
N
,
N
),
)
self
.
h_s
=
nn
.
Sequential
(
deconv
(
N
,
M
),
nn
.
LeakyReLU
(
inplace
=
True
),
deconv
(
M
,
M
*
3
//
2
),
nn
.
LeakyReLU
(
inplace
=
True
),
conv
(
M
*
3
//
2
,
M
*
2
,
stride
=
1
,
kernel_size
=
3
),
)
def
forward
(
self
,
x
):
y
=
self
.
g_a
(
x
)
z
=
self
.
h_a
(
y
)
z_hat
,
z_likelihoods
=
self
.
entropy_bottleneck
(
z
)
gaussian_params
=
self
.
h_s
(
z_hat
)
scales_hat
,
means_hat
=
gaussian_params
.
chunk
(
2
,
1
)
y_hat
,
y_likelihoods
=
self
.
gaussian_conditional
(
y
,
scales_hat
,
means
=
means_hat
)
x_hat
=
self
.
g_s
(
y_hat
)
return
{
"x_hat"
:
x_hat
,
"likelihoods"
:
{
"y"
:
y_likelihoods
,
"z"
:
z_likelihoods
},
}
def
compress
(
self
,
x
):
y
=
self
.
g_a
(
x
)
z
=
self
.
h_a
(
y
)
z_strings
=
self
.
entropy_bottleneck
.
compress
(
z
)
z_hat
=
self
.
entropy_bottleneck
.
decompress
(
z_strings
,
z
.
size
()[
-
2
:])
gaussian_params
=
self
.
h_s
(
z_hat
)
scales_hat
,
means_hat
=
gaussian_params
.
chunk
(
2
,
1
)
indexes
=
self
.
gaussian_conditional
.
build_indexes
(
scales_hat
)
y_strings
=
self
.
gaussian_conditional
.
compress
(
y
,
indexes
,
means
=
means_hat
)
return
{
"strings"
:
[
y_strings
,
z_strings
],
"shape"
:
z
.
size
()[
-
2
:]}
def
decompress
(
self
,
strings
,
shape
):
assert
isinstance
(
strings
,
list
)
and
len
(
strings
)
==
2
z_hat
=
self
.
entropy_bottleneck
.
decompress
(
strings
[
1
],
shape
)
gaussian_params
=
self
.
h_s
(
z_hat
)
scales_hat
,
means_hat
=
gaussian_params
.
chunk
(
2
,
1
)
indexes
=
self
.
gaussian_conditional
.
build_indexes
(
scales_hat
)
y_hat
=
self
.
gaussian_conditional
.
decompress
(
strings
[
0
],
indexes
,
means
=
means_hat
)
x_hat
=
self
.
g_s
(
y_hat
).
clamp_
(
0
,
1
)
return
{
"x_hat"
:
x_hat
}
class
JointAutoregressiveHierarchicalPriors
(
MeanScaleHyperprior
):
r
"""Joint Autoregressive Hierarchical Priors model from D.
Minnen, J. Balle, G.D. Toderici: `"Joint Autoregressive and Hierarchical
Priors for Learned Image Compression" <https://arxiv.org/abs/1809.02736>`_,
Adv. in Neural Information Processing Systems 31 (NeurIPS 2018).
Args:
N (int): Number of channels
M (int): Number of channels in the expansion layers (last layer of the
encoder and last layer of the hyperprior decoder)
"""
def
__init__
(
self
,
N
=
192
,
M
=
192
,
**
kwargs
):
super
().
__init__
(
N
=
N
,
M
=
M
,
**
kwargs
)
self
.
g_a
=
nn
.
Sequential
(
conv
(
3
,
N
,
kernel_size
=
5
,
stride
=
2
),
GDN
(
N
),
conv
(
N
,
N
,
kernel_size
=
5
,
stride
=
2
),
GDN
(
N
),
conv
(
N
,
N
,
kernel_size
=
5
,
stride
=
2
),
GDN
(
N
),
conv
(
N
,
M
,
kernel_size
=
5
,
stride
=
2
),
)
self
.
g_s
=
nn
.
Sequential
(
deconv
(
M
,
N
,
kernel_size
=
5
,
stride
=
2
),
GDN
(
N
,
inverse
=
True
),
deconv
(
N
,
N
,
kernel_size
=
5
,
stride
=
2
),
GDN
(
N
,
inverse
=
True
),
deconv
(
N
,
N
,
kernel_size
=
5
,
stride
=
2
),
GDN
(
N
,
inverse
=
True
),
deconv
(
N
,
3
,
kernel_size
=
5
,
stride
=
2
),
)
self
.
h_a
=
nn
.
Sequential
(
conv
(
M
,
N
,
stride
=
1
,
kernel_size
=
3
),
nn
.
LeakyReLU
(
inplace
=
True
),
conv
(
N
,
N
,
stride
=
2
,
kernel_size
=
5
),
nn
.
LeakyReLU
(
inplace
=
True
),
conv
(
N
,
N
,
stride
=
2
,
kernel_size
=
5
),
)
self
.
h_s
=
nn
.
Sequential
(
deconv
(
N
,
M
,
stride
=
2
,
kernel_size
=
5
),
nn
.
LeakyReLU
(
inplace
=
True
),
deconv
(
M
,
M
*
3
//
2
,
stride
=
2
,
kernel_size
=
5
),
nn
.
LeakyReLU
(
inplace
=
True
),
conv
(
M
*
3
//
2
,
M
*
2
,
stride
=
1
,
kernel_size
=
3
),
)
self
.
entropy_parameters
=
nn
.
Sequential
(
nn
.
Conv2d
(
M
*
12
//
3
,
M
*
10
//
3
,
1
),
nn
.
LeakyReLU
(
inplace
=
True
),
nn
.
Conv2d
(
M
*
10
//
3
,
M
*
8
//
3
,
1
),
nn
.
LeakyReLU
(
inplace
=
True
),
nn
.
Conv2d
(
M
*
8
//
3
,
M
*
6
//
3
,
1
),
)
self
.
context_prediction
=
MaskedConv2d
(
M
,
2
*
M
,
kernel_size
=
5
,
padding
=
2
,
stride
=
1
)
self
.
gaussian_conditional
=
GaussianConditional
(
None
)
self
.
N
=
int
(
N
)
self
.
M
=
int
(
M
)
@
property
def
downsampling_factor
(
self
)
->
int
:
return
2
**
(
4
+
2
)
def
forward
(
self
,
x
):
y
=
self
.
g_a
(
x
)
z
=
self
.
h_a
(
y
)
z_hat
,
z_likelihoods
=
self
.
entropy_bottleneck
(
z
)
params
=
self
.
h_s
(
z_hat
)
y_hat
=
self
.
gaussian_conditional
.
quantize
(
y
,
"noise"
if
self
.
training
else
"dequantize"
)
ctx_params
=
self
.
context_prediction
(
y_hat
)
gaussian_params
=
self
.
entropy_parameters
(
torch
.
cat
((
params
,
ctx_params
),
dim
=
1
)
)
scales_hat
,
means_hat
=
gaussian_params
.
chunk
(
2
,
1
)
_
,
y_likelihoods
=
self
.
gaussian_conditional
(
y
,
scales_hat
,
means
=
means_hat
)
x_hat
=
self
.
g_s
(
y_hat
)
return
{
"x_hat"
:
x_hat
,
"likelihoods"
:
{
"y"
:
y_likelihoods
,
"z"
:
z_likelihoods
},
}
@
classmethod
def
from_state_dict
(
cls
,
state_dict
):
"""Return a new model instance from `state_dict`."""
N
=
state_dict
[
"g_a.0.weight"
].
size
(
0
)
M
=
state_dict
[
"g_a.6.weight"
].
size
(
0
)
net
=
cls
(
N
,
M
)
net
.
load_state_dict
(
state_dict
)
return
net
def
compress
(
self
,
x
):
if
next
(
self
.
parameters
()).
device
!=
torch
.
device
(
"cpu"
):
warnings
.
warn
(
"Inference on GPU is not recommended for the autoregressive "
"models (the entropy coder is run sequentially on CPU)."
)
y
=
self
.
g_a
(
x
)
z
=
self
.
h_a
(
y
)
z_strings
=
self
.
entropy_bottleneck
.
compress
(
z
)
z_hat
=
self
.
entropy_bottleneck
.
decompress
(
z_strings
,
z
.
size
()[
-
2
:])
params
=
self
.
h_s
(
z_hat
)
s
=
4
# scaling factor between z and y
kernel_size
=
5
# context prediction kernel size
padding
=
(
kernel_size
-
1
)
//
2
y_height
=
z_hat
.
size
(
2
)
*
s
y_width
=
z_hat
.
size
(
3
)
*
s
y_hat
=
F
.
pad
(
y
,
(
padding
,
padding
,
padding
,
padding
))
y_strings
=
[]
for
i
in
range
(
y
.
size
(
0
)):
string
=
self
.
_compress_ar
(
y_hat
[
i
:
i
+
1
],
params
[
i
:
i
+
1
],
y_height
,
y_width
,
kernel_size
,
padding
,
)
y_strings
.
append
(
string
)
return
{
"strings"
:
[
y_strings
,
z_strings
],
"shape"
:
z
.
size
()[
-
2
:]}
def
_compress_ar
(
self
,
y_hat
,
params
,
height
,
width
,
kernel_size
,
padding
):
cdf
=
self
.
gaussian_conditional
.
quantized_cdf
.
tolist
()
cdf_lengths
=
self
.
gaussian_conditional
.
cdf_length
.
tolist
()
offsets
=
self
.
gaussian_conditional
.
offset
.
tolist
()
encoder
=
BufferedRansEncoder
()
symbols_list
=
[]
indexes_list
=
[]
# Warning, this is slow...
# TODO: profile the calls to the bindings...
masked_weight
=
self
.
context_prediction
.
weight
*
self
.
context_prediction
.
mask
for
h
in
range
(
height
):
for
w
in
range
(
width
):
y_crop
=
y_hat
[:,
:,
h
:
h
+
kernel_size
,
w
:
w
+
kernel_size
]
ctx_p
=
F
.
conv2d
(
y_crop
,
masked_weight
,
bias
=
self
.
context_prediction
.
bias
,
)
# 1x1 conv for the entropy parameters prediction network, so
# we only keep the elements in the "center"
p
=
params
[:,
:,
h
:
h
+
1
,
w
:
w
+
1
]
gaussian_params
=
self
.
entropy_parameters
(
torch
.
cat
((
p
,
ctx_p
),
dim
=
1
))
gaussian_params
=
gaussian_params
.
squeeze
(
3
).
squeeze
(
2
)
scales_hat
,
means_hat
=
gaussian_params
.
chunk
(
2
,
1
)
indexes
=
self
.
gaussian_conditional
.
build_indexes
(
scales_hat
)
y_crop
=
y_crop
[:,
:,
padding
,
padding
]
y_q
=
self
.
gaussian_conditional
.
quantize
(
y_crop
,
"symbols"
,
means_hat
)
y_hat
[:,
:,
h
+
padding
,
w
+
padding
]
=
y_q
+
means_hat
symbols_list
.
extend
(
y_q
.
squeeze
().
tolist
())
indexes_list
.
extend
(
indexes
.
squeeze
().
tolist
())
encoder
.
encode_with_indexes
(
symbols_list
,
indexes_list
,
cdf
,
cdf_lengths
,
offsets
)
string
=
encoder
.
flush
()
return
string
def
decompress
(
self
,
strings
,
shape
):
assert
isinstance
(
strings
,
list
)
and
len
(
strings
)
==
2
if
next
(
self
.
parameters
()).
device
!=
torch
.
device
(
"cpu"
):
warnings
.
warn
(
"Inference on GPU is not recommended for the autoregressive "
"models (the entropy coder is run sequentially on CPU)."
)
# FIXME: we don't respect the default entropy coder and directly call the
# range ANS decoder
z_hat
=
self
.
entropy_bottleneck
.
decompress
(
strings
[
1
],
shape
)
params
=
self
.
h_s
(
z_hat
)
s
=
4
# scaling factor between z and y
kernel_size
=
5
# context prediction kernel size
padding
=
(
kernel_size
-
1
)
//
2
y_height
=
z_hat
.
size
(
2
)
*
s
y_width
=
z_hat
.
size
(
3
)
*
s
# initialize y_hat to zeros, and pad it so we can directly work with
# sub-tensors of size (N, C, kernel size, kernel_size)
y_hat
=
torch
.
zeros
(
(
z_hat
.
size
(
0
),
self
.
M
,
y_height
+
2
*
padding
,
y_width
+
2
*
padding
),
device
=
z_hat
.
device
,
)
for
i
,
y_string
in
enumerate
(
strings
[
0
]):
self
.
_decompress_ar
(
y_string
,
y_hat
[
i
:
i
+
1
],
params
[
i
:
i
+
1
],
y_height
,
y_width
,
kernel_size
,
padding
,
)
y_hat
=
F
.
pad
(
y_hat
,
(
-
padding
,
-
padding
,
-
padding
,
-
padding
))
x_hat
=
self
.
g_s
(
y_hat
).
clamp_
(
0
,
1
)
return
{
"x_hat"
:
x_hat
}
def
_decompress_ar
(
self
,
y_string
,
y_hat
,
params
,
height
,
width
,
kernel_size
,
padding
):
cdf
=
self
.
gaussian_conditional
.
quantized_cdf
.
tolist
()
cdf_lengths
=
self
.
gaussian_conditional
.
cdf_length
.
tolist
()
offsets
=
self
.
gaussian_conditional
.
offset
.
tolist
()
decoder
=
RansDecoder
()
decoder
.
set_stream
(
y_string
)
# Warning: this is slow due to the auto-regressive nature of the
# decoding... See more recent publication where they use an
# auto-regressive module on chunks of channels for faster decoding...
for
h
in
range
(
height
):
for
w
in
range
(
width
):
# only perform the 5x5 convolution on a cropped tensor
# centered in (h, w)
y_crop
=
y_hat
[:,
:,
h
:
h
+
kernel_size
,
w
:
w
+
kernel_size
]
ctx_p
=
F
.
conv2d
(
y_crop
,
self
.
context_prediction
.
weight
,
bias
=
self
.
context_prediction
.
bias
,
)
# 1x1 conv for the entropy parameters prediction network, so
# we only keep the elements in the "center"
p
=
params
[:,
:,
h
:
h
+
1
,
w
:
w
+
1
]
gaussian_params
=
self
.
entropy_parameters
(
torch
.
cat
((
p
,
ctx_p
),
dim
=
1
))
scales_hat
,
means_hat
=
gaussian_params
.
chunk
(
2
,
1
)
indexes
=
self
.
gaussian_conditional
.
build_indexes
(
scales_hat
)
rv
=
decoder
.
decode_stream
(
indexes
.
squeeze
().
tolist
(),
cdf
,
cdf_lengths
,
offsets
)
rv
=
torch
.
Tensor
(
rv
).
reshape
(
1
,
-
1
,
1
,
1
)
rv
=
self
.
gaussian_conditional
.
dequantize
(
rv
,
means_hat
)
hp
=
h
+
padding
wp
=
w
+
padding
y_hat
[:,
:,
hp
:
hp
+
1
,
wp
:
wp
+
1
]
=
rv
\ No newline at end of file
codes/compressai/models/utils.py
0 → 100644
View file @
9e459ea3
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.nn
as
nn
def
find_named_module
(
module
,
query
):
"""Helper function to find a named module. Returns a `nn.Module` or `None`
Args:
module (nn.Module): the root module
query (str): the module name to find
Returns:
nn.Module or None
"""
return
next
((
m
for
n
,
m
in
module
.
named_modules
()
if
n
==
query
),
None
)
def
find_named_buffer
(
module
,
query
):
"""Helper function to find a named buffer. Returns a `torch.Tensor` or `None`
Args:
module (nn.Module): the root module
query (str): the buffer name to find
Returns:
torch.Tensor or None
"""
return
next
((
b
for
n
,
b
in
module
.
named_buffers
()
if
n
==
query
),
None
)
def
_update_registered_buffer
(
module
,
buffer_name
,
state_dict_key
,
state_dict
,
policy
=
"resize_if_empty"
,
dtype
=
torch
.
int
,
):
new_size
=
state_dict
[
state_dict_key
].
size
()
registered_buf
=
find_named_buffer
(
module
,
buffer_name
)
if
policy
in
(
"resize_if_empty"
,
"resize"
):
if
registered_buf
is
None
:
raise
RuntimeError
(
f
'buffer "
{
buffer_name
}
" was not registered'
)
if
policy
==
"resize"
or
registered_buf
.
numel
()
==
0
:
registered_buf
.
resize_
(
new_size
)
elif
policy
==
"register"
:
if
registered_buf
is
not
None
:
raise
RuntimeError
(
f
'buffer "
{
buffer_name
}
" was already registered'
)
module
.
register_buffer
(
buffer_name
,
torch
.
empty
(
new_size
,
dtype
=
dtype
).
fill_
(
0
))
else
:
raise
ValueError
(
f
'Invalid policy "
{
policy
}
"'
)
def
update_registered_buffers
(
module
,
module_name
,
buffer_names
,
state_dict
,
policy
=
"resize_if_empty"
,
dtype
=
torch
.
int
,
):
"""Update the registered buffers in a module according to the tensors sized
in a state_dict.
(There's no way in torch to directly load a buffer with a dynamic size)
Args:
module (nn.Module): the module
module_name (str): module name in the state dict
buffer_names (list(str)): list of the buffer names to resize in the module
state_dict (dict): the state dict
policy (str): Update policy, choose from
('resize_if_empty', 'resize', 'register')
dtype (dtype): Type of buffer to be registered (when policy is 'register')
"""
valid_buffer_names
=
[
n
for
n
,
_
in
module
.
named_buffers
()]
for
buffer_name
in
buffer_names
:
if
buffer_name
not
in
valid_buffer_names
:
raise
ValueError
(
f
'Invalid buffer name "
{
buffer_name
}
"'
)
for
buffer_name
in
buffer_names
:
_update_registered_buffer
(
module
,
buffer_name
,
f
"
{
module_name
}
.
{
buffer_name
}
"
,
state_dict
,
policy
,
dtype
,
)
def
conv
(
in_channels
,
out_channels
,
kernel_size
=
5
,
stride
=
2
):
return
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
kernel_size
//
2
,
)
def
deconv
(
in_channels
,
out_channels
,
kernel_size
=
5
,
stride
=
2
):
return
nn
.
ConvTranspose2d
(
in_channels
,
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
output_padding
=
stride
-
1
,
padding
=
kernel_size
//
2
,
)
codes/compressai/models/waseda.py
0 → 100644
View file @
9e459ea3
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch.nn
as
nn
from
compressai.layers
import
(
AttentionBlock
,
ResidualBlock
,
ResidualBlockUpsample
,
ResidualBlockWithStride
,
conv3x3
,
subpel_conv3x3
,
)
from
.priors
import
JointAutoregressiveHierarchicalPriors
class
Cheng2020Anchor
(
JointAutoregressiveHierarchicalPriors
):
"""Anchor model variant from `"Learned Image Compression with
Discretized Gaussian Mixture Likelihoods and Attention Modules"
<https://arxiv.org/abs/2001.01568>`_, by Zhengxue Cheng, Heming Sun, Masaru
Takeuchi, Jiro Katto.
Uses residual blocks with small convolutions (3x3 and 1x1), and sub-pixel
convolutions for up-sampling.
Args:
N (int): Number of channels
"""
def
__init__
(
self
,
N
=
192
,
**
kwargs
):
super
().
__init__
(
N
=
N
,
M
=
N
,
**
kwargs
)
self
.
g_a
=
nn
.
Sequential
(
ResidualBlockWithStride
(
3
,
N
,
stride
=
2
),
ResidualBlock
(
N
,
N
),
ResidualBlockWithStride
(
N
,
N
,
stride
=
2
),
ResidualBlock
(
N
,
N
),
ResidualBlockWithStride
(
N
,
N
,
stride
=
2
),
ResidualBlock
(
N
,
N
),
conv3x3
(
N
,
N
,
stride
=
2
),
)
self
.
h_a
=
nn
.
Sequential
(
conv3x3
(
N
,
N
),
nn
.
LeakyReLU
(
inplace
=
True
),
conv3x3
(
N
,
N
),
nn
.
LeakyReLU
(
inplace
=
True
),
conv3x3
(
N
,
N
,
stride
=
2
),
nn
.
LeakyReLU
(
inplace
=
True
),
conv3x3
(
N
,
N
),
nn
.
LeakyReLU
(
inplace
=
True
),
conv3x3
(
N
,
N
,
stride
=
2
),
)
self
.
h_s
=
nn
.
Sequential
(
conv3x3
(
N
,
N
),
nn
.
LeakyReLU
(
inplace
=
True
),
subpel_conv3x3
(
N
,
N
,
2
),
nn
.
LeakyReLU
(
inplace
=
True
),
conv3x3
(
N
,
N
*
3
//
2
),
nn
.
LeakyReLU
(
inplace
=
True
),
subpel_conv3x3
(
N
*
3
//
2
,
N
*
3
//
2
,
2
),
nn
.
LeakyReLU
(
inplace
=
True
),
conv3x3
(
N
*
3
//
2
,
N
*
2
),
)
self
.
g_s
=
nn
.
Sequential
(
ResidualBlock
(
N
,
N
),
ResidualBlockUpsample
(
N
,
N
,
2
),
ResidualBlock
(
N
,
N
),
ResidualBlockUpsample
(
N
,
N
,
2
),
ResidualBlock
(
N
,
N
),
ResidualBlockUpsample
(
N
,
N
,
2
),
ResidualBlock
(
N
,
N
),
subpel_conv3x3
(
N
,
3
,
2
),
)
@
classmethod
def
from_state_dict
(
cls
,
state_dict
):
"""Return a new model instance from `state_dict`."""
N
=
state_dict
[
"g_a.0.conv1.weight"
].
size
(
0
)
net
=
cls
(
N
)
net
.
load_state_dict
(
state_dict
)
return
net
class
Cheng2020Attention
(
Cheng2020Anchor
):
"""Self-attention model variant from `"Learned Image Compression with
Discretized Gaussian Mixture Likelihoods and Attention Modules"
<https://arxiv.org/abs/2001.01568>`_, by Zhengxue Cheng, Heming Sun, Masaru
Takeuchi, Jiro Katto.
Uses self-attention, residual blocks with small convolutions (3x3 and 1x1),
and sub-pixel convolutions for up-sampling.
Args:
N (int): Number of channels
"""
def
__init__
(
self
,
N
=
192
,
**
kwargs
):
super
().
__init__
(
N
=
N
,
**
kwargs
)
self
.
g_a
=
nn
.
Sequential
(
ResidualBlockWithStride
(
3
,
N
,
stride
=
2
),
ResidualBlock
(
N
,
N
),
ResidualBlockWithStride
(
N
,
N
,
stride
=
2
),
AttentionBlock
(
N
),
ResidualBlock
(
N
,
N
),
ResidualBlockWithStride
(
N
,
N
,
stride
=
2
),
ResidualBlock
(
N
,
N
),
conv3x3
(
N
,
N
,
stride
=
2
),
AttentionBlock
(
N
),
)
self
.
g_s
=
nn
.
Sequential
(
AttentionBlock
(
N
),
ResidualBlock
(
N
,
N
),
ResidualBlockUpsample
(
N
,
N
,
2
),
ResidualBlock
(
N
,
N
),
ResidualBlockUpsample
(
N
,
N
,
2
),
AttentionBlock
(
N
),
ResidualBlock
(
N
,
N
),
ResidualBlockUpsample
(
N
,
N
,
2
),
ResidualBlock
(
N
,
N
),
subpel_conv3x3
(
N
,
3
,
2
),
)
codes/compressai/ops/__init__.py
0 → 100644
View file @
9e459ea3
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.bound_ops
import
LowerBound
from
.ops
import
ste_round
from
.parametrizers
import
NonNegativeParametrizer
__all__
=
[
"ste_round"
,
"LowerBound"
,
"NonNegativeParametrizer"
]
codes/compressai/ops/bound_ops.py
0 → 100644
View file @
9e459ea3
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.nn
as
nn
class
LowerBoundFunction
(
torch
.
autograd
.
Function
):
"""Autograd function for the `LowerBound` operator."""
@
staticmethod
def
forward
(
ctx
,
input_
,
bound
):
ctx
.
save_for_backward
(
input_
,
bound
)
return
torch
.
max
(
input_
,
bound
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input_
,
bound
=
ctx
.
saved_tensors
pass_through_if
=
(
input_
>=
bound
)
|
(
grad_output
<
0
)
return
pass_through_if
.
type
(
grad_output
.
dtype
)
*
grad_output
,
None
class
LowerBound
(
nn
.
Module
):
"""Lower bound operator, computes `torch.max(x, bound)` with a custom
gradient.
The derivative is replaced by the identity function when `x` is moved
towards the `bound`, otherwise the gradient is kept to zero.
"""
def
__init__
(
self
,
bound
):
super
().
__init__
()
self
.
register_buffer
(
"bound"
,
torch
.
Tensor
([
float
(
bound
)]))
@
torch
.
jit
.
unused
def
lower_bound
(
self
,
x
):
return
LowerBoundFunction
.
apply
(
x
,
self
.
bound
)
def
forward
(
self
,
x
):
if
torch
.
jit
.
is_scripting
():
return
torch
.
max
(
x
,
self
.
bound
)
return
self
.
lower_bound
(
x
)
codes/compressai/ops/ops.py
0 → 100644
View file @
9e459ea3
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
def
ste_round
(
x
):
"""
Rounding with non-zero gradients. Gradients are approximated by replacing
the derivative by the identity function.
Used in `"Lossy Image Compression with Compressive Autoencoders"
<https://arxiv.org/abs/1703.00395>`_
.. note::
Implemented with the pytorch `detach()` reparametrization trick:
`x_round = x_round - x.detach() + x`
"""
return
torch
.
round
(
x
)
-
x
.
detach
()
+
x
codes/compressai/ops/parametrizers.py
0 → 100644
View file @
9e459ea3
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.nn
as
nn
from
.bound_ops
import
LowerBound
class
NonNegativeParametrizer
(
nn
.
Module
):
"""
Non negative reparametrization.
Used for stability during training.
"""
def
__init__
(
self
,
minimum
=
0
,
reparam_offset
=
2
**
-
18
):
super
().
__init__
()
self
.
minimum
=
float
(
minimum
)
self
.
reparam_offset
=
float
(
reparam_offset
)
pedestal
=
self
.
reparam_offset
**
2
self
.
register_buffer
(
"pedestal"
,
torch
.
Tensor
([
pedestal
]))
bound
=
(
self
.
minimum
+
self
.
reparam_offset
**
2
)
**
0.5
self
.
lower_bound
=
LowerBound
(
bound
)
def
init
(
self
,
x
):
return
torch
.
sqrt
(
torch
.
max
(
x
+
self
.
pedestal
,
self
.
pedestal
))
def
forward
(
self
,
x
):
out
=
self
.
lower_bound
(
x
)
out
=
out
**
2
-
self
.
pedestal
return
out
codes/compressai/transforms/__init__.py
0 → 100644
View file @
9e459ea3
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.transforms
import
*
codes/compressai/transforms/functional.py
0 → 100644
View file @
9e459ea3
from
typing
import
Tuple
,
Union
import
torch
import
torch.nn.functional
as
F
from
torch
import
Tensor
YCBCR_WEIGHTS
=
{
# Spec: (K_r, K_g, K_b) with K_g = 1 - K_r - K_b
"ITU-R_BT.709"
:
(
0.2126
,
0.7152
,
0.0722
)
}
def
_check_input_tensor
(
tensor
:
Tensor
)
->
None
:
if
(
not
isinstance
(
tensor
,
Tensor
)
or
not
tensor
.
is_floating_point
()
or
not
len
(
tensor
.
size
())
in
(
3
,
4
)
or
not
tensor
.
size
(
-
3
)
==
3
):
raise
ValueError
(
"Expected a 3D or 4D tensor with shape (Nx3xHxW) or (3xHxW) as input"
)
def
rgb2ycbcr
(
rgb
:
Tensor
)
->
Tensor
:
"""RGB to YCbCr conversion for torch Tensor.
Using ITU-R BT.709 coefficients.
Args:
rgb (torch.Tensor): 3D or 4D floating point RGB tensor
Returns:
ycbcr (torch.Tensor): converted tensor
"""
_check_input_tensor
(
rgb
)
r
,
g
,
b
=
rgb
.
chunk
(
3
,
-
3
)
Kr
,
Kg
,
Kb
=
YCBCR_WEIGHTS
[
"ITU-R_BT.709"
]
y
=
Kr
*
r
+
Kg
*
g
+
Kb
*
b
cb
=
0.5
*
(
b
-
y
)
/
(
1
-
Kb
)
+
0.5
cr
=
0.5
*
(
r
-
y
)
/
(
1
-
Kr
)
+
0.5
ycbcr
=
torch
.
cat
((
y
,
cb
,
cr
),
dim
=-
3
)
return
ycbcr
def
ycbcr2rgb
(
ycbcr
:
Tensor
)
->
Tensor
:
"""YCbCr to RGB conversion for torch Tensor.
Using ITU-R BT.709 coefficients.
Args:
ycbcr (torch.Tensor): 3D or 4D floating point RGB tensor
Returns:
rgb (torch.Tensor): converted tensor
"""
_check_input_tensor
(
ycbcr
)
y
,
cb
,
cr
=
ycbcr
.
chunk
(
3
,
-
3
)
Kr
,
Kg
,
Kb
=
YCBCR_WEIGHTS
[
"ITU-R_BT.709"
]
r
=
y
+
(
2
-
2
*
Kr
)
*
(
cr
-
0.5
)
b
=
y
+
(
2
-
2
*
Kb
)
*
(
cb
-
0.5
)
g
=
(
y
-
Kr
*
r
-
Kb
*
b
)
/
Kg
rgb
=
torch
.
cat
((
r
,
g
,
b
),
dim
=-
3
)
return
rgb
def
yuv_444_to_420
(
yuv
:
Union
[
Tensor
,
Tuple
[
Tensor
,
Tensor
,
Tensor
]],
mode
:
str
=
"avg_pool"
,
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
"""Convert a 444 tensor to a 420 representation.
Args:
yuv (torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)): 444
input to be downsampled. Takes either a (Nx3xHxW) tensor or a tuple
of 3 (Nx1xHxW) tensors.
mode (str): algorithm used for downsampling: ``'avg_pool'``. Default
``'avg_pool'``
Returns:
(torch.Tensor, torch.Tensor, torch.Tensor): Converted 420
"""
if
mode
not
in
(
"avg_pool"
,):
raise
ValueError
(
f
'Invalid downsampling mode "
{
mode
}
".'
)
if
mode
==
"avg_pool"
:
def
_downsample
(
tensor
):
return
F
.
avg_pool2d
(
tensor
,
kernel_size
=
2
,
stride
=
2
)
if
isinstance
(
yuv
,
torch
.
Tensor
):
y
,
u
,
v
=
yuv
.
chunk
(
3
,
1
)
else
:
y
,
u
,
v
=
yuv
return
(
y
,
_downsample
(
u
),
_downsample
(
v
))
def
yuv_420_to_444
(
yuv
:
Tuple
[
Tensor
,
Tensor
,
Tensor
],
mode
:
str
=
"bilinear"
,
return_tuple
:
bool
=
False
,
)
->
Union
[
Tensor
,
Tuple
[
Tensor
,
Tensor
,
Tensor
]]:
"""Convert a 420 input to a 444 representation.
Args:
yuv (torch.Tensor, torch.Tensor, torch.Tensor): 420 input frames in
(Nx1xHxW) format
mode (str): algorithm used for upsampling: ``'bilinear'`` |
``'nearest'`` Default ``'bilinear'``
return_tuple (bool): return input as tuple of tensors instead of a
concatenated tensor, 3 (Nx1xHxW) tensors instead of one (Nx3xHxW)
tensor (default: False)
Returns:
(torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)): Converted
444
"""
if
len
(
yuv
)
!=
3
or
any
(
not
isinstance
(
c
,
torch
.
Tensor
)
for
c
in
yuv
):
raise
ValueError
(
"Expected a tuple of 3 torch tensors"
)
if
mode
not
in
(
"bilinear"
,
"nearest"
):
raise
ValueError
(
f
'Invalid upsampling mode "
{
mode
}
".'
)
if
mode
in
(
"bilinear"
,
"nearest"
):
def
_upsample
(
tensor
):
return
F
.
interpolate
(
tensor
,
scale_factor
=
2
,
mode
=
mode
,
align_corners
=
False
)
y
,
u
,
v
=
yuv
u
,
v
=
_upsample
(
u
),
_upsample
(
v
)
if
return_tuple
:
return
y
,
u
,
v
return
torch
.
cat
((
y
,
u
,
v
),
dim
=
1
)
codes/compressai/transforms/transforms.py
0 → 100644
View file @
9e459ea3
from
.
import
functional
as
F_transforms
__all__
=
[
"RGB2YCbCr"
,
"YCbCr2RGB"
,
"YUV444To420"
,
"YUV420To444"
,
]
class
RGB2YCbCr
:
"""Convert a RGB tensor to YCbCr.
The tensor is expected to be in the [0, 1] floating point range, with a
shape of (3xHxW) or (Nx3xHxW).
"""
def
__call__
(
self
,
rgb
):
"""
Args:
rgb (torch.Tensor): 3D or 4D floating point RGB tensor
Returns:
ycbcr(torch.Tensor): converted tensor
"""
return
F_transforms
.
rgb2ycbcr
(
rgb
)
def
__repr__
(
self
):
return
f
"
{
self
.
__class__
.
__name__
}
()"
class
YCbCr2RGB
:
"""Convert a YCbCr tensor to RGB.
The tensor is expected to be in the [0, 1] floating point range, with a
shape of (3xHxW) or (Nx3xHxW).
"""
def
__call__
(
self
,
ycbcr
):
"""
Args:
ycbcr(torch.Tensor): 3D or 4D floating point RGB tensor
Returns:
rgb(torch.Tensor): converted tensor
"""
return
F_transforms
.
ycbcr2rgb
(
ycbcr
)
def
__repr__
(
self
):
return
f
"
{
self
.
__class__
.
__name__
}
()"
class
YUV444To420
:
"""Convert a YUV 444 tensor to a 420 representation.
Args:
mode (str): algorithm used for downsampling: ``'avg_pool'``. Default
``'avg_pool'``
Example:
>>> x = torch.rand(1, 3, 32, 32)
>>> y, u, v = YUV444To420()(x)
>>> y.size() # 1, 1, 32, 32
>>> u.size() # 1, 1, 16, 16
"""
def
__init__
(
self
,
mode
:
str
=
"avg_pool"
):
self
.
mode
=
str
(
mode
)
def
__call__
(
self
,
yuv
):
"""
Args:
yuv (torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)):
444 input to be downsampled. Takes either a (Nx3xHxW) tensor or
a tuple of 3 (Nx1xHxW) tensors.
Returns:
(torch.Tensor, torch.Tensor, torch.Tensor): Converted 420
"""
return
F_transforms
.
yuv_444_to_420
(
yuv
,
mode
=
self
.
mode
)
def
__repr__
(
self
):
return
f
"
{
self
.
__class__
.
__name__
}
()"
class
YUV420To444
:
"""Convert a YUV 420 input to a 444 representation.
Args:
mode (str): algorithm used for upsampling: ``'bilinear'`` | ``'nearest'``.
Default ``'bilinear'``
return_tuple (bool): return input as tuple of tensors instead of a
concatenated tensor, 3 (Nx1xHxW) tensors instead of one (Nx3xHxW)
tensor (default: False)
Example:
>>> y = torch.rand(1, 1, 32, 32)
>>> u, v = torch.rand(1, 1, 16, 16), torch.rand(1, 1, 16, 16)
>>> x = YUV420To444()((y, u, v))
>>> x.size() # 1, 3, 32, 32
"""
def
__init__
(
self
,
mode
:
str
=
"bilinear"
,
return_tuple
:
bool
=
False
):
self
.
mode
=
str
(
mode
)
self
.
return_tuple
=
bool
(
return_tuple
)
def
__call__
(
self
,
yuv
):
"""
Args:
yuv (torch.Tensor, torch.Tensor, torch.Tensor): 420 input frames in
(Nx1xHxW) format
Returns:
(torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)): Converted
444
"""
return
F_transforms
.
yuv_420_to_444
(
yuv
,
return_tuple
=
self
.
return_tuple
)
def
__repr__
(
self
):
return
f
"
{
self
.
__class__
.
__name__
}
(return_tuple=
{
self
.
return_tuple
}
)"
codes/compressai/utils/__init__.py
0 → 100644
View file @
9e459ea3
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
codes/compressai/utils/bench/__init__.py
0 → 100644
View file @
9e459ea3
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
codes/compressai/utils/bench/__main__.py
0 → 100644
View file @
9e459ea3
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Collect performance metrics of published traditional or end-to-end image
codecs.
"""
import
argparse
import
json
import
multiprocessing
as
mp
import
os
import
sys
from
collections
import
defaultdict
from
itertools
import
starmap
from
typing
import
List
from
.codecs
import
AV1
,
BPG
,
HM
,
JPEG
,
JPEG2000
,
TFCI
,
VTM
,
Codec
,
WebP
# from torchvision.datasets.folder
IMG_EXTENSIONS
=
(
".jpg"
,
".jpeg"
,
".png"
,
".ppm"
,
".bmp"
,
".pgm"
,
".tif"
,
".tiff"
,
".webp"
,
)
codecs
=
[
JPEG
,
WebP
,
JPEG2000
,
BPG
,
TFCI
,
VTM
,
HM
,
AV1
]
# we need the quality index (not value) to compute the stats later
def
func
(
codec
,
i
,
*
args
):
rv
=
codec
.
run
(
*
args
)
return
i
,
rv
def
collect
(
codec
:
Codec
,
dataset
:
str
,
qualities
:
List
[
int
],
num_jobs
:
int
=
1
):
filepaths
=
[
os
.
path
.
join
(
dataset
,
f
)
for
f
in
os
.
listdir
(
dataset
)
if
os
.
path
.
splitext
(
f
)[
-
1
].
lower
()
in
IMG_EXTENSIONS
]
# print(filepaths)
pool
=
mp
.
Pool
(
num_jobs
)
if
num_jobs
>
1
else
None
if
len
(
filepaths
)
==
0
:
print
(
"No images found in the dataset directory"
)
sys
.
exit
(
1
)
args
=
[(
codec
,
i
,
f
,
q
)
for
i
,
q
in
enumerate
(
qualities
)
for
f
in
filepaths
]
if
pool
:
rv
=
pool
.
starmap
(
func
,
args
)
else
:
rv
=
list
(
starmap
(
func
,
args
))
results
=
[
defaultdict
(
float
)
for
_
in
range
(
len
(
qualities
))]
for
i
,
metrics
in
rv
:
for
k
,
v
in
metrics
.
items
():
results
[
i
][
k
]
+=
v
for
i
,
_
in
enumerate
(
results
):
for
k
,
v
in
results
[
i
].
items
():
results
[
i
][
k
]
=
v
/
len
(
filepaths
)
# list of dict -> dict of list
out
=
defaultdict
(
list
)
for
r
in
results
:
for
k
,
v
in
r
.
items
():
out
[
k
].
append
(
v
)
return
out
def
setup_args
():
description
=
"Collect codec metrics."
parser
=
argparse
.
ArgumentParser
(
description
=
description
)
subparsers
=
parser
.
add_subparsers
(
dest
=
"codec"
,
help
=
"Select codec"
)
subparsers
.
required
=
True
return
parser
,
subparsers
def
setup_common_args
(
parser
):
parser
.
add_argument
(
"dataset"
,
type
=
str
)
parser
.
add_argument
(
"-j"
,
"--num-jobs"
,
type
=
int
,
metavar
=
"N"
,
default
=
1
,
help
=
"Number of parallel jobs (default: %(default)s)"
,
)
parser
.
add_argument
(
"-q"
,
"--quality"
,
dest
=
"qualities"
,
metavar
=
"Q"
,
default
=
[
5
,
10
,
20
,
30
,
40
,
50
,
60
,
70
,
80
,
90
],
nargs
=
"*"
,
type
=
int
,
help
=
"quality parameter (default: %(default)s)"
,
)
# [3,5,10,15,17,20,22,25,27,30,32,35,37,40,42,45,47,50],
# new added
parser
.
add_argument
(
"--name"
,
dest
=
"name"
,
default
=
"ans"
,
type
=
str
,
help
=
"name for json"
,
)
def
main
(
argv
):
import
time
start
=
time
.
time
()
parser
,
subparsers
=
setup_args
()
for
c
in
codecs
:
cparser
=
subparsers
.
add_parser
(
c
.
__name__
.
lower
(),
help
=
f
"
{
c
.
__name__
}
"
)
setup_common_args
(
cparser
)
c
.
setup_args
(
cparser
)
args
=
parser
.
parse_args
(
argv
)
codec_cls
=
next
(
c
for
c
in
codecs
if
c
.
__name__
.
lower
()
==
args
.
codec
)
codec
=
codec_cls
(
args
)
results
=
collect
(
codec
,
args
.
dataset
,
args
.
qualities
,
args
.
num_jobs
)
output
=
{
"name"
:
codec
.
name
,
"description"
:
codec
.
description
,
"results"
:
results
,
}
print
(
json
.
dumps
(
output
,
indent
=
2
))
end
=
time
.
time
()
print
(
'total time:'
,
end
-
start
)
output_dir
=
'/home/felix/disk2/compressai_v2/codes/results/log'
output_json_path
=
os
.
path
.
join
(
output_dir
,
args
.
name
+
'.json'
)
with
open
(
output_json_path
,
'w'
)
as
f
:
json
.
dump
(
output
,
f
,
indent
=
2
)
if
__name__
==
"__main__"
:
main
(
sys
.
argv
[
1
:])
codes/compressai/utils/bench/codecs.py
0 → 100644
View file @
9e459ea3
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
io
import
os
import
platform
import
subprocess
import
sys
import
time
from
tempfile
import
mkstemp
from
typing
import
Tuple
,
Union
import
numpy
as
np
import
PIL
import
PIL.Image
as
Image
import
torch
from
pytorch_msssim
import
ms_ssim
from
compressai.transforms.functional
import
rgb2ycbcr
,
ycbcr2rgb
# from torchvision.datasets.folder
IMG_EXTENSIONS
=
(
".jpg"
,
".jpeg"
,
".png"
,
".ppm"
,
".bmp"
,
".pgm"
,
".tif"
,
".tiff"
,
".webp"
,
)
def
filesize
(
filepath
:
str
)
->
int
:
"""Return file size in bits of `filepath`."""
if
not
os
.
path
.
isfile
(
filepath
):
raise
ValueError
(
f
'Invalid file "
{
filepath
}
".'
)
return
os
.
stat
(
filepath
).
st_size
def
read_image
(
filepath
:
str
,
mode
:
str
=
"RGB"
)
->
np
.
array
:
"""Return PIL image in the specified `mode` format. """
if
not
os
.
path
.
isfile
(
filepath
):
raise
ValueError
(
f
'Invalid file "
{
filepath
}
".'
)
return
Image
.
open
(
filepath
).
convert
(
mode
)
def
compute_metrics
(
a
:
Union
[
np
.
array
,
Image
.
Image
],
b
:
Union
[
np
.
array
,
Image
.
Image
],
max_val
:
float
=
255.0
,
)
->
Tuple
[
float
,
float
]:
"""Returns PSNR and MS-SSIM between images `a` and `b`. """
if
isinstance
(
a
,
Image
.
Image
):
a
=
np
.
asarray
(
a
)
if
isinstance
(
b
,
Image
.
Image
):
b
=
np
.
asarray
(
b
)
a
=
torch
.
from_numpy
(
a
.
copy
()).
float
().
unsqueeze
(
0
)
if
a
.
size
(
3
)
==
3
:
a
=
a
.
permute
(
0
,
3
,
1
,
2
)
b
=
torch
.
from_numpy
(
b
.
copy
()).
float
().
unsqueeze
(
0
)
if
b
.
size
(
3
)
==
3
:
b
=
b
.
permute
(
0
,
3
,
1
,
2
)
mse
=
torch
.
mean
((
a
-
b
)
**
2
).
item
()
p
=
20
*
np
.
log10
(
max_val
)
-
10
*
np
.
log10
(
mse
)
m
=
ms_ssim
(
a
,
b
,
data_range
=
max_val
).
item
()
return
p
,
m
def
run_command
(
cmd
,
ignore_returncodes
=
None
):
cmd
=
[
str
(
c
)
for
c
in
cmd
]
try
:
rv
=
subprocess
.
check_output
(
cmd
)
return
rv
.
decode
(
"ascii"
)
except
subprocess
.
CalledProcessError
as
err
:
if
ignore_returncodes
is
not
None
and
err
.
returncode
in
ignore_returncodes
:
return
err
.
output
print
(
err
.
output
.
decode
(
"utf-8"
))
sys
.
exit
(
1
)
def
_get_ffmpeg_version
():
rv
=
run_command
([
"ffmpeg"
,
"-version"
])
return
rv
.
split
()[
2
]
def
_get_bpg_version
(
encoder_path
):
rv
=
run_command
([
encoder_path
,
"-h"
],
ignore_returncodes
=
[
1
])
return
rv
.
split
()[
4
]
class
Codec
:
"""Abstract base class"""
_description
=
None
def
__init__
(
self
,
args
):
self
.
_set_args
(
args
)
def
_set_args
(
self
,
args
):
return
args
@
classmethod
def
setup_args
(
cls
,
parser
):
pass
@
property
def
description
(
self
):
return
self
.
_description
@
property
def
name
(
self
):
raise
NotImplementedError
()
def
_load_img
(
self
,
img
):
return
os
.
path
.
abspath
(
img
)
def
_run
(
self
,
img
,
quality
,
*
args
,
**
kwargs
):
raise
NotImplementedError
()
def
run
(
self
,
img
,
quality
,
*
args
,
**
kwargs
):
img
=
self
.
_load_img
(
img
)
return
self
.
_run
(
img
,
quality
,
*
args
,
**
kwargs
)
class
PillowCodec
(
Codec
):
"""Abastract codec based on Pillow bindings."""
fmt
=
None
@
property
def
name
(
self
):
raise
NotImplementedError
()
def
_load_img
(
self
,
img
):
return
read_image
(
img
)
def
_run
(
self
,
img
,
quality
,
return_rec
=
False
,
return_metrics
=
True
):
start
=
time
.
time
()
tmp
=
io
.
BytesIO
()
img
.
save
(
tmp
,
format
=
self
.
fmt
,
quality
=
int
(
quality
))
enc_time
=
time
.
time
()
-
start
tmp
.
seek
(
0
)
size
=
tmp
.
getbuffer
().
nbytes
start
=
time
.
time
()
rec
=
Image
.
open
(
tmp
)
rec
.
load
()
dec_time
=
time
.
time
()
-
start
bpp_val
=
float
(
size
)
*
8
/
(
img
.
size
[
0
]
*
img
.
size
[
1
])
out
=
{
"bpp"
:
bpp_val
,
"encoding_time"
:
enc_time
,
"decoding_time"
:
dec_time
,
}
if
return_metrics
:
psnr_val
,
msssim_val
=
compute_metrics
(
rec
,
img
)
out
[
"psnr"
]
=
psnr_val
out
[
"ms-ssim"
]
=
msssim_val
if
return_rec
:
return
out
,
rec
return
out
class
JPEG
(
PillowCodec
):
"""Use libjpeg linked in Pillow"""
fmt
=
"jpeg"
_description
=
f
"JPEG. Pillow version
{
PIL
.
__version__
}
"
@
property
def
name
(
self
):
return
"JPEG"
class
WebP
(
PillowCodec
):
"""Use libwebp linked in Pillow"""
fmt
=
"webp"
_description
=
f
"WebP. Pillow version
{
PIL
.
__version__
}
"
@
property
def
name
(
self
):
return
"WebP"
class
BinaryCodec
(
Codec
):
"""Call a external binary."""
fmt
=
None
def
_run
(
self
,
img
,
quality
,
return_rec
=
False
,
return_metrics
=
True
):
fd0
,
png_filepath
=
mkstemp
(
suffix
=
".png"
)
fd1
,
out_filepath
=
mkstemp
(
suffix
=
self
.
fmt
)
# Encode
start
=
time
.
time
()
run_command
(
self
.
_get_encode_cmd
(
img
,
quality
,
out_filepath
))
enc_time
=
time
.
time
()
-
start
size
=
filesize
(
out_filepath
)
# Decode
start
=
time
.
time
()
run_command
(
self
.
_get_decode_cmd
(
out_filepath
,
png_filepath
))
dec_time
=
time
.
time
()
-
start
# Read image
img
=
read_image
(
img
)
rec
=
read_image
(
png_filepath
)
os
.
close
(
fd0
)
os
.
remove
(
png_filepath
)
os
.
close
(
fd1
)
os
.
remove
(
out_filepath
)
bpp_val
=
float
(
size
)
*
8
/
(
img
.
size
[
0
]
*
img
.
size
[
1
])
out
=
{
"bpp"
:
bpp_val
,
"encoding_time"
:
enc_time
,
"decoding_time"
:
dec_time
,
}
if
return_metrics
:
psnr_val
,
msssim_val
=
compute_metrics
(
rec
,
img
)
out
[
"psnr"
]
=
psnr_val
out
[
"ms-ssim"
]
=
msssim_val
if
return_rec
:
return
out
,
rec
return
out
def
_get_encode_cmd
(
self
,
img
,
quality
,
out_filepath
):
raise
NotImplementedError
()
def
_get_decode_cmd
(
self
,
out_filepath
,
rec_filepath
):
raise
NotImplementedError
()
class
JPEG2000
(
BinaryCodec
):
"""Use ffmpeg version.
(Not built-in support in default Pillow builds)
"""
fmt
=
".jp2"
@
property
def
name
(
self
):
return
"JPEG2000"
@
property
def
description
(
self
):
return
f
"JPEG2000. ffmpeg version
{
_get_ffmpeg_version
()
}
"
def
_get_encode_cmd
(
self
,
img
,
quality
,
out_filepath
):
cmd
=
[
"ffmpeg"
,
"-loglevel"
,
"panic"
,
"-y"
,
"-i"
,
img
,
"-vcodec"
,
"jpeg2000"
,
"-pix_fmt"
,
"yuv444p"
,
"-c:v"
,
"libopenjpeg"
,
"-compression_level"
,
quality
,
out_filepath
,
]
return
cmd
# jpeg2000
def
_get_decode_cmd
(
self
,
out_filepath
,
rec_filepath
):
cmd
=
[
"ffmpeg"
,
"-loglevel"
,
"panic"
,
"-y"
,
"-i"
,
out_filepath
,
rec_filepath
]
return
cmd
class
BPG
(
BinaryCodec
):
"""BPG from Fabrice Bellard."""
fmt
=
".bpg"
@
property
def
name
(
self
):
return
(
f
"BPG
{
self
.
bitdepth
}
b
{
self
.
subsampling_mode
}
{
self
.
encoder
}
"
f
"
{
self
.
color_mode
}
"
)
@
property
def
description
(
self
):
return
f
"BPG. BPG version
{
_get_bpg_version
(
self
.
encoder_path
)
}
"
@
classmethod
def
setup_args
(
cls
,
parser
):
super
().
setup_args
(
parser
)
parser
.
add_argument
(
"-m"
,
choices
=
[
"420"
,
"444"
],
default
=
"444"
,
help
=
"subsampling mode (default: %(default)s)"
,
)
parser
.
add_argument
(
"-b"
,
choices
=
[
"8"
,
"10"
],
default
=
"8"
,
help
=
"bitdepth (default: %(default)s)"
,
)
parser
.
add_argument
(
"-c"
,
choices
=
[
"rgb"
,
"ycbcr"
],
default
=
"ycbcr"
,
help
=
"colorspace (default: %(default)s)"
,
)
parser
.
add_argument
(
"-e"
,
choices
=
[
"jctvc"
,
"x265"
],
default
=
"x265"
,
help
=
"HEVC implementation (default: %(default)s)"
,
)
parser
.
add_argument
(
"--encoder-path"
,
default
=
"bpgenc"
,
help
=
"BPG encoder path"
)
parser
.
add_argument
(
"--decoder-path"
,
default
=
"bpgdec"
,
help
=
"BPG decoder path"
)
def
_set_args
(
self
,
args
):
args
=
super
().
_set_args
(
args
)
self
.
color_mode
=
args
.
c
self
.
encoder
=
args
.
e
self
.
subsampling_mode
=
args
.
m
self
.
bitdepth
=
args
.
b
self
.
encoder_path
=
"/home/felix/disk2/libbpg/bpgenc"
#args.encoder_path
self
.
decoder_path
=
"/home/felix/disk2/libbpg/bpgdec"
return
args
def
_get_encode_cmd
(
self
,
img
,
quality
,
out_filepath
):
if
not
0
<=
quality
<=
51
:
raise
ValueError
(
f
"Invalid quality value:
{
quality
}
(0,51)"
)
cmd
=
[
self
.
encoder_path
,
"-o"
,
out_filepath
,
"-q"
,
str
(
quality
),
"-f"
,
self
.
subsampling_mode
,
"-e"
,
self
.
encoder
,
"-c"
,
self
.
color_mode
,
"-b"
,
self
.
bitdepth
,
img
,
]
return
cmd
def
_get_decode_cmd
(
self
,
out_filepath
,
rec_filepath
):
cmd
=
[
self
.
decoder_path
,
"-o"
,
rec_filepath
,
out_filepath
]
return
cmd
class
TFCI
(
BinaryCodec
):
"""Tensorflow image compression format from tensorflow/compression"""
fmt
=
".tfci"
_models
=
[
"bmshj2018-factorized-mse"
,
"bmshj2018-hyperprior-mse"
,
"mbt2018-mean-mse"
,
]
@
property
def
description
(
self
):
return
"TFCI"
@
property
def
name
(
self
):
return
f
"
{
self
.
model
}
"
@
classmethod
def
setup_args
(
cls
,
parser
):
super
().
setup_args
(
parser
)
parser
.
add_argument
(
"-m"
,
"--model"
,
choices
=
cls
.
_models
,
default
=
cls
.
_models
[
0
],
help
=
"model architecture (default: %(default)s)"
,
)
parser
.
add_argument
(
"-p"
,
"--path"
,
required
=
True
,
help
=
"tfci python script path (default: %(default)s)"
,
)
def
_set_args
(
self
,
args
):
args
=
super
().
_set_args
(
args
)
self
.
model
=
args
.
model
self
.
tfci_path
=
args
.
path
return
args
def
_get_encode_cmd
(
self
,
img
,
quality
,
out_filepath
):
if
not
1
<=
quality
<=
8
:
raise
ValueError
(
f
"Invalid quality value:
{
quality
}
(1, 8)"
)
cmd
=
[
sys
.
executable
,
self
.
tfci_path
,
"compress"
,
f
"
{
self
.
model
}
-
{
quality
:
d
}
"
,
img
,
out_filepath
,
]
return
cmd
def
_get_decode_cmd
(
self
,
out_filepath
,
rec_filepath
):
cmd
=
[
sys
.
executable
,
self
.
tfci_path
,
"decompress"
,
out_filepath
,
rec_filepath
]
return
cmd
def
get_vtm_encoder_path
(
build_dir
):
system
=
platform
.
system
()
try
:
elfnames
=
{
"Darwin"
:
"EncoderApp"
,
"Linux"
:
"EncoderAppStatic"
}
return
os
.
path
.
join
(
build_dir
,
elfnames
[
system
])
except
KeyError
as
err
:
raise
RuntimeError
(
f
'Unsupported platform "
{
system
}
"'
)
from
err
def
get_vtm_decoder_path
(
build_dir
):
system
=
platform
.
system
()
try
:
elfnames
=
{
"Darwin"
:
"DecoderApp"
,
"Linux"
:
"DecoderAppStatic"
}
return
os
.
path
.
join
(
build_dir
,
elfnames
[
system
])
except
KeyError
as
err
:
raise
RuntimeError
(
f
'Unsupported platform "
{
system
}
"'
)
from
err
class
VTM
(
Codec
):
"""VTM: VVC reference software"""
fmt
=
".bin"
@
property
def
description
(
self
):
return
"VTM"
@
property
def
name
(
self
):
return
"VTM"
@
classmethod
def
setup_args
(
cls
,
parser
):
super
().
setup_args
(
parser
)
parser
.
add_argument
(
"-b"
,
"--build-dir"
,
type
=
str
,
default
=
"/home/felix/disk2/VVCSoftware_VTM/bin"
,
help
=
"VTM build dir"
,
)
parser
.
add_argument
(
"-c"
,
"--config"
,
type
=
str
,
default
=
"/home/felix/disk2/VVCSoftware_VTM/cfg/encoder_intra_vtm.cfg"
,
help
=
"VTM config file"
,
)
parser
.
add_argument
(
"--rgb"
,
action
=
"store_true"
,
help
=
"Use RGB color space (over YCbCr)"
)
def
_set_args
(
self
,
args
):
args
=
super
().
_set_args
(
args
)
self
.
encoder_path
=
get_vtm_encoder_path
(
args
.
build_dir
)
self
.
decoder_path
=
get_vtm_decoder_path
(
args
.
build_dir
)
self
.
config_path
=
args
.
config
self
.
rgb
=
args
.
rgb
return
args
def
_run
(
self
,
img
,
quality
,
return_rec
=
False
,
return_metrics
=
True
):
if
not
0
<=
quality
<=
63
:
raise
ValueError
(
f
"Invalid quality value:
{
quality
}
(0,63)"
)
# Taking 8bit input for now
bitdepth
=
8
# Convert input image to yuv 444 file
arr
=
np
.
asarray
(
read_image
(
img
))
fd
,
yuv_path
=
mkstemp
(
suffix
=
".yuv"
)
out_filepath
=
os
.
path
.
splitext
(
yuv_path
)[
0
]
+
".bin"
arr
=
arr
.
transpose
((
2
,
0
,
1
))
# color channel first
if
not
self
.
rgb
:
# convert rgb content to YCbCr
rgb
=
torch
.
from_numpy
(
arr
.
copy
()).
float
()
/
(
2
**
bitdepth
-
1
)
arr
=
np
.
clip
(
rgb2ycbcr
(
rgb
).
numpy
(),
0
,
1
)
arr
=
(
arr
*
(
2
**
bitdepth
-
1
)).
astype
(
np
.
uint8
)
with
open
(
yuv_path
,
"wb"
)
as
f
:
f
.
write
(
arr
.
tobytes
())
# Encode
height
,
width
=
arr
.
shape
[
1
:]
cmd
=
[
self
.
encoder_path
,
"-i"
,
yuv_path
,
"-c"
,
self
.
config_path
,
"-q"
,
quality
,
"-o"
,
"/dev/null"
,
"-b"
,
out_filepath
,
"-wdt"
,
width
,
"-hgt"
,
height
,
"-fr"
,
"1"
,
"-f"
,
"1"
,
"--InputChromaFormat=444"
,
"--InputBitDepth=8"
,
"--ConformanceMode=1"
,
]
if
self
.
rgb
:
cmd
+=
[
"--InputColourSpaceConvert=RGBtoGBR"
,
"--SNRInternalColourSpace=1"
,
"--OutputInternalColourSpace=0"
,
]
start
=
time
.
time
()
run_command
(
cmd
)
enc_time
=
time
.
time
()
-
start
# cleanup encoder input
os
.
close
(
fd
)
os
.
unlink
(
yuv_path
)
# Decode
cmd
=
[
self
.
decoder_path
,
"-b"
,
out_filepath
,
"-o"
,
yuv_path
,
"-d"
,
8
]
if
self
.
rgb
:
cmd
.
append
(
"--OutputInternalColourSpace=GBRtoRGB"
)
start
=
time
.
time
()
run_command
(
cmd
)
dec_time
=
time
.
time
()
-
start
# Compute PSNR
rec_arr
=
np
.
fromfile
(
yuv_path
,
dtype
=
np
.
uint8
)
rec_arr
=
rec_arr
.
reshape
(
arr
.
shape
)
arr
=
arr
.
astype
(
np
.
float32
)
/
(
2
**
bitdepth
-
1
)
rec_arr
=
rec_arr
.
astype
(
np
.
float32
)
/
(
2
**
bitdepth
-
1
)
if
not
self
.
rgb
:
arr
=
ycbcr2rgb
(
torch
.
from_numpy
(
arr
.
copy
())).
numpy
()
rec_arr
=
ycbcr2rgb
(
torch
.
from_numpy
(
rec_arr
.
copy
())).
numpy
()
bpp
=
filesize
(
out_filepath
)
*
8.0
/
(
height
*
width
)
# Cleanup
os
.
unlink
(
yuv_path
)
os
.
unlink
(
out_filepath
)
out
=
{
"bpp"
:
bpp
,
"encoding_time"
:
enc_time
,
"decoding_time"
:
dec_time
,
}
if
return_metrics
:
psnr_val
,
msssim_val
=
compute_metrics
(
arr
,
rec_arr
,
max_val
=
1.0
)
out
[
"psnr"
]
=
psnr_val
out
[
"ms-ssim"
]
=
msssim_val
if
return_rec
:
rec
=
Image
.
fromarray
(
(
rec_arr
.
clip
(
0
,
1
).
transpose
(
1
,
2
,
0
)
*
255.0
).
astype
(
np
.
uint8
)
)
return
out
,
rec
return
out
class
HM
(
Codec
):
"""HM: H.265/HEVC reference software"""
fmt
=
".bin"
@
property
def
description
(
self
):
return
"HM"
@
property
def
name
(
self
):
return
"HM"
@
classmethod
def
setup_args
(
cls
,
parser
):
super
().
setup_args
(
parser
)
parser
.
add_argument
(
"-b"
,
"--build-dir"
,
type
=
str
,
required
=
True
,
help
=
"HM build dir"
,
)
parser
.
add_argument
(
"-c"
,
"--config"
,
type
=
str
,
required
=
True
,
help
=
"HM config file"
)
parser
.
add_argument
(
"--rgb"
,
action
=
"store_true"
,
help
=
"Use RGB color space (over YCbCr)"
)
def
_set_args
(
self
,
args
):
args
=
super
().
_set_args
(
args
)
self
.
encoder_path
=
os
.
path
.
join
(
args
.
build_dir
,
"TAppEncoderStatic"
)
self
.
decoder_path
=
os
.
path
.
join
(
args
.
build_dir
,
"TAppDecoderStatic"
)
self
.
config_path
=
args
.
config
self
.
rgb
=
args
.
rgb
return
args
def
_run
(
self
,
img
,
quality
,
return_rec
=
False
,
return_metrics
=
True
):
if
not
0
<=
quality
<=
51
:
raise
ValueError
(
f
"Invalid quality value:
{
quality
}
(0,51)"
)
# Convert input image to yuv 444 file
arr
=
np
.
asarray
(
read_image
(
img
))
fd
,
yuv_path
=
mkstemp
(
suffix
=
".yuv"
)
out_filepath
=
os
.
path
.
splitext
(
yuv_path
)[
0
]
+
".bin"
bitdepth
=
8
arr
=
arr
.
transpose
((
2
,
0
,
1
))
# color channel first
if
not
self
.
rgb
:
# convert rgb content to YCbCr
rgb
=
torch
.
from_numpy
(
arr
.
copy
()).
float
()
/
(
2
**
bitdepth
-
1
)
arr
=
np
.
clip
(
rgb2ycbcr
(
rgb
).
numpy
(),
0
,
1
)
arr
=
(
arr
*
(
2
**
bitdepth
-
1
)).
astype
(
np
.
uint8
)
with
open
(
yuv_path
,
"wb"
)
as
f
:
f
.
write
(
arr
.
tobytes
())
# Encode
height
,
width
=
arr
.
shape
[
1
:]
cmd
=
[
self
.
encoder_path
,
"-i"
,
yuv_path
,
"-c"
,
self
.
config_path
,
"-q"
,
quality
,
"-o"
,
"/dev/null"
,
"-b"
,
out_filepath
,
"-wdt"
,
width
,
"-hgt"
,
height
,
"-fr"
,
"1"
,
"-f"
,
"1"
,
"--InputChromaFormat=444"
,
"--InputBitDepth=8"
,
"--SEIDecodedPictureHash"
,
"--Level=5.1"
,
"--CUNoSplitIntraACT=0"
,
"--ConformanceMode=1"
,
]
if
self
.
rgb
:
cmd
+=
[
"--InputColourSpaceConvert=RGBtoGBR"
,
"--SNRInternalColourSpace=1"
,
"--OutputInternalColourSpace=0"
,
]
start
=
time
.
time
()
run_command
(
cmd
)
enc_time
=
time
.
time
()
-
start
# cleanup encoder input
os
.
close
(
fd
)
os
.
unlink
(
yuv_path
)
# Decode
cmd
=
[
self
.
decoder_path
,
"-b"
,
out_filepath
,
"-o"
,
yuv_path
,
"-d"
,
8
]
if
self
.
rgb
:
cmd
.
append
(
"--OutputInternalColourSpace=GBRtoRGB"
)
start
=
time
.
time
()
run_command
(
cmd
)
dec_time
=
time
.
time
()
-
start
# Compute PSNR
rec_arr
=
np
.
fromfile
(
yuv_path
,
dtype
=
np
.
uint8
)
rec_arr
=
rec_arr
.
reshape
(
arr
.
shape
)
arr
=
arr
.
astype
(
np
.
float32
)
/
(
2
**
bitdepth
-
1
)
rec_arr
=
rec_arr
.
astype
(
np
.
float32
)
/
(
2
**
bitdepth
-
1
)
if
not
self
.
rgb
:
arr
=
ycbcr2rgb
(
torch
.
from_numpy
(
arr
.
copy
())).
numpy
()
rec_arr
=
ycbcr2rgb
(
torch
.
from_numpy
(
rec_arr
.
copy
())).
numpy
()
bpp
=
filesize
(
out_filepath
)
*
8.0
/
(
height
*
width
)
# Cleanup
os
.
unlink
(
yuv_path
)
os
.
unlink
(
out_filepath
)
out
=
{
"bpp"
:
bpp
,
"encoding_time"
:
enc_time
,
"decoding_time"
:
dec_time
,
}
if
return_metrics
:
psnr_val
,
msssim_val
=
compute_metrics
(
arr
,
rec_arr
,
max_val
=
1.0
)
out
[
"psnr"
]
=
psnr_val
out
[
"ms-ssim"
]
=
msssim_val
if
return_rec
:
rec
=
Image
.
fromarray
(
(
rec_arr
.
clip
(
0
,
1
).
transpose
(
1
,
2
,
0
)
*
255.0
).
astype
(
np
.
uint8
)
)
return
out
,
rec
return
out
class
AV1
(
Codec
):
"""AV1: AOM reference software"""
fmt
=
".webm"
@
property
def
description
(
self
):
return
"AV1"
@
property
def
name
(
self
):
return
"AV1"
@
classmethod
def
setup_args
(
cls
,
parser
):
super
().
setup_args
(
parser
)
parser
.
add_argument
(
"-b"
,
"--build-dir"
,
type
=
str
,
required
=
True
,
help
=
"AOM binaries dir"
,
)
def
_set_args
(
self
,
args
):
args
=
super
().
_set_args
(
args
)
self
.
encoder_path
=
os
.
path
.
join
(
args
.
build_dir
,
"aomenc"
)
self
.
decoder_path
=
os
.
path
.
join
(
args
.
build_dir
,
"aomdec"
)
return
args
def
_run
(
self
,
img
,
quality
,
return_rec
=
False
,
return_metrics
=
True
):
if
not
0
<=
quality
<=
63
:
raise
ValueError
(
f
"Invalid quality value:
{
quality
}
(0,63)"
)
# Convert input image to yuv 444 file
arr
=
np
.
asarray
(
read_image
(
img
))
fd
,
yuv_path
=
mkstemp
(
suffix
=
".yuv"
)
out_filepath
=
os
.
path
.
splitext
(
yuv_path
)[
0
]
+
".webm"
bitdepth
=
8
arr
=
arr
.
transpose
((
2
,
0
,
1
))
# color channel first
# convert rgb content to YCbCr
rgb
=
torch
.
from_numpy
(
arr
.
copy
()).
float
()
/
(
2
**
bitdepth
-
1
)
arr
=
np
.
clip
(
rgb2ycbcr
(
rgb
).
numpy
(),
0
,
1
)
arr
=
(
arr
*
(
2
**
bitdepth
-
1
)).
astype
(
np
.
uint8
)
with
open
(
yuv_path
,
"wb"
)
as
f
:
f
.
write
(
arr
.
tobytes
())
# Encode
height
,
width
=
arr
.
shape
[
1
:]
cmd
=
[
self
.
encoder_path
,
"-w"
,
width
,
"-h"
,
height
,
"--fps=1/1"
,
"--limit=1"
,
"--input-bit-depth=8"
,
"--cpu-used=0"
,
"--threads=1"
,
"--passes=2"
,
"--end-usage=q"
,
"--cq-level="
+
str
(
quality
),
"--i444"
,
"--skip=0"
,
"--tune=psnr"
,
"--psnr"
,
"--bit-depth=8"
,
"-o"
,
out_filepath
,
yuv_path
,
]
start
=
time
.
time
()
run_command
(
cmd
)
enc_time
=
time
.
time
()
-
start
# cleanup encoder input
os
.
close
(
fd
)
os
.
unlink
(
yuv_path
)
# Decode
cmd
=
[
self
.
decoder_path
,
out_filepath
,
"-o"
,
yuv_path
,
"--rawvideo"
,
"--output-bit-depth=8"
,
]
start
=
time
.
time
()
run_command
(
cmd
)
dec_time
=
time
.
time
()
-
start
# Compute PSNR
rec_arr
=
np
.
fromfile
(
yuv_path
,
dtype
=
np
.
uint8
)
rec_arr
=
rec_arr
.
reshape
(
arr
.
shape
)
arr
=
arr
.
astype
(
np
.
float32
)
/
(
2
**
bitdepth
-
1
)
rec_arr
=
rec_arr
.
astype
(
np
.
float32
)
/
(
2
**
bitdepth
-
1
)
arr
=
ycbcr2rgb
(
torch
.
from_numpy
(
arr
.
copy
())).
numpy
()
rec_arr
=
ycbcr2rgb
(
torch
.
from_numpy
(
rec_arr
.
copy
())).
numpy
()
bpp
=
filesize
(
out_filepath
)
*
8.0
/
(
height
*
width
)
# Cleanup
os
.
unlink
(
yuv_path
)
os
.
unlink
(
out_filepath
)
out
=
{
"bpp"
:
bpp
,
"encoding_time"
:
enc_time
,
"decoding_time"
:
dec_time
,
}
if
return_metrics
:
psnr_val
,
msssim_val
=
compute_metrics
(
arr
,
rec_arr
,
max_val
=
1.0
)
out
[
"psnr"
]
=
psnr_val
out
[
"ms-ssim"
]
=
msssim_val
if
return_rec
:
rec
=
Image
.
fromarray
(
(
rec_arr
.
clip
(
0
,
1
).
transpose
(
1
,
2
,
0
)
*
255.0
).
astype
(
np
.
uint8
)
)
return
out
,
rec
return
out
codes/compressai/utils/eval_model/__init__.py
0 → 100644
View file @
9e459ea3
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
codes/compressai/utils/eval_model/__main__.py
0 → 100644
View file @
9e459ea3
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Evaluate an end-to-end compression model on an image dataset.
"""
import
argparse
import
json
import
math
import
os
import
sys
import
time
from
collections
import
defaultdict
from
typing
import
List
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
PIL
import
Image
from
pytorch_msssim
import
ms_ssim
from
torchvision
import
transforms
import
compressai
from
compressai.zoo
import
models
as
pretrained_models
from
compressai.zoo.image
import
model_architectures
as
architectures
# from torchvision.datasets.folder
IMG_EXTENSIONS
=
(
".jpg"
,
".jpeg"
,
".png"
,
".ppm"
,
".bmp"
,
".pgm"
,
".tif"
,
".tiff"
,
".webp"
,
)
def
collect_images
(
rootpath
:
str
)
->
List
[
str
]:
return
[
os
.
path
.
join
(
rootpath
,
f
)
for
f
in
os
.
listdir
(
rootpath
)
if
os
.
path
.
splitext
(
f
)[
-
1
].
lower
()
in
IMG_EXTENSIONS
]
def
psnr
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
)
->
float
:
mse
=
F
.
mse_loss
(
a
,
b
).
item
()
return
-
10
*
math
.
log10
(
mse
)
def
read_image
(
filepath
:
str
)
->
torch
.
Tensor
:
assert
os
.
path
.
isfile
(
filepath
)
img
=
Image
.
open
(
filepath
).
convert
(
"RGB"
)
# test_transforms = transforms.Compose(
# [transforms.CenterCrop(256), transforms.ToTensor()]
# )
# return test_transforms(img)
return
transforms
.
ToTensor
()(
img
)
@
torch
.
no_grad
()
def
inference
(
model
,
x
,
savedir
=
""
,
idx
=
1
):
x
=
x
.
unsqueeze
(
0
)
h
,
w
=
x
.
size
(
2
),
x
.
size
(
3
)
p
=
64
# maximum 6 strides of 2
new_h
=
(
h
+
p
-
1
)
//
p
*
p
new_w
=
(
w
+
p
-
1
)
//
p
*
p
padding_left
=
(
new_w
-
w
)
//
2
padding_right
=
new_w
-
w
-
padding_left
padding_top
=
(
new_h
-
h
)
//
2
padding_bottom
=
new_h
-
h
-
padding_top
x_padded
=
F
.
pad
(
x
,
(
padding_left
,
padding_right
,
padding_top
,
padding_bottom
),
mode
=
"constant"
,
value
=
0
,
)
start
=
time
.
time
()
out_enc
=
model
.
compress
(
x_padded
)
enc_time
=
time
.
time
()
-
start
start
=
time
.
time
()
out_dec
=
model
.
decompress
(
out_enc
[
"strings"
],
out_enc
[
"shape"
])
dec_time
=
time
.
time
()
-
start
out_dec
[
"x_hat"
]
=
F
.
pad
(
out_dec
[
"x_hat"
],
(
-
padding_left
,
-
padding_right
,
-
padding_top
,
-
padding_bottom
)
)
num_pixels
=
x
.
size
(
0
)
*
x
.
size
(
2
)
*
x
.
size
(
3
)
bpp
=
sum
(
len
(
s
[
0
])
for
s
in
out_enc
[
"strings"
])
*
8.0
/
num_pixels
if
savedir
!=
""
:
if
not
os
.
path
.
exists
(
savedir
):
os
.
makedirs
(
savedir
)
cur_psnr
=
psnr
(
x
,
out_dec
[
"x_hat"
])
cur_ssim
=
ms_ssim
(
x
,
out_dec
[
"x_hat"
],
data_range
=
1.0
).
item
()
tran1
=
transforms
.
ToPILImage
()
cur_img
=
tran1
(
out_dec
[
"x_hat"
][
0
])
cur_img
.
save
(
os
.
path
.
join
(
savedir
,
'{:02d}'
.
format
(
idx
)
+
"_"
+
'{:.2f}'
.
format
(
cur_psnr
)
+
"_"
+
'{:.3f}'
.
format
(
bpp
)
+
"_"
+
'{:.3f}'
.
format
(
cur_ssim
)
+
".png"
))
return
{
"psnr"
:
psnr
(
x
,
out_dec
[
"x_hat"
]),
"ms-ssim"
:
ms_ssim
(
x
,
out_dec
[
"x_hat"
],
data_range
=
1.0
).
item
(),
"bpp"
:
bpp
,
"encoding_time"
:
enc_time
,
"decoding_time"
:
dec_time
,
}
@
torch
.
no_grad
()
def
inference_entropy_estimation
(
model
,
x
):
x
=
x
.
unsqueeze
(
0
)
start
=
time
.
time
()
out_net
=
model
.
forward
(
x
)
# print(out_net['x_hat'][0,0,:5,:5])
elapsed_time
=
time
.
time
()
-
start
num_pixels
=
x
.
size
(
0
)
*
x
.
size
(
2
)
*
x
.
size
(
3
)
bpp
=
sum
(
(
torch
.
log
(
likelihoods
).
sum
()
/
(
-
math
.
log
(
2
)
*
num_pixels
))
for
likelihoods
in
out_net
[
"likelihoods"
].
values
()
)
return
{
"psnr"
:
psnr
(
x
,
out_net
[
"x_hat"
]),
"bpp"
:
bpp
.
item
(),
"encoding_time"
:
elapsed_time
/
2.0
,
# broad estimation
"decoding_time"
:
elapsed_time
/
2.0
,
}
def
load_pretrained
(
model
:
str
,
metric
:
str
,
quality
:
int
)
->
nn
.
Module
:
return
pretrained_models
[
model
](
quality
=
quality
,
metric
=
metric
,
pretrained
=
True
).
eval
()
def
load_checkpoint
(
arch
:
str
,
checkpoint_path
:
str
)
->
nn
.
Module
:
return
architectures
[
arch
].
from_state_dict
(
torch
.
load
(
checkpoint_path
)).
eval
()
def
eval_model
(
model
,
filepaths
,
entropy_estimation
=
False
,
half
=
False
,
savedir
=
""
):
device
=
next
(
model
.
parameters
()).
device
metrics
=
defaultdict
(
float
)
for
idx
,
f
in
enumerate
(
sorted
(
filepaths
)):
x
=
read_image
(
f
).
to
(
device
)
if
not
entropy_estimation
:
print
(
'evaluating index'
,
idx
)
if
half
:
model
=
model
.
half
()
x
=
x
.
half
()
rv
=
inference
(
model
,
x
,
savedir
,
idx
)
else
:
rv
=
inference_entropy_estimation
(
model
,
x
)
print
(
'bpp'
,
rv
[
'bpp'
])
print
(
'psnr'
,
rv
[
'psnr'
])
print
(
'ms-ssim'
,
rv
[
'ms-ssim'
])
print
()
for
k
,
v
in
rv
.
items
():
metrics
[
k
]
+=
v
for
k
,
v
in
metrics
.
items
():
metrics
[
k
]
=
v
/
len
(
filepaths
)
return
metrics
def
setup_args
():
parent_parser
=
argparse
.
ArgumentParser
(
add_help
=
False
,
)
# Common options.
parent_parser
.
add_argument
(
"dataset"
,
type
=
str
,
help
=
"dataset path"
)
parent_parser
.
add_argument
(
"-a"
,
"--arch"
,
type
=
str
,
choices
=
pretrained_models
.
keys
(),
help
=
"model architecture"
,
required
=
True
,
)
parent_parser
.
add_argument
(
"-c"
,
"--entropy-coder"
,
choices
=
compressai
.
available_entropy_coders
(),
default
=
compressai
.
available_entropy_coders
()[
0
],
help
=
"entropy coder (default: %(default)s)"
,
)
parent_parser
.
add_argument
(
"--cuda"
,
action
=
"store_true"
,
help
=
"enable CUDA"
,
)
parent_parser
.
add_argument
(
"--half"
,
action
=
"store_true"
,
help
=
"convert model to half floating point (fp16)"
,
)
parent_parser
.
add_argument
(
"--entropy-estimation"
,
action
=
"store_true"
,
help
=
"use evaluated entropy estimation (no entropy coding)"
,
)
parent_parser
.
add_argument
(
"-v"
,
"--verbose"
,
action
=
"store_true"
,
help
=
"verbose mode"
,
)
parent_parser
.
add_argument
(
"-s"
,
"--savedir"
,
type
=
str
,
default
=
""
,
)
parent_parser
.
add_argument
(
"--gpu_id"
,
type
=
int
,
default
=
0
,
help
=
"GPU ID"
)
parser
=
argparse
.
ArgumentParser
(
description
=
"Evaluate a model on an image dataset."
,
add_help
=
True
)
subparsers
=
parser
.
add_subparsers
(
help
=
"model source"
,
dest
=
"source"
)
#, required=True
# )
# Options for pretrained models
pretrained_parser
=
subparsers
.
add_parser
(
"pretrained"
,
parents
=
[
parent_parser
])
pretrained_parser
.
add_argument
(
"-m"
,
"--metric"
,
type
=
str
,
choices
=
[
"mse"
,
"ms-ssim"
],
default
=
"mse"
,
help
=
"metric trained against (default: %(default)s)"
,
)
pretrained_parser
.
add_argument
(
"-q"
,
"--quality"
,
dest
=
"qualities"
,
nargs
=
"+"
,
type
=
int
,
default
=
(
1
,),
)
checkpoint_parser
=
subparsers
.
add_parser
(
"checkpoint"
,
parents
=
[
parent_parser
])
# checkpoint_parser.add_argument(
# "-p",
# "--path",
# dest="paths",
# type=str,
# nargs="*",
# required=True,
# help="checkpoint path",
# )
checkpoint_parser
.
add_argument
(
"-exp"
,
"--experiment"
,
type
=
str
,
required
=
True
,
help
=
"Experiment name"
)
return
parser
def
main
(
argv
):
args
=
setup_args
().
parse_args
(
argv
)
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
str
(
args
.
gpu_id
)
device
=
"cuda"
if
args
.
cuda
and
torch
.
cuda
.
is_available
()
else
"cpu"
filepaths
=
collect_images
(
args
.
dataset
)
if
len
(
filepaths
)
==
0
:
print
(
"No images found in directory."
)
sys
.
exit
(
1
)
compressai
.
set_entropy_coder
(
args
.
entropy_coder
)
if
args
.
source
==
"pretrained"
:
runs
=
sorted
(
args
.
qualities
)
opts
=
(
args
.
arch
,
args
.
metric
)
load_func
=
load_pretrained
log_fmt
=
"
\r
Evaluating {0} | {run:d}"
elif
args
.
source
==
"checkpoint"
:
# runs = args.paths
checkpoint_updated_dir
=
os
.
path
.
join
(
'../experiments'
,
args
.
experiment
,
'checkpoint_updated'
)
checkpoint_updated
=
os
.
path
.
join
(
checkpoint_updated_dir
,
os
.
listdir
(
checkpoint_updated_dir
)[
0
])
runs
=
[
checkpoint_updated
]
opts
=
(
args
.
arch
,)
load_func
=
load_checkpoint
log_fmt
=
"
\r
Evaluating {run:s}"
results
=
defaultdict
(
list
)
for
run
in
runs
:
if
args
.
verbose
:
sys
.
stderr
.
write
(
log_fmt
.
format
(
*
opts
,
run
=
run
))
sys
.
stderr
.
flush
()
model
=
load_func
(
*
opts
,
run
)
if
args
.
cuda
and
torch
.
cuda
.
is_available
():
model
=
model
.
to
(
"cuda"
)
metrics
=
eval_model
(
model
,
filepaths
,
args
.
entropy_estimation
,
args
.
half
,
args
.
savedir
)
for
k
,
v
in
metrics
.
items
():
results
[
k
].
append
(
v
)
if
args
.
verbose
:
sys
.
stderr
.
write
(
"
\n
"
)
sys
.
stderr
.
flush
()
description
=
(
"entropy estimation"
if
args
.
entropy_estimation
else
args
.
entropy_coder
)
output
=
{
"name"
:
args
.
arch
,
"description"
:
f
"Inference (
{
description
}
)"
,
"results"
:
results
,
}
print
(
json
.
dumps
(
output
,
indent
=
2
))
if
__name__
==
"__main__"
:
main
(
sys
.
argv
[
1
:])
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment