"vscode:/vscode.git/clone" did not exist on "31d9f9ea77d7bda61484ef9a29d8453f88c6e28d"
Unverified Commit 5f0edb97 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Add ufmt (usort + black) as code formatter (#4384)



* add ufmt as code formatter

* cleanup

* quote ufmt requirement

* split imports into more groups

* regenerate circleci config

* fix CI

* clarify local testing utils section

* use ufmt pre-commit hook

* split relative imports into local category

* Revert "split relative imports into local category"

This reverts commit f2e224cde2008c56c9347c1f69746d39065cdd51.

* pin black and usort dependencies

* fix local test utils detection

* fix ufmt rev

* add reference utils to local category

* fix usort config

* remove custom categories sorting

* Run pre-commit without fixing flake8

* got a double import in merge
Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent e45489b1
from common_utils import IN_CIRCLE_CI, CIRCLECI_GPU_NO_CUDA_MSG, IN_FBCODE, IN_RE_WORKER, CUDA_NOT_AVAILABLE_MSG
import torch
import numpy as np
import random import random
import numpy as np
import pytest import pytest
import torch
from common_utils import IN_CIRCLE_CI, CIRCLECI_GPU_NO_CUDA_MSG, IN_FBCODE, IN_RE_WORKER, CUDA_NOT_AVAILABLE_MSG
def pytest_configure(config): def pytest_configure(config):
# register an additional marker (see pytest_collection_modifyitems) # register an additional marker (see pytest_collection_modifyitems)
config.addinivalue_line( config.addinivalue_line("markers", "needs_cuda: mark for tests that rely on a CUDA device")
"markers", "needs_cuda: mark for tests that rely on a CUDA device" config.addinivalue_line("markers", "dont_collect: mark for tests that should not be collected")
)
config.addinivalue_line(
"markers", "dont_collect: mark for tests that should not be collected"
)
def pytest_collection_modifyitems(items): def pytest_collection_modifyitems(items):
...@@ -34,7 +31,7 @@ def pytest_collection_modifyitems(items): ...@@ -34,7 +31,7 @@ def pytest_collection_modifyitems(items):
# @pytest.mark.parametrize('device', cpu_and_gpu()) # @pytest.mark.parametrize('device', cpu_and_gpu())
# the "instances" of the tests where device == 'cuda' will have the 'needs_cuda' mark, # the "instances" of the tests where device == 'cuda' will have the 'needs_cuda' mark,
# and the ones with device == 'cpu' won't have the mark. # and the ones with device == 'cpu' won't have the mark.
needs_cuda = item.get_closest_marker('needs_cuda') is not None needs_cuda = item.get_closest_marker("needs_cuda") is not None
if needs_cuda and not torch.cuda.is_available(): if needs_cuda and not torch.cuda.is_available():
# In general, we skip cuda tests on machines without a GPU # In general, we skip cuda tests on machines without a GPU
...@@ -59,7 +56,7 @@ def pytest_collection_modifyitems(items): ...@@ -59,7 +56,7 @@ def pytest_collection_modifyitems(items):
# to run the CPU-only tests. # to run the CPU-only tests.
item.add_marker(pytest.mark.skip(reason=CIRCLECI_GPU_NO_CUDA_MSG)) item.add_marker(pytest.mark.skip(reason=CIRCLECI_GPU_NO_CUDA_MSG))
if item.get_closest_marker('dont_collect') is not None: if item.get_closest_marker("dont_collect") is not None:
# currently, this is only used for some tests we're sure we dont want to run on fbcode # currently, this is only used for some tests we're sure we dont want to run on fbcode
continue continue
......
...@@ -18,7 +18,6 @@ import pytest ...@@ -18,7 +18,6 @@ import pytest
import torch import torch
import torchvision.datasets import torchvision.datasets
import torchvision.io import torchvision.io
from common_utils import get_tmp_dir, disable_console_output from common_utils import get_tmp_dir, disable_console_output
...@@ -419,7 +418,7 @@ class DatasetTestCase(unittest.TestCase): ...@@ -419,7 +418,7 @@ class DatasetTestCase(unittest.TestCase):
defaults.append( defaults.append(
{ {
kwarg: default kwarg: default
for kwarg, default in zip(argspec.args[-len(argspec.defaults):], argspec.defaults) for kwarg, default in zip(argspec.args[-len(argspec.defaults) :], argspec.defaults)
if not kwarg.startswith("_") if not kwarg.startswith("_")
} }
) )
...@@ -640,7 +639,7 @@ class VideoDatasetTestCase(DatasetTestCase): ...@@ -640,7 +639,7 @@ class VideoDatasetTestCase(DatasetTestCase):
def _set_default_frames_per_clip(self, inject_fake_data): def _set_default_frames_per_clip(self, inject_fake_data):
argspec = inspect.getfullargspec(self.DATASET_CLASS.__init__) argspec = inspect.getfullargspec(self.DATASET_CLASS.__init__)
args_without_default = argspec.args[1:(-len(argspec.defaults) if argspec.defaults else None)] args_without_default = argspec.args[1 : (-len(argspec.defaults) if argspec.defaults else None)]
frames_per_clip_last = args_without_default[-1] == "frames_per_clip" frames_per_clip_last = args_without_default[-1] == "frames_per_clip"
@functools.wraps(inject_fake_data) @functools.wraps(inject_fake_data)
......
import argparse import argparse
import os import os
from timeit import default_timer as timer from timeit import default_timer as timer
from torch.utils.model_zoo import tqdm
import torch import torch
import torch.utils.data import torch.utils.data
import torchvision import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.model_zoo import tqdm
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') parser = argparse.ArgumentParser(description="PyTorch ImageNet Training")
parser.add_argument('--data', metavar='PATH', required=True, parser.add_argument("--data", metavar="PATH", required=True, help="path to dataset")
help='path to dataset') parser.add_argument(
parser.add_argument('--nThreads', '-j', default=2, type=int, metavar='N', "--nThreads", "-j", default=2, type=int, metavar="N", help="number of data loading threads (default: 2)"
help='number of data loading threads (default: 2)') )
parser.add_argument('--batchSize', '-b', default=256, type=int, metavar='N', parser.add_argument(
help='mini-batch size (1 = pure stochastic) Default: 256') "--batchSize", "-b", default=256, type=int, metavar="N", help="mini-batch size (1 = pure stochastic) Default: 256"
parser.add_argument('--accimage', action='store_true', )
help='use accimage') parser.add_argument("--accimage", action="store_true", help="use accimage")
if __name__ == "__main__": if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
if args.accimage: if args.accimage:
torchvision.set_image_backend('accimage') torchvision.set_image_backend("accimage")
print('Using {}'.format(torchvision.get_image_backend())) print("Using {}".format(torchvision.get_image_backend()))
# Data loading code # Data loading code
transform = transforms.Compose([ transform = transforms.Compose(
transforms.RandomSizedCrop(224), [
transforms.RandomHorizontalFlip(), transforms.RandomSizedCrop(224),
transforms.ToTensor(), transforms.RandomHorizontalFlip(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], transforms.ToTensor(),
std=[0.229, 0.224, 0.225]), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]) ]
)
traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val') traindir = os.path.join(args.data, "train")
valdir = os.path.join(args.data, "val")
train = datasets.ImageFolder(traindir, transform) train = datasets.ImageFolder(traindir, transform)
val = datasets.ImageFolder(valdir, transform) val = datasets.ImageFolder(valdir, transform)
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
train, batch_size=args.batchSize, shuffle=True, num_workers=args.nThreads) train, batch_size=args.batchSize, shuffle=True, num_workers=args.nThreads
)
train_iter = iter(train_loader) train_iter = iter(train_loader)
start_time = timer() start_time = timer()
...@@ -51,9 +54,12 @@ if __name__ == "__main__": ...@@ -51,9 +54,12 @@ if __name__ == "__main__":
pbar.update(1) pbar.update(1)
batch = next(train_iter) batch = next(train_iter)
end_time = timer() end_time = timer()
print("Performance: {dataset:.0f} minutes/dataset, {batch:.1f} ms/batch," print(
" {image:.2f} ms/image {rate:.0f} images/sec" "Performance: {dataset:.0f} minutes/dataset, {batch:.1f} ms/batch,"
.format(dataset=(end_time - start_time) * (float(len(train_loader)) / batch_count / 60.0), " {image:.2f} ms/image {rate:.0f} images/sec".format(
batch=(end_time - start_time) / float(batch_count) * 1.0e+3, dataset=(end_time - start_time) * (float(len(train_loader)) / batch_count / 60.0),
image=(end_time - start_time) / (batch_count * args.batchSize) * 1.0e+3, batch=(end_time - start_time) / float(batch_count) * 1.0e3,
rate=(batch_count * args.batchSize) / (end_time - start_time))) image=(end_time - start_time) / (batch_count * args.batchSize) * 1.0e3,
rate=(batch_count * args.batchSize) / (end_time - start_time),
)
)
import random
from functools import partial from functools import partial
from itertools import chain from itertools import chain
import random
import pytest
import torch import torch
from torchvision import models
import torchvision import torchvision
from common_utils import set_rng_seed
from torchvision import models
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.models.feature_extraction import create_feature_extractor from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models.feature_extraction import get_graph_node_names from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models._utils import IntermediateLayerGetter
import pytest
from common_utils import set_rng_seed
def get_available_models(): def get_available_models():
# TODO add a registration mechanism to torchvision.models # TODO add a registration mechanism to torchvision.models
return [k for k, v in models.__dict__.items() return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
@pytest.mark.parametrize('backbone_name', ('resnet18', 'resnet50')) @pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50"))
def test_resnet_fpn_backbone(backbone_name): def test_resnet_fpn_backbone(backbone_name):
x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device='cpu') x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device="cpu")
y = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)(x) y = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)(x)
assert list(y.keys()) == ['0', '1', '2', '3', 'pool'] assert list(y.keys()) == ["0", "1", "2", "3", "pool"]
# Needed by TestFxFeatureExtraction.test_leaf_module_and_function # Needed by TestFxFeatureExtraction.test_leaf_module_and_function
...@@ -64,16 +61,21 @@ class TestModule(torch.nn.Module): ...@@ -64,16 +61,21 @@ class TestModule(torch.nn.Module):
test_module_nodes = [ test_module_nodes = [
'x', 'submodule.add', 'submodule.add_1', 'submodule.relu', "x",
'submodule.relu_1', 'add', 'add_1', 'relu', 'relu_1'] "submodule.add",
"submodule.add_1",
"submodule.relu",
"submodule.relu_1",
"add",
"add_1",
"relu",
"relu_1",
]
class TestFxFeatureExtraction: class TestFxFeatureExtraction:
inp = torch.rand(1, 3, 224, 224, dtype=torch.float32, device='cpu') inp = torch.rand(1, 3, 224, 224, dtype=torch.float32, device="cpu")
model_defaults = { model_defaults = {"num_classes": 1, "pretrained": False}
'num_classes': 1,
'pretrained': False
}
leaf_modules = [] leaf_modules = []
def _create_feature_extractor(self, *args, **kwargs): def _create_feature_extractor(self, *args, **kwargs):
...@@ -81,41 +83,36 @@ class TestFxFeatureExtraction: ...@@ -81,41 +83,36 @@ class TestFxFeatureExtraction:
Apply leaf modules Apply leaf modules
""" """
tracer_kwargs = {} tracer_kwargs = {}
if 'tracer_kwargs' not in kwargs: if "tracer_kwargs" not in kwargs:
tracer_kwargs = {'leaf_modules': self.leaf_modules} tracer_kwargs = {"leaf_modules": self.leaf_modules}
else: else:
tracer_kwargs = kwargs.pop('tracer_kwargs') tracer_kwargs = kwargs.pop("tracer_kwargs")
return create_feature_extractor( return create_feature_extractor(*args, **kwargs, tracer_kwargs=tracer_kwargs, suppress_diff_warning=True)
*args, **kwargs,
tracer_kwargs=tracer_kwargs,
suppress_diff_warning=True)
def _get_return_nodes(self, model): def _get_return_nodes(self, model):
set_rng_seed(0) set_rng_seed(0)
exclude_nodes_filter = ['getitem', 'floordiv', 'size', 'chunk'] exclude_nodes_filter = ["getitem", "floordiv", "size", "chunk"]
train_nodes, eval_nodes = get_graph_node_names( train_nodes, eval_nodes = get_graph_node_names(
model, tracer_kwargs={'leaf_modules': self.leaf_modules}, model, tracer_kwargs={"leaf_modules": self.leaf_modules}, suppress_diff_warning=True
suppress_diff_warning=True) )
# Get rid of any nodes that don't return tensors as they cause issues # Get rid of any nodes that don't return tensors as they cause issues
# when testing backward pass. # when testing backward pass.
train_nodes = [n for n in train_nodes train_nodes = [n for n in train_nodes if not any(x in n for x in exclude_nodes_filter)]
if not any(x in n for x in exclude_nodes_filter)] eval_nodes = [n for n in eval_nodes if not any(x in n for x in exclude_nodes_filter)]
eval_nodes = [n for n in eval_nodes
if not any(x in n for x in exclude_nodes_filter)]
return random.sample(train_nodes, 10), random.sample(eval_nodes, 10) return random.sample(train_nodes, 10), random.sample(eval_nodes, 10)
@pytest.mark.parametrize('model_name', get_available_models()) @pytest.mark.parametrize("model_name", get_available_models())
def test_build_fx_feature_extractor(self, model_name): def test_build_fx_feature_extractor(self, model_name):
set_rng_seed(0) set_rng_seed(0)
model = models.__dict__[model_name](**self.model_defaults).eval() model = models.__dict__[model_name](**self.model_defaults).eval()
train_return_nodes, eval_return_nodes = self._get_return_nodes(model) train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
# Check that it works with both a list and dict for return nodes # Check that it works with both a list and dict for return nodes
self._create_feature_extractor( self._create_feature_extractor(
model, train_return_nodes={v: v for v in train_return_nodes}, model, train_return_nodes={v: v for v in train_return_nodes}, eval_return_nodes=eval_return_nodes
eval_return_nodes=eval_return_nodes) )
self._create_feature_extractor( self._create_feature_extractor(
model, train_return_nodes=train_return_nodes, model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
eval_return_nodes=eval_return_nodes) )
# Check must specify return nodes # Check must specify return nodes
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
self._create_feature_extractor(model) self._create_feature_extractor(model)
...@@ -123,19 +120,16 @@ class TestFxFeatureExtraction: ...@@ -123,19 +120,16 @@ class TestFxFeatureExtraction:
# mutual exclusivity # mutual exclusivity
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
self._create_feature_extractor( self._create_feature_extractor(
model, return_nodes=train_return_nodes, model, return_nodes=train_return_nodes, train_return_nodes=train_return_nodes
train_return_nodes=train_return_nodes) )
# Check train_return_nodes / eval_return nodes must both be specified # Check train_return_nodes / eval_return nodes must both be specified
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
self._create_feature_extractor( self._create_feature_extractor(model, train_return_nodes=train_return_nodes)
model, train_return_nodes=train_return_nodes)
# Check invalid node name raises ValueError # Check invalid node name raises ValueError
with pytest.raises(ValueError): with pytest.raises(ValueError):
# First just double check that this node really doesn't exist # First just double check that this node really doesn't exist
if not any(n.startswith('l') or n.startswith('l.') for n if not any(n.startswith("l") or n.startswith("l.") for n in chain(train_return_nodes, eval_return_nodes)):
in chain(train_return_nodes, eval_return_nodes)): self._create_feature_extractor(model, train_return_nodes=["l"], eval_return_nodes=["l"])
self._create_feature_extractor(
model, train_return_nodes=['l'], eval_return_nodes=['l'])
else: # otherwise skip this check else: # otherwise skip this check
raise ValueError raise ValueError
...@@ -144,32 +138,25 @@ class TestFxFeatureExtraction: ...@@ -144,32 +138,25 @@ class TestFxFeatureExtraction:
train_nodes, _ = get_graph_node_names(model) train_nodes, _ = get_graph_node_names(model)
assert all(a == b for a, b in zip(train_nodes, test_module_nodes)) assert all(a == b for a, b in zip(train_nodes, test_module_nodes))
@pytest.mark.parametrize('model_name', get_available_models()) @pytest.mark.parametrize("model_name", get_available_models())
def test_forward_backward(self, model_name): def test_forward_backward(self, model_name):
model = models.__dict__[model_name](**self.model_defaults).train() model = models.__dict__[model_name](**self.model_defaults).train()
train_return_nodes, eval_return_nodes = self._get_return_nodes(model) train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
model = self._create_feature_extractor( model = self._create_feature_extractor(
model, train_return_nodes=train_return_nodes, model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
eval_return_nodes=eval_return_nodes) )
out = model(self.inp) out = model(self.inp)
sum([o.mean() for o in out.values()]).backward() sum([o.mean() for o in out.values()]).backward()
def test_feature_extraction_methods_equivalence(self): def test_feature_extraction_methods_equivalence(self):
model = models.resnet18(**self.model_defaults).eval() model = models.resnet18(**self.model_defaults).eval()
return_layers = { return_layers = {"layer1": "layer1", "layer2": "layer2", "layer3": "layer3", "layer4": "layer4"}
'layer1': 'layer1',
'layer2': 'layer2', ilg_model = IntermediateLayerGetter(model, return_layers).eval()
'layer3': 'layer3',
'layer4': 'layer4'
}
ilg_model = IntermediateLayerGetter(
model, return_layers).eval()
fx_model = self._create_feature_extractor(model, return_layers) fx_model = self._create_feature_extractor(model, return_layers)
# Check that we have same parameters # Check that we have same parameters
for (n1, p1), (n2, p2) in zip(ilg_model.named_parameters(), for (n1, p1), (n2, p2) in zip(ilg_model.named_parameters(), fx_model.named_parameters()):
fx_model.named_parameters()):
assert n1 == n2 assert n1 == n2
assert p1.equal(p2) assert p1.equal(p2)
...@@ -181,14 +168,14 @@ class TestFxFeatureExtraction: ...@@ -181,14 +168,14 @@ class TestFxFeatureExtraction:
for k in ilg_out.keys(): for k in ilg_out.keys():
assert ilg_out[k].equal(fgn_out[k]) assert ilg_out[k].equal(fgn_out[k])
@pytest.mark.parametrize('model_name', get_available_models()) @pytest.mark.parametrize("model_name", get_available_models())
def test_jit_forward_backward(self, model_name): def test_jit_forward_backward(self, model_name):
set_rng_seed(0) set_rng_seed(0)
model = models.__dict__[model_name](**self.model_defaults).train() model = models.__dict__[model_name](**self.model_defaults).train()
train_return_nodes, eval_return_nodes = self._get_return_nodes(model) train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
model = self._create_feature_extractor( model = self._create_feature_extractor(
model, train_return_nodes=train_return_nodes, model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
eval_return_nodes=eval_return_nodes) )
model = torch.jit.script(model) model = torch.jit.script(model)
fgn_out = model(self.inp) fgn_out = model(self.inp)
sum([o.mean() for o in fgn_out.values()]).backward() sum([o.mean() for o in fgn_out.values()]).backward()
...@@ -197,7 +184,7 @@ class TestFxFeatureExtraction: ...@@ -197,7 +184,7 @@ class TestFxFeatureExtraction:
class TestModel(torch.nn.Module): class TestModel(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.dropout = torch.nn.Dropout(p=1.) self.dropout = torch.nn.Dropout(p=1.0)
def forward(self, x): def forward(self, x):
x = x.mean() x = x.mean()
...@@ -211,54 +198,54 @@ class TestFxFeatureExtraction: ...@@ -211,54 +198,54 @@ class TestFxFeatureExtraction:
model = TestModel() model = TestModel()
train_return_nodes = ['dropout', 'add', 'sub'] train_return_nodes = ["dropout", "add", "sub"]
eval_return_nodes = ['dropout', 'mul', 'sub'] eval_return_nodes = ["dropout", "mul", "sub"]
def checks(model, mode): def checks(model, mode):
with torch.no_grad(): with torch.no_grad():
out = model(torch.ones(10, 10)) out = model(torch.ones(10, 10))
if mode == 'train': if mode == "train":
# Check that dropout is respected # Check that dropout is respected
assert out['dropout'].item() == 0 assert out["dropout"].item() == 0
# Check that control flow dependent on training_mode is respected # Check that control flow dependent on training_mode is respected
assert out['sub'].item() == 100 assert out["sub"].item() == 100
assert 'add' in out assert "add" in out
assert 'mul' not in out assert "mul" not in out
elif mode == 'eval': elif mode == "eval":
# Check that dropout is respected # Check that dropout is respected
assert out['dropout'].item() == 1 assert out["dropout"].item() == 1
# Check that control flow dependent on training_mode is respected # Check that control flow dependent on training_mode is respected
assert out['sub'].item() == 0 assert out["sub"].item() == 0
assert 'mul' in out assert "mul" in out
assert 'add' not in out assert "add" not in out
# Starting from train mode # Starting from train mode
model.train() model.train()
fx_model = self._create_feature_extractor( fx_model = self._create_feature_extractor(
model, train_return_nodes=train_return_nodes, model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
eval_return_nodes=eval_return_nodes) )
# Check that the models stay in their original training state # Check that the models stay in their original training state
assert model.training assert model.training
assert fx_model.training assert fx_model.training
# Check outputs # Check outputs
checks(fx_model, 'train') checks(fx_model, "train")
# Check outputs after switching to eval mode # Check outputs after switching to eval mode
fx_model.eval() fx_model.eval()
checks(fx_model, 'eval') checks(fx_model, "eval")
# Starting from eval mode # Starting from eval mode
model.eval() model.eval()
fx_model = self._create_feature_extractor( fx_model = self._create_feature_extractor(
model, train_return_nodes=train_return_nodes, model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
eval_return_nodes=eval_return_nodes) )
# Check that the models stay in their original training state # Check that the models stay in their original training state
assert not model.training assert not model.training
assert not fx_model.training assert not fx_model.training
# Check outputs # Check outputs
checks(fx_model, 'eval') checks(fx_model, "eval")
# Check outputs after switching to train mode # Check outputs after switching to train mode
fx_model.train() fx_model.train()
checks(fx_model, 'train') checks(fx_model, "train")
def test_leaf_module_and_function(self): def test_leaf_module_and_function(self):
class LeafModule(torch.nn.Module): class LeafModule(torch.nn.Module):
...@@ -279,15 +266,16 @@ class TestFxFeatureExtraction: ...@@ -279,15 +266,16 @@ class TestFxFeatureExtraction:
return self.leaf_module(x) return self.leaf_module(x)
model = self._create_feature_extractor( model = self._create_feature_extractor(
TestModule(), return_nodes=['leaf_module'], TestModule(),
tracer_kwargs={'leaf_modules': [LeafModule], return_nodes=["leaf_module"],
'autowrap_functions': [leaf_function]}).train() tracer_kwargs={"leaf_modules": [LeafModule], "autowrap_functions": [leaf_function]},
).train()
# Check that LeafModule is not in the list of nodes # Check that LeafModule is not in the list of nodes
assert 'relu' not in [str(n) for n in model.graph.nodes] assert "relu" not in [str(n) for n in model.graph.nodes]
assert 'leaf_module' in [str(n) for n in model.graph.nodes] assert "leaf_module" in [str(n) for n in model.graph.nodes]
# Check forward # Check forward
out = model(self.inp) out = model(self.inp)
# And backward # And backward
out['leaf_module'].mean().backward() out["leaf_module"].mean().backward()
import torch
import os import os
import unittest
from torchvision import models, transforms
import sys import sys
import unittest
from PIL import Image import torch
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
from PIL import Image
from torchvision import models, transforms
try: try:
from torchvision import _C_tests from torchvision import _C_tests
...@@ -21,12 +21,13 @@ def process_model(model, tensor, func, name): ...@@ -21,12 +21,13 @@ def process_model(model, tensor, func, name):
py_output = model.forward(tensor) py_output = model.forward(tensor)
cpp_output = func("model.pt", tensor) cpp_output = func("model.pt", tensor)
assert torch.allclose(py_output, cpp_output), 'Output mismatch of ' + name + ' models' assert torch.allclose(py_output, cpp_output), "Output mismatch of " + name + " models"
def read_image1(): def read_image1():
image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', image_path = os.path.join(
'grace_hopper_517x606.jpg') os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg"
)
image = Image.open(image_path) image = Image.open(image_path)
image = image.resize((224, 224)) image = image.resize((224, 224))
x = F.to_tensor(image) x = F.to_tensor(image)
...@@ -34,8 +35,9 @@ def read_image1(): ...@@ -34,8 +35,9 @@ def read_image1():
def read_image2(): def read_image2():
image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', image_path = os.path.join(
'grace_hopper_517x606.jpg') os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg"
)
image = Image.open(image_path) image = Image.open(image_path)
image = image.resize((299, 299)) image = image.resize((299, 299))
x = F.to_tensor(image) x = F.to_tensor(image)
...@@ -46,107 +48,110 @@ def read_image2(): ...@@ -46,107 +48,110 @@ def read_image2():
@unittest.skipIf( @unittest.skipIf(
sys.platform == "darwin" or True, sys.platform == "darwin" or True,
"C++ models are broken on OS X at the moment, and there's a BC breakage on main; " "C++ models are broken on OS X at the moment, and there's a BC breakage on main; "
"see https://github.com/pytorch/vision/issues/1191") "see https://github.com/pytorch/vision/issues/1191",
)
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
pretrained = False pretrained = False
image = read_image1() image = read_image1()
def test_alexnet(self): def test_alexnet(self):
process_model(models.alexnet(self.pretrained), self.image, _C_tests.forward_alexnet, 'Alexnet') process_model(models.alexnet(self.pretrained), self.image, _C_tests.forward_alexnet, "Alexnet")
def test_vgg11(self): def test_vgg11(self):
process_model(models.vgg11(self.pretrained), self.image, _C_tests.forward_vgg11, 'VGG11') process_model(models.vgg11(self.pretrained), self.image, _C_tests.forward_vgg11, "VGG11")
def test_vgg13(self): def test_vgg13(self):
process_model(models.vgg13(self.pretrained), self.image, _C_tests.forward_vgg13, 'VGG13') process_model(models.vgg13(self.pretrained), self.image, _C_tests.forward_vgg13, "VGG13")
def test_vgg16(self): def test_vgg16(self):
process_model(models.vgg16(self.pretrained), self.image, _C_tests.forward_vgg16, 'VGG16') process_model(models.vgg16(self.pretrained), self.image, _C_tests.forward_vgg16, "VGG16")
def test_vgg19(self): def test_vgg19(self):
process_model(models.vgg19(self.pretrained), self.image, _C_tests.forward_vgg19, 'VGG19') process_model(models.vgg19(self.pretrained), self.image, _C_tests.forward_vgg19, "VGG19")
def test_vgg11_bn(self): def test_vgg11_bn(self):
process_model(models.vgg11_bn(self.pretrained), self.image, _C_tests.forward_vgg11bn, 'VGG11BN') process_model(models.vgg11_bn(self.pretrained), self.image, _C_tests.forward_vgg11bn, "VGG11BN")
def test_vgg13_bn(self): def test_vgg13_bn(self):
process_model(models.vgg13_bn(self.pretrained), self.image, _C_tests.forward_vgg13bn, 'VGG13BN') process_model(models.vgg13_bn(self.pretrained), self.image, _C_tests.forward_vgg13bn, "VGG13BN")
def test_vgg16_bn(self): def test_vgg16_bn(self):
process_model(models.vgg16_bn(self.pretrained), self.image, _C_tests.forward_vgg16bn, 'VGG16BN') process_model(models.vgg16_bn(self.pretrained), self.image, _C_tests.forward_vgg16bn, "VGG16BN")
def test_vgg19_bn(self): def test_vgg19_bn(self):
process_model(models.vgg19_bn(self.pretrained), self.image, _C_tests.forward_vgg19bn, 'VGG19BN') process_model(models.vgg19_bn(self.pretrained), self.image, _C_tests.forward_vgg19bn, "VGG19BN")
def test_resnet18(self): def test_resnet18(self):
process_model(models.resnet18(self.pretrained), self.image, _C_tests.forward_resnet18, 'Resnet18') process_model(models.resnet18(self.pretrained), self.image, _C_tests.forward_resnet18, "Resnet18")
def test_resnet34(self): def test_resnet34(self):
process_model(models.resnet34(self.pretrained), self.image, _C_tests.forward_resnet34, 'Resnet34') process_model(models.resnet34(self.pretrained), self.image, _C_tests.forward_resnet34, "Resnet34")
def test_resnet50(self): def test_resnet50(self):
process_model(models.resnet50(self.pretrained), self.image, _C_tests.forward_resnet50, 'Resnet50') process_model(models.resnet50(self.pretrained), self.image, _C_tests.forward_resnet50, "Resnet50")
def test_resnet101(self): def test_resnet101(self):
process_model(models.resnet101(self.pretrained), self.image, _C_tests.forward_resnet101, 'Resnet101') process_model(models.resnet101(self.pretrained), self.image, _C_tests.forward_resnet101, "Resnet101")
def test_resnet152(self): def test_resnet152(self):
process_model(models.resnet152(self.pretrained), self.image, _C_tests.forward_resnet152, 'Resnet152') process_model(models.resnet152(self.pretrained), self.image, _C_tests.forward_resnet152, "Resnet152")
def test_resnext50_32x4d(self): def test_resnext50_32x4d(self):
process_model(models.resnext50_32x4d(), self.image, _C_tests.forward_resnext50_32x4d, 'ResNext50_32x4d') process_model(models.resnext50_32x4d(), self.image, _C_tests.forward_resnext50_32x4d, "ResNext50_32x4d")
def test_resnext101_32x8d(self): def test_resnext101_32x8d(self):
process_model(models.resnext101_32x8d(), self.image, _C_tests.forward_resnext101_32x8d, 'ResNext101_32x8d') process_model(models.resnext101_32x8d(), self.image, _C_tests.forward_resnext101_32x8d, "ResNext101_32x8d")
def test_wide_resnet50_2(self): def test_wide_resnet50_2(self):
process_model(models.wide_resnet50_2(), self.image, _C_tests.forward_wide_resnet50_2, 'WideResNet50_2') process_model(models.wide_resnet50_2(), self.image, _C_tests.forward_wide_resnet50_2, "WideResNet50_2")
def test_wide_resnet101_2(self): def test_wide_resnet101_2(self):
process_model(models.wide_resnet101_2(), self.image, _C_tests.forward_wide_resnet101_2, 'WideResNet101_2') process_model(models.wide_resnet101_2(), self.image, _C_tests.forward_wide_resnet101_2, "WideResNet101_2")
def test_squeezenet1_0(self): def test_squeezenet1_0(self):
process_model(models.squeezenet1_0(self.pretrained), self.image, process_model(
_C_tests.forward_squeezenet1_0, 'Squeezenet1.0') models.squeezenet1_0(self.pretrained), self.image, _C_tests.forward_squeezenet1_0, "Squeezenet1.0"
)
def test_squeezenet1_1(self): def test_squeezenet1_1(self):
process_model(models.squeezenet1_1(self.pretrained), self.image, process_model(
_C_tests.forward_squeezenet1_1, 'Squeezenet1.1') models.squeezenet1_1(self.pretrained), self.image, _C_tests.forward_squeezenet1_1, "Squeezenet1.1"
)
def test_densenet121(self): def test_densenet121(self):
process_model(models.densenet121(self.pretrained), self.image, _C_tests.forward_densenet121, 'Densenet121') process_model(models.densenet121(self.pretrained), self.image, _C_tests.forward_densenet121, "Densenet121")
def test_densenet169(self): def test_densenet169(self):
process_model(models.densenet169(self.pretrained), self.image, _C_tests.forward_densenet169, 'Densenet169') process_model(models.densenet169(self.pretrained), self.image, _C_tests.forward_densenet169, "Densenet169")
def test_densenet201(self): def test_densenet201(self):
process_model(models.densenet201(self.pretrained), self.image, _C_tests.forward_densenet201, 'Densenet201') process_model(models.densenet201(self.pretrained), self.image, _C_tests.forward_densenet201, "Densenet201")
def test_densenet161(self): def test_densenet161(self):
process_model(models.densenet161(self.pretrained), self.image, _C_tests.forward_densenet161, 'Densenet161') process_model(models.densenet161(self.pretrained), self.image, _C_tests.forward_densenet161, "Densenet161")
def test_mobilenet_v2(self): def test_mobilenet_v2(self):
process_model(models.mobilenet_v2(self.pretrained), self.image, _C_tests.forward_mobilenetv2, 'MobileNet') process_model(models.mobilenet_v2(self.pretrained), self.image, _C_tests.forward_mobilenetv2, "MobileNet")
def test_googlenet(self): def test_googlenet(self):
process_model(models.googlenet(self.pretrained), self.image, _C_tests.forward_googlenet, 'GoogLeNet') process_model(models.googlenet(self.pretrained), self.image, _C_tests.forward_googlenet, "GoogLeNet")
def test_mnasnet0_5(self): def test_mnasnet0_5(self):
process_model(models.mnasnet0_5(self.pretrained), self.image, _C_tests.forward_mnasnet0_5, 'MNASNet0_5') process_model(models.mnasnet0_5(self.pretrained), self.image, _C_tests.forward_mnasnet0_5, "MNASNet0_5")
def test_mnasnet0_75(self): def test_mnasnet0_75(self):
process_model(models.mnasnet0_75(self.pretrained), self.image, _C_tests.forward_mnasnet0_75, 'MNASNet0_75') process_model(models.mnasnet0_75(self.pretrained), self.image, _C_tests.forward_mnasnet0_75, "MNASNet0_75")
def test_mnasnet1_0(self): def test_mnasnet1_0(self):
process_model(models.mnasnet1_0(self.pretrained), self.image, _C_tests.forward_mnasnet1_0, 'MNASNet1_0') process_model(models.mnasnet1_0(self.pretrained), self.image, _C_tests.forward_mnasnet1_0, "MNASNet1_0")
def test_mnasnet1_3(self): def test_mnasnet1_3(self):
process_model(models.mnasnet1_3(self.pretrained), self.image, _C_tests.forward_mnasnet1_3, 'MNASNet1_3') process_model(models.mnasnet1_3(self.pretrained), self.image, _C_tests.forward_mnasnet1_3, "MNASNet1_3")
def test_inception_v3(self): def test_inception_v3(self):
self.image = read_image2() self.image = read_image2()
process_model(models.inception_v3(self.pretrained), self.image, _C_tests.forward_inceptionv3, 'Inceptionv3') process_model(models.inception_v3(self.pretrained), self.image, _C_tests.forward_inceptionv3, "Inceptionv3")
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -2,10 +2,10 @@ import bz2 ...@@ -2,10 +2,10 @@ import bz2
import contextlib import contextlib
import io import io
import itertools import itertools
import json
import os import os
import pathlib import pathlib
import pickle import pickle
import json
import random import random
import shutil import shutil
import string import string
...@@ -13,9 +13,9 @@ import unittest ...@@ -13,9 +13,9 @@ import unittest
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
import zipfile import zipfile
import PIL
import datasets_utils import datasets_utils
import numpy as np import numpy as np
import PIL
import pytest import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -24,8 +24,7 @@ from torchvision import datasets ...@@ -24,8 +24,7 @@ from torchvision import datasets
class STL10TestCase(datasets_utils.ImageDatasetTestCase): class STL10TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.STL10 DATASET_CLASS = datasets.STL10
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test", "unlabeled", "train+unlabeled"))
split=("train", "test", "unlabeled", "train+unlabeled"))
@staticmethod @staticmethod
def _make_binary_file(num_elements, root, name): def _make_binary_file(num_elements, root, name):
...@@ -206,11 +205,11 @@ class Caltech256TestCase(datasets_utils.ImageDatasetTestCase): ...@@ -206,11 +205,11 @@ class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
class WIDERFaceTestCase(datasets_utils.ImageDatasetTestCase): class WIDERFaceTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.WIDERFace DATASET_CLASS = datasets.WIDERFace
FEATURE_TYPES = (PIL.Image.Image, (dict, type(None))) # test split returns None as target FEATURE_TYPES = (PIL.Image.Image, (dict, type(None))) # test split returns None as target
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=('train', 'val', 'test')) ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val", "test"))
def inject_fake_data(self, tmpdir, config): def inject_fake_data(self, tmpdir, config):
widerface_dir = pathlib.Path(tmpdir) / 'widerface' widerface_dir = pathlib.Path(tmpdir) / "widerface"
annotations_dir = widerface_dir / 'wider_face_split' annotations_dir = widerface_dir / "wider_face_split"
os.makedirs(annotations_dir) os.makedirs(annotations_dir)
split_to_idx = split_to_num_examples = { split_to_idx = split_to_num_examples = {
...@@ -220,21 +219,21 @@ class WIDERFaceTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -220,21 +219,21 @@ class WIDERFaceTestCase(datasets_utils.ImageDatasetTestCase):
} }
# We need to create all folders regardless of the split in config # We need to create all folders regardless of the split in config
for split in ('train', 'val', 'test'): for split in ("train", "val", "test"):
split_idx = split_to_idx[split] split_idx = split_to_idx[split]
num_examples = split_to_num_examples[split] num_examples = split_to_num_examples[split]
datasets_utils.create_image_folder( datasets_utils.create_image_folder(
root=tmpdir, root=tmpdir,
name=widerface_dir / f'WIDER_{split}' / 'images' / '0--Parade', name=widerface_dir / f"WIDER_{split}" / "images" / "0--Parade",
file_name_fn=lambda image_idx: f"0_Parade_marchingband_1_{split_idx + image_idx}.jpg", file_name_fn=lambda image_idx: f"0_Parade_marchingband_1_{split_idx + image_idx}.jpg",
num_examples=num_examples, num_examples=num_examples,
) )
annotation_file_name = { annotation_file_name = {
'train': annotations_dir / 'wider_face_train_bbx_gt.txt', "train": annotations_dir / "wider_face_train_bbx_gt.txt",
'val': annotations_dir / 'wider_face_val_bbx_gt.txt', "val": annotations_dir / "wider_face_val_bbx_gt.txt",
'test': annotations_dir / 'wider_face_test_filelist.txt', "test": annotations_dir / "wider_face_test_filelist.txt",
}[split] }[split]
annotation_content = { annotation_content = {
...@@ -267,9 +266,7 @@ class CityScapesTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -267,9 +266,7 @@ class CityScapesTestCase(datasets_utils.ImageDatasetTestCase):
"color", "color",
) )
ADDITIONAL_CONFIGS = ( ADDITIONAL_CONFIGS = (
*datasets_utils.combinations_grid( *datasets_utils.combinations_grid(mode=("fine",), split=("train", "test", "val"), target_type=TARGET_TYPES),
mode=("fine",), split=("train", "test", "val"), target_type=TARGET_TYPES
),
*datasets_utils.combinations_grid( *datasets_utils.combinations_grid(
mode=("coarse",), mode=("coarse",),
split=("train", "train_extra", "val"), split=("train", "train_extra", "val"),
...@@ -324,6 +321,7 @@ class CityScapesTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -324,6 +321,7 @@ class CityScapesTestCase(datasets_utils.ImageDatasetTestCase):
gt_dir = tmpdir / f"gt{mode}" gt_dir = tmpdir / f"gt{mode}"
for split in mode_to_splits[mode]: for split in mode_to_splits[mode]:
for city in cities: for city in cities:
def make_image(name, size=10): def make_image(name, size=10):
datasets_utils.create_image_folder( datasets_utils.create_image_folder(
root=gt_dir / split, root=gt_dir / split,
...@@ -332,6 +330,7 @@ class CityScapesTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -332,6 +330,7 @@ class CityScapesTestCase(datasets_utils.ImageDatasetTestCase):
size=size, size=size,
num_examples=1, num_examples=1,
) )
make_image(f"{city}_000000_000000_gt{mode}_instanceIds.png") make_image(f"{city}_000000_000000_gt{mode}_instanceIds.png")
make_image(f"{city}_000000_000000_gt{mode}_labelIds.png") make_image(f"{city}_000000_000000_gt{mode}_labelIds.png")
make_image(f"{city}_000000_000000_gt{mode}_color.png", size=(4, 10, 10)) make_image(f"{city}_000000_000000_gt{mode}_color.png", size=(4, 10, 10))
...@@ -341,7 +340,7 @@ class CityScapesTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -341,7 +340,7 @@ class CityScapesTestCase(datasets_utils.ImageDatasetTestCase):
json.dump(polygon_target, outfile) json.dump(polygon_target, outfile)
# Create leftImg8bit folder # Create leftImg8bit folder
for split in ['test', 'train_extra', 'train', 'val']: for split in ["test", "train_extra", "train", "val"]:
for city in cities: for city in cities:
datasets_utils.create_image_folder( datasets_utils.create_image_folder(
root=tmpdir / "leftImg8bit" / split, root=tmpdir / "leftImg8bit" / split,
...@@ -350,13 +349,13 @@ class CityScapesTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -350,13 +349,13 @@ class CityScapesTestCase(datasets_utils.ImageDatasetTestCase):
num_examples=1, num_examples=1,
) )
info = {'num_examples': len(cities)} info = {"num_examples": len(cities)}
if config['target_type'] == 'polygon': if config["target_type"] == "polygon":
info['expected_polygon_target'] = polygon_target info["expected_polygon_target"] = polygon_target
return info return info
def test_combined_targets(self): def test_combined_targets(self):
target_types = ['semantic', 'polygon', 'color'] target_types = ["semantic", "polygon", "color"]
with self.create_dataset(target_type=target_types) as (dataset, _): with self.create_dataset(target_type=target_types) as (dataset, _):
output = dataset[0] output = dataset[0]
...@@ -370,32 +369,32 @@ class CityScapesTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -370,32 +369,32 @@ class CityScapesTestCase(datasets_utils.ImageDatasetTestCase):
assert isinstance(output[1][2], PIL.Image.Image) # color assert isinstance(output[1][2], PIL.Image.Image) # color
def test_feature_types_target_color(self): def test_feature_types_target_color(self):
with self.create_dataset(target_type='color') as (dataset, _): with self.create_dataset(target_type="color") as (dataset, _):
color_img, color_target = dataset[0] color_img, color_target = dataset[0]
assert isinstance(color_img, PIL.Image.Image) assert isinstance(color_img, PIL.Image.Image)
assert np.array(color_target).shape[2] == 4 assert np.array(color_target).shape[2] == 4
def test_feature_types_target_polygon(self): def test_feature_types_target_polygon(self):
with self.create_dataset(target_type='polygon') as (dataset, info): with self.create_dataset(target_type="polygon") as (dataset, info):
polygon_img, polygon_target = dataset[0] polygon_img, polygon_target = dataset[0]
assert isinstance(polygon_img, PIL.Image.Image) assert isinstance(polygon_img, PIL.Image.Image)
(polygon_target, info['expected_polygon_target']) (polygon_target, info["expected_polygon_target"])
class ImageNetTestCase(datasets_utils.ImageDatasetTestCase): class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.ImageNet DATASET_CLASS = datasets.ImageNet
REQUIRED_PACKAGES = ('scipy',) REQUIRED_PACKAGES = ("scipy",)
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=('train', 'val')) ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val"))
def inject_fake_data(self, tmpdir, config): def inject_fake_data(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir) tmpdir = pathlib.Path(tmpdir)
wnid = 'n01234567' wnid = "n01234567"
if config['split'] == 'train': if config["split"] == "train":
num_examples = 3 num_examples = 3
datasets_utils.create_image_folder( datasets_utils.create_image_folder(
root=tmpdir, root=tmpdir,
name=tmpdir / 'train' / wnid / wnid, name=tmpdir / "train" / wnid / wnid,
file_name_fn=lambda image_idx: f"{wnid}_{image_idx}.JPEG", file_name_fn=lambda image_idx: f"{wnid}_{image_idx}.JPEG",
num_examples=num_examples, num_examples=num_examples,
) )
...@@ -403,13 +402,13 @@ class ImageNetTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -403,13 +402,13 @@ class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
num_examples = 1 num_examples = 1
datasets_utils.create_image_folder( datasets_utils.create_image_folder(
root=tmpdir, root=tmpdir,
name=tmpdir / 'val' / wnid, name=tmpdir / "val" / wnid,
file_name_fn=lambda image_ifx: "ILSVRC2012_val_0000000{image_idx}.JPEG", file_name_fn=lambda image_ifx: "ILSVRC2012_val_0000000{image_idx}.JPEG",
num_examples=num_examples, num_examples=num_examples,
) )
wnid_to_classes = {wnid: [1]} wnid_to_classes = {wnid: [1]}
torch.save((wnid_to_classes, None), tmpdir / 'meta.bin') torch.save((wnid_to_classes, None), tmpdir / "meta.bin")
return num_examples return num_examples
...@@ -596,7 +595,7 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase): ...@@ -596,7 +595,7 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase):
assert tuple(dataset.attr_names) == info["attr_names"] assert tuple(dataset.attr_names) == info["attr_names"]
def test_images_names_split(self): def test_images_names_split(self):
with self.create_dataset(split='all') as (dataset, _): with self.create_dataset(split="all") as (dataset, _):
all_imgs_names = set(dataset.filename) all_imgs_names = set(dataset.filename)
merged_imgs_names = set() merged_imgs_names = set()
...@@ -888,10 +887,7 @@ class LSUNTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -888,10 +887,7 @@ class LSUNTestCase(datasets_utils.ImageDatasetTestCase):
return num_images return num_images
@contextlib.contextmanager @contextlib.contextmanager
def create_dataset( def create_dataset(self, *args, **kwargs):
self,
*args, **kwargs
):
with super().create_dataset(*args, **kwargs) as output: with super().create_dataset(*args, **kwargs) as output:
yield output yield output
# Currently datasets.LSUN caches the keys in the current directory rather than in the root directory. Thus, # Currently datasets.LSUN caches the keys in the current directory rather than in the root directory. Thus,
...@@ -951,14 +947,12 @@ class LSUNTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -951,14 +947,12 @@ class LSUNTestCase(datasets_utils.ImageDatasetTestCase):
class KineticsTestCase(datasets_utils.VideoDatasetTestCase): class KineticsTestCase(datasets_utils.VideoDatasetTestCase):
DATASET_CLASS = datasets.Kinetics DATASET_CLASS = datasets.Kinetics
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val"), num_classes=("400", "600", "700"))
split=("train", "val"), num_classes=("400", "600", "700")
)
def inject_fake_data(self, tmpdir, config): def inject_fake_data(self, tmpdir, config):
classes = ("Abseiling", "Zumba") classes = ("Abseiling", "Zumba")
num_videos_per_class = 2 num_videos_per_class = 2
tmpdir = pathlib.Path(tmpdir) / config['split'] tmpdir = pathlib.Path(tmpdir) / config["split"]
digits = string.ascii_letters + string.digits + "-_" digits = string.ascii_letters + string.digits + "-_"
for cls in classes: for cls in classes:
datasets_utils.create_video_folder( datasets_utils.create_video_folder(
...@@ -1582,7 +1576,7 @@ class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1582,7 +1576,7 @@ class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase):
# We need to explicitly pass extensions=None here or otherwise it would be filled by the value from the # We need to explicitly pass extensions=None here or otherwise it would be filled by the value from the
# DEFAULT_CONFIG. # DEFAULT_CONFIG.
with self.create_dataset( with self.create_dataset(
config, extensions=None, is_valid_file=lambda file: pathlib.Path(file).suffix[1:] in extensions config, extensions=None, is_valid_file=lambda file: pathlib.Path(file).suffix[1:] in extensions
) as (dataset, info): ) as (dataset, info):
assert len(dataset) == info["num_examples"] assert len(dataset) == info["num_examples"]
...@@ -1668,7 +1662,7 @@ class SvhnTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1668,7 +1662,7 @@ class SvhnTestCase(datasets_utils.ImageDatasetTestCase):
file = f"{split}_32x32.mat" file = f"{split}_32x32.mat"
images = np.zeros((32, 32, 3, num_examples), dtype=np.uint8) images = np.zeros((32, 32, 3, num_examples), dtype=np.uint8)
targets = np.zeros((num_examples,), dtype=np.uint8) targets = np.zeros((num_examples,), dtype=np.uint8)
sio.savemat(os.path.join(tmpdir, file), {'X': images, 'y': targets}) sio.savemat(os.path.join(tmpdir, file), {"X": images, "y": targets})
return num_examples return num_examples
...@@ -1703,8 +1697,7 @@ class Places365TestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1703,8 +1697,7 @@ class Places365TestCase(datasets_utils.ImageDatasetTestCase):
# (file, idx) # (file, idx)
_FILE_LIST_CONTENT = ( _FILE_LIST_CONTENT = (
("Places365_val_00000001.png", 0), ("Places365_val_00000001.png", 0),
*((f"{category}/Places365_train_00000001.png", idx) *((f"{category}/Places365_train_00000001.png", idx) for category, idx in _CATEGORIES_CONTENT),
for category, idx in _CATEGORIES_CONTENT),
) )
@staticmethod @staticmethod
...@@ -1744,8 +1737,8 @@ class Places365TestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1744,8 +1737,8 @@ class Places365TestCase(datasets_utils.ImageDatasetTestCase):
return [(os.path.join(root, folder_name, image), idx) for image, idx in zip(images, idcs)] return [(os.path.join(root, folder_name, image), idx) for image, idx in zip(images, idcs)]
def inject_fake_data(self, tmpdir, config): def inject_fake_data(self, tmpdir, config):
self._make_devkit_archive(tmpdir, config['split']) self._make_devkit_archive(tmpdir, config["split"])
return len(self._make_images_archive(tmpdir, config['split'], config['small'])) return len(self._make_images_archive(tmpdir, config["split"], config["small"]))
def test_classes(self): def test_classes(self):
classes = list(map(lambda x: x[0], self._CATEGORIES_CONTENT)) classes = list(map(lambda x: x[0], self._CATEGORIES_CONTENT))
...@@ -1759,7 +1752,7 @@ class Places365TestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1759,7 +1752,7 @@ class Places365TestCase(datasets_utils.ImageDatasetTestCase):
def test_images_download_preexisting(self): def test_images_download_preexisting(self):
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
with self.create_dataset({'download': True}): with self.create_dataset({"download": True}):
pass pass
...@@ -1805,22 +1798,17 @@ class LFWPeopleTestCase(datasets_utils.DatasetTestCase): ...@@ -1805,22 +1798,17 @@ class LFWPeopleTestCase(datasets_utils.DatasetTestCase):
DATASET_CLASS = datasets.LFWPeople DATASET_CLASS = datasets.LFWPeople
FEATURE_TYPES = (PIL.Image.Image, int) FEATURE_TYPES = (PIL.Image.Image, int)
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
split=('10fold', 'train', 'test'), split=("10fold", "train", "test"), image_set=("original", "funneled", "deepfunneled")
image_set=('original', 'funneled', 'deepfunneled')
) )
_IMAGES_DIR = { _IMAGES_DIR = {"original": "lfw", "funneled": "lfw_funneled", "deepfunneled": "lfw-deepfunneled"}
"original": "lfw", _file_id = {"10fold": "", "train": "DevTrain", "test": "DevTest"}
"funneled": "lfw_funneled",
"deepfunneled": "lfw-deepfunneled"
}
_file_id = {'10fold': '', 'train': 'DevTrain', 'test': 'DevTest'}
def inject_fake_data(self, tmpdir, config): def inject_fake_data(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir) / "lfw-py" tmpdir = pathlib.Path(tmpdir) / "lfw-py"
os.makedirs(tmpdir, exist_ok=True) os.makedirs(tmpdir, exist_ok=True)
return dict( return dict(
num_examples=self._create_images_dir(tmpdir, self._IMAGES_DIR[config["image_set"]], config["split"]), num_examples=self._create_images_dir(tmpdir, self._IMAGES_DIR[config["image_set"]], config["split"]),
split=config["split"] split=config["split"],
) )
def _create_images_dir(self, root, idir, split): def _create_images_dir(self, root, idir, split):
......
import contextlib import contextlib
import itertools import itertools
import tempfile
import time import time
import unittest.mock import unittest.mock
import warnings
from datetime import datetime from datetime import datetime
from distutils import dir_util from distutils import dir_util
from os import path from os import path
from urllib.error import HTTPError, URLError from urllib.error import HTTPError, URLError
from urllib.parse import urlparse from urllib.parse import urlparse
from urllib.request import urlopen, Request from urllib.request import urlopen, Request
import tempfile
import warnings
import pytest import pytest
from torchvision import datasets from torchvision import datasets
from torchvision.datasets.utils import ( from torchvision.datasets.utils import (
download_url, download_url,
......
import contextlib import contextlib
import sys
import os import os
import torch import sys
import pytest
import pytest
import torch
from common_utils import get_list_of_videos, assert_equal
from torchvision import get_video_backend
from torchvision import io from torchvision import io
from torchvision.datasets.samplers import ( from torchvision.datasets.samplers import (
DistributedSampler, DistributedSampler,
...@@ -11,9 +13,6 @@ from torchvision.datasets.samplers import ( ...@@ -11,9 +13,6 @@ from torchvision.datasets.samplers import (
UniformClipSampler, UniformClipSampler,
) )
from torchvision.datasets.video_utils import VideoClips, unfold from torchvision.datasets.video_utils import VideoClips, unfold
from torchvision import get_video_backend
from common_utils import get_list_of_videos, assert_equal
@pytest.mark.skipif(not io.video._av_available(), reason="this test requires av") @pytest.mark.skipif(not io.video._av_available(), reason="this test requires av")
...@@ -24,7 +23,7 @@ class TestDatasetsSamplers: ...@@ -24,7 +23,7 @@ class TestDatasetsSamplers:
sampler = RandomClipSampler(video_clips, 3) sampler = RandomClipSampler(video_clips, 3)
assert len(sampler) == 3 * 3 assert len(sampler) == 3 * 3
indices = torch.tensor(list(iter(sampler))) indices = torch.tensor(list(iter(sampler)))
videos = torch.div(indices, 5, rounding_mode='floor') videos = torch.div(indices, 5, rounding_mode="floor")
v_idxs, count = torch.unique(videos, return_counts=True) v_idxs, count = torch.unique(videos, return_counts=True)
assert_equal(v_idxs, torch.tensor([0, 1, 2])) assert_equal(v_idxs, torch.tensor([0, 1, 2]))
assert_equal(count, torch.tensor([3, 3, 3])) assert_equal(count, torch.tensor([3, 3, 3]))
...@@ -41,7 +40,7 @@ class TestDatasetsSamplers: ...@@ -41,7 +40,7 @@ class TestDatasetsSamplers:
indices.remove(0) indices.remove(0)
indices.remove(1) indices.remove(1)
indices = torch.tensor(indices) - 2 indices = torch.tensor(indices) - 2
videos = torch.div(indices, 5, rounding_mode='floor') videos = torch.div(indices, 5, rounding_mode="floor")
v_idxs, count = torch.unique(videos, return_counts=True) v_idxs, count = torch.unique(videos, return_counts=True)
assert_equal(v_idxs, torch.tensor([0, 1])) assert_equal(v_idxs, torch.tensor([0, 1]))
assert_equal(count, torch.tensor([3, 3])) assert_equal(count, torch.tensor([3, 3]))
...@@ -52,7 +51,7 @@ class TestDatasetsSamplers: ...@@ -52,7 +51,7 @@ class TestDatasetsSamplers:
sampler = UniformClipSampler(video_clips, 3) sampler = UniformClipSampler(video_clips, 3)
assert len(sampler) == 3 * 3 assert len(sampler) == 3 * 3
indices = torch.tensor(list(iter(sampler))) indices = torch.tensor(list(iter(sampler)))
videos = torch.div(indices, 5, rounding_mode='floor') videos = torch.div(indices, 5, rounding_mode="floor")
v_idxs, count = torch.unique(videos, return_counts=True) v_idxs, count = torch.unique(videos, return_counts=True)
assert_equal(v_idxs, torch.tensor([0, 1, 2])) assert_equal(v_idxs, torch.tensor([0, 1, 2]))
assert_equal(count, torch.tensor([3, 3, 3])) assert_equal(count, torch.tensor([3, 3, 3]))
...@@ -92,5 +91,5 @@ class TestDatasetsSamplers: ...@@ -92,5 +91,5 @@ class TestDatasetsSamplers:
assert_equal(indices, torch.tensor([5, 7, 9, 0, 2, 4])) assert_equal(indices, torch.tensor([5, 7, 9, 0, 2, 4]))
if __name__ == '__main__': if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
import bz2 import bz2
import contextlib
import gzip
import itertools
import lzma
import os import os
import torchvision.datasets.utils as utils
import pytest
import zipfile
import tarfile import tarfile
import gzip
import warnings import warnings
from torch._utils_internal import get_file_path_2 import zipfile
from urllib.error import URLError from urllib.error import URLError
import itertools
import lzma
import contextlib
import pytest
import torchvision.datasets.utils as utils
from torch._utils_internal import get_file_path_2
from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS
TEST_FILE = get_file_path_2( TEST_FILE = get_file_path_2(
os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', 'grace_hopper_517x606.jpg') os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg"
)
def patch_url_redirection(mocker, redirect_url): def patch_url_redirection(mocker, redirect_url):
...@@ -60,16 +61,16 @@ class TestDatasetsUtils: ...@@ -60,16 +61,16 @@ class TestDatasetsUtils:
def test_check_md5(self): def test_check_md5(self):
fpath = TEST_FILE fpath = TEST_FILE
correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc' correct_md5 = "9c0bb82894bb3af7f7675ef2b3b6dcdc"
false_md5 = '' false_md5 = ""
assert utils.check_md5(fpath, correct_md5) assert utils.check_md5(fpath, correct_md5)
assert not utils.check_md5(fpath, false_md5) assert not utils.check_md5(fpath, false_md5)
def test_check_integrity(self): def test_check_integrity(self):
existing_fpath = TEST_FILE existing_fpath = TEST_FILE
nonexisting_fpath = '' nonexisting_fpath = ""
correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc' correct_md5 = "9c0bb82894bb3af7f7675ef2b3b6dcdc"
false_md5 = '' false_md5 = ""
assert utils.check_integrity(existing_fpath, correct_md5) assert utils.check_integrity(existing_fpath, correct_md5)
assert not utils.check_integrity(existing_fpath, false_md5) assert not utils.check_integrity(existing_fpath, false_md5)
assert utils.check_integrity(existing_fpath) assert utils.check_integrity(existing_fpath)
...@@ -87,31 +88,35 @@ class TestDatasetsUtils: ...@@ -87,31 +88,35 @@ class TestDatasetsUtils:
assert utils._get_google_drive_file_id(url) is None assert utils._get_google_drive_file_id(url) is None
@pytest.mark.parametrize('file, expected', [ @pytest.mark.parametrize(
("foo.tar.bz2", (".tar.bz2", ".tar", ".bz2")), "file, expected",
("foo.tar.xz", (".tar.xz", ".tar", ".xz")), [
("foo.tar", (".tar", ".tar", None)), ("foo.tar.bz2", (".tar.bz2", ".tar", ".bz2")),
("foo.tar.gz", (".tar.gz", ".tar", ".gz")), ("foo.tar.xz", (".tar.xz", ".tar", ".xz")),
("foo.tbz", (".tbz", ".tar", ".bz2")), ("foo.tar", (".tar", ".tar", None)),
("foo.tbz2", (".tbz2", ".tar", ".bz2")), ("foo.tar.gz", (".tar.gz", ".tar", ".gz")),
("foo.tgz", (".tgz", ".tar", ".gz")), ("foo.tbz", (".tbz", ".tar", ".bz2")),
("foo.bz2", (".bz2", None, ".bz2")), ("foo.tbz2", (".tbz2", ".tar", ".bz2")),
("foo.gz", (".gz", None, ".gz")), ("foo.tgz", (".tgz", ".tar", ".gz")),
("foo.zip", (".zip", ".zip", None)), ("foo.bz2", (".bz2", None, ".bz2")),
("foo.xz", (".xz", None, ".xz")), ("foo.gz", (".gz", None, ".gz")),
("foo.bar.tar.gz", (".tar.gz", ".tar", ".gz")), ("foo.zip", (".zip", ".zip", None)),
("foo.bar.gz", (".gz", None, ".gz")), ("foo.xz", (".xz", None, ".xz")),
("foo.bar.zip", (".zip", ".zip", None))]) ("foo.bar.tar.gz", (".tar.gz", ".tar", ".gz")),
("foo.bar.gz", (".gz", None, ".gz")),
("foo.bar.zip", (".zip", ".zip", None)),
],
)
def test_detect_file_type(self, file, expected): def test_detect_file_type(self, file, expected):
assert utils._detect_file_type(file) == expected assert utils._detect_file_type(file) == expected
@pytest.mark.parametrize('file', ["foo", "foo.tar.baz", "foo.bar"]) @pytest.mark.parametrize("file", ["foo", "foo.tar.baz", "foo.bar"])
def test_detect_file_type_incompatible(self, file): def test_detect_file_type_incompatible(self, file):
# tests detect file type for no extension, unknown compression and unknown partial extension # tests detect file type for no extension, unknown compression and unknown partial extension
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
utils._detect_file_type(file) utils._detect_file_type(file)
@pytest.mark.parametrize('extension', [".bz2", ".gz", ".xz"]) @pytest.mark.parametrize("extension", [".bz2", ".gz", ".xz"])
def test_decompress(self, extension, tmpdir): def test_decompress(self, extension, tmpdir):
def create_compressed(root, content="this is the content"): def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file") file = os.path.join(root, "file")
...@@ -152,8 +157,8 @@ class TestDatasetsUtils: ...@@ -152,8 +157,8 @@ class TestDatasetsUtils:
assert not os.path.exists(compressed) assert not os.path.exists(compressed)
@pytest.mark.parametrize('extension', [".gz", ".xz"]) @pytest.mark.parametrize("extension", [".gz", ".xz"])
@pytest.mark.parametrize('remove_finished', [True, False]) @pytest.mark.parametrize("remove_finished", [True, False])
def test_extract_archive_defer_to_decompress(self, extension, remove_finished, mocker): def test_extract_archive_defer_to_decompress(self, extension, remove_finished, mocker):
filename = "foo" filename = "foo"
file = f"{filename}{extension}" file = f"{filename}{extension}"
...@@ -182,8 +187,9 @@ class TestDatasetsUtils: ...@@ -182,8 +187,9 @@ class TestDatasetsUtils:
with open(file, "r") as fh: with open(file, "r") as fh:
assert fh.read() == content assert fh.read() == content
@pytest.mark.parametrize('extension, mode', [ @pytest.mark.parametrize(
('.tar', 'w'), ('.tar.gz', 'w:gz'), ('.tgz', 'w:gz'), ('.tar.xz', 'w:xz')]) "extension, mode", [(".tar", "w"), (".tar.gz", "w:gz"), (".tgz", "w:gz"), (".tar.xz", "w:xz")]
)
def test_extract_tar(self, extension, mode, tmpdir): def test_extract_tar(self, extension, mode, tmpdir):
def create_archive(root, extension, mode, content="this is the content"): def create_archive(root, extension, mode, content="this is the content"):
src = os.path.join(root, "src.txt") src = os.path.join(root, "src.txt")
...@@ -213,5 +219,5 @@ class TestDatasetsUtils: ...@@ -213,5 +219,5 @@ class TestDatasetsUtils:
pytest.raises(ValueError, utils.verify_str_arg, "b", ("a",), "arg") pytest.raises(ValueError, utils.verify_str_arg, "b", ("a",), "arg")
if __name__ == '__main__': if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
import contextlib import contextlib
import os import os
import torch
import pytest
import pytest
import torch
from common_utils import get_list_of_videos, assert_equal
from torchvision import io from torchvision import io
from torchvision.datasets.video_utils import VideoClips, unfold from torchvision.datasets.video_utils import VideoClips, unfold
from common_utils import get_list_of_videos, assert_equal
class TestVideo: class TestVideo:
def test_unfold(self): def test_unfold(self):
a = torch.arange(7) a = torch.arange(7)
r = unfold(a, 3, 3, 1) r = unfold(a, 3, 3, 1)
expected = torch.tensor([ expected = torch.tensor(
[0, 1, 2], [
[3, 4, 5], [0, 1, 2],
]) [3, 4, 5],
]
)
assert_equal(r, expected) assert_equal(r, expected)
r = unfold(a, 3, 2, 1) r = unfold(a, 3, 2, 1)
expected = torch.tensor([ expected = torch.tensor([[0, 1, 2], [2, 3, 4], [4, 5, 6]])
[0, 1, 2],
[2, 3, 4],
[4, 5, 6]
])
assert_equal(r, expected) assert_equal(r, expected)
r = unfold(a, 3, 2, 2) r = unfold(a, 3, 2, 2)
expected = torch.tensor([ expected = torch.tensor(
[0, 2, 4], [
[2, 4, 6], [0, 2, 4],
]) [2, 4, 6],
]
)
assert_equal(r, expected) assert_equal(r, expected)
@pytest.mark.skipif(not io.video._av_available(), reason="this test requires av") @pytest.mark.skipif(not io.video._av_available(), reason="this test requires av")
...@@ -79,8 +77,7 @@ class TestVideo: ...@@ -79,8 +77,7 @@ class TestVideo:
orig_fps = 30 orig_fps = 30
duration = float(len(video_pts)) / orig_fps duration = float(len(video_pts)) / orig_fps
new_fps = 13 new_fps = 13
clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames, clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames, orig_fps, new_fps)
orig_fps, new_fps)
resampled_idxs = VideoClips._resample_video_idx(int(duration * new_fps), orig_fps, new_fps) resampled_idxs = VideoClips._resample_video_idx(int(duration * new_fps), orig_fps, new_fps)
assert len(clips) == 1 assert len(clips) == 1
assert_equal(clips, idxs) assert_equal(clips, idxs)
...@@ -91,8 +88,7 @@ class TestVideo: ...@@ -91,8 +88,7 @@ class TestVideo:
orig_fps = 30 orig_fps = 30
duration = float(len(video_pts)) / orig_fps duration = float(len(video_pts)) / orig_fps
new_fps = 12 new_fps = 12
clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames, clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames, orig_fps, new_fps)
orig_fps, new_fps)
resampled_idxs = VideoClips._resample_video_idx(int(duration * new_fps), orig_fps, new_fps) resampled_idxs = VideoClips._resample_video_idx(int(duration * new_fps), orig_fps, new_fps)
assert len(clips) == 3 assert len(clips) == 3
assert_equal(clips, idxs) assert_equal(clips, idxs)
...@@ -103,11 +99,10 @@ class TestVideo: ...@@ -103,11 +99,10 @@ class TestVideo:
orig_fps = 30 orig_fps = 30
new_fps = 13 new_fps = 13
with pytest.warns(UserWarning): with pytest.warns(UserWarning):
clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames, clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames, orig_fps, new_fps)
orig_fps, new_fps)
assert len(clips) == 0 assert len(clips) == 0
assert len(idxs) == 0 assert len(idxs) == 0
if __name__ == '__main__': if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
import unittest import unittest
from torchvision import set_video_backend
import test_datasets_video_utils import test_datasets_video_utils
from torchvision import set_video_backend
# Disabling the video backend switching temporarily # Disabling the video backend switching temporarily
# set_video_backend('video_reader') # set_video_backend('video_reader')
if __name__ == '__main__': if __name__ == "__main__":
suite = unittest.TestLoader().loadTestsFromModule(test_datasets_video_utils) suite = unittest.TestLoader().loadTestsFromModule(test_datasets_video_utils)
unittest.TextTestRunner(verbosity=1).run(suite) unittest.TextTestRunner(verbosity=1).run(suite)
from functools import partial
import itertools
import os
import colorsys import colorsys
import itertools
import math import math
import os
from functools import partial
from typing import Dict, List, Sequence, Tuple
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
import torchvision.transforms.functional_tensor as F_t
import torchvision.transforms.functional_pil as F_pil
import torchvision.transforms.functional as F
import torchvision.transforms as T import torchvision.transforms as T
from torchvision.transforms import InterpolationMode import torchvision.transforms.functional as F
import torchvision.transforms.functional_pil as F_pil
import torchvision.transforms.functional_tensor as F_t
from common_utils import ( from common_utils import (
cpu_and_gpu, cpu_and_gpu,
needs_cuda, needs_cuda,
...@@ -24,15 +22,14 @@ from common_utils import ( ...@@ -24,15 +22,14 @@ from common_utils import (
_test_fn_on_batch, _test_fn_on_batch,
assert_equal, assert_equal,
) )
from torchvision.transforms import InterpolationMode
from typing import Dict, List, Sequence, Tuple
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('fn', [F.get_image_size, F.get_image_num_channels]) @pytest.mark.parametrize("fn", [F.get_image_size, F.get_image_num_channels])
def test_image_sizes(device, fn): def test_image_sizes(device, fn):
script_F = torch.jit.script(fn) script_F = torch.jit.script(fn)
...@@ -57,10 +54,10 @@ def test_scale_channel(): ...@@ -57,10 +54,10 @@ def test_scale_channel():
# TODO: when # https://github.com/pytorch/pytorch/issues/53194 is fixed, # TODO: when # https://github.com/pytorch/pytorch/issues/53194 is fixed,
# only use bincount and remove that test. # only use bincount and remove that test.
size = (1_000,) size = (1_000,)
img_chan = torch.randint(0, 256, size=size).to('cpu') img_chan = torch.randint(0, 256, size=size).to("cpu")
scaled_cpu = F_t._scale_channel(img_chan) scaled_cpu = F_t._scale_channel(img_chan)
scaled_cuda = F_t._scale_channel(img_chan.to('cuda')) scaled_cuda = F_t._scale_channel(img_chan.to("cuda"))
assert_equal(scaled_cpu, scaled_cuda.to('cpu')) assert_equal(scaled_cpu, scaled_cuda.to("cpu"))
class TestRotate: class TestRotate:
...@@ -69,18 +66,33 @@ class TestRotate: ...@@ -69,18 +66,33 @@ class TestRotate:
scripted_rotate = torch.jit.script(F.rotate) scripted_rotate = torch.jit.script(F.rotate)
IMG_W = 26 IMG_W = 26
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('height, width', [(26, IMG_W), (32, IMG_W)]) @pytest.mark.parametrize("height, width", [(26, IMG_W), (32, IMG_W)])
@pytest.mark.parametrize('center', [ @pytest.mark.parametrize(
None, "center",
(int(IMG_W * 0.3), int(IMG_W * 0.4)), [
[int(IMG_W * 0.5), int(IMG_W * 0.6)], None,
]) (int(IMG_W * 0.3), int(IMG_W * 0.4)),
@pytest.mark.parametrize('dt', ALL_DTYPES) [int(IMG_W * 0.5), int(IMG_W * 0.6)],
@pytest.mark.parametrize('angle', range(-180, 180, 17)) ],
@pytest.mark.parametrize('expand', [True, False]) )
@pytest.mark.parametrize('fill', [None, [0, 0, 0], (1, 2, 3), [255, 255, 255], [1, ], (2.0, )]) @pytest.mark.parametrize("dt", ALL_DTYPES)
@pytest.mark.parametrize('fn', [F.rotate, scripted_rotate]) @pytest.mark.parametrize("angle", range(-180, 180, 17))
@pytest.mark.parametrize("expand", [True, False])
@pytest.mark.parametrize(
"fill",
[
None,
[0, 0, 0],
(1, 2, 3),
[255, 255, 255],
[
1,
],
(2.0,),
],
)
@pytest.mark.parametrize("fn", [F.rotate, scripted_rotate])
def test_rotate(self, device, height, width, center, dt, angle, expand, fill, fn): def test_rotate(self, device, height, width, center, dt, angle, expand, fill, fn):
tensor, pil_img = _create_data(height, width, device=device) tensor, pil_img = _create_data(height, width, device=device)
...@@ -101,8 +113,8 @@ class TestRotate: ...@@ -101,8 +113,8 @@ class TestRotate:
out_tensor = out_tensor.to(torch.uint8) out_tensor = out_tensor.to(torch.uint8)
assert out_tensor.shape == out_pil_tensor.shape, ( assert out_tensor.shape == out_pil_tensor.shape, (
f"{(height, width, NEAREST, dt, angle, expand, center)}: " f"{(height, width, NEAREST, dt, angle, expand, center)}: " f"{out_tensor.shape} vs {out_pil_tensor.shape}"
f"{out_tensor.shape} vs {out_pil_tensor.shape}") )
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
...@@ -110,10 +122,11 @@ class TestRotate: ...@@ -110,10 +122,11 @@ class TestRotate:
assert ratio_diff_pixels < 0.03, ( assert ratio_diff_pixels < 0.03, (
f"{(height, width, NEAREST, dt, angle, expand, center, fill)}: " f"{(height, width, NEAREST, dt, angle, expand, center, fill)}: "
f"{ratio_diff_pixels}\n{out_tensor[0, :7, :7]} vs \n" f"{ratio_diff_pixels}\n{out_tensor[0, :7, :7]} vs \n"
f"{out_pil_tensor[0, :7, :7]}") f"{out_pil_tensor[0, :7, :7]}"
)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('dt', ALL_DTYPES) @pytest.mark.parametrize("dt", ALL_DTYPES)
def test_rotate_batch(self, device, dt): def test_rotate_batch(self, device, dt):
if dt == torch.float16 and device == "cpu": if dt == torch.float16 and device == "cpu":
# skip float16 on CPU case # skip float16 on CPU case
...@@ -124,9 +137,7 @@ class TestRotate: ...@@ -124,9 +137,7 @@ class TestRotate:
batch_tensors = batch_tensors.to(dtype=dt) batch_tensors = batch_tensors.to(dtype=dt)
center = (20, 22) center = (20, 22)
_test_fn_on_batch( _test_fn_on_batch(batch_tensors, F.rotate, angle=32, interpolation=NEAREST, expand=True, center=center)
batch_tensors, F.rotate, angle=32, interpolation=NEAREST, expand=True, center=center
)
def test_rotate_deprecation_resample(self): def test_rotate_deprecation_resample(self):
tensor, _ = _create_data(26, 26) tensor, _ = _create_data(26, 26)
...@@ -150,9 +161,9 @@ class TestAffine: ...@@ -150,9 +161,9 @@ class TestAffine:
ALL_DTYPES = [None, torch.float32, torch.float64, torch.float16] ALL_DTYPES = [None, torch.float32, torch.float64, torch.float16]
scripted_affine = torch.jit.script(F.affine) scripted_affine = torch.jit.script(F.affine)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('height, width', [(26, 26), (32, 26)]) @pytest.mark.parametrize("height, width", [(26, 26), (32, 26)])
@pytest.mark.parametrize('dt', ALL_DTYPES) @pytest.mark.parametrize("dt", ALL_DTYPES)
def test_identity_map(self, device, height, width, dt): def test_identity_map(self, device, height, width, dt):
# Tests on square and rectangular images # Tests on square and rectangular images
tensor, pil_img = _create_data(height, width, device=device) tensor, pil_img = _create_data(height, width, device=device)
...@@ -173,19 +184,22 @@ class TestAffine: ...@@ -173,19 +184,22 @@ class TestAffine:
) )
assert_equal(tensor, out_tensor, msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])) assert_equal(tensor, out_tensor, msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]))
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('height, width', [(26, 26)]) @pytest.mark.parametrize("height, width", [(26, 26)])
@pytest.mark.parametrize('dt', ALL_DTYPES) @pytest.mark.parametrize("dt", ALL_DTYPES)
@pytest.mark.parametrize('angle, config', [ @pytest.mark.parametrize(
(90, {'k': 1, 'dims': (-1, -2)}), "angle, config",
(45, None), [
(30, None), (90, {"k": 1, "dims": (-1, -2)}),
(-30, None), (45, None),
(-45, None), (30, None),
(-90, {'k': -1, 'dims': (-1, -2)}), (-30, None),
(180, {'k': 2, 'dims': (-1, -2)}), (-45, None),
]) (-90, {"k": -1, "dims": (-1, -2)}),
@pytest.mark.parametrize('fn', [F.affine, scripted_affine]) (180, {"k": 2, "dims": (-1, -2)}),
],
)
@pytest.mark.parametrize("fn", [F.affine, scripted_affine])
def test_square_rotations(self, device, height, width, dt, angle, config, fn): def test_square_rotations(self, device, height, width, dt, angle, config, fn):
# 2) Test rotation # 2) Test rotation
tensor, pil_img = _create_data(height, width, device=device) tensor, pil_img = _create_data(height, width, device=device)
...@@ -202,9 +216,7 @@ class TestAffine: ...@@ -202,9 +216,7 @@ class TestAffine:
) )
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))).to(device) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))).to(device)
out_tensor = fn( out_tensor = fn(tensor, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST)
tensor, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
)
if config is not None: if config is not None:
assert_equal(torch.rot90(tensor, **config), out_tensor) assert_equal(torch.rot90(tensor, **config), out_tensor)
...@@ -218,11 +230,11 @@ class TestAffine: ...@@ -218,11 +230,11 @@ class TestAffine:
ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
) )
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('height, width', [(32, 26)]) @pytest.mark.parametrize("height, width", [(32, 26)])
@pytest.mark.parametrize('dt', ALL_DTYPES) @pytest.mark.parametrize("dt", ALL_DTYPES)
@pytest.mark.parametrize('angle', [90, 45, 15, -30, -60, -120]) @pytest.mark.parametrize("angle", [90, 45, 15, -30, -60, -120])
@pytest.mark.parametrize('fn', [F.affine, scripted_affine]) @pytest.mark.parametrize("fn", [F.affine, scripted_affine])
def test_rect_rotations(self, device, height, width, dt, angle, fn): def test_rect_rotations(self, device, height, width, dt, angle, fn):
# Tests on rectangular images # Tests on rectangular images
tensor, pil_img = _create_data(height, width, device=device) tensor, pil_img = _create_data(height, width, device=device)
...@@ -239,9 +251,7 @@ class TestAffine: ...@@ -239,9 +251,7 @@ class TestAffine:
) )
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
out_tensor = fn( out_tensor = fn(tensor, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST).cpu()
tensor, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
).cpu()
if out_tensor.dtype != torch.uint8: if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8) out_tensor = out_tensor.to(torch.uint8)
...@@ -253,11 +263,11 @@ class TestAffine: ...@@ -253,11 +263,11 @@ class TestAffine:
angle, ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] angle, ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
) )
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('height, width', [(26, 26), (32, 26)]) @pytest.mark.parametrize("height, width", [(26, 26), (32, 26)])
@pytest.mark.parametrize('dt', ALL_DTYPES) @pytest.mark.parametrize("dt", ALL_DTYPES)
@pytest.mark.parametrize('t', [[10, 12], (-12, -13)]) @pytest.mark.parametrize("t", [[10, 12], (-12, -13)])
@pytest.mark.parametrize('fn', [F.affine, scripted_affine]) @pytest.mark.parametrize("fn", [F.affine, scripted_affine])
def test_translations(self, device, height, width, dt, t, fn): def test_translations(self, device, height, width, dt, t, fn):
# 3) Test translation # 3) Test translation
tensor, pil_img = _create_data(height, width, device=device) tensor, pil_img = _create_data(height, width, device=device)
...@@ -278,22 +288,41 @@ class TestAffine: ...@@ -278,22 +288,41 @@ class TestAffine:
_assert_equal_tensor_to_pil(out_tensor, out_pil_img) _assert_equal_tensor_to_pil(out_tensor, out_pil_img)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('height, width', [(26, 26), (32, 26)]) @pytest.mark.parametrize("height, width", [(26, 26), (32, 26)])
@pytest.mark.parametrize('dt', ALL_DTYPES) @pytest.mark.parametrize("dt", ALL_DTYPES)
@pytest.mark.parametrize('a, t, s, sh, f', [ @pytest.mark.parametrize(
(45.5, [5, 6], 1.0, [0.0, 0.0], None), "a, t, s, sh, f",
(33, (5, -4), 1.0, [0.0, 0.0], [0, 0, 0]), [
(45, [-5, 4], 1.2, [0.0, 0.0], (1, 2, 3)), (45.5, [5, 6], 1.0, [0.0, 0.0], None),
(33, (-4, -8), 2.0, [0.0, 0.0], [255, 255, 255]), (33, (5, -4), 1.0, [0.0, 0.0], [0, 0, 0]),
(85, (10, -10), 0.7, [0.0, 0.0], [1, ]), (45, [-5, 4], 1.2, [0.0, 0.0], (1, 2, 3)),
(0, [0, 0], 1.0, [35.0, ], (2.0, )), (33, (-4, -8), 2.0, [0.0, 0.0], [255, 255, 255]),
(-25, [0, 0], 1.2, [0.0, 15.0], None), (
(-45, [-10, 0], 0.7, [2.0, 5.0], None), 85,
(-45, [-10, -10], 1.2, [4.0, 5.0], None), (10, -10),
(-90, [0, 0], 1.0, [0.0, 0.0], None), 0.7,
]) [0.0, 0.0],
@pytest.mark.parametrize('fn', [F.affine, scripted_affine]) [
1,
],
),
(
0,
[0, 0],
1.0,
[
35.0,
],
(2.0,),
),
(-25, [0, 0], 1.2, [0.0, 15.0], None),
(-45, [-10, 0], 0.7, [2.0, 5.0], None),
(-45, [-10, -10], 1.2, [4.0, 5.0], None),
(-90, [0, 0], 1.0, [0.0, 0.0], None),
],
)
@pytest.mark.parametrize("fn", [F.affine, scripted_affine])
def test_all_ops(self, device, height, width, dt, a, t, s, sh, f, fn): def test_all_ops(self, device, height, width, dt, a, t, s, sh, f, fn):
# 4) Test rotation + translation + scale + shear # 4) Test rotation + translation + scale + shear
tensor, pil_img = _create_data(height, width, device=device) tensor, pil_img = _create_data(height, width, device=device)
...@@ -322,8 +351,8 @@ class TestAffine: ...@@ -322,8 +351,8 @@ class TestAffine:
(NEAREST, a, t, s, sh, f), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] (NEAREST, a, t, s, sh, f), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
) )
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('dt', ALL_DTYPES) @pytest.mark.parametrize("dt", ALL_DTYPES)
def test_batches(self, device, dt): def test_batches(self, device, dt):
if dt == torch.float16 and device == "cpu": if dt == torch.float16 and device == "cpu":
# skip float16 on CPU case # skip float16 on CPU case
...@@ -333,11 +362,9 @@ class TestAffine: ...@@ -333,11 +362,9 @@ class TestAffine:
if dt is not None: if dt is not None:
batch_tensors = batch_tensors.to(dtype=dt) batch_tensors = batch_tensors.to(dtype=dt)
_test_fn_on_batch( _test_fn_on_batch(batch_tensors, F.affine, angle=-43, translate=[-3, 4], scale=1.2, shear=[4.0, 5.0])
batch_tensors, F.affine, angle=-43, translate=[-3, 4], scale=1.2, shear=[4.0, 5.0]
)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
def test_warnings(self, device): def test_warnings(self, device):
tensor, pil_img = _create_data(26, 26, device=device) tensor, pil_img = _create_data(26, 26, device=device)
...@@ -379,18 +406,27 @@ def _get_data_dims_and_points_for_perspective(): ...@@ -379,18 +406,27 @@ def _get_data_dims_and_points_for_perspective():
n = 10 n = 10
for dim in data_dims: for dim in data_dims:
points += [ points += [(dim, T.RandomPerspective.get_params(dim[1], dim[0], i / n)) for i in range(n)]
(dim, T.RandomPerspective.get_params(dim[1], dim[0], i / n))
for i in range(n)
]
return dims_and_points return dims_and_points
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('dims_and_points', _get_data_dims_and_points_for_perspective()) @pytest.mark.parametrize("dims_and_points", _get_data_dims_and_points_for_perspective())
@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize('fill', (None, [0, 0, 0], [1, 2, 3], [255, 255, 255], [1, ], (2.0, ))) @pytest.mark.parametrize(
@pytest.mark.parametrize('fn', [F.perspective, torch.jit.script(F.perspective)]) "fill",
(
None,
[0, 0, 0],
[1, 2, 3],
[255, 255, 255],
[
1,
],
(2.0,),
),
)
@pytest.mark.parametrize("fn", [F.perspective, torch.jit.script(F.perspective)])
def test_perspective_pil_vs_tensor(device, dims_and_points, dt, fill, fn): def test_perspective_pil_vs_tensor(device, dims_and_points, dt, fill, fn):
if dt == torch.float16 and device == "cpu": if dt == torch.float16 and device == "cpu":
...@@ -405,8 +441,9 @@ def test_perspective_pil_vs_tensor(device, dims_and_points, dt, fill, fn): ...@@ -405,8 +441,9 @@ def test_perspective_pil_vs_tensor(device, dims_and_points, dt, fill, fn):
interpolation = NEAREST interpolation = NEAREST
fill_pil = int(fill[0]) if fill is not None and len(fill) == 1 else fill fill_pil = int(fill[0]) if fill is not None and len(fill) == 1 else fill
out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=interpolation, out_pil_img = F.perspective(
fill=fill_pil) pil_img, startpoints=spoints, endpoints=epoints, interpolation=interpolation, fill=fill_pil
)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=interpolation, fill=fill).cpu() out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=interpolation, fill=fill).cpu()
...@@ -419,9 +456,9 @@ def test_perspective_pil_vs_tensor(device, dims_and_points, dt, fill, fn): ...@@ -419,9 +456,9 @@ def test_perspective_pil_vs_tensor(device, dims_and_points, dt, fill, fn):
assert ratio_diff_pixels < 0.05 assert ratio_diff_pixels < 0.05
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('dims_and_points', _get_data_dims_and_points_for_perspective()) @pytest.mark.parametrize("dims_and_points", _get_data_dims_and_points_for_perspective())
@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
def test_perspective_batch(device, dims_and_points, dt): def test_perspective_batch(device, dims_and_points, dt):
if dt == torch.float16 and device == "cpu": if dt == torch.float16 and device == "cpu":
...@@ -438,8 +475,12 @@ def test_perspective_batch(device, dims_and_points, dt): ...@@ -438,8 +475,12 @@ def test_perspective_batch(device, dims_and_points, dt):
# the border may be entirely different due to small rounding errors. # the border may be entirely different due to small rounding errors.
scripted_fn_atol = -1 if (dt == torch.float16 and device == "cuda") else 1e-8 scripted_fn_atol = -1 if (dt == torch.float16 and device == "cuda") else 1e-8
_test_fn_on_batch( _test_fn_on_batch(
batch_tensors, F.perspective, scripted_fn_atol=scripted_fn_atol, batch_tensors,
startpoints=spoints, endpoints=epoints, interpolation=NEAREST F.perspective,
scripted_fn_atol=scripted_fn_atol,
startpoints=spoints,
endpoints=epoints,
interpolation=NEAREST,
) )
...@@ -454,11 +495,23 @@ def test_perspective_interpolation_warning(): ...@@ -454,11 +495,23 @@ def test_perspective_interpolation_warning():
assert_equal(res1, res2) assert_equal(res1, res2)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize('size', [32, 26, [32, ], [32, 32], (32, 32), [26, 35]]) @pytest.mark.parametrize(
@pytest.mark.parametrize('max_size', [None, 34, 40, 1000]) "size",
@pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC, NEAREST]) [
32,
26,
[
32,
],
[32, 32],
(32, 32),
[26, 35],
],
)
@pytest.mark.parametrize("max_size", [None, 34, 40, 1000])
@pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC, NEAREST])
def test_resize(device, dt, size, max_size, interpolation): def test_resize(device, dt, size, max_size, interpolation):
if dt == torch.float16 and device == "cpu": if dt == torch.float16 and device == "cpu":
...@@ -483,7 +536,9 @@ def test_resize(device, dt, size, max_size, interpolation): ...@@ -483,7 +536,9 @@ def test_resize(device, dt, size, max_size, interpolation):
assert resized_tensor.size()[1:] == resized_pil_img.size[::-1] assert resized_tensor.size()[1:] == resized_pil_img.size[::-1]
if interpolation not in [NEAREST, ]: if interpolation not in [
NEAREST,
]:
# We can not check values if mode = NEAREST, as results are different # We can not check values if mode = NEAREST, as results are different
# E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]] # E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]]
# E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]] # E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]]
...@@ -496,21 +551,19 @@ def test_resize(device, dt, size, max_size, interpolation): ...@@ -496,21 +551,19 @@ def test_resize(device, dt, size, max_size, interpolation):
_assert_approx_equal_tensor_to_pil(resized_tensor_f, resized_pil_img, tol=8.0) _assert_approx_equal_tensor_to_pil(resized_tensor_f, resized_pil_img, tol=8.0)
if isinstance(size, int): if isinstance(size, int):
script_size = [size, ] script_size = [
size,
]
else: else:
script_size = size script_size = size
resize_result = script_fn( resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, max_size=max_size)
tensor, size=script_size, interpolation=interpolation, max_size=max_size
)
assert_equal(resized_tensor, resize_result) assert_equal(resized_tensor, resize_result)
_test_fn_on_batch( _test_fn_on_batch(batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size)
batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size
)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
def test_resize_asserts(device): def test_resize_asserts(device):
tensor, pil_img = _create_data(26, 36, device=device) tensor, pil_img = _create_data(26, 36, device=device)
...@@ -530,10 +583,10 @@ def test_resize_asserts(device): ...@@ -530,10 +583,10 @@ def test_resize_asserts(device):
F.resize(img, size=32, max_size=32) F.resize(img, size=32, max_size=32)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize('size', [[96, 72], [96, 420], [420, 72]]) @pytest.mark.parametrize("size", [[96, 72], [96, 420], [420, 72]])
@pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC]) @pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC])
def test_resize_antialias(device, dt, size, interpolation): def test_resize_antialias(device, dt, size, interpolation):
if dt == torch.float16 and device == "cpu": if dt == torch.float16 and device == "cpu":
...@@ -558,9 +611,7 @@ def test_resize_antialias(device, dt, size, interpolation): ...@@ -558,9 +611,7 @@ def test_resize_antialias(device, dt, size, interpolation):
if resized_tensor_f.dtype == torch.uint8: if resized_tensor_f.dtype == torch.uint8:
resized_tensor_f = resized_tensor_f.to(torch.float) resized_tensor_f = resized_tensor_f.to(torch.float)
_assert_approx_equal_tensor_to_pil( _assert_approx_equal_tensor_to_pil(resized_tensor_f, resized_pil_img, tol=0.5, msg=f"{size}, {interpolation}, {dt}")
resized_tensor_f, resized_pil_img, tol=0.5, msg=f"{size}, {interpolation}, {dt}"
)
accepted_tol = 1.0 + 1e-5 accepted_tol = 1.0 + 1e-5
if interpolation == BICUBIC: if interpolation == BICUBIC:
...@@ -571,12 +622,13 @@ def test_resize_antialias(device, dt, size, interpolation): ...@@ -571,12 +622,13 @@ def test_resize_antialias(device, dt, size, interpolation):
accepted_tol = 15.0 accepted_tol = 15.0
_assert_approx_equal_tensor_to_pil( _assert_approx_equal_tensor_to_pil(
resized_tensor_f, resized_pil_img, tol=accepted_tol, agg_method="max", resized_tensor_f, resized_pil_img, tol=accepted_tol, agg_method="max", msg=f"{size}, {interpolation}, {dt}"
msg=f"{size}, {interpolation}, {dt}"
) )
if isinstance(size, int): if isinstance(size, int):
script_size = [size, ] script_size = [
size,
]
else: else:
script_size = size script_size = size
...@@ -585,7 +637,7 @@ def test_resize_antialias(device, dt, size, interpolation): ...@@ -585,7 +637,7 @@ def test_resize_antialias(device, dt, size, interpolation):
@needs_cuda @needs_cuda
@pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC]) @pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC])
def test_assert_resize_antialias(interpolation): def test_assert_resize_antialias(interpolation):
# Checks implementation on very large scales # Checks implementation on very large scales
...@@ -597,10 +649,10 @@ def test_assert_resize_antialias(interpolation): ...@@ -597,10 +649,10 @@ def test_assert_resize_antialias(interpolation):
F.resize(tensor, size=(5, 5), interpolation=interpolation, antialias=True) F.resize(tensor, size=(5, 5), interpolation=interpolation, antialias=True)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('dt', [torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize("dt", [torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize('size', [[10, 7], [10, 42], [42, 7]]) @pytest.mark.parametrize("size", [[10, 7], [10, 42], [42, 7]])
@pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC]) @pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC])
def test_interpolate_antialias_backward(device, dt, size, interpolation): def test_interpolate_antialias_backward(device, dt, size, interpolation):
if dt == torch.float16 and device == "cpu": if dt == torch.float16 and device == "cpu":
...@@ -616,7 +668,6 @@ def test_interpolate_antialias_backward(device, dt, size, interpolation): ...@@ -616,7 +668,6 @@ def test_interpolate_antialias_backward(device, dt, size, interpolation):
backward_op = torch.ops.torchvision._interpolate_bicubic2d_aa_backward backward_op = torch.ops.torchvision._interpolate_bicubic2d_aa_backward
class F(torch.autograd.Function): class F(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, i): def forward(ctx, i):
result = forward_op(i, size, False) result = forward_op(i, size, False)
...@@ -630,14 +681,10 @@ def test_interpolate_antialias_backward(device, dt, size, interpolation): ...@@ -630,14 +681,10 @@ def test_interpolate_antialias_backward(device, dt, size, interpolation):
oshape = result.shape[2:] oshape = result.shape[2:]
return backward_op(grad_output, oshape, ishape, False) return backward_op(grad_output, oshape, ishape, False)
x = ( x = (torch.rand(1, 32, 29, 3, dtype=torch.double, device=device).permute(0, 3, 1, 2).requires_grad_(True),)
torch.rand(1, 32, 29, 3, dtype=torch.double, device=device).permute(0, 3, 1, 2).requires_grad_(True),
)
assert torch.autograd.gradcheck(F.apply, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False) assert torch.autograd.gradcheck(F.apply, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False)
x = ( x = (torch.rand(1, 3, 32, 29, dtype=torch.double, device=device, requires_grad=True),)
torch.rand(1, 3, 32, 29, dtype=torch.double, device=device, requires_grad=True),
)
assert torch.autograd.gradcheck(F.apply, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False) assert torch.autograd.gradcheck(F.apply, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False)
...@@ -678,10 +725,10 @@ def check_functional_vs_PIL_vs_scripted( ...@@ -678,10 +725,10 @@ def check_functional_vs_PIL_vs_scripted(
_test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=atol, **config) _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=atol, **config)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64)) @pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
@pytest.mark.parametrize('config', [{"brightness_factor": f} for f in (0.1, 0.5, 1.0, 1.34, 2.5)]) @pytest.mark.parametrize("config", [{"brightness_factor": f} for f in (0.1, 0.5, 1.0, 1.34, 2.5)])
@pytest.mark.parametrize('channels', [1, 3]) @pytest.mark.parametrize("channels", [1, 3])
def test_adjust_brightness(device, dtype, config, channels): def test_adjust_brightness(device, dtype, config, channels):
check_functional_vs_PIL_vs_scripted( check_functional_vs_PIL_vs_scripted(
F.adjust_brightness, F.adjust_brightness,
...@@ -694,26 +741,18 @@ def test_adjust_brightness(device, dtype, config, channels): ...@@ -694,26 +741,18 @@ def test_adjust_brightness(device, dtype, config, channels):
) )
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64)) @pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
@pytest.mark.parametrize('channels', [1, 3]) @pytest.mark.parametrize("channels", [1, 3])
def test_invert(device, dtype, channels): def test_invert(device, dtype, channels):
check_functional_vs_PIL_vs_scripted( check_functional_vs_PIL_vs_scripted(
F.invert, F.invert, F_pil.invert, F_t.invert, {}, device, dtype, channels, tol=1.0, agg_method="max"
F_pil.invert,
F_t.invert,
{},
device,
dtype,
channels,
tol=1.0,
agg_method="max"
) )
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('config', [{"bits": bits} for bits in range(0, 8)]) @pytest.mark.parametrize("config", [{"bits": bits} for bits in range(0, 8)])
@pytest.mark.parametrize('channels', [1, 3]) @pytest.mark.parametrize("channels", [1, 3])
def test_posterize(device, config, channels): def test_posterize(device, config, channels):
check_functional_vs_PIL_vs_scripted( check_functional_vs_PIL_vs_scripted(
F.posterize, F.posterize,
...@@ -728,9 +767,9 @@ def test_posterize(device, config, channels): ...@@ -728,9 +767,9 @@ def test_posterize(device, config, channels):
) )
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('config', [{"threshold": threshold} for threshold in [0, 64, 128, 192, 255]]) @pytest.mark.parametrize("config", [{"threshold": threshold} for threshold in [0, 64, 128, 192, 255]])
@pytest.mark.parametrize('channels', [1, 3]) @pytest.mark.parametrize("channels", [1, 3])
def test_solarize1(device, config, channels): def test_solarize1(device, config, channels):
check_functional_vs_PIL_vs_scripted( check_functional_vs_PIL_vs_scripted(
F.solarize, F.solarize,
...@@ -745,10 +784,10 @@ def test_solarize1(device, config, channels): ...@@ -745,10 +784,10 @@ def test_solarize1(device, config, channels):
) )
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('dtype', (torch.float32, torch.float64)) @pytest.mark.parametrize("dtype", (torch.float32, torch.float64))
@pytest.mark.parametrize('config', [{"threshold": threshold} for threshold in [0.0, 0.25, 0.5, 0.75, 1.0]]) @pytest.mark.parametrize("config", [{"threshold": threshold} for threshold in [0.0, 0.25, 0.5, 0.75, 1.0]])
@pytest.mark.parametrize('channels', [1, 3]) @pytest.mark.parametrize("channels", [1, 3])
def test_solarize2(device, dtype, config, channels): def test_solarize2(device, dtype, config, channels):
check_functional_vs_PIL_vs_scripted( check_functional_vs_PIL_vs_scripted(
F.solarize, F.solarize,
...@@ -763,10 +802,10 @@ def test_solarize2(device, dtype, config, channels): ...@@ -763,10 +802,10 @@ def test_solarize2(device, dtype, config, channels):
) )
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64)) @pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
@pytest.mark.parametrize('config', [{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]) @pytest.mark.parametrize("config", [{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]])
@pytest.mark.parametrize('channels', [1, 3]) @pytest.mark.parametrize("channels", [1, 3])
def test_adjust_sharpness(device, dtype, config, channels): def test_adjust_sharpness(device, dtype, config, channels):
check_functional_vs_PIL_vs_scripted( check_functional_vs_PIL_vs_scripted(
F.adjust_sharpness, F.adjust_sharpness,
...@@ -779,25 +818,17 @@ def test_adjust_sharpness(device, dtype, config, channels): ...@@ -779,25 +818,17 @@ def test_adjust_sharpness(device, dtype, config, channels):
) )
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64)) @pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
@pytest.mark.parametrize('channels', [1, 3]) @pytest.mark.parametrize("channels", [1, 3])
def test_autocontrast(device, dtype, channels): def test_autocontrast(device, dtype, channels):
check_functional_vs_PIL_vs_scripted( check_functional_vs_PIL_vs_scripted(
F.autocontrast, F.autocontrast, F_pil.autocontrast, F_t.autocontrast, {}, device, dtype, channels, tol=1.0, agg_method="max"
F_pil.autocontrast,
F_t.autocontrast,
{},
device,
dtype,
channels,
tol=1.0,
agg_method="max"
) )
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('channels', [1, 3]) @pytest.mark.parametrize("channels", [1, 3])
def test_equalize(device, channels): def test_equalize(device, channels):
torch.use_deterministic_algorithms(False) torch.use_deterministic_algorithms(False)
check_functional_vs_PIL_vs_scripted( check_functional_vs_PIL_vs_scripted(
...@@ -813,60 +844,40 @@ def test_equalize(device, channels): ...@@ -813,60 +844,40 @@ def test_equalize(device, channels):
) )
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64)) @pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
@pytest.mark.parametrize('config', [{"contrast_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]) @pytest.mark.parametrize("config", [{"contrast_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]])
@pytest.mark.parametrize('channels', [1, 3]) @pytest.mark.parametrize("channels", [1, 3])
def test_adjust_contrast(device, dtype, config, channels): def test_adjust_contrast(device, dtype, config, channels):
check_functional_vs_PIL_vs_scripted( check_functional_vs_PIL_vs_scripted(
F.adjust_contrast, F.adjust_contrast, F_pil.adjust_contrast, F_t.adjust_contrast, config, device, dtype, channels
F_pil.adjust_contrast,
F_t.adjust_contrast,
config,
device,
dtype,
channels
) )
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64)) @pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
@pytest.mark.parametrize('config', [{"saturation_factor": f} for f in [0.5, 0.75, 1.0, 1.5, 2.0]]) @pytest.mark.parametrize("config", [{"saturation_factor": f} for f in [0.5, 0.75, 1.0, 1.5, 2.0]])
@pytest.mark.parametrize('channels', [1, 3]) @pytest.mark.parametrize("channels", [1, 3])
def test_adjust_saturation(device, dtype, config, channels): def test_adjust_saturation(device, dtype, config, channels):
check_functional_vs_PIL_vs_scripted( check_functional_vs_PIL_vs_scripted(
F.adjust_saturation, F.adjust_saturation, F_pil.adjust_saturation, F_t.adjust_saturation, config, device, dtype, channels
F_pil.adjust_saturation,
F_t.adjust_saturation,
config,
device,
dtype,
channels
) )
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64)) @pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
@pytest.mark.parametrize('config', [{"hue_factor": f} for f in [-0.45, -0.25, 0.0, 0.25, 0.45]]) @pytest.mark.parametrize("config", [{"hue_factor": f} for f in [-0.45, -0.25, 0.0, 0.25, 0.45]])
@pytest.mark.parametrize('channels', [1, 3]) @pytest.mark.parametrize("channels", [1, 3])
def test_adjust_hue(device, dtype, config, channels): def test_adjust_hue(device, dtype, config, channels):
check_functional_vs_PIL_vs_scripted( check_functional_vs_PIL_vs_scripted(
F.adjust_hue, F.adjust_hue, F_pil.adjust_hue, F_t.adjust_hue, config, device, dtype, channels, tol=16.1, agg_method="max"
F_pil.adjust_hue,
F_t.adjust_hue,
config,
device,
dtype,
channels,
tol=16.1,
agg_method="max"
) )
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64)) @pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
@pytest.mark.parametrize('config', [{"gamma": g1, "gain": g2} for g1, g2 in zip([0.8, 1.0, 1.2], [0.7, 1.0, 1.3])]) @pytest.mark.parametrize("config", [{"gamma": g1, "gain": g2} for g1, g2 in zip([0.8, 1.0, 1.2], [0.7, 1.0, 1.3])])
@pytest.mark.parametrize('channels', [1, 3]) @pytest.mark.parametrize("channels", [1, 3])
def test_adjust_gamma(device, dtype, config, channels): def test_adjust_gamma(device, dtype, config, channels):
check_functional_vs_PIL_vs_scripted( check_functional_vs_PIL_vs_scripted(
F.adjust_gamma, F.adjust_gamma,
...@@ -879,17 +890,31 @@ def test_adjust_gamma(device, dtype, config, channels): ...@@ -879,17 +890,31 @@ def test_adjust_gamma(device, dtype, config, channels):
) )
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize('pad', [2, [3, ], [0, 3], (3, 3), [4, 2, 4, 3]]) @pytest.mark.parametrize(
@pytest.mark.parametrize('config', [ "pad",
{"padding_mode": "constant", "fill": 0}, [
{"padding_mode": "constant", "fill": 10}, 2,
{"padding_mode": "constant", "fill": 20}, [
{"padding_mode": "edge"}, 3,
{"padding_mode": "reflect"}, ],
{"padding_mode": "symmetric"}, [0, 3],
]) (3, 3),
[4, 2, 4, 3],
],
)
@pytest.mark.parametrize(
"config",
[
{"padding_mode": "constant", "fill": 0},
{"padding_mode": "constant", "fill": 10},
{"padding_mode": "constant", "fill": 20},
{"padding_mode": "edge"},
{"padding_mode": "reflect"},
{"padding_mode": "symmetric"},
],
)
def test_pad(device, dt, pad, config): def test_pad(device, dt, pad, config):
script_fn = torch.jit.script(F.pad) script_fn = torch.jit.script(F.pad)
tensor, pil_img = _create_data(7, 8, device=device) tensor, pil_img = _create_data(7, 8, device=device)
...@@ -915,7 +940,9 @@ def test_pad(device, dt, pad, config): ...@@ -915,7 +940,9 @@ def test_pad(device, dt, pad, config):
_assert_equal_tensor_to_pil(pad_tensor_8b, pad_pil_img, msg="{}, {}".format(pad, config)) _assert_equal_tensor_to_pil(pad_tensor_8b, pad_pil_img, msg="{}, {}".format(pad, config))
if isinstance(pad, int): if isinstance(pad, int):
script_pad = [pad, ] script_pad = [
pad,
]
else: else:
script_pad = pad script_pad = pad
pad_tensor_script = script_fn(tensor, script_pad, **config) pad_tensor_script = script_fn(tensor, script_pad, **config)
...@@ -924,8 +951,8 @@ def test_pad(device, dt, pad, config): ...@@ -924,8 +951,8 @@ def test_pad(device, dt, pad, config):
_test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **config) _test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **config)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('mode', [NEAREST, BILINEAR, BICUBIC]) @pytest.mark.parametrize("mode", [NEAREST, BILINEAR, BICUBIC])
def test_resized_crop(device, mode): def test_resized_crop(device, mode):
# test values of F.resized_crop in several cases: # test values of F.resized_crop in several cases:
# 1) resize to the same size, crop to the same size => should be identity # 1) resize to the same size, crop to the same size => should be identity
...@@ -950,19 +977,46 @@ def test_resized_crop(device, mode): ...@@ -950,19 +977,46 @@ def test_resized_crop(device, mode):
) )
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('func, args', [ @pytest.mark.parametrize(
(F_t.get_image_size, ()), (F_t.vflip, ()), "func, args",
(F_t.hflip, ()), (F_t.crop, (1, 2, 4, 5)), [
(F_t.adjust_brightness, (0., )), (F_t.adjust_contrast, (1., )), (F_t.get_image_size, ()),
(F_t.adjust_hue, (-0.5, )), (F_t.adjust_saturation, (2., )), (F_t.vflip, ()),
(F_t.pad, ([2, ], 2, "constant")), (F_t.hflip, ()),
(F_t.resize, ([10, 11], )), (F_t.perspective, ([0.2, ])), (F_t.crop, (1, 2, 4, 5)),
(F_t.gaussian_blur, ((2, 2), (0.7, 0.5))), (F_t.adjust_brightness, (0.0,)),
(F_t.invert, ()), (F_t.posterize, (0, )), (F_t.adjust_contrast, (1.0,)),
(F_t.solarize, (0.3, )), (F_t.adjust_sharpness, (0.3, )), (F_t.adjust_hue, (-0.5,)),
(F_t.autocontrast, ()), (F_t.equalize, ()) (F_t.adjust_saturation, (2.0,)),
]) (
F_t.pad,
(
[
2,
],
2,
"constant",
),
),
(F_t.resize, ([10, 11],)),
(
F_t.perspective,
(
[
0.2,
]
),
),
(F_t.gaussian_blur, ((2, 2), (0.7, 0.5))),
(F_t.invert, ()),
(F_t.posterize, (0,)),
(F_t.solarize, (0.3,)),
(F_t.adjust_sharpness, (0.3,)),
(F_t.autocontrast, ()),
(F_t.equalize, ()),
],
)
def test_assert_image_tensor(device, func, args): def test_assert_image_tensor(device, func, args):
shape = (100,) shape = (100,)
tensor = torch.rand(*shape, dtype=torch.float, device=device) tensor = torch.rand(*shape, dtype=torch.float, device=device)
...@@ -970,7 +1024,7 @@ def test_assert_image_tensor(device, func, args): ...@@ -970,7 +1024,7 @@ def test_assert_image_tensor(device, func, args):
func(tensor, *args) func(tensor, *args)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
def test_vflip(device): def test_vflip(device):
script_vflip = torch.jit.script(F.vflip) script_vflip = torch.jit.script(F.vflip)
...@@ -987,7 +1041,7 @@ def test_vflip(device): ...@@ -987,7 +1041,7 @@ def test_vflip(device):
_test_fn_on_batch(batch_tensors, F.vflip) _test_fn_on_batch(batch_tensors, F.vflip)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
def test_hflip(device): def test_hflip(device):
script_hflip = torch.jit.script(F.hflip) script_hflip = torch.jit.script(F.hflip)
...@@ -1004,13 +1058,16 @@ def test_hflip(device): ...@@ -1004,13 +1058,16 @@ def test_hflip(device):
_test_fn_on_batch(batch_tensors, F.hflip) _test_fn_on_batch(batch_tensors, F.hflip)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('top, left, height, width', [ @pytest.mark.parametrize(
(1, 2, 4, 5), # crop inside top-left corner "top, left, height, width",
(2, 12, 3, 4), # crop inside top-right corner [
(8, 3, 5, 6), # crop inside bottom-left corner (1, 2, 4, 5), # crop inside top-left corner
(8, 11, 4, 3), # crop inside bottom-right corner (2, 12, 3, 4), # crop inside top-right corner
]) (8, 3, 5, 6), # crop inside bottom-left corner
(8, 11, 4, 3), # crop inside bottom-right corner
],
)
def test_crop(device, top, left, height, width): def test_crop(device, top, left, height, width):
script_crop = torch.jit.script(F.crop) script_crop = torch.jit.script(F.crop)
...@@ -1028,12 +1085,12 @@ def test_crop(device, top, left, height, width): ...@@ -1028,12 +1085,12 @@ def test_crop(device, top, left, height, width):
_test_fn_on_batch(batch_tensors, F.crop, top=top, left=left, height=height, width=width) _test_fn_on_batch(batch_tensors, F.crop, top=top, left=left, height=height, width=width)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('image_size', ('small', 'large')) @pytest.mark.parametrize("image_size", ("small", "large"))
@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize('ksize', [(3, 3), [3, 5], (23, 23)]) @pytest.mark.parametrize("ksize", [(3, 3), [3, 5], (23, 23)])
@pytest.mark.parametrize('sigma', [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)]) @pytest.mark.parametrize("sigma", [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)])
@pytest.mark.parametrize('fn', [F.gaussian_blur, torch.jit.script(F.gaussian_blur)]) @pytest.mark.parametrize("fn", [F.gaussian_blur, torch.jit.script(F.gaussian_blur)])
def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn): def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn):
# true_cv2_results = { # true_cv2_results = {
...@@ -1050,17 +1107,15 @@ def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn): ...@@ -1050,17 +1107,15 @@ def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn):
# # cv2.GaussianBlur(np_img2, ksize=(23, 23), sigmaX=1.7) # # cv2.GaussianBlur(np_img2, ksize=(23, 23), sigmaX=1.7)
# "23_23_1.7": ... # "23_23_1.7": ...
# } # }
p = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'gaussian_blur_opencv_results.pt') p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "gaussian_blur_opencv_results.pt")
true_cv2_results = torch.load(p) true_cv2_results = torch.load(p)
if image_size == 'small': if image_size == "small":
tensor = torch.from_numpy( tensor = (
np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3)) torch.from_numpy(np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))).permute(2, 0, 1).to(device)
).permute(2, 0, 1).to(device) )
else: else:
tensor = torch.from_numpy( tensor = torch.from_numpy(np.arange(26 * 28, dtype="uint8").reshape((1, 26, 28))).to(device)
np.arange(26 * 28, dtype="uint8").reshape((1, 26, 28))
).to(device)
if dt == torch.float16 and device == "cpu": if dt == torch.float16 and device == "cpu":
# skip float16 on CPU case # skip float16 on CPU case
...@@ -1072,22 +1127,19 @@ def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn): ...@@ -1072,22 +1127,19 @@ def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn):
_ksize = (ksize, ksize) if isinstance(ksize, int) else ksize _ksize = (ksize, ksize) if isinstance(ksize, int) else ksize
_sigma = sigma[0] if sigma is not None else None _sigma = sigma[0] if sigma is not None else None
shape = tensor.shape shape = tensor.shape
gt_key = "{}_{}_{}__{}_{}_{}".format( gt_key = "{}_{}_{}__{}_{}_{}".format(shape[-2], shape[-1], shape[-3], _ksize[0], _ksize[1], _sigma)
shape[-2], shape[-1], shape[-3],
_ksize[0], _ksize[1], _sigma
)
if gt_key not in true_cv2_results: if gt_key not in true_cv2_results:
return return
true_out = torch.tensor( true_out = (
true_cv2_results[gt_key] torch.tensor(true_cv2_results[gt_key]).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor)
).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor) )
out = fn(tensor, kernel_size=ksize, sigma=sigma) out = fn(tensor, kernel_size=ksize, sigma=sigma)
torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg="{}, {}".format(ksize, sigma)) torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg="{}, {}".format(ksize, sigma))
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
def test_hsv2rgb(device): def test_hsv2rgb(device):
scripted_fn = torch.jit.script(F_t._hsv2rgb) scripted_fn = torch.jit.script(F_t._hsv2rgb)
shape = (3, 100, 150) shape = (3, 100, 150)
...@@ -1096,7 +1148,11 @@ def test_hsv2rgb(device): ...@@ -1096,7 +1148,11 @@ def test_hsv2rgb(device):
rgb_img = F_t._hsv2rgb(hsv_img) rgb_img = F_t._hsv2rgb(hsv_img)
ft_img = rgb_img.permute(1, 2, 0).flatten(0, 1) ft_img = rgb_img.permute(1, 2, 0).flatten(0, 1)
h, s, v, = hsv_img.unbind(0) (
h,
s,
v,
) = hsv_img.unbind(0)
h = h.flatten().cpu().numpy() h = h.flatten().cpu().numpy()
s = s.flatten().cpu().numpy() s = s.flatten().cpu().numpy()
v = v.flatten().cpu().numpy() v = v.flatten().cpu().numpy()
...@@ -1114,7 +1170,7 @@ def test_hsv2rgb(device): ...@@ -1114,7 +1170,7 @@ def test_hsv2rgb(device):
_test_fn_on_batch(batch_tensors, F_t._hsv2rgb) _test_fn_on_batch(batch_tensors, F_t._hsv2rgb)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
def test_rgb2hsv(device): def test_rgb2hsv(device):
scripted_fn = torch.jit.script(F_t._rgb2hsv) scripted_fn = torch.jit.script(F_t._rgb2hsv)
shape = (3, 150, 100) shape = (3, 150, 100)
...@@ -1123,7 +1179,11 @@ def test_rgb2hsv(device): ...@@ -1123,7 +1179,11 @@ def test_rgb2hsv(device):
hsv_img = F_t._rgb2hsv(rgb_img) hsv_img = F_t._rgb2hsv(rgb_img)
ft_hsv_img = hsv_img.permute(1, 2, 0).flatten(0, 1) ft_hsv_img = hsv_img.permute(1, 2, 0).flatten(0, 1)
r, g, b, = rgb_img.unbind(dim=-3) (
r,
g,
b,
) = rgb_img.unbind(dim=-3)
r = r.flatten().cpu().numpy() r = r.flatten().cpu().numpy()
g = g.flatten().cpu().numpy() g = g.flatten().cpu().numpy()
b = b.flatten().cpu().numpy() b = b.flatten().cpu().numpy()
...@@ -1149,8 +1209,8 @@ def test_rgb2hsv(device): ...@@ -1149,8 +1209,8 @@ def test_rgb2hsv(device):
_test_fn_on_batch(batch_tensors, F_t._rgb2hsv) _test_fn_on_batch(batch_tensors, F_t._rgb2hsv)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('num_output_channels', (3, 1)) @pytest.mark.parametrize("num_output_channels", (3, 1))
def test_rgb_to_grayscale(device, num_output_channels): def test_rgb_to_grayscale(device, num_output_channels):
script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale) script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale)
...@@ -1168,7 +1228,7 @@ def test_rgb_to_grayscale(device, num_output_channels): ...@@ -1168,7 +1228,7 @@ def test_rgb_to_grayscale(device, num_output_channels):
_test_fn_on_batch(batch_tensors, F.rgb_to_grayscale, num_output_channels=num_output_channels) _test_fn_on_batch(batch_tensors, F.rgb_to_grayscale, num_output_channels=num_output_channels)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
def test_center_crop(device): def test_center_crop(device):
script_center_crop = torch.jit.script(F.center_crop) script_center_crop = torch.jit.script(F.center_crop)
...@@ -1186,7 +1246,7 @@ def test_center_crop(device): ...@@ -1186,7 +1246,7 @@ def test_center_crop(device):
_test_fn_on_batch(batch_tensors, F.center_crop, output_size=[10, 11]) _test_fn_on_batch(batch_tensors, F.center_crop, output_size=[10, 11])
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
def test_five_crop(device): def test_five_crop(device):
script_five_crop = torch.jit.script(F.five_crop) script_five_crop = torch.jit.script(F.five_crop)
...@@ -1220,7 +1280,7 @@ def test_five_crop(device): ...@@ -1220,7 +1280,7 @@ def test_five_crop(device):
assert_equal(transformed_batch, s_transformed_batch) assert_equal(transformed_batch, s_transformed_batch)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
def test_ten_crop(device): def test_ten_crop(device):
script_ten_crop = torch.jit.script(F.ten_crop) script_ten_crop = torch.jit.script(F.ten_crop)
...@@ -1254,5 +1314,5 @@ def test_ten_crop(device): ...@@ -1254,5 +1314,5 @@ def test_ten_crop(device):
assert_equal(transformed_batch, s_transformed_batch) assert_equal(transformed_batch, s_transformed_batch)
if __name__ == '__main__': if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
import torch.hub as hub
import tempfile
import shutil
import os import os
import shutil
import sys import sys
import tempfile
import pytest import pytest
import torch.hub as hub
def sum_of_model_parameters(model): def sum_of_model_parameters(model):
...@@ -16,8 +17,7 @@ def sum_of_model_parameters(model): ...@@ -16,8 +17,7 @@ def sum_of_model_parameters(model):
SUM_OF_PRETRAINED_RESNET18_PARAMS = -12703.9931640625 SUM_OF_PRETRAINED_RESNET18_PARAMS = -12703.9931640625
@pytest.mark.skipif('torchvision' in sys.modules, @pytest.mark.skipif("torchvision" in sys.modules, reason="TestHub must start without torchvision imported")
reason='TestHub must start without torchvision imported')
class TestHub: class TestHub:
# Only run this check ONCE before all tests start. # Only run this check ONCE before all tests start.
# - If torchvision is imported before all tests start, e.g. we might find _C.so # - If torchvision is imported before all tests start, e.g. we might find _C.so
...@@ -26,28 +26,20 @@ class TestHub: ...@@ -26,28 +26,20 @@ class TestHub:
# Python cache as we run all hub tests in the same python process. # Python cache as we run all hub tests in the same python process.
def test_load_from_github(self): def test_load_from_github(self):
hub_model = hub.load( hub_model = hub.load("pytorch/vision", "resnet18", pretrained=True, progress=False)
'pytorch/vision',
'resnet18',
pretrained=True,
progress=False)
assert sum_of_model_parameters(hub_model).item() == pytest.approx(SUM_OF_PRETRAINED_RESNET18_PARAMS) assert sum_of_model_parameters(hub_model).item() == pytest.approx(SUM_OF_PRETRAINED_RESNET18_PARAMS)
def test_set_dir(self): def test_set_dir(self):
temp_dir = tempfile.gettempdir() temp_dir = tempfile.gettempdir()
hub.set_dir(temp_dir) hub.set_dir(temp_dir)
hub_model = hub.load( hub_model = hub.load("pytorch/vision", "resnet18", pretrained=True, progress=False)
'pytorch/vision',
'resnet18',
pretrained=True,
progress=False)
assert sum_of_model_parameters(hub_model).item() == pytest.approx(SUM_OF_PRETRAINED_RESNET18_PARAMS) assert sum_of_model_parameters(hub_model).item() == pytest.approx(SUM_OF_PRETRAINED_RESNET18_PARAMS)
assert os.path.exists(temp_dir + '/pytorch_vision_master') assert os.path.exists(temp_dir + "/pytorch_vision_master")
shutil.rmtree(temp_dir + '/pytorch_vision_master') shutil.rmtree(temp_dir + "/pytorch_vision_master")
def test_list_entrypoints(self): def test_list_entrypoints(self):
entry_lists = hub.list('pytorch/vision', force_reload=True) entry_lists = hub.list("pytorch/vision", force_reload=True)
assert 'resnet18' in entry_lists assert "resnet18" in entry_lists
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -4,25 +4,34 @@ import os ...@@ -4,25 +4,34 @@ import os
import sys import sys
from pathlib import Path from pathlib import Path
import pytest
import numpy as np import numpy as np
import pytest
import torch import torch
from PIL import Image, __version__ as PILLOW_VERSION
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
from common_utils import needs_cuda, assert_equal from common_utils import needs_cuda, assert_equal
from PIL import Image, __version__ as PILLOW_VERSION
from torchvision.io.image import ( from torchvision.io.image import (
decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file, decode_png,
encode_png, write_png, write_file, ImageReadMode, read_image) decode_jpeg,
encode_jpeg,
write_jpeg,
decode_image,
read_file,
encode_png,
write_png,
write_file,
ImageReadMode,
read_image,
)
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata") FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata")
IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder") IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder")
DAMAGED_JPEG = os.path.join(IMAGE_ROOT, 'damaged_jpeg') DAMAGED_JPEG = os.path.join(IMAGE_ROOT, "damaged_jpeg")
ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg") ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg")
INTERLACED_PNG = os.path.join(IMAGE_ROOT, "interlaced_png") INTERLACED_PNG = os.path.join(IMAGE_ROOT, "interlaced_png")
IS_WINDOWS = sys.platform in ('win32', 'cygwin') IS_WINDOWS = sys.platform in ("win32", "cygwin")
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split('.')) PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split("."))
def _get_safe_image_name(name): def _get_safe_image_name(name):
...@@ -35,9 +44,9 @@ def _get_safe_image_name(name): ...@@ -35,9 +44,9 @@ def _get_safe_image_name(name):
def get_images(directory, img_ext): def get_images(directory, img_ext):
assert os.path.isdir(directory) assert os.path.isdir(directory)
image_paths = glob.glob(directory + f'/**/*{img_ext}', recursive=True) image_paths = glob.glob(directory + f"/**/*{img_ext}", recursive=True)
for path in image_paths: for path in image_paths:
if path.split(os.sep)[-2] not in ['damaged_jpeg', 'jpeg_write']: if path.split(os.sep)[-2] not in ["damaged_jpeg", "jpeg_write"]:
yield path yield path
...@@ -54,15 +63,18 @@ def normalize_dimensions(img_pil): ...@@ -54,15 +63,18 @@ def normalize_dimensions(img_pil):
return img_pil return img_pil
@pytest.mark.parametrize('img_path', [ @pytest.mark.parametrize(
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) "img_path",
for jpeg_path in get_images(IMAGE_ROOT, ".jpg") [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")],
]) )
@pytest.mark.parametrize('pil_mode, mode', [ @pytest.mark.parametrize(
(None, ImageReadMode.UNCHANGED), "pil_mode, mode",
("L", ImageReadMode.GRAY), [
("RGB", ImageReadMode.RGB), (None, ImageReadMode.UNCHANGED),
]) ("L", ImageReadMode.GRAY),
("RGB", ImageReadMode.RGB),
],
)
def test_decode_jpeg(img_path, pil_mode, mode): def test_decode_jpeg(img_path, pil_mode, mode):
with Image.open(img_path) as img: with Image.open(img_path) as img:
...@@ -100,18 +112,21 @@ def test_decode_jpeg_errors(): ...@@ -100,18 +112,21 @@ def test_decode_jpeg_errors():
def test_decode_bad_huffman_images(): def test_decode_bad_huffman_images():
# sanity check: make sure we can decode the bad Huffman encoding # sanity check: make sure we can decode the bad Huffman encoding
bad_huff = read_file(os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg')) bad_huff = read_file(os.path.join(DAMAGED_JPEG, "bad_huffman.jpg"))
decode_jpeg(bad_huff) decode_jpeg(bad_huff)
@pytest.mark.parametrize('img_path', [ @pytest.mark.parametrize(
pytest.param(truncated_image, id=_get_safe_image_name(truncated_image)) "img_path",
for truncated_image in glob.glob(os.path.join(DAMAGED_JPEG, 'corrupt*.jpg')) [
]) pytest.param(truncated_image, id=_get_safe_image_name(truncated_image))
for truncated_image in glob.glob(os.path.join(DAMAGED_JPEG, "corrupt*.jpg"))
],
)
def test_damaged_corrupt_images(img_path): def test_damaged_corrupt_images(img_path):
# Truncated images should raise an exception # Truncated images should raise an exception
data = read_file(img_path) data = read_file(img_path)
if 'corrupt34' in img_path: if "corrupt34" in img_path:
match_message = "Image is incomplete or truncated" match_message = "Image is incomplete or truncated"
else: else:
match_message = "Unsupported marker type" match_message = "Unsupported marker type"
...@@ -119,17 +134,20 @@ def test_damaged_corrupt_images(img_path): ...@@ -119,17 +134,20 @@ def test_damaged_corrupt_images(img_path):
decode_jpeg(data) decode_jpeg(data)
@pytest.mark.parametrize('img_path', [ @pytest.mark.parametrize(
pytest.param(png_path, id=_get_safe_image_name(png_path)) "img_path",
for png_path in get_images(FAKEDATA_DIR, ".png") [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(FAKEDATA_DIR, ".png")],
]) )
@pytest.mark.parametrize('pil_mode, mode', [ @pytest.mark.parametrize(
(None, ImageReadMode.UNCHANGED), "pil_mode, mode",
("L", ImageReadMode.GRAY), [
("LA", ImageReadMode.GRAY_ALPHA), (None, ImageReadMode.UNCHANGED),
("RGB", ImageReadMode.RGB), ("L", ImageReadMode.GRAY),
("RGBA", ImageReadMode.RGB_ALPHA), ("LA", ImageReadMode.GRAY_ALPHA),
]) ("RGB", ImageReadMode.RGB),
("RGBA", ImageReadMode.RGB_ALPHA),
],
)
def test_decode_png(img_path, pil_mode, mode): def test_decode_png(img_path, pil_mode, mode):
with Image.open(img_path) as img: with Image.open(img_path) as img:
...@@ -160,10 +178,10 @@ def test_decode_png_errors(): ...@@ -160,10 +178,10 @@ def test_decode_png_errors():
decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8)) decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
@pytest.mark.parametrize('img_path', [ @pytest.mark.parametrize(
pytest.param(png_path, id=_get_safe_image_name(png_path)) "img_path",
for png_path in get_images(IMAGE_DIR, ".png") [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")],
]) )
def test_encode_png(img_path): def test_encode_png(img_path):
pil_image = Image.open(img_path) pil_image = Image.open(img_path)
img_pil = torch.from_numpy(np.array(pil_image)) img_pil = torch.from_numpy(np.array(pil_image))
...@@ -182,28 +200,26 @@ def test_encode_png_errors(): ...@@ -182,28 +200,26 @@ def test_encode_png_errors():
encode_png(torch.empty((3, 100, 100), dtype=torch.float32)) encode_png(torch.empty((3, 100, 100), dtype=torch.float32))
with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"): with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"):
encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), compression_level=-1)
compression_level=-1)
with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"): with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"):
encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), compression_level=10)
compression_level=10)
with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"): with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"):
encode_png(torch.empty((5, 100, 100), dtype=torch.uint8)) encode_png(torch.empty((5, 100, 100), dtype=torch.uint8))
@pytest.mark.parametrize('img_path', [ @pytest.mark.parametrize(
pytest.param(png_path, id=_get_safe_image_name(png_path)) "img_path",
for png_path in get_images(IMAGE_DIR, ".png") [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")],
]) )
def test_write_png(img_path, tmpdir): def test_write_png(img_path, tmpdir):
pil_image = Image.open(img_path) pil_image = Image.open(img_path)
img_pil = torch.from_numpy(np.array(pil_image)) img_pil = torch.from_numpy(np.array(pil_image))
img_pil = img_pil.permute(2, 0, 1) img_pil = img_pil.permute(2, 0, 1)
filename, _ = os.path.splitext(os.path.basename(img_path)) filename, _ = os.path.splitext(os.path.basename(img_path))
torch_png = os.path.join(tmpdir, '{0}_torch.png'.format(filename)) torch_png = os.path.join(tmpdir, "{0}_torch.png".format(filename))
write_png(img_pil, torch_png, compression_level=6) write_png(img_pil, torch_png, compression_level=6)
saved_image = torch.from_numpy(np.array(Image.open(torch_png))) saved_image = torch.from_numpy(np.array(Image.open(torch_png)))
saved_image = saved_image.permute(2, 0, 1) saved_image = saved_image.permute(2, 0, 1)
...@@ -212,9 +228,9 @@ def test_write_png(img_path, tmpdir): ...@@ -212,9 +228,9 @@ def test_write_png(img_path, tmpdir):
def test_read_file(tmpdir): def test_read_file(tmpdir):
fname, content = 'test1.bin', b'TorchVision\211\n' fname, content = "test1.bin", b"TorchVision\211\n"
fpath = os.path.join(tmpdir, fname) fpath = os.path.join(tmpdir, fname)
with open(fpath, 'wb') as f: with open(fpath, "wb") as f:
f.write(content) f.write(content)
data = read_file(fpath) data = read_file(fpath)
...@@ -223,13 +239,13 @@ def test_read_file(tmpdir): ...@@ -223,13 +239,13 @@ def test_read_file(tmpdir):
assert_equal(data, expected) assert_equal(data, expected)
with pytest.raises(RuntimeError, match="No such file or directory: 'tst'"): with pytest.raises(RuntimeError, match="No such file or directory: 'tst'"):
read_file('tst') read_file("tst")
def test_read_file_non_ascii(tmpdir): def test_read_file_non_ascii(tmpdir):
fname, content = '日本語(Japanese).bin', b'TorchVision\211\n' fname, content = "日本語(Japanese).bin", b"TorchVision\211\n"
fpath = os.path.join(tmpdir, fname) fpath = os.path.join(tmpdir, fname)
with open(fpath, 'wb') as f: with open(fpath, "wb") as f:
f.write(content) f.write(content)
data = read_file(fpath) data = read_file(fpath)
...@@ -239,37 +255,40 @@ def test_read_file_non_ascii(tmpdir): ...@@ -239,37 +255,40 @@ def test_read_file_non_ascii(tmpdir):
def test_write_file(tmpdir): def test_write_file(tmpdir):
fname, content = 'test1.bin', b'TorchVision\211\n' fname, content = "test1.bin", b"TorchVision\211\n"
fpath = os.path.join(tmpdir, fname) fpath = os.path.join(tmpdir, fname)
content_tensor = torch.tensor(list(content), dtype=torch.uint8) content_tensor = torch.tensor(list(content), dtype=torch.uint8)
write_file(fpath, content_tensor) write_file(fpath, content_tensor)
with open(fpath, 'rb') as f: with open(fpath, "rb") as f:
saved_content = f.read() saved_content = f.read()
os.unlink(fpath) os.unlink(fpath)
assert content == saved_content assert content == saved_content
def test_write_file_non_ascii(tmpdir): def test_write_file_non_ascii(tmpdir):
fname, content = '日本語(Japanese).bin', b'TorchVision\211\n' fname, content = "日本語(Japanese).bin", b"TorchVision\211\n"
fpath = os.path.join(tmpdir, fname) fpath = os.path.join(tmpdir, fname)
content_tensor = torch.tensor(list(content), dtype=torch.uint8) content_tensor = torch.tensor(list(content), dtype=torch.uint8)
write_file(fpath, content_tensor) write_file(fpath, content_tensor)
with open(fpath, 'rb') as f: with open(fpath, "rb") as f:
saved_content = f.read() saved_content = f.read()
os.unlink(fpath) os.unlink(fpath)
assert content == saved_content assert content == saved_content
@pytest.mark.parametrize('shape', [ @pytest.mark.parametrize(
(27, 27), "shape",
(60, 60), [
(105, 105), (27, 27),
]) (60, 60),
(105, 105),
],
)
def test_read_1_bit_png(shape, tmpdir): def test_read_1_bit_png(shape, tmpdir):
np_rng = np.random.RandomState(0) np_rng = np.random.RandomState(0)
image_path = os.path.join(tmpdir, f'test_{shape}.png') image_path = os.path.join(tmpdir, f"test_{shape}.png")
pixels = np_rng.rand(*shape) > 0.5 pixels = np_rng.rand(*shape) > 0.5
img = Image.fromarray(pixels) img = Image.fromarray(pixels)
img.save(image_path) img.save(image_path)
...@@ -278,18 +297,24 @@ def test_read_1_bit_png(shape, tmpdir): ...@@ -278,18 +297,24 @@ def test_read_1_bit_png(shape, tmpdir):
assert_equal(img1, img2) assert_equal(img1, img2)
@pytest.mark.parametrize('shape', [ @pytest.mark.parametrize(
(27, 27), "shape",
(60, 60), [
(105, 105), (27, 27),
]) (60, 60),
@pytest.mark.parametrize('mode', [ (105, 105),
ImageReadMode.UNCHANGED, ],
ImageReadMode.GRAY, )
]) @pytest.mark.parametrize(
"mode",
[
ImageReadMode.UNCHANGED,
ImageReadMode.GRAY,
],
)
def test_read_1_bit_png_consistency(shape, mode, tmpdir): def test_read_1_bit_png_consistency(shape, mode, tmpdir):
np_rng = np.random.RandomState(0) np_rng = np.random.RandomState(0)
image_path = os.path.join(tmpdir, f'test_{shape}.png') image_path = os.path.join(tmpdir, f"test_{shape}.png")
pixels = np_rng.rand(*shape) > 0.5 pixels = np_rng.rand(*shape) > 0.5
img = Image.fromarray(pixels) img = Image.fromarray(pixels)
img.save(image_path) img.save(image_path)
...@@ -308,30 +333,30 @@ def test_read_interlaced_png(): ...@@ -308,30 +333,30 @@ def test_read_interlaced_png():
@needs_cuda @needs_cuda
@pytest.mark.parametrize('img_path', [ @pytest.mark.parametrize(
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) "img_path",
for jpeg_path in get_images(IMAGE_ROOT, ".jpg") [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")],
]) )
@pytest.mark.parametrize('mode', [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB]) @pytest.mark.parametrize("mode", [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB])
@pytest.mark.parametrize('scripted', (False, True)) @pytest.mark.parametrize("scripted", (False, True))
def test_decode_jpeg_cuda(mode, img_path, scripted): def test_decode_jpeg_cuda(mode, img_path, scripted):
if 'cmyk' in img_path: if "cmyk" in img_path:
pytest.xfail("Decoding a CMYK jpeg isn't supported") pytest.xfail("Decoding a CMYK jpeg isn't supported")
data = read_file(img_path) data = read_file(img_path)
img = decode_image(data, mode=mode) img = decode_image(data, mode=mode)
f = torch.jit.script(decode_jpeg) if scripted else decode_jpeg f = torch.jit.script(decode_jpeg) if scripted else decode_jpeg
img_nvjpeg = f(data, mode=mode, device='cuda') img_nvjpeg = f(data, mode=mode, device="cuda")
# Some difference expected between jpeg implementations # Some difference expected between jpeg implementations
assert (img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2 assert (img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2
@needs_cuda @needs_cuda
@pytest.mark.parametrize('cuda_device', ('cuda', 'cuda:0', torch.device('cuda'))) @pytest.mark.parametrize("cuda_device", ("cuda", "cuda:0", torch.device("cuda")))
def test_decode_jpeg_cuda_device_param(cuda_device): def test_decode_jpeg_cuda_device_param(cuda_device):
"""Make sure we can pass a string or a torch.device as device param""" """Make sure we can pass a string or a torch.device as device param"""
path = next(path for path in get_images(IMAGE_ROOT, ".jpg") if 'cmyk' not in path) path = next(path for path in get_images(IMAGE_ROOT, ".jpg") if "cmyk" not in path)
data = read_file(path) data = read_file(path)
decode_jpeg(data, device=cuda_device) decode_jpeg(data, device=cuda_device)
...@@ -340,13 +365,13 @@ def test_decode_jpeg_cuda_device_param(cuda_device): ...@@ -340,13 +365,13 @@ def test_decode_jpeg_cuda_device_param(cuda_device):
def test_decode_jpeg_cuda_errors(): def test_decode_jpeg_cuda_errors():
data = read_file(next(get_images(IMAGE_ROOT, ".jpg"))) data = read_file(next(get_images(IMAGE_ROOT, ".jpg")))
with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"): with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
decode_jpeg(data.reshape(-1, 1), device='cuda') decode_jpeg(data.reshape(-1, 1), device="cuda")
with pytest.raises(RuntimeError, match="input tensor must be on CPU"): with pytest.raises(RuntimeError, match="input tensor must be on CPU"):
decode_jpeg(data.to('cuda'), device='cuda') decode_jpeg(data.to("cuda"), device="cuda")
with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"): with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
decode_jpeg(data.to(torch.float), device='cuda') decode_jpeg(data.to(torch.float), device="cuda")
with pytest.raises(RuntimeError, match="Expected a cuda device"): with pytest.raises(RuntimeError, match="Expected a cuda device"):
torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, 'cpu') torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, "cpu")
def test_encode_jpeg_errors(): def test_encode_jpeg_errors():
...@@ -354,12 +379,10 @@ def test_encode_jpeg_errors(): ...@@ -354,12 +379,10 @@ def test_encode_jpeg_errors():
with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"): with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32)) encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32))
with pytest.raises(ValueError, match="Image quality should be a positive number " with pytest.raises(ValueError, match="Image quality should be a positive number " "between 1 and 100"):
"between 1 and 100"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1) encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1)
with pytest.raises(ValueError, match="Image quality should be a positive number " with pytest.raises(ValueError, match="Image quality should be a positive number " "between 1 and 100"):
"between 1 and 100"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101) encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101)
with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"): with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"):
...@@ -380,14 +403,15 @@ def _collect_if(cond): ...@@ -380,14 +403,15 @@ def _collect_if(cond):
return test_func return test_func
else: else:
return pytest.mark.dont_collect(test_func) return pytest.mark.dont_collect(test_func)
return _inner return _inner
@_collect_if(cond=IS_WINDOWS) @_collect_if(cond=IS_WINDOWS)
@pytest.mark.parametrize('img_path', [ @pytest.mark.parametrize(
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) "img_path",
for jpeg_path in get_images(ENCODE_JPEG, ".jpg") [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
]) )
def test_encode_jpeg_reference(img_path): def test_encode_jpeg_reference(img_path):
# This test is *wrong*. # This test is *wrong*.
# It compares a torchvision-encoded jpeg with a PIL-encoded jpeg (the reference), but it # It compares a torchvision-encoded jpeg with a PIL-encoded jpeg (the reference), but it
...@@ -401,12 +425,11 @@ def test_encode_jpeg_reference(img_path): ...@@ -401,12 +425,11 @@ def test_encode_jpeg_reference(img_path):
# FIXME: make the correct tests pass on windows and remove this. # FIXME: make the correct tests pass on windows and remove this.
dirname = os.path.dirname(img_path) dirname = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path)) filename, _ = os.path.splitext(os.path.basename(img_path))
write_folder = os.path.join(dirname, 'jpeg_write') write_folder = os.path.join(dirname, "jpeg_write")
expected_file = os.path.join( expected_file = os.path.join(write_folder, "{0}_pil.jpg".format(filename))
write_folder, '{0}_pil.jpg'.format(filename))
img = decode_jpeg(read_file(img_path)) img = decode_jpeg(read_file(img_path))
with open(expected_file, 'rb') as f: with open(expected_file, "rb") as f:
pil_bytes = f.read() pil_bytes = f.read()
pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8) pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8)
for src_img in [img, img.contiguous()]: for src_img in [img, img.contiguous()]:
...@@ -416,10 +439,10 @@ def test_encode_jpeg_reference(img_path): ...@@ -416,10 +439,10 @@ def test_encode_jpeg_reference(img_path):
@_collect_if(cond=IS_WINDOWS) @_collect_if(cond=IS_WINDOWS)
@pytest.mark.parametrize('img_path', [ @pytest.mark.parametrize(
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) "img_path",
for jpeg_path in get_images(ENCODE_JPEG, ".jpg") [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
]) )
def test_write_jpeg_reference(img_path, tmpdir): def test_write_jpeg_reference(img_path, tmpdir):
# FIXME: Remove this eventually, see test_encode_jpeg_reference # FIXME: Remove this eventually, see test_encode_jpeg_reference
data = read_file(img_path) data = read_file(img_path)
...@@ -427,35 +450,31 @@ def test_write_jpeg_reference(img_path, tmpdir): ...@@ -427,35 +450,31 @@ def test_write_jpeg_reference(img_path, tmpdir):
basedir = os.path.dirname(img_path) basedir = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path)) filename, _ = os.path.splitext(os.path.basename(img_path))
torch_jpeg = os.path.join( torch_jpeg = os.path.join(tmpdir, "{0}_torch.jpg".format(filename))
tmpdir, '{0}_torch.jpg'.format(filename)) pil_jpeg = os.path.join(basedir, "jpeg_write", "{0}_pil.jpg".format(filename))
pil_jpeg = os.path.join(
basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename))
write_jpeg(img, torch_jpeg, quality=75) write_jpeg(img, torch_jpeg, quality=75)
with open(torch_jpeg, 'rb') as f: with open(torch_jpeg, "rb") as f:
torch_bytes = f.read() torch_bytes = f.read()
with open(pil_jpeg, 'rb') as f: with open(pil_jpeg, "rb") as f:
pil_bytes = f.read() pil_bytes = f.read()
assert_equal(torch_bytes, pil_bytes) assert_equal(torch_bytes, pil_bytes)
@pytest.mark.skipif(IS_WINDOWS, reason=( @pytest.mark.skipif(IS_WINDOWS, reason=("this test fails on windows because PIL uses libjpeg-turbo on windows"))
'this test fails on windows because PIL uses libjpeg-turbo on windows' @pytest.mark.parametrize(
)) "img_path",
@pytest.mark.parametrize('img_path', [ [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) )
for jpeg_path in get_images(ENCODE_JPEG, ".jpg")
])
def test_encode_jpeg(img_path): def test_encode_jpeg(img_path):
img = read_image(img_path) img = read_image(img_path)
pil_img = F.to_pil_image(img) pil_img = F.to_pil_image(img)
buf = io.BytesIO() buf = io.BytesIO()
pil_img.save(buf, format='JPEG', quality=75) pil_img.save(buf, format="JPEG", quality=75)
# pytorch can't read from raw bytes so we go through numpy # pytorch can't read from raw bytes so we go through numpy
pil_bytes = np.frombuffer(buf.getvalue(), dtype=np.uint8) pil_bytes = np.frombuffer(buf.getvalue(), dtype=np.uint8)
...@@ -466,28 +485,26 @@ def test_encode_jpeg(img_path): ...@@ -466,28 +485,26 @@ def test_encode_jpeg(img_path):
assert_equal(encoded_jpeg_torch, encoded_jpeg_pil) assert_equal(encoded_jpeg_torch, encoded_jpeg_pil)
@pytest.mark.skipif(IS_WINDOWS, reason=( @pytest.mark.skipif(IS_WINDOWS, reason=("this test fails on windows because PIL uses libjpeg-turbo on windows"))
'this test fails on windows because PIL uses libjpeg-turbo on windows' @pytest.mark.parametrize(
)) "img_path",
@pytest.mark.parametrize('img_path', [ [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) )
for jpeg_path in get_images(ENCODE_JPEG, ".jpg")
])
def test_write_jpeg(img_path, tmpdir): def test_write_jpeg(img_path, tmpdir):
tmpdir = Path(tmpdir) tmpdir = Path(tmpdir)
img = read_image(img_path) img = read_image(img_path)
pil_img = F.to_pil_image(img) pil_img = F.to_pil_image(img)
torch_jpeg = str(tmpdir / 'torch.jpg') torch_jpeg = str(tmpdir / "torch.jpg")
pil_jpeg = str(tmpdir / 'pil.jpg') pil_jpeg = str(tmpdir / "pil.jpg")
write_jpeg(img, torch_jpeg, quality=75) write_jpeg(img, torch_jpeg, quality=75)
pil_img.save(pil_jpeg, quality=75) pil_img.save(pil_jpeg, quality=75)
with open(torch_jpeg, 'rb') as f: with open(torch_jpeg, "rb") as f:
torch_bytes = f.read() torch_bytes = f.read()
with open(pil_jpeg, 'rb') as f: with open(pil_jpeg, "rb") as f:
pil_bytes = f.read() pil_bytes = f.read()
assert_equal(torch_bytes, pil_bytes) assert_equal(torch_bytes, pil_bytes)
......
...@@ -6,10 +6,10 @@ cleanly ignored in FB internal test infra. ...@@ -6,10 +6,10 @@ cleanly ignored in FB internal test infra.
""" """
import os import os
import pytest
import warnings import warnings
from urllib.error import URLError from urllib.error import URLError
import pytest
import torchvision.datasets.utils as utils import torchvision.datasets.utils as utils
...@@ -42,11 +42,11 @@ class TestDatasetUtils: ...@@ -42,11 +42,11 @@ class TestDatasetUtils:
filename = "filename" filename = "filename"
md5 = "md5" md5 = "md5"
mocked = mocker.patch('torchvision.datasets.utils.download_file_from_google_drive') mocked = mocker.patch("torchvision.datasets.utils.download_file_from_google_drive")
utils.download_url(url, tmpdir, filename, md5) utils.download_url(url, tmpdir, filename, md5)
mocked.assert_called_once_with(id, tmpdir, filename, md5) mocked.assert_called_once_with(id, tmpdir, filename, md5)
if __name__ == '__main__': if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
import pytest
import os
import contextlib import contextlib
import os
import sys import sys
import tempfile import tempfile
import torch
import torchvision.io as io
from torchvision import get_video_backend
import warnings import warnings
from urllib.error import URLError from urllib.error import URLError
import pytest
import torch
import torchvision.io as io
from common_utils import assert_equal from common_utils import assert_equal
from torchvision import get_video_backend
try: try:
import av import av
# Do a version test too # Do a version test too
io.video._check_av_available() io.video._check_av_available()
except ImportError: except ImportError:
...@@ -42,29 +43,30 @@ def temp_video(num_frames, height, width, fps, lossless=False, video_codec=None, ...@@ -42,29 +43,30 @@ def temp_video(num_frames, height, width, fps, lossless=False, video_codec=None,
raise ValueError("video_codec can't be specified together with lossless") raise ValueError("video_codec can't be specified together with lossless")
if options is not None: if options is not None:
raise ValueError("options can't be specified together with lossless") raise ValueError("options can't be specified together with lossless")
video_codec = 'libx264rgb' video_codec = "libx264rgb"
options = {'crf': '0'} options = {"crf": "0"}
if video_codec is None: if video_codec is None:
if get_video_backend() == "pyav": if get_video_backend() == "pyav":
video_codec = 'libx264' video_codec = "libx264"
else: else:
# when video_codec is not set, we assume it is libx264rgb which accepts # when video_codec is not set, we assume it is libx264rgb which accepts
# RGB pixel formats as input instead of YUV # RGB pixel formats as input instead of YUV
video_codec = 'libx264rgb' video_codec = "libx264rgb"
if options is None: if options is None:
options = {} options = {}
data = _create_video_frames(num_frames, height, width) data = _create_video_frames(num_frames, height, width)
with tempfile.NamedTemporaryFile(suffix='.mp4') as f: with tempfile.NamedTemporaryFile(suffix=".mp4") as f:
f.close() f.close()
io.write_video(f.name, data, fps=fps, video_codec=video_codec, options=options) io.write_video(f.name, data, fps=fps, video_codec=video_codec, options=options)
yield f.name, data yield f.name, data
os.unlink(f.name) os.unlink(f.name)
@pytest.mark.skipif(get_video_backend() != "pyav" and not io._HAS_VIDEO_OPT, @pytest.mark.skipif(
reason="video_reader backend not available") get_video_backend() != "pyav" and not io._HAS_VIDEO_OPT, reason="video_reader backend not available"
)
@pytest.mark.skipif(av is None, reason="PyAV unavailable") @pytest.mark.skipif(av is None, reason="PyAV unavailable")
class TestVideo: class TestVideo:
# compression adds artifacts, thus we add a tolerance of # compression adds artifacts, thus we add a tolerance of
...@@ -107,14 +109,14 @@ class TestVideo: ...@@ -107,14 +109,14 @@ class TestVideo:
assert pts == expected_pts assert pts == expected_pts
@pytest.mark.parametrize('start', range(5)) @pytest.mark.parametrize("start", range(5))
@pytest.mark.parametrize('offset', range(1, 4)) @pytest.mark.parametrize("offset", range(1, 4))
def test_read_partial_video(self, start, offset): def test_read_partial_video(self, start, offset):
with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data): with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
pts, _ = io.read_video_timestamps(f_name) pts, _ = io.read_video_timestamps(f_name)
lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1]) lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1])
s_data = data[start:(start + offset)] s_data = data[start : (start + offset)]
assert len(lv) == offset assert len(lv) == offset
assert_equal(s_data, lv) assert_equal(s_data, lv)
...@@ -125,22 +127,22 @@ class TestVideo: ...@@ -125,22 +127,22 @@ class TestVideo:
assert len(lv) == 4 assert len(lv) == 4
assert_equal(data[4:8], lv) assert_equal(data[4:8], lv)
@pytest.mark.parametrize('start', range(0, 80, 20)) @pytest.mark.parametrize("start", range(0, 80, 20))
@pytest.mark.parametrize('offset', range(1, 4)) @pytest.mark.parametrize("offset", range(1, 4))
def test_read_partial_video_bframes(self, start, offset): def test_read_partial_video_bframes(self, start, offset):
# do not use lossless encoding, to test the presence of B-frames # do not use lossless encoding, to test the presence of B-frames
options = {'bframes': '16', 'keyint': '10', 'min-keyint': '4'} options = {"bframes": "16", "keyint": "10", "min-keyint": "4"}
with temp_video(100, 300, 300, 5, options=options) as (f_name, data): with temp_video(100, 300, 300, 5, options=options) as (f_name, data):
pts, _ = io.read_video_timestamps(f_name) pts, _ = io.read_video_timestamps(f_name)
lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1]) lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1])
s_data = data[start:(start + offset)] s_data = data[start : (start + offset)]
assert len(lv) == offset assert len(lv) == offset
assert_equal(s_data, lv, rtol=0.0, atol=self.TOLERANCE) assert_equal(s_data, lv, rtol=0.0, atol=self.TOLERANCE)
lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7]) lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7])
# TODO fix this # TODO fix this
if get_video_backend() == 'pyav': if get_video_backend() == "pyav":
assert len(lv) == 4 assert len(lv) == 4
assert_equal(data[4:8], lv, rtol=0.0, atol=self.TOLERANCE) assert_equal(data[4:8], lv, rtol=0.0, atol=self.TOLERANCE)
else: else:
...@@ -156,7 +158,7 @@ class TestVideo: ...@@ -156,7 +158,7 @@ class TestVideo:
assert fps == 30 assert fps == 30
def test_read_timestamps_from_packet(self): def test_read_timestamps_from_packet(self):
with temp_video(10, 300, 300, 5, video_codec='mpeg4') as (f_name, data): with temp_video(10, 300, 300, 5, video_codec="mpeg4") as (f_name, data):
pts, _ = io.read_video_timestamps(f_name) pts, _ = io.read_video_timestamps(f_name)
# note: not all formats/codecs provide accurate information for computing the # note: not all formats/codecs provide accurate information for computing the
# timestamps. For the format that we use here, this information is available, # timestamps. For the format that we use here, this information is available,
...@@ -164,7 +166,7 @@ class TestVideo: ...@@ -164,7 +166,7 @@ class TestVideo:
with av.open(f_name) as container: with av.open(f_name) as container:
stream = container.streams[0] stream = container.streams[0]
# make sure we went through the optimized codepath # make sure we went through the optimized codepath
assert b'Lavc' in stream.codec_context.extradata assert b"Lavc" in stream.codec_context.extradata
pts_step = int(round(float(1 / (stream.average_rate * stream.time_base)))) pts_step = int(round(float(1 / (stream.average_rate * stream.time_base))))
num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration))) num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
expected_pts = [i * pts_step for i in range(num_frames)] expected_pts = [i * pts_step for i in range(num_frames)]
...@@ -173,7 +175,7 @@ class TestVideo: ...@@ -173,7 +175,7 @@ class TestVideo:
def test_read_video_pts_unit_sec(self): def test_read_video_pts_unit_sec(self):
with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data): with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
lv, _, info = io.read_video(f_name, pts_unit='sec') lv, _, info = io.read_video(f_name, pts_unit="sec")
assert_equal(data, lv) assert_equal(data, lv)
assert info["video_fps"] == 5 assert info["video_fps"] == 5
...@@ -181,7 +183,7 @@ class TestVideo: ...@@ -181,7 +183,7 @@ class TestVideo:
def test_read_timestamps_pts_unit_sec(self): def test_read_timestamps_pts_unit_sec(self):
with temp_video(10, 300, 300, 5) as (f_name, data): with temp_video(10, 300, 300, 5) as (f_name, data):
pts, _ = io.read_video_timestamps(f_name, pts_unit='sec') pts, _ = io.read_video_timestamps(f_name, pts_unit="sec")
with av.open(f_name) as container: with av.open(f_name) as container:
stream = container.streams[0] stream = container.streams[0]
...@@ -191,22 +193,22 @@ class TestVideo: ...@@ -191,22 +193,22 @@ class TestVideo:
assert pts == expected_pts assert pts == expected_pts
@pytest.mark.parametrize('start', range(5)) @pytest.mark.parametrize("start", range(5))
@pytest.mark.parametrize('offset', range(1, 4)) @pytest.mark.parametrize("offset", range(1, 4))
def test_read_partial_video_pts_unit_sec(self, start, offset): def test_read_partial_video_pts_unit_sec(self, start, offset):
with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data): with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
pts, _ = io.read_video_timestamps(f_name, pts_unit='sec') pts, _ = io.read_video_timestamps(f_name, pts_unit="sec")
lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1], pts_unit='sec') lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1], pts_unit="sec")
s_data = data[start:(start + offset)] s_data = data[start : (start + offset)]
assert len(lv) == offset assert len(lv) == offset
assert_equal(s_data, lv) assert_equal(s_data, lv)
with av.open(f_name) as container: with av.open(f_name) as container:
stream = container.streams[0] stream = container.streams[0]
lv, _, _ = io.read_video(f_name, lv, _, _ = io.read_video(
int(pts[4] * (1.0 / stream.time_base) + 1) * stream.time_base, pts[7], f_name, int(pts[4] * (1.0 / stream.time_base) + 1) * stream.time_base, pts[7], pts_unit="sec"
pts_unit='sec') )
if get_video_backend() == "pyav": if get_video_backend() == "pyav":
# for "video_reader" backend, we don't decode the closest early frame # for "video_reader" backend, we don't decode the closest early frame
# when the given start pts is not matching any frame pts # when the given start pts is not matching any frame pts
...@@ -214,8 +216,8 @@ class TestVideo: ...@@ -214,8 +216,8 @@ class TestVideo:
assert_equal(data[4:8], lv) assert_equal(data[4:8], lv)
def test_read_video_corrupted_file(self): def test_read_video_corrupted_file(self):
with tempfile.NamedTemporaryFile(suffix='.mp4') as f: with tempfile.NamedTemporaryFile(suffix=".mp4") as f:
f.write(b'This is not an mpg4 file') f.write(b"This is not an mpg4 file")
video, audio, info = io.read_video(f.name) video, audio, info = io.read_video(f.name)
assert isinstance(video, torch.Tensor) assert isinstance(video, torch.Tensor)
assert isinstance(audio, torch.Tensor) assert isinstance(audio, torch.Tensor)
...@@ -224,8 +226,8 @@ class TestVideo: ...@@ -224,8 +226,8 @@ class TestVideo:
assert info == {} assert info == {}
def test_read_video_timestamps_corrupted_file(self): def test_read_video_timestamps_corrupted_file(self):
with tempfile.NamedTemporaryFile(suffix='.mp4') as f: with tempfile.NamedTemporaryFile(suffix=".mp4") as f:
f.write(b'This is not an mpg4 file') f.write(b"This is not an mpg4 file")
video_pts, video_fps = io.read_video_timestamps(f.name) video_pts, video_fps = io.read_video_timestamps(f.name)
assert video_pts == [] assert video_pts == []
assert video_fps is None assert video_fps is None
...@@ -233,18 +235,18 @@ class TestVideo: ...@@ -233,18 +235,18 @@ class TestVideo:
@pytest.mark.skip(reason="Temporarily disabled due to new pyav") @pytest.mark.skip(reason="Temporarily disabled due to new pyav")
def test_read_video_partially_corrupted_file(self): def test_read_video_partially_corrupted_file(self):
with temp_video(5, 4, 4, 5, lossless=True) as (f_name, data): with temp_video(5, 4, 4, 5, lossless=True) as (f_name, data):
with open(f_name, 'r+b') as f: with open(f_name, "r+b") as f:
size = os.path.getsize(f_name) size = os.path.getsize(f_name)
bytes_to_overwrite = size // 10 bytes_to_overwrite = size // 10
# seek to the middle of the file # seek to the middle of the file
f.seek(5 * bytes_to_overwrite) f.seek(5 * bytes_to_overwrite)
# corrupt 10% of the file from the middle # corrupt 10% of the file from the middle
f.write(b'\xff' * bytes_to_overwrite) f.write(b"\xff" * bytes_to_overwrite)
# this exercises the container.decode assertion check # this exercises the container.decode assertion check
video, audio, info = io.read_video(f.name, pts_unit='sec') video, audio, info = io.read_video(f.name, pts_unit="sec")
# check that size is not equal to 5, but 3 # check that size is not equal to 5, but 3
# TODO fix this # TODO fix this
if get_video_backend() == 'pyav': if get_video_backend() == "pyav":
assert len(video) == 3 assert len(video) == 3
else: else:
assert len(video) == 4 assert len(video) == 4
...@@ -254,7 +256,7 @@ class TestVideo: ...@@ -254,7 +256,7 @@ class TestVideo:
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
assert_equal(video, data) assert_equal(video, data)
@pytest.mark.skipif(sys.platform == 'win32', reason='temporarily disabled on Windows') @pytest.mark.skipif(sys.platform == "win32", reason="temporarily disabled on Windows")
def test_write_video_with_audio(self, tmpdir): def test_write_video_with_audio(self, tmpdir):
f_name = os.path.join(VIDEO_DIR, "R6llTwEh07w.mp4") f_name = os.path.join(VIDEO_DIR, "R6llTwEh07w.mp4")
video_tensor, audio_tensor, info = io.read_video(f_name, pts_unit="sec") video_tensor, audio_tensor, info = io.read_video(f_name, pts_unit="sec")
...@@ -265,15 +267,13 @@ class TestVideo: ...@@ -265,15 +267,13 @@ class TestVideo:
video_tensor, video_tensor,
round(info["video_fps"]), round(info["video_fps"]),
video_codec="libx264rgb", video_codec="libx264rgb",
options={'crf': '0'}, options={"crf": "0"},
audio_array=audio_tensor, audio_array=audio_tensor,
audio_fps=info["audio_fps"], audio_fps=info["audio_fps"],
audio_codec="aac", audio_codec="aac",
) )
out_video_tensor, out_audio_tensor, out_info = io.read_video( out_video_tensor, out_audio_tensor, out_info = io.read_video(out_f_name, pts_unit="sec")
out_f_name, pts_unit="sec"
)
assert info["video_fps"] == out_info["video_fps"] assert info["video_fps"] == out_info["video_fps"]
assert_equal(video_tensor, out_video_tensor) assert_equal(video_tensor, out_video_tensor)
...@@ -289,5 +289,5 @@ class TestVideo: ...@@ -289,5 +289,5 @@ class TestVideo:
# TODO add tests for audio # TODO add tests for audio
if __name__ == '__main__': if __name__ == "__main__":
pytest.main(__file__) pytest.main(__file__)
import unittest import unittest
from torchvision import set_video_backend
import test_io import test_io
from torchvision import set_video_backend
# Disabling the video backend switching temporarily # Disabling the video backend switching temporarily
# set_video_backend('video_reader') # set_video_backend('video_reader')
if __name__ == '__main__': if __name__ == "__main__":
suite = unittest.TestLoader().loadTestsFromModule(test_io) suite = unittest.TestLoader().loadTestsFromModule(test_io)
unittest.TextTestRunner(verbosity=1).run(suite) unittest.TextTestRunner(verbosity=1).run(suite)
import os import functools
import io import io
import operator
import os
import sys import sys
from common_utils import map_nested_tensor_object, freeze_rng_state, set_rng_seed, cpu_and_gpu, needs_cuda import traceback
from _utils_internal import get_relative_path import warnings
from collections import OrderedDict from collections import OrderedDict
import functools
import operator import pytest
import torch import torch
import torch.fx import torch.fx
import torch.nn as nn import torch.nn as nn
import torchvision import torchvision
from _utils_internal import get_relative_path
from common_utils import map_nested_tensor_object, freeze_rng_state, set_rng_seed, cpu_and_gpu, needs_cuda
from torchvision import models from torchvision import models
import pytest
import warnings
import traceback
ACCEPT = os.getenv('EXPECTTEST_ACCEPT', '0') == '1' ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1"
def get_available_classification_models(): def get_available_classification_models():
...@@ -50,7 +51,7 @@ def _get_expected_file(name=None): ...@@ -50,7 +51,7 @@ def _get_expected_file(name=None):
# Note: for legacy reasons, the reference file names all had "ModelTest.test_" in their names # Note: for legacy reasons, the reference file names all had "ModelTest.test_" in their names
# We hardcode it here to avoid having to re-generate the reference files # We hardcode it here to avoid having to re-generate the reference files
expected_file = expected_file = os.path.join(expected_file_base, 'ModelTester.test_' + name) expected_file = expected_file = os.path.join(expected_file_base, "ModelTester.test_" + name)
expected_file += "_expect.pkl" expected_file += "_expect.pkl"
if not ACCEPT and not os.path.exists(expected_file): if not ACCEPT and not os.path.exists(expected_file):
...@@ -92,6 +93,7 @@ def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False): ...@@ -92,6 +93,7 @@ def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False):
def assert_export_import_module(m, args): def assert_export_import_module(m, args):
"""Check that the results of a model are the same after saving and loading""" """Check that the results of a model are the same after saving and loading"""
def get_export_import_copy(m): def get_export_import_copy(m):
"""Save and load a TorchScript model""" """Save and load a TorchScript model"""
buffer = io.BytesIO() buffer = io.BytesIO()
...@@ -115,15 +117,17 @@ def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False): ...@@ -115,15 +117,17 @@ def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False):
if a is not None: if a is not None:
torch.testing.assert_close(a, b, atol=tol, rtol=tol) torch.testing.assert_close(a, b, atol=tol, rtol=tol)
TEST_WITH_SLOW = os.getenv('PYTORCH_TEST_WITH_SLOW', '0') == '1' TEST_WITH_SLOW = os.getenv("PYTORCH_TEST_WITH_SLOW", "0") == "1"
if not TEST_WITH_SLOW or skip: if not TEST_WITH_SLOW or skip:
# TorchScript is not enabled, skip these tests # TorchScript is not enabled, skip these tests
msg = "The check_jit_scriptable test for {} was skipped. " \ msg = (
"This test checks if the module's results in TorchScript " \ "The check_jit_scriptable test for {} was skipped. "
"match eager and that it can be exported. To run these " \ "This test checks if the module's results in TorchScript "
"tests make sure you set the environment variable " \ "match eager and that it can be exported. To run these "
"PYTORCH_TEST_WITH_SLOW=1 and that the test is not " \ "tests make sure you set the environment variable "
"manually skipped.".format(nn_module.__class__.__name__) "PYTORCH_TEST_WITH_SLOW=1 and that the test is not "
"manually skipped.".format(nn_module.__class__.__name__)
)
warnings.warn(msg, RuntimeWarning) warnings.warn(msg, RuntimeWarning)
return None return None
...@@ -181,8 +185,8 @@ def _check_input_backprop(model, inputs): ...@@ -181,8 +185,8 @@ def _check_input_backprop(model, inputs):
# before they are compared to the eager model outputs. This is useful if the # before they are compared to the eager model outputs. This is useful if the
# model outputs are different between TorchScript / Eager mode # model outputs are different between TorchScript / Eager mode
script_model_unwrapper = { script_model_unwrapper = {
'googlenet': lambda x: x.logits, "googlenet": lambda x: x.logits,
'inception_v3': lambda x: x.logits, "inception_v3": lambda x: x.logits,
"fasterrcnn_resnet50_fpn": lambda x: x[1], "fasterrcnn_resnet50_fpn": lambda x: x[1],
"fasterrcnn_mobilenet_v3_large_fpn": lambda x: x[1], "fasterrcnn_mobilenet_v3_large_fpn": lambda x: x[1],
"fasterrcnn_mobilenet_v3_large_320_fpn": lambda x: x[1], "fasterrcnn_mobilenet_v3_large_320_fpn": lambda x: x[1],
...@@ -221,43 +225,41 @@ autocast_flaky_numerics = ( ...@@ -221,43 +225,41 @@ autocast_flaky_numerics = (
# The following contains configuration parameters for all models which are used by # The following contains configuration parameters for all models which are used by
# the _test_*_model methods. # the _test_*_model methods.
_model_params = { _model_params = {
'inception_v3': { "inception_v3": {"input_shape": (1, 3, 299, 299)},
'input_shape': (1, 3, 299, 299) "retinanet_resnet50_fpn": {
"num_classes": 20,
"score_thresh": 0.01,
"min_size": 224,
"max_size": 224,
"input_shape": (3, 224, 224),
}, },
'retinanet_resnet50_fpn': { "keypointrcnn_resnet50_fpn": {
'num_classes': 20, "num_classes": 2,
'score_thresh': 0.01, "min_size": 224,
'min_size': 224, "max_size": 224,
'max_size': 224, "box_score_thresh": 0.15,
'input_shape': (3, 224, 224), "input_shape": (3, 224, 224),
}, },
'keypointrcnn_resnet50_fpn': { "fasterrcnn_resnet50_fpn": {
'num_classes': 2, "num_classes": 20,
'min_size': 224, "min_size": 224,
'max_size': 224, "max_size": 224,
'box_score_thresh': 0.15, "input_shape": (3, 224, 224),
'input_shape': (3, 224, 224),
}, },
'fasterrcnn_resnet50_fpn': { "maskrcnn_resnet50_fpn": {
'num_classes': 20, "num_classes": 10,
'min_size': 224, "min_size": 224,
'max_size': 224, "max_size": 224,
'input_shape': (3, 224, 224), "input_shape": (3, 224, 224),
}, },
'maskrcnn_resnet50_fpn': { "fasterrcnn_mobilenet_v3_large_fpn": {
'num_classes': 10, "box_score_thresh": 0.02076,
'min_size': 224,
'max_size': 224,
'input_shape': (3, 224, 224),
}, },
'fasterrcnn_mobilenet_v3_large_fpn': { "fasterrcnn_mobilenet_v3_large_320_fpn": {
'box_score_thresh': 0.02076, "box_score_thresh": 0.02076,
"rpn_pre_nms_top_n_test": 1000,
"rpn_post_nms_top_n_test": 1000,
}, },
'fasterrcnn_mobilenet_v3_large_320_fpn': {
'box_score_thresh': 0.02076,
'rpn_pre_nms_top_n_test': 1000,
'rpn_post_nms_top_n_test': 1000,
}
} }
...@@ -271,7 +273,7 @@ def _make_sliced_model(model, stop_layer): ...@@ -271,7 +273,7 @@ def _make_sliced_model(model, stop_layer):
return new_model return new_model
@pytest.mark.parametrize('model_name', ['densenet121', 'densenet169', 'densenet201', 'densenet161']) @pytest.mark.parametrize("model_name", ["densenet121", "densenet169", "densenet201", "densenet161"])
def test_memory_efficient_densenet(model_name): def test_memory_efficient_densenet(model_name):
input_shape = (1, 3, 300, 300) input_shape = (1, 3, 300, 300)
x = torch.rand(input_shape) x = torch.rand(input_shape)
...@@ -296,9 +298,9 @@ def test_memory_efficient_densenet(model_name): ...@@ -296,9 +298,9 @@ def test_memory_efficient_densenet(model_name):
_check_input_backprop(model2, x) _check_input_backprop(model2, x)
@pytest.mark.parametrize('dilate_layer_2', (True, False)) @pytest.mark.parametrize("dilate_layer_2", (True, False))
@pytest.mark.parametrize('dilate_layer_3', (True, False)) @pytest.mark.parametrize("dilate_layer_3", (True, False))
@pytest.mark.parametrize('dilate_layer_4', (True, False)) @pytest.mark.parametrize("dilate_layer_4", (True, False))
def test_resnet_dilation(dilate_layer_2, dilate_layer_3, dilate_layer_4): def test_resnet_dilation(dilate_layer_2, dilate_layer_3, dilate_layer_4):
# TODO improve tests to also check that each layer has the right dimensionality # TODO improve tests to also check that each layer has the right dimensionality
model = models.__dict__["resnet50"](replace_stride_with_dilation=(dilate_layer_2, dilate_layer_3, dilate_layer_4)) model = models.__dict__["resnet50"](replace_stride_with_dilation=(dilate_layer_2, dilate_layer_3, dilate_layer_4))
...@@ -318,7 +320,7 @@ def test_mobilenet_v2_residual_setting(): ...@@ -318,7 +320,7 @@ def test_mobilenet_v2_residual_setting():
assert out.shape[-1] == 1000 assert out.shape[-1] == 1000
@pytest.mark.parametrize('model_name', ["mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small"]) @pytest.mark.parametrize("model_name", ["mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small"])
def test_mobilenet_norm_layer(model_name): def test_mobilenet_norm_layer(model_name):
model = models.__dict__[model_name]() model = models.__dict__[model_name]()
assert any(isinstance(x, nn.BatchNorm2d) for x in model.modules()) assert any(isinstance(x, nn.BatchNorm2d) for x in model.modules())
...@@ -327,16 +329,16 @@ def test_mobilenet_norm_layer(model_name): ...@@ -327,16 +329,16 @@ def test_mobilenet_norm_layer(model_name):
return nn.GroupNorm(32, num_channels) return nn.GroupNorm(32, num_channels)
model = models.__dict__[model_name](norm_layer=get_gn) model = models.__dict__[model_name](norm_layer=get_gn)
assert not(any(isinstance(x, nn.BatchNorm2d) for x in model.modules())) assert not (any(isinstance(x, nn.BatchNorm2d) for x in model.modules()))
assert any(isinstance(x, nn.GroupNorm) for x in model.modules()) assert any(isinstance(x, nn.GroupNorm) for x in model.modules())
def test_inception_v3_eval(): def test_inception_v3_eval():
# replacement for models.inception_v3(pretrained=True) that does not download weights # replacement for models.inception_v3(pretrained=True) that does not download weights
kwargs = {} kwargs = {}
kwargs['transform_input'] = True kwargs["transform_input"] = True
kwargs['aux_logits'] = True kwargs["aux_logits"] = True
kwargs['init_weights'] = False kwargs["init_weights"] = False
name = "inception_v3" name = "inception_v3"
model = models.Inception3(**kwargs) model = models.Inception3(**kwargs)
model.aux_logits = False model.aux_logits = False
...@@ -366,9 +368,9 @@ def test_fasterrcnn_double(): ...@@ -366,9 +368,9 @@ def test_fasterrcnn_double():
def test_googlenet_eval(): def test_googlenet_eval():
# replacement for models.googlenet(pretrained=True) that does not download weights # replacement for models.googlenet(pretrained=True) that does not download weights
kwargs = {} kwargs = {}
kwargs['transform_input'] = True kwargs["transform_input"] = True
kwargs['aux_logits'] = True kwargs["aux_logits"] = True
kwargs['init_weights'] = False kwargs["init_weights"] = False
name = "googlenet" name = "googlenet"
model = models.GoogLeNet(**kwargs) model = models.GoogLeNet(**kwargs)
model.aux_logits = False model.aux_logits = False
...@@ -392,7 +394,7 @@ def test_fasterrcnn_switch_devices(): ...@@ -392,7 +394,7 @@ def test_fasterrcnn_switch_devices():
model.cuda() model.cuda()
model.eval() model.eval()
input_shape = (3, 300, 300) input_shape = (3, 300, 300)
x = torch.rand(input_shape, device='cuda') x = torch.rand(input_shape, device="cuda")
model_input = [x] model_input = [x]
out = model(model_input) out = model(model_input)
assert model_input[0] is x assert model_input[0] is x
...@@ -422,30 +424,29 @@ def test_generalizedrcnn_transform_repr(): ...@@ -422,30 +424,29 @@ def test_generalizedrcnn_transform_repr():
image_mean = [0.485, 0.456, 0.406] image_mean = [0.485, 0.456, 0.406]
image_std = [0.229, 0.224, 0.225] image_std = [0.229, 0.224, 0.225]
t = models.detection.transform.GeneralizedRCNNTransform(min_size=min_size, t = models.detection.transform.GeneralizedRCNNTransform(
max_size=max_size, min_size=min_size, max_size=max_size, image_mean=image_mean, image_std=image_std
image_mean=image_mean, )
image_std=image_std)
# Check integrity of object __repr__ attribute # Check integrity of object __repr__ attribute
expected_string = 'GeneralizedRCNNTransform(' expected_string = "GeneralizedRCNNTransform("
_indent = '\n ' _indent = "\n "
expected_string += '{0}Normalize(mean={1}, std={2})'.format(_indent, image_mean, image_std) expected_string += "{0}Normalize(mean={1}, std={2})".format(_indent, image_mean, image_std)
expected_string += '{0}Resize(min_size=({1},), max_size={2}, '.format(_indent, min_size, max_size) expected_string += "{0}Resize(min_size=({1},), max_size={2}, ".format(_indent, min_size, max_size)
expected_string += "mode='bilinear')\n)" expected_string += "mode='bilinear')\n)"
assert t.__repr__() == expected_string assert t.__repr__() == expected_string
@pytest.mark.parametrize('model_name', get_available_classification_models()) @pytest.mark.parametrize("model_name", get_available_classification_models())
@pytest.mark.parametrize('dev', cpu_and_gpu()) @pytest.mark.parametrize("dev", cpu_and_gpu())
def test_classification_model(model_name, dev): def test_classification_model(model_name, dev):
set_rng_seed(0) set_rng_seed(0)
defaults = { defaults = {
'num_classes': 50, "num_classes": 50,
'input_shape': (1, 3, 224, 224), "input_shape": (1, 3, 224, 224),
} }
kwargs = {**defaults, **_model_params.get(model_name, {})} kwargs = {**defaults, **_model_params.get(model_name, {})}
input_shape = kwargs.pop('input_shape') input_shape = kwargs.pop("input_shape")
model = models.__dict__[model_name](**kwargs) model = models.__dict__[model_name](**kwargs)
model.eval().to(device=dev) model.eval().to(device=dev)
...@@ -468,17 +469,17 @@ def test_classification_model(model_name, dev): ...@@ -468,17 +469,17 @@ def test_classification_model(model_name, dev):
_check_input_backprop(model, x) _check_input_backprop(model, x)
@pytest.mark.parametrize('model_name', get_available_segmentation_models()) @pytest.mark.parametrize("model_name", get_available_segmentation_models())
@pytest.mark.parametrize('dev', cpu_and_gpu()) @pytest.mark.parametrize("dev", cpu_and_gpu())
def test_segmentation_model(model_name, dev): def test_segmentation_model(model_name, dev):
set_rng_seed(0) set_rng_seed(0)
defaults = { defaults = {
'num_classes': 10, "num_classes": 10,
'pretrained_backbone': False, "pretrained_backbone": False,
'input_shape': (1, 3, 32, 32), "input_shape": (1, 3, 32, 32),
} }
kwargs = {**defaults, **_model_params.get(model_name, {})} kwargs = {**defaults, **_model_params.get(model_name, {})}
input_shape = kwargs.pop('input_shape') input_shape = kwargs.pop("input_shape")
model = models.segmentation.__dict__[model_name](**kwargs) model = models.segmentation.__dict__[model_name](**kwargs)
model.eval().to(device=dev) model.eval().to(device=dev)
...@@ -517,27 +518,29 @@ def test_segmentation_model(model_name, dev): ...@@ -517,27 +518,29 @@ def test_segmentation_model(model_name, dev):
full_validation &= check_out(out) full_validation &= check_out(out)
if not full_validation: if not full_validation:
msg = "The output of {} could only be partially validated. " \ msg = (
"This is likely due to unit-test flakiness, but you may " \ "The output of {} could only be partially validated. "
"want to do additional manual checks if you made " \ "This is likely due to unit-test flakiness, but you may "
"significant changes to the codebase.".format(test_segmentation_model.__name__) "want to do additional manual checks if you made "
"significant changes to the codebase.".format(test_segmentation_model.__name__)
)
warnings.warn(msg, RuntimeWarning) warnings.warn(msg, RuntimeWarning)
pytest.skip(msg) pytest.skip(msg)
_check_input_backprop(model, x) _check_input_backprop(model, x)
@pytest.mark.parametrize('model_name', get_available_detection_models()) @pytest.mark.parametrize("model_name", get_available_detection_models())
@pytest.mark.parametrize('dev', cpu_and_gpu()) @pytest.mark.parametrize("dev", cpu_and_gpu())
def test_detection_model(model_name, dev): def test_detection_model(model_name, dev):
set_rng_seed(0) set_rng_seed(0)
defaults = { defaults = {
'num_classes': 50, "num_classes": 50,
'pretrained_backbone': False, "pretrained_backbone": False,
'input_shape': (3, 300, 300), "input_shape": (3, 300, 300),
} }
kwargs = {**defaults, **_model_params.get(model_name, {})} kwargs = {**defaults, **_model_params.get(model_name, {})}
input_shape = kwargs.pop('input_shape') input_shape = kwargs.pop("input_shape")
model = models.detection.__dict__[model_name](**kwargs) model = models.detection.__dict__[model_name](**kwargs)
model.eval().to(device=dev) model.eval().to(device=dev)
...@@ -565,7 +568,7 @@ def test_detection_model(model_name, dev): ...@@ -565,7 +568,7 @@ def test_detection_model(model_name, dev):
return tensor return tensor
ith_index = num_elems // num_samples ith_index = num_elems // num_samples
return tensor[ith_index - 1::ith_index] return tensor[ith_index - 1 :: ith_index]
def compute_mean_std(tensor): def compute_mean_std(tensor):
# can't compute mean of integral tensor # can't compute mean of integral tensor
...@@ -588,8 +591,9 @@ def test_detection_model(model_name, dev): ...@@ -588,8 +591,9 @@ def test_detection_model(model_name, dev):
# scores. # scores.
expected_file = _get_expected_file(model_name) expected_file = _get_expected_file(model_name)
expected = torch.load(expected_file) expected = torch.load(expected_file)
torch.testing.assert_close(output[0]["scores"], expected[0]["scores"], rtol=prec, atol=prec, torch.testing.assert_close(
check_device=False, check_dtype=False) output[0]["scores"], expected[0]["scores"], rtol=prec, atol=prec, check_device=False, check_dtype=False
)
# Note: Fmassa proposed turning off NMS by adapting the threshold # Note: Fmassa proposed turning off NMS by adapting the threshold
# and then using the Hungarian algorithm as in DETR to find the # and then using the Hungarian algorithm as in DETR to find the
...@@ -610,17 +614,19 @@ def test_detection_model(model_name, dev): ...@@ -610,17 +614,19 @@ def test_detection_model(model_name, dev):
full_validation &= check_out(out) full_validation &= check_out(out)
if not full_validation: if not full_validation:
msg = "The output of {} could only be partially validated. " \ msg = (
"This is likely due to unit-test flakiness, but you may " \ "The output of {} could only be partially validated. "
"want to do additional manual checks if you made " \ "This is likely due to unit-test flakiness, but you may "
"significant changes to the codebase.".format(test_detection_model.__name__) "want to do additional manual checks if you made "
"significant changes to the codebase.".format(test_detection_model.__name__)
)
warnings.warn(msg, RuntimeWarning) warnings.warn(msg, RuntimeWarning)
pytest.skip(msg) pytest.skip(msg)
_check_input_backprop(model, model_input) _check_input_backprop(model, model_input)
@pytest.mark.parametrize('model_name', get_available_detection_models()) @pytest.mark.parametrize("model_name", get_available_detection_models())
def test_detection_model_validation(model_name): def test_detection_model_validation(model_name):
set_rng_seed(0) set_rng_seed(0)
model = models.detection.__dict__[model_name](num_classes=50, pretrained_backbone=False) model = models.detection.__dict__[model_name](num_classes=50, pretrained_backbone=False)
...@@ -632,25 +638,25 @@ def test_detection_model_validation(model_name): ...@@ -632,25 +638,25 @@ def test_detection_model_validation(model_name):
model(x) model(x)
# validate type # validate type
targets = [{'boxes': 0.}] targets = [{"boxes": 0.0}]
with pytest.raises(ValueError): with pytest.raises(ValueError):
model(x, targets=targets) model(x, targets=targets)
# validate boxes shape # validate boxes shape
for boxes in (torch.rand((4,)), torch.rand((1, 5))): for boxes in (torch.rand((4,)), torch.rand((1, 5))):
targets = [{'boxes': boxes}] targets = [{"boxes": boxes}]
with pytest.raises(ValueError): with pytest.raises(ValueError):
model(x, targets=targets) model(x, targets=targets)
# validate that no degenerate boxes are present # validate that no degenerate boxes are present
boxes = torch.tensor([[1, 3, 1, 4], [2, 4, 3, 4]]) boxes = torch.tensor([[1, 3, 1, 4], [2, 4, 3, 4]])
targets = [{'boxes': boxes}] targets = [{"boxes": boxes}]
with pytest.raises(ValueError): with pytest.raises(ValueError):
model(x, targets=targets) model(x, targets=targets)
@pytest.mark.parametrize('model_name', get_available_video_models()) @pytest.mark.parametrize("model_name", get_available_video_models())
@pytest.mark.parametrize('dev', cpu_and_gpu()) @pytest.mark.parametrize("dev", cpu_and_gpu())
def test_video_model(model_name, dev): def test_video_model(model_name, dev):
# the default input shape is # the default input shape is
# bs * num_channels * clip_len * h *w # bs * num_channels * clip_len * h *w
...@@ -673,25 +679,29 @@ def test_video_model(model_name, dev): ...@@ -673,25 +679,29 @@ def test_video_model(model_name, dev):
_check_input_backprop(model, x) _check_input_backprop(model, x)
@pytest.mark.skipif(not ('fbgemm' in torch.backends.quantized.supported_engines and @pytest.mark.skipif(
'qnnpack' in torch.backends.quantized.supported_engines), not (
reason="This Pytorch Build has not been built with fbgemm and qnnpack") "fbgemm" in torch.backends.quantized.supported_engines
@pytest.mark.parametrize('model_name', get_available_quantizable_models()) and "qnnpack" in torch.backends.quantized.supported_engines
),
reason="This Pytorch Build has not been built with fbgemm and qnnpack",
)
@pytest.mark.parametrize("model_name", get_available_quantizable_models())
def test_quantized_classification_model(model_name): def test_quantized_classification_model(model_name):
defaults = { defaults = {
'input_shape': (1, 3, 224, 224), "input_shape": (1, 3, 224, 224),
'pretrained': False, "pretrained": False,
'quantize': True, "quantize": True,
} }
kwargs = {**defaults, **_model_params.get(model_name, {})} kwargs = {**defaults, **_model_params.get(model_name, {})}
input_shape = kwargs.pop('input_shape') input_shape = kwargs.pop("input_shape")
# First check if quantize=True provides models that can run with input data # First check if quantize=True provides models that can run with input data
model = torchvision.models.quantization.__dict__[model_name](**kwargs) model = torchvision.models.quantization.__dict__[model_name](**kwargs)
x = torch.rand(input_shape) x = torch.rand(input_shape)
model(x) model(x)
kwargs['quantize'] = False kwargs["quantize"] = False
for eval_mode in [True, False]: for eval_mode in [True, False]:
model = torchvision.models.quantization.__dict__[model_name](**kwargs) model = torchvision.models.quantization.__dict__[model_name](**kwargs)
if eval_mode: if eval_mode:
...@@ -717,5 +727,5 @@ def test_quantized_classification_model(model_name): ...@@ -717,5 +727,5 @@ def test_quantized_classification_model(model_name):
raise AssertionError(f"model cannot be scripted. Traceback = {str(tb)}") from e raise AssertionError(f"model cannot be scripted. Traceback = {str(tb)}") from e
if __name__ == '__main__': if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
import pytest
import torch import torch
from common_utils import assert_equal from common_utils import assert_equal
from torchvision.models.detection.anchor_utils import AnchorGenerator, DefaultBoxGenerator from torchvision.models.detection.anchor_utils import AnchorGenerator, DefaultBoxGenerator
from torchvision.models.detection.image_list import ImageList from torchvision.models.detection.image_list import ImageList
import pytest
class Tester: class Tester:
def test_incorrect_anchors(self): def test_incorrect_anchors(self):
incorrect_sizes = ((2, 4, 8), (32, 8), ) incorrect_sizes = (
(2, 4, 8),
(32, 8),
)
incorrect_aspects = (0.5, 1.0) incorrect_aspects = (0.5, 1.0)
anc = AnchorGenerator(incorrect_sizes, incorrect_aspects) anc = AnchorGenerator(incorrect_sizes, incorrect_aspects)
image1 = torch.randn(3, 800, 800) image1 = torch.randn(3, 800, 800)
...@@ -49,15 +52,19 @@ class Tester: ...@@ -49,15 +52,19 @@ class Tester:
for sizes, num_anchors_per_loc in zip(grid_sizes, model.num_anchors_per_location()): for sizes, num_anchors_per_loc in zip(grid_sizes, model.num_anchors_per_location()):
num_anchors_estimated += sizes[0] * sizes[1] * num_anchors_per_loc num_anchors_estimated += sizes[0] * sizes[1] * num_anchors_per_loc
anchors_output = torch.tensor([[-5., -5., 5., 5.], anchors_output = torch.tensor(
[0., -5., 10., 5.], [
[5., -5., 15., 5.], [-5.0, -5.0, 5.0, 5.0],
[-5., 0., 5., 10.], [0.0, -5.0, 10.0, 5.0],
[0., 0., 10., 10.], [5.0, -5.0, 15.0, 5.0],
[5., 0., 15., 10.], [-5.0, 0.0, 5.0, 10.0],
[-5., 5., 5., 15.], [0.0, 0.0, 10.0, 10.0],
[0., 5., 10., 15.], [5.0, 0.0, 15.0, 10.0],
[5., 5., 15., 15.]]) [-5.0, 5.0, 5.0, 15.0],
[0.0, 5.0, 10.0, 15.0],
[5.0, 5.0, 15.0, 15.0],
]
)
assert num_anchors_estimated == 9 assert num_anchors_estimated == 9
assert len(anchors) == 2 assert len(anchors) == 2
...@@ -76,12 +83,14 @@ class Tester: ...@@ -76,12 +83,14 @@ class Tester:
model.eval() model.eval()
dboxes = model(images, features) dboxes = model(images, features)
dboxes_output = torch.tensor([ dboxes_output = torch.tensor(
[6.3750, 6.3750, 8.6250, 8.6250], [
[4.7443, 4.7443, 10.2557, 10.2557], [6.3750, 6.3750, 8.6250, 8.6250],
[5.9090, 6.7045, 9.0910, 8.2955], [4.7443, 4.7443, 10.2557, 10.2557],
[6.7045, 5.9090, 8.2955, 9.0910] [5.9090, 6.7045, 9.0910, 8.2955],
]) [6.7045, 5.9090, 8.2955, 9.0910],
]
)
assert len(dboxes) == 2 assert len(dboxes) == 2
assert tuple(dboxes[0].shape) == (4, 4) assert tuple(dboxes[0].shape) == (4, 4)
......
import pytest
import torch import torch
import torchvision.models import torchvision.models
from torchvision.ops import MultiScaleRoIAlign
from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork
from torchvision.models.detection.roi_heads import RoIHeads
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead
import pytest
from common_utils import assert_equal from common_utils import assert_equal
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead
from torchvision.models.detection.roi_heads import RoIHeads
from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork
from torchvision.ops import MultiScaleRoIAlign
class TestModelsDetectionNegativeSamples: class TestModelsDetectionNegativeSamples:
def _make_empty_sample(self, add_masks=False, add_keypoints=False): def _make_empty_sample(self, add_masks=False, add_keypoints=False):
images = [torch.rand((3, 100, 100), dtype=torch.float32)] images = [torch.rand((3, 100, 100), dtype=torch.float32)]
boxes = torch.zeros((0, 4), dtype=torch.float32) boxes = torch.zeros((0, 4), dtype=torch.float32)
negative_target = {"boxes": boxes, negative_target = {
"labels": torch.zeros(0, dtype=torch.int64), "boxes": boxes,
"image_id": 4, "labels": torch.zeros(0, dtype=torch.int64),
"area": (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]), "image_id": 4,
"iscrowd": torch.zeros((0,), dtype=torch.int64)} "area": (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]),
"iscrowd": torch.zeros((0,), dtype=torch.int64),
}
if add_masks: if add_masks:
negative_target["masks"] = torch.zeros(0, 100, 100, dtype=torch.uint8) negative_target["masks"] = torch.zeros(0, 100, 100, dtype=torch.uint8)
...@@ -36,16 +35,10 @@ class TestModelsDetectionNegativeSamples: ...@@ -36,16 +35,10 @@ class TestModelsDetectionNegativeSamples:
anchor_sizes = ((32,), (64,), (128,), (256,), (512,)) anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
rpn_anchor_generator = AnchorGenerator( rpn_anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
anchor_sizes, aspect_ratios
)
rpn_head = RPNHead(4, rpn_anchor_generator.num_anchors_per_location()[0]) rpn_head = RPNHead(4, rpn_anchor_generator.num_anchors_per_location()[0])
head = RegionProposalNetwork( head = RegionProposalNetwork(rpn_anchor_generator, rpn_head, 0.5, 0.3, 256, 0.5, 2000, 2000, 0.7, 0.05)
rpn_anchor_generator, rpn_head,
0.5, 0.3,
256, 0.5,
2000, 2000, 0.7, 0.05)
labels, matched_gt_boxes = head.assign_targets_to_anchors(anchors, targets) labels, matched_gt_boxes = head.assign_targets_to_anchors(anchors, targets)
...@@ -63,29 +56,29 @@ class TestModelsDetectionNegativeSamples: ...@@ -63,29 +56,29 @@ class TestModelsDetectionNegativeSamples:
gt_boxes = [torch.zeros((0, 4), dtype=torch.float32)] gt_boxes = [torch.zeros((0, 4), dtype=torch.float32)]
gt_labels = [torch.tensor([[0]], dtype=torch.int64)] gt_labels = [torch.tensor([[0]], dtype=torch.int64)]
box_roi_pool = MultiScaleRoIAlign( box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2)
featmap_names=['0', '1', '2', '3'],
output_size=7,
sampling_ratio=2)
resolution = box_roi_pool.output_size[0] resolution = box_roi_pool.output_size[0]
representation_size = 1024 representation_size = 1024
box_head = TwoMLPHead( box_head = TwoMLPHead(4 * resolution ** 2, representation_size)
4 * resolution ** 2,
representation_size)
representation_size = 1024 representation_size = 1024
box_predictor = FastRCNNPredictor( box_predictor = FastRCNNPredictor(representation_size, 2)
representation_size,
2)
roi_heads = RoIHeads( roi_heads = RoIHeads(
# Box # Box
box_roi_pool, box_head, box_predictor, box_roi_pool,
0.5, 0.5, box_head,
512, 0.25, box_predictor,
0.5,
0.5,
512,
0.25,
None, None,
0.05, 0.5, 100) 0.05,
0.5,
100,
)
matched_idxs, labels = roi_heads.assign_targets_to_proposals(proposals, gt_boxes, gt_labels) matched_idxs, labels = roi_heads.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
...@@ -97,61 +90,61 @@ class TestModelsDetectionNegativeSamples: ...@@ -97,61 +90,61 @@ class TestModelsDetectionNegativeSamples:
assert labels[0].shape == torch.Size([proposals[0].shape[0]]) assert labels[0].shape == torch.Size([proposals[0].shape[0]])
assert labels[0].dtype == torch.int64 assert labels[0].dtype == torch.int64
@pytest.mark.parametrize('name', [ @pytest.mark.parametrize(
"fasterrcnn_resnet50_fpn", "name",
"fasterrcnn_mobilenet_v3_large_fpn", [
"fasterrcnn_mobilenet_v3_large_320_fpn", "fasterrcnn_resnet50_fpn",
]) "fasterrcnn_mobilenet_v3_large_fpn",
"fasterrcnn_mobilenet_v3_large_320_fpn",
],
)
def test_forward_negative_sample_frcnn(self, name): def test_forward_negative_sample_frcnn(self, name):
model = torchvision.models.detection.__dict__[name]( model = torchvision.models.detection.__dict__[name](num_classes=2, min_size=100, max_size=100)
num_classes=2, min_size=100, max_size=100)
images, targets = self._make_empty_sample() images, targets = self._make_empty_sample()
loss_dict = model(images, targets) loss_dict = model(images, targets)
assert_equal(loss_dict["loss_box_reg"], torch.tensor(0.)) assert_equal(loss_dict["loss_box_reg"], torch.tensor(0.0))
assert_equal(loss_dict["loss_rpn_box_reg"], torch.tensor(0.)) assert_equal(loss_dict["loss_rpn_box_reg"], torch.tensor(0.0))
def test_forward_negative_sample_mrcnn(self): def test_forward_negative_sample_mrcnn(self):
model = torchvision.models.detection.maskrcnn_resnet50_fpn( model = torchvision.models.detection.maskrcnn_resnet50_fpn(num_classes=2, min_size=100, max_size=100)
num_classes=2, min_size=100, max_size=100)
images, targets = self._make_empty_sample(add_masks=True) images, targets = self._make_empty_sample(add_masks=True)
loss_dict = model(images, targets) loss_dict = model(images, targets)
assert_equal(loss_dict["loss_box_reg"], torch.tensor(0.)) assert_equal(loss_dict["loss_box_reg"], torch.tensor(0.0))
assert_equal(loss_dict["loss_rpn_box_reg"], torch.tensor(0.)) assert_equal(loss_dict["loss_rpn_box_reg"], torch.tensor(0.0))
assert_equal(loss_dict["loss_mask"], torch.tensor(0.)) assert_equal(loss_dict["loss_mask"], torch.tensor(0.0))
def test_forward_negative_sample_krcnn(self): def test_forward_negative_sample_krcnn(self):
model = torchvision.models.detection.keypointrcnn_resnet50_fpn( model = torchvision.models.detection.keypointrcnn_resnet50_fpn(num_classes=2, min_size=100, max_size=100)
num_classes=2, min_size=100, max_size=100)
images, targets = self._make_empty_sample(add_keypoints=True) images, targets = self._make_empty_sample(add_keypoints=True)
loss_dict = model(images, targets) loss_dict = model(images, targets)
assert_equal(loss_dict["loss_box_reg"], torch.tensor(0.)) assert_equal(loss_dict["loss_box_reg"], torch.tensor(0.0))
assert_equal(loss_dict["loss_rpn_box_reg"], torch.tensor(0.)) assert_equal(loss_dict["loss_rpn_box_reg"], torch.tensor(0.0))
assert_equal(loss_dict["loss_keypoint"], torch.tensor(0.)) assert_equal(loss_dict["loss_keypoint"], torch.tensor(0.0))
def test_forward_negative_sample_retinanet(self): def test_forward_negative_sample_retinanet(self):
model = torchvision.models.detection.retinanet_resnet50_fpn( model = torchvision.models.detection.retinanet_resnet50_fpn(
num_classes=2, min_size=100, max_size=100, pretrained_backbone=False) num_classes=2, min_size=100, max_size=100, pretrained_backbone=False
)
images, targets = self._make_empty_sample() images, targets = self._make_empty_sample()
loss_dict = model(images, targets) loss_dict = model(images, targets)
assert_equal(loss_dict["bbox_regression"], torch.tensor(0.)) assert_equal(loss_dict["bbox_regression"], torch.tensor(0.0))
def test_forward_negative_sample_ssd(self): def test_forward_negative_sample_ssd(self):
model = torchvision.models.detection.ssd300_vgg16( model = torchvision.models.detection.ssd300_vgg16(num_classes=2, pretrained_backbone=False)
num_classes=2, pretrained_backbone=False)
images, targets = self._make_empty_sample() images, targets = self._make_empty_sample()
loss_dict = model(images, targets) loss_dict = model(images, targets)
assert_equal(loss_dict["bbox_regression"], torch.tensor(0.)) assert_equal(loss_dict["bbox_regression"], torch.tensor(0.0))
if __name__ == '__main__': if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment