Commit 9e459ea3 authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
# list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import os
import sys
sys.path.insert(0, os.path.abspath("../compressai/"))
# -- Project information -----------------------------------------------------
project = "compressai"
copyright = "2020, InterDigital Communications, Inc."
author = "InterDigital Communications, Inc."
# -- General configuration ---------------------------------------------------
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.mathjax",
"sphinx.ext.napoleon",
"sphinx.ext.viewcode",
]
napoleon_use_ivar = True
# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = "sphinx_rtd_theme"
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ["_static"]
compressai.datasets
===================
.. currentmodule:: compressai.datasets
ImageFolder
-----------
.. autoclass:: ImageFolder
:members:
compressai.entropy_models
=========================
.. currentmodule:: compressai.entropy_models
EntropyBottleneck
-----------------
.. autoclass:: EntropyBottleneck
GaussianConditional
-------------------
.. autoclass:: GaussianConditional
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Based on https://github.com/facebookresearch/ParlAI/tree/c06c40603f45918f58cb09122fa8c74dd4047057/docs/source
import importlib
import io
from pathlib import Path
import compressai.utils
def get_utils():
rootdir = Path(compressai.utils.__file__).parent
for d in rootdir.iterdir():
if d.is_dir() and (d / "__main__.py").is_file():
yield d
def main():
fout = open("cli_usage.inc", "w")
for p in get_utils():
try:
m = importlib.import_module(f"compressai.utils.{p.name}.__main__")
except ImportError:
continue
if not hasattr(m, "setup_args"):
continue
fout.write(p.name)
fout.write("\n")
fout.write("-" * len(p.name))
fout.write("\n")
doc = m.__doc__
if doc:
fout.write(doc)
fout.write("\n")
fout.write(".. code-block:: text\n\n")
capture = io.StringIO()
parser = m.setup_args()
if isinstance(parser, tuple):
parser = parser[0]
parser.prog = f"python -m compressai.utils.{p.name}"
parser.print_help(capture)
for line in capture.getvalue().split("\n"):
fout.write(f"\t{line}\n")
fout.write("\n\n")
fout.close()
if __name__ == "__main__":
main()
CompressAI
==========
CompressAI (*compress-ay*) is a PyTorch library and evaluation platform for
end-to-end compression research.
.. toctree::
:hidden:
self
.. toctree::
:maxdepth: 1
:caption: Tutorials
tutorial_intro
tutorial_installation
tutorial_train
.. toctree::
:maxdepth: 1
:caption: Library API
compressai
ans
datasets
entropy_models
layers
models
ops
transforms
.. toctree::
:maxdepth: 2
:caption: Model Zoo
zoo
.. toctree::
:maxdepth: 2
:caption: Utils
cli_usage
Indices and tables
==================
* :ref:`genindex`
* :ref:`modindex`
compressai.layers
=================
.. currentmodule:: compressai.layers
MaskedConv2d
------------
.. autoclass:: MaskedConv2d
GDN
---
.. autoclass:: GDN
GDN1
----
.. autoclass:: GDN1
ResidualBlock
-------------
.. autoclass:: ResidualBlock
ResidualBlockWithStride
-----------------------
.. autoclass:: ResidualBlockWithStride
ResidualBlockUpsample
---------------------
.. autoclass:: ResidualBlockUpsample
AttentionBlock
--------------
.. autoclass:: AttentionBlock
compressai.models
=================
.. currentmodule:: compressai.models
CompressionModel
----------------
.. autoclass:: CompressionModel
:members:
FactorizedPrior
----------------
.. autoclass:: FactorizedPrior
:members:
ScaleHyperprior
---------------
.. autoclass:: ScaleHyperprior
:members:
MeanScaleHyperprior
-------------------
.. autoclass:: MeanScaleHyperprior
:members:
JointAutoregressiveHierarchicalPriors
-------------------------------------
.. autoclass:: JointAutoregressiveHierarchicalPriors
:members:
Cheng2020Anchor
---------------
.. autoclass:: Cheng2020Anchor
:members:
Cheng2020Attention
------------------
.. autoclass:: Cheng2020Attention
:members:
compressai.ops
==============
.. currentmodule:: compressai.ops
ste_round
---------
.. autofunction:: ste_round
LowerBound
----------
.. autoclass:: LowerBound
NonNegativeParametrizer
-----------------------
.. autoclass:: NonNegativeParametrizer
compressai.transforms
=====================
.. currentmodule:: compressai.transforms
Transforms on Tensors
---------------------
.. autoclass:: RGB2YCbCr
.. autoclass:: YCbCr2RGB
.. autoclass:: YUV420To444
.. autoclass:: YUV444To420
Functional Transforms
---------------------
Functional transforms can be used to define custom transform classes.
.. automodule:: compressai.transforms.functional
:members:
Installation
============
CompressAI only supports python3. We also recommend to use a virtual
environment to isolate project packages from the base system installation.
Python package
~~~~~~~~~~~~~~
Requirements
------------
* python 3.6 or later (`python3-dev`, `python3-pip`, `python3-venv`)
* pip 19.0 or later
* a C++17 compiler (tested with `gcc` and `clang`)
* python packages: `numpy`, `scipy`, `torch`, `torchvision`
Virtual environment
-------------------
.. code-block:: bash
python3 -m venv venv
source ./venv/bin/activate
pip install -U pip
Using pip
---------
1. Clone the CompressAI repository:
.. code-block:: bash
git clone https://github.com/InterDigitalInc/CompressAI compressai
2. Install CompressAI:
.. code-block:: bash
cd compressai
pip install -e .
3. Custom installation
You can also run one of the following commands:
* :code:`pip install -e '.[dev]'`: install the packages required for development (testing, linting, docs)
* :code:`pip install -e '.[tutorials]'`: install the packages required for the tutorials (notebooks)
* :code:`pip install -e '.[all]'`: install all the optional packages
Build your own package
----------------------
You can also build your own pip package:
.. code-block:: bash
git clone https://github.com/InterDigitalInc/CompressAI compressai
cd compressai
python3 setup.py bdist_wheel --dist-dir dist/
pip install dist/compressai-*.whl
.. note::
on MacOS you might want to use :code:`CC=clang CXX=clang++ pip install ...` to
compile with clang instead of gcc.
Docker
~~~~~~
We are planning to publish docker images in the future.
For now, a Makefile is provided to build docker images locally.
Run :code:`make help` in the source code directory to list the available options.
Introduction
============
Concept
~~~~~~~
CompressAI is built on top of PyTorch and provides:
* custom operations, layers and models for deep learning based data compression
* a partial port of the official `TensorFlow compression
<https://github.com/tensorflow/compression>`_ library
* pre-trained end-to-end compression models for learned image compression
* evaluation scripts to compare learned models against classical image/video
compression codecs
CompressAI aims to allow more researchers to contribute to the learned
image and video compression domain, by providing resources to research,
implement and evaluate machine learning based compression codecs.
Model Zoo
~~~~~~~~~
CompressAI includes some pre-trained models for compression tasks. See the Model
Zoo section for more documentation.
The list of available models, trained at different bit-rate distortion points
and with different metrics, is expected to grow in the future.
Training
========
In this tutorial we are going to implement a custom auto encoder architecture
by using some modules and layers pre-defined in CompressAI.
For a complete runnable example, check out the :code:`train.py` script in the
:code:`examples/` folder of the CompressAI source tree.
Defining a custom model
-----------------------
Let's build a simple auto encoder with an
:mod:`~compressai.entropy_models.EntropyBottleneck` module, 3 convolutions at
the encoder, 3 transposed deconvolutions for the decoder, and
:mod:`~compressai.layers.GDN` activation functions:
.. code-block:: python
import torch.nn as nn
from compressai.entropy_models import EntropyBottleneck
from compressai.layers import GDN
class Network(nn.Module):
def __init__(self, N=128):
super().__init__()
self.encode = nn.Sequential(
nn.Conv2d(3, N, stride=2, kernel_size=5, padding=2),
GDN(N)
nn.Conv2d(N, N, stride=2, kernel_size=5, padding=2),
GDN(N)
nn.Conv2d(N, N, stride=2, kernel_size=5, padding=2),
)
self.decode = nn.Sequential(
nn.ConvTranspose2d(N, N, kernel_size=5, padding=2, output_padding=1, stride=2)
GDN(N, inverse=True),
nn.ConvTranspose2d(N, N, kernel_size=5, padding=2, output_padding=1, stride=2)
GDN(N, inverse=True),
nn.ConvTranspose2d(N, 3, kernel_size=5, padding=2, output_padding=1, stride=2)
)
def forward(self, x):
y = self.encode(x)
y_hat, y_likelihoods = self.entropy_bottleneck(y)
x_hat = self.decode(y_hat)
return x_hat, y_likelihoods
The convolutions are strided to reduce the spatial dimensions of the tensor,
while increasing the number of channels (which helps to learn better latent
representation). The bottleneck module is used to obtain a differentiable
entropy estimation of the latent tensors while training.
.. note::
See the original paper: `"Variational image compression with a scale
hyperprior" <https://arxiv.org/abs/1802.01436>`_, and the **tensorflow/compression**
`documentation <https://tensorflow.github.io/compression/docs/entropy_bottleneck.html>`_
for a detailed explanation of the EntropyBottleneck module.
Loss functions
--------------
1. Rate distortion loss
We are going to define a simple rate-distortion loss, which maximizes the
PSNR reconstruction (RGB) and minimizes the length (in bits) of the quantized
latent tensor (:code:`y_hat`).
A scalar is used to balance between the reconstruction quality and the
bit-rate (like the JPEG quality parameter, or the QP with HEVC):
.. math::
\mathcal{L} = \mathcal{D} + \lambda * \mathcal{R}
.. code-block:: python
import math
import torch.nn as nn
import torch.nn.functional as F
x = torch.rand(1, 3, 64, 64)
net = Network()
x_hat, y_likelihoods = net(x)
# bitrate of the quantized latent
N, _, H, W = x.size()
num_pixels = N * H * W
bpp_loss = torch.log(y_likelihoods).sum() / (-math.log(2) * num_pixels)
# mean square error
mse_loss = F.mse_loss(x, x_hat)
# final loss term
loss = mse_loss + lmbda * bpp_loss
.. note::
It's possible to train architectures that can handle multiple bit-rate
distortion points but that's outside the scope of this tutorial. See this
paper: `"Variable Rate Deep Image Compression With a Conditional Autoencoder"
<http://openaccess.thecvf.com/content_ICCV_2019/papers/Choi_Variable_Rate_Deep_Image_Compression_With_a_Conditional_Autoencoder_ICCV_2019_paper.pdf>`_
for a good example.
2. Auxiliary loss
The entropy bottleneck parameters need to be trained to minimize the density
model evaluation of the latent elements. The auxiliary loss is accessible
through the :code:`entropy_bottleneck` layer:
.. code-block:: python
aux_loss = net.entropy_bottleneck.loss()
The auxiliary loss must be minimized during or after the training of the
network.
3. Optimizers
To train both the compression network and the entropy bottleneck densities
estimation, we will thus need two optimizers. To simplify the implementation,
CompressAI provides a :mod:`~compressai.models.CompressionModel` base class,
that includes an :mod:`~compressai.entropy_models.EntropyBottleneck` module
and some helper methods, let's rewrite our network:
.. code-block:: python
from compressai.models import CompressionModel
from compressai.models.utils import conv, deconv
class Network(CompressionModel):
def __init__(self, N=128):
super().__init__()
self.encode = nn.Sequential(
conv(3, N),
GDN(N)
conv(N, N),
GDN(N)
conv(N, N),
)
self.decode = nn.Sequential(
deconv(N, N),
GDN(N, inverse=True),
deconv(N, N),
GDN(N, inverse=True),
deconv(N, 3),
)
def forward(self, x):
y = self.encode(x)
y_hat, y_likelihoods = self.entropy_bottleneck(y)
x_hat = self.decode(y_hat)
return x_hat, y_likelihoods
Now, we can simply access the two sets of trainable parameters:
.. code-block:: python
import torch.optim as optim
parameters = set(p for n, p in net.named_parameters() if not n.endswith(".quantiles"))
aux_parameters = set(p for n, p in net.named_parameters() if n.endswith(".quantiles"))
optimizer = optim.Adam(parameters, lr=1e-4)
aux_optimizer = optim.Adam(aux_parameters, lr=1e-3)
And write a training loop:
.. code-block:: python
x = torch.rand(1, 3, 64, 64)
for i in range(10):
optimizer.zero_grad()
aux_optimizer.zero_grad()
x_hat, y_likelihoods = net(x)
# ...
# compute loss as before
# ...
loss.backward()
optimizer.step()
aux_loss = net.aux_loss()
aux_loss.backward()
aux_optimizer.step()
Updating the model
------------------
Once a model has been trained, you need to run the :code:`update_model` script
to update the internal parameters of the entropy bottlenecks:
.. code-block:: bash
python -m compressai.utils.update_model -n final-model --arch ARCH model_checkpoint.pth.tar
This will modify the buffers related to the learned cumulative distribution
functions (CDFs) required to perform the actual entropy coding.
You can run :code:`python -m compressai.utils.update_model --help` to get the
complete list of options.
Alternatively, you can call the :meth:`~compressai.models.CompressionModel.update`
method of a :mod:`~compressai.models.CompressionModel` or
:mod:`~compressai.entropy_models.EntropyBottleneck` instance at the end of your
training script, before saving the model checkpoint.
Entropy coding
--------------
By default CompressAI uses a range Asymmetric Numeral Systems (ANS) entropy
coder. You can use :meth:`compressai.available_entropy_coders()` to get a list
of the implemented entropy coders and change the default entropy coder via
:meth:`compressai.set_entropy_coder()`.
1. Compress an image tensor to a bit-stream:
.. code-block:: python
x = torch.rand(1, 3, 64, 64)
y = net.encode(x)
strings = net.entropy_bottleneck.compress(y)
2. Decompress a bit-stream to an image tensor:
.. code-block:: python
shape = y.size()[2:]
y_hat = net.entropy_bottleneck.decompress(strings, shape)
x_hat = net.decode(y_hat)
Image compression
=================
.. currentmodule:: compressai.zoo
This is the list of the pre-trained models for end-to-end image compression
available in CompressAI.
Currently, only models optimized w.r.t to the mean square error (*mse*) computed
on the RGB channels are available. We expect to release models fine-tuned with
other metrics in the future.
Pass :code:`pretrained=True` to construct a model with pretrained weights.
Instancing a pre-trained model will download its weights to a cache directory.
See the official `PyTorch documentation
<https://pytorch.org/docs/stable/model_zoo.html#torch.utils.model_zoo.load_url>`_
for details on the mechanics of loading models from url in PyTorch.
The current pre-trained models expect input batches of RGB image tensors of
shape (N, 3, H, W). H and W are expected to be at least 64. The images data have
to be in the [0, 1] range. The images *should not be normalized*. Based on the
number of strided convolutions and deconvolutions of the model you are using,
you might have to pad the input tensors H and W dimensions to be a power of 2.
Models may have different behaviors for their training or evaluation modes. For
example, the quantization operations may be performed differently. You can use
``model.train()`` or ``model.eval()`` to switch between modes. See the PyTorch
documentation for more information on
`train <https://pytorch.org/docs/stable/nn.html?highlight=train#torch.nn.Module.train>`_
and `eval <https://pytorch.org/docs/stable/nn.html?highlight=eval#torch.nn.Module.eval>`_.
.. contents:: Table of content
:local:
Training
~~~~~~~~
Unless specified otherwise, networks were trained for 4-5M steps on *256x256*
image patches randomly extracted and cropped from the `Vimeo90K
<http://toflow.csail.mit.edu/>`_ dataset [xue2019video]_.
Models were trained with a batch size of 16 or 32, and an initial learning rate
of 1e-4 for approximately 1-2M steps. The learning rate of the main optimizer is
then divided by 2 when the evaluation loss reaches a plateau (we use a patience
of 20 epochs). This can be implemented by using PyTorch `ReduceLROnPlateau <https://pytorch.org/docs/stable/optim.html?highlight=reducelronplateau#torch.optim.lr_scheduler.ReduceLROnPlateau>`_ learning rate scheduler.
Training usually take between one or two weeks to reach state-of-the-art
performances, depending on the model, the number of channels and the GPU
architecture used.
The following loss functions and lambda values were used for training:
.. csv-table::
:header: "Metric", "Loss function"
:widths: 10, 50
MSE, :math:`\mathcal{L} = \lambda * 255^{2} * \mathcal{D} + \mathcal{R}`
MS-SSIM, :math:`\mathcal{L} = \lambda * (1 - \mathcal{D}) + \mathcal{R}`
with :math:`\mathcal{D}` and :math:`\mathcal{R}` respectively the mean
distortion and the mean estimated bit-rate.
.. csv-table::
:header: "Quality", 1, 2, 3, 4, 5, 6, 7, 8
:widths: 10, 5, 5, 5, 5, 5, 5, 5, 5
MSE, 0.0018, 0.0035, 0.0067, 0.0130, 0.0250, 0.0483, 0.0932, 0.1800
MS-SSIM, 2.40, 4.58, 8.73, 16.64, 31.73, 60.50, 115.37, 220.00
.. note:: MS-SSIM optimized networks were fine-tuned from pre-trained MSE
networks (with a learning rate of 1e-5 for both optimizers).
.. note:: The number of channels for the convolutionnal layers and the entropy
bottleneck depends on the architecture and the quality parameter (~targeted
bit-rate). For low bit-rates (<0.5 bpp), the literature usually recommends 192
channels for the entropy bottleneck, and 320 channels for higher bitrates.
The detailed list of configurations can be found in
:obj:`compressai.zoo.image.cfgs`.
.. note:: For the *cheng2020_\** architectures, we trained with the first 6
quality parameters.
....
Models
~~~~~~
.. warning:: All the models are currently implemented using floating point
operations only. As such operations are not reproducible and
encoding/decoding on different devices is not supported. See the following
paper, `"Integer Networks for Data Compression with Latent-Variable Models"
<https://openreview.net/forum?id=S1zz2i0cY7>`_ by Ballé *et al.*, for
solutions to implement cross-platform encoding and decoding.
bmshj2018_factorized
--------------------
Original paper: [bmshj2018]_
.. autofunction:: bmshj2018_factorized
bmshj2018_hyperprior
--------------------
Original paper: [bmshj2018]_
.. autofunction:: bmshj2018_hyperprior
mbt2018_mean
------------
Original paper: [mbt2018]_
.. autofunction:: mbt2018_mean
mbt2018
-------
Original paper: [mbt2018]_
.. autofunction:: mbt2018
cheng2020_anchor
----------------
Original paper: [cheng2020]_
.. autofunction:: cheng2020_anchor
cheng2020_attn
--------------
Original paper: [cheng2020]_
.. autofunction:: cheng2020_attn
.. warning:: Pre-trained weights are not yet available for this architecture.
....
.. rubric:: Citations
.. [bmshj2018]
.. code-block:: bibtex
@inproceedings{ballemshj18,
author = {Johannes Ball{\'{e}} and
David Minnen and
Saurabh Singh and
Sung Jin Hwang and
Nick Johnston},
title = {Variational image compression with a scale hyperprior},
booktitle = {6th International Conference on Learning Representations, {ICLR} 2018,
Vancouver, BC, Canada, April 30 - May 3, 2018, Conference Track Proceedings},
publisher = {OpenReview.net},
year = {2018},
}
.. [mbt2018]
.. code-block:: bibtex
@inproceedings{minnenbt18,
author = {David Minnen and
Johannes Ball{\'{e}} and
George Toderici},
editor = {Samy Bengio and
Hanna M. Wallach and
Hugo Larochelle and
Kristen Grauman and
Nicol{\`{o}} Cesa{-}Bianchi and
Roman Garnett},
title = {Joint Autoregressive and Hierarchical Priors for Learned Image Compression},
booktitle = {Advances in Neural Information Processing Systems 31: Annual Conference
on Neural Information Processing Systems 2018, NeurIPS 2018, 3-8 December
2018, Montr{\'{e}}al, Canada},
pages = {10794--10803},
year = {2018},
}
.. [xue2019video]
.. code-block:: bibtex
@article{xue2019video,
title={Video Enhancement with Task-Oriented Flow},
author={Xue, Tianfan and Chen, Baian and Wu, Jiajun and Wei, Donglai and
Freeman, William T},
journal={International Journal of Computer Vision (IJCV)},
volume={127},
number={8},
pages={1106--1125},
year={2019},
publisher={Springer}
}
.. [cheng2020]
.. code-block:: bibtex
@inproceedings{cheng2020image,
title={Learned Image Compression with Discretized Gaussian Mixture
Likelihoods and Attention Modules},
author={Cheng, Zhengxue and Sun, Heming and Takeuchi, Masaru and Katto,
Jiro},
booktitle= "Proceedings of the IEEE Conference on Computer Vision and
Pattern Recognition (CVPR)",
year={2020}
}]
....
Performances
~~~~~~~~~~~~
.. note:: See the `CompressAI paper <https://arxiv.org/abs/2011.03029>`_ on
arXiv for more comparisons and evaluations.
all models
----------
.. image:: media/images/compressai.png
.. image:: media/images/compressai-clic2020-mobile.png
.. image:: media/images/compressai-clic2020-pro.png
bmshj2018 factorized
--------------------
.. image:: media/images/bmshj2018-factorized-mse.png
bmshj2018 hyperprior
--------------------
.. image:: media/images/bmshj2018-hyperprior-mse.png
mbt2018 mean
------------
.. image:: media/images/mbt2018-mean-mse.png
mbt2018
-------
.. image:: media/images/mbt2018-mse.png
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment