Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
compressai
Commits
9e459ea3
Commit
9e459ea3
authored
Nov 05, 2024
by
jerrrrry
Browse files
Initial commit
parents
Changes
120
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1981 additions
and
0 deletions
+1981
-0
codes/compressai/layers/layers.py
codes/compressai/layers/layers.py
+213
-0
codes/compressai/models/__init__.py
codes/compressai/models/__init__.py
+17
-0
codes/compressai/models/our_utils.py
codes/compressai/models/our_utils.py
+385
-0
codes/compressai/models/ours.py
codes/compressai/models/ours.py
+140
-0
codes/compressai/models/priors.py
codes/compressai/models/priors.py
+0
-0
codes/compressai/models/utils.py
codes/compressai/models/utils.py
+130
-0
codes/compressai/models/waseda.py
codes/compressai/models/waseda.py
+138
-0
codes/compressai/ops/__init__.py
codes/compressai/ops/__init__.py
+19
-0
codes/compressai/ops/bound_ops.py
codes/compressai/ops/bound_ops.py
+53
-0
codes/compressai/ops/ops.py
codes/compressai/ops/ops.py
+32
-0
codes/compressai/ops/parametrizers.py
codes/compressai/ops/parametrizers.py
+45
-0
codes/compressai/transforms/__init__.py
codes/compressai/transforms/__init__.py
+15
-0
codes/compressai/transforms/functional.py
codes/compressai/transforms/functional.py
+135
-0
codes/compressai/transforms/transforms.py
codes/compressai/transforms/transforms.py
+118
-0
codes/compressai/utils/__init__.py
codes/compressai/utils/__init__.py
+13
-0
codes/compressai/utils/bench/__init__.py
codes/compressai/utils/bench/__init__.py
+13
-0
codes/compressai/utils/bench/__main__.py
codes/compressai/utils/bench/__main__.py
+166
-0
codes/compressai/utils/bench/codecs.py
codes/compressai/utils/bench/codecs.py
+0
-0
codes/compressai/utils/eval_model/__init__.py
codes/compressai/utils/eval_model/__init__.py
+13
-0
codes/compressai/utils/eval_model/__main__.py
codes/compressai/utils/eval_model/__main__.py
+336
-0
No files found.
codes/compressai/layers/layers.py
0 → 100644
View file @
9e459ea3
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.nn
as
nn
from
.gdn
import
GDN
class
MaskedConv2d
(
nn
.
Conv2d
):
r
"""Masked 2D convolution implementation, mask future "unseen" pixels.
Useful for building auto-regressive network components.
Introduced in `"Conditional Image Generation with PixelCNN Decoders"
<https://arxiv.org/abs/1606.05328>`_.
Inherits the same arguments as a `nn.Conv2d`. Use `mask_type='A'` for the
first layer (which also masks the "current pixel"), `mask_type='B'` for the
following layers.
"""
def
__init__
(
self
,
*
args
,
mask_type
=
"A"
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
if
mask_type
not
in
(
"A"
,
"B"
):
raise
ValueError
(
f
'Invalid "mask_type" value "
{
mask_type
}
"'
)
self
.
register_buffer
(
"mask"
,
torch
.
ones_like
(
self
.
weight
.
data
))
_
,
_
,
h
,
w
=
self
.
mask
.
size
()
self
.
mask
[:,
:,
h
//
2
,
w
//
2
+
(
mask_type
==
"B"
)
:]
=
0
self
.
mask
[:,
:,
h
//
2
+
1
:]
=
0
def
forward
(
self
,
x
):
# TODO(begaintj): weight assigment is not supported by torchscript
self
.
weight
.
data
*=
self
.
mask
return
super
().
forward
(
x
)
def
conv3x3
(
in_ch
,
out_ch
,
stride
=
1
):
"""3x3 convolution with padding."""
return
nn
.
Conv2d
(
in_ch
,
out_ch
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
1
)
def
subpel_conv3x3
(
in_ch
,
out_ch
,
r
=
1
):
"""3x3 sub-pixel convolution for up-sampling."""
return
nn
.
Sequential
(
nn
.
Conv2d
(
in_ch
,
out_ch
*
r
**
2
,
kernel_size
=
3
,
padding
=
1
),
nn
.
PixelShuffle
(
r
)
)
def
conv1x1
(
in_ch
,
out_ch
,
stride
=
1
):
"""1x1 convolution."""
return
nn
.
Conv2d
(
in_ch
,
out_ch
,
kernel_size
=
1
,
stride
=
stride
)
class
ResidualBlockWithStride
(
nn
.
Module
):
"""Residual block with a stride on the first convolution.
Args:
in_ch (int): number of input channels
out_ch (int): number of output channels
stride (int): stride value (default: 2)
"""
def
__init__
(
self
,
in_ch
,
out_ch
,
stride
=
2
):
super
().
__init__
()
self
.
conv1
=
conv3x3
(
in_ch
,
out_ch
,
stride
=
stride
)
self
.
leaky_relu
=
nn
.
LeakyReLU
(
inplace
=
True
)
self
.
conv2
=
conv3x3
(
out_ch
,
out_ch
)
self
.
gdn
=
GDN
(
out_ch
)
if
stride
!=
1
or
in_ch
!=
out_ch
:
self
.
skip
=
conv1x1
(
in_ch
,
out_ch
,
stride
=
stride
)
else
:
self
.
skip
=
None
def
forward
(
self
,
x
):
identity
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
leaky_relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
gdn
(
out
)
if
self
.
skip
is
not
None
:
identity
=
self
.
skip
(
x
)
out
+=
identity
return
out
class
ResidualBlockUpsample
(
nn
.
Module
):
"""Residual block with sub-pixel upsampling on the last convolution.
Args:
in_ch (int): number of input channels
out_ch (int): number of output channels
upsample (int): upsampling factor (default: 2)
"""
def
__init__
(
self
,
in_ch
,
out_ch
,
upsample
=
2
):
super
().
__init__
()
self
.
subpel_conv
=
subpel_conv3x3
(
in_ch
,
out_ch
,
upsample
)
self
.
leaky_relu
=
nn
.
LeakyReLU
(
inplace
=
True
)
self
.
conv
=
conv3x3
(
out_ch
,
out_ch
)
self
.
igdn
=
GDN
(
out_ch
,
inverse
=
True
)
self
.
upsample
=
subpel_conv3x3
(
in_ch
,
out_ch
,
upsample
)
def
forward
(
self
,
x
):
identity
=
x
out
=
self
.
subpel_conv
(
x
)
out
=
self
.
leaky_relu
(
out
)
out
=
self
.
conv
(
out
)
out
=
self
.
igdn
(
out
)
identity
=
self
.
upsample
(
x
)
out
+=
identity
return
out
class
ResidualBlock
(
nn
.
Module
):
"""Simple residual block with two 3x3 convolutions.
Args:
in_ch (int): number of input channels
out_ch (int): number of output channels
"""
def
__init__
(
self
,
in_ch
,
out_ch
):
super
().
__init__
()
self
.
conv1
=
conv3x3
(
in_ch
,
out_ch
)
self
.
leaky_relu
=
nn
.
LeakyReLU
(
inplace
=
True
)
self
.
conv2
=
conv3x3
(
out_ch
,
out_ch
)
if
in_ch
!=
out_ch
:
self
.
skip
=
conv1x1
(
in_ch
,
out_ch
)
else
:
self
.
skip
=
None
def
forward
(
self
,
x
):
identity
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
leaky_relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
leaky_relu
(
out
)
if
self
.
skip
is
not
None
:
identity
=
self
.
skip
(
x
)
out
=
out
+
identity
return
out
class
AttentionBlock
(
nn
.
Module
):
"""Self attention block.
Simplified variant from `"Learned Image Compression with
Discretized Gaussian Mixture Likelihoods and Attention Modules"
<https://arxiv.org/abs/2001.01568>`_, by Zhengxue Cheng, Heming Sun, Masaru
Takeuchi, Jiro Katto.
Args:
N (int): Number of channels)
"""
def
__init__
(
self
,
N
):
super
().
__init__
()
class
ResidualUnit
(
nn
.
Module
):
"""Simple residual unit."""
def
__init__
(
self
):
super
().
__init__
()
self
.
conv
=
nn
.
Sequential
(
conv1x1
(
N
,
N
//
2
),
nn
.
ReLU
(
inplace
=
True
),
conv3x3
(
N
//
2
,
N
//
2
),
nn
.
ReLU
(
inplace
=
True
),
conv1x1
(
N
//
2
,
N
),
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
def
forward
(
self
,
x
):
identity
=
x
out
=
self
.
conv
(
x
)
out
+=
identity
out
=
self
.
relu
(
out
)
return
out
self
.
conv_a
=
nn
.
Sequential
(
ResidualUnit
(),
ResidualUnit
(),
ResidualUnit
())
self
.
conv_b
=
nn
.
Sequential
(
ResidualUnit
(),
ResidualUnit
(),
ResidualUnit
(),
conv1x1
(
N
,
N
),
)
def
forward
(
self
,
x
):
identity
=
x
a
=
self
.
conv_a
(
x
)
b
=
self
.
conv_b
(
x
)
out
=
a
*
torch
.
sigmoid
(
b
)
out
+=
identity
return
out
codes/compressai/models/__init__.py
0 → 100644
View file @
9e459ea3
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.priors
import
*
from
.waseda
import
*
from
.ours
import
*
codes/compressai/models/our_utils.py
0 → 100644
View file @
9e459ea3
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.init
as
init
import
torch.nn.functional
as
F
import
warnings
from
compressai.layers
import
*
class
AttModule
(
nn
.
Module
):
def
__init__
(
self
,
N
):
super
(
AttModule
,
self
).
__init__
()
self
.
forw_att
=
AttentionBlock
(
N
)
self
.
back_att
=
AttentionBlock
(
N
)
def
forward
(
self
,
x
,
rev
=
False
):
if
not
rev
:
return
self
.
forw_att
(
x
)
else
:
return
self
.
back_att
(
x
)
class
EnhModule
(
nn
.
Module
):
def
__init__
(
self
,
nf
):
super
(
EnhModule
,
self
).
__init__
()
self
.
forw_enh
=
EnhBlock
(
nf
)
self
.
back_enh
=
EnhBlock
(
nf
)
def
forward
(
self
,
x
,
rev
=
False
):
if
not
rev
:
return
self
.
forw_enh
(
x
)
else
:
return
self
.
back_enh
(
x
)
class
EnhBlock
(
nn
.
Module
):
def
__init__
(
self
,
nf
):
super
(
EnhBlock
,
self
).
__init__
()
self
.
layers
=
nn
.
Sequential
(
DenseBlock
(
3
,
nf
),
nn
.
Conv2d
(
nf
,
nf
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
True
),
nn
.
Conv2d
(
nf
,
nf
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
True
),
nn
.
Conv2d
(
nf
,
nf
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
True
),
DenseBlock
(
nf
,
3
)
)
def
forward
(
self
,
x
):
return
x
+
self
.
layers
(
x
)
*
0.2
class
InvComp
(
nn
.
Module
):
def
__init__
(
self
,
M
):
super
(
InvComp
,
self
).
__init__
()
self
.
in_nc
=
3
self
.
out_nc
=
M
self
.
operations
=
nn
.
ModuleList
()
# 1st level
b
=
SqueezeLayer
(
2
)
self
.
operations
.
append
(
b
)
self
.
in_nc
*=
4
b
=
InvertibleConv1x1
(
self
.
in_nc
)
self
.
operations
.
append
(
b
)
b
=
CouplingLayer
(
self
.
in_nc
//
4
,
3
*
self
.
in_nc
//
4
,
5
)
self
.
operations
.
append
(
b
)
b
=
CouplingLayer
(
self
.
in_nc
//
4
,
3
*
self
.
in_nc
//
4
,
5
)
self
.
operations
.
append
(
b
)
b
=
CouplingLayer
(
self
.
in_nc
//
4
,
3
*
self
.
in_nc
//
4
,
5
)
self
.
operations
.
append
(
b
)
# 2nd level
b
=
SqueezeLayer
(
2
)
self
.
operations
.
append
(
b
)
self
.
in_nc
*=
4
b
=
InvertibleConv1x1
(
self
.
in_nc
)
self
.
operations
.
append
(
b
)
b
=
CouplingLayer
(
self
.
in_nc
//
4
,
3
*
self
.
in_nc
//
4
,
5
)
self
.
operations
.
append
(
b
)
b
=
CouplingLayer
(
self
.
in_nc
//
4
,
3
*
self
.
in_nc
//
4
,
5
)
self
.
operations
.
append
(
b
)
b
=
CouplingLayer
(
self
.
in_nc
//
4
,
3
*
self
.
in_nc
//
4
,
5
)
self
.
operations
.
append
(
b
)
# 3rd level
b
=
SqueezeLayer
(
2
)
self
.
operations
.
append
(
b
)
self
.
in_nc
*=
4
b
=
InvertibleConv1x1
(
self
.
in_nc
)
self
.
operations
.
append
(
b
)
b
=
CouplingLayer
(
self
.
in_nc
//
4
,
3
*
self
.
in_nc
//
4
,
3
)
self
.
operations
.
append
(
b
)
b
=
CouplingLayer
(
self
.
in_nc
//
4
,
3
*
self
.
in_nc
//
4
,
3
)
self
.
operations
.
append
(
b
)
b
=
CouplingLayer
(
self
.
in_nc
//
4
,
3
*
self
.
in_nc
//
4
,
3
)
self
.
operations
.
append
(
b
)
# 4th level
b
=
SqueezeLayer
(
2
)
self
.
operations
.
append
(
b
)
self
.
in_nc
*=
4
b
=
InvertibleConv1x1
(
self
.
in_nc
)
self
.
operations
.
append
(
b
)
b
=
CouplingLayer
(
self
.
in_nc
//
4
,
3
*
self
.
in_nc
//
4
,
3
)
self
.
operations
.
append
(
b
)
b
=
CouplingLayer
(
self
.
in_nc
//
4
,
3
*
self
.
in_nc
//
4
,
3
)
self
.
operations
.
append
(
b
)
b
=
CouplingLayer
(
self
.
in_nc
//
4
,
3
*
self
.
in_nc
//
4
,
3
)
self
.
operations
.
append
(
b
)
def
forward
(
self
,
x
,
rev
=
False
):
if
not
rev
:
for
op
in
self
.
operations
:
x
=
op
.
forward
(
x
,
False
)
b
,
c
,
h
,
w
=
x
.
size
()
x
=
torch
.
mean
(
x
.
view
(
b
,
c
//
self
.
out_nc
,
self
.
out_nc
,
h
,
w
),
dim
=
1
)
else
:
times
=
self
.
in_nc
//
self
.
out_nc
x
=
x
.
repeat
(
1
,
times
,
1
,
1
)
for
op
in
reversed
(
self
.
operations
):
x
=
op
.
forward
(
x
,
True
)
return
x
class
CouplingLayer
(
nn
.
Module
):
def
__init__
(
self
,
split_len1
,
split_len2
,
kernal_size
,
clamp
=
1.0
):
super
(
CouplingLayer
,
self
).
__init__
()
self
.
split_len1
=
split_len1
self
.
split_len2
=
split_len2
self
.
clamp
=
clamp
self
.
G1
=
Bottleneck
(
self
.
split_len1
,
self
.
split_len2
,
kernal_size
)
self
.
G2
=
Bottleneck
(
self
.
split_len2
,
self
.
split_len1
,
kernal_size
)
self
.
H1
=
Bottleneck
(
self
.
split_len1
,
self
.
split_len2
,
kernal_size
)
self
.
H2
=
Bottleneck
(
self
.
split_len2
,
self
.
split_len1
,
kernal_size
)
def
forward
(
self
,
x
,
rev
=
False
):
x1
,
x2
=
(
x
.
narrow
(
1
,
0
,
self
.
split_len1
),
x
.
narrow
(
1
,
self
.
split_len1
,
self
.
split_len2
))
if
not
rev
:
y1
=
x1
.
mul
(
torch
.
exp
(
self
.
clamp
*
(
torch
.
sigmoid
(
self
.
G2
(
x2
))
*
2
-
1
)
))
+
self
.
H2
(
x2
)
y2
=
x2
.
mul
(
torch
.
exp
(
self
.
clamp
*
(
torch
.
sigmoid
(
self
.
G1
(
y1
))
*
2
-
1
)
))
+
self
.
H1
(
y1
)
else
:
y2
=
(
x2
-
self
.
H1
(
x1
)).
div
(
torch
.
exp
(
self
.
clamp
*
(
torch
.
sigmoid
(
self
.
G1
(
x1
))
*
2
-
1
)
))
y1
=
(
x1
-
self
.
H2
(
y2
)).
div
(
torch
.
exp
(
self
.
clamp
*
(
torch
.
sigmoid
(
self
.
G2
(
y2
))
*
2
-
1
)
))
return
torch
.
cat
((
y1
,
y2
),
1
)
class
Bottleneck
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
):
super
(
Bottleneck
,
self
).
__init__
()
# P = ((S-1)*W-S+F)/2, with F = filter size, S = stride
padding
=
(
kernel_size
-
1
)
//
2
self
.
conv1
=
nn
.
Conv2d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
padding
=
padding
)
self
.
conv2
=
nn
.
Conv2d
(
in_channels
=
out_channels
,
out_channels
=
out_channels
,
kernel_size
=
1
)
self
.
conv3
=
nn
.
Conv2d
(
in_channels
=
out_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
padding
=
padding
)
self
.
lrelu
=
nn
.
LeakyReLU
(
negative_slope
=
0.2
,
inplace
=
True
)
initialize_weights_xavier
([
self
.
conv1
,
self
.
conv2
],
0.1
)
initialize_weights
(
self
.
conv3
,
0
)
def
forward
(
self
,
x
):
conv1
=
self
.
lrelu
(
self
.
conv1
(
x
))
conv2
=
self
.
lrelu
(
self
.
conv2
(
conv1
))
conv3
=
self
.
conv3
(
conv2
)
return
conv3
class
SqueezeLayer
(
nn
.
Module
):
def
__init__
(
self
,
factor
):
super
().
__init__
()
self
.
factor
=
factor
def
forward
(
self
,
input
,
reverse
=
False
):
if
not
reverse
:
output
=
self
.
squeeze2d
(
input
,
self
.
factor
)
# Squeeze in forward
return
output
else
:
output
=
self
.
unsqueeze2d
(
input
,
self
.
factor
)
return
output
def
jacobian
(
self
,
x
,
rev
=
False
):
return
0
@
staticmethod
def
squeeze2d
(
input
,
factor
=
2
):
assert
factor
>=
1
and
isinstance
(
factor
,
int
)
if
factor
==
1
:
return
input
size
=
input
.
size
()
B
=
size
[
0
]
C
=
size
[
1
]
H
=
size
[
2
]
W
=
size
[
3
]
assert
H
%
factor
==
0
and
W
%
factor
==
0
,
"{}"
.
format
((
H
,
W
,
factor
))
x
=
input
.
view
(
B
,
C
,
H
//
factor
,
factor
,
W
//
factor
,
factor
)
x
=
x
.
permute
(
0
,
3
,
5
,
1
,
2
,
4
).
contiguous
()
x
=
x
.
view
(
B
,
factor
*
factor
*
C
,
H
//
factor
,
W
//
factor
)
return
x
@
staticmethod
def
unsqueeze2d
(
input
,
factor
=
2
):
assert
factor
>=
1
and
isinstance
(
factor
,
int
)
factor2
=
factor
**
2
if
factor
==
1
:
return
input
size
=
input
.
size
()
B
=
size
[
0
]
C
=
size
[
1
]
H
=
size
[
2
]
W
=
size
[
3
]
assert
C
%
(
factor2
)
==
0
,
"{}"
.
format
(
C
)
x
=
input
.
view
(
B
,
factor
,
factor
,
C
//
factor2
,
H
,
W
)
x
=
x
.
permute
(
0
,
3
,
4
,
1
,
5
,
2
).
contiguous
()
x
=
x
.
view
(
B
,
C
//
(
factor2
),
H
*
factor
,
W
*
factor
)
return
x
class
InvertibleConv1x1
(
nn
.
Module
):
def
__init__
(
self
,
num_channels
):
super
().
__init__
()
w_shape
=
[
num_channels
,
num_channels
]
w_init
=
np
.
linalg
.
qr
(
np
.
random
.
randn
(
*
w_shape
))[
0
].
astype
(
np
.
float32
)
self
.
register_parameter
(
"weight"
,
nn
.
Parameter
(
torch
.
Tensor
(
w_init
)))
self
.
w_shape
=
w_shape
def
get_weight
(
self
,
input
,
reverse
):
w_shape
=
self
.
w_shape
if
not
reverse
:
weight
=
self
.
weight
.
view
(
w_shape
[
0
],
w_shape
[
1
],
1
,
1
)
else
:
weight
=
torch
.
inverse
(
self
.
weight
.
double
()).
float
()
\
.
view
(
w_shape
[
0
],
w_shape
[
1
],
1
,
1
)
return
weight
def
forward
(
self
,
input
,
reverse
=
False
):
weight
=
self
.
get_weight
(
input
,
reverse
)
if
not
reverse
:
z
=
F
.
conv2d
(
input
,
weight
)
return
z
else
:
z
=
F
.
conv2d
(
input
,
weight
)
return
z
class
DenseBlock
(
nn
.
Module
):
def
__init__
(
self
,
channel_in
,
channel_out
,
init
=
'xavier'
,
gc
=
32
,
bias
=
True
):
super
(
DenseBlock
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
channel_in
,
gc
,
3
,
1
,
1
,
bias
=
bias
)
self
.
conv2
=
nn
.
Conv2d
(
channel_in
+
gc
,
gc
,
3
,
1
,
1
,
bias
=
bias
)
self
.
conv3
=
nn
.
Conv2d
(
channel_in
+
2
*
gc
,
gc
,
3
,
1
,
1
,
bias
=
bias
)
self
.
conv4
=
nn
.
Conv2d
(
channel_in
+
3
*
gc
,
gc
,
3
,
1
,
1
,
bias
=
bias
)
self
.
conv5
=
nn
.
Conv2d
(
channel_in
+
4
*
gc
,
channel_out
,
3
,
1
,
1
,
bias
=
bias
)
self
.
lrelu
=
nn
.
LeakyReLU
(
negative_slope
=
0.2
,
inplace
=
True
)
if
init
==
'xavier'
:
initialize_weights_xavier
([
self
.
conv1
,
self
.
conv2
,
self
.
conv3
,
self
.
conv4
],
0.1
)
else
:
initialize_weights
([
self
.
conv1
,
self
.
conv2
,
self
.
conv3
,
self
.
conv4
],
0.1
)
initialize_weights
(
self
.
conv5
,
0
)
def
forward
(
self
,
x
):
x1
=
self
.
lrelu
(
self
.
conv1
(
x
))
x2
=
self
.
lrelu
(
self
.
conv2
(
torch
.
cat
((
x
,
x1
),
1
)))
x3
=
self
.
lrelu
(
self
.
conv3
(
torch
.
cat
((
x
,
x1
,
x2
),
1
)))
x4
=
self
.
lrelu
(
self
.
conv4
(
torch
.
cat
((
x
,
x1
,
x2
,
x3
),
1
)))
x5
=
self
.
conv5
(
torch
.
cat
((
x
,
x1
,
x2
,
x3
,
x4
),
1
))
return
x5
def
initialize_weights
(
net_l
,
scale
=
1
):
if
not
isinstance
(
net_l
,
list
):
net_l
=
[
net_l
]
for
net
in
net_l
:
for
m
in
net
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
init
.
kaiming_normal_
(
m
.
weight
,
a
=
0
,
mode
=
'fan_in'
)
m
.
weight
.
data
*=
scale
# for residual block
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
elif
isinstance
(
m
,
nn
.
Linear
):
init
.
kaiming_normal_
(
m
.
weight
,
a
=
0
,
mode
=
'fan_in'
)
m
.
weight
.
data
*=
scale
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
init
.
constant_
(
m
.
weight
,
1
)
init
.
constant_
(
m
.
bias
.
data
,
0.0
)
def
initialize_weights_xavier
(
net_l
,
scale
=
1
):
if
not
isinstance
(
net_l
,
list
):
net_l
=
[
net_l
]
for
net
in
net_l
:
for
m
in
net
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
init
.
xavier_normal_
(
m
.
weight
)
m
.
weight
.
data
*=
scale
# for residual block
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
elif
isinstance
(
m
,
nn
.
Linear
):
init
.
xavier_normal_
(
m
.
weight
)
m
.
weight
.
data
*=
scale
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
init
.
constant_
(
m
.
weight
,
1
)
init
.
constant_
(
m
.
bias
.
data
,
0.0
)
def
__init__
(
self
,
in_shape
,
int_ch
,
numTraceSamples
=
0
,
numSeriesTerms
=
0
,
stride
=
1
,
coeff
=
.
97
,
input_nonlin
=
True
,
actnorm
=
True
,
n_power_iter
=
5
,
nonlin
=
"elu"
,
train
=
False
):
"""
buid invertible bottleneck block
:param in_shape: shape of the input (channels, height, width)
:param int_ch: dimension of intermediate layers
:param stride: 1 if no downsample 2 if downsample
:param coeff: desired lipschitz constant
:param input_nonlin: if true applies a nonlinearity on the input
:param actnorm: if true uses actnorm like GLOW
:param n_power_iter: number of iterations for spectral normalization
:param nonlin: the nonlinearity to use
"""
super
(
conv_iresnet_block_simplified
,
self
).
__init__
()
assert
stride
in
(
1
,
2
)
self
.
stride
=
stride
self
.
squeeze
=
IRes_Squeeze
(
stride
)
self
.
coeff
=
coeff
self
.
numTraceSamples
=
numTraceSamples
self
.
numSeriesTerms
=
numSeriesTerms
self
.
n_power_iter
=
n_power_iter
nonlin
=
{
"relu"
:
nn
.
ReLU
,
"elu"
:
nn
.
ELU
,
"softplus"
:
nn
.
Softplus
,
"sorting"
:
lambda
:
MaxMinGroup
(
group_size
=
2
,
axis
=
1
)
}[
nonlin
]
# set shapes for spectral norm conv
in_ch
,
h
,
w
=
in_shape
layers
=
[]
if
input_nonlin
:
layers
.
append
(
nonlin
())
in_ch
=
in_ch
*
stride
**
2
kernel_size1
=
1
layers
.
append
(
self
.
_wrapper_spectral_norm
(
nn
.
Conv2d
(
in_ch
,
int_ch
,
kernel_size
=
kernel_size1
,
padding
=
0
),
(
in_ch
,
h
,
w
),
kernel_size1
))
layers
.
append
(
nonlin
())
kernel_size3
=
1
layers
.
append
(
self
.
_wrapper_spectral_norm
(
nn
.
Conv2d
(
int_ch
,
in_ch
,
kernel_size
=
kernel_size3
,
padding
=
0
),
(
int_ch
,
h
,
w
),
kernel_size3
))
self
.
bottleneck_block
=
nn
.
Sequential
(
*
layers
)
if
actnorm
:
self
.
actnorm
=
ActNorm2D
(
in_ch
,
train
=
train
)
else
:
self
.
actnorm
=
None
def
forward
(
self
,
x
,
rev
=
False
,
ignore_logdet
=
False
,
maxIter
=
25
):
if
not
rev
:
""" bijective or injective block forward """
if
self
.
stride
==
2
:
x
=
self
.
squeeze
.
forward
(
x
)
if
self
.
actnorm
is
not
None
:
x
,
an_logdet
=
self
.
actnorm
(
x
)
else
:
an_logdet
=
0.0
Fx
=
self
.
bottleneck_block
(
x
)
if
(
self
.
numTraceSamples
==
0
and
self
.
numSeriesTerms
==
0
)
or
ignore_logdet
:
trace
=
torch
.
tensor
(
0.
)
else
:
trace
=
power_series_matrix_logarithm_trace
(
Fx
,
x
,
self
.
numSeriesTerms
,
self
.
numTraceSamples
)
y
=
Fx
+
x
return
y
,
trace
+
an_logdet
else
:
y
=
x
for
iter_index
in
range
(
maxIter
):
summand
=
self
.
bottleneck_block
(
x
)
x
=
y
-
summand
if
self
.
actnorm
is
not
None
:
x
=
self
.
actnorm
.
inverse
(
x
)
if
self
.
stride
==
2
:
x
=
self
.
squeeze
.
inverse
(
x
)
return
x
def
_wrapper_spectral_norm
(
self
,
layer
,
shapes
,
kernel_size
):
if
kernel_size
==
1
:
# use spectral norm fc, because bound are tight for 1x1 convolutions
return
spectral_norm_fc
(
layer
,
self
.
coeff
,
n_power_iterations
=
self
.
n_power_iter
)
else
:
# use spectral norm based on conv, because bound not tight
return
spectral_norm_conv
(
layer
,
self
.
coeff
,
shapes
,
n_power_iterations
=
self
.
n_power_iter
)
\ No newline at end of file
codes/compressai/models/ours.py
0 → 100644
View file @
9e459ea3
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.init
as
init
import
torch.nn.functional
as
F
import
warnings
from
.priors
import
JointAutoregressiveHierarchicalPriors
from
.our_utils
import
*
from
compressai.layers
import
*
from
.waseda
import
Cheng2020Anchor
class
InvCompress
(
Cheng2020Anchor
):
def
__init__
(
self
,
N
=
192
,
**
kwargs
):
super
().
__init__
(
N
=
N
)
self
.
g_a
=
None
self
.
g_s
=
None
self
.
enh
=
EnhModule
(
64
)
self
.
inv
=
InvComp
(
M
=
N
)
self
.
attention
=
AttModule
(
N
)
def
g_a_func
(
self
,
x
):
x
=
self
.
enh
(
x
)
x
=
self
.
inv
(
x
)
x
=
self
.
attention
(
x
)
return
x
def
g_s_func
(
self
,
x
):
x
=
self
.
attention
(
x
,
rev
=
True
)
x
=
self
.
inv
(
x
,
rev
=
True
)
x
=
self
.
enh
(
x
,
rev
=
True
)
return
x
def
forward
(
self
,
x
):
y
=
self
.
g_a_func
(
x
)
z
=
self
.
h_a
(
y
)
z_hat
,
z_likelihoods
=
self
.
entropy_bottleneck
(
z
)
params
=
self
.
h_s
(
z_hat
)
y_hat
=
self
.
gaussian_conditional
.
quantize
(
y
,
"noise"
if
self
.
training
else
"dequantize"
)
ctx_params
=
self
.
context_prediction
(
y_hat
)
gaussian_params
=
self
.
entropy_parameters
(
torch
.
cat
((
params
,
ctx_params
),
dim
=
1
)
)
scales_hat
,
means_hat
=
gaussian_params
.
chunk
(
2
,
1
)
_
,
y_likelihoods
=
self
.
gaussian_conditional
(
y
,
scales_hat
,
means
=
means_hat
)
x_hat
=
self
.
g_s_func
(
y_hat
)
return
{
"x_hat"
:
x_hat
,
"likelihoods"
:
{
"y"
:
y_likelihoods
,
"z"
:
z_likelihoods
}
}
@
classmethod
def
from_state_dict
(
cls
,
state_dict
):
"""Return a new model instance from `state_dict`."""
N
=
state_dict
[
"h_a.0.weight"
].
size
(
0
)
net
=
cls
(
N
)
net
.
load_state_dict
(
state_dict
)
return
net
def
compress
(
self
,
x
):
if
next
(
self
.
parameters
()).
device
!=
torch
.
device
(
"cpu"
):
warnings
.
warn
(
"Inference on GPU is not recommended for the autoregressive "
"models (the entropy coder is run sequentially on CPU)."
)
y
=
self
.
g_a_func
(
x
)
z
=
self
.
h_a
(
y
)
z_strings
=
self
.
entropy_bottleneck
.
compress
(
z
)
z_hat
=
self
.
entropy_bottleneck
.
decompress
(
z_strings
,
z
.
size
()[
-
2
:])
params
=
self
.
h_s
(
z_hat
)
s
=
4
# scaling factor between z and y
kernel_size
=
5
# context prediction kernel size
padding
=
(
kernel_size
-
1
)
//
2
y_height
=
z_hat
.
size
(
2
)
*
s
y_width
=
z_hat
.
size
(
3
)
*
s
y_hat
=
F
.
pad
(
y
,
(
padding
,
padding
,
padding
,
padding
))
y_strings
=
[]
for
i
in
range
(
y
.
size
(
0
)):
string
=
self
.
_compress_ar
(
y_hat
[
i
:
i
+
1
],
params
[
i
:
i
+
1
],
y_height
,
y_width
,
kernel_size
,
padding
,
)
y_strings
.
append
(
string
)
return
{
"strings"
:
[
y_strings
,
z_strings
],
"shape"
:
z
.
size
()[
-
2
:],
"y"
:
y
}
def
decompress
(
self
,
strings
,
shape
):
assert
isinstance
(
strings
,
list
)
and
len
(
strings
)
==
2
if
next
(
self
.
parameters
()).
device
!=
torch
.
device
(
"cpu"
):
warnings
.
warn
(
"Inference on GPU is not recommended for the autoregressive "
"models (the entropy coder is run sequentially on CPU)."
)
z_hat
=
self
.
entropy_bottleneck
.
decompress
(
strings
[
1
],
shape
)
params
=
self
.
h_s
(
z_hat
)
s
=
4
# scaling factor between z and y
kernel_size
=
5
# context prediction kernel size
padding
=
(
kernel_size
-
1
)
//
2
y_height
=
z_hat
.
size
(
2
)
*
s
y_width
=
z_hat
.
size
(
3
)
*
s
y_hat
=
torch
.
zeros
(
(
z_hat
.
size
(
0
),
self
.
M
,
y_height
+
2
*
padding
,
y_width
+
2
*
padding
),
device
=
z_hat
.
device
,
)
for
i
,
y_string
in
enumerate
(
strings
[
0
]):
self
.
_decompress_ar
(
y_string
,
y_hat
[
i
:
i
+
1
],
params
[
i
:
i
+
1
],
y_height
,
y_width
,
kernel_size
,
padding
,
)
y_hat
=
F
.
pad
(
y_hat
,
(
-
padding
,
-
padding
,
-
padding
,
-
padding
))
x_hat
=
self
.
g_s_func
(
y_hat
).
clamp_
(
0
,
1
)
return
{
"x_hat"
:
x_hat
}
codes/compressai/models/priors.py
0 → 100644
View file @
9e459ea3
This diff is collapsed.
Click to expand it.
codes/compressai/models/utils.py
0 → 100644
View file @
9e459ea3
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.nn
as
nn
def
find_named_module
(
module
,
query
):
"""Helper function to find a named module. Returns a `nn.Module` or `None`
Args:
module (nn.Module): the root module
query (str): the module name to find
Returns:
nn.Module or None
"""
return
next
((
m
for
n
,
m
in
module
.
named_modules
()
if
n
==
query
),
None
)
def
find_named_buffer
(
module
,
query
):
"""Helper function to find a named buffer. Returns a `torch.Tensor` or `None`
Args:
module (nn.Module): the root module
query (str): the buffer name to find
Returns:
torch.Tensor or None
"""
return
next
((
b
for
n
,
b
in
module
.
named_buffers
()
if
n
==
query
),
None
)
def
_update_registered_buffer
(
module
,
buffer_name
,
state_dict_key
,
state_dict
,
policy
=
"resize_if_empty"
,
dtype
=
torch
.
int
,
):
new_size
=
state_dict
[
state_dict_key
].
size
()
registered_buf
=
find_named_buffer
(
module
,
buffer_name
)
if
policy
in
(
"resize_if_empty"
,
"resize"
):
if
registered_buf
is
None
:
raise
RuntimeError
(
f
'buffer "
{
buffer_name
}
" was not registered'
)
if
policy
==
"resize"
or
registered_buf
.
numel
()
==
0
:
registered_buf
.
resize_
(
new_size
)
elif
policy
==
"register"
:
if
registered_buf
is
not
None
:
raise
RuntimeError
(
f
'buffer "
{
buffer_name
}
" was already registered'
)
module
.
register_buffer
(
buffer_name
,
torch
.
empty
(
new_size
,
dtype
=
dtype
).
fill_
(
0
))
else
:
raise
ValueError
(
f
'Invalid policy "
{
policy
}
"'
)
def
update_registered_buffers
(
module
,
module_name
,
buffer_names
,
state_dict
,
policy
=
"resize_if_empty"
,
dtype
=
torch
.
int
,
):
"""Update the registered buffers in a module according to the tensors sized
in a state_dict.
(There's no way in torch to directly load a buffer with a dynamic size)
Args:
module (nn.Module): the module
module_name (str): module name in the state dict
buffer_names (list(str)): list of the buffer names to resize in the module
state_dict (dict): the state dict
policy (str): Update policy, choose from
('resize_if_empty', 'resize', 'register')
dtype (dtype): Type of buffer to be registered (when policy is 'register')
"""
valid_buffer_names
=
[
n
for
n
,
_
in
module
.
named_buffers
()]
for
buffer_name
in
buffer_names
:
if
buffer_name
not
in
valid_buffer_names
:
raise
ValueError
(
f
'Invalid buffer name "
{
buffer_name
}
"'
)
for
buffer_name
in
buffer_names
:
_update_registered_buffer
(
module
,
buffer_name
,
f
"
{
module_name
}
.
{
buffer_name
}
"
,
state_dict
,
policy
,
dtype
,
)
def
conv
(
in_channels
,
out_channels
,
kernel_size
=
5
,
stride
=
2
):
return
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
kernel_size
//
2
,
)
def
deconv
(
in_channels
,
out_channels
,
kernel_size
=
5
,
stride
=
2
):
return
nn
.
ConvTranspose2d
(
in_channels
,
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
output_padding
=
stride
-
1
,
padding
=
kernel_size
//
2
,
)
codes/compressai/models/waseda.py
0 → 100644
View file @
9e459ea3
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch.nn
as
nn
from
compressai.layers
import
(
AttentionBlock
,
ResidualBlock
,
ResidualBlockUpsample
,
ResidualBlockWithStride
,
conv3x3
,
subpel_conv3x3
,
)
from
.priors
import
JointAutoregressiveHierarchicalPriors
class
Cheng2020Anchor
(
JointAutoregressiveHierarchicalPriors
):
"""Anchor model variant from `"Learned Image Compression with
Discretized Gaussian Mixture Likelihoods and Attention Modules"
<https://arxiv.org/abs/2001.01568>`_, by Zhengxue Cheng, Heming Sun, Masaru
Takeuchi, Jiro Katto.
Uses residual blocks with small convolutions (3x3 and 1x1), and sub-pixel
convolutions for up-sampling.
Args:
N (int): Number of channels
"""
def
__init__
(
self
,
N
=
192
,
**
kwargs
):
super
().
__init__
(
N
=
N
,
M
=
N
,
**
kwargs
)
self
.
g_a
=
nn
.
Sequential
(
ResidualBlockWithStride
(
3
,
N
,
stride
=
2
),
ResidualBlock
(
N
,
N
),
ResidualBlockWithStride
(
N
,
N
,
stride
=
2
),
ResidualBlock
(
N
,
N
),
ResidualBlockWithStride
(
N
,
N
,
stride
=
2
),
ResidualBlock
(
N
,
N
),
conv3x3
(
N
,
N
,
stride
=
2
),
)
self
.
h_a
=
nn
.
Sequential
(
conv3x3
(
N
,
N
),
nn
.
LeakyReLU
(
inplace
=
True
),
conv3x3
(
N
,
N
),
nn
.
LeakyReLU
(
inplace
=
True
),
conv3x3
(
N
,
N
,
stride
=
2
),
nn
.
LeakyReLU
(
inplace
=
True
),
conv3x3
(
N
,
N
),
nn
.
LeakyReLU
(
inplace
=
True
),
conv3x3
(
N
,
N
,
stride
=
2
),
)
self
.
h_s
=
nn
.
Sequential
(
conv3x3
(
N
,
N
),
nn
.
LeakyReLU
(
inplace
=
True
),
subpel_conv3x3
(
N
,
N
,
2
),
nn
.
LeakyReLU
(
inplace
=
True
),
conv3x3
(
N
,
N
*
3
//
2
),
nn
.
LeakyReLU
(
inplace
=
True
),
subpel_conv3x3
(
N
*
3
//
2
,
N
*
3
//
2
,
2
),
nn
.
LeakyReLU
(
inplace
=
True
),
conv3x3
(
N
*
3
//
2
,
N
*
2
),
)
self
.
g_s
=
nn
.
Sequential
(
ResidualBlock
(
N
,
N
),
ResidualBlockUpsample
(
N
,
N
,
2
),
ResidualBlock
(
N
,
N
),
ResidualBlockUpsample
(
N
,
N
,
2
),
ResidualBlock
(
N
,
N
),
ResidualBlockUpsample
(
N
,
N
,
2
),
ResidualBlock
(
N
,
N
),
subpel_conv3x3
(
N
,
3
,
2
),
)
@
classmethod
def
from_state_dict
(
cls
,
state_dict
):
"""Return a new model instance from `state_dict`."""
N
=
state_dict
[
"g_a.0.conv1.weight"
].
size
(
0
)
net
=
cls
(
N
)
net
.
load_state_dict
(
state_dict
)
return
net
class
Cheng2020Attention
(
Cheng2020Anchor
):
"""Self-attention model variant from `"Learned Image Compression with
Discretized Gaussian Mixture Likelihoods and Attention Modules"
<https://arxiv.org/abs/2001.01568>`_, by Zhengxue Cheng, Heming Sun, Masaru
Takeuchi, Jiro Katto.
Uses self-attention, residual blocks with small convolutions (3x3 and 1x1),
and sub-pixel convolutions for up-sampling.
Args:
N (int): Number of channels
"""
def
__init__
(
self
,
N
=
192
,
**
kwargs
):
super
().
__init__
(
N
=
N
,
**
kwargs
)
self
.
g_a
=
nn
.
Sequential
(
ResidualBlockWithStride
(
3
,
N
,
stride
=
2
),
ResidualBlock
(
N
,
N
),
ResidualBlockWithStride
(
N
,
N
,
stride
=
2
),
AttentionBlock
(
N
),
ResidualBlock
(
N
,
N
),
ResidualBlockWithStride
(
N
,
N
,
stride
=
2
),
ResidualBlock
(
N
,
N
),
conv3x3
(
N
,
N
,
stride
=
2
),
AttentionBlock
(
N
),
)
self
.
g_s
=
nn
.
Sequential
(
AttentionBlock
(
N
),
ResidualBlock
(
N
,
N
),
ResidualBlockUpsample
(
N
,
N
,
2
),
ResidualBlock
(
N
,
N
),
ResidualBlockUpsample
(
N
,
N
,
2
),
AttentionBlock
(
N
),
ResidualBlock
(
N
,
N
),
ResidualBlockUpsample
(
N
,
N
,
2
),
ResidualBlock
(
N
,
N
),
subpel_conv3x3
(
N
,
3
,
2
),
)
codes/compressai/ops/__init__.py
0 → 100644
View file @
9e459ea3
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.bound_ops
import
LowerBound
from
.ops
import
ste_round
from
.parametrizers
import
NonNegativeParametrizer
__all__
=
[
"ste_round"
,
"LowerBound"
,
"NonNegativeParametrizer"
]
codes/compressai/ops/bound_ops.py
0 → 100644
View file @
9e459ea3
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.nn
as
nn
class
LowerBoundFunction
(
torch
.
autograd
.
Function
):
"""Autograd function for the `LowerBound` operator."""
@
staticmethod
def
forward
(
ctx
,
input_
,
bound
):
ctx
.
save_for_backward
(
input_
,
bound
)
return
torch
.
max
(
input_
,
bound
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input_
,
bound
=
ctx
.
saved_tensors
pass_through_if
=
(
input_
>=
bound
)
|
(
grad_output
<
0
)
return
pass_through_if
.
type
(
grad_output
.
dtype
)
*
grad_output
,
None
class
LowerBound
(
nn
.
Module
):
"""Lower bound operator, computes `torch.max(x, bound)` with a custom
gradient.
The derivative is replaced by the identity function when `x` is moved
towards the `bound`, otherwise the gradient is kept to zero.
"""
def
__init__
(
self
,
bound
):
super
().
__init__
()
self
.
register_buffer
(
"bound"
,
torch
.
Tensor
([
float
(
bound
)]))
@
torch
.
jit
.
unused
def
lower_bound
(
self
,
x
):
return
LowerBoundFunction
.
apply
(
x
,
self
.
bound
)
def
forward
(
self
,
x
):
if
torch
.
jit
.
is_scripting
():
return
torch
.
max
(
x
,
self
.
bound
)
return
self
.
lower_bound
(
x
)
codes/compressai/ops/ops.py
0 → 100644
View file @
9e459ea3
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
def
ste_round
(
x
):
"""
Rounding with non-zero gradients. Gradients are approximated by replacing
the derivative by the identity function.
Used in `"Lossy Image Compression with Compressive Autoencoders"
<https://arxiv.org/abs/1703.00395>`_
.. note::
Implemented with the pytorch `detach()` reparametrization trick:
`x_round = x_round - x.detach() + x`
"""
return
torch
.
round
(
x
)
-
x
.
detach
()
+
x
codes/compressai/ops/parametrizers.py
0 → 100644
View file @
9e459ea3
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.nn
as
nn
from
.bound_ops
import
LowerBound
class
NonNegativeParametrizer
(
nn
.
Module
):
"""
Non negative reparametrization.
Used for stability during training.
"""
def
__init__
(
self
,
minimum
=
0
,
reparam_offset
=
2
**
-
18
):
super
().
__init__
()
self
.
minimum
=
float
(
minimum
)
self
.
reparam_offset
=
float
(
reparam_offset
)
pedestal
=
self
.
reparam_offset
**
2
self
.
register_buffer
(
"pedestal"
,
torch
.
Tensor
([
pedestal
]))
bound
=
(
self
.
minimum
+
self
.
reparam_offset
**
2
)
**
0.5
self
.
lower_bound
=
LowerBound
(
bound
)
def
init
(
self
,
x
):
return
torch
.
sqrt
(
torch
.
max
(
x
+
self
.
pedestal
,
self
.
pedestal
))
def
forward
(
self
,
x
):
out
=
self
.
lower_bound
(
x
)
out
=
out
**
2
-
self
.
pedestal
return
out
codes/compressai/transforms/__init__.py
0 → 100644
View file @
9e459ea3
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.transforms
import
*
codes/compressai/transforms/functional.py
0 → 100644
View file @
9e459ea3
from
typing
import
Tuple
,
Union
import
torch
import
torch.nn.functional
as
F
from
torch
import
Tensor
YCBCR_WEIGHTS
=
{
# Spec: (K_r, K_g, K_b) with K_g = 1 - K_r - K_b
"ITU-R_BT.709"
:
(
0.2126
,
0.7152
,
0.0722
)
}
def
_check_input_tensor
(
tensor
:
Tensor
)
->
None
:
if
(
not
isinstance
(
tensor
,
Tensor
)
or
not
tensor
.
is_floating_point
()
or
not
len
(
tensor
.
size
())
in
(
3
,
4
)
or
not
tensor
.
size
(
-
3
)
==
3
):
raise
ValueError
(
"Expected a 3D or 4D tensor with shape (Nx3xHxW) or (3xHxW) as input"
)
def
rgb2ycbcr
(
rgb
:
Tensor
)
->
Tensor
:
"""RGB to YCbCr conversion for torch Tensor.
Using ITU-R BT.709 coefficients.
Args:
rgb (torch.Tensor): 3D or 4D floating point RGB tensor
Returns:
ycbcr (torch.Tensor): converted tensor
"""
_check_input_tensor
(
rgb
)
r
,
g
,
b
=
rgb
.
chunk
(
3
,
-
3
)
Kr
,
Kg
,
Kb
=
YCBCR_WEIGHTS
[
"ITU-R_BT.709"
]
y
=
Kr
*
r
+
Kg
*
g
+
Kb
*
b
cb
=
0.5
*
(
b
-
y
)
/
(
1
-
Kb
)
+
0.5
cr
=
0.5
*
(
r
-
y
)
/
(
1
-
Kr
)
+
0.5
ycbcr
=
torch
.
cat
((
y
,
cb
,
cr
),
dim
=-
3
)
return
ycbcr
def
ycbcr2rgb
(
ycbcr
:
Tensor
)
->
Tensor
:
"""YCbCr to RGB conversion for torch Tensor.
Using ITU-R BT.709 coefficients.
Args:
ycbcr (torch.Tensor): 3D or 4D floating point RGB tensor
Returns:
rgb (torch.Tensor): converted tensor
"""
_check_input_tensor
(
ycbcr
)
y
,
cb
,
cr
=
ycbcr
.
chunk
(
3
,
-
3
)
Kr
,
Kg
,
Kb
=
YCBCR_WEIGHTS
[
"ITU-R_BT.709"
]
r
=
y
+
(
2
-
2
*
Kr
)
*
(
cr
-
0.5
)
b
=
y
+
(
2
-
2
*
Kb
)
*
(
cb
-
0.5
)
g
=
(
y
-
Kr
*
r
-
Kb
*
b
)
/
Kg
rgb
=
torch
.
cat
((
r
,
g
,
b
),
dim
=-
3
)
return
rgb
def
yuv_444_to_420
(
yuv
:
Union
[
Tensor
,
Tuple
[
Tensor
,
Tensor
,
Tensor
]],
mode
:
str
=
"avg_pool"
,
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
"""Convert a 444 tensor to a 420 representation.
Args:
yuv (torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)): 444
input to be downsampled. Takes either a (Nx3xHxW) tensor or a tuple
of 3 (Nx1xHxW) tensors.
mode (str): algorithm used for downsampling: ``'avg_pool'``. Default
``'avg_pool'``
Returns:
(torch.Tensor, torch.Tensor, torch.Tensor): Converted 420
"""
if
mode
not
in
(
"avg_pool"
,):
raise
ValueError
(
f
'Invalid downsampling mode "
{
mode
}
".'
)
if
mode
==
"avg_pool"
:
def
_downsample
(
tensor
):
return
F
.
avg_pool2d
(
tensor
,
kernel_size
=
2
,
stride
=
2
)
if
isinstance
(
yuv
,
torch
.
Tensor
):
y
,
u
,
v
=
yuv
.
chunk
(
3
,
1
)
else
:
y
,
u
,
v
=
yuv
return
(
y
,
_downsample
(
u
),
_downsample
(
v
))
def
yuv_420_to_444
(
yuv
:
Tuple
[
Tensor
,
Tensor
,
Tensor
],
mode
:
str
=
"bilinear"
,
return_tuple
:
bool
=
False
,
)
->
Union
[
Tensor
,
Tuple
[
Tensor
,
Tensor
,
Tensor
]]:
"""Convert a 420 input to a 444 representation.
Args:
yuv (torch.Tensor, torch.Tensor, torch.Tensor): 420 input frames in
(Nx1xHxW) format
mode (str): algorithm used for upsampling: ``'bilinear'`` |
``'nearest'`` Default ``'bilinear'``
return_tuple (bool): return input as tuple of tensors instead of a
concatenated tensor, 3 (Nx1xHxW) tensors instead of one (Nx3xHxW)
tensor (default: False)
Returns:
(torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)): Converted
444
"""
if
len
(
yuv
)
!=
3
or
any
(
not
isinstance
(
c
,
torch
.
Tensor
)
for
c
in
yuv
):
raise
ValueError
(
"Expected a tuple of 3 torch tensors"
)
if
mode
not
in
(
"bilinear"
,
"nearest"
):
raise
ValueError
(
f
'Invalid upsampling mode "
{
mode
}
".'
)
if
mode
in
(
"bilinear"
,
"nearest"
):
def
_upsample
(
tensor
):
return
F
.
interpolate
(
tensor
,
scale_factor
=
2
,
mode
=
mode
,
align_corners
=
False
)
y
,
u
,
v
=
yuv
u
,
v
=
_upsample
(
u
),
_upsample
(
v
)
if
return_tuple
:
return
y
,
u
,
v
return
torch
.
cat
((
y
,
u
,
v
),
dim
=
1
)
codes/compressai/transforms/transforms.py
0 → 100644
View file @
9e459ea3
from
.
import
functional
as
F_transforms
__all__
=
[
"RGB2YCbCr"
,
"YCbCr2RGB"
,
"YUV444To420"
,
"YUV420To444"
,
]
class
RGB2YCbCr
:
"""Convert a RGB tensor to YCbCr.
The tensor is expected to be in the [0, 1] floating point range, with a
shape of (3xHxW) or (Nx3xHxW).
"""
def
__call__
(
self
,
rgb
):
"""
Args:
rgb (torch.Tensor): 3D or 4D floating point RGB tensor
Returns:
ycbcr(torch.Tensor): converted tensor
"""
return
F_transforms
.
rgb2ycbcr
(
rgb
)
def
__repr__
(
self
):
return
f
"
{
self
.
__class__
.
__name__
}
()"
class
YCbCr2RGB
:
"""Convert a YCbCr tensor to RGB.
The tensor is expected to be in the [0, 1] floating point range, with a
shape of (3xHxW) or (Nx3xHxW).
"""
def
__call__
(
self
,
ycbcr
):
"""
Args:
ycbcr(torch.Tensor): 3D or 4D floating point RGB tensor
Returns:
rgb(torch.Tensor): converted tensor
"""
return
F_transforms
.
ycbcr2rgb
(
ycbcr
)
def
__repr__
(
self
):
return
f
"
{
self
.
__class__
.
__name__
}
()"
class
YUV444To420
:
"""Convert a YUV 444 tensor to a 420 representation.
Args:
mode (str): algorithm used for downsampling: ``'avg_pool'``. Default
``'avg_pool'``
Example:
>>> x = torch.rand(1, 3, 32, 32)
>>> y, u, v = YUV444To420()(x)
>>> y.size() # 1, 1, 32, 32
>>> u.size() # 1, 1, 16, 16
"""
def
__init__
(
self
,
mode
:
str
=
"avg_pool"
):
self
.
mode
=
str
(
mode
)
def
__call__
(
self
,
yuv
):
"""
Args:
yuv (torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)):
444 input to be downsampled. Takes either a (Nx3xHxW) tensor or
a tuple of 3 (Nx1xHxW) tensors.
Returns:
(torch.Tensor, torch.Tensor, torch.Tensor): Converted 420
"""
return
F_transforms
.
yuv_444_to_420
(
yuv
,
mode
=
self
.
mode
)
def
__repr__
(
self
):
return
f
"
{
self
.
__class__
.
__name__
}
()"
class
YUV420To444
:
"""Convert a YUV 420 input to a 444 representation.
Args:
mode (str): algorithm used for upsampling: ``'bilinear'`` | ``'nearest'``.
Default ``'bilinear'``
return_tuple (bool): return input as tuple of tensors instead of a
concatenated tensor, 3 (Nx1xHxW) tensors instead of one (Nx3xHxW)
tensor (default: False)
Example:
>>> y = torch.rand(1, 1, 32, 32)
>>> u, v = torch.rand(1, 1, 16, 16), torch.rand(1, 1, 16, 16)
>>> x = YUV420To444()((y, u, v))
>>> x.size() # 1, 3, 32, 32
"""
def
__init__
(
self
,
mode
:
str
=
"bilinear"
,
return_tuple
:
bool
=
False
):
self
.
mode
=
str
(
mode
)
self
.
return_tuple
=
bool
(
return_tuple
)
def
__call__
(
self
,
yuv
):
"""
Args:
yuv (torch.Tensor, torch.Tensor, torch.Tensor): 420 input frames in
(Nx1xHxW) format
Returns:
(torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)): Converted
444
"""
return
F_transforms
.
yuv_420_to_444
(
yuv
,
return_tuple
=
self
.
return_tuple
)
def
__repr__
(
self
):
return
f
"
{
self
.
__class__
.
__name__
}
(return_tuple=
{
self
.
return_tuple
}
)"
codes/compressai/utils/__init__.py
0 → 100644
View file @
9e459ea3
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
codes/compressai/utils/bench/__init__.py
0 → 100644
View file @
9e459ea3
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
codes/compressai/utils/bench/__main__.py
0 → 100644
View file @
9e459ea3
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Collect performance metrics of published traditional or end-to-end image
codecs.
"""
import
argparse
import
json
import
multiprocessing
as
mp
import
os
import
sys
from
collections
import
defaultdict
from
itertools
import
starmap
from
typing
import
List
from
.codecs
import
AV1
,
BPG
,
HM
,
JPEG
,
JPEG2000
,
TFCI
,
VTM
,
Codec
,
WebP
# from torchvision.datasets.folder
IMG_EXTENSIONS
=
(
".jpg"
,
".jpeg"
,
".png"
,
".ppm"
,
".bmp"
,
".pgm"
,
".tif"
,
".tiff"
,
".webp"
,
)
codecs
=
[
JPEG
,
WebP
,
JPEG2000
,
BPG
,
TFCI
,
VTM
,
HM
,
AV1
]
# we need the quality index (not value) to compute the stats later
def
func
(
codec
,
i
,
*
args
):
rv
=
codec
.
run
(
*
args
)
return
i
,
rv
def
collect
(
codec
:
Codec
,
dataset
:
str
,
qualities
:
List
[
int
],
num_jobs
:
int
=
1
):
filepaths
=
[
os
.
path
.
join
(
dataset
,
f
)
for
f
in
os
.
listdir
(
dataset
)
if
os
.
path
.
splitext
(
f
)[
-
1
].
lower
()
in
IMG_EXTENSIONS
]
# print(filepaths)
pool
=
mp
.
Pool
(
num_jobs
)
if
num_jobs
>
1
else
None
if
len
(
filepaths
)
==
0
:
print
(
"No images found in the dataset directory"
)
sys
.
exit
(
1
)
args
=
[(
codec
,
i
,
f
,
q
)
for
i
,
q
in
enumerate
(
qualities
)
for
f
in
filepaths
]
if
pool
:
rv
=
pool
.
starmap
(
func
,
args
)
else
:
rv
=
list
(
starmap
(
func
,
args
))
results
=
[
defaultdict
(
float
)
for
_
in
range
(
len
(
qualities
))]
for
i
,
metrics
in
rv
:
for
k
,
v
in
metrics
.
items
():
results
[
i
][
k
]
+=
v
for
i
,
_
in
enumerate
(
results
):
for
k
,
v
in
results
[
i
].
items
():
results
[
i
][
k
]
=
v
/
len
(
filepaths
)
# list of dict -> dict of list
out
=
defaultdict
(
list
)
for
r
in
results
:
for
k
,
v
in
r
.
items
():
out
[
k
].
append
(
v
)
return
out
def
setup_args
():
description
=
"Collect codec metrics."
parser
=
argparse
.
ArgumentParser
(
description
=
description
)
subparsers
=
parser
.
add_subparsers
(
dest
=
"codec"
,
help
=
"Select codec"
)
subparsers
.
required
=
True
return
parser
,
subparsers
def
setup_common_args
(
parser
):
parser
.
add_argument
(
"dataset"
,
type
=
str
)
parser
.
add_argument
(
"-j"
,
"--num-jobs"
,
type
=
int
,
metavar
=
"N"
,
default
=
1
,
help
=
"Number of parallel jobs (default: %(default)s)"
,
)
parser
.
add_argument
(
"-q"
,
"--quality"
,
dest
=
"qualities"
,
metavar
=
"Q"
,
default
=
[
5
,
10
,
20
,
30
,
40
,
50
,
60
,
70
,
80
,
90
],
nargs
=
"*"
,
type
=
int
,
help
=
"quality parameter (default: %(default)s)"
,
)
# [3,5,10,15,17,20,22,25,27,30,32,35,37,40,42,45,47,50],
# new added
parser
.
add_argument
(
"--name"
,
dest
=
"name"
,
default
=
"ans"
,
type
=
str
,
help
=
"name for json"
,
)
def
main
(
argv
):
import
time
start
=
time
.
time
()
parser
,
subparsers
=
setup_args
()
for
c
in
codecs
:
cparser
=
subparsers
.
add_parser
(
c
.
__name__
.
lower
(),
help
=
f
"
{
c
.
__name__
}
"
)
setup_common_args
(
cparser
)
c
.
setup_args
(
cparser
)
args
=
parser
.
parse_args
(
argv
)
codec_cls
=
next
(
c
for
c
in
codecs
if
c
.
__name__
.
lower
()
==
args
.
codec
)
codec
=
codec_cls
(
args
)
results
=
collect
(
codec
,
args
.
dataset
,
args
.
qualities
,
args
.
num_jobs
)
output
=
{
"name"
:
codec
.
name
,
"description"
:
codec
.
description
,
"results"
:
results
,
}
print
(
json
.
dumps
(
output
,
indent
=
2
))
end
=
time
.
time
()
print
(
'total time:'
,
end
-
start
)
output_dir
=
'/home/felix/disk2/compressai_v2/codes/results/log'
output_json_path
=
os
.
path
.
join
(
output_dir
,
args
.
name
+
'.json'
)
with
open
(
output_json_path
,
'w'
)
as
f
:
json
.
dump
(
output
,
f
,
indent
=
2
)
if
__name__
==
"__main__"
:
main
(
sys
.
argv
[
1
:])
codes/compressai/utils/bench/codecs.py
0 → 100644
View file @
9e459ea3
This diff is collapsed.
Click to expand it.
codes/compressai/utils/eval_model/__init__.py
0 → 100644
View file @
9e459ea3
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
codes/compressai/utils/eval_model/__main__.py
0 → 100644
View file @
9e459ea3
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Evaluate an end-to-end compression model on an image dataset.
"""
import
argparse
import
json
import
math
import
os
import
sys
import
time
from
collections
import
defaultdict
from
typing
import
List
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
PIL
import
Image
from
pytorch_msssim
import
ms_ssim
from
torchvision
import
transforms
import
compressai
from
compressai.zoo
import
models
as
pretrained_models
from
compressai.zoo.image
import
model_architectures
as
architectures
# from torchvision.datasets.folder
IMG_EXTENSIONS
=
(
".jpg"
,
".jpeg"
,
".png"
,
".ppm"
,
".bmp"
,
".pgm"
,
".tif"
,
".tiff"
,
".webp"
,
)
def
collect_images
(
rootpath
:
str
)
->
List
[
str
]:
return
[
os
.
path
.
join
(
rootpath
,
f
)
for
f
in
os
.
listdir
(
rootpath
)
if
os
.
path
.
splitext
(
f
)[
-
1
].
lower
()
in
IMG_EXTENSIONS
]
def
psnr
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
)
->
float
:
mse
=
F
.
mse_loss
(
a
,
b
).
item
()
return
-
10
*
math
.
log10
(
mse
)
def
read_image
(
filepath
:
str
)
->
torch
.
Tensor
:
assert
os
.
path
.
isfile
(
filepath
)
img
=
Image
.
open
(
filepath
).
convert
(
"RGB"
)
# test_transforms = transforms.Compose(
# [transforms.CenterCrop(256), transforms.ToTensor()]
# )
# return test_transforms(img)
return
transforms
.
ToTensor
()(
img
)
@
torch
.
no_grad
()
def
inference
(
model
,
x
,
savedir
=
""
,
idx
=
1
):
x
=
x
.
unsqueeze
(
0
)
h
,
w
=
x
.
size
(
2
),
x
.
size
(
3
)
p
=
64
# maximum 6 strides of 2
new_h
=
(
h
+
p
-
1
)
//
p
*
p
new_w
=
(
w
+
p
-
1
)
//
p
*
p
padding_left
=
(
new_w
-
w
)
//
2
padding_right
=
new_w
-
w
-
padding_left
padding_top
=
(
new_h
-
h
)
//
2
padding_bottom
=
new_h
-
h
-
padding_top
x_padded
=
F
.
pad
(
x
,
(
padding_left
,
padding_right
,
padding_top
,
padding_bottom
),
mode
=
"constant"
,
value
=
0
,
)
start
=
time
.
time
()
out_enc
=
model
.
compress
(
x_padded
)
enc_time
=
time
.
time
()
-
start
start
=
time
.
time
()
out_dec
=
model
.
decompress
(
out_enc
[
"strings"
],
out_enc
[
"shape"
])
dec_time
=
time
.
time
()
-
start
out_dec
[
"x_hat"
]
=
F
.
pad
(
out_dec
[
"x_hat"
],
(
-
padding_left
,
-
padding_right
,
-
padding_top
,
-
padding_bottom
)
)
num_pixels
=
x
.
size
(
0
)
*
x
.
size
(
2
)
*
x
.
size
(
3
)
bpp
=
sum
(
len
(
s
[
0
])
for
s
in
out_enc
[
"strings"
])
*
8.0
/
num_pixels
if
savedir
!=
""
:
if
not
os
.
path
.
exists
(
savedir
):
os
.
makedirs
(
savedir
)
cur_psnr
=
psnr
(
x
,
out_dec
[
"x_hat"
])
cur_ssim
=
ms_ssim
(
x
,
out_dec
[
"x_hat"
],
data_range
=
1.0
).
item
()
tran1
=
transforms
.
ToPILImage
()
cur_img
=
tran1
(
out_dec
[
"x_hat"
][
0
])
cur_img
.
save
(
os
.
path
.
join
(
savedir
,
'{:02d}'
.
format
(
idx
)
+
"_"
+
'{:.2f}'
.
format
(
cur_psnr
)
+
"_"
+
'{:.3f}'
.
format
(
bpp
)
+
"_"
+
'{:.3f}'
.
format
(
cur_ssim
)
+
".png"
))
return
{
"psnr"
:
psnr
(
x
,
out_dec
[
"x_hat"
]),
"ms-ssim"
:
ms_ssim
(
x
,
out_dec
[
"x_hat"
],
data_range
=
1.0
).
item
(),
"bpp"
:
bpp
,
"encoding_time"
:
enc_time
,
"decoding_time"
:
dec_time
,
}
@
torch
.
no_grad
()
def
inference_entropy_estimation
(
model
,
x
):
x
=
x
.
unsqueeze
(
0
)
start
=
time
.
time
()
out_net
=
model
.
forward
(
x
)
# print(out_net['x_hat'][0,0,:5,:5])
elapsed_time
=
time
.
time
()
-
start
num_pixels
=
x
.
size
(
0
)
*
x
.
size
(
2
)
*
x
.
size
(
3
)
bpp
=
sum
(
(
torch
.
log
(
likelihoods
).
sum
()
/
(
-
math
.
log
(
2
)
*
num_pixels
))
for
likelihoods
in
out_net
[
"likelihoods"
].
values
()
)
return
{
"psnr"
:
psnr
(
x
,
out_net
[
"x_hat"
]),
"bpp"
:
bpp
.
item
(),
"encoding_time"
:
elapsed_time
/
2.0
,
# broad estimation
"decoding_time"
:
elapsed_time
/
2.0
,
}
def
load_pretrained
(
model
:
str
,
metric
:
str
,
quality
:
int
)
->
nn
.
Module
:
return
pretrained_models
[
model
](
quality
=
quality
,
metric
=
metric
,
pretrained
=
True
).
eval
()
def
load_checkpoint
(
arch
:
str
,
checkpoint_path
:
str
)
->
nn
.
Module
:
return
architectures
[
arch
].
from_state_dict
(
torch
.
load
(
checkpoint_path
)).
eval
()
def
eval_model
(
model
,
filepaths
,
entropy_estimation
=
False
,
half
=
False
,
savedir
=
""
):
device
=
next
(
model
.
parameters
()).
device
metrics
=
defaultdict
(
float
)
for
idx
,
f
in
enumerate
(
sorted
(
filepaths
)):
x
=
read_image
(
f
).
to
(
device
)
if
not
entropy_estimation
:
print
(
'evaluating index'
,
idx
)
if
half
:
model
=
model
.
half
()
x
=
x
.
half
()
rv
=
inference
(
model
,
x
,
savedir
,
idx
)
else
:
rv
=
inference_entropy_estimation
(
model
,
x
)
print
(
'bpp'
,
rv
[
'bpp'
])
print
(
'psnr'
,
rv
[
'psnr'
])
print
(
'ms-ssim'
,
rv
[
'ms-ssim'
])
print
()
for
k
,
v
in
rv
.
items
():
metrics
[
k
]
+=
v
for
k
,
v
in
metrics
.
items
():
metrics
[
k
]
=
v
/
len
(
filepaths
)
return
metrics
def
setup_args
():
parent_parser
=
argparse
.
ArgumentParser
(
add_help
=
False
,
)
# Common options.
parent_parser
.
add_argument
(
"dataset"
,
type
=
str
,
help
=
"dataset path"
)
parent_parser
.
add_argument
(
"-a"
,
"--arch"
,
type
=
str
,
choices
=
pretrained_models
.
keys
(),
help
=
"model architecture"
,
required
=
True
,
)
parent_parser
.
add_argument
(
"-c"
,
"--entropy-coder"
,
choices
=
compressai
.
available_entropy_coders
(),
default
=
compressai
.
available_entropy_coders
()[
0
],
help
=
"entropy coder (default: %(default)s)"
,
)
parent_parser
.
add_argument
(
"--cuda"
,
action
=
"store_true"
,
help
=
"enable CUDA"
,
)
parent_parser
.
add_argument
(
"--half"
,
action
=
"store_true"
,
help
=
"convert model to half floating point (fp16)"
,
)
parent_parser
.
add_argument
(
"--entropy-estimation"
,
action
=
"store_true"
,
help
=
"use evaluated entropy estimation (no entropy coding)"
,
)
parent_parser
.
add_argument
(
"-v"
,
"--verbose"
,
action
=
"store_true"
,
help
=
"verbose mode"
,
)
parent_parser
.
add_argument
(
"-s"
,
"--savedir"
,
type
=
str
,
default
=
""
,
)
parent_parser
.
add_argument
(
"--gpu_id"
,
type
=
int
,
default
=
0
,
help
=
"GPU ID"
)
parser
=
argparse
.
ArgumentParser
(
description
=
"Evaluate a model on an image dataset."
,
add_help
=
True
)
subparsers
=
parser
.
add_subparsers
(
help
=
"model source"
,
dest
=
"source"
)
#, required=True
# )
# Options for pretrained models
pretrained_parser
=
subparsers
.
add_parser
(
"pretrained"
,
parents
=
[
parent_parser
])
pretrained_parser
.
add_argument
(
"-m"
,
"--metric"
,
type
=
str
,
choices
=
[
"mse"
,
"ms-ssim"
],
default
=
"mse"
,
help
=
"metric trained against (default: %(default)s)"
,
)
pretrained_parser
.
add_argument
(
"-q"
,
"--quality"
,
dest
=
"qualities"
,
nargs
=
"+"
,
type
=
int
,
default
=
(
1
,),
)
checkpoint_parser
=
subparsers
.
add_parser
(
"checkpoint"
,
parents
=
[
parent_parser
])
# checkpoint_parser.add_argument(
# "-p",
# "--path",
# dest="paths",
# type=str,
# nargs="*",
# required=True,
# help="checkpoint path",
# )
checkpoint_parser
.
add_argument
(
"-exp"
,
"--experiment"
,
type
=
str
,
required
=
True
,
help
=
"Experiment name"
)
return
parser
def
main
(
argv
):
args
=
setup_args
().
parse_args
(
argv
)
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
str
(
args
.
gpu_id
)
device
=
"cuda"
if
args
.
cuda
and
torch
.
cuda
.
is_available
()
else
"cpu"
filepaths
=
collect_images
(
args
.
dataset
)
if
len
(
filepaths
)
==
0
:
print
(
"No images found in directory."
)
sys
.
exit
(
1
)
compressai
.
set_entropy_coder
(
args
.
entropy_coder
)
if
args
.
source
==
"pretrained"
:
runs
=
sorted
(
args
.
qualities
)
opts
=
(
args
.
arch
,
args
.
metric
)
load_func
=
load_pretrained
log_fmt
=
"
\r
Evaluating {0} | {run:d}"
elif
args
.
source
==
"checkpoint"
:
# runs = args.paths
checkpoint_updated_dir
=
os
.
path
.
join
(
'../experiments'
,
args
.
experiment
,
'checkpoint_updated'
)
checkpoint_updated
=
os
.
path
.
join
(
checkpoint_updated_dir
,
os
.
listdir
(
checkpoint_updated_dir
)[
0
])
runs
=
[
checkpoint_updated
]
opts
=
(
args
.
arch
,)
load_func
=
load_checkpoint
log_fmt
=
"
\r
Evaluating {run:s}"
results
=
defaultdict
(
list
)
for
run
in
runs
:
if
args
.
verbose
:
sys
.
stderr
.
write
(
log_fmt
.
format
(
*
opts
,
run
=
run
))
sys
.
stderr
.
flush
()
model
=
load_func
(
*
opts
,
run
)
if
args
.
cuda
and
torch
.
cuda
.
is_available
():
model
=
model
.
to
(
"cuda"
)
metrics
=
eval_model
(
model
,
filepaths
,
args
.
entropy_estimation
,
args
.
half
,
args
.
savedir
)
for
k
,
v
in
metrics
.
items
():
results
[
k
].
append
(
v
)
if
args
.
verbose
:
sys
.
stderr
.
write
(
"
\n
"
)
sys
.
stderr
.
flush
()
description
=
(
"entropy estimation"
if
args
.
entropy_estimation
else
args
.
entropy_coder
)
output
=
{
"name"
:
args
.
arch
,
"description"
:
f
"Inference (
{
description
}
)"
,
"results"
:
results
,
}
print
(
json
.
dumps
(
output
,
indent
=
2
))
if
__name__
==
"__main__"
:
main
(
sys
.
argv
[
1
:])
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment