Commit 9e459ea3 authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
import torch.nn as nn
from compressai.models.priors import (
SCALES_LEVELS,
SCALES_MAX,
SCALES_MIN,
CompressionModel,
FactorizedPrior,
JointAutoregressiveHierarchicalPriors,
MeanScaleHyperprior,
ScaleHyperprior,
get_scale_table,
)
from compressai.models.utils import (
_update_registered_buffer,
find_named_module,
update_registered_buffers,
)
class TestCompressionModel:
def test_parameters(self):
model = CompressionModel(32)
assert len(list(model.parameters())) == 15
with pytest.raises(NotImplementedError):
model(torch.rand(1))
def test_init(self):
class Model(CompressionModel):
def __init__(self):
super().__init__(3)
self.conv = nn.Conv2d(3, 3, 3)
self.deconv = nn.ConvTranspose2d(3, 3, 3)
self.original_conv = self.conv.weight
self.original_deconv = self.deconv.weight
self._initialize_weights()
model = Model()
nn.init.kaiming_normal_(model.original_conv)
nn.init.kaiming_normal_(model.original_deconv)
assert torch.allclose(model.original_conv, model.conv.weight)
assert torch.allclose(model.original_deconv, model.deconv.weight)
assert model.conv.bias.abs().sum() == 0
assert model.deconv.bias.abs().sum() == 0
class TestModels:
def test_factorized_prior(self):
model = FactorizedPrior(128, 192)
x = torch.rand(1, 3, 64, 64)
out = model(x)
assert "x_hat" in out
assert "likelihoods" in out
assert "y" in out["likelihoods"]
assert out["x_hat"].shape == x.shape
y_likelihoods_shape = out["likelihoods"]["y"].shape
assert y_likelihoods_shape[0] == x.shape[0]
assert y_likelihoods_shape[1] == 192
assert y_likelihoods_shape[2] == x.shape[2] / 2 ** 4
assert y_likelihoods_shape[3] == x.shape[3] / 2 ** 4
def test_scale_hyperprior(self, tmpdir):
model = ScaleHyperprior(128, 192)
x = torch.rand(1, 3, 64, 64)
out = model(x)
assert "x_hat" in out
assert "likelihoods" in out
assert "y" in out["likelihoods"]
assert "z" in out["likelihoods"]
assert out["x_hat"].shape == x.shape
y_likelihoods_shape = out["likelihoods"]["y"].shape
assert y_likelihoods_shape[0] == x.shape[0]
assert y_likelihoods_shape[1] == 192
assert y_likelihoods_shape[2] == x.shape[2] / 2 ** 4
assert y_likelihoods_shape[3] == x.shape[3] / 2 ** 4
z_likelihoods_shape = out["likelihoods"]["z"].shape
assert z_likelihoods_shape[0] == x.shape[0]
assert z_likelihoods_shape[1] == 128
assert z_likelihoods_shape[2] == x.shape[2] / 2 ** 6
assert z_likelihoods_shape[3] == x.shape[3] / 2 ** 6
for sz in [(128, 128), (128, 192), (192, 128)]:
model = ScaleHyperprior(*sz)
filepath = tmpdir.join("model.pth.rar").strpath
torch.save(model.state_dict(), filepath)
loaded = ScaleHyperprior.from_state_dict(torch.load(filepath))
assert model.N == loaded.N and model.M == loaded.M
def test_mean_scale_hyperprior(self):
model = MeanScaleHyperprior(128, 192)
x = torch.rand(1, 3, 64, 64)
out = model(x)
assert "x_hat" in out
assert "likelihoods" in out
assert "y" in out["likelihoods"]
assert "z" in out["likelihoods"]
assert out["x_hat"].shape == x.shape
y_likelihoods_shape = out["likelihoods"]["y"].shape
assert y_likelihoods_shape[0] == x.shape[0]
assert y_likelihoods_shape[1] == 192
assert y_likelihoods_shape[2] == x.shape[2] / 2 ** 4
assert y_likelihoods_shape[3] == x.shape[3] / 2 ** 4
z_likelihoods_shape = out["likelihoods"]["z"].shape
assert z_likelihoods_shape[0] == x.shape[0]
assert z_likelihoods_shape[1] == 128
assert z_likelihoods_shape[2] == x.shape[2] / 2 ** 6
assert z_likelihoods_shape[3] == x.shape[3] / 2 ** 6
def test_jarhp(self, tmpdir):
model = JointAutoregressiveHierarchicalPriors(128, 192)
x = torch.rand(1, 3, 64, 64)
out = model(x)
assert "x_hat" in out
assert "likelihoods" in out
assert "y" in out["likelihoods"]
assert "z" in out["likelihoods"]
assert out["x_hat"].shape == x.shape
y_likelihoods_shape = out["likelihoods"]["y"].shape
assert y_likelihoods_shape[0] == x.shape[0]
assert y_likelihoods_shape[1] == 192
assert y_likelihoods_shape[2] == x.shape[2] / 2 ** 4
assert y_likelihoods_shape[3] == x.shape[3] / 2 ** 4
z_likelihoods_shape = out["likelihoods"]["z"].shape
assert z_likelihoods_shape[0] == x.shape[0]
assert z_likelihoods_shape[1] == 128
assert z_likelihoods_shape[2] == x.shape[2] / 2 ** 6
assert z_likelihoods_shape[3] == x.shape[3] / 2 ** 6
for sz in [(128, 128), (128, 192), (192, 128)]:
model = JointAutoregressiveHierarchicalPriors(*sz)
filepath = tmpdir.join("model.pth.rar").strpath
torch.save(model.state_dict(), filepath)
loaded = JointAutoregressiveHierarchicalPriors.from_state_dict(
torch.load(filepath)
)
assert model.N == loaded.N and model.M == loaded.M
def test_scale_table_default():
table = get_scale_table()
assert SCALES_MIN == 0.11
assert SCALES_MAX == 256
assert SCALES_LEVELS == 64
assert table[0] == SCALES_MIN
assert table[-1] == SCALES_MAX
assert len(table.size()) == 1
assert table.size(0) == SCALES_LEVELS
def test_scale_table_custom():
table = get_scale_table(0.02, 1337, 32)
assert pytest.approx(table[0].item(), 0.02)
assert pytest.approx(table[-1].item(), 1337)
assert len(table.size()) == 1
assert table.size(0) == 32
class Foo(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 3, 1)
self.conv2 = nn.Conv2d(3, 3, 1)
def test_find_named_module():
assert find_named_module(Foo(), "conv3") is None
foo = Foo()
found = find_named_module(foo, "conv1")
assert found == foo.conv1
def test_update_registered_buffers():
foo = Foo()
with pytest.raises(ValueError):
update_registered_buffers(foo, "conv1", ["qweight"], {})
def test_update_registered_buffer():
foo = Foo()
# non-registered buffer
state_dict = foo.state_dict()
state_dict["conv1.wweight"] = torch.rand(3)
with pytest.raises(RuntimeError):
_update_registered_buffer(
foo.conv1, "wweight", "conv1.wweight", state_dict, policy="resize"
)
with pytest.raises(RuntimeError):
_update_registered_buffer(
foo.conv1, "wweight", "conv1.wweight", state_dict, policy="resize_if_empty"
)
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from compressai.ops import LowerBound, NonNegativeParametrizer, ste_round
class TestSTERound:
def test_ste_round_ok(self):
x = torch.rand(16)
assert (ste_round(x) == torch.round(x)).all()
def test_ste_round_grads(self):
x = torch.rand(24, requires_grad=True)
y = ste_round(x)
y.backward(x)
assert x.grad is not None
assert (x.grad == x).all()
class TestLowerBound:
def test_lower_bound_ok(self):
x = torch.rand(16)
bound = torch.rand(1)
lower_bound = LowerBound(bound)
assert (lower_bound(x) == torch.max(x, bound)).all()
def test_lower_bound_script(self):
x = torch.rand(16)
bound = torch.rand(1)
lower_bound = LowerBound(bound)
scripted = torch.jit.script(lower_bound)
assert (scripted(x) == torch.max(x, bound)).all()
def test_lower_bound_grads(self):
x = torch.rand(16, requires_grad=True)
bound = torch.rand(1)
lower_bound = LowerBound(bound)
y = lower_bound(x)
y.backward(x)
assert x.grad is not None
assert (x.grad == ((x >= bound) * x)).all()
class TestNonNegativeParametrizer:
def test_non_negative(self):
parametrizer = NonNegativeParametrizer()
x = torch.rand(1, 8, 8, 8) * 2 - 1 # [0, 1] -> [-1, 1]
x_reparam = parametrizer(x)
assert x_reparam.shape == x.shape
assert x_reparam.min() >= 0
def test_non_negative_init(self):
parametrizer = NonNegativeParametrizer()
x = torch.rand(1, 8, 8, 8) * 2 - 1
x_init = parametrizer.init(x)
assert x_init.shape == x.shape
assert torch.allclose(x_init, torch.sqrt(torch.max(x, x - x)), atol=2 ** -18)
def test_non_negative_min(self):
for _ in range(10):
minimum = torch.rand(1)
parametrizer = NonNegativeParametrizer(minimum.item())
x = torch.rand(1, 8, 8, 8) * 2 - 1
x_reparam = parametrizer(x)
assert x_reparam.shape == x.shape
assert torch.allclose(x_reparam.min(), minimum)
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from compressai.layers import GDN, GDN1, MaskedConv2d
class TestScripting:
def test_gdn(self):
g = GDN(128)
x = torch.rand(1, 128, 1, 1)
y0 = g(x)
m = torch.jit.script(g)
y1 = m(x)
assert torch.allclose(y0, y1)
def test_gdn1(self):
g = GDN1(128)
x = torch.rand(1, 128, 1, 1)
y0 = g(x)
m = torch.jit.script(g)
y1 = m(x)
assert torch.allclose(y0, y1)
def test_masked_conv_A(self):
conv = MaskedConv2d(3, 3, 3, padding=1)
with pytest.raises(RuntimeError):
m = torch.jit.script(conv)
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.util
import io
import re
from contextlib import redirect_stdout
from pathlib import Path
import pytest
@pytest.mark.slow
def test_train_example():
cwd = Path(__file__).resolve().parent
rootdir = cwd.parent
spec = importlib.util.spec_from_file_location(
"examples.train", rootdir / "examples/train.py"
)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
argv = [
"-d",
str(rootdir / "tests/assets/fakedata/imagefolder"),
"-e",
"10",
"--batch-size",
"1",
"--patch-size",
"48",
"128",
"--seed",
"3.14",
]
f = io.StringIO()
with redirect_stdout(f):
module.main(argv)
log = f.getvalue()
logpath = cwd / "expected" / "train_log_3.14.txt"
if not logpath.is_file():
with logpath.open("w") as f:
f.write(log)
with logpath.open("r") as f:
expected = f.read()
test_values = [m[0] for m in re.findall(r"(?P<number>([0-9]*[.])?[0-9]+)", log)]
expected_values = [
m[0] for m in re.findall(r"(?P<number>([0-9]*[.])?[0-9]+)", expected)
]
assert len(test_values) == len(expected_values)
for a, b in zip(test_values, expected_values):
try:
assert int(a) == int(b)
except ValueError:
assert pytest.approx(float(a), float(b))
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from compressai.transforms import RGB2YCbCr, YCbCr2RGB, YUV420To444, YUV444To420
from compressai.transforms.functional import (
rgb2ycbcr,
ycbcr2rgb,
yuv_420_to_444,
yuv_444_to_420,
)
@pytest.mark.parametrize("func", (rgb2ycbcr, ycbcr2rgb))
def test_invalid_input(func):
with pytest.raises(ValueError):
func(torch.rand(1, 3).numpy())
with pytest.raises(ValueError):
func(torch.rand(1, 3))
with pytest.raises(ValueError):
func(torch.rand(1, 4, 4, 4))
with pytest.raises(ValueError):
func(torch.rand(1, 3, 4, 4).int())
@pytest.mark.parametrize("func", (rgb2ycbcr, ycbcr2rgb))
def test_ok(func):
x = torch.rand(1, 3, 32, 32)
rv = func(x)
assert rv.size() == x.size()
assert rv.type() == x.type()
x = torch.rand(3, 64, 64)
rv = func(x)
assert rv.size() == x.size()
assert rv.type() == x.type()
def test_round_trip():
x = torch.rand(1, 3, 32, 32)
rv = ycbcr2rgb(rgb2ycbcr(x))
assert torch.allclose(x, rv, atol=1e-5)
rv = rgb2ycbcr(ycbcr2rgb(x))
assert torch.allclose(x, rv, atol=1e-5)
def test_444_to_420():
x = torch.rand(1, 3, 32, 32)
y, u, v = yuv_444_to_420(x)
assert u.size(0) == v.size(0) == y.size(0) == x.size(0)
assert u.size(1) == v.size(1) == y.size(1) == 1
assert y.size(2) == x.size(2) and y.size(3) == x.size(3)
assert u.size(2) == v.size(2) == (y.size(2) // 2)
assert u.size(3) == v.size(3) == (y.size(2) // 2)
assert (x[:, [0]] == y).all()
with pytest.raises(ValueError):
y, u, v = yuv_444_to_420(x, mode="toto")
y, u, v = yuv_444_to_420(x.chunk(3, 1))
def test_420_to_444():
y = torch.rand(1, 1, 32, 32)
u = torch.rand(1, 1, 16, 16)
v = torch.rand(1, 1, 16, 16)
with pytest.raises(ValueError):
yuv_420_to_444((y, u))
with pytest.raises(ValueError):
yuv_420_to_444((y, u, v), mode="bilateral")
rv = yuv_420_to_444((y, u, v))
assert isinstance(rv, torch.Tensor)
assert (rv[:, [0]] == y).all()
rv = yuv_420_to_444((y, u, v), return_tuple=True)
assert all(isinstance(c, torch.Tensor) for c in rv)
assert (rv[0] == y).all()
assert rv[0].size() == rv[1].size() == rv[2].size()
def test_transforms():
x = torch.rand(1, 3, 32, 32)
rv = RGB2YCbCr()(x)
assert rv.size() == x.size()
repr(RGB2YCbCr())
rv = YCbCr2RGB()(x)
assert rv.size() == x.size()
repr(YCbCr2RGB())
rv = YUV444To420()(x)
assert len(rv) == 3
repr(YUV444To420())
rv = YUV420To444()(rv)
assert rv.size() == x.size()
repr(YUV420To444())
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import io
from contextlib import redirect_stderr, redirect_stdout
from pathlib import Path
import pytest
import torch
from compressai.models.priors import FactorizedPrior
update_model_module = importlib.import_module("compressai.utils.update_model.__main__")
def run_update_model(*args):
fout, ferr = io.StringIO(), io.StringIO()
with redirect_stderr(ferr):
with redirect_stdout(fout):
update_model_module.main(map(str, args))
return fout.getvalue(), ferr.getvalue()
def test_missing_filepath():
with pytest.raises(SystemExit):
run_update_model()
def test_invalid_filepath(tmpdir):
# directory
with pytest.raises(RuntimeError):
run_update_model(tmpdir)
# empty/invalid file
p = tmpdir.join("hello.txt")
p.write("")
with pytest.raises((EOFError, TypeError)):
run_update_model(p)
def test_valid(tmpdir):
p = tmpdir.join("model.pth.tar").strpath
net = FactorizedPrior(32, 64)
torch.save(net.state_dict(), p)
stdout, stderr = run_update_model(
p, "--architecture", "factorized-prior", "--dir", tmpdir
)
assert len(stdout) == 0
assert len(stderr) == 0
files = list(Path(tmpdir).glob("*.pth.tar"))
assert len(files) == 1
cdf_len = net.state_dict()["entropy_bottleneck._cdf_length"]
new_cdf_len = torch.load(files[0])["entropy_bottleneck._cdf_length"]
assert cdf_len.size(0) != new_cdf_len.size(0)
def test_valid_name(tmpdir):
p = tmpdir.join("model.pth.tar").strpath
net = FactorizedPrior(32, 64)
torch.save(net.state_dict(), p)
stdout, stderr = run_update_model(
p, "--architecture", "factorized-prior", "--dir", tmpdir, "--name", "yolo"
)
assert len(stdout) == 0
assert len(stderr) == 0
files = sorted(list(Path(tmpdir).glob("*.pth.tar")))
assert len(files) == 2
assert files[0].name == "model.pth.tar"
assert files[1].name[:5] == "yolo-"
def test_valid_no_update(tmpdir):
p = tmpdir.join("model.pth.tar").strpath
net = FactorizedPrior(32, 64)
torch.save(net.state_dict(), p)
stdout, stderr = run_update_model(
p, "--architecture", "factorized-prior", "--dir", tmpdir, "--no-update"
)
assert len(stdout) == 0
assert len(stderr) == 0
files = list(Path(tmpdir).glob("*.pth.tar"))
assert len(files) == 1
cdf_len = net.state_dict()["entropy_bottleneck._cdf_length"]
new_cdf_len = torch.load(files[0])["entropy_bottleneck._cdf_length"]
assert cdf_len.size(0) == new_cdf_len.size(0)
def test_invalid_model(tmpdir):
p = tmpdir.join("model.pth.tar").strpath
net = FactorizedPrior(32, 64)
torch.save(net.state_dict(), p)
with pytest.raises(SystemExit):
run_update_model(p, "--architecture", "foobar")
def test_load(tmpdir):
p = tmpdir.join("model.pth.tar").strpath
net = FactorizedPrior(32, 64)
for k in ["network", "state_dict"]:
torch.save({k: net.state_dict()}, p)
stdout, stderr = run_update_model(
p, "--architecture", "factorized-prior", "--dir", tmpdir
)
assert len(stdout) == 0
assert len(stderr) == 0
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from compressai.models.waseda import Cheng2020Anchor
from compressai.zoo import cheng2020_anchor
def test_cheng2020_anchor():
net = cheng2020_anchor(quality=1, pretrained=True)
Cheng2020Anchor.from_state_dict(net.state_dict())
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from compressai.models import (
Cheng2020Anchor,
Cheng2020Attention,
FactorizedPrior,
JointAutoregressiveHierarchicalPriors,
MeanScaleHyperprior,
ScaleHyperprior,
)
from compressai.zoo import (
bmshj2018_factorized,
bmshj2018_hyperprior,
cheng2020_anchor,
cheng2020_attn,
mbt2018,
mbt2018_mean,
)
from compressai.zoo.image import _load_model
class TestLoadModel:
def test_invalid(self):
with pytest.raises(ValueError):
_load_model("yolo", "mse", 1)
with pytest.raises(ValueError):
_load_model("mbt2018", "mse", 0)
class TestBmshj2018Factorized:
def test_params(self):
for i in range(1, 6):
net = bmshj2018_factorized(i, metric="mse")
assert isinstance(net, FactorizedPrior)
assert net.state_dict()["g_a.0.weight"].size(0) == 128
assert net.state_dict()["g_a.6.weight"].size(0) == 192
for i in range(6, 9):
net = bmshj2018_factorized(i, metric="mse")
assert isinstance(net, FactorizedPrior)
assert net.state_dict()["g_a.0.weight"].size(0) == 192
def test_invalid_params(self):
with pytest.raises(ValueError):
bmshj2018_factorized(-1)
with pytest.raises(ValueError):
bmshj2018_factorized(10)
with pytest.raises(ValueError):
bmshj2018_factorized(10, metric="ssim")
with pytest.raises(ValueError):
bmshj2018_factorized(1, metric="ssim")
@pytest.mark.slow
@pytest.mark.pretrained
@pytest.mark.parametrize(
"metric", [("mse",), ("ms-ssim",)]
) # bypass weird pytest bug
def test_pretrained(self, metric):
metric = metric[0]
for i in range(1, 6):
net = bmshj2018_factorized(i, metric=metric, pretrained=True)
assert net.state_dict()["g_a.0.weight"].size(0) == 128
assert net.state_dict()["g_a.6.weight"].size(0) == 192
for i in range(6, 9):
net = bmshj2018_factorized(i, metric=metric, pretrained=True)
assert net.state_dict()["g_a.0.weight"].size(0) == 192
assert net.state_dict()["g_a.6.weight"].size(0) == 320
class TestBmshj2018Hyperprior:
def test_params(self):
for i in range(1, 6):
net = bmshj2018_hyperprior(i, metric="mse")
assert isinstance(net, ScaleHyperprior)
assert net.state_dict()["g_a.0.weight"].size(0) == 128
assert net.state_dict()["g_a.6.weight"].size(0) == 192
for i in range(6, 9):
net = bmshj2018_hyperprior(i, metric="mse")
assert isinstance(net, ScaleHyperprior)
assert net.state_dict()["g_a.0.weight"].size(0) == 192
assert net.state_dict()["g_a.6.weight"].size(0) == 320
def test_invalid_params(self):
with pytest.raises(ValueError):
bmshj2018_hyperprior(-1)
with pytest.raises(ValueError):
bmshj2018_hyperprior(10)
with pytest.raises(ValueError):
bmshj2018_hyperprior(10, metric="ssim")
with pytest.raises(ValueError):
bmshj2018_hyperprior(1, metric="ssim")
@pytest.mark.slow
@pytest.mark.pretrained
def test_pretrained(self):
# test we can load the correct models from the urls
for i in range(1, 6):
net = bmshj2018_factorized(i, metric="mse", pretrained=True)
assert net.state_dict()["g_a.0.weight"].size(0) == 128
assert net.state_dict()["g_a.6.weight"].size(0) == 192
for i in range(6, 9):
net = bmshj2018_factorized(i, metric="mse", pretrained=True)
assert net.state_dict()["g_a.0.weight"].size(0) == 192
assert net.state_dict()["g_a.6.weight"].size(0) == 320
class TestMbt2018Mean:
def test_parameters(self):
for i in range(1, 5):
net = mbt2018_mean(i, metric="mse")
assert isinstance(net, MeanScaleHyperprior)
assert net.state_dict()["g_a.0.weight"].size(0) == 128
assert net.state_dict()["g_a.6.weight"].size(0) == 192
for i in range(5, 9):
net = mbt2018_mean(i, metric="mse")
assert isinstance(net, MeanScaleHyperprior)
assert net.state_dict()["g_a.0.weight"].size(0) == 192
assert net.state_dict()["g_a.6.weight"].size(0) == 320
def test_invalid_params(self):
with pytest.raises(ValueError):
mbt2018_mean(-1)
with pytest.raises(ValueError):
mbt2018_mean(10)
with pytest.raises(ValueError):
mbt2018_mean(10, metric="ssim")
with pytest.raises(ValueError):
mbt2018_mean(1, metric="ssim")
@pytest.mark.slow
@pytest.mark.pretrained
def test_pretrained(self):
# test we can load the correct models from the urls
for i in range(1, 5):
net = mbt2018_mean(i, metric="mse", pretrained=True)
assert net.state_dict()["g_a.0.weight"].size(0) == 128
assert net.state_dict()["g_a.6.weight"].size(0) == 192
for i in range(5, 9):
net = mbt2018_mean(i, metric="mse", pretrained=True)
assert net.state_dict()["g_a.0.weight"].size(0) == 192
assert net.state_dict()["g_a.6.weight"].size(0) == 320
class TestMbt2018:
def test_ok(self):
for i in range(1, 5):
net = mbt2018(i, metric="mse")
assert isinstance(net, JointAutoregressiveHierarchicalPriors)
assert net.state_dict()["g_a.0.weight"].size(0) == 192
assert net.state_dict()["g_a.6.weight"].size(0) == 192
for i in range(5, 9):
net = mbt2018(i, metric="mse")
assert isinstance(net, JointAutoregressiveHierarchicalPriors)
assert net.state_dict()["g_a.0.weight"].size(0) == 192
assert net.state_dict()["g_a.6.weight"].size(0) == 320
def test_invalid_params(self):
with pytest.raises(ValueError):
mbt2018(-1)
with pytest.raises(ValueError):
mbt2018(10)
with pytest.raises(ValueError):
mbt2018(10, metric="ssim")
with pytest.raises(ValueError):
mbt2018(1, metric="ssim")
@pytest.mark.slow
@pytest.mark.pretrained
def test_pretrained(self):
# test we can load the correct models from the urls
for i in range(1, 5):
net = mbt2018(i, metric="mse", pretrained=True)
assert net.state_dict()["g_a.0.weight"].size(0) == 192
assert net.state_dict()["g_a.6.weight"].size(0) == 192
for i in range(5, 9):
net = mbt2018(i, metric="mse", pretrained=True)
assert net.state_dict()["g_a.0.weight"].size(0) == 192
assert net.state_dict()["g_a.6.weight"].size(0) == 320
class TestCheng2020:
@pytest.mark.parametrize(
"func,cls",
(
(cheng2020_anchor, Cheng2020Anchor),
(cheng2020_attn, Cheng2020Attention),
),
)
def test_anchor_ok(self, func, cls):
for i in range(1, 4):
net = func(i, metric="mse")
assert isinstance(net, cls)
assert net.state_dict()["g_a.0.conv1.weight"].size(0) == 128
for i in range(4, 7):
net = func(i, metric="mse")
assert isinstance(net, cls)
assert net.state_dict()["g_a.0.conv1.weight"].size(0) == 192
@pytest.mark.slow
@pytest.mark.pretrained
def test_pretrained(self):
for i in range(1, 4):
net = cheng2020_anchor(i, metric="mse", pretrained=True)
assert net.state_dict()["g_a.0.conv1.weight"].size(0) == 128
for i in range(4, 7):
net = cheng2020_anchor(i, metric="mse", pretrained=True)
assert net.state_dict()["g_a.0.conv1.weight"].size(0) == 192
To the extent possible under law, Fabian Giesen has waived all
copyright and related or neighboring rights to ryg_rans, as
per the terms of the CC0 license:
https://creativecommons.org/publicdomain/zero/1.0
This work is published from the United States.
This is a public-domain implementation of several rANS variants. rANS is an
entropy coder from the ANS family, as described in Jarek Duda's paper
"Asymmetric numeral systems" (http://arxiv.org/abs/1311.2540).
- "rans_byte.h" has a byte-aligned rANS encoder/decoder and some comments on
how to use it. This implementation should work on all 32-bit architectures.
"main.cpp" is an example program that shows how to use it.
- "rans64.h" is a 64-bit version that emits entire 32-bit words at a time. It
is (usually) a good deal faster than rans_byte on 64-bit architectures, and
also makes for a very precise arithmetic coder (i.e. it gets quite close
to entropy). The trade-off is that this version will be slower on 32-bit
machines, and the output bitstream is not endian-neutral. "main64.cpp" is
the corresponding example.
- "rans_word_sse41.h" has a SIMD decoder (SSE 4.1 to be precise) that does IO
in units of 16-bit words. It has less precision than either rans_byte or
rans64 (meaning that it doesn't get as close to entropy) and requires
at least 4 independent streams of data to be useful; however, it is also a
good deal faster. "main_simd.cpp" shows how to use it.
See my blog http://fgiesen.wordpress.com/ for some notes on the design.
I've also written a paper on interleaving output streams from multiple entropy
coders:
http://arxiv.org/abs/1402.3392
this documents the underlying design for "rans_word_sse41", and also shows how
the same approach generalizes to e.g. GPU implementations, provided there are
enough independent contexts coded at the same time to fill up a warp/wavefront
or whatever your favorite GPU's terminology for its native SIMD width is.
Finally, there's also "main_alias.cpp", which shows how to combine rANS with
the alias method to get O(1) symbol lookup with table size proportional to the
number of symbols. I presented an overview of the underlying idea here:
http://fgiesen.wordpress.com/2014/02/18/rans-with-static-probability-distributions/
Results on my machine (Sandy Bridge i7-2600K) with rans_byte in 64-bit mode:
----
rANS encode:
12896496 clocks, 16.8 clocks/symbol (192.8MiB/s)
12486912 clocks, 16.2 clocks/symbol (199.2MiB/s)
12511975 clocks, 16.3 clocks/symbol (198.8MiB/s)
12660765 clocks, 16.5 clocks/symbol (196.4MiB/s)
12550285 clocks, 16.3 clocks/symbol (198.2MiB/s)
rANS: 435113 bytes
17023550 clocks, 22.1 clocks/symbol (146.1MiB/s)
18081509 clocks, 23.5 clocks/symbol (137.5MiB/s)
16901632 clocks, 22.0 clocks/symbol (147.1MiB/s)
17166188 clocks, 22.3 clocks/symbol (144.9MiB/s)
17235859 clocks, 22.4 clocks/symbol (144.3MiB/s)
decode ok!
interleaved rANS encode:
9618004 clocks, 12.5 clocks/symbol (258.6MiB/s)
9488277 clocks, 12.3 clocks/symbol (262.1MiB/s)
9460194 clocks, 12.3 clocks/symbol (262.9MiB/s)
9582025 clocks, 12.5 clocks/symbol (259.5MiB/s)
9332017 clocks, 12.1 clocks/symbol (266.5MiB/s)
interleaved rANS: 435117 bytes
10687601 clocks, 13.9 clocks/symbol (232.7MB/s)
10637918 clocks, 13.8 clocks/symbol (233.8MB/s)
10909652 clocks, 14.2 clocks/symbol (227.9MB/s)
10947637 clocks, 14.2 clocks/symbol (227.2MB/s)
10529464 clocks, 13.7 clocks/symbol (236.2MB/s)
decode ok!
----
And here's rans64 in 64-bit mode:
----
rANS encode:
10256075 clocks, 13.3 clocks/symbol (242.3MiB/s)
10620132 clocks, 13.8 clocks/symbol (234.1MiB/s)
10043080 clocks, 13.1 clocks/symbol (247.6MiB/s)
9878205 clocks, 12.8 clocks/symbol (251.8MiB/s)
10122645 clocks, 13.2 clocks/symbol (245.7MiB/s)
rANS: 435116 bytes
14244155 clocks, 18.5 clocks/symbol (174.6MiB/s)
15072524 clocks, 19.6 clocks/symbol (165.0MiB/s)
14787604 clocks, 19.2 clocks/symbol (168.2MiB/s)
14736556 clocks, 19.2 clocks/symbol (168.8MiB/s)
14686129 clocks, 19.1 clocks/symbol (169.3MiB/s)
decode ok!
interleaved rANS encode:
7691159 clocks, 10.0 clocks/symbol (323.3MiB/s)
7182692 clocks, 9.3 clocks/symbol (346.2MiB/s)
7060804 clocks, 9.2 clocks/symbol (352.2MiB/s)
6949201 clocks, 9.0 clocks/symbol (357.9MiB/s)
6876415 clocks, 8.9 clocks/symbol (361.6MiB/s)
interleaved rANS: 435120 bytes
8133574 clocks, 10.6 clocks/symbol (305.7MB/s)
8631618 clocks, 11.2 clocks/symbol (288.1MB/s)
8643790 clocks, 11.2 clocks/symbol (287.7MB/s)
8449364 clocks, 11.0 clocks/symbol (294.3MB/s)
8331444 clocks, 10.8 clocks/symbol (298.5MB/s)
decode ok!
----
Finally, here's the rans_word_sse41 decoder on an 8-way interleaved stream:
----
SIMD rANS: 435626 bytes
4597641 clocks, 6.0 clocks/symbol (540.8MB/s)
4514356 clocks, 5.9 clocks/symbol (550.8MB/s)
4780918 clocks, 6.2 clocks/symbol (520.1MB/s)
4532913 clocks, 5.9 clocks/symbol (548.5MB/s)
4554527 clocks, 5.9 clocks/symbol (545.9MB/s)
decode ok!
----
There's also an experimental 16-way interleaved AVX2 version that hits
faster rates still, developed by my colleague Won Chun; I will post it
soon.
Note that this is running "book1" which is a relatively short test, and
the measurement setup is not great, so take the results with a grain
of salt.
-Fabian "ryg" Giesen, Feb 2014.
// 64-bit rANS encoder/decoder - public domain - Fabian 'ryg' Giesen 2014
//
// This uses 64-bit states (63-bit actually) which allows renormalizing
// by writing out a whole 32 bits at a time (b=2^32) while still
// retaining good precision and allowing for high probability resolution.
//
// The only caveat is that this version requires 64-bit arithmetic; in
// particular, the encoder approximation in the bottom half requires a
// fast way to obtain the top 64 bits of an unsigned 64*64 bit product.
//
// In short, as written, this code works on 64-bit targets only!
#ifndef RANS64_HEADER
#define RANS64_HEADER
#include <stdint.h>
#ifdef assert
#define Rans64Assert assert
#else
#define Rans64Assert(x)
#endif
// --------------------------------------------------------------------------
// This code needs support for 64-bit long multiplies with 128-bit result
// (or more precisely, the top 64 bits of a 128-bit result). This is not
// really portable functionality, so we need some compiler-specific hacks
// here.
#if defined(_MSC_VER)
#include <intrin.h>
static inline uint64_t Rans64MulHi(uint64_t a, uint64_t b)
{
return __umulh(a, b);
}
#elif defined(__GNUC__)
static inline uint64_t Rans64MulHi(uint64_t a, uint64_t b)
{
return (uint64_t) (((unsigned __int128)a * b) >> 64);
}
#else
#error Unknown/unsupported compiler!
#endif
// --------------------------------------------------------------------------
// L ('l' in the paper) is the lower bound of our normalization interval.
// Between this and our 32-bit-aligned emission, we use 63 (not 64!) bits.
// This is done intentionally because exact reciprocals for 63-bit uints
// fit in 64-bit uints: this permits some optimizations during encoding.
#define RANS64_L (1ull << 31) // lower bound of our normalization interval
// State for a rANS encoder. Yep, that's all there is to it.
typedef uint64_t Rans64State;
// Initialize a rANS encoder.
static inline void Rans64EncInit(Rans64State* r)
{
*r = RANS64_L;
}
// Encodes a single symbol with range start "start" and frequency "freq".
// All frequencies are assumed to sum to "1 << scale_bits", and the
// resulting bytes get written to ptr (which is updated).
//
// NOTE: With rANS, you need to encode symbols in *reverse order*, i.e. from
// beginning to end! Likewise, the output bytestream is written *backwards*:
// ptr starts pointing at the end of the output buffer and keeps decrementing.
static inline void Rans64EncPut(Rans64State* r, uint32_t** pptr, uint32_t start, uint32_t freq, uint32_t scale_bits)
{
Rans64Assert(freq != 0);
// renormalize (never needs to loop)
uint64_t x = *r;
uint64_t x_max = ((RANS64_L >> scale_bits) << 32) * freq; // this turns into a shift.
if (x >= x_max) {
*pptr -= 1;
**pptr = (uint32_t) x;
x >>= 32;
Rans64Assert(x < x_max);
}
// x = C(s,x)
*r = ((x / freq) << scale_bits) + (x % freq) + start;
}
// Flushes the rANS encoder.
static inline void Rans64EncFlush(Rans64State* r, uint32_t** pptr)
{
uint64_t x = *r;
*pptr -= 2;
(*pptr)[0] = (uint32_t) (x >> 0);
(*pptr)[1] = (uint32_t) (x >> 32);
}
// Initializes a rANS decoder.
// Unlike the encoder, the decoder works forwards as you'd expect.
static inline void Rans64DecInit(Rans64State* r, uint32_t** pptr)
{
uint64_t x;
x = (uint64_t) ((*pptr)[0]) << 0;
x |= (uint64_t) ((*pptr)[1]) << 32;
*pptr += 2;
*r = x;
}
// Returns the current cumulative frequency (map it to a symbol yourself!)
static inline uint32_t Rans64DecGet(Rans64State* r, uint32_t scale_bits)
{
return *r & ((1u << scale_bits) - 1);
}
// Advances in the bit stream by "popping" a single symbol with range start
// "start" and frequency "freq". All frequencies are assumed to sum to "1 << scale_bits",
// and the resulting bytes get written to ptr (which is updated).
static inline void Rans64DecAdvance(Rans64State* r, uint32_t** pptr, uint32_t start, uint32_t freq, uint32_t scale_bits)
{
uint64_t mask = (1ull << scale_bits) - 1;
// s, x = D(x)
uint64_t x = *r;
x = freq * (x >> scale_bits) + (x & mask) - start;
// renormalize
if (x < RANS64_L) {
x = (x << 32) | **pptr;
*pptr += 1;
Rans64Assert(x >= RANS64_L);
}
*r = x;
}
// --------------------------------------------------------------------------
// That's all you need for a full encoder; below here are some utility
// functions with extra convenience or optimizations.
// Encoder symbol description
// This (admittedly odd) selection of parameters was chosen to make
// RansEncPutSymbol as cheap as possible.
typedef struct {
uint64_t rcp_freq; // Fixed-point reciprocal frequency
uint32_t freq; // Symbol frequency
uint32_t bias; // Bias
uint32_t cmpl_freq; // Complement of frequency: (1 << scale_bits) - freq
uint32_t rcp_shift; // Reciprocal shift
} Rans64EncSymbol;
// Decoder symbols are straightforward.
typedef struct {
uint32_t start; // Start of range.
uint32_t freq; // Symbol frequency.
} Rans64DecSymbol;
// Initializes an encoder symbol to start "start" and frequency "freq"
static inline void Rans64EncSymbolInit(Rans64EncSymbol* s, uint32_t start, uint32_t freq, uint32_t scale_bits)
{
Rans64Assert(scale_bits <= 31);
Rans64Assert(start <= (1u << scale_bits));
Rans64Assert(freq <= (1u << scale_bits) - start);
// Say M := 1 << scale_bits.
//
// The original encoder does:
// x_new = (x/freq)*M + start + (x%freq)
//
// The fast encoder does (schematically):
// q = mul_hi(x, rcp_freq) >> rcp_shift (division)
// r = x - q*freq (remainder)
// x_new = q*M + bias + r (new x)
// plugging in r into x_new yields:
// x_new = bias + x + q*(M - freq)
// =: bias + x + q*cmpl_freq (*)
//
// and we can just precompute cmpl_freq. Now we just need to
// set up our parameters such that the original encoder and
// the fast encoder agree.
s->freq = freq;
s->cmpl_freq = ((1 << scale_bits) - freq);
if (freq < 2) {
// freq=0 symbols are never valid to encode, so it doesn't matter what
// we set our values to.
//
// freq=1 is tricky, since the reciprocal of 1 is 1; unfortunately,
// our fixed-point reciprocal approximation can only multiply by values
// smaller than 1.
//
// So we use the "next best thing": rcp_freq=~0, rcp_shift=0.
// This gives:
// q = mul_hi(x, rcp_freq) >> rcp_shift
// = mul_hi(x, (1<<64) - 1)) >> 0
// = floor(x - x/(2^64))
// = x - 1 if 1 <= x < 2^64
// and we know that x>0 (x=0 is never in a valid normalization interval).
//
// So we now need to choose the other parameters such that
// x_new = x*M + start
// plug it in:
// x*M + start (desired result)
// = bias + x + q*cmpl_freq (*)
// = bias + x + (x - 1)*(M - 1) (plug in q=x-1, cmpl_freq)
// = bias + 1 + (x - 1)*M
// = x*M + (bias + 1 - M)
//
// so we have start = bias + 1 - M, or equivalently
// bias = start + M - 1.
s->rcp_freq = ~0ull;
s->rcp_shift = 0;
s->bias = start + (1 << scale_bits) - 1;
} else {
// Alverson, "Integer Division using reciprocals"
// shift=ceil(log2(freq))
uint32_t shift = 0;
uint64_t x0, x1, t0, t1;
while (freq > (1u << shift))
shift++;
// long divide ((uint128) (1 << (shift + 63)) + freq-1) / freq
// by splitting it into two 64:64 bit divides (this works because
// the dividend has a simple form.)
x0 = freq - 1;
x1 = 1ull << (shift + 31);
t1 = x1 / freq;
x0 += (x1 % freq) << 32;
t0 = x0 / freq;
s->rcp_freq = t0 + (t1 << 32);
s->rcp_shift = shift - 1;
// With these values, 'q' is the correct quotient, so we
// have bias=start.
s->bias = start;
}
}
// Initialize a decoder symbol to start "start" and frequency "freq"
static inline void Rans64DecSymbolInit(Rans64DecSymbol* s, uint32_t start, uint32_t freq)
{
Rans64Assert(start <= (1 << 31));
Rans64Assert(freq <= (1 << 31) - start);
s->start = start;
s->freq = freq;
}
// Encodes a given symbol. This is faster than straight RansEnc since we can do
// multiplications instead of a divide.
//
// See RansEncSymbolInit for a description of how this works.
static inline void Rans64EncPutSymbol(Rans64State* r, uint32_t** pptr, Rans64EncSymbol const* sym, uint32_t scale_bits)
{
Rans64Assert(sym->freq != 0); // can't encode symbol with freq=0
// renormalize
uint64_t x = *r;
uint64_t x_max = ((RANS64_L >> scale_bits) << 32) * sym->freq; // turns into a shift
if (x >= x_max) {
*pptr -= 1;
**pptr = (uint32_t) x;
x >>= 32;
}
// x = C(s,x)
uint64_t q = Rans64MulHi(x, sym->rcp_freq) >> sym->rcp_shift;
*r = x + sym->bias + q * sym->cmpl_freq;
}
// Equivalent to RansDecAdvance that takes a symbol.
static inline void Rans64DecAdvanceSymbol(Rans64State* r, uint32_t** pptr, Rans64DecSymbol const* sym, uint32_t scale_bits)
{
Rans64DecAdvance(r, pptr, sym->start, sym->freq, scale_bits);
}
// Advances in the bit stream by "popping" a single symbol with range start
// "start" and frequency "freq". All frequencies are assumed to sum to "1 << scale_bits".
// No renormalization or output happens.
static inline void Rans64DecAdvanceStep(Rans64State* r, uint32_t start, uint32_t freq, uint32_t scale_bits)
{
uint64_t mask = (1u << scale_bits) - 1;
// s, x = D(x)
uint64_t x = *r;
*r = freq * (x >> scale_bits) + (x & mask) - start;
}
// Equivalent to RansDecAdvanceStep that takes a symbol.
static inline void Rans64DecAdvanceSymbolStep(Rans64State* r, Rans64DecSymbol const* sym, uint32_t scale_bits)
{
Rans64DecAdvanceStep(r, sym->start, sym->freq, scale_bits);
}
// Renormalize.
static inline void Rans64DecRenorm(Rans64State* r, uint32_t** pptr)
{
// renormalize
uint64_t x = *r;
if (x < RANS64_L) {
x = (x << 32) | **pptr;
*pptr += 1;
Rans64Assert(x >= RANS64_L);
}
*r = x;
}
#endif // RANS64_HEADER
// Simple byte-aligned rANS encoder/decoder - public domain - Fabian 'ryg' Giesen 2014
//
// Not intended to be "industrial strength"; just meant to illustrate the general
// idea.
#ifndef RANS_BYTE_HEADER
#define RANS_BYTE_HEADER
#include <stdint.h>
#ifdef assert
#define RansAssert assert
#else
#define RansAssert(x)
#endif
// READ ME FIRST:
//
// This is designed like a typical arithmetic coder API, but there's three
// twists you absolutely should be aware of before you start hacking:
//
// 1. You need to encode data in *reverse* - last symbol first. rANS works
// like a stack: last in, first out.
// 2. Likewise, the encoder outputs bytes *in reverse* - that is, you give
// it a pointer to the *end* of your buffer (exclusive), and it will
// slowly move towards the beginning as more bytes are emitted.
// 3. Unlike basically any other entropy coder implementation you might
// have used, you can interleave data from multiple independent rANS
// encoders into the same bytestream without any extra signaling;
// you can also just write some bytes by yourself in the middle if
// you want to. This is in addition to the usual arithmetic encoder
// property of being able to switch models on the fly. Writing raw
// bytes can be useful when you have some data that you know is
// incompressible, and is cheaper than going through the rANS encode
// function. Using multiple rANS coders on the same byte stream wastes
// a few bytes compared to using just one, but execution of two
// independent encoders can happen in parallel on superscalar and
// Out-of-Order CPUs, so this can be *much* faster in tight decoding
// loops.
//
// This is why all the rANS functions take the write pointer as an
// argument instead of just storing it in some context struct.
// --------------------------------------------------------------------------
// L ('l' in the paper) is the lower bound of our normalization interval.
// Between this and our byte-aligned emission, we use 31 (not 32!) bits.
// This is done intentionally because exact reciprocals for 31-bit uints
// fit in 32-bit uints: this permits some optimizations during encoding.
#define RANS_BYTE_L (1u << 23) // lower bound of our normalization interval
// State for a rANS encoder. Yep, that's all there is to it.
typedef uint32_t RansState;
// Initialize a rANS encoder.
static inline void RansEncInit(RansState* r)
{
*r = RANS_BYTE_L;
}
// Renormalize the encoder. Internal function.
static inline RansState RansEncRenorm(RansState x, uint8_t** pptr, uint32_t freq, uint32_t scale_bits)
{
uint32_t x_max = ((RANS_BYTE_L >> scale_bits) << 8) * freq; // this turns into a shift.
if (x >= x_max) {
uint8_t* ptr = *pptr;
do {
*--ptr = (uint8_t) (x & 0xff);
x >>= 8;
} while (x >= x_max);
*pptr = ptr;
}
return x;
}
// Encodes a single symbol with range start "start" and frequency "freq".
// All frequencies are assumed to sum to "1 << scale_bits", and the
// resulting bytes get written to ptr (which is updated).
//
// NOTE: With rANS, you need to encode symbols in *reverse order*, i.e. from
// beginning to end! Likewise, the output bytestream is written *backwards*:
// ptr starts pointing at the end of the output buffer and keeps decrementing.
static inline void RansEncPut(RansState* r, uint8_t** pptr, uint32_t start, uint32_t freq, uint32_t scale_bits)
{
// renormalize
RansState x = RansEncRenorm(*r, pptr, freq, scale_bits);
// x = C(s,x)
*r = ((x / freq) << scale_bits) + (x % freq) + start;
}
// Flushes the rANS encoder.
static inline void RansEncFlush(RansState* r, uint8_t** pptr)
{
uint32_t x = *r;
uint8_t* ptr = *pptr;
ptr -= 4;
ptr[0] = (uint8_t) (x >> 0);
ptr[1] = (uint8_t) (x >> 8);
ptr[2] = (uint8_t) (x >> 16);
ptr[3] = (uint8_t) (x >> 24);
*pptr = ptr;
}
// Initializes a rANS decoder.
// Unlike the encoder, the decoder works forwards as you'd expect.
static inline void RansDecInit(RansState* r, uint8_t** pptr)
{
uint32_t x;
uint8_t* ptr = *pptr;
x = ptr[0] << 0;
x |= ptr[1] << 8;
x |= ptr[2] << 16;
x |= ptr[3] << 24;
ptr += 4;
*pptr = ptr;
*r = x;
}
// Returns the current cumulative frequency (map it to a symbol yourself!)
static inline uint32_t RansDecGet(RansState* r, uint32_t scale_bits)
{
return *r & ((1u << scale_bits) - 1);
}
// Advances in the bit stream by "popping" a single symbol with range start
// "start" and frequency "freq". All frequencies are assumed to sum to "1 << scale_bits",
// and the resulting bytes get written to ptr (which is updated).
static inline void RansDecAdvance(RansState* r, uint8_t** pptr, uint32_t start, uint32_t freq, uint32_t scale_bits)
{
uint32_t mask = (1u << scale_bits) - 1;
// s, x = D(x)
uint32_t x = *r;
x = freq * (x >> scale_bits) + (x & mask) - start;
// renormalize
if (x < RANS_BYTE_L) {
uint8_t* ptr = *pptr;
do x = (x << 8) | *ptr++; while (x < RANS_BYTE_L);
*pptr = ptr;
}
*r = x;
}
// --------------------------------------------------------------------------
// That's all you need for a full encoder; below here are some utility
// functions with extra convenience or optimizations.
// Encoder symbol description
// This (admittedly odd) selection of parameters was chosen to make
// RansEncPutSymbol as cheap as possible.
typedef struct {
uint32_t x_max; // (Exclusive) upper bound of pre-normalization interval
uint32_t rcp_freq; // Fixed-point reciprocal frequency
uint32_t bias; // Bias
uint16_t cmpl_freq; // Complement of frequency: (1 << scale_bits) - freq
uint16_t rcp_shift; // Reciprocal shift
} RansEncSymbol;
// Decoder symbols are straightforward.
typedef struct {
uint16_t start; // Start of range.
uint16_t freq; // Symbol frequency.
} RansDecSymbol;
// Initializes an encoder symbol to start "start" and frequency "freq"
static inline void RansEncSymbolInit(RansEncSymbol* s, uint32_t start, uint32_t freq, uint32_t scale_bits)
{
RansAssert(scale_bits <= 16);
RansAssert(start <= (1u << scale_bits));
RansAssert(freq <= (1u << scale_bits) - start);
// Say M := 1 << scale_bits.
//
// The original encoder does:
// x_new = (x/freq)*M + start + (x%freq)
//
// The fast encoder does (schematically):
// q = mul_hi(x, rcp_freq) >> rcp_shift (division)
// r = x - q*freq (remainder)
// x_new = q*M + bias + r (new x)
// plugging in r into x_new yields:
// x_new = bias + x + q*(M - freq)
// =: bias + x + q*cmpl_freq (*)
//
// and we can just precompute cmpl_freq. Now we just need to
// set up our parameters such that the original encoder and
// the fast encoder agree.
s->x_max = ((RANS_BYTE_L >> scale_bits) << 8) * freq;
s->cmpl_freq = (uint16_t) ((1 << scale_bits) - freq);
if (freq < 2) {
// freq=0 symbols are never valid to encode, so it doesn't matter what
// we set our values to.
//
// freq=1 is tricky, since the reciprocal of 1 is 1; unfortunately,
// our fixed-point reciprocal approximation can only multiply by values
// smaller than 1.
//
// So we use the "next best thing": rcp_freq=0xffffffff, rcp_shift=0.
// This gives:
// q = mul_hi(x, rcp_freq) >> rcp_shift
// = mul_hi(x, (1<<32) - 1)) >> 0
// = floor(x - x/(2^32))
// = x - 1 if 1 <= x < 2^32
// and we know that x>0 (x=0 is never in a valid normalization interval).
//
// So we now need to choose the other parameters such that
// x_new = x*M + start
// plug it in:
// x*M + start (desired result)
// = bias + x + q*cmpl_freq (*)
// = bias + x + (x - 1)*(M - 1) (plug in q=x-1, cmpl_freq)
// = bias + 1 + (x - 1)*M
// = x*M + (bias + 1 - M)
//
// so we have start = bias + 1 - M, or equivalently
// bias = start + M - 1.
s->rcp_freq = ~0u;
s->rcp_shift = 0;
s->bias = start + (1 << scale_bits) - 1;
} else {
// Alverson, "Integer Division using reciprocals"
// shift=ceil(log2(freq))
uint32_t shift = 0;
while (freq > (1u << shift))
shift++;
s->rcp_freq = (uint32_t) (((1ull << (shift + 31)) + freq-1) / freq);
s->rcp_shift = shift - 1;
// With these values, 'q' is the correct quotient, so we
// have bias=start.
s->bias = start;
}
}
// Initialize a decoder symbol to start "start" and frequency "freq"
static inline void RansDecSymbolInit(RansDecSymbol* s, uint32_t start, uint32_t freq)
{
RansAssert(start <= (1 << 16));
RansAssert(freq <= (1 << 16) - start);
s->start = (uint16_t) start;
s->freq = (uint16_t) freq;
}
// Encodes a given symbol. This is faster than straight RansEnc since we can do
// multiplications instead of a divide.
//
// See RansEncSymbolInit for a description of how this works.
static inline void RansEncPutSymbol(RansState* r, uint8_t** pptr, RansEncSymbol const* sym)
{
RansAssert(sym->x_max != 0); // can't encode symbol with freq=0
// renormalize
uint32_t x = *r;
uint32_t x_max = sym->x_max;
if (x >= x_max) {
uint8_t* ptr = *pptr;
do {
*--ptr = (uint8_t) (x & 0xff);
x >>= 8;
} while (x >= x_max);
*pptr = ptr;
}
// x = C(s,x)
// NOTE: written this way so we get a 32-bit "multiply high" when
// available. If you're on a 64-bit platform with cheap multiplies
// (e.g. x64), just bake the +32 into rcp_shift.
uint32_t q = (uint32_t) (((uint64_t)x * sym->rcp_freq) >> 32) >> sym->rcp_shift;
*r = x + sym->bias + q * sym->cmpl_freq;
}
// Equivalent to RansDecAdvance that takes a symbol.
static inline void RansDecAdvanceSymbol(RansState* r, uint8_t** pptr, RansDecSymbol const* sym, uint32_t scale_bits)
{
RansDecAdvance(r, pptr, sym->start, sym->freq, scale_bits);
}
// Advances in the bit stream by "popping" a single symbol with range start
// "start" and frequency "freq". All frequencies are assumed to sum to "1 << scale_bits".
// No renormalization or output happens.
static inline void RansDecAdvanceStep(RansState* r, uint32_t start, uint32_t freq, uint32_t scale_bits)
{
uint32_t mask = (1u << scale_bits) - 1;
// s, x = D(x)
uint32_t x = *r;
*r = freq * (x >> scale_bits) + (x & mask) - start;
}
// Equivalent to RansDecAdvanceStep that takes a symbol.
static inline void RansDecAdvanceSymbolStep(RansState* r, RansDecSymbol const* sym, uint32_t scale_bits)
{
RansDecAdvanceStep(r, sym->start, sym->freq, scale_bits);
}
// Renormalize.
static inline void RansDecRenorm(RansState* r, uint8_t** pptr)
{
// renormalize
uint32_t x = *r;
if (x < RANS_BYTE_L) {
uint8_t* ptr = *pptr;
do x = (x << 8) | *ptr++; while (x < RANS_BYTE_L);
*pptr = ptr;
}
*r = x;
}
#endif // RANS_BYTE_HEADER
\ No newline at end of file
// Word-aligned SSE 4.1 rANS encoder/decoder - public domain - Fabian 'ryg' Giesen
//
// This implementation has a regular rANS encoder and a 4-way interleaved SIMD
// decoder. Like rans_byte.h, it's intended to illustrate the idea, not to
// be used as a drop-in arithmetic coder.
#ifndef RANS_WORD_SSE41_HEADER
#define RANS_WORD_SSE41_HEADER
#include <stdint.h>
#include <smmintrin.h>
// READ ME FIRST:
//
// The intention in this version is to demonstrate a design where the decoder
// is made as fast as possible, even when it makes the encoder slightly slower
// or hurts compression a bit. (The code in rans_byte.h, with the 31-bit
// arithmetic to allow for faster division by constants, is a more "balanced"
// approach).
//
// This version is intended to be used with relatively low-resolution
// probability distributions (scale_bits=12 or less). In these regions, the
// "fully unrolled" table-based approach shown here (suggested by "enotuss"
// on my blog) is optimal; for larger scale_bits, other approaches are more
// favorable. It also only assumes an 8-bit symbol alphabet for simplicity.
//
// Unlike rans_byte.h, this file needs to be compiled as C++.
// --------------------------------------------------------------------------
// This coder uses L=1<<16 and B=1<<16 (16-bit word based renormalization).
// Since we still continue to use 32-bit words, this means we require
// scale_bits <= 16; on the plus side, renormalization never needs to
// iterate.
#define RANS_WORD_L (1u << 16)
#define RANS_WORD_SCALE_BITS 12
#define RANS_WORD_M (1u << RANS_WORD_SCALE_BITS)
#define RANS_WORD_NSYMS 256
typedef uint32_t RansWordEnc;
typedef uint32_t RansWordDec;
typedef union {
__m128i simd;
uint32_t lane[4];
} RansSimdDec;
union RansWordSlot {
uint32_t u32;
struct {
uint16_t freq;
uint16_t bias;
};
};
struct RansWordTables {
RansWordSlot slots[RANS_WORD_M];
uint8_t slot2sym[RANS_WORD_M];
};
// Initialize slots for a symbol in the table
static inline void RansWordTablesInitSymbol(RansWordTables* tab, uint8_t sym, uint32_t start, uint32_t freq)
{
for (uint32_t i=0; i < freq; i++) {
uint32_t slot = start + i;
tab->slot2sym[slot] = sym;
tab->slots[slot].freq = (uint16_t)freq;
tab->slots[slot].bias = (uint16_t)i;
}
}
// Initialize a rANS encoder
static inline RansWordEnc RansWordEncInit()
{
return RANS_WORD_L;
}
// Encodes a single symbol with range "start" and frequency "freq".
static inline void RansWordEncPut(RansWordEnc* r, uint16_t** pptr, uint32_t start, uint32_t freq)
{
// renormalize
uint32_t x = *r;
if (x >= ((RANS_WORD_L >> RANS_WORD_SCALE_BITS) << 16) * freq) {
*pptr -= 1;
**pptr = (uint16_t) (x & 0xffff);
x >>= 16;
}
// x = C(s,x)
*r = ((x / freq) << RANS_WORD_SCALE_BITS) + (x % freq) + start;
}
// Flushes the rANS encoder
static inline void RansWordEncFlush(RansWordEnc* r, uint16_t** pptr)
{
uint32_t x = *r;
uint16_t* ptr = *pptr;
ptr -= 2;
ptr[0] = (uint16_t) (x >> 0);
ptr[1] = (uint16_t) (x >> 16);
*pptr = ptr;
}
// Initializes a rANS decoder.
static inline void RansWordDecInit(RansWordDec* r, uint16_t** pptr)
{
uint32_t x;
uint16_t* ptr = *pptr;
x = ptr[0] << 0;
x |= ptr[1] << 16;
ptr += 2;
*pptr = ptr;
*r = x;
}
// Decodes a symbol using the given tables.
static inline uint8_t RansWordDecSym(RansWordDec* r, RansWordTables const* tab)
{
uint32_t x = *r;
uint32_t slot = x & (RANS_WORD_M - 1);
// s, x = D(x)
*r = tab->slots[slot].freq * (x >> RANS_WORD_SCALE_BITS) + tab->slots[slot].bias;
return tab->slot2sym[slot];
}
// Renormalize after decoding a symbol.
static inline void RansWordDecRenorm(RansWordDec* r, uint16_t** pptr)
{
uint32_t x = *r;
if (x < RANS_WORD_L) {
*r = (x << 16) | **pptr;
*pptr += 1;
}
}
// Initializes a SIMD rANS decoder.
static inline void RansSimdDecInit(RansSimdDec* r, uint16_t** pptr)
{
r->simd = _mm_loadu_si128((const __m128i*)*pptr);
*pptr += 2*4;
}
// Decodes a four symbols in parallel using the given tables.
static inline uint32_t RansSimdDecSym(RansSimdDec* r, RansWordTables const* tab)
{
__m128i freq_bias_lo, freq_bias_hi, freq_bias;
__m128i freq, bias;
__m128i xscaled;
__m128i x = r->simd;
__m128i slots = _mm_and_si128(x, _mm_set1_epi32(RANS_WORD_M - 1));
uint32_t i0 = (uint32_t) _mm_cvtsi128_si32(slots);
uint32_t i1 = (uint32_t) _mm_extract_epi32(slots, 1);
uint32_t i2 = (uint32_t) _mm_extract_epi32(slots, 2);
uint32_t i3 = (uint32_t) _mm_extract_epi32(slots, 3);
// symbol
uint32_t s = tab->slot2sym[i0] | (tab->slot2sym[i1] << 8) | (tab->slot2sym[i2] << 16) | (tab->slot2sym[i3] << 24);
// gather freq_bias
freq_bias_lo = _mm_cvtsi32_si128(tab->slots[i0].u32);
freq_bias_lo = _mm_insert_epi32(freq_bias_lo, tab->slots[i1].u32, 1);
freq_bias_hi = _mm_cvtsi32_si128(tab->slots[i2].u32);
freq_bias_hi = _mm_insert_epi32(freq_bias_hi, tab->slots[i3].u32, 1);
freq_bias = _mm_unpacklo_epi64(freq_bias_lo, freq_bias_hi);
// s, x = D(x)
xscaled = _mm_srli_epi32(x, RANS_WORD_SCALE_BITS);
freq = _mm_and_si128(freq_bias, _mm_set1_epi32(0xffff));
bias = _mm_srli_epi32(freq_bias, 16);
r->simd = _mm_add_epi32(_mm_mullo_epi32(xscaled, freq), bias);
return s;
}
// Renormalize after decoding a symbol.
static inline void RansSimdDecRenorm(RansSimdDec* r, uint16_t** pptr)
{
static ALIGNSPEC(int8_t const, shuffles[16][16], 16) = {
#define _ -1 // for readability
{ _,_,_,_, _,_,_,_, _,_,_,_, _,_,_,_ }, // 0000
{ 0,1,_,_, _,_,_,_, _,_,_,_, _,_,_,_ }, // 0001
{ _,_,_,_, 0,1,_,_, _,_,_,_, _,_,_,_ }, // 0010
{ 0,1,_,_, 2,3,_,_, _,_,_,_, _,_,_,_ }, // 0011
{ _,_,_,_, _,_,_,_, 0,1,_,_, _,_,_,_ }, // 0100
{ 0,1,_,_, _,_,_,_, 2,3,_,_, _,_,_,_ }, // 0101
{ _,_,_,_, 0,1,_,_, 2,3,_,_, _,_,_,_ }, // 0110
{ 0,1,_,_, 2,3,_,_, 4,5,_,_, _,_,_,_ }, // 0111
{ _,_,_,_, _,_,_,_, _,_,_,_, 0,1,_,_ }, // 1000
{ 0,1,_,_, _,_,_,_, _,_,_,_, 2,3,_,_ }, // 1001
{ _,_,_,_, 0,1,_,_, _,_,_,_, 2,3,_,_ }, // 1010
{ 0,1,_,_, 2,3,_,_, _,_,_,_, 4,5,_,_ }, // 1011
{ _,_,_,_, _,_,_,_, 0,1,_,_, 2,3,_,_ }, // 1100
{ 0,1,_,_, _,_,_,_, 2,3,_,_, 4,5,_,_ }, // 1101
{ _,_,_,_, 0,1,_,_, 2,3,_,_, 4,5,_,_ }, // 1110
{ 0,1,_,_, 2,3,_,_, 4,5,_,_, 6,7,_,_ }, // 1111
#undef _
};
static uint8_t const numbits[16] = {
0,1,1,2, 1,2,2,3, 1,2,2,3, 2,3,3,4
};
__m128i x = r->simd;
// NOTE: SSE2+ only offer a signed 32-bit integer compare, while we
// need unsigned. So we subtract 0x80000000 before the compare,
// which converts unsigned integers to signed integers in an
// order-preserving manner.
__m128i x_biased = _mm_xor_si128(x, _mm_set1_epi32((int) 0x80000000));
__m128i greater = _mm_cmpgt_epi32(_mm_set1_epi32(RANS_WORD_L - 0x80000000), x_biased);
unsigned int mask = _mm_movemask_ps(_mm_castsi128_ps(greater));
// NOTE: this will read slightly past the end of the input buffer.
// In practice, either pad the input buffer by 8 bytes at the end,
// or switch to the non-SIMD version once you get close to the end.
__m128i memvals = _mm_loadl_epi64((const __m128i*)*pptr);
__m128i xshifted = _mm_slli_epi32(x, 16);
__m128i shufmask = _mm_load_si128((const __m128i*)shuffles[mask]);
__m128i newx = _mm_or_si128(xshifted, _mm_shuffle_epi8(memvals, shufmask));
r->simd = _mm_blendv_epi8(x, newx, greater);
*pptr += numbits[mask];
}
#endif // RANS_WORD_SSE41_HEADER
import os
import logging
from datetime import datetime
def get_timestamp():
return datetime.now().strftime('%y%m%d-%H%M%S')
def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False):
'''set up logger'''
lg = logging.getLogger(logger_name)
formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s',
datefmt='%y-%m-%d %H:%M:%S')
lg.setLevel(level)
if tofile:
log_file = os.path.join(root, phase + '_{}.log'.format(get_timestamp()))
fh = logging.FileHandler(log_file, mode='w')
fh.setFormatter(formatter)
lg.addHandler(fh)
if screen:
sh = logging.StreamHandler()
sh.setFormatter(formatter)
lg.addHandler(sh)
\ No newline at end of file
{
"name": "Proposed Method [MSE]",
"description": "Inference (ans)",
"results": {
"psnr": [
29.59623646173013,
30.620305225841804,
31.368097891084776,
33.500982263908604,
35.1986092737886,
36.75209027231476,
37.659769404047275,
38.98167890128869
],
"ms-ssim": [
0.9271597891319089,
0.942064642906189,
0.9509546567754048,
0.9705119947107826,
0.9790188103187375,
0.9849583288518394,
0.9879685073364072,
0.9913198279171456
],
"bpp": [
0.07203977324678053,
0.09755356726631553,
0.11748815950460842,
0.20864437863241664,
0.31829253667507795,
0.45920202148814876,
0.564088358437684,
0.7759934857000356
],
"encoding_time": [
64.97474208110717,
57.54451480144408
],
"decoding_time": [
115.90137354339042,
102.1610572919613
]
}
}
\ No newline at end of file
{
"name": "Proposed Method [MSE]",
"description": "Inference (ans)",
"results": {
"psnr": [
27.89430274785873,
28.957334258416935,
29.71510731761415,
31.919470844862207,
33.94583516657302,
35.83985444690141,
36.90782873010037,
38.47072210973406
],
"ms-ssim": [
0.910035602748394,
0.9292794143160185,
0.9408208007613817,
0.9658057888348898,
0.9779217839241028,
0.9856954514980316,
0.9888884996374449,
0.9923855985204378
],
"bpp": [
0.0983615451388889,
0.13387722439236113,
0.16317070855034724,
0.2925754123263889,
0.45141940646701384,
0.6507873535156249,
0.7909918891059027,
1.061119927300347
]
}
}
\ No newline at end of file
{
"name": "Proposed Method [MSE]",
"description": "Inference (ans)",
"results": {
"psnr": [
29.739129827686046,
30.79101072271687,
31.549300117049253,
33.654606374551406,
35.304123313098216,
36.84500056980779,
37.71330502780009,
39.05021961699284
],
"ms-ssim": [
0.9410666453838349,
0.9522642749547958,
0.9591444969177246,
0.9740614533424378,
0.9810326504707336,
0.9860490924119949,
0.9884526658058167,
0.9914442735910416
],
"bpp": [
0.08284577777777775,
0.10827288888888892,
0.12657400000000002,
0.21016333333333337,
0.30768,
0.43872444444444464,
0.5358042222222221,
0.7412455555555556
],
"encoding_time": [
47.26153783559799
],
"decoding_time": [
94.44327362060547
]
}
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment