Commit 9e459ea3 authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=line-too-long
"""
Find the closest codec quality parameter to reach a given metric (bpp, ms-ssim,
or psnr).
Example usages:
* :code:`python -m compressai.utils.find_close webp ~/picture.png 0.5 --metric bpp`
* :code:`python -m compressai.utils.find_close jpeg ~/picture.png 35 --metric psnr --save`
"""
# pylint: enable=line-too-long
import argparse
import sys
import os
from typing import Dict, List, Tuple
from PIL import Image
from compressai.utils.bench.codecs import AV1, BPG, HM, JPEG, JPEG2000, VTM, Codec, WebP
def get_codec_q_bounds(codec: Codec) -> Tuple[bool, int, int]:
rev = False # higher Q -> better quality or reverse
if isinstance(codec, (JPEG, JPEG2000, WebP)):
lower = -1
upper = 101
elif isinstance(codec, (BPG, HM)):
lower = -1
upper = 51
rev = True
elif isinstance(codec, (AV1, VTM)):
lower = -1
upper = 64
rev = True
else:
assert False, codec
return rev, lower, upper
def find_closest(
codec: Codec, img: str, target: float, metric: str = "psnr"
) -> Tuple[int, Dict[str, float], Image.Image]:
rev, lower, upper = get_codec_q_bounds(codec)
best_rv = {}
best_rec = None
while upper > lower + 1:
mid = (upper + lower) // 2
rv, rec = codec.run(img, mid, return_rec=True)
is_best = best_rv == {} or abs(rv[metric] - target) < abs(
best_rv[metric] - target
)
if is_best:
best_rv = rv
best_rec = rec
if rv[metric] > target:
if not rev:
upper = mid
else:
lower = mid
elif rv[metric] < target:
if not rev:
lower = mid
else:
upper = mid
else:
break
sys.stderr.write(
f"\rtarget {metric}: {target:.4f} | value: {rv[metric]:.4f} | q: {mid}"
)
sys.stderr.flush()
sys.stderr.write("\n")
sys.stderr.flush()
return mid, best_rv, best_rec
codecs = [JPEG, WebP, JPEG2000, BPG, VTM, HM, AV1]
def setup_args():
description = "Collect codec metrics and performances."
parser = argparse.ArgumentParser(description=description)
subparsers = parser.add_subparsers(dest="codec", help="Select codec")
subparsers.required = True
parser.add_argument("image", type=str, help="image filepath")
parser.add_argument("target", type=float, help="target value to match")
parser.add_argument(
"-m", "--metric", type=str, choices=["bpp", "psnr", "ms-ssim"], default="bpp"
)
parser.add_argument(
"--save", action="store_true", help="Save reconstructed image to disk"
)
parser.add_argument(
"--prefix", type=str, default='.'
)
return parser, subparsers
def main(argv: List[str]):
parser, subparsers = setup_args()
for c in codecs:
cparser = subparsers.add_parser(c.__name__.lower(), help=f"{c.__name__}")
c.setup_args(cparser)
args = parser.parse_args(argv)
codec_cls = next(c for c in codecs if c.__name__.lower() == args.codec)
codec = codec_cls(args)
quality, metrics, rec = find_closest(codec, args.image, args.target, args.metric)
for k, v in metrics.items():
print(f"{k}: {v:.4f}")
cur_psnr = metrics["psnr"]
cur_ssim = metrics["ms-ssim"]
bpp = metrics["bpp"]
prefix = args.prefix
if not os.path.exists(prefix):
os.makedirs(prefix)
filename = f"output_{codec_cls.__name__.lower()}" +"_"+'{:.2f}'.format(cur_psnr)+"_"+'{:.3f}'.format(bpp)+"_"+'{:.3f}'.format(cur_ssim)+'_'+f"_q{quality}.png"
output_path = os.path.join(prefix, filename)
if args.save:
rec.save(output_path)
if __name__ == "__main__":
main(sys.argv[1:])
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# - Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Simple plotting utility to display Rate-Distortion curves (RD) comparison
between codecs.
"""
import argparse
import json
import sys
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
_backends = ["matplotlib"]
try:
import plotly.graph_objs as go
import plotly.offline
_backends.append("plotly")
except ImportError:
pass
def parse_json_file(filepath, metric):
filepath = Path(filepath)
name = filepath.name.split(".")[0]
with filepath.open("r") as f:
try:
data = json.load(f)
except json.decoder.JSONDecodeError as err:
print(f'Error reading file "{filepath}"')
raise err
if "results" not in data or "bpp" not in data["results"]:
raise ValueError(f'Invalid file "{filepath}"')
if metric not in data["results"]:
raise ValueError(
f'Error: metric "{metric}" not available.'
f' Available metrics: {", ".join(data["results"].keys())}'
)
if metric == "ms-ssim":
# Convert to db
values = np.array(data["results"][metric])
data["results"][metric] = -10 * np.log10(1 - values)
return {
"name": data.get("name", name),
"xs": data["results"]["bpp"],
"ys": data["results"][metric],
}
def matplotlib_plt(
scatters, title, ylabel, output_file, limits=None, show=False, figsize=None
):
if figsize is None:
figsize = (9, 6)
fig, ax = plt.subplots(figsize=figsize)
for sc in scatters:
ax.plot(sc["xs"], sc["ys"], ".-", label=sc["name"])
ax.set_xlabel("Bit-rate [bpp]")
ax.set_ylabel(ylabel)
ax.grid()
if limits is not None:
ax.axis(limits)
ax.legend(loc="lower right")
if title:
ax.title.set_text(title)
if show:
plt.show()
if output_file:
fig.savefig(output_file, dpi=300)
def plotly_plt(
scatters, title, ylabel, output_file, limits=None, show=False, figsize=None
):
del figsize
scatters = [go.Scatter(x=sc["xs"], y=sc["ys"], name=sc["name"]) for sc in scatters]
plotly.offline.plot(
{
"data": scatters,
"layout": go.Layout(
title=title,
legend={
"font": {
"size": 14,
},
},
xaxis={"title": "Bit-rate [bpp]", "range": [limits[0], limits[1]]},
yaxis={"title": ylabel, "range": [limits[2], limits[3]]},
),
},
auto_open=show,
filename=output_file or "plot.html",
)
def setup_args():
parser = argparse.ArgumentParser(description="")
parser.add_argument(
"-f",
"--results-file",
metavar="",
default="",
type=str,
nargs="*",
required=True,
)
parser.add_argument(
"-m",
"--metric",
metavar="",
type=str,
default="psnr",
help="Metric (default: %(default)s)",
)
parser.add_argument("-t", "--title", metavar="", type=str, help="Plot title")
parser.add_argument("-o", "--output", metavar="", type=str, help="Output file name")
parser.add_argument(
"--figsize",
metavar="",
type=float,
nargs=2,
default=(9, 6),
help="Figure relative size (width, height), default: %(default)s",
)
parser.add_argument(
"--axes",
metavar="",
type=float,
nargs=4,
default=(0, 2, 28, 43),
help="Axes limit (xmin, xmax, ymin, ymax), default: %(default)s",
)
parser.add_argument(
"--backend",
type=str,
metavar="",
default=_backends[0],
choices=_backends,
help="Change plot backend (default: %(default)s)",
)
parser.add_argument("--show", action="store_true", help="Open plot figure")
return parser
def main(argv):
args = setup_args().parse_args(argv)
scatters = []
for f in args.results_file:
rv = parse_json_file(f, args.metric)
scatters.append(rv)
ylabel = args.metric
if ylabel == "psnr":
ylabel = "PSNR [dB]"
func_map = {
"matplotlib": matplotlib_plt,
"plotly": plotly_plt,
}
func_map[args.backend](
scatters,
args.title,
ylabel,
args.output,
limits=args.axes,
figsize=args.figsize,
show=args.show,
)
if __name__ == "__main__":
main(sys.argv[1:])
\ No newline at end of file
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Update the CDFs parameters of a trained model.
To be called on a model checkpoint after training. This will update the internal
CDFs related buffers required for entropy coding.
"""
import argparse
import hashlib
import sys
import os
from pathlib import Path
from typing import Dict
import torch
from compressai.models.priors import (
FactorizedPrior,
JointAutoregressiveHierarchicalPriors,
MeanScaleHyperprior,
ScaleHyperprior,
)
from compressai.models.ours import (
InvCompress,
)
def sha256_file(filepath: Path, len_hash_prefix: int = 8) -> str:
# from pytorch github repo
sha256 = hashlib.sha256()
with filepath.open("rb") as f:
while True:
buf = f.read(8192)
if len(buf) == 0:
break
sha256.update(buf)
digest = sha256.hexdigest()
return digest[:len_hash_prefix]
def load_checkpoint(filepath: Path) -> Dict[str, torch.Tensor]:
checkpoint = torch.load(filepath, map_location="cpu")
if "network" in checkpoint:
state_dict = checkpoint["network"]
elif "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
else:
state_dict = checkpoint
return state_dict
description = """
Export a trained model to a new checkpoint with an updated CDFs parameters and a
hash prefix, so that it can be loaded later via `load_state_dict_from_url`.
""".strip()
models = {
"factorized-prior": FactorizedPrior,
"jarhp": JointAutoregressiveHierarchicalPriors,
"mean-scale-hyperprior": MeanScaleHyperprior,
"scale-hyperprior": ScaleHyperprior,
"invcompress": InvCompress,
}
def setup_args():
parser = argparse.ArgumentParser(description=description)
# parser.add_argument(
# "filepath", type=str, help="Path to the checkpoint model to be exported."
# )
parser.add_argument("-exp", "--experiment", type=str, required=True, help="Experiment name")
parser.add_argument("-d", "--dir", type=str, help="Exported model directory.")
parser.add_argument(
"--no-update",
action="store_true",
default=False,
help="Do not update the model CDFs parameters.",
)
parser.add_argument(
"-a",
"--architecture",
default="scale-hyperprior",
choices=models.keys(),
help="Set model architecture (default: %(default)s).",
)
parser.add_argument("--epoch", type=int, default=-1, help="Epoch")
return parser
def main(argv):
args = setup_args().parse_args(argv)
if args.epoch != -1:
filepath = os.path.join('../experiments', args.experiment, 'checkpoints', 'checkpoint_%03d.pth.tar' % args.epoch)
else:
filepath = os.path.join('../experiments', args.experiment, 'checkpoints', 'checkpoint_best_loss.pth.tar')
filepath = Path(filepath).resolve()
if not filepath.is_file():
raise RuntimeError(f'"{filepath}" is not a valid file.')
state_dict = load_checkpoint(filepath)
model_cls = models[args.architecture]
net = model_cls.from_state_dict(state_dict)
if not args.no_update:
net.update(force=True)
state_dict = net.state_dict()
filename = filepath
while filename.suffixes:
filename = Path(filename.stem)
ext = "".join(filepath.suffixes)
output_dir = os.path.join('../experiments', args.experiment, 'checkpoint_updated')
if not os.path.exists(output_dir):
os.makedirs(output_dir)
for f in os.listdir(output_dir):
os.remove(os.path.join(output_dir, f))
filepath = Path(f"{output_dir}/{filename}{ext}")
torch.save(state_dict, filepath)
hash_prefix = sha256_file(filepath)
filepath.rename(f"{output_dir}/{filename}-{hash_prefix}{ext}")
if __name__ == "__main__":
main(sys.argv[1:])
__version__ = "1.1.1"
git_version = "16100c4cb697036e36118f18e9078b00a676fb75"
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .image import *
models = {
"bmshj2018-factorized": bmshj2018_factorized,
"bmshj2018-hyperprior": bmshj2018_hyperprior,
"mbt2018-mean": mbt2018_mean,
"mbt2018": mbt2018,
"cheng2020-anchor": cheng2020_anchor,
"cheng2020-attn": cheng2020_attn,
"invcompress": invcompress,
}
\ No newline at end of file
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torch.hub import load_state_dict_from_url
from compressai.models import (
InvCompress,
Cheng2020Anchor,
Cheng2020Attention,
FactorizedPrior,
JointAutoregressiveHierarchicalPriors,
MeanScaleHyperprior,
ScaleHyperprior,
)
from .pretrained import load_pretrained
__all__ = [
"bmshj2018_factorized",
"bmshj2018_hyperprior",
"mbt2018",
"mbt2018_mean",
"cheng2020_anchor",
"cheng2020_attn",
"invcompress",
]
model_architectures = {
"bmshj2018-factorized": FactorizedPrior,
"bmshj2018-hyperprior": ScaleHyperprior,
"mbt2018-mean": MeanScaleHyperprior,
"mbt2018": JointAutoregressiveHierarchicalPriors,
"cheng2020-anchor": Cheng2020Anchor,
"cheng2020-attn": Cheng2020Attention,
"invcompress": InvCompress,
}
root_url = "https://compressai.s3.amazonaws.com/models/v1"
model_urls = {
"bmshj2018-factorized": {
"mse": {
1: f"{root_url}/bmshj2018-factorized-prior-1-446d5c7f.pth.tar",
2: f"{root_url}/bmshj2018-factorized-prior-2-87279a02.pth.tar",
3: f"{root_url}/bmshj2018-factorized-prior-3-5c6f152b.pth.tar",
4: f"{root_url}/bmshj2018-factorized-prior-4-1ed4405a.pth.tar",
5: f"{root_url}/bmshj2018-factorized-prior-5-866ba797.pth.tar",
6: f"{root_url}/bmshj2018-factorized-prior-6-9b02ea3a.pth.tar",
7: f"{root_url}/bmshj2018-factorized-prior-7-6dfd6734.pth.tar",
8: f"{root_url}/bmshj2018-factorized-prior-8-5232faa3.pth.tar",
},
"ms-ssim": {
1: f"{root_url}/bmshj2018-factorized-ms-ssim-1-9781d705.pth.tar",
2: f"{root_url}/bmshj2018-factorized-ms-ssim-2-4a584386.pth.tar",
3: f"{root_url}/bmshj2018-factorized-ms-ssim-3-5352f123.pth.tar",
4: f"{root_url}/bmshj2018-factorized-ms-ssim-4-4f91b847.pth.tar",
5: f"{root_url}/bmshj2018-factorized-ms-ssim-5-b3a88897.pth.tar",
6: f"{root_url}/bmshj2018-factorized-ms-ssim-6-ee028763.pth.tar",
7: f"{root_url}/bmshj2018-factorized-ms-ssim-7-8c265a29.pth.tar",
8: f"{root_url}/bmshj2018-factorized-ms-ssim-8-8811bd14.pth.tar",
},
},
"bmshj2018-hyperprior": {
"mse": {
1: f"{root_url}/bmshj2018-hyperprior-1-7eb97409.pth.tar",
2: f"{root_url}/bmshj2018-hyperprior-2-93677231.pth.tar",
3: f"{root_url}/bmshj2018-hyperprior-3-6d87be32.pth.tar",
4: f"{root_url}/bmshj2018-hyperprior-4-de1b779c.pth.tar",
5: f"{root_url}/bmshj2018-hyperprior-5-f8b614e1.pth.tar",
6: f"{root_url}/bmshj2018-hyperprior-6-1ab9c41e.pth.tar",
7: f"{root_url}/bmshj2018-hyperprior-7-3804dcbd.pth.tar",
8: f"{root_url}/bmshj2018-hyperprior-8-a583f0cf.pth.tar",
},
"ms-ssim": {
1: f"{root_url}/bmshj2018-hyperprior-ms-ssim-1-5cf249be.pth.tar",
2: f"{root_url}/bmshj2018-hyperprior-ms-ssim-2-1ff60d1f.pth.tar",
3: f"{root_url}/bmshj2018-hyperprior-ms-ssim-3-92dd7878.pth.tar",
4: f"{root_url}/bmshj2018-hyperprior-ms-ssim-4-4377354e.pth.tar",
5: f"{root_url}/bmshj2018-hyperprior-ms-ssim-5-c34afc8d.pth.tar",
6: f"{root_url}/bmshj2018-hyperprior-ms-ssim-6-3a6d8229.pth.tar",
7: f"{root_url}/bmshj2018-hyperprior-ms-ssim-7-8747d3bc.pth.tar",
8: f"{root_url}/bmshj2018-hyperprior-ms-ssim-8-cc15b5f3.pth.tar",
},
},
"mbt2018-mean": {
"mse": {
1: f"{root_url}/mbt2018-mean-1-e522738d.pth.tar",
2: f"{root_url}/mbt2018-mean-2-e54a039d.pth.tar",
3: f"{root_url}/mbt2018-mean-3-723404a8.pth.tar",
4: f"{root_url}/mbt2018-mean-4-6dba02a3.pth.tar",
5: f"{root_url}/mbt2018-mean-5-d504e8eb.pth.tar",
6: f"{root_url}/mbt2018-mean-6-a19628ab.pth.tar",
7: f"{root_url}/mbt2018-mean-7-d5d441d1.pth.tar",
8: f"{root_url}/mbt2018-mean-8-8089ae3e.pth.tar",
},
},
"mbt2018": {
"mse": {
1: f"{root_url}/mbt2018-1-3f36cd77.pth.tar",
2: f"{root_url}/mbt2018-2-43b70cdd.pth.tar",
3: f"{root_url}/mbt2018-3-22901978.pth.tar",
4: f"{root_url}/mbt2018-4-456e2af9.pth.tar",
5: f"{root_url}/mbt2018-5-b4a046dd.pth.tar",
6: f"{root_url}/mbt2018-6-7052e5ea.pth.tar",
7: f"{root_url}/mbt2018-7-8ba2bf82.pth.tar",
8: f"{root_url}/mbt2018-8-dd0097aa.pth.tar",
},
},
"cheng2020-anchor": {
"mse": {
1: f"{root_url}/cheng2020-anchor-1-dad2ebff.pth.tar",
2: f"{root_url}/cheng2020-anchor-2-a29008eb.pth.tar",
3: f"{root_url}/cheng2020-anchor-3-e49be189.pth.tar",
4: f"{root_url}/cheng2020-anchor-4-98b0b468.pth.tar",
5: f"{root_url}/cheng2020-anchor-5-23852949.pth.tar",
6: f"{root_url}/cheng2020-anchor-6-4c052b1a.pth.tar",
},
},
"cheng2020-attn": {
"mse": {},
},
}
cfgs = {
"bmshj2018-factorized": {
1: (128, 192),
2: (128, 192),
3: (128, 192),
4: (128, 192),
5: (128, 192),
6: (192, 320),
7: (192, 320),
8: (192, 320),
},
"bmshj2018-hyperprior": {
1: (128, 192),
2: (128, 192),
3: (128, 192),
4: (128, 192),
5: (128, 192),
6: (192, 320),
7: (192, 320),
8: (192, 320),
},
"mbt2018-mean": {
1: (128, 192),
2: (128, 192),
3: (128, 192),
4: (128, 192),
5: (192, 320),
6: (192, 320),
7: (192, 320),
8: (192, 320),
},
"mbt2018": {
1: (192, 192),
2: (192, 192),
3: (192, 192),
4: (192, 192),
5: (192, 320),
6: (192, 320),
7: (192, 320),
8: (192, 320),
},
"cheng2020-anchor": {
1: (128,),
2: (128,),
3: (128,),
4: (192,),
5: (192,),
6: (192,),
},
"cheng2020-attn": {
1: (128,),
2: (128,),
3: (128,),
4: (192,),
5: (192,),
6: (192,),
7: (256,),
8: (384,),
},
"invcompress": {
1: (128,),
2: (128,),
3: (128,),
3: (128,),
5: (192,),
6: (192,),
7: (192,),
8: (192,),
},
}
def _load_model(
architecture, metric, quality, pretrained=False, progress=True, **kwargs
):
if architecture not in model_architectures:
raise ValueError(f'Invalid architecture name "{architecture}"')
if quality not in cfgs[architecture]:
raise ValueError(f'Invalid quality value "{quality}"')
if pretrained:
if (
architecture not in model_urls
or metric not in model_urls[architecture]
or quality not in model_urls[architecture][metric]
):
raise RuntimeError("Pre-trained model not yet available")
url = model_urls[architecture][metric][quality]
state_dict = load_state_dict_from_url(url, progress=progress)
state_dict = load_pretrained(state_dict)
model = model_architectures[architecture].from_state_dict(state_dict)
return model
model = model_architectures[architecture](*cfgs[architecture][quality], **kwargs)
return model
def bmshj2018_factorized(
quality, metric="mse", pretrained=False, progress=True, **kwargs
):
r"""Factorized Prior model from J. Balle, D. Minnen, S. Singh, S.J. Hwang,
N. Johnston: `"Variational Image Compression with a Scale Hyperprior"
<https://arxiv.org/abs/1802.01436>`_, Int Conf. on Learning Representations
(ICLR), 2018.
Args:
quality (int): Quality levels (1: lowest, highest: 8)
metric (str): Optimized metric, choose from ('mse')
pretrained (bool): If True, returns a pre-trained model
progress (bool): If True, displays a progress bar of the download to stderr
"""
if metric not in ("mse", "ms-ssim"):
raise ValueError(f'Invalid metric "{metric}"')
if quality < 1 or quality > 8:
raise ValueError(f'Invalid quality "{quality}", should be between (1, 8)')
return _load_model(
"bmshj2018-factorized", metric, quality, pretrained, progress, **kwargs
)
def bmshj2018_hyperprior(
quality, metric="mse", pretrained=False, progress=True, **kwargs
):
r"""Scale Hyperprior model from J. Balle, D. Minnen, S. Singh, S.J. Hwang,
N. Johnston: `"Variational Image Compression with a Scale Hyperprior"
<https://arxiv.org/abs/1802.01436>`_ Int. Conf. on Learning Representations
(ICLR), 2018.
Args:
quality (int): Quality levels (1: lowest, highest: 8)
metric (str): Optimized metric, choose from ('mse')
pretrained (bool): If True, returns a pre-trained model
progress (bool): If True, displays a progress bar of the download to stderr
"""
if metric not in ("mse", "ms-ssim"):
raise ValueError(f'Invalid metric "{metric}"')
if quality < 1 or quality > 8:
raise ValueError(f'Invalid quality "{quality}", should be between (1, 8)')
return _load_model(
"bmshj2018-hyperprior", metric, quality, pretrained, progress, **kwargs
)
def mbt2018_mean(quality, metric="mse", pretrained=False, progress=True, **kwargs):
r"""Scale Hyperprior with non zero-mean Gaussian conditionals from D.
Minnen, J. Balle, G.D. Toderici: `"Joint Autoregressive and Hierarchical
Priors for Learned Image Compression" <https://arxiv.org/abs/1809.02736>`_,
Adv. in Neural Information Processing Systems 31 (NeurIPS 2018).
Args:
quality (int): Quality levels (1: lowest, highest: 8)
metric (str): Optimized metric, choose from ('mse')
pretrained (bool): If True, returns a pre-trained model
progress (bool): If True, displays a progress bar of the download to stderr
"""
if metric not in ("mse",):
raise ValueError(f'Invalid metric "{metric}"')
if quality < 1 or quality > 8:
raise ValueError(f'Invalid quality "{quality}", should be between (1, 8)')
return _load_model("mbt2018-mean", metric, quality, pretrained, progress, **kwargs)
def mbt2018(quality, metric="mse", pretrained=False, progress=True, **kwargs):
r"""Joint Autoregressive Hierarchical Priors model from D.
Minnen, J. Balle, G.D. Toderici: `"Joint Autoregressive and Hierarchical
Priors for Learned Image Compression" <https://arxiv.org/abs/1809.02736>`_,
Adv. in Neural Information Processing Systems 31 (NeurIPS 2018).
Args:
quality (int): Quality levels (1: lowest, highest: 8)
metric (str): Optimized metric, choose from ('mse')
pretrained (bool): If True, returns a pre-trained model
progress (bool): If True, displays a progress bar of the download to stderr
"""
if metric not in ("mse",):
raise ValueError(f'Invalid metric "{metric}"')
if quality < 1 or quality > 8:
raise ValueError(f'Invalid quality "{quality}", should be between (1, 8)')
return _load_model("mbt2018", metric, quality, pretrained, progress, **kwargs)
def cheng2020_anchor(quality, metric="mse", pretrained=False, progress=True, **kwargs):
r"""Anchor model variant from `"Learned Image Compression with
Discretized Gaussian Mixture Likelihoods and Attention Modules"
<https://arxiv.org/abs/2001.01568>`_, by Zhengxue Cheng, Heming Sun, Masaru
Takeuchi, Jiro Katto.
Args:
quality (int): Quality levels (1: lowest, highest: 6)
metric (str): Optimized metric, choose from ('mse')
pretrained (bool): If True, returns a pre-trained model
progress (bool): If True, displays a progress bar of the download to stderr
"""
if metric not in ("mse",):
raise ValueError(f'Invalid metric "{metric}"')
if quality < 1 or quality > 6:
raise ValueError(f'Invalid quality "{quality}", should be between (1, 6)')
return _load_model(
"cheng2020-anchor", metric, quality, pretrained, progress, **kwargs
)
def cheng2020_attn(quality, metric="mse", pretrained=False, progress=True, **kwargs):
r"""Self-attention model variant from `"Learned Image Compression with
Discretized Gaussian Mixture Likelihoods and Attention Modules"
<https://arxiv.org/abs/2001.01568>`_, by Zhengxue Cheng, Heming Sun, Masaru
Takeuchi, Jiro Katto.
Args:
quality (int): Quality levels (1: lowest, highest: 6)
metric (str): Optimized metric, choose from ('mse')
pretrained (bool): If True, returns a pre-trained model
progress (bool): If True, displays a progress bar of the download to stderr
"""
if metric not in ("mse",):
raise ValueError(f'Invalid metric "{metric}"')
if quality < 1 or quality > 8:
raise ValueError(f'Invalid quality "{quality}", should be between (1, 8)')
return _load_model(
"cheng2020-attn", metric, quality, pretrained, progress, **kwargs
)
def invcompress(
quality, metric="mse", pretrained=False, progress=True, **kwargs
):
r"""Our InvCompress model
Args:
quality (int): Quality levels (1: lowest, highest: 8)
metric (str): Optimized metric, choose from ('mse')
pretrained (bool): If True, returns a pre-trained model
progress (bool): If True, displays a progress bar of the download to stderr
"""
if metric not in ("mse", "ms-ssim"):
raise ValueError(f'Invalid metric "{metric}"')
if quality < 1 or quality > 8:
raise ValueError(f'Invalid quality "{quality}", should be between (1, 13)')
if pretrained == True:
raise ValueError(f'Invalid pretrain "{pretrain}", not yet supported')
return _load_model(
"invcompress", metric, quality, pretrained, progress, **kwargs
)
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def rename_key(key):
"""Rename state_dict key."""
# Deal with modules trained with DataParallel
if key.startswith("module."):
key = key[7:]
# ResidualBlockWithStride: 'downsample' -> 'skip'
if ".downsample." in key:
return key.replace("downsample", "skip")
# EntropyBottleneck: nn.ParameterList to nn.Parameters
if key.startswith("entropy_bottleneck."):
if key.startswith("entropy_bottleneck._biases."):
return f"entropy_bottleneck._bias{key[-1]}"
if key.startswith("entropy_bottleneck._matrices."):
return f"entropy_bottleneck._matrix{key[-1]}"
if key.startswith("entropy_bottleneck._factors."):
return f"entropy_bottleneck._factor{key[-1]}"
return key
def load_pretrained(state_dict):
"""Convert state_dict keys."""
state_dict = {rename_key(k): v for k, v in state_dict.items()}
return state_dict
ARG PYTORCH_IMAGE
FROM ${PYTORCH_IMAGE}-devel as builder
WORKDIR /tmp/compressai
COPY compressai.tar.gz .
RUN tar xzf compressai.tar.gz && \
python3 setup.py bdist_wheel
FROM ${PYTORCH_IMAGE}-runtime
LABEL maintainer="compressai@interdigital.com"
WORKDIR /tmp
COPY --from=builder /tmp/compressai/dist/compressai-*.whl .
RUN pip install compressai-*.whl && \
python3 -c 'import compressai'
# Install jupyter?
ARG WITH_JUPYTER=0
RUN if [ "$WITH_JUPYTER" = "1" ]; then \
pip3 install jupyter ipywidgets && \
jupyter nbextension enable --py widgetsnbextension \
; fi
WORKDIR /workspace
CMD ["bash"]
ARG BASE_IMAGE
FROM ${BASE_IMAGE} as base
RUN pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
FROM base as builder
WORKDIR /tmp/compressai
COPY compressai.tar.gz .
RUN tar xzf compressai.tar.gz && \
python3 setup.py sdist bdist_wheel
FROM base
LABEL maintainer="compressai@interdigital.com"
WORKDIR /tmp
COPY --from=builder /tmp/compressai/dist/compressai-*.whl .
RUN pip install compressai-*.whl && \
python3 -c 'import compressai'
# Install jupyter?
ARG WITH_JUPYTER=0
RUN if [ "$WITH_JUPYTER" = "1" ]; then \
pip3 install jupyter ipywidgets && \
jupyter nbextension enable --py widgetsnbextension \
; fi
WORKDIR /workspace
CMD ["bash"]
sphinx==3.0.3
sphinx_rtd_theme
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = ./source/
BUILDDIR = _build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
cd "${SOURCEDIR}"; python generate_cli_help.py
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=.
set BUILDDIR=_build
if "%1" == "" goto help
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.http://sphinx-doc.org/
exit /b 1
)
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd
Sphinx==3.0.3
sphinx-rtd-theme==0.4.3
sphinxcontrib-applehelp==1.0.2
sphinxcontrib-devhelp==1.0.2
sphinxcontrib-htmlhelp==1.0.3
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==1.0.3
sphinxcontrib-serializinghtml==1.1.4
compressai.ans
==============
Range Asymmetric Numeral System (rANS) bindings. rANS can be used as a
replacement for a traditional range coder.
Based on the original C++ implementation from Fabian "ryg" Giesen
`(github link) <https://github.com/rygorous/ryg_rans>`_.
.. currentmodule:: compressai.ans
RansEncoder
-----------
.. autoclass:: RansEncoder
RansDecoder
-----------
.. autoclass:: RansDecoder
Command line usage
==================
.. include:: cli_usage.inc
compressai
==========
.. automodule:: compressai
:members:
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment