Commit 9e459ea3 authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
import fiftyone
if __name__ == '__main__':
"""Download the training/test data set from OpenImages."""
dataset_train = fiftyone.zoo.load_zoo_dataset(
"open-images-v6",
split="test",
max_samples=3000,
label_types=["classifications"],
dataset_dir='openimages',
)
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
# Examples
## Notebooks
To run the jupyter notebooks:
* `pip install -U ipython jupyter ipywidgets matplotlib`
* `jupyter notebook`
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import struct
import sys
import time
from pathlib import Path
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision.transforms import ToPILImage, ToTensor
import compressai
from compressai.zoo import models
model_ids = {k: i for i, k in enumerate(models.keys())}
metric_ids = {
"mse": 0,
}
def inverse_dict(d):
# We assume dict values are unique...
assert len(d.keys()) == len(set(d.keys()))
return {v: k for k, v in d.items()}
def filesize(filepath: str) -> int:
if not Path(filepath).is_file():
raise ValueError(f'Invalid file "{filepath}".')
return Path(filepath).stat().st_size
def load_image(filepath: str) -> Image.Image:
return Image.open(filepath).convert("RGB")
def img2torch(img: Image.Image) -> torch.Tensor:
return ToTensor()(img).unsqueeze(0)
def torch2img(x: torch.Tensor) -> Image.Image:
return ToPILImage()(x.clamp_(0, 1).squeeze())
def write_uints(fd, values, fmt=">{:d}I"):
fd.write(struct.pack(fmt.format(len(values)), *values))
def write_uchars(fd, values, fmt=">{:d}B"):
fd.write(struct.pack(fmt.format(len(values)), *values))
def read_uints(fd, n, fmt=">{:d}I"):
sz = struct.calcsize("I")
return struct.unpack(fmt.format(n), fd.read(n * sz))
def read_uchars(fd, n, fmt=">{:d}B"):
sz = struct.calcsize("B")
return struct.unpack(fmt.format(n), fd.read(n * sz))
def write_bytes(fd, values, fmt=">{:d}s"):
if len(values) == 0:
return
fd.write(struct.pack(fmt.format(len(values)), values))
def read_bytes(fd, n, fmt=">{:d}s"):
sz = struct.calcsize("s")
return struct.unpack(fmt.format(n), fd.read(n * sz))[0]
def get_header(model_name, metric, quality):
"""Format header information:
- 1 byte for model id
- 4 bits for metric
- 4 bits for quality param
"""
metric = metric_ids[metric]
code = (metric << 4) | (quality - 1 & 0x0F)
return model_ids[model_name], code
def parse_header(header):
"""Read header information from 2 bytes:
- 1 byte for model id
- 4 bits for metric
- 4 bits for quality param
"""
model_id, code = header
quality = (code & 0x0F) + 1
metric = code >> 4
return (
inverse_dict(model_ids)[model_id],
inverse_dict(metric_ids)[metric],
quality,
)
def pad(x, p=2 ** 6):
h, w = x.size(2), x.size(3)
H = (h + p - 1) // p * p
W = (w + p - 1) // p * p
padding_left = (W - w) // 2
padding_right = W - w - padding_left
padding_top = (H - h) // 2
padding_bottom = H - h - padding_top
return F.pad(
x,
(padding_left, padding_right, padding_top, padding_bottom),
mode="constant",
value=0,
)
def crop(x, size):
H, W = x.size(2), x.size(3)
h, w = size
padding_left = (W - w) // 2
padding_right = W - w - padding_left
padding_top = (H - h) // 2
padding_bottom = H - h - padding_top
return F.pad(
x,
(-padding_left, -padding_right, -padding_top, -padding_bottom),
mode="constant",
value=0,
)
def _encode(image, model, metric, quality, coder, output):
compressai.set_entropy_coder(coder)
enc_start = time.time()
img = load_image(image)
start = time.time()
net = models[model](quality=quality, metric=metric, pretrained=True).eval()
load_time = time.time() - start
x = img2torch(img)
h, w = x.size(2), x.size(3)
p = 64 # maximum 6 strides of 2
x = pad(x, p)
with torch.no_grad():
out = net.compress(x)
shape = out["shape"]
header = get_header(model, metric, quality)
with Path(output).open("wb") as f:
write_uchars(f, header)
# write original image size
write_uints(f, (h, w))
# write shape and number of encoded latents
write_uints(f, (shape[0], shape[1], len(out["strings"])))
for s in out["strings"]:
write_uints(f, (len(s[0]),))
write_bytes(f, s[0])
enc_time = time.time() - enc_start
size = filesize(output)
bpp = float(size) * 8 / (img.size[0] * img.size[1])
print(
f"{bpp:.3f} bpp |"
f" Encoded in {enc_time:.2f}s (model loading: {load_time:.2f}s)"
)
def _decode(inputpath, coder, show, output=None):
compressai.set_entropy_coder(coder)
dec_start = time.time()
with Path(inputpath).open("rb") as f:
model, metric, quality = parse_header(read_uchars(f, 2))
original_size = read_uints(f, 2)
shape = read_uints(f, 2)
strings = []
n_strings = read_uints(f, 1)[0]
for _ in range(n_strings):
s = read_bytes(f, read_uints(f, 1)[0])
strings.append([s])
print(f"Model: {model:s}, metric: {metric:s}, quality: {quality:d}")
start = time.time()
net = models[model](quality=quality, metric=metric, pretrained=True).eval()
load_time = time.time() - start
with torch.no_grad():
out = net.decompress(strings, shape)
x_hat = crop(out["x_hat"], original_size)
img = torch2img(x_hat)
dec_time = time.time() - dec_start
print(f"Decoded in {dec_time:.2f}s (model loading: {load_time:.2f}s)")
if show:
show_image(img)
if output is not None:
img.save(output)
def show_image(img: Image.Image):
from matplotlib import pyplot as plt
fig, ax = plt.subplots()
ax.axis("off")
ax.title.set_text("Decoded image")
ax.imshow(img)
fig.tight_layout()
plt.show()
def encode(argv):
parser = argparse.ArgumentParser(description="Encode image to bit-stream")
parser.add_argument("image", type=str)
parser.add_argument(
"--model",
choices=models.keys(),
default=list(models.keys())[0],
help="NN model to use (default: %(default)s)",
)
parser.add_argument(
"-m",
"--metric",
choices=["mse"],
default="mse",
help="metric trained against (default: %(default)s",
)
parser.add_argument(
"-q",
"--quality",
choices=list(range(1, 9)),
type=int,
default=3,
help="Quality setting (default: %(default)s)",
)
parser.add_argument(
"-c",
"--coder",
choices=compressai.available_entropy_coders(),
default=compressai.available_entropy_coders()[0],
help="Entropy coder (default: %(default)s)",
)
parser.add_argument("-o", "--output", help="Output path")
args = parser.parse_args(argv)
if not args.output:
args.output = Path(Path(args.image).resolve().name).with_suffix(".bin")
_encode(args.image, args.model, args.metric, args.quality, args.coder, args.output)
def decode(argv):
parser = argparse.ArgumentParser(description="Decode bit-stream to imager")
parser.add_argument("input", type=str)
parser.add_argument(
"-c",
"--coder",
choices=compressai.available_entropy_coders(),
default=compressai.available_entropy_coders()[0],
help="Entropy coder (default: %(default)s)",
)
parser.add_argument("--show", action="store_true")
parser.add_argument("-o", "--output", help="Output path")
args = parser.parse_args(argv)
_decode(args.input, args.coder, args.show, args.output)
def parse_args(argv):
parser = argparse.ArgumentParser(description="")
parser.add_argument("command", choices=["encode", "decode"])
args = parser.parse_args(argv)
return args
def main(argv):
args = parse_args(argv[1:2])
argv = argv[2:]
torch.set_num_threads(1) # just to be sure
if args.command == "encode":
encode(argv)
elif args.command == "decode":
decode(argv)
if __name__ == "__main__":
main(sys.argv)
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import argparse
import math
import random
import shutil
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from compressai.datasets import ImageFolder
from compressai.zoo import models
from pytorch_msssim import ms_ssim
from typing import Tuple, Union
import numpy as np
import PIL
import PIL.Image as Image
from torchvision.transforms import ToPILImage
import logging
from utils import util
from torch.utils.tensorboard import SummaryWriter
from compressai.zoo.image import model_architectures as architectures
def torch2img(x: torch.Tensor) -> Image.Image:
return ToPILImage()(x.clamp_(0, 1).squeeze())
def compute_metrics(
a: Union[np.array, Image.Image],
b: Union[np.array, Image.Image],
max_val: float = 255.0,
) -> Tuple[float, float]:
"""Returns PSNR and MS-SSIM between images `a` and `b`. """
if isinstance(a, Image.Image):
a = np.asarray(a)
if isinstance(b, Image.Image):
b = np.asarray(b)
a = torch.from_numpy(a.copy()).float().unsqueeze(0)
if a.size(3) == 3:
a = a.permute(0, 3, 1, 2)
b = torch.from_numpy(b.copy()).float().unsqueeze(0)
if b.size(3) == 3:
b = b.permute(0, 3, 1, 2)
mse = torch.mean((a - b) ** 2).item()
p = 20 * np.log10(max_val) - 10 * np.log10(mse)
m = ms_ssim(a, b, data_range=max_val).item()
return p, m
class RateDistortionLoss(nn.Module):
"""Custom rate distortion loss with a Lagrangian parameter."""
def __init__(self, lmbda=1e-2, metrics='mse'):
super().__init__()
self.mse = nn.MSELoss()
self.lmbda = lmbda
self.metrics = metrics
def forward(self, output, target):
N, _, H, W = target.size()
out = {}
num_pixels = N * H * W
out["bpp_loss"] = sum(
(torch.log(likelihoods).sum() / (-math.log(2) * num_pixels))
for likelihoods in output["likelihoods"].values()
)
if self.metrics == 'mse':
out["mse_loss"] = self.mse(output["x_hat"], target)
out["ms_ssim_loss"] = None
out["loss"] = self.lmbda * 255 ** 2 * out["mse_loss"] + out["bpp_loss"]
elif self.metrics == 'ms-ssim':
out["mse_loss"] = None
out["ms_ssim_loss"] = 1 - ms_ssim(output["x_hat"], target, data_range=1.0)
out["loss"] = self.lmbda * out["ms_ssim_loss"] + out["bpp_loss"]
return out
class AverageMeter:
"""Compute running average."""
def __init__(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class CustomDataParallel(nn.DataParallel):
"""Custom DataParallel to access the module methods."""
def __getattr__(self, key):
try:
return super().__getattr__(key)
except AttributeError:
return getattr(self.module, key)
def configure_optimizers(net, args):
"""Separate parameters for the main optimizer and the auxiliary optimizer.
Return two optimizers"""
parameters = [
p for n, p in net.named_parameters() if not n.endswith(".quantiles")
]
aux_parameters = [
p for n, p in net.named_parameters() if n.endswith(".quantiles")
]
# Make sure we don't have an intersection of parameters
params_dict = dict(net.named_parameters())
inter_params = set(parameters) & set(aux_parameters)
union_params = set(parameters) | set(aux_parameters)
assert len(inter_params) == 0
assert len(union_params) - len(params_dict.keys()) == 0
optimizer = optim.Adam(
(p for p in parameters if p.requires_grad),
lr=args.learning_rate,
)
aux_optimizer = optim.Adam(
(p for p in aux_parameters if p.requires_grad),
lr=args.aux_learning_rate,
)
return optimizer, aux_optimizer
def train_one_epoch(
model, criterion, train_dataloader, optimizer, aux_optimizer, epoch, clip_max_norm, logger_train, tb_logger, current_step
):
model.train()
device = next(model.parameters()).device
for i, d in enumerate(train_dataloader):
d = d.to(device)
optimizer.zero_grad()
aux_optimizer.zero_grad()
out_net = model(d)
out_criterion = criterion(out_net, d)
out_criterion["loss"].backward()
if clip_max_norm > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_max_norm)
optimizer.step()
aux_loss = model.aux_loss()
aux_loss.backward()
aux_optimizer.step()
current_step += 1
if current_step % 100 == 0:
tb_logger.add_scalar('{}'.format('[train]: loss'), out_criterion["loss"].item(), current_step)
tb_logger.add_scalar('{}'.format('[train]: bpp_loss'), out_criterion["bpp_loss"].item(), current_step)
if out_criterion["mse_loss"] is not None:
tb_logger.add_scalar('{}'.format('[train]: mse_loss'), out_criterion["mse_loss"].item(), current_step)
if out_criterion["ms_ssim_loss"] is not None:
tb_logger.add_scalar('{}'.format('[train]: ms_ssim_loss'), out_criterion["ms_ssim_loss"].item(), current_step)
if i % 100 == 0:
if out_criterion["ms_ssim_loss"] is None:
logger_train.info(
f"Train epoch {epoch}: ["
f"{i*len(d):5d}/{len(train_dataloader.dataset)}"
f" ({100. * i / len(train_dataloader):.0f}%)] "
f'Loss: {out_criterion["loss"].item():.4f} | '
f'MSE loss: {out_criterion["mse_loss"].item():.4f} | '
f'Bpp loss: {out_criterion["bpp_loss"].item():.2f} | '
f"Aux loss: {aux_loss.item():.2f}"
)
else:
logger_train.info(
f"Train epoch {epoch}: ["
f"{i*len(d):5d}/{len(train_dataloader.dataset)}"
f" ({100. * i / len(train_dataloader):.0f}%)] "
f'Loss: {out_criterion["loss"].item():.4f} | '
f'MS-SSIM loss: {out_criterion["ms_ssim_loss"].item():.4f} | '
f'Bpp loss: {out_criterion["bpp_loss"].item():.2f} | '
f"Aux loss: {aux_loss.item():.2f}"
)
return current_step
def test_epoch(epoch, test_dataloader, model, criterion, save_dir, logger_val, tb_logger):
model.eval()
device = next(model.parameters()).device
loss = AverageMeter()
bpp_loss = AverageMeter()
mse_loss = AverageMeter()
ms_ssim_loss = AverageMeter()
aux_loss = AverageMeter()
psnr = AverageMeter()
ms_ssim = AverageMeter()
with torch.no_grad():
for i, d in enumerate(test_dataloader):
d = d.to(device)
out_net = model(d)
out_criterion = criterion(out_net, d)
aux_loss.update(model.aux_loss())
bpp_loss.update(out_criterion["bpp_loss"])
loss.update(out_criterion["loss"])
if out_criterion["mse_loss"] is not None:
mse_loss.update(out_criterion["mse_loss"])
if out_criterion["ms_ssim_loss"] is not None:
ms_ssim_loss.update(out_criterion["ms_ssim_loss"])
rec = torch2img(out_net['x_hat'])
img = torch2img(d)
p, m = compute_metrics(rec, img)
psnr.update(p)
ms_ssim.update(m)
if i % 20 == 1:
if not os.path.exists(save_dir):
os.makedirs(save_dir)
rec.save(os.path.join(save_dir, '%03d_rec.png' % i))
img.save(os.path.join(save_dir, '%03d_gt.png' % i))
tb_logger.add_scalar('{}'.format('[val]: loss'), loss.avg, epoch + 1)
tb_logger.add_scalar('{}'.format('[val]: bpp_loss'), bpp_loss.avg, epoch + 1)
tb_logger.add_scalar('{}'.format('[val]: psnr'), psnr.avg, epoch + 1)
tb_logger.add_scalar('{}'.format('[val]: ms-ssim'), ms_ssim.avg, epoch + 1)
if out_criterion["mse_loss"] is not None:
logger_val.info(
f"Test epoch {epoch}: Average losses: "
f"Loss: {loss.avg:.4f} | "
f"MSE loss: {mse_loss.avg:.4f} | "
f"Bpp loss: {bpp_loss.avg:.2f} | "
f"Aux loss: {aux_loss.avg:.2f} | "
f"PSNR: {psnr.avg:.6f} | "
f"MS-SSIM: {ms_ssim.avg:.6f}"
)
tb_logger.add_scalar('{}'.format('[val]: mse_loss'), mse_loss.avg, epoch + 1)
if out_criterion["ms_ssim_loss"] is not None:
logger_val.info(
f"Test epoch {epoch}: Average losses: "
f"Loss: {loss.avg:.4f} | "
f"MS-SSIM loss: {ms_ssim_loss.avg:.4f} | "
f"Bpp loss: {bpp_loss.avg:.2f} | "
f"Aux loss: {aux_loss.avg:.2f} | "
f"PSNR: {psnr.avg:.6f} | "
f"MS-SSIM: {ms_ssim.avg:.6f}"
)
tb_logger.add_scalar('{}'.format('[val]: ms_ssim_loss'), ms_ssim_loss.avg, epoch + 1)
return loss.avg
def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"):
torch.save(state, filename)
if is_best:
dest_filename = filename.replace(filename.split('/')[-1], "checkpoint_best_loss.pth.tar")
shutil.copyfile(filename, dest_filename)
def parse_args(argv):
parser = argparse.ArgumentParser(description="Example training script.")
parser.add_argument(
"-exp", "--experiment", type=str, required=True, help="Experiment name"
)
parser.add_argument(
"-m",
"--model",
default="bmshj2018-factorized",
choices=models.keys(),
help="Model architecture (default: %(default)s)",
)
parser.add_argument(
"-d", "--dataset", type=str, required=True, help="Training dataset"
)
parser.add_argument(
"-e",
"--epochs",
default=100,
type=int,
help="Number of epochs (default: %(default)s)",
)
parser.add_argument(
"-lr",
"--learning-rate",
default=1e-4,
type=float,
help="Learning rate (default: %(default)s)",
)
parser.add_argument(
"-q",
"--quality",
type=int,
default=1,
help="Quality (default: %(default)s)",
)
parser.add_argument(
"-n",
"--num-workers",
type=int,
default=30,
help="Dataloaders threads (default: %(default)s)",
)
parser.add_argument(
"--lambda",
dest="lmbda",
type=float,
default=1e-2,
help="Bit-rate distortion parameter (default: %(default)s)",
)
parser.add_argument(
"--metrics",
type=str,
default="mse",
help="Optimized for (default: %(default)s)",
)
parser.add_argument(
"--batch-size", type=int, default=16, help="Batch size (default: %(default)s)"
)
parser.add_argument(
"--test-batch-size",
type=int,
default=1,
help="Test batch size (default: %(default)s)",
)
parser.add_argument(
"--aux-learning-rate",
default=1e-3,
help="Auxiliary loss learning rate (default: %(default)s)",
)
parser.add_argument(
"--patch-size",
type=int,
nargs=2,
default=(256, 256),
help="Size of the patches to be cropped (default: %(default)s)",
)
parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID")
parser.add_argument("--cuda", action="store_true", help="Use cuda")
parser.add_argument("--save", action="store_true", help="Save model to disk")
parser.add_argument(
"--seed", type=float, help="Set random seed for reproducibility"
)
parser.add_argument(
"--clip_max_norm",
default=1.0,
type=float,
help="gradient clipping max norm (default: %(default)s",
)
parser.add_argument(
"-c",
"--checkpoint",
default=None,
type=str,
help="pretrained model path"
)
args = parser.parse_args(argv)
return args
def main(argv):
args = parse_args(argv)
if args.seed is not None:
torch.manual_seed(args.seed)
random.seed(args.seed)
if not os.path.exists(os.path.join('../experiments', args.experiment)):
os.makedirs(os.path.join('../experiments', args.experiment))
util.setup_logger('train', os.path.join('../experiments', args.experiment), 'train_' + args.experiment, level=logging.INFO,
screen=True, tofile=True)
util.setup_logger('val', os.path.join('../experiments', args.experiment), 'val_' + args.experiment, level=logging.INFO,
screen=True, tofile=True)
logger_train = logging.getLogger('train')
logger_val = logging.getLogger('val')
tb_logger = SummaryWriter(log_dir='../tb_logger/' + args.experiment)
if not os.path.exists(os.path.join('../experiments', args.experiment, 'checkpoints')):
os.makedirs(os.path.join('../experiments', args.experiment, 'checkpoints'))
train_transforms = transforms.Compose(
[transforms.RandomCrop(args.patch_size), transforms.ToTensor()]
)
test_transforms = transforms.Compose(
[transforms.CenterCrop(args.patch_size), transforms.ToTensor()]
)
train_dataset = ImageFolder(args.dataset, split="train", transform=train_transforms)
test_dataset = ImageFolder(args.dataset, split="test", transform=test_transforms)
train_dataloader = DataLoader(
train_dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
shuffle=True,
pin_memory=True,
)
test_dataloader = DataLoader(
test_dataset,
batch_size=args.test_batch_size,
num_workers=args.num_workers,
shuffle=False,
pin_memory=True,
)
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu"
if args.checkpoint != None:
checkpoint = torch.load(args.checkpoint)
net = architectures[args.model].from_state_dict(checkpoint['state_dict'])
net.update()
else:
net = models[args.model](quality=args.quality, train=True)
net = net.to(device)
logger_train.info(args)
logger_train.info(net)
if args.cuda and torch.cuda.device_count() > 1:
net = CustomDataParallel(net)
optimizer, aux_optimizer = configure_optimizers(net, args)
# lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")
lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[450,550], gamma=0.1)
criterion = RateDistortionLoss(lmbda=args.lmbda, metrics=args.metrics)
if args.checkpoint != None:
optimizer.load_state_dict(checkpoint['optimizer'])
aux_optimizer.load_state_dict(checkpoint['aux_optimizer'])
# lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[450,550], gamma=0.1)
lr_scheduler._step_count = checkpoint['lr_scheduler']['_step_count']
lr_scheduler.last_epoch = checkpoint['lr_scheduler']['last_epoch']
# print(lr_scheduler.state_dict())
start_epoch = checkpoint['epoch']
best_loss = checkpoint['loss']
current_step = start_epoch * math.ceil(len(train_dataloader.dataset) / args.batch_size)
checkpoint = None
else:
start_epoch = 0
best_loss = 1e10
current_step = 0
for epoch in range(start_epoch, args.epochs):
logger_train.info(f"Learning rate: {optimizer.param_groups[0]['lr']}")
current_step = train_one_epoch(
net,
criterion,
train_dataloader,
optimizer,
aux_optimizer,
epoch,
args.clip_max_norm,
logger_train,
tb_logger,
current_step
)
save_dir = os.path.join('../experiments', args.experiment, 'val_images', '%03d' % (epoch + 1))
loss = test_epoch(epoch, test_dataloader, net, criterion, save_dir, logger_val, tb_logger)
# lr_scheduler.step(loss)
lr_scheduler.step()
is_best = loss < best_loss
best_loss = min(loss, best_loss)
if args.save:
save_checkpoint(
{
"epoch": epoch + 1,
"state_dict": net.state_dict(),
"loss": loss,
"optimizer": optimizer.state_dict(),
"aux_optimizer": aux_optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
},
is_best,
os.path.join('../experiments', args.experiment, 'checkpoints', "checkpoint_%03d.pth.tar" % (epoch + 1))
)
if is_best:
logger_val.info('best checkpoint saved.')
if __name__ == "__main__":
main(sys.argv[1:])
[build-system]
requires = ["setuptools>=42", "wheel", "pybind11>=2.6.0",]
build-backend = "setuptools.build_meta"
[tool.black]
line-length = 88
target-version = ['py36', 'py37', 'py38']
include = '\.pyi?$'
exclude = '''
/(
\.eggs
| \.git
| \.mypy_cache
| \venv*
| _build
| build
| dist
)/
'''
[tool.isort]
multi_line_output = 3
lines_between_types = 1
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
ensure_newline_before_comments = true
line_length = 88
known_third_party = "PIL,pytorch_msssim,torchvision,torch"
skip_gitignore = true
[tool.pylint.messages_control]
disable = "C0330,C0326,bad-continuation,C0103,C0304,W0221,R0902,R0914,C0114,W0511,C0116,W0201,C0415,R0401,R0201,R0913,R0903,R0801"
[tool.pylint.format]
max-line-length = "88"
[tool.pylint.typecheck]
ignored-modules = "torch"
extension-pkg-whitelist = "compressai.ans,compressai._cxx,range_coder"
[tool.pytest.ini_options]
markers = [
"pretrained: download and check pretrained models (slow, deselect with '-m \"not pretrained\"')",
"slow: all slow tests (pretrained models, train, etc...)",
]
alabaster==0.7.12
appdirs==1.4.4
appnope==0.1.2
argon2-cffi==20.1.0
astroid==2.4.2
async-generator==1.10
attrs==20.3.0
Babel==2.9.0
backcall==0.2.0
black==20.8b1
bleach==3.3.0
certifi==2020.12.5
cffi==1.14.5
chardet==4.0.0
click==7.1.2
coverage==5.4
cycler==0.10.0
decorator==4.4.2
defusedxml==0.6.0
docutils==0.16
entrypoints==0.3
idna==2.10
imagesize==1.2.0
iniconfig==1.1.1
ipykernel==5.4.3
ipython==7.20.0
ipython-genutils==0.2.0
ipywidgets==7.6.3
isort==5.7.0
jedi==0.18.0
Jinja2==2.11.3
jsonschema==3.2.0
jupyter==1.0.0
jupyter-client==6.1.11
jupyter-console==6.2.0
jupyter-core==4.7.1
jupyterlab-pygments==0.1.2
jupyterlab-widgets==1.0.0
kiwisolver==1.3.1
lazy-object-proxy==1.4.3
MarkupSafe==1.1.1
matplotlib==3.3.4
mccabe==0.6.1
mistune==0.8.4
mypy-extensions==0.4.3
nbclient==0.5.2
nbconvert==6.0.7
nbformat==5.1.2
nest-asyncio==1.5.1
notebook==6.2.0
numpy==1.20.1
packaging==20.9
pandocfilters==1.4.3
parso==0.8.1
pathspec==0.8.1
pexpect==4.8.0
pickleshare==0.7.5
Pillow==8.1.0
pluggy==0.13.1
prometheus-client==0.9.0
prompt-toolkit==3.0.16
ptyprocess==0.7.0
py==1.10.0
pycparser==2.20
Pygments==2.7.4
pylint==2.6.0
pyparsing==2.4.7
pyrsistent==0.17.3
pytest==6.2.2
pytest-cov==2.11.1
python-dateutil==2.8.1
pytz==2021.1
pyzmq==22.0.2
qtconsole==5.0.2
QtPy==1.9.0
regex==2020.11.13
requests==2.25.1
scipy==1.6.0
Send2Trash==1.5.0
six==1.15.0
snowballstemmer==2.1.0
Sphinx==3.4.3
sphinx-rtd-theme==0.5.1
sphinxcontrib-applehelp==1.0.2
sphinxcontrib-devhelp==1.0.2
sphinxcontrib-htmlhelp==1.0.3
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==1.0.3
sphinxcontrib-serializinghtml==1.1.4
terminado==0.9.2
testpath==0.4.4
toml==0.10.2
torch==1.7.1
torchvision==0.8.2
tornado==6.1
traitlets==5.0.5
typed-ast==1.4.2
typing-extensions==3.7.4.3
urllib3==1.26.3
wcwidth==0.2.5
webencodings==0.5.1
widgetsnbextension==3.5.1
wrapt==1.12.1
import os
import json
import imagesize
from shutil import copyfile
def process_set(data_root_dir, index_list):
imgs = sorted(os.listdir(data_root_dir))
alls = [imgs[i] for i in index_list]
set_dict = {}
for i in range(len(alls)):
set_dict['scene_%05d' % i] = {}
set_dict['scene_%05d' % i]['img_path'] = os.path.join(data_root_dir, alls[i])
return set_dict
# change these three path
data_root_dir = '/home/felix/disk2/data/flicker/flicker_2W_images'
train_root_dir = '/home/felix/disk2/data/flicker/train'
test_root_dir = '/home/felix/disk2/data/flicker/test'
if not os.path.exists(train_root_dir):
os.makedirs(train_root_dir)
if not os.path.exists(test_root_dir):
os.makedirs(test_root_dir)
# get filtered index dict
index_list = []
count = 0
for img in sorted(os.listdir(data_root_dir)):
img_path = os.path.join(data_root_dir, img)
width, height = imagesize.get(img_path)
if width < 256 or height < 256:
pass
else:
index_list.append(count)
count += 1
if count % 1000 == 0:
print(count, 'done!')
imgs = sorted(os.listdir(data_root_dir))
alls = [imgs[i] for i in index_list]
all_length = len(alls)
train_ratio = 0.99
train_length = round(len(alls) * train_ratio)
val_length = all_length - train_length
for img_name in alls[:train_length]:
copyfile(os.path.join(data_root_dir, img_name), os.path.join(train_root_dir, img_name))
for img_name in alls[-val_length:]:
copyfile(os.path.join(data_root_dir, img_name), os.path.join(test_root_dir, img_name))
# Copyright (c) 2021-2022, InterDigital Communications, Inc
# All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of InterDigital Communications, Inc nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.
# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import os
import subprocess
from pathlib import Path
from pybind11.setup_helpers import Pybind11Extension, build_ext
from setuptools import find_packages, setup
cwd = Path(__file__).resolve().parent
package_name = "compressai"
version = "1.2.1"
git_hash = "unknown"
try:
git_hash = (
subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=cwd).decode().strip()
)
except (FileNotFoundError, subprocess.CalledProcessError):
pass
def write_version_file():
path = cwd / package_name / "version.py"
with path.open("w") as f:
f.write(f'__version__ = "{version}"\n')
f.write(f'git_version = "{git_hash}"\n')
write_version_file()
def get_extensions():
ext_dirs = cwd / package_name / "cpp_exts"
ext_modules = []
# Add rANS module
rans_lib_dir = cwd / "third_party/ryg_rans"
rans_ext_dir = ext_dirs / "rans"
extra_compile_args = ["-std=c++17"]
if os.getenv("DEBUG_BUILD", None):
extra_compile_args += ["-O0", "-g", "-UNDEBUG"]
else:
extra_compile_args += ["-O3"]
ext_modules.append(
Pybind11Extension(
name=f"{package_name}.ans",
sources=[str(s) for s in rans_ext_dir.glob("*.cpp")],
language="c++",
include_dirs=[rans_lib_dir, rans_ext_dir],
extra_compile_args=extra_compile_args,
)
)
# Add ops
ops_ext_dir = ext_dirs / "ops"
ext_modules.append(
Pybind11Extension(
name=f"{package_name}._CXX",
sources=[str(s) for s in ops_ext_dir.glob("*.cpp")],
language="c++",
extra_compile_args=extra_compile_args,
)
)
return ext_modules
TEST_REQUIRES = ["pytest", "pytest-cov"]
DEV_REQUIRES = TEST_REQUIRES + [
"black",
"flake8",
"flake8-bugbear",
"flake8-comprehensions",
"isort",
"mypy",
]
def get_extra_requirements():
extras_require = {
"test": TEST_REQUIRES,
"dev": DEV_REQUIRES,
"doc": ["sphinx", "furo"],
"tutorials": ["jupyter", "ipywidgets"],
}
extras_require["all"] = {req for reqs in extras_require.values() for req in reqs}
return extras_require
setup(
name=package_name,
version=version,
description="A PyTorch library and evaluation platform for end-to-end compression research",
url="https://github.com/InterDigitalInc/CompressAI",
author="InterDigital AI Lab",
author_email="compressai@interdigital.com",
packages=find_packages(exclude=("tests",)),
zip_safe=False,
python_requires=">=3.6",
install_requires=[
"numpy",
"scipy",
"matplotlib",
"torch>=1.7.1",
"torchvision",
"pytorch-msssim",
],
extras_require=get_extra_requirements(),
license="BSD 3-Clause Clear License",
classifiers=[
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: BSD License",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
ext_modules=get_extensions(),
cmdclass={"build_ext": build_ext},
)
\ No newline at end of file
Learning rate: 0.0001
Train epoch 0: [0/1 (0%)] Loss: 73.254 | MSE loss: 0.106 | Bpp loss: 4.06 | Aux loss: 7917.84
Test epoch 0: Average losses: Loss: 78.085 | MSE loss: 0.114 | Bpp loss: 4.06 | Aux loss: 7917.78
Learning rate: 0.0001
Train epoch 1: [0/1 (0%)] Loss: 69.391 | MSE loss: 0.100 | Bpp loss: 4.06 | Aux loss: 7917.42
Test epoch 1: Average losses: Loss: 77.089 | MSE loss: 0.112 | Bpp loss: 4.06 | Aux loss: 7917.37
Learning rate: 0.0001
Train epoch 2: [0/1 (0%)] Loss: 81.454 | MSE loss: 0.119 | Bpp loss: 4.06 | Aux loss: 7917.01
Test epoch 2: Average losses: Loss: 76.104 | MSE loss: 0.111 | Bpp loss: 4.06 | Aux loss: 7916.95
Learning rate: 0.0001
Train epoch 3: [0/1 (0%)] Loss: 76.327 | MSE loss: 0.111 | Bpp loss: 4.06 | Aux loss: 7916.59
Test epoch 3: Average losses: Loss: 75.106 | MSE loss: 0.109 | Bpp loss: 4.05 | Aux loss: 7916.54
Learning rate: 0.0001
Train epoch 4: [0/1 (0%)] Loss: 68.382 | MSE loss: 0.099 | Bpp loss: 4.06 | Aux loss: 7916.18
Test epoch 4: Average losses: Loss: 72.342 | MSE loss: 0.105 | Bpp loss: 4.05 | Aux loss: 7916.12
Learning rate: 0.0001
Train epoch 5: [0/1 (0%)] Loss: 68.586 | MSE loss: 0.099 | Bpp loss: 4.05 | Aux loss: 7915.76
Test epoch 5: Average losses: Loss: 65.022 | MSE loss: 0.094 | Bpp loss: 4.05 | Aux loss: 7915.70
Learning rate: 0.0001
Train epoch 6: [0/1 (0%)] Loss: 61.092 | MSE loss: 0.088 | Bpp loss: 4.05 | Aux loss: 7915.34
Test epoch 6: Average losses: Loss: 55.618 | MSE loss: 0.079 | Bpp loss: 4.05 | Aux loss: 7915.28
Learning rate: 0.0001
Train epoch 7: [0/1 (0%)] Loss: 57.247 | MSE loss: 0.082 | Bpp loss: 4.05 | Aux loss: 7914.92
Test epoch 7: Average losses: Loss: 58.262 | MSE loss: 0.083 | Bpp loss: 4.05 | Aux loss: 7914.86
Learning rate: 0.0001
Train epoch 8: [0/1 (0%)] Loss: 58.874 | MSE loss: 0.084 | Bpp loss: 4.05 | Aux loss: 7914.51
Test epoch 8: Average losses: Loss: 60.872 | MSE loss: 0.087 | Bpp loss: 4.05 | Aux loss: 7914.45
Learning rate: 0.0001
Train epoch 9: [0/1 (0%)] Loss: 60.795 | MSE loss: 0.087 | Bpp loss: 4.05 | Aux loss: 7914.12
Test epoch 9: Average losses: Loss: 53.210 | MSE loss: 0.076 | Bpp loss: 4.05 | Aux loss: 7914.06
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import pytest
import torch
from compressai.zoo import models
archs = [
"bmshj2018-factorized",
"bmshj2018-hyperprior",
"mbt2018-mean",
"mbt2018",
]
class TestCompressDecompress:
@pytest.mark.parametrize("arch,N", itertools.product(archs, [1, 2, 3]))
def test_codec(self, arch: str, N: int):
x = torch.zeros(N, 3, 256, 256)
h, w = x.size()[-2:]
x[:, :, h // 4 : -h // 4, w // 4 : -w // 4].fill_(1)
model = models[arch]
net = model(quality=1, metric="mse", pretrained=True).eval()
with torch.no_grad():
rv = net.compress(x)
assert "shape" in rv
shape = rv["shape"]
ds = net.downsampling_factor
assert shape == torch.Size([x.size(2) // ds, x.size(3) // ds])
assert "strings" in rv
strings_list = rv["strings"]
# y_strings (+ optional z_strings)
assert len(strings_list) == 1 or len(strings_list) == 2
for strings in strings_list:
assert len(strings) == N
for string in strings:
assert isinstance(string, bytes)
assert len(string) > 0
with torch.no_grad():
rv = net.decompress(strings_list, shape)
assert "x_hat" in rv
x_hat = rv["x_hat"]
assert x_hat.size() == x.size()
mse = torch.mean((x - x_hat) ** 2)
psnr = -10 * torch.log10(mse).item()
assert 35 < psnr < 41
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import compressai
def test_get_entropy_coder():
assert compressai.get_entropy_coder() == "ans"
def test_available_entropy_coders():
rv = compressai.available_entropy_coders()
assert isinstance(rv, list)
assert "ans" in rv
def test_set_entropy_coder():
compressai.set_entropy_coder("ans")
with pytest.raises(ValueError):
compressai.set_entropy_coder("cabac")
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from PIL import Image
from torchvision import transforms
from compressai.datasets import ImageFolder
def save_fake_image(filepath, size=(512, 512)):
img = Image.new("RGB", size=size)
img.save(filepath)
class TestImageFolder:
def test_init_ok(self, tmpdir):
tmpdir.mkdir("train")
tmpdir.mkdir("test")
train_dataset = ImageFolder(tmpdir, split="train")
test_dataset = ImageFolder(tmpdir, split="test")
assert len(train_dataset) == 0
assert len(test_dataset) == 0
def test_count_ok(self, tmpdir):
tmpdir.mkdir("train")
(tmpdir / "train" / "img1.jpg").write("")
(tmpdir / "train" / "img2.jpg").write("")
(tmpdir / "train" / "img3.jpg").write("")
train_dataset = ImageFolder(tmpdir, split="train")
assert len(train_dataset) == 3
def test_invalid_dir(self, tmpdir):
with pytest.raises(RuntimeError):
ImageFolder(tmpdir)
def test_load(self, tmpdir):
tmpdir.mkdir("train")
save_fake_image((tmpdir / "train" / "img0.jpeg").strpath)
train_dataset = ImageFolder(tmpdir, split="train")
assert isinstance(train_dataset[0], Image.Image)
def test_load_transforms(self, tmpdir):
tmpdir.mkdir("train")
save_fake_image((tmpdir / "train" / "img0.jpeg").strpath)
transform = transforms.Compose(
[
transforms.CenterCrop((128, 128)),
transforms.ToTensor(),
]
)
train_dataset = ImageFolder(tmpdir, split="train", transform=transform)
assert isinstance(train_dataset[0], torch.Tensor)
assert train_dataset[0].size() == (3, 128, 128)
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from compressai.entropy_models import (
EntropyBottleneck,
EntropyModel,
GaussianConditional,
)
from compressai.models.priors import FactorizedPrior
from compressai.zoo import bmshj2018_factorized, bmshj2018_hyperprior
@pytest.fixture
def entropy_model():
return EntropyModel()
class TestEntropyModel:
def test_quantize_invalid(self, entropy_model):
x = torch.rand(1, 3, 4, 4)
with pytest.raises(ValueError):
entropy_model.quantize(x, mode="toto")
def test_quantize_noise(self, entropy_model):
x = torch.rand(1, 3, 4, 4)
y = entropy_model.quantize(x, "noise")
assert y.shape == x.shape
assert ((y - x) <= 0.5).all()
assert ((y - x) >= -0.5).all()
assert (y != torch.round(x)).any()
def test__quantize(self, entropy_model):
x = torch.rand(1, 3, 4, 4)
s = torch.rand(1).item()
torch.manual_seed(s)
y0 = entropy_model.quantize(x, "noise")
torch.manual_seed(s)
with pytest.warns(UserWarning):
y1 = entropy_model._quantize(x, "noise")
assert (y0 == y1).all()
def test_quantize_symbols(self, entropy_model):
x = torch.rand(1, 3, 4, 4)
y = entropy_model.quantize(x, "symbols")
assert y.shape == x.shape
assert (y == torch.round(x).int()).all()
def test_quantize_dequantize(self, entropy_model):
x = torch.rand(1, 3, 4, 4)
means = torch.rand(1, 3, 4, 4)
y = entropy_model.quantize(x, "dequantize", means)
assert y.shape == x.shape
assert (y == torch.round(x - means) + means).all()
def test_dequantize(self, entropy_model):
x = torch.randint(-32, 32, (1, 3, 4, 4))
means = torch.rand(1, 3, 4, 4)
y = entropy_model.dequantize(x, means)
assert y.shape == x.shape
assert y.type() == means.type()
with pytest.warns(UserWarning):
yy = entropy_model._dequantize(x, means)
assert (yy == y).all()
def test_forward(self, entropy_model):
with pytest.raises(NotImplementedError):
entropy_model()
def test_invalid_coder(self):
with pytest.raises(ValueError):
entropy_model = EntropyModel(entropy_coder="huffman")
with pytest.raises(ValueError):
entropy_model = EntropyModel(entropy_coder=0xFF)
def test_invalid_inputs(self, entropy_model):
with pytest.raises(TypeError):
entropy_model.compress(torch.rand(1, 3))
with pytest.raises(ValueError):
entropy_model.compress(torch.rand(1, 3), torch.rand(2, 3))
with pytest.raises(ValueError):
entropy_model.compress(torch.rand(1, 3, 1, 1), torch.rand(2, 3))
def test_invalid_cdf(self, entropy_model):
x = torch.rand(1, 32, 16, 16)
indexes = torch.rand(1, 32, 16, 16)
with pytest.raises(ValueError):
entropy_model.compress(x, indexes)
def test_invalid_cdf_length(self, entropy_model):
x = torch.rand(1, 32, 16, 16)
indexes = torch.rand(1, 32, 16, 16)
entropy_model._quantized_cdf.resize_(32, 1)
with pytest.raises(ValueError):
entropy_model.compress(x, indexes)
entropy_model._cdf_length.resize_(32, 1)
with pytest.raises(ValueError):
entropy_model.compress(x, indexes)
def test_invalid_offsets(self, entropy_model):
x = torch.rand(1, 32, 16, 16)
indexes = torch.rand(1, 32, 16, 16)
entropy_model._quantized_cdf.resize_(32, 1)
entropy_model._cdf_length.resize_(32)
with pytest.raises(ValueError):
entropy_model.compress(x, indexes)
def test_invalid_decompress(self, entropy_model):
with pytest.raises(TypeError):
entropy_model.decompress(["ssss"])
with pytest.raises(ValueError):
entropy_model.decompress("sss", torch.rand(1, 3, 4, 4))
with pytest.raises(ValueError):
entropy_model.decompress(["sss"], torch.rand(1, 4, 4))
with pytest.raises(ValueError):
entropy_model.decompress(["sss"], torch.rand(2, 4, 4))
with pytest.raises(ValueError):
entropy_model.decompress(["sss"], torch.rand(1, 4, 4), torch.rand(2, 4, 4))
class TestEntropyBottleneck:
def test_forward_training(self):
entropy_bottleneck = EntropyBottleneck(128)
x = torch.rand(1, 128, 32, 32)
y, y_likelihoods = entropy_bottleneck(x)
assert isinstance(entropy_bottleneck, EntropyModel)
assert y.shape == x.shape
assert y_likelihoods.shape == x.shape
assert ((y - x) <= 0.5).all()
assert ((y - x) >= -0.5).all()
assert (y != torch.round(x)).any()
def test_forward_inference(self):
entropy_bottleneck = EntropyBottleneck(128)
entropy_bottleneck.eval()
x = torch.rand(1, 128, 32, 32)
y, y_likelihoods = entropy_bottleneck(x)
assert y.shape == x.shape
assert y_likelihoods.shape == x.shape
assert (y == torch.round(x)).all()
def test_loss(self):
entropy_bottleneck = EntropyBottleneck(128)
loss = entropy_bottleneck.loss()
assert len(loss.size()) == 0
assert loss.numel() == 1
def test_scripting(self):
entropy_bottleneck = EntropyBottleneck(128)
x = torch.rand(1, 128, 32, 32)
torch.manual_seed(32)
y0 = entropy_bottleneck(x)
m = torch.jit.script(entropy_bottleneck)
torch.manual_seed(32)
y1 = m(x)
assert torch.allclose(y0[0], y1[0])
assert torch.all(y1[1] == 0) # not yet supported
def test_update(self):
# get a pretrained model
net = bmshj2018_factorized(quality=1, pretrained=True).eval()
assert not net.update()
assert not net.update(force=False)
assert net.update(force=True)
def test_script(self):
eb = EntropyBottleneck(32)
eb = torch.jit.script(eb)
x = torch.rand(1, 32, 4, 4)
x_q, likelihoods = eb(x)
assert (likelihoods == torch.zeros_like(x_q)).all()
class TestGaussianConditional:
def test_invalid_scale_table(self):
with pytest.raises(ValueError):
GaussianConditional(1)
with pytest.raises(ValueError):
GaussianConditional([])
with pytest.raises(ValueError):
GaussianConditional(())
with pytest.raises(ValueError):
GaussianConditional(torch.rand(10))
with pytest.raises(ValueError):
GaussianConditional([2, 1])
with pytest.raises(ValueError):
GaussianConditional([0, 1, 2])
with pytest.raises(ValueError):
GaussianConditional([], scale_bound=None)
with pytest.raises(ValueError):
GaussianConditional([], scale_bound=-0.1)
def test_forward_training(self):
gaussian_conditional = GaussianConditional(None)
x = torch.rand(1, 128, 32, 32)
scales = torch.rand(1, 128, 32, 32)
y, y_likelihoods = gaussian_conditional(x, scales)
assert isinstance(gaussian_conditional, EntropyModel)
assert y.shape == x.shape
assert y_likelihoods.shape == x.shape
assert ((y - x) <= 0.5).all()
assert ((y - x) >= -0.5).all()
assert (y != torch.round(x)).any()
def test_forward_inference(self):
gaussian_conditional = GaussianConditional(None)
gaussian_conditional.eval()
x = torch.rand(1, 128, 32, 32)
scales = torch.rand(1, 128, 32, 32)
y, y_likelihoods = gaussian_conditional(x, scales)
assert y.shape == x.shape
assert y_likelihoods.shape == x.shape
assert (y == torch.round(x)).all()
def test_forward_training_mean(self):
gaussian_conditional = GaussianConditional(None)
x = torch.rand(1, 128, 32, 32)
scales = torch.rand(1, 128, 32, 32)
means = torch.rand(1, 128, 32, 32)
y, y_likelihoods = gaussian_conditional(x, scales, means)
assert y.shape == x.shape
assert y_likelihoods.shape == x.shape
assert ((y - x) <= 0.5).all()
assert ((y - x) >= -0.5).all()
assert (y != torch.round(x)).any()
def test_forward_inference_mean(self):
gaussian_conditional = GaussianConditional(None)
gaussian_conditional.eval()
x = torch.rand(1, 128, 32, 32)
scales = torch.rand(1, 128, 32, 32)
means = torch.rand(1, 128, 32, 32)
y, y_likelihoods = gaussian_conditional(x, scales, means)
assert y.shape == x.shape
assert y_likelihoods.shape == x.shape
assert (y == torch.round(x - means) + means).all()
def test_scripting(self):
gaussian_conditional = GaussianConditional(None)
x = torch.rand(1, 128, 32, 32)
scales = torch.rand(1, 128, 32, 32)
means = torch.rand(1, 128, 32, 32)
torch.manual_seed(32)
y0 = gaussian_conditional(x, scales, means)
m = torch.jit.script(gaussian_conditional)
torch.manual_seed(32)
y1 = m(x, scales, means)
assert torch.allclose(y0[0], y1[0])
assert torch.allclose(y0[1], y1[1])
def test_update(self):
# get a pretrained model
net = bmshj2018_hyperprior(quality=1, pretrained=True).eval()
assert not net.update()
assert not net.update(force=False)
quantized_cdf = net.gaussian_conditional._quantized_cdf
offset = net.gaussian_conditional._offset
cdf_length = net.gaussian_conditional._cdf_length
assert net.update(force=True)
def approx(a, b):
return ((a - b).abs() <= 2).all()
assert approx(net.gaussian_conditional._cdf_length, cdf_length)
assert approx(net.gaussian_conditional._offset, offset)
assert approx(net.gaussian_conditional._quantized_cdf, quantized_cdf)
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# From: https://stackoverflow.com/questions/2481511/mocking-importerror-in-python
try:
import builtins
except ImportError:
import __builtin__ as builtins
realimport = builtins.__import__
def monkeypatched_import(name, *args, **kwargs):
# raise ImportError
if name == "compressai.version":
raise ImportError
if name == "range_coder":
raise ImportError
return realimport(name, *args, **kwargs)
builtins.__import__ = monkeypatched_import
def test_import_errors():
# This should not crash
import compressai
def test_version():
builtins.__import__ = realimport
from compressai.version import __version__
assert 5 <= len(__version__) <= 7
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from compressai.layers import (
GDN,
GDN1,
AttentionBlock,
MaskedConv2d,
ResidualBlock,
ResidualBlockUpsample,
ResidualBlockWithStride,
)
class TestMaskedConv2d:
@staticmethod
def test_mask_type():
MaskedConv2d(1, 3, 3, mask_type="A")
MaskedConv2d(1, 3, 3, mask_type="B")
with pytest.raises(ValueError):
MaskedConv2d(1, 3, 3, mask_type="C")
@staticmethod
def test_mask_A():
conv = MaskedConv2d(1, 3, 5, mask_type="A")
assert (conv.mask[0] == conv.mask[1]).all()
assert (conv.mask[0] == conv.mask[2]).all()
_, _, h, w = conv.mask.size()
a = torch.ones_like(conv.mask)
a[:, :, h // 2, w // 2 :] = 0
a[:, :, h // 2 + 1 :] = 0
assert (conv.mask == a).all()
@staticmethod
def test_mask_B():
conv = MaskedConv2d(1, 3, 5, mask_type="B")
assert (conv.mask[0] == conv.mask[1]).all()
assert (conv.mask[0] == conv.mask[2]).all()
_, _, h, w = conv.mask.size()
b = torch.ones_like(conv.mask)
b[:, :, h // 2, w // 2 + 1 :] = 0
b[:, :, h // 2 + 1 :] = 0
assert (conv.mask == b).all()
@staticmethod
def test_mask_A_1d():
conv = MaskedConv2d(1, 3, (1, 5), mask_type="A")
assert (conv.mask[0] == conv.mask[1]).all()
assert (conv.mask[0] == conv.mask[2]).all()
_, _, h, w = conv.mask.size()
a = torch.ones_like(conv.mask)
a[:, :, h // 2, w // 2 :] = 0
a[:, :, h // 2 + 1 :] = 0
assert (conv.mask == a).all()
@staticmethod
def test_mask_B_1d():
conv = MaskedConv2d(3, 1, (5, 1), mask_type="B")
assert (conv.mask[:, 0] == conv.mask[:, 1]).all()
assert (conv.mask[:, 0] == conv.mask[:, 2]).all()
_, _, h, w = conv.mask.size()
b = torch.ones_like(conv.mask)
b[:, :, h // 2, w // 2 + 1 :] = 0
b[:, :, h // 2 + 1 :] = 0
assert (conv.mask == b).all()
@staticmethod
def test_mask_multiple():
cfgs = [
# (in, out, kernel_size)
(1, 3, 5),
(3, 1, 3),
(3, 3, 7),
]
for cfg in cfgs:
in_ch, out_ch, k = cfg
conv = MaskedConv2d(in_ch, out_ch, k, mask_type="A")
assert conv.mask[0].sum() != 0
assert (conv.mask - conv.mask[0]).sum() == 0
_, _, h, w = conv.mask.size()
a = torch.ones_like(conv.mask)
a[:, :, h // 2, w // 2 :] = 0
a[:, :, h // 2 + 1 :] = 0
assert (conv.mask == a).all()
class TestGDN:
def test_gdn(self):
g = GDN(32)
x = torch.rand(1, 32, 16, 16, requires_grad=True)
y = g(x)
y.backward(x)
assert y.shape == x.shape
assert x.grad is not None
assert x.grad.shape == x.shape
y_ref = x / torch.sqrt(1 + 0.1 * (x ** 2))
assert torch.allclose(y_ref, y)
def test_igdn(self):
g = GDN(32, inverse=True)
x = torch.rand(1, 32, 16, 16, requires_grad=True)
y = g(x)
y.backward(x)
assert y.shape == x.shape
assert x.grad is not None
assert x.grad.shape == x.shape
y_ref = x * torch.sqrt(1 + 0.1 * (x ** 2))
assert torch.allclose(y_ref, y)
def test_gdn1(self):
g = GDN1(32)
x = torch.rand(1, 32, 16, 16, requires_grad=True)
y = g(x)
y.backward(x)
assert y.shape == x.shape
assert x.grad is not None
assert x.grad.shape == x.shape
y_ref = x / (1 + 0.1 * torch.abs(x))
assert torch.allclose(y_ref, y)
def test_ResidualBlockWithStride():
layer = ResidualBlockWithStride(32, 64, stride=1)
layer(torch.rand(1, 32, 4, 4))
layer = ResidualBlockWithStride(32, 32, stride=1)
layer(torch.rand(1, 32, 4, 4))
layer = ResidualBlockWithStride(32, 32, stride=2)
layer(torch.rand(1, 32, 4, 4))
layer = ResidualBlockWithStride(32, 64, stride=2)
layer(torch.rand(1, 32, 4, 4))
def test_ResidualBlockUpsample():
layer = ResidualBlockUpsample(8, 16)
layer(torch.rand(1, 8, 4, 4))
def test_ResidualBlock():
layer = ResidualBlock(8, 8)
layer(torch.rand(1, 8, 4, 4))
layer = ResidualBlock(8, 16)
layer(torch.rand(1, 8, 4, 4))
def test_AttentionBlock():
layer = AttentionBlock(8)
layer(torch.rand(1, 8, 4, 4))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment