Unverified Commit 5f0edb97 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Add ufmt (usort + black) as code formatter (#4384)



* add ufmt as code formatter

* cleanup

* quote ufmt requirement

* split imports into more groups

* regenerate circleci config

* fix CI

* clarify local testing utils section

* use ufmt pre-commit hook

* split relative imports into local category

* Revert "split relative imports into local category"

This reverts commit f2e224cde2008c56c9347c1f69746d39065cdd51.

* pin black and usort dependencies

* fix local test utils detection

* fix ufmt rev

* add reference utils to local category

* fix usort config

* remove custom categories sorting

* Run pre-commit without fixing flake8

* got a double import in merge
Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent e45489b1
import copy import copy
import pytest
import torch import torch
from common_utils import assert_equal
from torchvision.models.detection import _utils from torchvision.models.detection import _utils
from torchvision.models.detection.transform import GeneralizedRCNNTransform
import pytest
from torchvision.models.detection import backbone_utils from torchvision.models.detection import backbone_utils
from common_utils import assert_equal from torchvision.models.detection.transform import GeneralizedRCNNTransform
class TestModelsDetectionUtils: class TestModelsDetectionUtils:
def test_balanced_positive_negative_sampler(self): def test_balanced_positive_negative_sampler(self):
sampler = _utils.BalancedPositiveNegativeSampler(4, 0.25) sampler = _utils.BalancedPositiveNegativeSampler(4, 0.25)
# keep all 6 negatives first, then add 3 positives, last two are ignore # keep all 6 negatives first, then add 3 positives, last two are ignore
...@@ -22,16 +22,13 @@ class TestModelsDetectionUtils: ...@@ -22,16 +22,13 @@ class TestModelsDetectionUtils:
assert neg[0].sum() == 3 assert neg[0].sum() == 3
assert neg[0][0:6].sum() == 3 assert neg[0][0:6].sum() == 3
@pytest.mark.parametrize('train_layers, exp_froz_params', [ @pytest.mark.parametrize("train_layers, exp_froz_params", [(0, 53), (1, 43), (2, 24), (3, 11), (4, 1), (5, 0)])
(0, 53), (1, 43), (2, 24), (3, 11), (4, 1), (5, 0)
])
def test_resnet_fpn_backbone_frozen_layers(self, train_layers, exp_froz_params): def test_resnet_fpn_backbone_frozen_layers(self, train_layers, exp_froz_params):
# we know how many initial layers and parameters of the network should # we know how many initial layers and parameters of the network should
# be frozen for each trainable_backbone_layers parameter value # be frozen for each trainable_backbone_layers parameter value
# i.e all 53 params are frozen if trainable_backbone_layers=0 # i.e all 53 params are frozen if trainable_backbone_layers=0
# ad first 24 params are frozen if trainable_backbone_layers=2 # ad first 24 params are frozen if trainable_backbone_layers=2
model = backbone_utils.resnet_fpn_backbone( model = backbone_utils.resnet_fpn_backbone("resnet50", pretrained=False, trainable_layers=train_layers)
'resnet50', pretrained=False, trainable_layers=train_layers)
# boolean list that is true if the param at that index is frozen # boolean list that is true if the param at that index is frozen
is_frozen = [not parameter.requires_grad for _, parameter in model.named_parameters()] is_frozen = [not parameter.requires_grad for _, parameter in model.named_parameters()]
# check that expected initial number of layers are frozen # check that expected initial number of layers are frozen
...@@ -40,34 +37,37 @@ class TestModelsDetectionUtils: ...@@ -40,34 +37,37 @@ class TestModelsDetectionUtils:
def test_validate_resnet_inputs_detection(self): def test_validate_resnet_inputs_detection(self):
# default number of backbone layers to train # default number of backbone layers to train
ret = backbone_utils._validate_trainable_layers( ret = backbone_utils._validate_trainable_layers(
pretrained=True, trainable_backbone_layers=None, max_value=5, default_value=3) pretrained=True, trainable_backbone_layers=None, max_value=5, default_value=3
)
assert ret == 3 assert ret == 3
# can't go beyond 5 # can't go beyond 5
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
ret = backbone_utils._validate_trainable_layers( ret = backbone_utils._validate_trainable_layers(
pretrained=True, trainable_backbone_layers=6, max_value=5, default_value=3) pretrained=True, trainable_backbone_layers=6, max_value=5, default_value=3
)
# if not pretrained, should use all trainable layers and warn # if not pretrained, should use all trainable layers and warn
with pytest.warns(UserWarning): with pytest.warns(UserWarning):
ret = backbone_utils._validate_trainable_layers( ret = backbone_utils._validate_trainable_layers(
pretrained=False, trainable_backbone_layers=0, max_value=5, default_value=3) pretrained=False, trainable_backbone_layers=0, max_value=5, default_value=3
)
assert ret == 5 assert ret == 5
def test_transform_copy_targets(self): def test_transform_copy_targets(self):
transform = GeneralizedRCNNTransform(300, 500, torch.zeros(3), torch.ones(3)) transform = GeneralizedRCNNTransform(300, 500, torch.zeros(3), torch.ones(3))
image = [torch.rand(3, 200, 300), torch.rand(3, 200, 200)] image = [torch.rand(3, 200, 300), torch.rand(3, 200, 200)]
targets = [{'boxes': torch.rand(3, 4)}, {'boxes': torch.rand(2, 4)}] targets = [{"boxes": torch.rand(3, 4)}, {"boxes": torch.rand(2, 4)}]
targets_copy = copy.deepcopy(targets) targets_copy = copy.deepcopy(targets)
out = transform(image, targets) # noqa: F841 out = transform(image, targets) # noqa: F841
assert_equal(targets[0]['boxes'], targets_copy[0]['boxes']) assert_equal(targets[0]["boxes"], targets_copy[0]["boxes"])
assert_equal(targets[1]['boxes'], targets_copy[1]['boxes']) assert_equal(targets[1]["boxes"], targets_copy[1]["boxes"])
def test_not_float_normalize(self): def test_not_float_normalize(self):
transform = GeneralizedRCNNTransform(300, 500, torch.zeros(3), torch.ones(3)) transform = GeneralizedRCNNTransform(300, 500, torch.zeros(3), torch.ones(3))
image = [torch.randint(0, 255, (3, 200, 300), dtype=torch.uint8)] image = [torch.randint(0, 255, (3, 200, 300), dtype=torch.uint8)]
targets = [{'boxes': torch.rand(3, 4)}] targets = [{"boxes": torch.rand(3, 4)}]
with pytest.raises(TypeError): with pytest.raises(TypeError):
out = transform(image, targets) # noqa: F841 out = transform(image, targets) # noqa: F841
if __name__ == '__main__': if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
from common_utils import set_rng_seed, assert_equal
import io import io
from collections import OrderedDict
from typing import List, Tuple
import pytest import pytest
import torch import torch
from torchvision import ops from common_utils import set_rng_seed, assert_equal
from torchvision import models from torchvision import models
from torchvision import ops
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead
from torchvision.models.detection.image_list import ImageList from torchvision.models.detection.image_list import ImageList
from torchvision.models.detection.transform import GeneralizedRCNNTransform
from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork
from torchvision.models.detection.roi_heads import RoIHeads from torchvision.models.detection.roi_heads import RoIHeads
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork
from torchvision.models.detection.transform import GeneralizedRCNNTransform
from collections import OrderedDict
from typing import List, Tuple
from torchvision.ops._register_onnx_ops import _onnx_opset_version from torchvision.ops._register_onnx_ops import _onnx_opset_version
# In environments without onnxruntime we prefer to # In environments without onnxruntime we prefer to
...@@ -25,8 +24,16 @@ class TestONNXExporter: ...@@ -25,8 +24,16 @@ class TestONNXExporter:
def setup_class(cls): def setup_class(cls):
torch.manual_seed(123) torch.manual_seed(123)
def run_model(self, model, inputs_list, tolerate_small_mismatch=False, do_constant_folding=True, dynamic_axes=None, def run_model(
output_names=None, input_names=None): self,
model,
inputs_list,
tolerate_small_mismatch=False,
do_constant_folding=True,
dynamic_axes=None,
output_names=None,
input_names=None,
):
model.eval() model.eval()
onnx_io = io.BytesIO() onnx_io = io.BytesIO()
...@@ -35,14 +42,20 @@ class TestONNXExporter: ...@@ -35,14 +42,20 @@ class TestONNXExporter:
else: else:
torch_onnx_input = inputs_list[0] torch_onnx_input = inputs_list[0]
# export to onnx with the first input # export to onnx with the first input
torch.onnx.export(model, torch_onnx_input, onnx_io, torch.onnx.export(
do_constant_folding=do_constant_folding, opset_version=_onnx_opset_version, model,
dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names) torch_onnx_input,
onnx_io,
do_constant_folding=do_constant_folding,
opset_version=_onnx_opset_version,
dynamic_axes=dynamic_axes,
input_names=input_names,
output_names=output_names,
)
# validate the exported model with onnx runtime # validate the exported model with onnx runtime
for test_inputs in inputs_list: for test_inputs in inputs_list:
with torch.no_grad(): with torch.no_grad():
if isinstance(test_inputs, torch.Tensor) or \ if isinstance(test_inputs, torch.Tensor) or isinstance(test_inputs, list):
isinstance(test_inputs, list):
test_inputs = (test_inputs,) test_inputs = (test_inputs,)
test_ouputs = model(*test_inputs) test_ouputs = model(*test_inputs)
if isinstance(test_ouputs, torch.Tensor): if isinstance(test_ouputs, torch.Tensor):
...@@ -113,9 +126,9 @@ class TestONNXExporter: ...@@ -113,9 +126,9 @@ class TestONNXExporter:
def forward(self, boxes, size): def forward(self, boxes, size):
return ops.boxes.clip_boxes_to_image(boxes, size.shape) return ops.boxes.clip_boxes_to_image(boxes, size.shape)
self.run_model(Module(), [(boxes, size), (boxes, size_2)], self.run_model(
input_names=["boxes", "size"], Module(), [(boxes, size), (boxes, size_2)], input_names=["boxes", "size"], dynamic_axes={"size": [0, 1]}
dynamic_axes={"size": [0, 1]}) )
def test_roi_align(self): def test_roi_align(self):
x = torch.rand(1, 1, 10, 10, dtype=torch.float32) x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
...@@ -180,11 +193,11 @@ class TestONNXExporter: ...@@ -180,11 +193,11 @@ class TestONNXExporter:
input = torch.rand(3, 10, 20) input = torch.rand(3, 10, 20)
input_test = torch.rand(3, 100, 150) input_test = torch.rand(3, 100, 150)
self.run_model(TransformModule(), [(input,), (input_test,)], self.run_model(
input_names=["input1"], dynamic_axes={"input1": [0, 1, 2]}) TransformModule(), [(input,), (input_test,)], input_names=["input1"], dynamic_axes={"input1": [0, 1, 2]}
)
def test_transform_images(self): def test_transform_images(self):
class TransformModule(torch.nn.Module): class TransformModule(torch.nn.Module):
def __init__(self_module): def __init__(self_module):
super(TransformModule, self_module).__init__() super(TransformModule, self_module).__init__()
...@@ -221,11 +234,17 @@ class TestONNXExporter: ...@@ -221,11 +234,17 @@ class TestONNXExporter:
rpn_score_thresh = 0.0 rpn_score_thresh = 0.0
rpn = RegionProposalNetwork( rpn = RegionProposalNetwork(
rpn_anchor_generator, rpn_head, rpn_anchor_generator,
rpn_fg_iou_thresh, rpn_bg_iou_thresh, rpn_head,
rpn_batch_size_per_image, rpn_positive_fraction, rpn_fg_iou_thresh,
rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh, rpn_bg_iou_thresh,
score_thresh=rpn_score_thresh) rpn_batch_size_per_image,
rpn_positive_fraction,
rpn_pre_nms_top_n,
rpn_post_nms_top_n,
rpn_nms_thresh,
score_thresh=rpn_score_thresh,
)
return rpn return rpn
def _init_test_roi_heads_faster_rcnn(self): def _init_test_roi_heads_faster_rcnn(self):
...@@ -241,38 +260,38 @@ class TestONNXExporter: ...@@ -241,38 +260,38 @@ class TestONNXExporter:
box_nms_thresh = 0.5 box_nms_thresh = 0.5
box_detections_per_img = 100 box_detections_per_img = 100
box_roi_pool = ops.MultiScaleRoIAlign( box_roi_pool = ops.MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2)
featmap_names=['0', '1', '2', '3'],
output_size=7,
sampling_ratio=2)
resolution = box_roi_pool.output_size[0] resolution = box_roi_pool.output_size[0]
representation_size = 1024 representation_size = 1024
box_head = TwoMLPHead( box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
out_channels * resolution ** 2,
representation_size)
representation_size = 1024 representation_size = 1024
box_predictor = FastRCNNPredictor( box_predictor = FastRCNNPredictor(representation_size, num_classes)
representation_size,
num_classes)
roi_heads = RoIHeads( roi_heads = RoIHeads(
box_roi_pool, box_head, box_predictor, box_roi_pool,
box_fg_iou_thresh, box_bg_iou_thresh, box_head,
box_batch_size_per_image, box_positive_fraction, box_predictor,
box_fg_iou_thresh,
box_bg_iou_thresh,
box_batch_size_per_image,
box_positive_fraction,
bbox_reg_weights, bbox_reg_weights,
box_score_thresh, box_nms_thresh, box_detections_per_img) box_score_thresh,
box_nms_thresh,
box_detections_per_img,
)
return roi_heads return roi_heads
def get_features(self, images): def get_features(self, images):
s0, s1 = images.shape[-2:] s0, s1 = images.shape[-2:]
features = [ features = [
('0', torch.rand(2, 256, s0 // 4, s1 // 4)), ("0", torch.rand(2, 256, s0 // 4, s1 // 4)),
('1', torch.rand(2, 256, s0 // 8, s1 // 8)), ("1", torch.rand(2, 256, s0 // 8, s1 // 8)),
('2', torch.rand(2, 256, s0 // 16, s1 // 16)), ("2", torch.rand(2, 256, s0 // 16, s1 // 16)),
('3', torch.rand(2, 256, s0 // 32, s1 // 32)), ("3", torch.rand(2, 256, s0 // 32, s1 // 32)),
('4', torch.rand(2, 256, s0 // 64, s1 // 64)), ("4", torch.rand(2, 256, s0 // 64, s1 // 64)),
] ]
features = OrderedDict(features) features = OrderedDict(features)
return features return features
...@@ -298,36 +317,56 @@ class TestONNXExporter: ...@@ -298,36 +317,56 @@ class TestONNXExporter:
model.eval() model.eval()
model(images, features) model(images, features)
self.run_model(model, [(images, features), (images2, test_features)], tolerate_small_mismatch=True, self.run_model(
input_names=["input1", "input2", "input3", "input4", "input5", "input6"], model,
dynamic_axes={"input1": [0, 1, 2, 3], "input2": [0, 1, 2, 3], [(images, features), (images2, test_features)],
"input3": [0, 1, 2, 3], "input4": [0, 1, 2, 3], tolerate_small_mismatch=True,
"input5": [0, 1, 2, 3], "input6": [0, 1, 2, 3]}) input_names=["input1", "input2", "input3", "input4", "input5", "input6"],
dynamic_axes={
"input1": [0, 1, 2, 3],
"input2": [0, 1, 2, 3],
"input3": [0, 1, 2, 3],
"input4": [0, 1, 2, 3],
"input5": [0, 1, 2, 3],
"input6": [0, 1, 2, 3],
},
)
def test_multi_scale_roi_align(self): def test_multi_scale_roi_align(self):
class TransformModule(torch.nn.Module): class TransformModule(torch.nn.Module):
def __init__(self): def __init__(self):
super(TransformModule, self).__init__() super(TransformModule, self).__init__()
self.model = ops.MultiScaleRoIAlign(['feat1', 'feat2'], 3, 2) self.model = ops.MultiScaleRoIAlign(["feat1", "feat2"], 3, 2)
self.image_sizes = [(512, 512)] self.image_sizes = [(512, 512)]
def forward(self, input, boxes): def forward(self, input, boxes):
return self.model(input, boxes, self.image_sizes) return self.model(input, boxes, self.image_sizes)
i = OrderedDict() i = OrderedDict()
i['feat1'] = torch.rand(1, 5, 64, 64) i["feat1"] = torch.rand(1, 5, 64, 64)
i['feat2'] = torch.rand(1, 5, 16, 16) i["feat2"] = torch.rand(1, 5, 16, 16)
boxes = torch.rand(6, 4) * 256 boxes = torch.rand(6, 4) * 256
boxes[:, 2:] += boxes[:, :2] boxes[:, 2:] += boxes[:, :2]
i1 = OrderedDict() i1 = OrderedDict()
i1['feat1'] = torch.rand(1, 5, 64, 64) i1["feat1"] = torch.rand(1, 5, 64, 64)
i1['feat2'] = torch.rand(1, 5, 16, 16) i1["feat2"] = torch.rand(1, 5, 16, 16)
boxes1 = torch.rand(6, 4) * 256 boxes1 = torch.rand(6, 4) * 256
boxes1[:, 2:] += boxes1[:, :2] boxes1[:, 2:] += boxes1[:, :2]
self.run_model(TransformModule(), [(i, [boxes],), (i1, [boxes1],)]) self.run_model(
TransformModule(),
[
(
i,
[boxes],
),
(
i1,
[boxes1],
),
],
)
def test_roi_heads(self): def test_roi_heads(self):
class RoiHeadsModule(torch.nn.Module): class RoiHeadsModule(torch.nn.Module):
...@@ -342,9 +381,7 @@ class TestONNXExporter: ...@@ -342,9 +381,7 @@ class TestONNXExporter:
images = ImageList(images, [i.shape[-2:] for i in images]) images = ImageList(images, [i.shape[-2:] for i in images])
proposals, _ = self_module.rpn(images, features) proposals, _ = self_module.rpn(images, features)
detections, _ = self_module.roi_heads(features, proposals, images.image_sizes) detections, _ = self_module.roi_heads(features, proposals, images.image_sizes)
detections = self_module.transform.postprocess(detections, detections = self_module.transform.postprocess(detections, images.image_sizes, original_image_sizes)
images.image_sizes,
original_image_sizes)
return detections return detections
images = torch.rand(2, 3, 100, 100) images = torch.rand(2, 3, 100, 100)
...@@ -356,13 +393,24 @@ class TestONNXExporter: ...@@ -356,13 +393,24 @@ class TestONNXExporter:
model.eval() model.eval()
model(images, features) model(images, features)
self.run_model(model, [(images, features), (images2, test_features)], tolerate_small_mismatch=True, self.run_model(
input_names=["input1", "input2", "input3", "input4", "input5", "input6"], model,
dynamic_axes={"input1": [0, 1, 2, 3], "input2": [0, 1, 2, 3], "input3": [0, 1, 2, 3], [(images, features), (images2, test_features)],
"input4": [0, 1, 2, 3], "input5": [0, 1, 2, 3], "input6": [0, 1, 2, 3]}) tolerate_small_mismatch=True,
input_names=["input1", "input2", "input3", "input4", "input5", "input6"],
dynamic_axes={
"input1": [0, 1, 2, 3],
"input2": [0, 1, 2, 3],
"input3": [0, 1, 2, 3],
"input4": [0, 1, 2, 3],
"input5": [0, 1, 2, 3],
"input6": [0, 1, 2, 3],
},
)
def get_image(self, rel_path: str, size: Tuple[int, int]) -> torch.Tensor: def get_image(self, rel_path: str, size: Tuple[int, int]) -> torch.Tensor:
import os import os
from PIL import Image from PIL import Image
from torchvision import transforms from torchvision import transforms
...@@ -373,8 +421,10 @@ class TestONNXExporter: ...@@ -373,8 +421,10 @@ class TestONNXExporter:
return transforms.ToTensor()(image) return transforms.ToTensor()(image)
def get_test_images(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: def get_test_images(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
return ([self.get_image("encode_jpeg/grace_hopper_517x606.jpg", (100, 320))], return (
[self.get_image("fakedata/logos/rgb_pytorch.png", (250, 380))]) [self.get_image("encode_jpeg/grace_hopper_517x606.jpg", (100, 320))],
[self.get_image("fakedata/logos/rgb_pytorch.png", (250, 380))],
)
def test_faster_rcnn(self): def test_faster_rcnn(self):
images, test_images = self.get_test_images() images, test_images = self.get_test_images()
...@@ -383,15 +433,23 @@ class TestONNXExporter: ...@@ -383,15 +433,23 @@ class TestONNXExporter:
model.eval() model.eval()
model(images) model(images)
# Test exported model on images of different size, or dummy input # Test exported model on images of different size, or dummy input
self.run_model(model, [(images,), (test_images,), (dummy_image,)], input_names=["images_tensors"], self.run_model(
output_names=["outputs"], model,
dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]}, [(images,), (test_images,), (dummy_image,)],
tolerate_small_mismatch=True) input_names=["images_tensors"],
output_names=["outputs"],
dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]},
tolerate_small_mismatch=True,
)
# Test exported model for an image with no detections on other images # Test exported model for an image with no detections on other images
self.run_model(model, [(dummy_image,), (images,)], input_names=["images_tensors"], self.run_model(
output_names=["outputs"], model,
dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]}, [(dummy_image,), (images,)],
tolerate_small_mismatch=True) input_names=["images_tensors"],
output_names=["outputs"],
dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]},
tolerate_small_mismatch=True,
)
# Verify that paste_mask_in_image beahves the same in tracing. # Verify that paste_mask_in_image beahves the same in tracing.
# This test also compares both paste_masks_in_image and _onnx_paste_masks_in_image # This test also compares both paste_masks_in_image and _onnx_paste_masks_in_image
...@@ -403,11 +461,11 @@ class TestONNXExporter: ...@@ -403,11 +461,11 @@ class TestONNXExporter:
boxes *= 50 boxes *= 50
o_im_s = (100, 100) o_im_s = (100, 100)
from torchvision.models.detection.roi_heads import paste_masks_in_image from torchvision.models.detection.roi_heads import paste_masks_in_image
out = paste_masks_in_image(masks, boxes, o_im_s) out = paste_masks_in_image(masks, boxes, o_im_s)
jit_trace = torch.jit.trace(paste_masks_in_image, jit_trace = torch.jit.trace(
(masks, boxes, paste_masks_in_image, (masks, boxes, [torch.tensor(o_im_s[0]), torch.tensor(o_im_s[1])])
[torch.tensor(o_im_s[0]), )
torch.tensor(o_im_s[1])]))
out_trace = jit_trace(masks, boxes, [torch.tensor(o_im_s[0]), torch.tensor(o_im_s[1])]) out_trace = jit_trace(masks, boxes, [torch.tensor(o_im_s[0]), torch.tensor(o_im_s[1])])
assert torch.all(out.eq(out_trace)) assert torch.all(out.eq(out_trace))
...@@ -418,6 +476,7 @@ class TestONNXExporter: ...@@ -418,6 +476,7 @@ class TestONNXExporter:
boxes2 *= 100 boxes2 *= 100
o_im_s2 = (200, 200) o_im_s2 = (200, 200)
from torchvision.models.detection.roi_heads import paste_masks_in_image from torchvision.models.detection.roi_heads import paste_masks_in_image
out2 = paste_masks_in_image(masks2, boxes2, o_im_s2) out2 = paste_masks_in_image(masks2, boxes2, o_im_s2)
out_trace2 = jit_trace(masks2, boxes2, [torch.tensor(o_im_s2[0]), torch.tensor(o_im_s2[1])]) out_trace2 = jit_trace(masks2, boxes2, [torch.tensor(o_im_s2[0]), torch.tensor(o_im_s2[1])])
...@@ -430,19 +489,35 @@ class TestONNXExporter: ...@@ -430,19 +489,35 @@ class TestONNXExporter:
model.eval() model.eval()
model(images) model(images)
# Test exported model on images of different size, or dummy input # Test exported model on images of different size, or dummy input
self.run_model(model, [(images,), (test_images,), (dummy_image,)], self.run_model(
input_names=["images_tensors"], model,
output_names=["boxes", "labels", "scores", "masks"], [(images,), (test_images,), (dummy_image,)],
dynamic_axes={"images_tensors": [0, 1, 2], "boxes": [0, 1], "labels": [0], input_names=["images_tensors"],
"scores": [0], "masks": [0, 1, 2]}, output_names=["boxes", "labels", "scores", "masks"],
tolerate_small_mismatch=True) dynamic_axes={
"images_tensors": [0, 1, 2],
"boxes": [0, 1],
"labels": [0],
"scores": [0],
"masks": [0, 1, 2],
},
tolerate_small_mismatch=True,
)
# Test exported model for an image with no detections on other images # Test exported model for an image with no detections on other images
self.run_model(model, [(dummy_image,), (images,)], self.run_model(
input_names=["images_tensors"], model,
output_names=["boxes", "labels", "scores", "masks"], [(dummy_image,), (images,)],
dynamic_axes={"images_tensors": [0, 1, 2], "boxes": [0, 1], "labels": [0], input_names=["images_tensors"],
"scores": [0], "masks": [0, 1, 2]}, output_names=["boxes", "labels", "scores", "masks"],
tolerate_small_mismatch=True) dynamic_axes={
"images_tensors": [0, 1, 2],
"boxes": [0, 1],
"labels": [0],
"scores": [0],
"masks": [0, 1, 2],
},
tolerate_small_mismatch=True,
)
# Verify that heatmaps_to_keypoints behaves the same in tracing. # Verify that heatmaps_to_keypoints behaves the same in tracing.
# This test also compares both heatmaps_to_keypoints and _onnx_heatmaps_to_keypoints # This test also compares both heatmaps_to_keypoints and _onnx_heatmaps_to_keypoints
...@@ -451,6 +526,7 @@ class TestONNXExporter: ...@@ -451,6 +526,7 @@ class TestONNXExporter:
maps = torch.rand(10, 1, 26, 26) maps = torch.rand(10, 1, 26, 26)
rois = torch.rand(10, 4) rois = torch.rand(10, 4)
from torchvision.models.detection.roi_heads import heatmaps_to_keypoints from torchvision.models.detection.roi_heads import heatmaps_to_keypoints
out = heatmaps_to_keypoints(maps, rois) out = heatmaps_to_keypoints(maps, rois)
jit_trace = torch.jit.trace(heatmaps_to_keypoints, (maps, rois)) jit_trace = torch.jit.trace(heatmaps_to_keypoints, (maps, rois))
out_trace = jit_trace(maps, rois) out_trace = jit_trace(maps, rois)
...@@ -461,6 +537,7 @@ class TestONNXExporter: ...@@ -461,6 +537,7 @@ class TestONNXExporter:
maps2 = torch.rand(20, 2, 21, 21) maps2 = torch.rand(20, 2, 21, 21)
rois2 = torch.rand(20, 4) rois2 = torch.rand(20, 4)
from torchvision.models.detection.roi_heads import heatmaps_to_keypoints from torchvision.models.detection.roi_heads import heatmaps_to_keypoints
out2 = heatmaps_to_keypoints(maps2, rois2) out2 = heatmaps_to_keypoints(maps2, rois2)
out_trace2 = jit_trace(maps2, rois2) out_trace2 = jit_trace(maps2, rois2)
...@@ -473,29 +550,38 @@ class TestONNXExporter: ...@@ -473,29 +550,38 @@ class TestONNXExporter:
model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300) model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300)
model.eval() model.eval()
model(images) model(images)
self.run_model(model, [(images,), (test_images,), (dummy_images,)], self.run_model(
input_names=["images_tensors"], model,
output_names=["outputs1", "outputs2", "outputs3", "outputs4"], [(images,), (test_images,), (dummy_images,)],
dynamic_axes={"images_tensors": [0, 1, 2]}, input_names=["images_tensors"],
tolerate_small_mismatch=True) output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
dynamic_axes={"images_tensors": [0, 1, 2]},
self.run_model(model, [(dummy_images,), (test_images,)], tolerate_small_mismatch=True,
input_names=["images_tensors"], )
output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
dynamic_axes={"images_tensors": [0, 1, 2]}, self.run_model(
tolerate_small_mismatch=True) model,
[(dummy_images,), (test_images,)],
input_names=["images_tensors"],
output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
dynamic_axes={"images_tensors": [0, 1, 2]},
tolerate_small_mismatch=True,
)
def test_shufflenet_v2_dynamic_axes(self): def test_shufflenet_v2_dynamic_axes(self):
model = models.shufflenet_v2_x0_5(pretrained=True) model = models.shufflenet_v2_x0_5(pretrained=True)
dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True) dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True)
test_inputs = torch.cat([dummy_input, dummy_input, dummy_input], 0) test_inputs = torch.cat([dummy_input, dummy_input, dummy_input], 0)
self.run_model(model, [(dummy_input,), (test_inputs,)], self.run_model(
input_names=["input_images"], model,
output_names=["output"], [(dummy_input,), (test_inputs,)],
dynamic_axes={"input_images": {0: 'batch_size'}, "output": {0: 'batch_size'}}, input_names=["input_images"],
tolerate_small_mismatch=True) output_names=["output"],
dynamic_axes={"input_images": {0: "batch_size"}, "output": {0: "batch_size"}},
tolerate_small_mismatch=True,
)
if __name__ == '__main__': if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
from common_utils import needs_cuda, cpu_and_gpu, assert_equal
import math import math
import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import pytest from functools import lru_cache
from typing import Tuple
import numpy as np import numpy as np
import os import pytest
from PIL import Image
import torch import torch
from functools import lru_cache from common_utils import needs_cuda, cpu_and_gpu, assert_equal
from PIL import Image
from torch import Tensor from torch import Tensor
from torch.autograd import gradcheck from torch.autograd import gradcheck
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from torchvision import ops from torchvision import ops
from typing import Tuple
class RoIOpTester(ABC): class RoIOpTester(ABC):
dtype = torch.float64 dtype = torch.float64
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('contiguous', (True, False)) @pytest.mark.parametrize("contiguous", (True, False))
def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, **kwargs): def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, **kwargs):
x_dtype = self.dtype if x_dtype is None else x_dtype x_dtype = self.dtype if x_dtype is None else x_dtype
rois_dtype = self.dtype if rois_dtype is None else rois_dtype rois_dtype = self.dtype if rois_dtype is None else rois_dtype
...@@ -30,33 +29,33 @@ class RoIOpTester(ABC): ...@@ -30,33 +29,33 @@ class RoIOpTester(ABC):
x = torch.rand(2, n_channels, 10, 10, dtype=x_dtype, device=device) x = torch.rand(2, n_channels, 10, 10, dtype=x_dtype, device=device)
if not contiguous: if not contiguous:
x = x.permute(0, 1, 3, 2) x = x.permute(0, 1, 3, 2)
rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy) rois = torch.tensor(
[0, 0, 5, 4, 9], [[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]], # format is (xyxy)
[0, 5, 5, 9, 9], dtype=rois_dtype,
[1, 0, 0, 9, 9]], device=device,
dtype=rois_dtype, device=device) )
pool_h, pool_w = pool_size, pool_size pool_h, pool_w = pool_size, pool_size
y = self.fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs) y = self.fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs)
# the following should be true whether we're running an autocast test or not. # the following should be true whether we're running an autocast test or not.
assert y.dtype == x.dtype assert y.dtype == x.dtype
gt_y = self.expected_fn(x, rois, pool_h, pool_w, spatial_scale=1, gt_y = self.expected_fn(
sampling_ratio=-1, device=device, dtype=self.dtype, **kwargs) x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=device, dtype=self.dtype, **kwargs
)
tol = 1e-3 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5 tol = 1e-3 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5
torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol) torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('contiguous', (True, False)) @pytest.mark.parametrize("contiguous", (True, False))
def test_backward(self, device, contiguous): def test_backward(self, device, contiguous):
pool_size = 2 pool_size = 2
x = torch.rand(1, 2 * (pool_size ** 2), 5, 5, dtype=self.dtype, device=device, requires_grad=True) x = torch.rand(1, 2 * (pool_size ** 2), 5, 5, dtype=self.dtype, device=device, requires_grad=True)
if not contiguous: if not contiguous:
x = x.permute(0, 1, 3, 2) x = x.permute(0, 1, 3, 2)
rois = torch.tensor([[0, 0, 0, 4, 4], # format is (xyxy) rois = torch.tensor(
[0, 0, 2, 3, 4], [[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=self.dtype, device=device # format is (xyxy)
[0, 2, 2, 4, 4]], )
dtype=self.dtype, device=device)
def func(z): def func(z):
return self.fn(z, rois, pool_size, pool_size, spatial_scale=1, sampling_ratio=1) return self.fn(z, rois, pool_size, pool_size, spatial_scale=1, sampling_ratio=1)
...@@ -67,8 +66,8 @@ class RoIOpTester(ABC): ...@@ -67,8 +66,8 @@ class RoIOpTester(ABC):
gradcheck(script_func, (x,)) gradcheck(script_func, (x,))
@needs_cuda @needs_cuda
@pytest.mark.parametrize('x_dtype', (torch.float, torch.half)) @pytest.mark.parametrize("x_dtype", (torch.float, torch.half))
@pytest.mark.parametrize('rois_dtype', (torch.float, torch.half)) @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half))
def test_autocast(self, x_dtype, rois_dtype): def test_autocast(self, x_dtype, rois_dtype):
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
self.test_forward(torch.device("cuda"), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype) self.test_forward(torch.device("cuda"), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype)
...@@ -107,8 +106,9 @@ class TestRoiPool(RoIOpTester): ...@@ -107,8 +106,9 @@ class TestRoiPool(RoIOpTester):
scriped = torch.jit.script(ops.roi_pool) scriped = torch.jit.script(ops.roi_pool)
return lambda x: scriped(x, rois, pool_size) return lambda x: scriped(x, rois, pool_size)
def expected_fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, def expected_fn(
device=None, dtype=torch.float64): self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=None, dtype=torch.float64
):
if device is None: if device is None:
device = torch.device("cpu") device = torch.device("cpu")
...@@ -121,7 +121,7 @@ class TestRoiPool(RoIOpTester): ...@@ -121,7 +121,7 @@ class TestRoiPool(RoIOpTester):
for roi_idx, roi in enumerate(rois): for roi_idx, roi in enumerate(rois):
batch_idx = int(roi[0]) batch_idx = int(roi[0])
j_begin, i_begin, j_end, i_end = (int(round(x.item() * spatial_scale)) for x in roi[1:]) j_begin, i_begin, j_end, i_end = (int(round(x.item() * spatial_scale)) for x in roi[1:])
roi_x = x[batch_idx, :, i_begin:i_end + 1, j_begin:j_end + 1] roi_x = x[batch_idx, :, i_begin : i_end + 1, j_begin : j_end + 1]
roi_h, roi_w = roi_x.shape[-2:] roi_h, roi_w = roi_x.shape[-2:]
bin_h = roi_h / pool_h bin_h = roi_h / pool_h
...@@ -146,8 +146,9 @@ class TestPSRoIPool(RoIOpTester): ...@@ -146,8 +146,9 @@ class TestPSRoIPool(RoIOpTester):
scriped = torch.jit.script(ops.ps_roi_pool) scriped = torch.jit.script(ops.ps_roi_pool)
return lambda x: scriped(x, rois, pool_size) return lambda x: scriped(x, rois, pool_size)
def expected_fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, def expected_fn(
device=None, dtype=torch.float64): self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=None, dtype=torch.float64
):
if device is None: if device is None:
device = torch.device("cpu") device = torch.device("cpu")
n_input_channels = x.size(1) n_input_channels = x.size(1)
...@@ -161,7 +162,7 @@ class TestPSRoIPool(RoIOpTester): ...@@ -161,7 +162,7 @@ class TestPSRoIPool(RoIOpTester):
for roi_idx, roi in enumerate(rois): for roi_idx, roi in enumerate(rois):
batch_idx = int(roi[0]) batch_idx = int(roi[0])
j_begin, i_begin, j_end, i_end = (int(round(x.item() * spatial_scale)) for x in roi[1:]) j_begin, i_begin, j_end, i_end = (int(round(x.item() * spatial_scale)) for x in roi[1:])
roi_x = x[batch_idx, :, i_begin:i_end + 1, j_begin:j_end + 1] roi_x = x[batch_idx, :, i_begin : i_end + 1, j_begin : j_end + 1]
roi_height = max(i_end - i_begin, 1) roi_height = max(i_end - i_begin, 1)
roi_width = max(j_end - j_begin, 1) roi_width = max(j_end - j_begin, 1)
...@@ -216,21 +217,32 @@ def bilinear_interpolate(data, y, x, snap_border=False): ...@@ -216,21 +217,32 @@ def bilinear_interpolate(data, y, x, snap_border=False):
class TestRoIAlign(RoIOpTester): class TestRoIAlign(RoIOpTester):
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligned=False, **kwargs): def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligned=False, **kwargs):
return ops.RoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, return ops.RoIAlign(
sampling_ratio=sampling_ratio, aligned=aligned)(x, rois) (pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned
)(x, rois)
def get_script_fn(self, rois, pool_size): def get_script_fn(self, rois, pool_size):
scriped = torch.jit.script(ops.roi_align) scriped = torch.jit.script(ops.roi_align)
return lambda x: scriped(x, rois, pool_size) return lambda x: scriped(x, rois, pool_size)
def expected_fn(self, in_data, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligned=False, def expected_fn(
device=None, dtype=torch.float64): self,
in_data,
rois,
pool_h,
pool_w,
spatial_scale=1,
sampling_ratio=-1,
aligned=False,
device=None,
dtype=torch.float64,
):
if device is None: if device is None:
device = torch.device("cpu") device = torch.device("cpu")
n_channels = in_data.size(1) n_channels = in_data.size(1)
out_data = torch.zeros(rois.size(0), n_channels, pool_h, pool_w, dtype=dtype, device=device) out_data = torch.zeros(rois.size(0), n_channels, pool_h, pool_w, dtype=dtype, device=device)
offset = 0.5 if aligned else 0. offset = 0.5 if aligned else 0.0
for r, roi in enumerate(rois): for r, roi in enumerate(rois):
batch_idx = int(roi[0]) batch_idx = int(roi[0])
...@@ -264,21 +276,23 @@ class TestRoIAlign(RoIOpTester): ...@@ -264,21 +276,23 @@ class TestRoIAlign(RoIOpTester):
def test_boxes_shape(self): def test_boxes_shape(self):
self._helper_boxes_shape(ops.roi_align) self._helper_boxes_shape(ops.roi_align)
@pytest.mark.parametrize('aligned', (True, False)) @pytest.mark.parametrize("aligned", (True, False))
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('contiguous', (True, False)) @pytest.mark.parametrize("contiguous", (True, False))
def test_forward(self, device, contiguous, aligned, x_dtype=None, rois_dtype=None): def test_forward(self, device, contiguous, aligned, x_dtype=None, rois_dtype=None):
super().test_forward(device=device, contiguous=contiguous, x_dtype=x_dtype, rois_dtype=rois_dtype, super().test_forward(
aligned=aligned) device=device, contiguous=contiguous, x_dtype=x_dtype, rois_dtype=rois_dtype, aligned=aligned
)
@needs_cuda @needs_cuda
@pytest.mark.parametrize('aligned', (True, False)) @pytest.mark.parametrize("aligned", (True, False))
@pytest.mark.parametrize('x_dtype', (torch.float, torch.half)) @pytest.mark.parametrize("x_dtype", (torch.float, torch.half))
@pytest.mark.parametrize('rois_dtype', (torch.float, torch.half)) @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half))
def test_autocast(self, aligned, x_dtype, rois_dtype): def test_autocast(self, aligned, x_dtype, rois_dtype):
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
self.test_forward(torch.device("cuda"), contiguous=False, aligned=aligned, x_dtype=x_dtype, self.test_forward(
rois_dtype=rois_dtype) torch.device("cuda"), contiguous=False, aligned=aligned, x_dtype=x_dtype, rois_dtype=rois_dtype
)
def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000): def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000):
rois = torch.randint(0, img_size // 2, size=(num_rois, 5)).to(dtype) rois = torch.randint(0, img_size // 2, size=(num_rois, 5)).to(dtype)
...@@ -286,9 +300,9 @@ class TestRoIAlign(RoIOpTester): ...@@ -286,9 +300,9 @@ class TestRoIAlign(RoIOpTester):
rois[:, 3:] += rois[:, 1:3] # make sure boxes aren't degenerate rois[:, 3:] += rois[:, 1:3] # make sure boxes aren't degenerate
return rois return rois
@pytest.mark.parametrize('aligned', (True, False)) @pytest.mark.parametrize("aligned", (True, False))
@pytest.mark.parametrize('scale, zero_point', ((1, 0), (2, 10), (0.1, 50))) @pytest.mark.parametrize("scale, zero_point", ((1, 0), (2, 10), (0.1, 50)))
@pytest.mark.parametrize('qdtype', (torch.qint8, torch.quint8, torch.qint32)) @pytest.mark.parametrize("qdtype", (torch.qint8, torch.quint8, torch.qint32))
def test_qroialign(self, aligned, scale, zero_point, qdtype): def test_qroialign(self, aligned, scale, zero_point, qdtype):
"""Make sure quantized version of RoIAlign is close to float version""" """Make sure quantized version of RoIAlign is close to float version"""
pool_size = 5 pool_size = 5
...@@ -338,7 +352,7 @@ class TestRoIAlign(RoIOpTester): ...@@ -338,7 +352,7 @@ class TestRoIAlign(RoIOpTester):
# - any difference between qy and quantized_float_y is == scale # - any difference between qy and quantized_float_y is == scale
diff_idx = torch.where(qy != quantized_float_y) diff_idx = torch.where(qy != quantized_float_y)
num_diff = diff_idx[0].numel() num_diff = diff_idx[0].numel()
assert num_diff / qy.numel() < .05 assert num_diff / qy.numel() < 0.05
abs_diff = torch.abs(qy[diff_idx].dequantize() - quantized_float_y[diff_idx].dequantize()) abs_diff = torch.abs(qy[diff_idx].dequantize() - quantized_float_y[diff_idx].dequantize())
t_scale = torch.full_like(abs_diff, fill_value=scale) t_scale = torch.full_like(abs_diff, fill_value=scale)
...@@ -356,15 +370,15 @@ class TestRoIAlign(RoIOpTester): ...@@ -356,15 +370,15 @@ class TestRoIAlign(RoIOpTester):
class TestPSRoIAlign(RoIOpTester): class TestPSRoIAlign(RoIOpTester):
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs): def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
return ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, return ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)(x, rois)
sampling_ratio=sampling_ratio)(x, rois)
def get_script_fn(self, rois, pool_size): def get_script_fn(self, rois, pool_size):
scriped = torch.jit.script(ops.ps_roi_align) scriped = torch.jit.script(ops.ps_roi_align)
return lambda x: scriped(x, rois, pool_size) return lambda x: scriped(x, rois, pool_size)
def expected_fn(self, in_data, rois, pool_h, pool_w, device, spatial_scale=1, def expected_fn(
sampling_ratio=-1, dtype=torch.float64): self, in_data, rois, pool_h, pool_w, device, spatial_scale=1, sampling_ratio=-1, dtype=torch.float64
):
if device is None: if device is None:
device = torch.device("cpu") device = torch.device("cpu")
n_input_channels = in_data.size(1) n_input_channels = in_data.size(1)
...@@ -407,15 +421,17 @@ class TestPSRoIAlign(RoIOpTester): ...@@ -407,15 +421,17 @@ class TestPSRoIAlign(RoIOpTester):
class TestMultiScaleRoIAlign: class TestMultiScaleRoIAlign:
def test_msroialign_repr(self): def test_msroialign_repr(self):
fmap_names = ['0'] fmap_names = ["0"]
output_size = (7, 7) output_size = (7, 7)
sampling_ratio = 2 sampling_ratio = 2
# Pass mock feature map names # Pass mock feature map names
t = ops.poolers.MultiScaleRoIAlign(fmap_names, output_size, sampling_ratio) t = ops.poolers.MultiScaleRoIAlign(fmap_names, output_size, sampling_ratio)
# Check integrity of object __repr__ attribute # Check integrity of object __repr__ attribute
expected_string = (f"MultiScaleRoIAlign(featmap_names={fmap_names}, output_size={output_size}, " expected_string = (
f"sampling_ratio={sampling_ratio})") f"MultiScaleRoIAlign(featmap_names={fmap_names}, output_size={output_size}, "
f"sampling_ratio={sampling_ratio})"
)
assert repr(t) == expected_string assert repr(t) == expected_string
...@@ -460,9 +476,9 @@ class TestNMS: ...@@ -460,9 +476,9 @@ class TestNMS:
scores = torch.rand(N) scores = torch.rand(N)
return boxes, scores return boxes, scores
@pytest.mark.parametrize("iou", (.2, .5, .8)) @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
def test_nms_ref(self, iou): def test_nms_ref(self, iou):
err_msg = 'NMS incompatible between CPU and reference implementation for IoU={}' err_msg = "NMS incompatible between CPU and reference implementation for IoU={}"
boxes, scores = self._create_tensors_with_iou(1000, iou) boxes, scores = self._create_tensors_with_iou(1000, iou)
keep_ref = self._reference_nms(boxes, scores, iou) keep_ref = self._reference_nms(boxes, scores, iou)
keep = ops.nms(boxes, scores, iou) keep = ops.nms(boxes, scores, iou)
...@@ -478,13 +494,13 @@ class TestNMS: ...@@ -478,13 +494,13 @@ class TestNMS:
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
ops.nms(torch.rand(3, 4), torch.rand(4), 0.5) ops.nms(torch.rand(3, 4), torch.rand(4), 0.5)
@pytest.mark.parametrize("iou", (.2, .5, .8)) @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
@pytest.mark.parametrize("scale, zero_point", ((1, 0), (2, 50), (3, 10))) @pytest.mark.parametrize("scale, zero_point", ((1, 0), (2, 50), (3, 10)))
def test_qnms(self, iou, scale, zero_point): def test_qnms(self, iou, scale, zero_point):
# Note: we compare qnms vs nms instead of qnms vs reference implementation. # Note: we compare qnms vs nms instead of qnms vs reference implementation.
# This is because with the int convertion, the trick used in _create_tensors_with_iou # This is because with the int convertion, the trick used in _create_tensors_with_iou
# doesn't really work (in fact, nms vs reference implem will also fail with ints) # doesn't really work (in fact, nms vs reference implem will also fail with ints)
err_msg = 'NMS and QNMS give different results for IoU={}' err_msg = "NMS and QNMS give different results for IoU={}"
boxes, scores = self._create_tensors_with_iou(1000, iou) boxes, scores = self._create_tensors_with_iou(1000, iou)
scores *= 100 # otherwise most scores would be 0 or 1 after int convertion scores *= 100 # otherwise most scores would be 0 or 1 after int convertion
...@@ -500,10 +516,10 @@ class TestNMS: ...@@ -500,10 +516,10 @@ class TestNMS:
assert torch.allclose(qkeep, keep), err_msg.format(iou) assert torch.allclose(qkeep, keep), err_msg.format(iou)
@needs_cuda @needs_cuda
@pytest.mark.parametrize("iou", (.2, .5, .8)) @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
def test_nms_cuda(self, iou, dtype=torch.float64): def test_nms_cuda(self, iou, dtype=torch.float64):
tol = 1e-3 if dtype is torch.half else 1e-5 tol = 1e-3 if dtype is torch.half else 1e-5
err_msg = 'NMS incompatible between CPU and CUDA for IoU={}' err_msg = "NMS incompatible between CPU and CUDA for IoU={}"
boxes, scores = self._create_tensors_with_iou(1000, iou) boxes, scores = self._create_tensors_with_iou(1000, iou)
r_cpu = ops.nms(boxes, scores, iou) r_cpu = ops.nms(boxes, scores, iou)
...@@ -517,7 +533,7 @@ class TestNMS: ...@@ -517,7 +533,7 @@ class TestNMS:
assert is_eq, err_msg.format(iou) assert is_eq, err_msg.format(iou)
@needs_cuda @needs_cuda
@pytest.mark.parametrize("iou", (.2, .5, .8)) @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
@pytest.mark.parametrize("dtype", (torch.float, torch.half)) @pytest.mark.parametrize("dtype", (torch.float, torch.half))
def test_autocast(self, iou, dtype): def test_autocast(self, iou, dtype):
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
...@@ -525,9 +541,13 @@ class TestNMS: ...@@ -525,9 +541,13 @@ class TestNMS:
@needs_cuda @needs_cuda
def test_nms_cuda_float16(self): def test_nms_cuda_float16(self):
boxes = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551], boxes = torch.tensor(
[285.1472, 188.7374, 1192.4984, 851.0669], [
[279.2440, 197.9812, 1189.4746, 849.2019]]).cuda() [285.3538, 185.5758, 1193.5110, 851.4551],
[285.1472, 188.7374, 1192.4984, 851.0669],
[279.2440, 197.9812, 1189.4746, 849.2019],
]
).cuda()
scores = torch.tensor([0.6370, 0.7569, 0.3966]).cuda() scores = torch.tensor([0.6370, 0.7569, 0.3966]).cuda()
iou_thres = 0.2 iou_thres = 0.2
...@@ -539,7 +559,7 @@ class TestNMS: ...@@ -539,7 +559,7 @@ class TestNMS:
"""Make sure that both implementations of batched_nms yield identical results""" """Make sure that both implementations of batched_nms yield identical results"""
num_boxes = 1000 num_boxes = 1000
iou_threshold = .9 iou_threshold = 0.9
boxes = torch.cat((torch.rand(num_boxes, 2), torch.rand(num_boxes, 2) + 10), dim=1) boxes = torch.cat((torch.rand(num_boxes, 2), torch.rand(num_boxes, 2) + 10), dim=1)
assert max(boxes[:, 0]) < min(boxes[:, 2]) # x1 < x2 assert max(boxes[:, 0]) < min(boxes[:, 2]) # x1 < x2
...@@ -603,8 +623,11 @@ class TestDeformConv: ...@@ -603,8 +623,11 @@ class TestDeformConv:
if mask is not None: if mask is not None:
mask_value = mask[b, mask_idx, i, j] mask_value = mask[b, mask_idx, i, j]
out[b, c_out, i, j] += (mask_value * weight[c_out, c, di, dj] * out[b, c_out, i, j] += (
bilinear_interpolate(x[b, c_in, :, :], pi, pj)) mask_value
* weight[c_out, c, di, dj]
* bilinear_interpolate(x[b, c_in, :, :], pi, pj)
)
out += bias.view(1, n_out_channels, 1, 1) out += bias.view(1, n_out_channels, 1, 1)
return out return out
...@@ -630,14 +653,29 @@ class TestDeformConv: ...@@ -630,14 +653,29 @@ class TestDeformConv:
x = torch.rand(batch_sz, n_in_channels, in_h, in_w, device=device, dtype=dtype, requires_grad=True) x = torch.rand(batch_sz, n_in_channels, in_h, in_w, device=device, dtype=dtype, requires_grad=True)
offset = torch.randn(batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w, offset = torch.randn(
device=device, dtype=dtype, requires_grad=True) batch_sz,
n_offset_grps * 2 * weight_h * weight_w,
out_h,
out_w,
device=device,
dtype=dtype,
requires_grad=True,
)
mask = torch.randn(batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w, mask = torch.randn(
device=device, dtype=dtype, requires_grad=True) batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w, device=device, dtype=dtype, requires_grad=True
)
weight = torch.randn(n_out_channels, n_in_channels // n_weight_grps, weight_h, weight_w, weight = torch.randn(
device=device, dtype=dtype, requires_grad=True) n_out_channels,
n_in_channels // n_weight_grps,
weight_h,
weight_w,
device=device,
dtype=dtype,
requires_grad=True,
)
bias = torch.randn(n_out_channels, device=device, dtype=dtype, requires_grad=True) bias = torch.randn(n_out_channels, device=device, dtype=dtype, requires_grad=True)
...@@ -649,9 +687,9 @@ class TestDeformConv: ...@@ -649,9 +687,9 @@ class TestDeformConv:
return x, weight, offset, mask, bias, stride, pad, dilation return x, weight, offset, mask, bias, stride, pad, dilation
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('contiguous', (True, False)) @pytest.mark.parametrize("contiguous", (True, False))
@pytest.mark.parametrize('batch_sz', (0, 33)) @pytest.mark.parametrize("batch_sz", (0, 33))
def test_forward(self, device, contiguous, batch_sz, dtype=None): def test_forward(self, device, contiguous, batch_sz, dtype=None):
dtype = dtype or self.dtype dtype = dtype or self.dtype
x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype) x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype)
...@@ -661,8 +699,9 @@ class TestDeformConv: ...@@ -661,8 +699,9 @@ class TestDeformConv:
groups = 2 groups = 2
tol = 2e-3 if dtype is torch.half else 1e-5 tol = 2e-3 if dtype is torch.half else 1e-5
layer = ops.DeformConv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, layer = ops.DeformConv2d(
dilation=dilation, groups=groups).to(device=x.device, dtype=dtype) in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups
).to(device=x.device, dtype=dtype)
res = layer(x, offset, mask) res = layer(x, offset, mask)
weight = layer.weight.data weight = layer.weight.data
...@@ -670,7 +709,7 @@ class TestDeformConv: ...@@ -670,7 +709,7 @@ class TestDeformConv:
expected = self.expected_fn(x, weight, offset, mask, bias, stride=stride, padding=padding, dilation=dilation) expected = self.expected_fn(x, weight, offset, mask, bias, stride=stride, padding=padding, dilation=dilation)
torch.testing.assert_close( torch.testing.assert_close(
res.to(expected), expected, rtol=tol, atol=tol, msg='\nres:\n{}\nexpected:\n{}'.format(res, expected) res.to(expected), expected, rtol=tol, atol=tol, msg="\nres:\n{}\nexpected:\n{}".format(res, expected)
) )
# no modulation test # no modulation test
...@@ -678,7 +717,7 @@ class TestDeformConv: ...@@ -678,7 +717,7 @@ class TestDeformConv:
expected = self.expected_fn(x, weight, offset, None, bias, stride=stride, padding=padding, dilation=dilation) expected = self.expected_fn(x, weight, offset, None, bias, stride=stride, padding=padding, dilation=dilation)
torch.testing.assert_close( torch.testing.assert_close(
res.to(expected), expected, rtol=tol, atol=tol, msg='\nres:\n{}\nexpected:\n{}'.format(res, expected) res.to(expected), expected, rtol=tol, atol=tol, msg="\nres:\n{}\nexpected:\n{}".format(res, expected)
) )
def test_wrong_sizes(self): def test_wrong_sizes(self):
...@@ -686,57 +725,72 @@ class TestDeformConv: ...@@ -686,57 +725,72 @@ class TestDeformConv:
out_channels = 2 out_channels = 2
kernel_size = (3, 2) kernel_size = (3, 2)
groups = 2 groups = 2
x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args('cpu', contiguous=True, x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(
batch_sz=10, dtype=self.dtype) "cpu", contiguous=True, batch_sz=10, dtype=self.dtype
layer = ops.DeformConv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, )
dilation=dilation, groups=groups) layer = ops.DeformConv2d(
in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups
)
with pytest.raises(RuntimeError, match="the shape of the offset"): with pytest.raises(RuntimeError, match="the shape of the offset"):
wrong_offset = torch.rand_like(offset[:, :2]) wrong_offset = torch.rand_like(offset[:, :2])
layer(x, wrong_offset) layer(x, wrong_offset)
with pytest.raises(RuntimeError, match=r'mask.shape\[1\] is not valid'): with pytest.raises(RuntimeError, match=r"mask.shape\[1\] is not valid"):
wrong_mask = torch.rand_like(mask[:, :2]) wrong_mask = torch.rand_like(mask[:, :2])
layer(x, offset, wrong_mask) layer(x, offset, wrong_mask)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('contiguous', (True, False)) @pytest.mark.parametrize("contiguous", (True, False))
@pytest.mark.parametrize('batch_sz', (0, 33)) @pytest.mark.parametrize("batch_sz", (0, 33))
def test_backward(self, device, contiguous, batch_sz): def test_backward(self, device, contiguous, batch_sz):
x, weight, offset, mask, bias, stride, padding, dilation = self.get_fn_args(device, contiguous, x, weight, offset, mask, bias, stride, padding, dilation = self.get_fn_args(
batch_sz, self.dtype) device, contiguous, batch_sz, self.dtype
)
def func(x_, offset_, mask_, weight_, bias_): def func(x_, offset_, mask_, weight_, bias_):
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride, return ops.deform_conv2d(
padding=padding, dilation=dilation, mask=mask_) x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation, mask=mask_
)
gradcheck(func, (x, offset, mask, weight, bias), nondet_tol=1e-5, fast_mode=True) gradcheck(func, (x, offset, mask, weight, bias), nondet_tol=1e-5, fast_mode=True)
def func_no_mask(x_, offset_, weight_, bias_): def func_no_mask(x_, offset_, weight_, bias_):
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride, return ops.deform_conv2d(
padding=padding, dilation=dilation, mask=None) x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation, mask=None
)
gradcheck(func_no_mask, (x, offset, weight, bias), nondet_tol=1e-5, fast_mode=True) gradcheck(func_no_mask, (x, offset, weight, bias), nondet_tol=1e-5, fast_mode=True)
@torch.jit.script @torch.jit.script
def script_func(x_, offset_, mask_, weight_, bias_, stride_, pad_, dilation_): def script_func(x_, offset_, mask_, weight_, bias_, stride_, pad_, dilation_):
# type:(Tensor, Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor # type:(Tensor, Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride_, return ops.deform_conv2d(
padding=pad_, dilation=dilation_, mask=mask_) x_, offset_, weight_, bias_, stride=stride_, padding=pad_, dilation=dilation_, mask=mask_
)
gradcheck(lambda z, off, msk, wei, bi: script_func(z, off, msk, wei, bi, stride, padding, dilation),
(x, offset, mask, weight, bias), nondet_tol=1e-5, fast_mode=True) gradcheck(
lambda z, off, msk, wei, bi: script_func(z, off, msk, wei, bi, stride, padding, dilation),
(x, offset, mask, weight, bias),
nondet_tol=1e-5,
fast_mode=True,
)
@torch.jit.script @torch.jit.script
def script_func_no_mask(x_, offset_, weight_, bias_, stride_, pad_, dilation_): def script_func_no_mask(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
# type:(Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor # type:(Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride_, return ops.deform_conv2d(
padding=pad_, dilation=dilation_, mask=None) x_, offset_, weight_, bias_, stride=stride_, padding=pad_, dilation=dilation_, mask=None
)
gradcheck(lambda z, off, wei, bi: script_func_no_mask(z, off, wei, bi, stride, padding, dilation),
(x, offset, weight, bias), nondet_tol=1e-5, fast_mode=True) gradcheck(
lambda z, off, wei, bi: script_func_no_mask(z, off, wei, bi, stride, padding, dilation),
(x, offset, weight, bias),
nondet_tol=1e-5,
fast_mode=True,
)
@needs_cuda @needs_cuda
@pytest.mark.parametrize('contiguous', (True, False)) @pytest.mark.parametrize("contiguous", (True, False))
def test_compare_cpu_cuda_grads(self, contiguous): def test_compare_cpu_cuda_grads(self, contiguous):
# Test from https://github.com/pytorch/vision/issues/2598 # Test from https://github.com/pytorch/vision/issues/2598
# Run on CUDA only # Run on CUDA only
...@@ -770,8 +824,8 @@ class TestDeformConv: ...@@ -770,8 +824,8 @@ class TestDeformConv:
torch.testing.assert_close(true_cpu_grads, res_grads) torch.testing.assert_close(true_cpu_grads, res_grads)
@needs_cuda @needs_cuda
@pytest.mark.parametrize('batch_sz', (0, 33)) @pytest.mark.parametrize("batch_sz", (0, 33))
@pytest.mark.parametrize('dtype', (torch.float, torch.half)) @pytest.mark.parametrize("dtype", (torch.float, torch.half))
def test_autocast(self, batch_sz, dtype): def test_autocast(self, batch_sz, dtype):
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
self.test_forward(torch.device("cuda"), contiguous=False, batch_sz=batch_sz, dtype=dtype) self.test_forward(torch.device("cuda"), contiguous=False, batch_sz=batch_sz, dtype=dtype)
...@@ -794,11 +848,13 @@ class TestFrozenBNT: ...@@ -794,11 +848,13 @@ class TestFrozenBNT:
def test_frozenbatchnorm2d_eps(self): def test_frozenbatchnorm2d_eps(self):
sample_size = (4, 32, 28, 28) sample_size = (4, 32, 28, 28)
x = torch.rand(sample_size) x = torch.rand(sample_size)
state_dict = dict(weight=torch.rand(sample_size[1]), state_dict = dict(
bias=torch.rand(sample_size[1]), weight=torch.rand(sample_size[1]),
running_mean=torch.rand(sample_size[1]), bias=torch.rand(sample_size[1]),
running_var=torch.rand(sample_size[1]), running_mean=torch.rand(sample_size[1]),
num_batches_tracked=torch.tensor(100)) running_var=torch.rand(sample_size[1]),
num_batches_tracked=torch.tensor(100),
)
# Check that default eps is equal to the one of BN # Check that default eps is equal to the one of BN
fbn = ops.misc.FrozenBatchNorm2d(sample_size[1]) fbn = ops.misc.FrozenBatchNorm2d(sample_size[1])
...@@ -826,17 +882,19 @@ class TestBoxConversion: ...@@ -826,17 +882,19 @@ class TestBoxConversion:
def _get_box_sequences(): def _get_box_sequences():
# Define here the argument type of `boxes` supported by region pooling operations # Define here the argument type of `boxes` supported by region pooling operations
box_tensor = torch.tensor([[0, 0, 0, 100, 100], [1, 0, 0, 100, 100]], dtype=torch.float) box_tensor = torch.tensor([[0, 0, 0, 100, 100], [1, 0, 0, 100, 100]], dtype=torch.float)
box_list = [torch.tensor([[0, 0, 100, 100]], dtype=torch.float), box_list = [
torch.tensor([[0, 0, 100, 100]], dtype=torch.float)] torch.tensor([[0, 0, 100, 100]], dtype=torch.float),
torch.tensor([[0, 0, 100, 100]], dtype=torch.float),
]
box_tuple = tuple(box_list) box_tuple = tuple(box_list)
return box_tensor, box_list, box_tuple return box_tensor, box_list, box_tuple
@pytest.mark.parametrize('box_sequence', _get_box_sequences()) @pytest.mark.parametrize("box_sequence", _get_box_sequences())
def test_check_roi_boxes_shape(self, box_sequence): def test_check_roi_boxes_shape(self, box_sequence):
# Ensure common sequences of tensors are supported # Ensure common sequences of tensors are supported
ops._utils.check_roi_boxes_shape(box_sequence) ops._utils.check_roi_boxes_shape(box_sequence)
@pytest.mark.parametrize('box_sequence', _get_box_sequences()) @pytest.mark.parametrize("box_sequence", _get_box_sequences())
def test_convert_boxes_to_roi_format(self, box_sequence): def test_convert_boxes_to_roi_format(self, box_sequence):
# Ensure common sequences of tensors yield the same result # Ensure common sequences of tensors yield the same result
ref_tensor = None ref_tensor = None
...@@ -848,11 +906,11 @@ class TestBoxConversion: ...@@ -848,11 +906,11 @@ class TestBoxConversion:
class TestBox: class TestBox:
def test_bbox_same(self): def test_bbox_same(self):
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], box_tensor = torch.tensor(
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
)
exp_xyxy = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], exp_xyxy = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
assert exp_xyxy.size() == torch.Size([4, 4]) assert exp_xyxy.size() == torch.Size([4, 4])
assert_equal(ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xyxy"), exp_xyxy) assert_equal(ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xyxy"), exp_xyxy)
...@@ -862,10 +920,10 @@ class TestBox: ...@@ -862,10 +920,10 @@ class TestBox:
def test_bbox_xyxy_xywh(self): def test_bbox_xyxy_xywh(self):
# Simple test convert boxes to xywh and back. Make sure they are same. # Simple test convert boxes to xywh and back. Make sure they are same.
# box_tensor is in x1 y1 x2 y2 format. # box_tensor is in x1 y1 x2 y2 format.
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], box_tensor = torch.tensor(
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
exp_xywh = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], )
[10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float) exp_xywh = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float)
assert exp_xywh.size() == torch.Size([4, 4]) assert exp_xywh.size() == torch.Size([4, 4])
box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh") box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh")
...@@ -878,10 +936,12 @@ class TestBox: ...@@ -878,10 +936,12 @@ class TestBox:
def test_bbox_xyxy_cxcywh(self): def test_bbox_xyxy_cxcywh(self):
# Simple test convert boxes to xywh and back. Make sure they are same. # Simple test convert boxes to xywh and back. Make sure they are same.
# box_tensor is in x1 y1 x2 y2 format. # box_tensor is in x1 y1 x2 y2 format.
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], box_tensor = torch.tensor(
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
exp_cxcywh = torch.tensor([[50, 50, 100, 100], [0, 0, 0, 0], )
[20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float) exp_cxcywh = torch.tensor(
[[50, 50, 100, 100], [0, 0, 0, 0], [20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float
)
assert exp_cxcywh.size() == torch.Size([4, 4]) assert exp_cxcywh.size() == torch.Size([4, 4])
box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh") box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh")
...@@ -892,12 +952,14 @@ class TestBox: ...@@ -892,12 +952,14 @@ class TestBox:
assert_equal(box_xyxy, box_tensor) assert_equal(box_xyxy, box_tensor)
def test_bbox_xywh_cxcywh(self): def test_bbox_xywh_cxcywh(self):
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], box_tensor = torch.tensor(
[10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float) [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float
)
# This is wrong # This is wrong
exp_cxcywh = torch.tensor([[50, 50, 100, 100], [0, 0, 0, 0], exp_cxcywh = torch.tensor(
[20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float) [[50, 50, 100, 100], [0, 0, 0, 0], [20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float
)
assert exp_cxcywh.size() == torch.Size([4, 4]) assert exp_cxcywh.size() == torch.Size([4, 4])
box_cxcywh = ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="cxcywh") box_cxcywh = ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="cxcywh")
...@@ -907,28 +969,30 @@ class TestBox: ...@@ -907,28 +969,30 @@ class TestBox:
box_xywh = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xywh") box_xywh = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xywh")
assert_equal(box_xywh, box_tensor) assert_equal(box_xywh, box_tensor)
@pytest.mark.parametrize('inv_infmt', ["xwyh", "cxwyh"]) @pytest.mark.parametrize("inv_infmt", ["xwyh", "cxwyh"])
@pytest.mark.parametrize('inv_outfmt', ["xwcx", "xhwcy"]) @pytest.mark.parametrize("inv_outfmt", ["xwcx", "xhwcy"])
def test_bbox_invalid(self, inv_infmt, inv_outfmt): def test_bbox_invalid(self, inv_infmt, inv_outfmt):
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], box_tensor = torch.tensor(
[10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float) [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float
)
with pytest.raises(ValueError): with pytest.raises(ValueError):
ops.box_convert(box_tensor, inv_infmt, inv_outfmt) ops.box_convert(box_tensor, inv_infmt, inv_outfmt)
def test_bbox_convert_jit(self): def test_bbox_convert_jit(self):
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], box_tensor = torch.tensor(
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
)
scripted_fn = torch.jit.script(ops.box_convert) scripted_fn = torch.jit.script(ops.box_convert)
TOLERANCE = 1e-3 TOLERANCE = 1e-3
box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh") box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh")
scripted_xywh = scripted_fn(box_tensor, 'xyxy', 'xywh') scripted_xywh = scripted_fn(box_tensor, "xyxy", "xywh")
torch.testing.assert_close(scripted_xywh, box_xywh, rtol=0.0, atol=TOLERANCE) torch.testing.assert_close(scripted_xywh, box_xywh, rtol=0.0, atol=TOLERANCE)
box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh") box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh")
scripted_cxcywh = scripted_fn(box_tensor, 'xyxy', 'cxcywh') scripted_cxcywh = scripted_fn(box_tensor, "xyxy", "cxcywh")
torch.testing.assert_close(scripted_cxcywh, box_cxcywh, rtol=0.0, atol=TOLERANCE) torch.testing.assert_close(scripted_cxcywh, box_cxcywh, rtol=0.0, atol=TOLERANCE)
...@@ -946,16 +1010,22 @@ class TestBoxArea: ...@@ -946,16 +1010,22 @@ class TestBoxArea:
# Check for float32 and float64 boxes # Check for float32 and float64 boxes
for dtype in [torch.float32, torch.float64]: for dtype in [torch.float32, torch.float64]:
box_tensor = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551], box_tensor = torch.tensor(
[285.1472, 188.7374, 1192.4984, 851.0669], [
[279.2440, 197.9812, 1189.4746, 849.2019]], dtype=dtype) [285.3538, 185.5758, 1193.5110, 851.4551],
[285.1472, 188.7374, 1192.4984, 851.0669],
[279.2440, 197.9812, 1189.4746, 849.2019],
],
dtype=dtype,
)
expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=torch.float64) expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=torch.float64)
area_check(box_tensor, expected, tolerance=0.05) area_check(box_tensor, expected, tolerance=0.05)
# Check for float16 box # Check for float16 box
box_tensor = torch.tensor([[285.25, 185.625, 1194.0, 851.5], box_tensor = torch.tensor(
[285.25, 188.75, 1192.0, 851.0], [[285.25, 185.625, 1194.0, 851.5], [285.25, 188.75, 1192.0, 851.0], [279.25, 198.0, 1189.0, 849.0]],
[279.25, 198.0, 1189.0, 849.0]], dtype=torch.float16) dtype=torch.float16,
)
expected = torch.tensor([605113.875, 600495.1875, 592247.25]) expected = torch.tensor([605113.875, 600495.1875, 592247.25])
area_check(box_tensor, expected) area_check(box_tensor, expected)
...@@ -982,9 +1052,14 @@ class TestBoxIou: ...@@ -982,9 +1052,14 @@ class TestBoxIou:
# Check for float boxes # Check for float boxes
for dtype in [torch.float16, torch.float32, torch.float64]: for dtype in [torch.float16, torch.float32, torch.float64]:
box_tensor = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551], box_tensor = torch.tensor(
[285.1472, 188.7374, 1192.4984, 851.0669], [
[279.2440, 197.9812, 1189.4746, 849.2019]], dtype=dtype) [285.3538, 185.5758, 1193.5110, 851.4551],
[285.1472, 188.7374, 1192.4984, 851.0669],
[279.2440, 197.9812, 1189.4746, 849.2019],
],
dtype=dtype,
)
expected = torch.tensor([[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]) expected = torch.tensor([[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]])
iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-4) iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-4)
...@@ -1011,9 +1086,14 @@ class TestGenBoxIou: ...@@ -1011,9 +1086,14 @@ class TestGenBoxIou:
# Check for float boxes # Check for float boxes
for dtype in [torch.float16, torch.float32, torch.float64]: for dtype in [torch.float16, torch.float32, torch.float64]:
box_tensor = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551], box_tensor = torch.tensor(
[285.1472, 188.7374, 1192.4984, 851.0669], [
[279.2440, 197.9812, 1189.4746, 849.2019]], dtype=dtype) [285.3538, 185.5758, 1193.5110, 851.4551],
[285.1472, 188.7374, 1192.4984, 851.0669],
[279.2440, 197.9812, 1189.4746, 849.2019],
],
dtype=dtype,
)
expected = torch.tensor([[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]) expected = torch.tensor([[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]])
gen_iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-3) gen_iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-3)
...@@ -1048,8 +1128,18 @@ class TestMasksToBoxes: ...@@ -1048,8 +1128,18 @@ class TestMasksToBoxes:
return masks return masks
expected = torch.tensor([[127, 2, 165, 40], [2, 50, 44, 92], [56, 63, 98, 100], [139, 68, 175, 104], expected = torch.tensor(
[160, 112, 198, 145], [49, 138, 99, 182], [108, 148, 152, 213]], dtype=torch.float) [
[127, 2, 165, 40],
[2, 50, 44, 92],
[56, 63, 98, 100],
[139, 68, 175, 104],
[160, 112, 198, 145],
[49, 138, 99, 182],
[108, 148, 152, 213],
],
dtype=torch.float,
)
image = _get_image() image = _get_image()
for dtype in [torch.float16, torch.float32, torch.float64]: for dtype in [torch.float16, torch.float32, torch.float64]:
...@@ -1059,8 +1149,8 @@ class TestMasksToBoxes: ...@@ -1059,8 +1149,8 @@ class TestMasksToBoxes:
class TestStochasticDepth: class TestStochasticDepth:
@pytest.mark.parametrize('p', [0.2, 0.5, 0.8]) @pytest.mark.parametrize("p", [0.2, 0.5, 0.8])
@pytest.mark.parametrize('mode', ["batch", "row"]) @pytest.mark.parametrize("mode", ["batch", "row"])
def test_stochastic_depth(self, mode, p): def test_stochastic_depth(self, mode, p):
stats = pytest.importorskip("scipy.stats") stats = pytest.importorskip("scipy.stats")
batch_size = 5 batch_size = 5
...@@ -1086,5 +1176,5 @@ class TestStochasticDepth: ...@@ -1086,5 +1176,5 @@ class TestStochasticDepth:
assert p_value > 0.0001 assert p_value > 0.0001
if __name__ == '__main__': if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
import math
import os import os
import random
import numpy as np
import pytest
import torch import torch
import torchvision.transforms as transforms import torchvision.transforms as transforms
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
import torchvision.transforms.functional_tensor as F_t import torchvision.transforms.functional_tensor as F_t
from torch._utils_internal import get_file_path_2
import math
import random
import numpy as np
import pytest
from PIL import Image from PIL import Image
from torch._utils_internal import get_file_path_2
try: try:
import accimage import accimage
except ImportError: except ImportError:
...@@ -23,17 +25,18 @@ from common_utils import cycle_over, int_dtypes, float_dtypes, assert_equal ...@@ -23,17 +25,18 @@ from common_utils import cycle_over, int_dtypes, float_dtypes, assert_equal
GRACE_HOPPER = get_file_path_2( GRACE_HOPPER = get_file_path_2(
os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', 'grace_hopper_517x606.jpg') os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg"
)
def _get_grayscale_test_image(img, fill=None): def _get_grayscale_test_image(img, fill=None):
img = img.convert('L') img = img.convert("L")
fill = (fill[0], ) if isinstance(fill, tuple) else fill fill = (fill[0],) if isinstance(fill, tuple) else fill
return img, fill return img, fill
class TestConvertImageDtype: class TestConvertImageDtype:
@pytest.mark.parametrize('input_dtype, output_dtype', cycle_over(float_dtypes())) @pytest.mark.parametrize("input_dtype, output_dtype", cycle_over(float_dtypes()))
def test_float_to_float(self, input_dtype, output_dtype): def test_float_to_float(self, input_dtype, output_dtype):
input_image = torch.tensor((0.0, 1.0), dtype=input_dtype) input_image = torch.tensor((0.0, 1.0), dtype=input_dtype)
transform = transforms.ConvertImageDtype(output_dtype) transform = transforms.ConvertImageDtype(output_dtype)
...@@ -50,15 +53,15 @@ class TestConvertImageDtype: ...@@ -50,15 +53,15 @@ class TestConvertImageDtype:
assert abs(actual_min - desired_min) < 1e-7 assert abs(actual_min - desired_min) < 1e-7
assert abs(actual_max - desired_max) < 1e-7 assert abs(actual_max - desired_max) < 1e-7
@pytest.mark.parametrize('input_dtype', float_dtypes()) @pytest.mark.parametrize("input_dtype", float_dtypes())
@pytest.mark.parametrize('output_dtype', int_dtypes()) @pytest.mark.parametrize("output_dtype", int_dtypes())
def test_float_to_int(self, input_dtype, output_dtype): def test_float_to_int(self, input_dtype, output_dtype):
input_image = torch.tensor((0.0, 1.0), dtype=input_dtype) input_image = torch.tensor((0.0, 1.0), dtype=input_dtype)
transform = transforms.ConvertImageDtype(output_dtype) transform = transforms.ConvertImageDtype(output_dtype)
transform_script = torch.jit.script(F.convert_image_dtype) transform_script = torch.jit.script(F.convert_image_dtype)
if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or ( if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or (
input_dtype == torch.float64 and output_dtype == torch.int64 input_dtype == torch.float64 and output_dtype == torch.int64
): ):
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
transform(input_image) transform(input_image)
...@@ -74,8 +77,8 @@ class TestConvertImageDtype: ...@@ -74,8 +77,8 @@ class TestConvertImageDtype:
assert actual_min == desired_min assert actual_min == desired_min
assert actual_max == desired_max assert actual_max == desired_max
@pytest.mark.parametrize('input_dtype', int_dtypes()) @pytest.mark.parametrize("input_dtype", int_dtypes())
@pytest.mark.parametrize('output_dtype', float_dtypes()) @pytest.mark.parametrize("output_dtype", float_dtypes())
def test_int_to_float(self, input_dtype, output_dtype): def test_int_to_float(self, input_dtype, output_dtype):
input_image = torch.tensor((0, torch.iinfo(input_dtype).max), dtype=input_dtype) input_image = torch.tensor((0, torch.iinfo(input_dtype).max), dtype=input_dtype)
transform = transforms.ConvertImageDtype(output_dtype) transform = transforms.ConvertImageDtype(output_dtype)
...@@ -94,7 +97,7 @@ class TestConvertImageDtype: ...@@ -94,7 +97,7 @@ class TestConvertImageDtype:
assert abs(actual_max - desired_max) < 1e-7 assert abs(actual_max - desired_max) < 1e-7
assert actual_max <= desired_max assert actual_max <= desired_max
@pytest.mark.parametrize('input_dtype, output_dtype', cycle_over(int_dtypes())) @pytest.mark.parametrize("input_dtype, output_dtype", cycle_over(int_dtypes()))
def test_dtype_int_to_int(self, input_dtype, output_dtype): def test_dtype_int_to_int(self, input_dtype, output_dtype):
input_max = torch.iinfo(input_dtype).max input_max = torch.iinfo(input_dtype).max
input_image = torch.tensor((0, input_max), dtype=input_dtype) input_image = torch.tensor((0, input_max), dtype=input_dtype)
...@@ -126,7 +129,7 @@ class TestConvertImageDtype: ...@@ -126,7 +129,7 @@ class TestConvertImageDtype:
assert actual_min == desired_min assert actual_min == desired_min
assert actual_max == (desired_max + error_term) assert actual_max == (desired_max + error_term)
@pytest.mark.parametrize('input_dtype, output_dtype', cycle_over(int_dtypes())) @pytest.mark.parametrize("input_dtype, output_dtype", cycle_over(int_dtypes()))
def test_int_to_int_consistency(self, input_dtype, output_dtype): def test_int_to_int_consistency(self, input_dtype, output_dtype):
input_max = torch.iinfo(input_dtype).max input_max = torch.iinfo(input_dtype).max
input_image = torch.tensor((0, input_max), dtype=input_dtype) input_image = torch.tensor((0, input_max), dtype=input_dtype)
...@@ -148,11 +151,10 @@ class TestConvertImageDtype: ...@@ -148,11 +151,10 @@ class TestConvertImageDtype:
@pytest.mark.skipif(accimage is None, reason="accimage not available") @pytest.mark.skipif(accimage is None, reason="accimage not available")
class TestAccImage: class TestAccImage:
def test_accimage_to_tensor(self): def test_accimage_to_tensor(self):
trans = transforms.ToTensor() trans = transforms.ToTensor()
expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB')) expected_output = trans(Image.open(GRACE_HOPPER).convert("RGB"))
output = trans(accimage.Image(GRACE_HOPPER)) output = trans(accimage.Image(GRACE_HOPPER))
torch.testing.assert_close(output, expected_output) torch.testing.assert_close(output, expected_output)
...@@ -160,22 +162,24 @@ class TestAccImage: ...@@ -160,22 +162,24 @@ class TestAccImage:
def test_accimage_pil_to_tensor(self): def test_accimage_pil_to_tensor(self):
trans = transforms.PILToTensor() trans = transforms.PILToTensor()
expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB')) expected_output = trans(Image.open(GRACE_HOPPER).convert("RGB"))
output = trans(accimage.Image(GRACE_HOPPER)) output = trans(accimage.Image(GRACE_HOPPER))
assert expected_output.size() == output.size() assert expected_output.size() == output.size()
torch.testing.assert_close(output, expected_output) torch.testing.assert_close(output, expected_output)
def test_accimage_resize(self): def test_accimage_resize(self):
trans = transforms.Compose([ trans = transforms.Compose(
transforms.Resize(256, interpolation=Image.LINEAR), [
transforms.ToTensor(), transforms.Resize(256, interpolation=Image.LINEAR),
]) transforms.ToTensor(),
]
)
# Checking if Compose, Resize and ToTensor can be printed as string # Checking if Compose, Resize and ToTensor can be printed as string
trans.__repr__() trans.__repr__()
expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB')) expected_output = trans(Image.open(GRACE_HOPPER).convert("RGB"))
output = trans(accimage.Image(GRACE_HOPPER)) output = trans(accimage.Image(GRACE_HOPPER))
assert expected_output.size() == output.size() assert expected_output.size() == output.size()
...@@ -185,15 +189,17 @@ class TestAccImage: ...@@ -185,15 +189,17 @@ class TestAccImage:
torch.testing.assert_close(output.numpy(), expected_output.numpy(), rtol=1e-5, atol=5e-2) torch.testing.assert_close(output.numpy(), expected_output.numpy(), rtol=1e-5, atol=5e-2)
def test_accimage_crop(self): def test_accimage_crop(self):
trans = transforms.Compose([ trans = transforms.Compose(
transforms.CenterCrop(256), [
transforms.ToTensor(), transforms.CenterCrop(256),
]) transforms.ToTensor(),
]
)
# Checking if Compose, CenterCrop and ToTensor can be printed as string # Checking if Compose, CenterCrop and ToTensor can be printed as string
trans.__repr__() trans.__repr__()
expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB')) expected_output = trans(Image.open(GRACE_HOPPER).convert("RGB"))
output = trans(accimage.Image(GRACE_HOPPER)) output = trans(accimage.Image(GRACE_HOPPER))
assert expected_output.size() == output.size() assert expected_output.size() == output.size()
...@@ -201,8 +207,7 @@ class TestAccImage: ...@@ -201,8 +207,7 @@ class TestAccImage:
class TestToTensor: class TestToTensor:
@pytest.mark.parametrize("channels", [1, 3, 4])
@pytest.mark.parametrize('channels', [1, 3, 4])
def test_to_tensor(self, channels): def test_to_tensor(self, channels):
height, width = 4, 4 height, width = 4, 4
trans = transforms.ToTensor() trans = transforms.ToTensor()
...@@ -225,7 +230,7 @@ class TestToTensor: ...@@ -225,7 +230,7 @@ class TestToTensor:
# separate test for mode '1' PIL images # separate test for mode '1' PIL images
input_data = torch.ByteTensor(1, height, width).bernoulli_() input_data = torch.ByteTensor(1, height, width).bernoulli_()
img = transforms.ToPILImage()(input_data.mul(255)).convert('1') img = transforms.ToPILImage()(input_data.mul(255)).convert("1")
output = trans(img) output = trans(img)
torch.testing.assert_close(input_data, output, check_dtype=False) torch.testing.assert_close(input_data, output, check_dtype=False)
...@@ -243,7 +248,7 @@ class TestToTensor: ...@@ -243,7 +248,7 @@ class TestToTensor:
with pytest.raises(ValueError): with pytest.raises(ValueError):
trans(np_rng.rand(1, 1, height, width)) trans(np_rng.rand(1, 1, height, width))
@pytest.mark.parametrize('dtype', [torch.float16, torch.float, torch.double]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float, torch.double])
def test_to_tensor_with_other_default_dtypes(self, dtype): def test_to_tensor_with_other_default_dtypes(self, dtype):
np_rng = np.random.RandomState(0) np_rng = np.random.RandomState(0)
current_def_dtype = torch.get_default_dtype() current_def_dtype = torch.get_default_dtype()
...@@ -258,7 +263,7 @@ class TestToTensor: ...@@ -258,7 +263,7 @@ class TestToTensor:
torch.set_default_dtype(current_def_dtype) torch.set_default_dtype(current_def_dtype)
@pytest.mark.parametrize('channels', [1, 3, 4]) @pytest.mark.parametrize("channels", [1, 3, 4])
def test_pil_to_tensor(self, channels): def test_pil_to_tensor(self, channels):
height, width = 4, 4 height, width = 4, 4
trans = transforms.PILToTensor() trans = transforms.PILToTensor()
...@@ -283,7 +288,7 @@ class TestToTensor: ...@@ -283,7 +288,7 @@ class TestToTensor:
# separate test for mode '1' PIL images # separate test for mode '1' PIL images
input_data = torch.ByteTensor(1, height, width).bernoulli_() input_data = torch.ByteTensor(1, height, width).bernoulli_()
img = transforms.ToPILImage()(input_data.mul(255)).convert('1') img = transforms.ToPILImage()(input_data.mul(255)).convert("1")
output = trans(img).view(torch.uint8).bool().to(torch.uint8) output = trans(img).view(torch.uint8).bool().to(torch.uint8)
torch.testing.assert_close(input_data, output) torch.testing.assert_close(input_data, output)
...@@ -316,34 +321,47 @@ def test_randomresized_params(): ...@@ -316,34 +321,47 @@ def test_randomresized_params():
randresizecrop = transforms.RandomResizedCrop(size, scale_range, aspect_ratio_range) randresizecrop = transforms.RandomResizedCrop(size, scale_range, aspect_ratio_range)
i, j, h, w = randresizecrop.get_params(img, scale_range, aspect_ratio_range) i, j, h, w = randresizecrop.get_params(img, scale_range, aspect_ratio_range)
aspect_ratio_obtained = w / h aspect_ratio_obtained = w / h
assert((min(aspect_ratio_range) - epsilon <= aspect_ratio_obtained and assert (
aspect_ratio_obtained <= max(aspect_ratio_range) + epsilon) or min(aspect_ratio_range) - epsilon <= aspect_ratio_obtained
aspect_ratio_obtained == 1.0) and aspect_ratio_obtained <= max(aspect_ratio_range) + epsilon
) or aspect_ratio_obtained == 1.0
assert isinstance(i, int) assert isinstance(i, int)
assert isinstance(j, int) assert isinstance(j, int)
assert isinstance(h, int) assert isinstance(h, int)
assert isinstance(w, int) assert isinstance(w, int)
@pytest.mark.parametrize('height, width', [ @pytest.mark.parametrize(
# height, width "height, width",
# square image [
(28, 28), # height, width
(27, 27), # square image
# rectangular image: h < w (28, 28),
(28, 34), (27, 27),
(29, 35), # rectangular image: h < w
# rectangular image: h > w (28, 34),
(34, 28), (29, 35),
(35, 29), # rectangular image: h > w
]) (34, 28),
@pytest.mark.parametrize('osize', [ (35, 29),
# single integer ],
22, 27, 28, 36, )
# single integer in tuple/list @pytest.mark.parametrize(
[22, ], (27, ), "osize",
]) [
@pytest.mark.parametrize('max_size', (None, 37, 1000)) # single integer
22,
27,
28,
36,
# single integer in tuple/list
[
22,
],
(27,),
],
)
@pytest.mark.parametrize("max_size", (None, 37, 1000))
def test_resize(height, width, osize, max_size): def test_resize(height, width, osize, max_size):
img = Image.new("RGB", size=(width, height), color=127) img = Image.new("RGB", size=(width, height), color=127)
...@@ -371,24 +389,36 @@ def test_resize(height, width, osize, max_size): ...@@ -371,24 +389,36 @@ def test_resize(height, width, osize, max_size):
assert result.size == (exp_w, exp_h), msg assert result.size == (exp_w, exp_h), msg
@pytest.mark.parametrize('height, width', [ @pytest.mark.parametrize(
# height, width "height, width",
# square image [
(28, 28), # height, width
(27, 27), # square image
# rectangular image: h < w (28, 28),
(28, 34), (27, 27),
(29, 35), # rectangular image: h < w
# rectangular image: h > w (28, 34),
(34, 28), (29, 35),
(35, 29), # rectangular image: h > w
]) (34, 28),
@pytest.mark.parametrize('osize', [ (35, 29),
# two integers sequence output ],
[22, 22], [22, 28], [22, 36], )
[27, 22], [36, 22], [28, 28], @pytest.mark.parametrize(
[28, 37], [37, 27], [37, 37] "osize",
]) [
# two integers sequence output
[22, 22],
[22, 28],
[22, 36],
[27, 22],
[36, 22],
[28, 28],
[28, 37],
[37, 27],
[37, 37],
],
)
def test_resize_sequence_output(height, width, osize): def test_resize_sequence_output(height, width, osize):
img = Image.new("RGB", size=(width, height), color=127) img = Image.new("RGB", size=(width, height), color=127)
oheight, owidth = osize oheight, owidth = osize
...@@ -409,18 +439,19 @@ def test_resize_antialias_error(): ...@@ -409,18 +439,19 @@ def test_resize_antialias_error():
class TestPad: class TestPad:
def test_pad(self): def test_pad(self):
height = random.randint(10, 32) * 2 height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2 width = random.randint(10, 32) * 2
img = torch.ones(3, height, width) img = torch.ones(3, height, width)
padding = random.randint(1, 20) padding = random.randint(1, 20)
fill = random.randint(1, 50) fill = random.randint(1, 50)
result = transforms.Compose([ result = transforms.Compose(
transforms.ToPILImage(), [
transforms.Pad(padding, fill=fill), transforms.ToPILImage(),
transforms.ToTensor(), transforms.Pad(padding, fill=fill),
])(img) transforms.ToTensor(),
]
)(img)
assert result.size(1) == height + 2 * padding assert result.size(1) == height + 2 * padding
assert result.size(2) == width + 2 * padding assert result.size(2) == width + 2 * padding
# check that all elements in the padded region correspond # check that all elements in the padded region correspond
...@@ -429,14 +460,9 @@ class TestPad: ...@@ -429,14 +460,9 @@ class TestPad:
eps = 1e-5 eps = 1e-5
h_padded = result[:, :padding, :] h_padded = result[:, :padding, :]
w_padded = result[:, :, :padding] w_padded = result[:, :, :padding]
torch.testing.assert_close( torch.testing.assert_close(h_padded, torch.full_like(h_padded, fill_value=fill_v), rtol=0.0, atol=eps)
h_padded, torch.full_like(h_padded, fill_value=fill_v), rtol=0.0, atol=eps torch.testing.assert_close(w_padded, torch.full_like(w_padded, fill_value=fill_v), rtol=0.0, atol=eps)
) pytest.raises(ValueError, transforms.Pad(padding, fill=(1, 2)), transforms.ToPILImage()(img))
torch.testing.assert_close(
w_padded, torch.full_like(w_padded, fill_value=fill_v), rtol=0.0, atol=eps
)
pytest.raises(ValueError, transforms.Pad(padding, fill=(1, 2)),
transforms.ToPILImage()(img))
def test_pad_with_tuple_of_pad_values(self): def test_pad_with_tuple_of_pad_values(self):
height = random.randint(10, 32) * 2 height = random.randint(10, 32) * 2
...@@ -463,7 +489,7 @@ class TestPad: ...@@ -463,7 +489,7 @@ class TestPad:
img = F.pad(img, 1, (200, 200, 200)) img = F.pad(img, 1, (200, 200, 200))
# pad 3 to all sidess # pad 3 to all sidess
edge_padded_img = F.pad(img, 3, padding_mode='edge') edge_padded_img = F.pad(img, 3, padding_mode="edge")
# First 6 elements of leftmost edge in the middle of the image, values are in order: # First 6 elements of leftmost edge in the middle of the image, values are in order:
# edge_pad, edge_pad, edge_pad, constant_pad, constant value added to leftmost edge, 0 # edge_pad, edge_pad, edge_pad, constant_pad, constant value added to leftmost edge, 0
edge_middle_slice = np.asarray(edge_padded_img).transpose(2, 0, 1)[0][17][:6] edge_middle_slice = np.asarray(edge_padded_img).transpose(2, 0, 1)[0][17][:6]
...@@ -471,7 +497,7 @@ class TestPad: ...@@ -471,7 +497,7 @@ class TestPad:
assert transforms.ToTensor()(edge_padded_img).size() == (3, 35, 35) assert transforms.ToTensor()(edge_padded_img).size() == (3, 35, 35)
# Pad 3 to left/right, 2 to top/bottom # Pad 3 to left/right, 2 to top/bottom
reflect_padded_img = F.pad(img, (3, 2), padding_mode='reflect') reflect_padded_img = F.pad(img, (3, 2), padding_mode="reflect")
# First 6 elements of leftmost edge in the middle of the image, values are in order: # First 6 elements of leftmost edge in the middle of the image, values are in order:
# reflect_pad, reflect_pad, reflect_pad, constant_pad, constant value added to leftmost edge, 0 # reflect_pad, reflect_pad, reflect_pad, constant_pad, constant value added to leftmost edge, 0
reflect_middle_slice = np.asarray(reflect_padded_img).transpose(2, 0, 1)[0][17][:6] reflect_middle_slice = np.asarray(reflect_padded_img).transpose(2, 0, 1)[0][17][:6]
...@@ -479,7 +505,7 @@ class TestPad: ...@@ -479,7 +505,7 @@ class TestPad:
assert transforms.ToTensor()(reflect_padded_img).size() == (3, 33, 35) assert transforms.ToTensor()(reflect_padded_img).size() == (3, 33, 35)
# Pad 3 to left, 2 to top, 2 to right, 1 to bottom # Pad 3 to left, 2 to top, 2 to right, 1 to bottom
symmetric_padded_img = F.pad(img, (3, 2, 2, 1), padding_mode='symmetric') symmetric_padded_img = F.pad(img, (3, 2, 2, 1), padding_mode="symmetric")
# First 6 elements of leftmost edge in the middle of the image, values are in order: # First 6 elements of leftmost edge in the middle of the image, values are in order:
# sym_pad, sym_pad, sym_pad, constant_pad, constant value added to leftmost edge, 0 # sym_pad, sym_pad, sym_pad, constant_pad, constant value added to leftmost edge, 0
symmetric_middle_slice = np.asarray(symmetric_padded_img).transpose(2, 0, 1)[0][17][:6] symmetric_middle_slice = np.asarray(symmetric_padded_img).transpose(2, 0, 1)[0][17][:6]
...@@ -489,7 +515,7 @@ class TestPad: ...@@ -489,7 +515,7 @@ class TestPad:
# Check negative padding explicitly for symmetric case, since it is not # Check negative padding explicitly for symmetric case, since it is not
# implemented for tensor case to compare to # implemented for tensor case to compare to
# Crop 1 to left, pad 2 to top, pad 3 to right, crop 3 to bottom # Crop 1 to left, pad 2 to top, pad 3 to right, crop 3 to bottom
symmetric_padded_img_neg = F.pad(img, (-1, 2, 3, -3), padding_mode='symmetric') symmetric_padded_img_neg = F.pad(img, (-1, 2, 3, -3), padding_mode="symmetric")
symmetric_neg_middle_left = np.asarray(symmetric_padded_img_neg).transpose(2, 0, 1)[0][17][:3] symmetric_neg_middle_left = np.asarray(symmetric_padded_img_neg).transpose(2, 0, 1)[0][17][:3]
symmetric_neg_middle_right = np.asarray(symmetric_padded_img_neg).transpose(2, 0, 1)[0][17][-4:] symmetric_neg_middle_right = np.asarray(symmetric_padded_img_neg).transpose(2, 0, 1)[0][17][-4:]
assert_equal(symmetric_neg_middle_left, np.asarray([1, 0, 0], dtype=np.uint8)) assert_equal(symmetric_neg_middle_left, np.asarray([1, 0, 0], dtype=np.uint8))
...@@ -516,14 +542,18 @@ class TestPad: ...@@ -516,14 +542,18 @@ class TestPad:
@pytest.mark.skipif(stats is None, reason="scipy.stats not available") @pytest.mark.skipif(stats is None, reason="scipy.stats not available")
@pytest.mark.parametrize('fn, trans, config', [ @pytest.mark.parametrize(
(F.invert, transforms.RandomInvert, {}), "fn, trans, config",
(F.posterize, transforms.RandomPosterize, {"bits": 4}), [
(F.solarize, transforms.RandomSolarize, {"threshold": 192}), (F.invert, transforms.RandomInvert, {}),
(F.adjust_sharpness, transforms.RandomAdjustSharpness, {"sharpness_factor": 2.0}), (F.posterize, transforms.RandomPosterize, {"bits": 4}),
(F.autocontrast, transforms.RandomAutocontrast, {}), (F.solarize, transforms.RandomSolarize, {"threshold": 192}),
(F.equalize, transforms.RandomEqualize, {})]) (F.adjust_sharpness, transforms.RandomAdjustSharpness, {"sharpness_factor": 2.0}),
@pytest.mark.parametrize('p', (.5, .7)) (F.autocontrast, transforms.RandomAutocontrast, {}),
(F.equalize, transforms.RandomEqualize, {}),
],
)
@pytest.mark.parametrize("p", (0.5, 0.7))
def test_randomness(fn, trans, config, p): def test_randomness(fn, trans, config, p):
random_state = random.getstate() random_state = random.getstate()
random.seed(42) random.seed(42)
...@@ -546,43 +576,42 @@ def test_randomness(fn, trans, config, p): ...@@ -546,43 +576,42 @@ def test_randomness(fn, trans, config, p):
class TestToPil: class TestToPil:
def _get_1_channel_tensor_various_types(): def _get_1_channel_tensor_various_types():
img_data_float = torch.Tensor(1, 4, 4).uniform_() img_data_float = torch.Tensor(1, 4, 4).uniform_()
expected_output = img_data_float.mul(255).int().float().div(255).numpy() expected_output = img_data_float.mul(255).int().float().div(255).numpy()
yield img_data_float, expected_output, 'L' yield img_data_float, expected_output, "L"
img_data_byte = torch.ByteTensor(1, 4, 4).random_(0, 255) img_data_byte = torch.ByteTensor(1, 4, 4).random_(0, 255)
expected_output = img_data_byte.float().div(255.0).numpy() expected_output = img_data_byte.float().div(255.0).numpy()
yield img_data_byte, expected_output, 'L' yield img_data_byte, expected_output, "L"
img_data_short = torch.ShortTensor(1, 4, 4).random_() img_data_short = torch.ShortTensor(1, 4, 4).random_()
expected_output = img_data_short.numpy() expected_output = img_data_short.numpy()
yield img_data_short, expected_output, 'I;16' yield img_data_short, expected_output, "I;16"
img_data_int = torch.IntTensor(1, 4, 4).random_() img_data_int = torch.IntTensor(1, 4, 4).random_()
expected_output = img_data_int.numpy() expected_output = img_data_int.numpy()
yield img_data_int, expected_output, 'I' yield img_data_int, expected_output, "I"
def _get_2d_tensor_various_types(): def _get_2d_tensor_various_types():
img_data_float = torch.Tensor(4, 4).uniform_() img_data_float = torch.Tensor(4, 4).uniform_()
expected_output = img_data_float.mul(255).int().float().div(255).numpy() expected_output = img_data_float.mul(255).int().float().div(255).numpy()
yield img_data_float, expected_output, 'L' yield img_data_float, expected_output, "L"
img_data_byte = torch.ByteTensor(4, 4).random_(0, 255) img_data_byte = torch.ByteTensor(4, 4).random_(0, 255)
expected_output = img_data_byte.float().div(255.0).numpy() expected_output = img_data_byte.float().div(255.0).numpy()
yield img_data_byte, expected_output, 'L' yield img_data_byte, expected_output, "L"
img_data_short = torch.ShortTensor(4, 4).random_() img_data_short = torch.ShortTensor(4, 4).random_()
expected_output = img_data_short.numpy() expected_output = img_data_short.numpy()
yield img_data_short, expected_output, 'I;16' yield img_data_short, expected_output, "I;16"
img_data_int = torch.IntTensor(4, 4).random_() img_data_int = torch.IntTensor(4, 4).random_()
expected_output = img_data_int.numpy() expected_output = img_data_int.numpy()
yield img_data_int, expected_output, 'I' yield img_data_int, expected_output, "I"
@pytest.mark.parametrize('with_mode', [False, True]) @pytest.mark.parametrize("with_mode", [False, True])
@pytest.mark.parametrize('img_data, expected_output, expected_mode', _get_1_channel_tensor_various_types()) @pytest.mark.parametrize("img_data, expected_output, expected_mode", _get_1_channel_tensor_various_types())
def test_1_channel_tensor_to_pil_image(self, with_mode, img_data, expected_output, expected_mode): def test_1_channel_tensor_to_pil_image(self, with_mode, img_data, expected_output, expected_mode):
transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage() transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage()
to_tensor = transforms.ToTensor() to_tensor = transforms.ToTensor()
...@@ -594,19 +623,22 @@ class TestToPil: ...@@ -594,19 +623,22 @@ class TestToPil:
def test_1_channel_float_tensor_to_pil_image(self): def test_1_channel_float_tensor_to_pil_image(self):
img_data = torch.Tensor(1, 4, 4).uniform_() img_data = torch.Tensor(1, 4, 4).uniform_()
# 'F' mode for torch.FloatTensor # 'F' mode for torch.FloatTensor
img_F_mode = transforms.ToPILImage(mode='F')(img_data) img_F_mode = transforms.ToPILImage(mode="F")(img_data)
assert img_F_mode.mode == 'F' assert img_F_mode.mode == "F"
torch.testing.assert_close( torch.testing.assert_close(
np.array(Image.fromarray(img_data.squeeze(0).numpy(), mode='F')), np.array(img_F_mode) np.array(Image.fromarray(img_data.squeeze(0).numpy(), mode="F")), np.array(img_F_mode)
) )
@pytest.mark.parametrize('with_mode', [False, True]) @pytest.mark.parametrize("with_mode", [False, True])
@pytest.mark.parametrize('img_data, expected_mode', [ @pytest.mark.parametrize(
(torch.Tensor(4, 4, 1).uniform_().numpy(), 'F'), "img_data, expected_mode",
(torch.ByteTensor(4, 4, 1).random_(0, 255).numpy(), 'L'), [
(torch.ShortTensor(4, 4, 1).random_().numpy(), 'I;16'), (torch.Tensor(4, 4, 1).uniform_().numpy(), "F"),
(torch.IntTensor(4, 4, 1).random_().numpy(), 'I'), (torch.ByteTensor(4, 4, 1).random_(0, 255).numpy(), "L"),
]) (torch.ShortTensor(4, 4, 1).random_().numpy(), "I;16"),
(torch.IntTensor(4, 4, 1).random_().numpy(), "I"),
],
)
def test_1_channel_ndarray_to_pil_image(self, with_mode, img_data, expected_mode): def test_1_channel_ndarray_to_pil_image(self, with_mode, img_data, expected_mode):
transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage() transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage()
img = transform(img_data) img = transform(img_data)
...@@ -615,13 +647,13 @@ class TestToPil: ...@@ -615,13 +647,13 @@ class TestToPil:
# and otherwise assert_close wouldn't be able to construct a tensor from the uint16 array # and otherwise assert_close wouldn't be able to construct a tensor from the uint16 array
torch.testing.assert_close(img_data[:, :, 0], np.asarray(img).astype(img_data.dtype)) torch.testing.assert_close(img_data[:, :, 0], np.asarray(img).astype(img_data.dtype))
@pytest.mark.parametrize('expected_mode', [None, 'LA']) @pytest.mark.parametrize("expected_mode", [None, "LA"])
def test_2_channel_ndarray_to_pil_image(self, expected_mode): def test_2_channel_ndarray_to_pil_image(self, expected_mode):
img_data = torch.ByteTensor(4, 4, 2).random_(0, 255).numpy() img_data = torch.ByteTensor(4, 4, 2).random_(0, 255).numpy()
if expected_mode is None: if expected_mode is None:
img = transforms.ToPILImage()(img_data) img = transforms.ToPILImage()(img_data)
assert img.mode == 'LA' # default should assume LA assert img.mode == "LA" # default should assume LA
else: else:
img = transforms.ToPILImage(mode=expected_mode)(img_data) img = transforms.ToPILImage(mode=expected_mode)(img_data)
assert img.mode == expected_mode assert img.mode == expected_mode
...@@ -635,19 +667,19 @@ class TestToPil: ...@@ -635,19 +667,19 @@ class TestToPil:
# should raise if we try a mode for 4 or 1 or 3 channel images # should raise if we try a mode for 4 or 1 or 3 channel images
with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"): with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"):
transforms.ToPILImage(mode='RGBA')(img_data) transforms.ToPILImage(mode="RGBA")(img_data)
with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"): with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"):
transforms.ToPILImage(mode='P')(img_data) transforms.ToPILImage(mode="P")(img_data)
with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"): with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"):
transforms.ToPILImage(mode='RGB')(img_data) transforms.ToPILImage(mode="RGB")(img_data)
@pytest.mark.parametrize('expected_mode', [None, 'LA']) @pytest.mark.parametrize("expected_mode", [None, "LA"])
def test_2_channel_tensor_to_pil_image(self, expected_mode): def test_2_channel_tensor_to_pil_image(self, expected_mode):
img_data = torch.Tensor(2, 4, 4).uniform_() img_data = torch.Tensor(2, 4, 4).uniform_()
expected_output = img_data.mul(255).int().float().div(255) expected_output = img_data.mul(255).int().float().div(255)
if expected_mode is None: if expected_mode is None:
img = transforms.ToPILImage()(img_data) img = transforms.ToPILImage()(img_data)
assert img.mode == 'LA' # default should assume LA assert img.mode == "LA" # default should assume LA
else: else:
img = transforms.ToPILImage(mode=expected_mode)(img_data) img = transforms.ToPILImage(mode=expected_mode)(img_data)
assert img.mode == expected_mode assert img.mode == expected_mode
...@@ -661,14 +693,14 @@ class TestToPil: ...@@ -661,14 +693,14 @@ class TestToPil:
# should raise if we try a mode for 4 or 1 or 3 channel images # should raise if we try a mode for 4 or 1 or 3 channel images
with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"): with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"):
transforms.ToPILImage(mode='RGBA')(img_data) transforms.ToPILImage(mode="RGBA")(img_data)
with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"): with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"):
transforms.ToPILImage(mode='P')(img_data) transforms.ToPILImage(mode="P")(img_data)
with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"): with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"):
transforms.ToPILImage(mode='RGB')(img_data) transforms.ToPILImage(mode="RGB")(img_data)
@pytest.mark.parametrize('with_mode', [False, True]) @pytest.mark.parametrize("with_mode", [False, True])
@pytest.mark.parametrize('img_data, expected_output, expected_mode', _get_2d_tensor_various_types()) @pytest.mark.parametrize("img_data, expected_output, expected_mode", _get_2d_tensor_various_types())
def test_2d_tensor_to_pil_image(self, with_mode, img_data, expected_output, expected_mode): def test_2d_tensor_to_pil_image(self, with_mode, img_data, expected_output, expected_mode):
transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage() transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage()
to_tensor = transforms.ToTensor() to_tensor = transforms.ToTensor()
...@@ -677,27 +709,30 @@ class TestToPil: ...@@ -677,27 +709,30 @@ class TestToPil:
assert img.mode == expected_mode assert img.mode == expected_mode
torch.testing.assert_close(expected_output, to_tensor(img).numpy()[0]) torch.testing.assert_close(expected_output, to_tensor(img).numpy()[0])
@pytest.mark.parametrize('with_mode', [False, True]) @pytest.mark.parametrize("with_mode", [False, True])
@pytest.mark.parametrize('img_data, expected_mode', [ @pytest.mark.parametrize(
(torch.Tensor(4, 4).uniform_().numpy(), 'F'), "img_data, expected_mode",
(torch.ByteTensor(4, 4).random_(0, 255).numpy(), 'L'), [
(torch.ShortTensor(4, 4).random_().numpy(), 'I;16'), (torch.Tensor(4, 4).uniform_().numpy(), "F"),
(torch.IntTensor(4, 4).random_().numpy(), 'I'), (torch.ByteTensor(4, 4).random_(0, 255).numpy(), "L"),
]) (torch.ShortTensor(4, 4).random_().numpy(), "I;16"),
(torch.IntTensor(4, 4).random_().numpy(), "I"),
],
)
def test_2d_ndarray_to_pil_image(self, with_mode, img_data, expected_mode): def test_2d_ndarray_to_pil_image(self, with_mode, img_data, expected_mode):
transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage() transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage()
img = transform(img_data) img = transform(img_data)
assert img.mode == expected_mode assert img.mode == expected_mode
np.testing.assert_allclose(img_data, img) np.testing.assert_allclose(img_data, img)
@pytest.mark.parametrize('expected_mode', [None, 'RGB', 'HSV', 'YCbCr']) @pytest.mark.parametrize("expected_mode", [None, "RGB", "HSV", "YCbCr"])
def test_3_channel_tensor_to_pil_image(self, expected_mode): def test_3_channel_tensor_to_pil_image(self, expected_mode):
img_data = torch.Tensor(3, 4, 4).uniform_() img_data = torch.Tensor(3, 4, 4).uniform_()
expected_output = img_data.mul(255).int().float().div(255) expected_output = img_data.mul(255).int().float().div(255)
if expected_mode is None: if expected_mode is None:
img = transforms.ToPILImage()(img_data) img = transforms.ToPILImage()(img_data)
assert img.mode == 'RGB' # default should assume RGB assert img.mode == "RGB" # default should assume RGB
else: else:
img = transforms.ToPILImage(mode=expected_mode)(img_data) img = transforms.ToPILImage(mode=expected_mode)(img_data)
assert img.mode == expected_mode assert img.mode == expected_mode
...@@ -710,22 +745,22 @@ class TestToPil: ...@@ -710,22 +745,22 @@ class TestToPil:
error_message_3d = r"Only modes \['RGB', 'YCbCr', 'HSV'\] are supported for 3D inputs" error_message_3d = r"Only modes \['RGB', 'YCbCr', 'HSV'\] are supported for 3D inputs"
# should raise if we try a mode for 4 or 1 or 2 channel images # should raise if we try a mode for 4 or 1 or 2 channel images
with pytest.raises(ValueError, match=error_message_3d): with pytest.raises(ValueError, match=error_message_3d):
transforms.ToPILImage(mode='RGBA')(img_data) transforms.ToPILImage(mode="RGBA")(img_data)
with pytest.raises(ValueError, match=error_message_3d): with pytest.raises(ValueError, match=error_message_3d):
transforms.ToPILImage(mode='P')(img_data) transforms.ToPILImage(mode="P")(img_data)
with pytest.raises(ValueError, match=error_message_3d): with pytest.raises(ValueError, match=error_message_3d):
transforms.ToPILImage(mode='LA')(img_data) transforms.ToPILImage(mode="LA")(img_data)
with pytest.raises(ValueError, match=r'pic should be 2/3 dimensional. Got \d+ dimensions.'): with pytest.raises(ValueError, match=r"pic should be 2/3 dimensional. Got \d+ dimensions."):
transforms.ToPILImage()(torch.Tensor(1, 3, 4, 4).uniform_()) transforms.ToPILImage()(torch.Tensor(1, 3, 4, 4).uniform_())
@pytest.mark.parametrize('expected_mode', [None, 'RGB', 'HSV', 'YCbCr']) @pytest.mark.parametrize("expected_mode", [None, "RGB", "HSV", "YCbCr"])
def test_3_channel_ndarray_to_pil_image(self, expected_mode): def test_3_channel_ndarray_to_pil_image(self, expected_mode):
img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy() img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy()
if expected_mode is None: if expected_mode is None:
img = transforms.ToPILImage()(img_data) img = transforms.ToPILImage()(img_data)
assert img.mode == 'RGB' # default should assume RGB assert img.mode == "RGB" # default should assume RGB
else: else:
img = transforms.ToPILImage(mode=expected_mode)(img_data) img = transforms.ToPILImage(mode=expected_mode)(img_data)
assert img.mode == expected_mode assert img.mode == expected_mode
...@@ -742,20 +777,20 @@ class TestToPil: ...@@ -742,20 +777,20 @@ class TestToPil:
error_message_3d = r"Only modes \['RGB', 'YCbCr', 'HSV'\] are supported for 3D inputs" error_message_3d = r"Only modes \['RGB', 'YCbCr', 'HSV'\] are supported for 3D inputs"
# should raise if we try a mode for 4 or 1 or 2 channel images # should raise if we try a mode for 4 or 1 or 2 channel images
with pytest.raises(ValueError, match=error_message_3d): with pytest.raises(ValueError, match=error_message_3d):
transforms.ToPILImage(mode='RGBA')(img_data) transforms.ToPILImage(mode="RGBA")(img_data)
with pytest.raises(ValueError, match=error_message_3d): with pytest.raises(ValueError, match=error_message_3d):
transforms.ToPILImage(mode='P')(img_data) transforms.ToPILImage(mode="P")(img_data)
with pytest.raises(ValueError, match=error_message_3d): with pytest.raises(ValueError, match=error_message_3d):
transforms.ToPILImage(mode='LA')(img_data) transforms.ToPILImage(mode="LA")(img_data)
@pytest.mark.parametrize('expected_mode', [None, 'RGBA', 'CMYK', 'RGBX']) @pytest.mark.parametrize("expected_mode", [None, "RGBA", "CMYK", "RGBX"])
def test_4_channel_tensor_to_pil_image(self, expected_mode): def test_4_channel_tensor_to_pil_image(self, expected_mode):
img_data = torch.Tensor(4, 4, 4).uniform_() img_data = torch.Tensor(4, 4, 4).uniform_()
expected_output = img_data.mul(255).int().float().div(255) expected_output = img_data.mul(255).int().float().div(255)
if expected_mode is None: if expected_mode is None:
img = transforms.ToPILImage()(img_data) img = transforms.ToPILImage()(img_data)
assert img.mode == 'RGBA' # default should assume RGBA assert img.mode == "RGBA" # default should assume RGBA
else: else:
img = transforms.ToPILImage(mode=expected_mode)(img_data) img = transforms.ToPILImage(mode=expected_mode)(img_data)
assert img.mode == expected_mode assert img.mode == expected_mode
...@@ -770,19 +805,19 @@ class TestToPil: ...@@ -770,19 +805,19 @@ class TestToPil:
error_message_4d = r"Only modes \['RGBA', 'CMYK', 'RGBX'\] are supported for 4D inputs" error_message_4d = r"Only modes \['RGBA', 'CMYK', 'RGBX'\] are supported for 4D inputs"
# should raise if we try a mode for 3 or 1 or 2 channel images # should raise if we try a mode for 3 or 1 or 2 channel images
with pytest.raises(ValueError, match=error_message_4d): with pytest.raises(ValueError, match=error_message_4d):
transforms.ToPILImage(mode='RGB')(img_data) transforms.ToPILImage(mode="RGB")(img_data)
with pytest.raises(ValueError, match=error_message_4d): with pytest.raises(ValueError, match=error_message_4d):
transforms.ToPILImage(mode='P')(img_data) transforms.ToPILImage(mode="P")(img_data)
with pytest.raises(ValueError, match=error_message_4d): with pytest.raises(ValueError, match=error_message_4d):
transforms.ToPILImage(mode='LA')(img_data) transforms.ToPILImage(mode="LA")(img_data)
@pytest.mark.parametrize('expected_mode', [None, 'RGBA', 'CMYK', 'RGBX']) @pytest.mark.parametrize("expected_mode", [None, "RGBA", "CMYK", "RGBX"])
def test_4_channel_ndarray_to_pil_image(self, expected_mode): def test_4_channel_ndarray_to_pil_image(self, expected_mode):
img_data = torch.ByteTensor(4, 4, 4).random_(0, 255).numpy() img_data = torch.ByteTensor(4, 4, 4).random_(0, 255).numpy()
if expected_mode is None: if expected_mode is None:
img = transforms.ToPILImage()(img_data) img = transforms.ToPILImage()(img_data)
assert img.mode == 'RGBA' # default should assume RGBA assert img.mode == "RGBA" # default should assume RGBA
else: else:
img = transforms.ToPILImage(mode=expected_mode)(img_data) img = transforms.ToPILImage(mode=expected_mode)(img_data)
assert img.mode == expected_mode assert img.mode == expected_mode
...@@ -796,15 +831,15 @@ class TestToPil: ...@@ -796,15 +831,15 @@ class TestToPil:
error_message_4d = r"Only modes \['RGBA', 'CMYK', 'RGBX'\] are supported for 4D inputs" error_message_4d = r"Only modes \['RGBA', 'CMYK', 'RGBX'\] are supported for 4D inputs"
# should raise if we try a mode for 3 or 1 or 2 channel images # should raise if we try a mode for 3 or 1 or 2 channel images
with pytest.raises(ValueError, match=error_message_4d): with pytest.raises(ValueError, match=error_message_4d):
transforms.ToPILImage(mode='RGB')(img_data) transforms.ToPILImage(mode="RGB")(img_data)
with pytest.raises(ValueError, match=error_message_4d): with pytest.raises(ValueError, match=error_message_4d):
transforms.ToPILImage(mode='P')(img_data) transforms.ToPILImage(mode="P")(img_data)
with pytest.raises(ValueError, match=error_message_4d): with pytest.raises(ValueError, match=error_message_4d):
transforms.ToPILImage(mode='LA')(img_data) transforms.ToPILImage(mode="LA")(img_data)
def test_ndarray_bad_types_to_pil_image(self): def test_ndarray_bad_types_to_pil_image(self):
trans = transforms.ToPILImage() trans = transforms.ToPILImage()
reg_msg = r'Input type \w+ is not supported' reg_msg = r"Input type \w+ is not supported"
with pytest.raises(TypeError, match=reg_msg): with pytest.raises(TypeError, match=reg_msg):
trans(np.ones([4, 4, 1], np.int64)) trans(np.ones([4, 4, 1], np.int64))
with pytest.raises(TypeError, match=reg_msg): with pytest.raises(TypeError, match=reg_msg):
...@@ -814,15 +849,15 @@ class TestToPil: ...@@ -814,15 +849,15 @@ class TestToPil:
with pytest.raises(TypeError, match=reg_msg): with pytest.raises(TypeError, match=reg_msg):
trans(np.ones([4, 4, 1], np.float64)) trans(np.ones([4, 4, 1], np.float64))
with pytest.raises(ValueError, match=r'pic should be 2/3 dimensional. Got \d+ dimensions.'): with pytest.raises(ValueError, match=r"pic should be 2/3 dimensional. Got \d+ dimensions."):
transforms.ToPILImage()(np.ones([1, 4, 4, 3])) transforms.ToPILImage()(np.ones([1, 4, 4, 3]))
with pytest.raises(ValueError, match=r'pic should not have > 4 channels. Got \d+ channels.'): with pytest.raises(ValueError, match=r"pic should not have > 4 channels. Got \d+ channels."):
transforms.ToPILImage()(np.ones([4, 4, 6])) transforms.ToPILImage()(np.ones([4, 4, 6]))
def test_tensor_bad_types_to_pil_image(self): def test_tensor_bad_types_to_pil_image(self):
with pytest.raises(ValueError, match=r'pic should be 2/3 dimensional. Got \d+ dimensions.'): with pytest.raises(ValueError, match=r"pic should be 2/3 dimensional. Got \d+ dimensions."):
transforms.ToPILImage()(torch.ones(1, 3, 4, 4)) transforms.ToPILImage()(torch.ones(1, 3, 4, 4))
with pytest.raises(ValueError, match=r'pic should not have > 4 channels. Got \d+ channels.'): with pytest.raises(ValueError, match=r"pic should not have > 4 channels. Got \d+ channels."):
transforms.ToPILImage()(torch.ones(6, 4, 4)) transforms.ToPILImage()(torch.ones(6, 4, 4))
...@@ -830,7 +865,7 @@ def test_adjust_brightness(): ...@@ -830,7 +865,7 @@ def test_adjust_brightness():
x_shape = [2, 2, 3] x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode='RGB') x_pil = Image.fromarray(x_np, mode="RGB")
# test 0 # test 0
y_pil = F.adjust_brightness(x_pil, 1) y_pil = F.adjust_brightness(x_pil, 1)
...@@ -856,7 +891,7 @@ def test_adjust_contrast(): ...@@ -856,7 +891,7 @@ def test_adjust_contrast():
x_shape = [2, 2, 3] x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode='RGB') x_pil = Image.fromarray(x_np, mode="RGB")
# test 0 # test 0
y_pil = F.adjust_contrast(x_pil, 1) y_pil = F.adjust_contrast(x_pil, 1)
...@@ -878,12 +913,12 @@ def test_adjust_contrast(): ...@@ -878,12 +913,12 @@ def test_adjust_contrast():
torch.testing.assert_close(y_np, y_ans) torch.testing.assert_close(y_np, y_ans)
@pytest.mark.skipif(Image.__version__ >= '7', reason="Temporarily disabled") @pytest.mark.skipif(Image.__version__ >= "7", reason="Temporarily disabled")
def test_adjust_saturation(): def test_adjust_saturation():
x_shape = [2, 2, 3] x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode='RGB') x_pil = Image.fromarray(x_np, mode="RGB")
# test 0 # test 0
y_pil = F.adjust_saturation(x_pil, 1) y_pil = F.adjust_saturation(x_pil, 1)
...@@ -909,7 +944,7 @@ def test_adjust_hue(): ...@@ -909,7 +944,7 @@ def test_adjust_hue():
x_shape = [2, 2, 3] x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode='RGB') x_pil = Image.fromarray(x_np, mode="RGB")
with pytest.raises(ValueError): with pytest.raises(ValueError):
F.adjust_hue(x_pil, -0.7) F.adjust_hue(x_pil, -0.7)
...@@ -940,11 +975,58 @@ def test_adjust_hue(): ...@@ -940,11 +975,58 @@ def test_adjust_hue():
def test_adjust_sharpness(): def test_adjust_sharpness():
x_shape = [4, 4, 3] x_shape = [4, 4, 3]
x_data = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 0, x_data = [
0, 65, 108, 101, 120, 97, 110, 100, 101, 114, 32, 86, 114, 121, 110, 105, 75,
111, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117] 121,
114,
105,
97,
107,
105,
32,
66,
111,
117,
114,
99,
104,
97,
0,
0,
65,
108,
101,
120,
97,
110,
100,
101,
114,
32,
86,
114,
121,
110,
105,
111,
116,
105,
115,
0,
0,
73,
32,
108,
111,
118,
101,
32,
121,
111,
117,
]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode='RGB') x_pil = Image.fromarray(x_np, mode="RGB")
# test 0 # test 0
y_pil = F.adjust_sharpness(x_pil, 1) y_pil = F.adjust_sharpness(x_pil, 1)
...@@ -954,18 +1036,112 @@ def test_adjust_sharpness(): ...@@ -954,18 +1036,112 @@ def test_adjust_sharpness():
# test 1 # test 1
y_pil = F.adjust_sharpness(x_pil, 0.5) y_pil = F.adjust_sharpness(x_pil, 0.5)
y_np = np.array(y_pil) y_np = np.array(y_pil)
y_ans = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 30, y_ans = [
30, 74, 103, 96, 114, 97, 110, 100, 101, 114, 32, 81, 103, 108, 102, 101, 75,
107, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117] 121,
114,
105,
97,
107,
105,
32,
66,
111,
117,
114,
99,
104,
97,
30,
30,
74,
103,
96,
114,
97,
110,
100,
101,
114,
32,
81,
103,
108,
102,
101,
107,
116,
105,
115,
0,
0,
73,
32,
108,
111,
118,
101,
32,
121,
111,
117,
]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
torch.testing.assert_close(y_np, y_ans) torch.testing.assert_close(y_np, y_ans)
# test 2 # test 2
y_pil = F.adjust_sharpness(x_pil, 2) y_pil = F.adjust_sharpness(x_pil, 2)
y_np = np.array(y_pil) y_np = np.array(y_pil)
y_ans = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 0, y_ans = [
0, 46, 118, 111, 132, 97, 110, 100, 101, 114, 32, 95, 135, 146, 126, 112, 75,
119, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117] 121,
114,
105,
97,
107,
105,
32,
66,
111,
117,
114,
99,
104,
97,
0,
0,
46,
118,
111,
132,
97,
110,
100,
101,
114,
32,
95,
135,
146,
126,
112,
119,
116,
105,
115,
0,
0,
73,
32,
108,
111,
118,
101,
32,
121,
111,
117,
]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
torch.testing.assert_close(y_np, y_ans) torch.testing.assert_close(y_np, y_ans)
...@@ -973,7 +1149,7 @@ def test_adjust_sharpness(): ...@@ -973,7 +1149,7 @@ def test_adjust_sharpness():
x_shape = [2, 2, 3] x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode='RGB') x_pil = Image.fromarray(x_np, mode="RGB")
x_th = torch.tensor(x_np.transpose(2, 0, 1)) x_th = torch.tensor(x_np.transpose(2, 0, 1))
y_pil = F.adjust_sharpness(x_pil, 2) y_pil = F.adjust_sharpness(x_pil, 2)
y_np = np.array(y_pil).transpose(2, 0, 1) y_np = np.array(y_pil).transpose(2, 0, 1)
...@@ -985,7 +1161,7 @@ def test_adjust_gamma(): ...@@ -985,7 +1161,7 @@ def test_adjust_gamma():
x_shape = [2, 2, 3] x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode='RGB') x_pil = Image.fromarray(x_np, mode="RGB")
# test 0 # test 0
y_pil = F.adjust_gamma(x_pil, 1) y_pil = F.adjust_gamma(x_pil, 1)
...@@ -1011,15 +1187,15 @@ def test_adjusts_L_mode(): ...@@ -1011,15 +1187,15 @@ def test_adjusts_L_mode():
x_shape = [2, 2, 3] x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_rgb = Image.fromarray(x_np, mode='RGB') x_rgb = Image.fromarray(x_np, mode="RGB")
x_l = x_rgb.convert('L') x_l = x_rgb.convert("L")
assert F.adjust_brightness(x_l, 2).mode == 'L' assert F.adjust_brightness(x_l, 2).mode == "L"
assert F.adjust_saturation(x_l, 2).mode == 'L' assert F.adjust_saturation(x_l, 2).mode == "L"
assert F.adjust_contrast(x_l, 2).mode == 'L' assert F.adjust_contrast(x_l, 2).mode == "L"
assert F.adjust_hue(x_l, 0.4).mode == 'L' assert F.adjust_hue(x_l, 0.4).mode == "L"
assert F.adjust_sharpness(x_l, 2).mode == 'L' assert F.adjust_sharpness(x_l, 2).mode == "L"
assert F.adjust_gamma(x_l, 0.5).mode == 'L' assert F.adjust_gamma(x_l, 0.5).mode == "L"
def test_rotate(): def test_rotate():
...@@ -1058,7 +1234,7 @@ def test_rotate(): ...@@ -1058,7 +1234,7 @@ def test_rotate():
assert_equal(np.array(result_a), np.array(result_b)) assert_equal(np.array(result_a), np.array(result_b))
@pytest.mark.parametrize('mode', ["L", "RGB", "F"]) @pytest.mark.parametrize("mode", ["L", "RGB", "F"])
def test_rotate_fill(mode): def test_rotate_fill(mode):
img = F.to_pil_image(np.ones((100, 100, 3), dtype=np.uint8) * 255, "RGB") img = F.to_pil_image(np.ones((100, 100, 3), dtype=np.uint8) * 255, "RGB")
...@@ -1141,8 +1317,8 @@ def test_to_grayscale(): ...@@ -1141,8 +1317,8 @@ def test_to_grayscale():
x_shape = [2, 2, 3] x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode='RGB') x_pil = Image.fromarray(x_np, mode="RGB")
x_pil_2 = x_pil.convert('L') x_pil_2 = x_pil.convert("L")
gray_np = np.array(x_pil_2) gray_np = np.array(x_pil_2)
# Test Set: Grayscale an image with desired number of output channels # Test Set: Grayscale an image with desired number of output channels
...@@ -1150,16 +1326,16 @@ def test_to_grayscale(): ...@@ -1150,16 +1326,16 @@ def test_to_grayscale():
trans1 = transforms.Grayscale(num_output_channels=1) trans1 = transforms.Grayscale(num_output_channels=1)
gray_pil_1 = trans1(x_pil) gray_pil_1 = trans1(x_pil)
gray_np_1 = np.array(gray_pil_1) gray_np_1 = np.array(gray_pil_1)
assert gray_pil_1.mode == 'L', 'mode should be L' assert gray_pil_1.mode == "L", "mode should be L"
assert gray_np_1.shape == tuple(x_shape[0:2]), 'should be 1 channel' assert gray_np_1.shape == tuple(x_shape[0:2]), "should be 1 channel"
assert_equal(gray_np, gray_np_1) assert_equal(gray_np, gray_np_1)
# Case 2: RGB -> 3 channel grayscale # Case 2: RGB -> 3 channel grayscale
trans2 = transforms.Grayscale(num_output_channels=3) trans2 = transforms.Grayscale(num_output_channels=3)
gray_pil_2 = trans2(x_pil) gray_pil_2 = trans2(x_pil)
gray_np_2 = np.array(gray_pil_2) gray_np_2 = np.array(gray_pil_2)
assert gray_pil_2.mode == 'RGB', 'mode should be RGB' assert gray_pil_2.mode == "RGB", "mode should be RGB"
assert gray_np_2.shape == tuple(x_shape), 'should be 3 channel' assert gray_np_2.shape == tuple(x_shape), "should be 3 channel"
assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1]) assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1])
assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2]) assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2])
assert_equal(gray_np, gray_np_2[:, :, 0]) assert_equal(gray_np, gray_np_2[:, :, 0])
...@@ -1168,16 +1344,16 @@ def test_to_grayscale(): ...@@ -1168,16 +1344,16 @@ def test_to_grayscale():
trans3 = transforms.Grayscale(num_output_channels=1) trans3 = transforms.Grayscale(num_output_channels=1)
gray_pil_3 = trans3(x_pil_2) gray_pil_3 = trans3(x_pil_2)
gray_np_3 = np.array(gray_pil_3) gray_np_3 = np.array(gray_pil_3)
assert gray_pil_3.mode == 'L', 'mode should be L' assert gray_pil_3.mode == "L", "mode should be L"
assert gray_np_3.shape == tuple(x_shape[0:2]), 'should be 1 channel' assert gray_np_3.shape == tuple(x_shape[0:2]), "should be 1 channel"
assert_equal(gray_np, gray_np_3) assert_equal(gray_np, gray_np_3)
# Case 4: 1 channel grayscale -> 3 channel grayscale # Case 4: 1 channel grayscale -> 3 channel grayscale
trans4 = transforms.Grayscale(num_output_channels=3) trans4 = transforms.Grayscale(num_output_channels=3)
gray_pil_4 = trans4(x_pil_2) gray_pil_4 = trans4(x_pil_2)
gray_np_4 = np.array(gray_pil_4) gray_np_4 = np.array(gray_pil_4)
assert gray_pil_4.mode == 'RGB', 'mode should be RGB' assert gray_pil_4.mode == "RGB", "mode should be RGB"
assert gray_np_4.shape == tuple(x_shape), 'should be 3 channel' assert gray_np_4.shape == tuple(x_shape), "should be 3 channel"
assert_equal(gray_np_4[:, :, 0], gray_np_4[:, :, 1]) assert_equal(gray_np_4[:, :, 0], gray_np_4[:, :, 1])
assert_equal(gray_np_4[:, :, 1], gray_np_4[:, :, 2]) assert_equal(gray_np_4[:, :, 1], gray_np_4[:, :, 2])
assert_equal(gray_np, gray_np_4[:, :, 0]) assert_equal(gray_np, gray_np_4[:, :, 0])
...@@ -1196,8 +1372,8 @@ def test_random_grayscale(): ...@@ -1196,8 +1372,8 @@ def test_random_grayscale():
random.seed(42) random.seed(42)
x_shape = [2, 2, 3] x_shape = [2, 2, 3]
x_np = np_rng.randint(0, 256, x_shape, np.uint8) x_np = np_rng.randint(0, 256, x_shape, np.uint8)
x_pil = Image.fromarray(x_np, mode='RGB') x_pil = Image.fromarray(x_np, mode="RGB")
x_pil_2 = x_pil.convert('L') x_pil_2 = x_pil.convert("L")
gray_np = np.array(x_pil_2) gray_np = np.array(x_pil_2)
num_samples = 250 num_samples = 250
...@@ -1205,9 +1381,11 @@ def test_random_grayscale(): ...@@ -1205,9 +1381,11 @@ def test_random_grayscale():
for _ in range(num_samples): for _ in range(num_samples):
gray_pil_2 = transforms.RandomGrayscale(p=0.5)(x_pil) gray_pil_2 = transforms.RandomGrayscale(p=0.5)(x_pil)
gray_np_2 = np.array(gray_pil_2) gray_np_2 = np.array(gray_pil_2)
if np.array_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1]) and \ if (
np.array_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2]) and \ np.array_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1])
np.array_equal(gray_np, gray_np_2[:, :, 0]): and np.array_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2])
and np.array_equal(gray_np, gray_np_2[:, :, 0])
):
num_gray = num_gray + 1 num_gray = num_gray + 1
p_value = stats.binom_test(num_gray, num_samples, p=0.5) p_value = stats.binom_test(num_gray, num_samples, p=0.5)
...@@ -1219,8 +1397,8 @@ def test_random_grayscale(): ...@@ -1219,8 +1397,8 @@ def test_random_grayscale():
random.seed(42) random.seed(42)
x_shape = [2, 2, 3] x_shape = [2, 2, 3]
x_np = np_rng.randint(0, 256, x_shape, np.uint8) x_np = np_rng.randint(0, 256, x_shape, np.uint8)
x_pil = Image.fromarray(x_np, mode='RGB') x_pil = Image.fromarray(x_np, mode="RGB")
x_pil_2 = x_pil.convert('L') x_pil_2 = x_pil.convert("L")
gray_np = np.array(x_pil_2) gray_np = np.array(x_pil_2)
num_samples = 250 num_samples = 250
...@@ -1239,16 +1417,16 @@ def test_random_grayscale(): ...@@ -1239,16 +1417,16 @@ def test_random_grayscale():
x_shape = [2, 2, 3] x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode='RGB') x_pil = Image.fromarray(x_np, mode="RGB")
x_pil_2 = x_pil.convert('L') x_pil_2 = x_pil.convert("L")
gray_np = np.array(x_pil_2) gray_np = np.array(x_pil_2)
# Case 3a: RGB -> 3 channel grayscale (grayscaled) # Case 3a: RGB -> 3 channel grayscale (grayscaled)
trans2 = transforms.RandomGrayscale(p=1.0) trans2 = transforms.RandomGrayscale(p=1.0)
gray_pil_2 = trans2(x_pil) gray_pil_2 = trans2(x_pil)
gray_np_2 = np.array(gray_pil_2) gray_np_2 = np.array(gray_pil_2)
assert gray_pil_2.mode == 'RGB', 'mode should be RGB' assert gray_pil_2.mode == "RGB", "mode should be RGB"
assert gray_np_2.shape == tuple(x_shape), 'should be 3 channel' assert gray_np_2.shape == tuple(x_shape), "should be 3 channel"
assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1]) assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1])
assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2]) assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2])
assert_equal(gray_np, gray_np_2[:, :, 0]) assert_equal(gray_np, gray_np_2[:, :, 0])
...@@ -1257,31 +1435,31 @@ def test_random_grayscale(): ...@@ -1257,31 +1435,31 @@ def test_random_grayscale():
trans2 = transforms.RandomGrayscale(p=0.0) trans2 = transforms.RandomGrayscale(p=0.0)
gray_pil_2 = trans2(x_pil) gray_pil_2 = trans2(x_pil)
gray_np_2 = np.array(gray_pil_2) gray_np_2 = np.array(gray_pil_2)
assert gray_pil_2.mode == 'RGB', 'mode should be RGB' assert gray_pil_2.mode == "RGB", "mode should be RGB"
assert gray_np_2.shape == tuple(x_shape), 'should be 3 channel' assert gray_np_2.shape == tuple(x_shape), "should be 3 channel"
assert_equal(x_np, gray_np_2) assert_equal(x_np, gray_np_2)
# Case 3c: 1 channel grayscale -> 1 channel grayscale (grayscaled) # Case 3c: 1 channel grayscale -> 1 channel grayscale (grayscaled)
trans3 = transforms.RandomGrayscale(p=1.0) trans3 = transforms.RandomGrayscale(p=1.0)
gray_pil_3 = trans3(x_pil_2) gray_pil_3 = trans3(x_pil_2)
gray_np_3 = np.array(gray_pil_3) gray_np_3 = np.array(gray_pil_3)
assert gray_pil_3.mode == 'L', 'mode should be L' assert gray_pil_3.mode == "L", "mode should be L"
assert gray_np_3.shape == tuple(x_shape[0:2]), 'should be 1 channel' assert gray_np_3.shape == tuple(x_shape[0:2]), "should be 1 channel"
assert_equal(gray_np, gray_np_3) assert_equal(gray_np, gray_np_3)
# Case 3d: 1 channel grayscale -> 1 channel grayscale (unchanged) # Case 3d: 1 channel grayscale -> 1 channel grayscale (unchanged)
trans3 = transforms.RandomGrayscale(p=0.0) trans3 = transforms.RandomGrayscale(p=0.0)
gray_pil_3 = trans3(x_pil_2) gray_pil_3 = trans3(x_pil_2)
gray_np_3 = np.array(gray_pil_3) gray_np_3 = np.array(gray_pil_3)
assert gray_pil_3.mode == 'L', 'mode should be L' assert gray_pil_3.mode == "L", "mode should be L"
assert gray_np_3.shape == tuple(x_shape[0:2]), 'should be 1 channel' assert gray_np_3.shape == tuple(x_shape[0:2]), "should be 1 channel"
assert_equal(gray_np, gray_np_3) assert_equal(gray_np, gray_np_3)
# Checking if RandomGrayscale can be printed as string # Checking if RandomGrayscale can be printed as string
trans3.__repr__() trans3.__repr__()
@pytest.mark.skipif(stats is None, reason='scipy.stats not available') @pytest.mark.skipif(stats is None, reason="scipy.stats not available")
def test_random_apply(): def test_random_apply():
random_state = random.getstate() random_state = random.getstate()
random.seed(42) random.seed(42)
...@@ -1290,7 +1468,8 @@ def test_random_apply(): ...@@ -1290,7 +1468,8 @@ def test_random_apply():
transforms.RandomRotation((-45, 45)), transforms.RandomRotation((-45, 45)),
transforms.RandomHorizontalFlip(), transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(), transforms.RandomVerticalFlip(),
], p=0.75 ],
p=0.75,
) )
img = transforms.ToPILImage()(torch.rand(3, 10, 10)) img = transforms.ToPILImage()(torch.rand(3, 10, 10))
num_samples = 250 num_samples = 250
...@@ -1308,17 +1487,12 @@ def test_random_apply(): ...@@ -1308,17 +1487,12 @@ def test_random_apply():
random_apply_transform.__repr__() random_apply_transform.__repr__()
@pytest.mark.skipif(stats is None, reason='scipy.stats not available') @pytest.mark.skipif(stats is None, reason="scipy.stats not available")
def test_random_choice(): def test_random_choice():
random_state = random.getstate() random_state = random.getstate()
random.seed(42) random.seed(42)
random_choice_transform = transforms.RandomChoice( random_choice_transform = transforms.RandomChoice(
[ [transforms.Resize(15), transforms.Resize(20), transforms.CenterCrop(10)], [1 / 3, 1 / 3, 1 / 3]
transforms.Resize(15),
transforms.Resize(20),
transforms.CenterCrop(10)
],
[1 / 3, 1 / 3, 1 / 3]
) )
img = transforms.ToPILImage()(torch.rand(3, 25, 25)) img = transforms.ToPILImage()(torch.rand(3, 25, 25))
num_samples = 250 num_samples = 250
...@@ -1346,16 +1520,11 @@ def test_random_choice(): ...@@ -1346,16 +1520,11 @@ def test_random_choice():
random_choice_transform.__repr__() random_choice_transform.__repr__()
@pytest.mark.skipif(stats is None, reason='scipy.stats not available') @pytest.mark.skipif(stats is None, reason="scipy.stats not available")
def test_random_order(): def test_random_order():
random_state = random.getstate() random_state = random.getstate()
random.seed(42) random.seed(42)
random_order_transform = transforms.RandomOrder( random_order_transform = transforms.RandomOrder([transforms.Resize(20), transforms.CenterCrop(10)])
[
transforms.Resize(20),
transforms.CenterCrop(10)
]
)
img = transforms.ToPILImage()(torch.rand(3, 25, 25)) img = transforms.ToPILImage()(torch.rand(3, 25, 25))
num_samples = 250 num_samples = 250
num_normal_order = 0 num_normal_order = 0
...@@ -1381,10 +1550,10 @@ def test_linear_transformation(): ...@@ -1381,10 +1550,10 @@ def test_linear_transformation():
sigma = torch.mm(flat_x.t(), flat_x) / flat_x.size(0) sigma = torch.mm(flat_x.t(), flat_x) / flat_x.size(0)
u, s, _ = np.linalg.svd(sigma.numpy()) u, s, _ = np.linalg.svd(sigma.numpy())
zca_epsilon = 1e-10 # avoid division by 0 zca_epsilon = 1e-10 # avoid division by 0
d = torch.Tensor(np.diag(1. / np.sqrt(s + zca_epsilon))) d = torch.Tensor(np.diag(1.0 / np.sqrt(s + zca_epsilon)))
u = torch.Tensor(u) u = torch.Tensor(u)
principal_components = torch.mm(torch.mm(u, d), u.t()) principal_components = torch.mm(torch.mm(u, d), u.t())
mean_vector = (torch.sum(flat_x, dim=0) / flat_x.size(0)) mean_vector = torch.sum(flat_x, dim=0) / flat_x.size(0)
# initialize whitening matrix # initialize whitening matrix
whitening = transforms.LinearTransformation(principal_components, mean_vector) whitening = transforms.LinearTransformation(principal_components, mean_vector)
# estimate covariance and mean using weak law of large number # estimate covariance and mean using weak law of large number
...@@ -1397,16 +1566,18 @@ def test_linear_transformation(): ...@@ -1397,16 +1566,18 @@ def test_linear_transformation():
cov += np.dot(xwhite, xwhite.T) / num_features cov += np.dot(xwhite, xwhite.T) / num_features
mean += np.sum(xwhite) / num_features mean += np.sum(xwhite) / num_features
# if rtol for std = 1e-3 then rtol for cov = 2e-3 as std**2 = cov # if rtol for std = 1e-3 then rtol for cov = 2e-3 as std**2 = cov
torch.testing.assert_close(cov / num_samples, np.identity(1), rtol=2e-3, atol=1e-8, check_dtype=False, torch.testing.assert_close(
msg="cov not close to 1") cov / num_samples, np.identity(1), rtol=2e-3, atol=1e-8, check_dtype=False, msg="cov not close to 1"
torch.testing.assert_close(mean / num_samples, 0, rtol=1e-3, atol=1e-8, check_dtype=False, )
msg="mean not close to 0") torch.testing.assert_close(
mean / num_samples, 0, rtol=1e-3, atol=1e-8, check_dtype=False, msg="mean not close to 0"
)
# Checking if LinearTransformation can be printed as string # Checking if LinearTransformation can be printed as string
whitening.__repr__() whitening.__repr__()
@pytest.mark.parametrize('dtype', int_dtypes()) @pytest.mark.parametrize("dtype", int_dtypes())
def test_max_value(dtype): def test_max_value(dtype):
assert F_t._max_value(dtype) == torch.iinfo(dtype).max assert F_t._max_value(dtype) == torch.iinfo(dtype).max
...@@ -1416,8 +1587,8 @@ def test_max_value(dtype): ...@@ -1416,8 +1587,8 @@ def test_max_value(dtype):
# self.assertGreater(F_t._max_value(dtype), torch.finfo(dtype).max) # self.assertGreater(F_t._max_value(dtype), torch.finfo(dtype).max)
@pytest.mark.parametrize('should_vflip', [True, False]) @pytest.mark.parametrize("should_vflip", [True, False])
@pytest.mark.parametrize('single_dim', [True, False]) @pytest.mark.parametrize("single_dim", [True, False])
def test_ten_crop(should_vflip, single_dim): def test_ten_crop(should_vflip, single_dim):
to_pil_image = transforms.ToPILImage() to_pil_image = transforms.ToPILImage()
h = random.randint(5, 25) h = random.randint(5, 25)
...@@ -1427,12 +1598,10 @@ def test_ten_crop(should_vflip, single_dim): ...@@ -1427,12 +1598,10 @@ def test_ten_crop(should_vflip, single_dim):
if single_dim: if single_dim:
crop_h = min(crop_h, crop_w) crop_h = min(crop_h, crop_w)
crop_w = crop_h crop_w = crop_h
transform = transforms.TenCrop(crop_h, transform = transforms.TenCrop(crop_h, vertical_flip=should_vflip)
vertical_flip=should_vflip)
five_crop = transforms.FiveCrop(crop_h) five_crop = transforms.FiveCrop(crop_h)
else: else:
transform = transforms.TenCrop((crop_h, crop_w), transform = transforms.TenCrop((crop_h, crop_w), vertical_flip=should_vflip)
vertical_flip=should_vflip)
five_crop = transforms.FiveCrop((crop_h, crop_w)) five_crop = transforms.FiveCrop((crop_h, crop_w))
img = to_pil_image(torch.FloatTensor(3, h, w).uniform_()) img = to_pil_image(torch.FloatTensor(3, h, w).uniform_())
...@@ -1454,7 +1623,7 @@ def test_ten_crop(should_vflip, single_dim): ...@@ -1454,7 +1623,7 @@ def test_ten_crop(should_vflip, single_dim):
assert results == expected_output assert results == expected_output
@pytest.mark.parametrize('single_dim', [True, False]) @pytest.mark.parametrize("single_dim", [True, False])
def test_five_crop(single_dim): def test_five_crop(single_dim):
to_pil_image = transforms.ToPILImage() to_pil_image = transforms.ToPILImage()
h = random.randint(5, 25) h = random.randint(5, 25)
...@@ -1478,17 +1647,17 @@ def test_five_crop(single_dim): ...@@ -1478,17 +1647,17 @@ def test_five_crop(single_dim):
to_pil_image = transforms.ToPILImage() to_pil_image = transforms.ToPILImage()
tl = to_pil_image(img[:, 0:crop_h, 0:crop_w]) tl = to_pil_image(img[:, 0:crop_h, 0:crop_w])
tr = to_pil_image(img[:, 0:crop_h, w - crop_w:]) tr = to_pil_image(img[:, 0:crop_h, w - crop_w :])
bl = to_pil_image(img[:, h - crop_h:, 0:crop_w]) bl = to_pil_image(img[:, h - crop_h :, 0:crop_w])
br = to_pil_image(img[:, h - crop_h:, w - crop_w:]) br = to_pil_image(img[:, h - crop_h :, w - crop_w :])
center = transforms.CenterCrop((crop_h, crop_w))(to_pil_image(img)) center = transforms.CenterCrop((crop_h, crop_w))(to_pil_image(img))
expected_output = (tl, tr, bl, br, center) expected_output = (tl, tr, bl, br, center)
assert results == expected_output assert results == expected_output
@pytest.mark.parametrize('policy', transforms.AutoAugmentPolicy) @pytest.mark.parametrize("policy", transforms.AutoAugmentPolicy)
@pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)]) @pytest.mark.parametrize("fill", [None, 85, (128, 128, 128)])
@pytest.mark.parametrize('grayscale', [True, False]) @pytest.mark.parametrize("grayscale", [True, False])
def test_autoaugment(policy, fill, grayscale): def test_autoaugment(policy, fill, grayscale):
random.seed(42) random.seed(42)
img = Image.open(GRACE_HOPPER) img = Image.open(GRACE_HOPPER)
...@@ -1500,10 +1669,10 @@ def test_autoaugment(policy, fill, grayscale): ...@@ -1500,10 +1669,10 @@ def test_autoaugment(policy, fill, grayscale):
transform.__repr__() transform.__repr__()
@pytest.mark.parametrize('num_ops', [1, 2, 3]) @pytest.mark.parametrize("num_ops", [1, 2, 3])
@pytest.mark.parametrize('magnitude', [7, 9, 11]) @pytest.mark.parametrize("magnitude", [7, 9, 11])
@pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)]) @pytest.mark.parametrize("fill", [None, 85, (128, 128, 128)])
@pytest.mark.parametrize('grayscale', [True, False]) @pytest.mark.parametrize("grayscale", [True, False])
def test_randaugment(num_ops, magnitude, fill, grayscale): def test_randaugment(num_ops, magnitude, fill, grayscale):
random.seed(42) random.seed(42)
img = Image.open(GRACE_HOPPER) img = Image.open(GRACE_HOPPER)
...@@ -1515,9 +1684,9 @@ def test_randaugment(num_ops, magnitude, fill, grayscale): ...@@ -1515,9 +1684,9 @@ def test_randaugment(num_ops, magnitude, fill, grayscale):
transform.__repr__() transform.__repr__()
@pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)]) @pytest.mark.parametrize("fill", [None, 85, (128, 128, 128)])
@pytest.mark.parametrize('num_magnitude_bins', [10, 13, 30]) @pytest.mark.parametrize("num_magnitude_bins", [10, 13, 30])
@pytest.mark.parametrize('grayscale', [True, False]) @pytest.mark.parametrize("grayscale", [True, False])
def test_trivialaugmentwide(fill, num_magnitude_bins, grayscale): def test_trivialaugmentwide(fill, num_magnitude_bins, grayscale):
random.seed(42) random.seed(42)
img = Image.open(GRACE_HOPPER) img = Image.open(GRACE_HOPPER)
...@@ -1535,37 +1704,41 @@ def test_random_crop(): ...@@ -1535,37 +1704,41 @@ def test_random_crop():
oheight = random.randint(5, (height - 2) / 2) * 2 oheight = random.randint(5, (height - 2) / 2) * 2
owidth = random.randint(5, (width - 2) / 2) * 2 owidth = random.randint(5, (width - 2) / 2) * 2
img = torch.ones(3, height, width) img = torch.ones(3, height, width)
result = transforms.Compose([ result = transforms.Compose(
transforms.ToPILImage(), [
transforms.RandomCrop((oheight, owidth)), transforms.ToPILImage(),
transforms.ToTensor(), transforms.RandomCrop((oheight, owidth)),
])(img) transforms.ToTensor(),
]
)(img)
assert result.size(1) == oheight assert result.size(1) == oheight
assert result.size(2) == owidth assert result.size(2) == owidth
padding = random.randint(1, 20) padding = random.randint(1, 20)
result = transforms.Compose([ result = transforms.Compose(
transforms.ToPILImage(), [
transforms.RandomCrop((oheight, owidth), padding=padding), transforms.ToPILImage(),
transforms.ToTensor(), transforms.RandomCrop((oheight, owidth), padding=padding),
])(img) transforms.ToTensor(),
]
)(img)
assert result.size(1) == oheight assert result.size(1) == oheight
assert result.size(2) == owidth assert result.size(2) == owidth
result = transforms.Compose([ result = transforms.Compose(
transforms.ToPILImage(), [transforms.ToPILImage(), transforms.RandomCrop((height, width)), transforms.ToTensor()]
transforms.RandomCrop((height, width)), )(img)
transforms.ToTensor()
])(img)
assert result.size(1) == height assert result.size(1) == height
assert result.size(2) == width assert result.size(2) == width
torch.testing.assert_close(result, img) torch.testing.assert_close(result, img)
result = transforms.Compose([ result = transforms.Compose(
transforms.ToPILImage(), [
transforms.RandomCrop((height + 1, width + 1), pad_if_needed=True), transforms.ToPILImage(),
transforms.ToTensor(), transforms.RandomCrop((height + 1, width + 1), pad_if_needed=True),
])(img) transforms.ToTensor(),
]
)(img)
assert result.size(1) == height + 1 assert result.size(1) == height + 1
assert result.size(2) == width + 1 assert result.size(2) == width + 1
...@@ -1584,41 +1757,47 @@ def test_center_crop(): ...@@ -1584,41 +1757,47 @@ def test_center_crop():
img = torch.ones(3, height, width) img = torch.ones(3, height, width)
oh1 = (height - oheight) // 2 oh1 = (height - oheight) // 2
ow1 = (width - owidth) // 2 ow1 = (width - owidth) // 2
imgnarrow = img[:, oh1:oh1 + oheight, ow1:ow1 + owidth] imgnarrow = img[:, oh1 : oh1 + oheight, ow1 : ow1 + owidth]
imgnarrow.fill_(0) imgnarrow.fill_(0)
result = transforms.Compose([ result = transforms.Compose(
transforms.ToPILImage(), [
transforms.CenterCrop((oheight, owidth)), transforms.ToPILImage(),
transforms.ToTensor(), transforms.CenterCrop((oheight, owidth)),
])(img) transforms.ToTensor(),
]
)(img)
assert result.sum() == 0 assert result.sum() == 0
oheight += 1 oheight += 1
owidth += 1 owidth += 1
result = transforms.Compose([ result = transforms.Compose(
transforms.ToPILImage(), [
transforms.CenterCrop((oheight, owidth)), transforms.ToPILImage(),
transforms.ToTensor(), transforms.CenterCrop((oheight, owidth)),
])(img) transforms.ToTensor(),
]
)(img)
sum1 = result.sum() sum1 = result.sum()
assert sum1 > 1 assert sum1 > 1
oheight += 1 oheight += 1
owidth += 1 owidth += 1
result = transforms.Compose([ result = transforms.Compose(
transforms.ToPILImage(), [
transforms.CenterCrop((oheight, owidth)), transforms.ToPILImage(),
transforms.ToTensor(), transforms.CenterCrop((oheight, owidth)),
])(img) transforms.ToTensor(),
]
)(img)
sum2 = result.sum() sum2 = result.sum()
assert sum2 > 0 assert sum2 > 0
assert sum2 > sum1 assert sum2 > sum1
@pytest.mark.parametrize('odd_image_size', (True, False)) @pytest.mark.parametrize("odd_image_size", (True, False))
@pytest.mark.parametrize('delta', (1, 3, 5)) @pytest.mark.parametrize("delta", (1, 3, 5))
@pytest.mark.parametrize('delta_width', (-2, -1, 0, 1, 2)) @pytest.mark.parametrize("delta_width", (-2, -1, 0, 1, 2))
@pytest.mark.parametrize('delta_height', (-2, -1, 0, 1, 2)) @pytest.mark.parametrize("delta_height", (-2, -1, 0, 1, 2))
def test_center_crop_2(odd_image_size, delta, delta_width, delta_height): def test_center_crop_2(odd_image_size, delta, delta_width, delta_height):
""" Tests when center crop size is larger than image size, along any dimension""" """Tests when center crop size is larger than image size, along any dimension"""
# Since height is independent of width, we can ignore images with odd height and even width and vice-versa. # Since height is independent of width, we can ignore images with odd height and even width and vice-versa.
input_image_size = (random.randint(10, 32) * 2, random.randint(10, 32) * 2) input_image_size = (random.randint(10, 32) * 2, random.randint(10, 32) * 2)
...@@ -1632,10 +1811,8 @@ def test_center_crop_2(odd_image_size, delta, delta_width, delta_height): ...@@ -1632,10 +1811,8 @@ def test_center_crop_2(odd_image_size, delta, delta_width, delta_height):
crop_size = (input_image_size[0] + delta_height, input_image_size[1] + delta_width) crop_size = (input_image_size[0] + delta_height, input_image_size[1] + delta_width)
# Test both transforms, one with PIL input and one with tensor # Test both transforms, one with PIL input and one with tensor
output_pil = transforms.Compose([ output_pil = transforms.Compose(
transforms.ToPILImage(), [transforms.ToPILImage(), transforms.CenterCrop(crop_size), transforms.ToTensor()],
transforms.CenterCrop(crop_size),
transforms.ToTensor()],
)(img) )(img)
assert output_pil.size()[1:3] == crop_size assert output_pil.size()[1:3] == crop_size
...@@ -1660,14 +1837,14 @@ def test_center_crop_2(odd_image_size, delta, delta_width, delta_height): ...@@ -1660,14 +1837,14 @@ def test_center_crop_2(odd_image_size, delta, delta_width, delta_height):
output_center = output_pil[ output_center = output_pil[
:, :,
crop_center_tl[0]:crop_center_tl[0] + center_size[0], crop_center_tl[0] : crop_center_tl[0] + center_size[0],
crop_center_tl[1]:crop_center_tl[1] + center_size[1] crop_center_tl[1] : crop_center_tl[1] + center_size[1],
] ]
img_center = img[ img_center = img[
:, :,
input_center_tl[0]:input_center_tl[0] + center_size[0], input_center_tl[0] : input_center_tl[0] + center_size[0],
input_center_tl[1]:input_center_tl[1] + center_size[1] input_center_tl[1] : input_center_tl[1] + center_size[1],
] ]
assert_equal(output_center, img_center) assert_equal(output_center, img_center)
...@@ -1679,8 +1856,8 @@ def test_color_jitter(): ...@@ -1679,8 +1856,8 @@ def test_color_jitter():
x_shape = [2, 2, 3] x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode='RGB') x_pil = Image.fromarray(x_np, mode="RGB")
x_pil_2 = x_pil.convert('L') x_pil_2 = x_pil.convert("L")
for _ in range(10): for _ in range(10):
y_pil = color_jitter(x_pil) y_pil = color_jitter(x_pil)
...@@ -1697,18 +1874,32 @@ def test_color_jitter(): ...@@ -1697,18 +1874,32 @@ def test_color_jitter():
def test_random_erasing(): def test_random_erasing():
img = torch.ones(3, 128, 128) img = torch.ones(3, 128, 128)
t = transforms.RandomErasing(scale=(0.1, 0.1), ratio=(1 / 3, 3.)) t = transforms.RandomErasing(scale=(0.1, 0.1), ratio=(1 / 3, 3.0))
y, x, h, w, v = t.get_params(img, t.scale, t.ratio, [t.value, ]) y, x, h, w, v = t.get_params(
img,
t.scale,
t.ratio,
[
t.value,
],
)
aspect_ratio = h / w aspect_ratio = h / w
# Add some tolerance due to the rounding and int conversion used in the transform # Add some tolerance due to the rounding and int conversion used in the transform
tol = 0.05 tol = 0.05
assert (1 / 3 - tol <= aspect_ratio <= 3 + tol) assert 1 / 3 - tol <= aspect_ratio <= 3 + tol
aspect_ratios = [] aspect_ratios = []
random.seed(42) random.seed(42)
trial = 1000 trial = 1000
for _ in range(trial): for _ in range(trial):
y, x, h, w, v = t.get_params(img, t.scale, t.ratio, [t.value, ]) y, x, h, w, v = t.get_params(
img,
t.scale,
t.ratio,
[
t.value,
],
)
aspect_ratios.append(h / w) aspect_ratios.append(h / w)
count_bigger_then_ones = len([1 for aspect_ratio in aspect_ratios if aspect_ratio > 1]) count_bigger_then_ones = len([1 for aspect_ratio in aspect_ratios if aspect_ratio > 1])
...@@ -1735,11 +1926,11 @@ def test_random_rotation(): ...@@ -1735,11 +1926,11 @@ def test_random_rotation():
t = transforms.RandomRotation(10) t = transforms.RandomRotation(10)
angle = t.get_params(t.degrees) angle = t.get_params(t.degrees)
assert (angle > -10 and angle < 10) assert angle > -10 and angle < 10
t = transforms.RandomRotation((-10, 10)) t = transforms.RandomRotation((-10, 10))
angle = t.get_params(t.degrees) angle = t.get_params(t.degrees)
assert (-10 < angle < 10) assert -10 < angle < 10
# Checking if RandomRotation can be printed as string # Checking if RandomRotation can be printed as string
t.__repr__() t.__repr__()
...@@ -1775,11 +1966,12 @@ def test_randomperspective(): ...@@ -1775,11 +1966,12 @@ def test_randomperspective():
tr_img = F.to_tensor(tr_img) tr_img = F.to_tensor(tr_img)
assert img.size[0] == width assert img.size[0] == width
assert img.size[1] == height assert img.size[1] == height
assert (torch.nn.functional.mse_loss(tr_img, F.to_tensor(img)) + 0.3 > assert torch.nn.functional.mse_loss(tr_img, F.to_tensor(img)) + 0.3 > torch.nn.functional.mse_loss(
torch.nn.functional.mse_loss(tr_img2, F.to_tensor(img))) tr_img2, F.to_tensor(img)
)
@pytest.mark.parametrize('mode', ["L", "RGB", "F"]) @pytest.mark.parametrize("mode", ["L", "RGB", "F"])
def test_randomperspective_fill(mode): def test_randomperspective_fill(mode):
# assert fill being either a Sequence or a Number # assert fill being either a Sequence or a Number
...@@ -1819,7 +2011,7 @@ def test_randomperspective_fill(mode): ...@@ -1819,7 +2011,7 @@ def test_randomperspective_fill(mode):
F.perspective(img_conv, startpoints, endpoints, fill=tuple([fill] * wrong_num_bands)) F.perspective(img_conv, startpoints, endpoints, fill=tuple([fill] * wrong_num_bands))
@pytest.mark.skipif(stats is None, reason='scipy.stats not available') @pytest.mark.skipif(stats is None, reason="scipy.stats not available")
def test_random_vertical_flip(): def test_random_vertical_flip():
random_state = random.getstate() random_state = random.getstate()
random.seed(42) random.seed(42)
...@@ -1852,7 +2044,7 @@ def test_random_vertical_flip(): ...@@ -1852,7 +2044,7 @@ def test_random_vertical_flip():
transforms.RandomVerticalFlip().__repr__() transforms.RandomVerticalFlip().__repr__()
@pytest.mark.skipif(stats is None, reason='scipy.stats not available') @pytest.mark.skipif(stats is None, reason="scipy.stats not available")
def test_random_horizontal_flip(): def test_random_horizontal_flip():
random_state = random.getstate() random_state = random.getstate()
random.seed(42) random.seed(42)
...@@ -1885,10 +2077,10 @@ def test_random_horizontal_flip(): ...@@ -1885,10 +2077,10 @@ def test_random_horizontal_flip():
transforms.RandomHorizontalFlip().__repr__() transforms.RandomHorizontalFlip().__repr__()
@pytest.mark.skipif(stats is None, reason='scipy.stats not available') @pytest.mark.skipif(stats is None, reason="scipy.stats not available")
def test_normalize(): def test_normalize():
def samples_from_standard_normal(tensor): def samples_from_standard_normal(tensor):
p_value = stats.kstest(list(tensor.view(-1)), 'norm', args=(0, 1)).pvalue p_value = stats.kstest(list(tensor.view(-1)), "norm", args=(0, 1)).pvalue
return p_value > 0.0001 return p_value > 0.0001
random_state = random.getstate() random_state = random.getstate()
...@@ -1910,8 +2102,8 @@ def test_normalize(): ...@@ -1910,8 +2102,8 @@ def test_normalize():
assert_equal(tensor, tensor_inplace) assert_equal(tensor, tensor_inplace)
@pytest.mark.parametrize('dtype1', [torch.float32, torch.float64]) @pytest.mark.parametrize("dtype1", [torch.float32, torch.float64])
@pytest.mark.parametrize('dtype2', [torch.int64, torch.float32, torch.float64]) @pytest.mark.parametrize("dtype2", [torch.int64, torch.float32, torch.float64])
def test_normalize_different_dtype(dtype1, dtype2): def test_normalize_different_dtype(dtype1, dtype2):
img = torch.rand(3, 10, 10, dtype=dtype1) img = torch.rand(3, 10, 10, dtype=dtype1)
mean = torch.tensor([1, 2, 3], dtype=dtype2) mean = torch.tensor([1, 2, 3], dtype=dtype2)
...@@ -1932,15 +2124,15 @@ def test_normalize_3d_tensor(): ...@@ -1932,15 +2124,15 @@ def test_normalize_3d_tensor():
mean_unsqueezed = mean.view(-1, 1, 1) mean_unsqueezed = mean.view(-1, 1, 1)
std_unsqueezed = std.view(-1, 1, 1) std_unsqueezed = std.view(-1, 1, 1)
result1 = F.normalize(img, mean_unsqueezed, std_unsqueezed) result1 = F.normalize(img, mean_unsqueezed, std_unsqueezed)
result2 = F.normalize(img, mean_unsqueezed.repeat(1, img_size, img_size), result2 = F.normalize(
std_unsqueezed.repeat(1, img_size, img_size)) img, mean_unsqueezed.repeat(1, img_size, img_size), std_unsqueezed.repeat(1, img_size, img_size)
)
torch.testing.assert_close(target, result1) torch.testing.assert_close(target, result1)
torch.testing.assert_close(target, result2) torch.testing.assert_close(target, result2)
class TestAffine: class TestAffine:
@pytest.fixture(scope="class")
@pytest.fixture(scope='class')
def input_img(self): def input_img(self):
input_img = np.zeros((40, 40, 3), dtype=np.uint8) input_img = np.zeros((40, 40, 3), dtype=np.uint8)
for pt in [(16, 16), (20, 16), (20, 20)]: for pt in [(16, 16), (20, 16), (20, 20)]:
...@@ -1953,7 +2145,7 @@ class TestAffine: ...@@ -1953,7 +2145,7 @@ class TestAffine:
with pytest.raises(TypeError, match=r"Argument translate should be a sequence"): with pytest.raises(TypeError, match=r"Argument translate should be a sequence"):
F.affine(input_img, 10, translate=0, scale=1, shear=1) F.affine(input_img, 10, translate=0, scale=1, shear=1)
@pytest.fixture(scope='class') @pytest.fixture(scope="class")
def pil_image(self, input_img): def pil_image(self, input_img):
return F.to_pil_image(input_img) return F.to_pil_image(input_img)
...@@ -1974,33 +2166,29 @@ class TestAffine: ...@@ -1974,33 +2166,29 @@ class TestAffine:
rot = a_rad rot = a_rad
# 1) Check transformation matrix: # 1) Check transformation matrix:
C = np.array([[1, 0, cx], C = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]])
[0, 1, cy], T = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]])
[0, 0, 1]])
T = np.array([[1, 0, tx],
[0, 1, ty],
[0, 0, 1]])
Cinv = np.linalg.inv(C) Cinv = np.linalg.inv(C)
RS = np.array( RS = np.array(
[[scale * math.cos(rot), -scale * math.sin(rot), 0], [
[scale * math.sin(rot), scale * math.cos(rot), 0], [scale * math.cos(rot), -scale * math.sin(rot), 0],
[0, 0, 1]]) [scale * math.sin(rot), scale * math.cos(rot), 0],
[0, 0, 1],
]
)
SHx = np.array([[1, -math.tan(sx), 0], SHx = np.array([[1, -math.tan(sx), 0], [0, 1, 0], [0, 0, 1]])
[0, 1, 0],
[0, 0, 1]])
SHy = np.array([[1, 0, 0], SHy = np.array([[1, 0, 0], [-math.tan(sy), 1, 0], [0, 0, 1]])
[-math.tan(sy), 1, 0],
[0, 0, 1]])
RSS = np.matmul(RS, np.matmul(SHy, SHx)) RSS = np.matmul(RS, np.matmul(SHy, SHx))
true_matrix = np.matmul(T, np.matmul(C, np.matmul(RSS, Cinv))) true_matrix = np.matmul(T, np.matmul(C, np.matmul(RSS, Cinv)))
result_matrix = self._to_3x3_inv(F._get_inverse_affine_matrix(center=cnt, angle=angle, result_matrix = self._to_3x3_inv(
translate=translate, scale=scale, shear=shear)) F._get_inverse_affine_matrix(center=cnt, angle=angle, translate=translate, scale=scale, shear=shear)
)
assert np.sum(np.abs(true_matrix - result_matrix)) < 1e-10 assert np.sum(np.abs(true_matrix - result_matrix)) < 1e-10
# 2) Perform inverse mapping: # 2) Perform inverse mapping:
true_result = np.zeros((40, 40, 3), dtype=np.uint8) true_result = np.zeros((40, 40, 3), dtype=np.uint8)
...@@ -2022,38 +2210,49 @@ class TestAffine: ...@@ -2022,38 +2210,49 @@ class TestAffine:
np_result = np.array(result) np_result = np.array(result)
n_diff_pixels = np.sum(np_result != true_result) / 3 n_diff_pixels = np.sum(np_result != true_result) / 3
# Accept 3 wrong pixels # Accept 3 wrong pixels
error_msg = ("angle={}, translate={}, scale={}, shear={}\n".format(angle, translate, scale, shear) + error_msg = "angle={}, translate={}, scale={}, shear={}\n".format(
"n diff pixels={}\n".format(n_diff_pixels)) angle, translate, scale, shear
) + "n diff pixels={}\n".format(n_diff_pixels)
assert n_diff_pixels < 3, error_msg assert n_diff_pixels < 3, error_msg
def test_transformation_discrete(self, pil_image, input_img): def test_transformation_discrete(self, pil_image, input_img):
# Test rotation # Test rotation
angle = 45 angle = 45
self._test_transformation(angle=angle, translate=(0, 0), scale=1.0, self._test_transformation(
shear=(0.0, 0.0), pil_image=pil_image, input_img=input_img) angle=angle, translate=(0, 0), scale=1.0, shear=(0.0, 0.0), pil_image=pil_image, input_img=input_img
)
# Test translation # Test translation
translate = [10, 15] translate = [10, 15]
self._test_transformation(angle=0.0, translate=translate, scale=1.0, self._test_transformation(
shear=(0.0, 0.0), pil_image=pil_image, input_img=input_img) angle=0.0, translate=translate, scale=1.0, shear=(0.0, 0.0), pil_image=pil_image, input_img=input_img
)
# Test scale # Test scale
scale = 1.2 scale = 1.2
self._test_transformation(angle=0.0, translate=(0.0, 0.0), scale=scale, self._test_transformation(
shear=(0.0, 0.0), pil_image=pil_image, input_img=input_img) angle=0.0, translate=(0.0, 0.0), scale=scale, shear=(0.0, 0.0), pil_image=pil_image, input_img=input_img
)
# Test shear # Test shear
shear = [45.0, 25.0] shear = [45.0, 25.0]
self._test_transformation(angle=0.0, translate=(0.0, 0.0), scale=1.0, self._test_transformation(
shear=shear, pil_image=pil_image, input_img=input_img) angle=0.0, translate=(0.0, 0.0), scale=1.0, shear=shear, pil_image=pil_image, input_img=input_img
)
@pytest.mark.parametrize("angle", range(-90, 90, 36)) @pytest.mark.parametrize("angle", range(-90, 90, 36))
@pytest.mark.parametrize("translate", range(-10, 10, 5)) @pytest.mark.parametrize("translate", range(-10, 10, 5))
@pytest.mark.parametrize("scale", [0.77, 1.0, 1.27]) @pytest.mark.parametrize("scale", [0.77, 1.0, 1.27])
@pytest.mark.parametrize("shear", range(-15, 15, 5)) @pytest.mark.parametrize("shear", range(-15, 15, 5))
def test_transformation_range(self, angle, translate, scale, shear, pil_image, input_img): def test_transformation_range(self, angle, translate, scale, shear, pil_image, input_img):
self._test_transformation(angle=angle, translate=(translate, translate), scale=scale, self._test_transformation(
shear=(shear, shear), pil_image=pil_image, input_img=input_img) angle=angle,
translate=(translate, translate),
scale=scale,
shear=(shear, shear),
pil_image=pil_image,
input_img=input_img,
)
def test_random_affine(): def test_random_affine():
...@@ -2101,13 +2300,14 @@ def test_random_affine(): ...@@ -2101,13 +2300,14 @@ def test_random_affine():
t = transforms.RandomAffine(10, translate=[0.5, 0.3], scale=[0.7, 1.3], shear=[-10, 10, 20, 40]) t = transforms.RandomAffine(10, translate=[0.5, 0.3], scale=[0.7, 1.3], shear=[-10, 10, 20, 40])
for _ in range(100): for _ in range(100):
angle, translations, scale, shear = t.get_params(t.degrees, t.translate, t.scale, t.shear, angle, translations, scale, shear = t.get_params(t.degrees, t.translate, t.scale, t.shear, img_size=img.size)
img_size=img.size)
assert -10 < angle < 10 assert -10 < angle < 10
assert -img.size[0] * 0.5 <= translations[0] <= img.size[0] * 0.5, ("{} vs {}" assert -img.size[0] * 0.5 <= translations[0] <= img.size[0] * 0.5, "{} vs {}".format(
.format(translations[0], img.size[0] * 0.5)) translations[0], img.size[0] * 0.5
assert -img.size[1] * 0.5 <= translations[1] <= img.size[1] * 0.5, ("{} vs {}" )
.format(translations[1], img.size[1] * 0.5)) assert -img.size[1] * 0.5 <= translations[1] <= img.size[1] * 0.5, "{} vs {}".format(
translations[1], img.size[1] * 0.5
)
assert 0.7 < scale < 1.3 assert 0.7 < scale < 1.3
assert -10 < shear[0] < 10 assert -10 < shear[0] < 10
assert -20 < shear[1] < 40 assert -20 < shear[1] < 40
...@@ -2133,5 +2333,5 @@ def test_random_affine(): ...@@ -2133,5 +2333,5 @@ def test_random_affine():
assert t.interpolation == transforms.InterpolationMode.BILINEAR assert t.interpolation == transforms.InterpolationMode.BILINEAR
if __name__ == '__main__': if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
import os import os
import torch from typing import Sequence
from torchvision import transforms as T
from torchvision.transforms import functional as F
from torchvision.transforms import InterpolationMode
import numpy as np import numpy as np
import pytest import pytest
import torch
from typing import Sequence
from common_utils import ( from common_utils import (
get_tmp_dir, get_tmp_dir,
int_dtypes, int_dtypes,
...@@ -20,6 +15,9 @@ from common_utils import ( ...@@ -20,6 +15,9 @@ from common_utils import (
cpu_and_gpu, cpu_and_gpu,
assert_equal, assert_equal,
) )
from torchvision import transforms as T
from torchvision.transforms import InterpolationMode
from torchvision.transforms import functional as F
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC
...@@ -94,110 +92,137 @@ def _test_op(func, method, device, channels=3, fn_kwargs=None, meth_kwargs=None, ...@@ -94,110 +92,137 @@ def _test_op(func, method, device, channels=3, fn_kwargs=None, meth_kwargs=None,
_test_class_op(method, device, channels, meth_kwargs, test_exact_match=test_exact_match, **match_kwargs) _test_class_op(method, device, channels, meth_kwargs, test_exact_match=test_exact_match, **match_kwargs)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize( @pytest.mark.parametrize(
'func,method,fn_kwargs,match_kwargs', [ "func,method,fn_kwargs,match_kwargs",
[
(F.hflip, T.RandomHorizontalFlip, None, {}), (F.hflip, T.RandomHorizontalFlip, None, {}),
(F.vflip, T.RandomVerticalFlip, None, {}), (F.vflip, T.RandomVerticalFlip, None, {}),
(F.invert, T.RandomInvert, None, {}), (F.invert, T.RandomInvert, None, {}),
(F.posterize, T.RandomPosterize, {"bits": 4}, {}), (F.posterize, T.RandomPosterize, {"bits": 4}, {}),
(F.solarize, T.RandomSolarize, {"threshold": 192.0}, {}), (F.solarize, T.RandomSolarize, {"threshold": 192.0}, {}),
(F.adjust_sharpness, T.RandomAdjustSharpness, {"sharpness_factor": 2.0}, {}), (F.adjust_sharpness, T.RandomAdjustSharpness, {"sharpness_factor": 2.0}, {}),
(F.autocontrast, T.RandomAutocontrast, None, {'test_exact_match': False, (
'agg_method': 'max', 'tol': (1 + 1e-5), F.autocontrast,
'allowed_percentage_diff': .05}), T.RandomAutocontrast,
(F.equalize, T.RandomEqualize, None, {}) None,
] {"test_exact_match": False, "agg_method": "max", "tol": (1 + 1e-5), "allowed_percentage_diff": 0.05},
),
(F.equalize, T.RandomEqualize, None, {}),
],
) )
@pytest.mark.parametrize('channels', [1, 3]) @pytest.mark.parametrize("channels", [1, 3])
def test_random(func, method, device, channels, fn_kwargs, match_kwargs): def test_random(func, method, device, channels, fn_kwargs, match_kwargs):
_test_op(func, method, device, channels, fn_kwargs, fn_kwargs, **match_kwargs) _test_op(func, method, device, channels, fn_kwargs, fn_kwargs, **match_kwargs)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('channels', [1, 3]) @pytest.mark.parametrize("channels", [1, 3])
class TestColorJitter: class TestColorJitter:
@pytest.mark.parametrize("brightness", [0.1, 0.5, 1.0, 1.34, (0.3, 0.7), [0.4, 0.5]])
@pytest.mark.parametrize('brightness', [0.1, 0.5, 1.0, 1.34, (0.3, 0.7), [0.4, 0.5]])
def test_color_jitter_brightness(self, brightness, device, channels): def test_color_jitter_brightness(self, brightness, device, channels):
tol = 1.0 + 1e-10 tol = 1.0 + 1e-10
meth_kwargs = {"brightness": brightness} meth_kwargs = {"brightness": brightness}
_test_class_op( _test_class_op(
T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=device, T.ColorJitter,
tol=tol, agg_method="max", channels=channels, meth_kwargs=meth_kwargs,
test_exact_match=False,
device=device,
tol=tol,
agg_method="max",
channels=channels,
) )
@pytest.mark.parametrize('contrast', [0.2, 0.5, 1.0, 1.5, (0.3, 0.7), [0.4, 0.5]]) @pytest.mark.parametrize("contrast", [0.2, 0.5, 1.0, 1.5, (0.3, 0.7), [0.4, 0.5]])
def test_color_jitter_contrast(self, contrast, device, channels): def test_color_jitter_contrast(self, contrast, device, channels):
tol = 1.0 + 1e-10 tol = 1.0 + 1e-10
meth_kwargs = {"contrast": contrast} meth_kwargs = {"contrast": contrast}
_test_class_op( _test_class_op(
T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=device, T.ColorJitter,
tol=tol, agg_method="max", channels=channels meth_kwargs=meth_kwargs,
test_exact_match=False,
device=device,
tol=tol,
agg_method="max",
channels=channels,
) )
@pytest.mark.parametrize('saturation', [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]]) @pytest.mark.parametrize("saturation", [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]])
def test_color_jitter_saturation(self, saturation, device, channels): def test_color_jitter_saturation(self, saturation, device, channels):
tol = 1.0 + 1e-10 tol = 1.0 + 1e-10
meth_kwargs = {"saturation": saturation} meth_kwargs = {"saturation": saturation}
_test_class_op( _test_class_op(
T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=device, T.ColorJitter,
tol=tol, agg_method="max", channels=channels meth_kwargs=meth_kwargs,
test_exact_match=False,
device=device,
tol=tol,
agg_method="max",
channels=channels,
) )
@pytest.mark.parametrize('hue', [0.2, 0.5, (-0.2, 0.3), [-0.4, 0.5]]) @pytest.mark.parametrize("hue", [0.2, 0.5, (-0.2, 0.3), [-0.4, 0.5]])
def test_color_jitter_hue(self, hue, device, channels): def test_color_jitter_hue(self, hue, device, channels):
meth_kwargs = {"hue": hue} meth_kwargs = {"hue": hue}
_test_class_op( _test_class_op(
T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=device, T.ColorJitter,
tol=16.1, agg_method="max", channels=channels meth_kwargs=meth_kwargs,
test_exact_match=False,
device=device,
tol=16.1,
agg_method="max",
channels=channels,
) )
def test_color_jitter_all(self, device, channels): def test_color_jitter_all(self, device, channels):
# All 4 parameters together # All 4 parameters together
meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2} meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2}
_test_class_op( _test_class_op(
T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=device, T.ColorJitter,
tol=12.1, agg_method="max", channels=channels meth_kwargs=meth_kwargs,
test_exact_match=False,
device=device,
tol=12.1,
agg_method="max",
channels=channels,
) )
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('m', ["constant", "edge", "reflect", "symmetric"]) @pytest.mark.parametrize("m", ["constant", "edge", "reflect", "symmetric"])
@pytest.mark.parametrize('mul', [1, -1]) @pytest.mark.parametrize("mul", [1, -1])
def test_pad(m, mul, device): def test_pad(m, mul, device):
fill = 127 if m == "constant" else 0 fill = 127 if m == "constant" else 0
# Test functional.pad (PIL and Tensor) with padding as single int # Test functional.pad (PIL and Tensor) with padding as single int
_test_functional_op( _test_functional_op(F.pad, fn_kwargs={"padding": mul * 2, "fill": fill, "padding_mode": m}, device=device)
F.pad, fn_kwargs={"padding": mul * 2, "fill": fill, "padding_mode": m},
device=device
)
# Test functional.pad and transforms.Pad with padding as [int, ] # Test functional.pad and transforms.Pad with padding as [int, ]
fn_kwargs = meth_kwargs = {"padding": [mul * 2, ], "fill": fill, "padding_mode": m} fn_kwargs = meth_kwargs = {
_test_op( "padding": [
F.pad, T.Pad, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs mul * 2,
) ],
"fill": fill,
"padding_mode": m,
}
_test_op(F.pad, T.Pad, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)
# Test functional.pad and transforms.Pad with padding as list # Test functional.pad and transforms.Pad with padding as list
fn_kwargs = meth_kwargs = {"padding": [mul * 4, 4], "fill": fill, "padding_mode": m} fn_kwargs = meth_kwargs = {"padding": [mul * 4, 4], "fill": fill, "padding_mode": m}
_test_op( _test_op(F.pad, T.Pad, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)
F.pad, T.Pad, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
# Test functional.pad and transforms.Pad with padding as tuple # Test functional.pad and transforms.Pad with padding as tuple
fn_kwargs = meth_kwargs = {"padding": (mul * 2, 2, 2, mul * 2), "fill": fill, "padding_mode": m} fn_kwargs = meth_kwargs = {"padding": (mul * 2, 2, 2, mul * 2), "fill": fill, "padding_mode": m}
_test_op( _test_op(F.pad, T.Pad, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)
F.pad, T.Pad, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
def test_crop(device): def test_crop(device):
fn_kwargs = {"top": 2, "left": 3, "height": 4, "width": 5} fn_kwargs = {"top": 2, "left": 3, "height": 4, "width": 5}
# Test transforms.RandomCrop with size and padding as tuple # Test transforms.RandomCrop with size and padding as tuple
meth_kwargs = {"size": (4, 5), "padding": (4, 4), "pad_if_needed": True, } meth_kwargs = {
_test_op( "size": (4, 5),
F.crop, T.RandomCrop, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs "padding": (4, 4),
) "pad_if_needed": True,
}
_test_op(F.crop, T.RandomCrop, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)
# Test transforms.functional.crop including outside the image area # Test transforms.functional.crop including outside the image area
fn_kwargs = {"top": -2, "left": 3, "height": 4, "width": 5} # top fn_kwargs = {"top": -2, "left": 3, "height": 4, "width": 5} # top
...@@ -216,35 +241,43 @@ def test_crop(device): ...@@ -216,35 +241,43 @@ def test_crop(device):
_test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=device) _test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=device)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('padding_config', [ @pytest.mark.parametrize(
{"padding_mode": "constant", "fill": 0}, "padding_config",
{"padding_mode": "constant", "fill": 10}, [
{"padding_mode": "constant", "fill": 20}, {"padding_mode": "constant", "fill": 0},
{"padding_mode": "edge"}, {"padding_mode": "constant", "fill": 10},
{"padding_mode": "reflect"} {"padding_mode": "constant", "fill": 20},
]) {"padding_mode": "edge"},
@pytest.mark.parametrize('size', [5, [5, ], [6, 6]]) {"padding_mode": "reflect"},
],
)
@pytest.mark.parametrize(
"size",
[
5,
[
5,
],
[6, 6],
],
)
def test_crop_pad(size, padding_config, device): def test_crop_pad(size, padding_config, device):
config = dict(padding_config) config = dict(padding_config)
config["size"] = size config["size"] = size
_test_class_op(T.RandomCrop, device, meth_kwargs=config) _test_class_op(T.RandomCrop, device, meth_kwargs=config)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
def test_center_crop(device, tmpdir): def test_center_crop(device, tmpdir):
fn_kwargs = {"output_size": (4, 5)} fn_kwargs = {"output_size": (4, 5)}
meth_kwargs = {"size": (4, 5), } meth_kwargs = {
_test_op( "size": (4, 5),
F.center_crop, T.CenterCrop, device=device, fn_kwargs=fn_kwargs, }
meth_kwargs=meth_kwargs _test_op(F.center_crop, T.CenterCrop, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)
)
fn_kwargs = {"output_size": (5,)} fn_kwargs = {"output_size": (5,)}
meth_kwargs = {"size": (5,)} meth_kwargs = {"size": (5,)}
_test_op( _test_op(F.center_crop, T.CenterCrop, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)
F.center_crop, T.CenterCrop, device=device, fn_kwargs=fn_kwargs,
meth_kwargs=meth_kwargs
)
tensor = torch.randint(0, 256, (3, 10, 10), dtype=torch.uint8, device=device) tensor = torch.randint(0, 256, (3, 10, 10), dtype=torch.uint8, device=device)
# Test torchscript of transforms.CenterCrop with size as int # Test torchscript of transforms.CenterCrop with size as int
f = T.CenterCrop(size=5) f = T.CenterCrop(size=5)
...@@ -252,7 +285,11 @@ def test_center_crop(device, tmpdir): ...@@ -252,7 +285,11 @@ def test_center_crop(device, tmpdir):
scripted_fn(tensor) scripted_fn(tensor)
# Test torchscript of transforms.CenterCrop with size as [int, ] # Test torchscript of transforms.CenterCrop with size as [int, ]
f = T.CenterCrop(size=[5, ]) f = T.CenterCrop(
size=[
5,
]
)
scripted_fn = torch.jit.script(f) scripted_fn = torch.jit.script(f)
scripted_fn(tensor) scripted_fn(tensor)
...@@ -264,16 +301,29 @@ def test_center_crop(device, tmpdir): ...@@ -264,16 +301,29 @@ def test_center_crop(device, tmpdir):
scripted_fn.save(os.path.join(tmpdir, "t_center_crop.pt")) scripted_fn.save(os.path.join(tmpdir, "t_center_crop.pt"))
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('fn, method, out_length', [ @pytest.mark.parametrize(
# test_five_crop "fn, method, out_length",
(F.five_crop, T.FiveCrop, 5), [
# test_ten_crop # test_five_crop
(F.ten_crop, T.TenCrop, 10) (F.five_crop, T.FiveCrop, 5),
]) # test_ten_crop
@pytest.mark.parametrize('size', [(5,), [5, ], (4, 5), [4, 5]]) (F.ten_crop, T.TenCrop, 10),
],
)
@pytest.mark.parametrize(
"size",
[
(5,),
[
5,
],
(4, 5),
[4, 5],
],
)
def test_x_crop(fn, method, out_length, size, device): def test_x_crop(fn, method, out_length, size, device):
meth_kwargs = fn_kwargs = {'size': size} meth_kwargs = fn_kwargs = {"size": size}
scripted_fn = torch.jit.script(fn) scripted_fn = torch.jit.script(fn)
tensor, pil_img = _create_data(height=20, width=20, device=device) tensor, pil_img = _create_data(height=20, width=20, device=device)
...@@ -309,15 +359,19 @@ def test_x_crop(fn, method, out_length, size, device): ...@@ -309,15 +359,19 @@ def test_x_crop(fn, method, out_length, size, device):
assert_equal(transformed_img, transformed_batch[i, ...]) assert_equal(transformed_img, transformed_batch[i, ...])
@pytest.mark.parametrize('method', ["FiveCrop", "TenCrop"]) @pytest.mark.parametrize("method", ["FiveCrop", "TenCrop"])
def test_x_crop_save(method, tmpdir): def test_x_crop_save(method, tmpdir):
fn = getattr(T, method)(size=[5, ]) fn = getattr(T, method)(
size=[
5,
]
)
scripted_fn = torch.jit.script(fn) scripted_fn = torch.jit.script(fn)
scripted_fn.save(os.path.join(tmpdir, "t_op_list_{}.pt".format(method))) scripted_fn.save(os.path.join(tmpdir, "t_op_list_{}.pt".format(method)))
class TestResize: class TestResize:
@pytest.mark.parametrize('size', [32, 34, 35, 36, 38]) @pytest.mark.parametrize("size", [32, 34, 35, 36, 38])
def test_resize_int(self, size): def test_resize_int(self, size):
# TODO: Minimal check for bug-fix, improve this later # TODO: Minimal check for bug-fix, improve this later
x = torch.rand(3, 32, 46) x = torch.rand(3, 32, 46)
...@@ -329,11 +383,21 @@ class TestResize: ...@@ -329,11 +383,21 @@ class TestResize:
assert y.shape[1] == size assert y.shape[1] == size
assert y.shape[2] == int(size * 46 / 32) assert y.shape[2] == int(size * 46 / 32)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64]) @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64])
@pytest.mark.parametrize('size', [[32, ], [32, 32], (32, 32), [34, 35]]) @pytest.mark.parametrize(
@pytest.mark.parametrize('max_size', [None, 35, 1000]) "size",
@pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC, NEAREST]) [
[
32,
],
[32, 32],
(32, 32),
[34, 35],
],
)
@pytest.mark.parametrize("max_size", [None, 35, 1000])
@pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC, NEAREST])
def test_resize_scripted(self, dt, size, max_size, interpolation, device): def test_resize_scripted(self, dt, size, max_size, interpolation, device):
tensor, _ = _create_data(height=34, width=36, device=device) tensor, _ = _create_data(height=34, width=36, device=device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device) batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
...@@ -350,15 +414,33 @@ class TestResize: ...@@ -350,15 +414,33 @@ class TestResize:
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
def test_resize_save(self, tmpdir): def test_resize_save(self, tmpdir):
transform = T.Resize(size=[32, ]) transform = T.Resize(
size=[
32,
]
)
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
s_transform.save(os.path.join(tmpdir, "t_resize.pt")) s_transform.save(os.path.join(tmpdir, "t_resize.pt"))
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('scale', [(0.7, 1.2), [0.7, 1.2]]) @pytest.mark.parametrize("scale", [(0.7, 1.2), [0.7, 1.2]])
@pytest.mark.parametrize('ratio', [(0.75, 1.333), [0.75, 1.333]]) @pytest.mark.parametrize("ratio", [(0.75, 1.333), [0.75, 1.333]])
@pytest.mark.parametrize('size', [(32,), [44, ], [32, ], [32, 32], (32, 32), [44, 55]]) @pytest.mark.parametrize(
@pytest.mark.parametrize('interpolation', [NEAREST, BILINEAR, BICUBIC]) "size",
[
(32,),
[
44,
],
[
32,
],
[32, 32],
(32, 32),
[44, 55],
],
)
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR, BICUBIC])
def test_resized_crop(self, scale, ratio, size, interpolation, device): def test_resized_crop(self, scale, ratio, size, interpolation, device):
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device) tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device) batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
...@@ -368,7 +450,11 @@ class TestResize: ...@@ -368,7 +450,11 @@ class TestResize:
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
def test_resized_crop_save(self, tmpdir): def test_resized_crop_save(self, tmpdir):
transform = T.RandomResizedCrop(size=[32, ]) transform = T.RandomResizedCrop(
size=[
32,
]
)
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
s_transform.save(os.path.join(tmpdir, "t_resized_crop.pt")) s_transform.save(os.path.join(tmpdir, "t_resized_crop.pt"))
...@@ -383,61 +469,83 @@ def _test_random_affine_helper(device, **kwargs): ...@@ -383,61 +469,83 @@ def _test_random_affine_helper(device, **kwargs):
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
def test_random_affine(device, tmpdir): def test_random_affine(device, tmpdir):
transform = T.RandomAffine(degrees=45.0) transform = T.RandomAffine(degrees=45.0)
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
s_transform.save(os.path.join(tmpdir, "t_random_affine.pt")) s_transform.save(os.path.join(tmpdir, "t_random_affine.pt"))
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('interpolation', [NEAREST, BILINEAR]) @pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
@pytest.mark.parametrize('shear', [15, 10.0, (5.0, 10.0), [-15, 15], [-10.0, 10.0, -11.0, 11.0]]) @pytest.mark.parametrize("shear", [15, 10.0, (5.0, 10.0), [-15, 15], [-10.0, 10.0, -11.0, 11.0]])
def test_random_affine_shear(device, interpolation, shear): def test_random_affine_shear(device, interpolation, shear):
_test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, shear=shear) _test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, shear=shear)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('interpolation', [NEAREST, BILINEAR]) @pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
@pytest.mark.parametrize('scale', [(0.7, 1.2), [0.7, 1.2]]) @pytest.mark.parametrize("scale", [(0.7, 1.2), [0.7, 1.2]])
def test_random_affine_scale(device, interpolation, scale): def test_random_affine_scale(device, interpolation, scale):
_test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, scale=scale) _test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, scale=scale)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('interpolation', [NEAREST, BILINEAR]) @pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
@pytest.mark.parametrize('translate', [(0.1, 0.2), [0.2, 0.1]]) @pytest.mark.parametrize("translate", [(0.1, 0.2), [0.2, 0.1]])
def test_random_affine_translate(device, interpolation, translate): def test_random_affine_translate(device, interpolation, translate):
_test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, translate=translate) _test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, translate=translate)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('interpolation', [NEAREST, BILINEAR]) @pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
@pytest.mark.parametrize('degrees', [45, 35.0, (-45, 45), [-90.0, 90.0]]) @pytest.mark.parametrize("degrees", [45, 35.0, (-45, 45), [-90.0, 90.0]])
def test_random_affine_degrees(device, interpolation, degrees): def test_random_affine_degrees(device, interpolation, degrees):
_test_random_affine_helper(device, degrees=degrees, interpolation=interpolation) _test_random_affine_helper(device, degrees=degrees, interpolation=interpolation)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('interpolation', [NEAREST, BILINEAR]) @pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
@pytest.mark.parametrize('fill', [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]) @pytest.mark.parametrize(
"fill",
[
85,
(10, -10, 10),
0.7,
[0.0, 0.0, 0.0],
[
1,
],
1,
],
)
def test_random_affine_fill(device, interpolation, fill): def test_random_affine_fill(device, interpolation, fill):
_test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, fill=fill) _test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, fill=fill)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('center', [(0, 0), [10, 10], None, (56, 44)]) @pytest.mark.parametrize("center", [(0, 0), [10, 10], None, (56, 44)])
@pytest.mark.parametrize('expand', [True, False]) @pytest.mark.parametrize("expand", [True, False])
@pytest.mark.parametrize('degrees', [45, 35.0, (-45, 45), [-90.0, 90.0]]) @pytest.mark.parametrize("degrees", [45, 35.0, (-45, 45), [-90.0, 90.0]])
@pytest.mark.parametrize('interpolation', [NEAREST, BILINEAR]) @pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
@pytest.mark.parametrize('fill', [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]) @pytest.mark.parametrize(
"fill",
[
85,
(10, -10, 10),
0.7,
[0.0, 0.0, 0.0],
[
1,
],
1,
],
)
def test_random_rotate(device, center, expand, degrees, interpolation, fill): def test_random_rotate(device, center, expand, degrees, interpolation, fill):
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device) tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device) batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
transform = T.RandomRotation( transform = T.RandomRotation(degrees=degrees, interpolation=interpolation, expand=expand, center=center, fill=fill)
degrees=degrees, interpolation=interpolation, expand=expand, center=center, fill=fill
)
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
_test_transform_vs_scripted(transform, s_transform, tensor) _test_transform_vs_scripted(transform, s_transform, tensor)
...@@ -450,19 +558,27 @@ def test_random_rotate_save(tmpdir): ...@@ -450,19 +558,27 @@ def test_random_rotate_save(tmpdir):
s_transform.save(os.path.join(tmpdir, "t_random_rotate.pt")) s_transform.save(os.path.join(tmpdir, "t_random_rotate.pt"))
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('distortion_scale', np.linspace(0.1, 1.0, num=20)) @pytest.mark.parametrize("distortion_scale", np.linspace(0.1, 1.0, num=20))
@pytest.mark.parametrize('interpolation', [NEAREST, BILINEAR]) @pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
@pytest.mark.parametrize('fill', [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]) @pytest.mark.parametrize(
"fill",
[
85,
(10, -10, 10),
0.7,
[0.0, 0.0, 0.0],
[
1,
],
1,
],
)
def test_random_perspective(device, distortion_scale, interpolation, fill): def test_random_perspective(device, distortion_scale, interpolation, fill):
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device) tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device) batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
transform = T.RandomPerspective( transform = T.RandomPerspective(distortion_scale=distortion_scale, interpolation=interpolation, fill=fill)
distortion_scale=distortion_scale,
interpolation=interpolation,
fill=fill
)
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
_test_transform_vs_scripted(transform, s_transform, tensor) _test_transform_vs_scripted(transform, s_transform, tensor)
...@@ -475,23 +591,19 @@ def test_random_perspective_save(tmpdir): ...@@ -475,23 +591,19 @@ def test_random_perspective_save(tmpdir):
s_transform.save(os.path.join(tmpdir, "t_perspective.pt")) s_transform.save(os.path.join(tmpdir, "t_perspective.pt"))
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('Klass, meth_kwargs', [ @pytest.mark.parametrize(
(T.Grayscale, {"num_output_channels": 1}), "Klass, meth_kwargs",
(T.Grayscale, {"num_output_channels": 3}), [(T.Grayscale, {"num_output_channels": 1}), (T.Grayscale, {"num_output_channels": 3}), (T.RandomGrayscale, {})],
(T.RandomGrayscale, {}) )
])
def test_to_grayscale(device, Klass, meth_kwargs): def test_to_grayscale(device, Klass, meth_kwargs):
tol = 1.0 + 1e-10 tol = 1.0 + 1e-10
_test_class_op( _test_class_op(Klass, meth_kwargs=meth_kwargs, test_exact_match=False, device=device, tol=tol, agg_method="max")
Klass, meth_kwargs=meth_kwargs, test_exact_match=False, device=device,
tol=tol, agg_method="max"
)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('in_dtype', int_dtypes() + float_dtypes()) @pytest.mark.parametrize("in_dtype", int_dtypes() + float_dtypes())
@pytest.mark.parametrize('out_dtype', int_dtypes() + float_dtypes()) @pytest.mark.parametrize("out_dtype", int_dtypes() + float_dtypes())
def test_convert_image_dtype(device, in_dtype, out_dtype): def test_convert_image_dtype(device, in_dtype, out_dtype):
tensor, _ = _create_data(26, 34, device=device) tensor, _ = _create_data(26, 34, device=device)
batch_tensors = torch.rand(4, 3, 44, 56, device=device) batch_tensors = torch.rand(4, 3, 44, 56, device=device)
...@@ -502,8 +614,9 @@ def test_convert_image_dtype(device, in_dtype, out_dtype): ...@@ -502,8 +614,9 @@ def test_convert_image_dtype(device, in_dtype, out_dtype):
fn = T.ConvertImageDtype(dtype=out_dtype) fn = T.ConvertImageDtype(dtype=out_dtype)
scripted_fn = torch.jit.script(fn) scripted_fn = torch.jit.script(fn)
if (in_dtype == torch.float32 and out_dtype in (torch.int32, torch.int64)) or \ if (in_dtype == torch.float32 and out_dtype in (torch.int32, torch.int64)) or (
(in_dtype == torch.float64 and out_dtype == torch.int64): in_dtype == torch.float64 and out_dtype == torch.int64
):
with pytest.raises(RuntimeError, match=r"cannot be performed safely"): with pytest.raises(RuntimeError, match=r"cannot be performed safely"):
_test_transform_vs_scripted(fn, scripted_fn, in_tensor) _test_transform_vs_scripted(fn, scripted_fn, in_tensor)
with pytest.raises(RuntimeError, match=r"cannot be performed safely"): with pytest.raises(RuntimeError, match=r"cannot be performed safely"):
...@@ -520,9 +633,22 @@ def test_convert_image_dtype_save(tmpdir): ...@@ -520,9 +633,22 @@ def test_convert_image_dtype_save(tmpdir):
scripted_fn.save(os.path.join(tmpdir, "t_convert_dtype.pt")) scripted_fn.save(os.path.join(tmpdir, "t_convert_dtype.pt"))
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('policy', [policy for policy in T.AutoAugmentPolicy]) @pytest.mark.parametrize("policy", [policy for policy in T.AutoAugmentPolicy])
@pytest.mark.parametrize('fill', [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]) @pytest.mark.parametrize(
"fill",
[
None,
85,
(10, -10, 10),
0.7,
[0.0, 0.0, 0.0],
[
1,
],
1,
],
)
def test_autoaugment(device, policy, fill): def test_autoaugment(device, policy, fill):
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device) tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device) batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
...@@ -534,10 +660,23 @@ def test_autoaugment(device, policy, fill): ...@@ -534,10 +660,23 @@ def test_autoaugment(device, policy, fill):
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('num_ops', [1, 2, 3]) @pytest.mark.parametrize("num_ops", [1, 2, 3])
@pytest.mark.parametrize('magnitude', [7, 9, 11]) @pytest.mark.parametrize("magnitude", [7, 9, 11])
@pytest.mark.parametrize('fill', [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]) @pytest.mark.parametrize(
"fill",
[
None,
85,
(10, -10, 10),
0.7,
[0.0, 0.0, 0.0],
[
1,
],
1,
],
)
def test_randaugment(device, num_ops, magnitude, fill): def test_randaugment(device, num_ops, magnitude, fill):
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device) tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device) batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
...@@ -549,8 +688,21 @@ def test_randaugment(device, num_ops, magnitude, fill): ...@@ -549,8 +688,21 @@ def test_randaugment(device, num_ops, magnitude, fill):
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('fill', [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]) @pytest.mark.parametrize(
"fill",
[
None,
85,
(10, -10, 10),
0.7,
[0.0, 0.0, 0.0],
[
1,
],
1,
],
)
def test_trivialaugmentwide(device, fill): def test_trivialaugmentwide(device, fill):
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device) tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device) batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
...@@ -562,21 +714,17 @@ def test_trivialaugmentwide(device, fill): ...@@ -562,21 +714,17 @@ def test_trivialaugmentwide(device, fill):
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
@pytest.mark.parametrize('augmentation', [T.AutoAugment, T.RandAugment, T.TrivialAugmentWide]) @pytest.mark.parametrize("augmentation", [T.AutoAugment, T.RandAugment, T.TrivialAugmentWide])
def test_autoaugment_save(augmentation, tmpdir): def test_autoaugment_save(augmentation, tmpdir):
transform = augmentation() transform = augmentation()
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
s_transform.save(os.path.join(tmpdir, "t_autoaugment.pt")) s_transform.save(os.path.join(tmpdir, "t_autoaugment.pt"))
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize( @pytest.mark.parametrize(
'config', [ "config",
{"value": 0.2}, [{"value": 0.2}, {"value": "random"}, {"value": (0.2, 0.2, 0.2)}, {"value": "random", "ratio": (0.1, 0.2)}],
{"value": "random"},
{"value": (0.2, 0.2, 0.2)},
{"value": "random", "ratio": (0.1, 0.2)}
]
) )
def test_random_erasing(device, config): def test_random_erasing(device, config):
tensor, _ = _create_data(24, 32, channels=3, device=device) tensor, _ = _create_data(24, 32, channels=3, device=device)
...@@ -602,7 +750,7 @@ def test_random_erasing_with_invalid_data(): ...@@ -602,7 +750,7 @@ def test_random_erasing_with_invalid_data():
random_erasing(img) random_erasing(img)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
def test_normalize(device, tmpdir): def test_normalize(device, tmpdir):
fn = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) fn = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
tensor, _ = _create_data(26, 34, device=device) tensor, _ = _create_data(26, 34, device=device)
...@@ -621,7 +769,7 @@ def test_normalize(device, tmpdir): ...@@ -621,7 +769,7 @@ def test_normalize(device, tmpdir):
scripted_fn.save(os.path.join(tmpdir, "t_norm.pt")) scripted_fn.save(os.path.join(tmpdir, "t_norm.pt"))
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
def test_linear_transformation(device, tmpdir): def test_linear_transformation(device, tmpdir):
c, h, w = 3, 24, 32 c, h, w = 3, 24, 32
...@@ -647,14 +795,16 @@ def test_linear_transformation(device, tmpdir): ...@@ -647,14 +795,16 @@ def test_linear_transformation(device, tmpdir):
scripted_fn.save(os.path.join(tmpdir, "t_norm.pt")) scripted_fn.save(os.path.join(tmpdir, "t_norm.pt"))
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
def test_compose(device): def test_compose(device):
tensor, _ = _create_data(26, 34, device=device) tensor, _ = _create_data(26, 34, device=device)
tensor = tensor.to(dtype=torch.float32) / 255.0 tensor = tensor.to(dtype=torch.float32) / 255.0
transforms = T.Compose([ transforms = T.Compose(
T.CenterCrop(10), [
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), T.CenterCrop(10),
]) T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
]
)
s_transforms = torch.nn.Sequential(*transforms.transforms) s_transforms = torch.nn.Sequential(*transforms.transforms)
scripted_fn = torch.jit.script(s_transforms) scripted_fn = torch.jit.script(s_transforms)
...@@ -664,26 +814,36 @@ def test_compose(device): ...@@ -664,26 +814,36 @@ def test_compose(device):
transformed_tensor_script = scripted_fn(tensor) transformed_tensor_script = scripted_fn(tensor)
assert_equal(transformed_tensor, transformed_tensor_script, msg="{}".format(transforms)) assert_equal(transformed_tensor, transformed_tensor_script, msg="{}".format(transforms))
t = T.Compose([ t = T.Compose(
lambda x: x, [
]) lambda x: x,
]
)
with pytest.raises(RuntimeError, match="cannot call a value of type 'Tensor'"): with pytest.raises(RuntimeError, match="cannot call a value of type 'Tensor'"):
torch.jit.script(t) torch.jit.script(t)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
def test_random_apply(device): def test_random_apply(device):
tensor, _ = _create_data(26, 34, device=device) tensor, _ = _create_data(26, 34, device=device)
tensor = tensor.to(dtype=torch.float32) / 255.0 tensor = tensor.to(dtype=torch.float32) / 255.0
transforms = T.RandomApply([ transforms = T.RandomApply(
T.RandomHorizontalFlip(), [
T.ColorJitter(), T.RandomHorizontalFlip(),
], p=0.4) T.ColorJitter(),
s_transforms = T.RandomApply(torch.nn.ModuleList([ ],
T.RandomHorizontalFlip(), p=0.4,
T.ColorJitter(), )
]), p=0.4) s_transforms = T.RandomApply(
torch.nn.ModuleList(
[
T.RandomHorizontalFlip(),
T.ColorJitter(),
]
),
p=0.4,
)
scripted_fn = torch.jit.script(s_transforms) scripted_fn = torch.jit.script(s_transforms)
torch.manual_seed(12) torch.manual_seed(12)
...@@ -695,27 +855,38 @@ def test_random_apply(device): ...@@ -695,27 +855,38 @@ def test_random_apply(device):
if device == "cpu": if device == "cpu":
# Can't check this twice, otherwise # Can't check this twice, otherwise
# "Can't redefine method: forward on class: __torch__.torchvision.transforms.transforms.RandomApply" # "Can't redefine method: forward on class: __torch__.torchvision.transforms.transforms.RandomApply"
transforms = T.RandomApply([ transforms = T.RandomApply(
T.ColorJitter(), [
], p=0.3) T.ColorJitter(),
],
p=0.3,
)
with pytest.raises(RuntimeError, match="Module 'RandomApply' has no attribute 'transforms'"): with pytest.raises(RuntimeError, match="Module 'RandomApply' has no attribute 'transforms'"):
torch.jit.script(transforms) torch.jit.script(transforms)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize('meth_kwargs', [ @pytest.mark.parametrize(
{"kernel_size": 3, "sigma": 0.75}, "meth_kwargs",
{"kernel_size": 23, "sigma": [0.1, 2.0]}, [
{"kernel_size": 23, "sigma": (0.1, 2.0)}, {"kernel_size": 3, "sigma": 0.75},
{"kernel_size": [3, 3], "sigma": (1.0, 1.0)}, {"kernel_size": 23, "sigma": [0.1, 2.0]},
{"kernel_size": (3, 3), "sigma": (0.1, 2.0)}, {"kernel_size": 23, "sigma": (0.1, 2.0)},
{"kernel_size": [23], "sigma": 0.75} {"kernel_size": [3, 3], "sigma": (1.0, 1.0)},
]) {"kernel_size": (3, 3), "sigma": (0.1, 2.0)},
@pytest.mark.parametrize('channels', [1, 3]) {"kernel_size": [23], "sigma": 0.75},
],
)
@pytest.mark.parametrize("channels", [1, 3])
def test_gaussian_blur(device, channels, meth_kwargs): def test_gaussian_blur(device, channels, meth_kwargs):
tol = 1.0 + 1e-10 tol = 1.0 + 1e-10
torch.manual_seed(12) torch.manual_seed(12)
_test_class_op( _test_class_op(
T.GaussianBlur, meth_kwargs=meth_kwargs, channels=channels, T.GaussianBlur,
test_exact_match=False, device=device, agg_method="max", tol=tol meth_kwargs=meth_kwargs,
channels=channels,
test_exact_match=False,
device=device,
agg_method="max",
tol=tol,
) )
import torch
from torchvision.transforms import Compose
import pytest
import random import random
import numpy as np
import warnings import warnings
import numpy as np
import pytest
import torch
from common_utils import assert_equal from common_utils import assert_equal
from torchvision.transforms import Compose
try: try:
from scipy import stats from scipy import stats
...@@ -17,8 +18,7 @@ with warnings.catch_warnings(record=True): ...@@ -17,8 +18,7 @@ with warnings.catch_warnings(record=True):
import torchvision.transforms._transforms_video as transforms import torchvision.transforms._transforms_video as transforms
class TestVideoTransforms(): class TestVideoTransforms:
def test_random_crop_video(self): def test_random_crop_video(self):
numFrames = random.randint(4, 128) numFrames = random.randint(4, 128)
height = random.randint(10, 32) * 2 height = random.randint(10, 32) * 2
...@@ -26,10 +26,12 @@ class TestVideoTransforms(): ...@@ -26,10 +26,12 @@ class TestVideoTransforms():
oheight = random.randint(5, (height - 2) / 2) * 2 oheight = random.randint(5, (height - 2) / 2) * 2
owidth = random.randint(5, (width - 2) / 2) * 2 owidth = random.randint(5, (width - 2) / 2) * 2
clip = torch.randint(0, 256, (numFrames, height, width, 3), dtype=torch.uint8) clip = torch.randint(0, 256, (numFrames, height, width, 3), dtype=torch.uint8)
result = Compose([ result = Compose(
transforms.ToTensorVideo(), [
transforms.RandomCropVideo((oheight, owidth)), transforms.ToTensorVideo(),
])(clip) transforms.RandomCropVideo((oheight, owidth)),
]
)(clip)
assert result.size(2) == oheight assert result.size(2) == oheight
assert result.size(3) == owidth assert result.size(3) == owidth
...@@ -42,10 +44,12 @@ class TestVideoTransforms(): ...@@ -42,10 +44,12 @@ class TestVideoTransforms():
oheight = random.randint(5, (height - 2) / 2) * 2 oheight = random.randint(5, (height - 2) / 2) * 2
owidth = random.randint(5, (width - 2) / 2) * 2 owidth = random.randint(5, (width - 2) / 2) * 2
clip = torch.randint(0, 256, (numFrames, height, width, 3), dtype=torch.uint8) clip = torch.randint(0, 256, (numFrames, height, width, 3), dtype=torch.uint8)
result = Compose([ result = Compose(
transforms.ToTensorVideo(), [
transforms.RandomResizedCropVideo((oheight, owidth)), transforms.ToTensorVideo(),
])(clip) transforms.RandomResizedCropVideo((oheight, owidth)),
]
)(clip)
assert result.size(2) == oheight assert result.size(2) == oheight
assert result.size(3) == owidth assert result.size(3) == owidth
...@@ -61,47 +65,56 @@ class TestVideoTransforms(): ...@@ -61,47 +65,56 @@ class TestVideoTransforms():
clip = torch.ones((numFrames, height, width, 3), dtype=torch.uint8) * 255 clip = torch.ones((numFrames, height, width, 3), dtype=torch.uint8) * 255
oh1 = (height - oheight) // 2 oh1 = (height - oheight) // 2
ow1 = (width - owidth) // 2 ow1 = (width - owidth) // 2
clipNarrow = clip[:, oh1:oh1 + oheight, ow1:ow1 + owidth, :] clipNarrow = clip[:, oh1 : oh1 + oheight, ow1 : ow1 + owidth, :]
clipNarrow.fill_(0) clipNarrow.fill_(0)
result = Compose([ result = Compose(
transforms.ToTensorVideo(), [
transforms.CenterCropVideo((oheight, owidth)), transforms.ToTensorVideo(),
])(clip) transforms.CenterCropVideo((oheight, owidth)),
]
msg = "height: " + str(height) + " width: " \ )(clip)
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
msg = (
"height: " + str(height) + " width: " + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
)
assert result.sum().item() == 0, msg assert result.sum().item() == 0, msg
oheight += 1 oheight += 1
owidth += 1 owidth += 1
result = Compose([ result = Compose(
transforms.ToTensorVideo(), [
transforms.CenterCropVideo((oheight, owidth)), transforms.ToTensorVideo(),
])(clip) transforms.CenterCropVideo((oheight, owidth)),
]
)(clip)
sum1 = result.sum() sum1 = result.sum()
msg = "height: " + str(height) + " width: " \ msg = (
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) "height: " + str(height) + " width: " + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
)
assert sum1.item() > 1, msg assert sum1.item() > 1, msg
oheight += 1 oheight += 1
owidth += 1 owidth += 1
result = Compose([ result = Compose(
transforms.ToTensorVideo(), [
transforms.CenterCropVideo((oheight, owidth)), transforms.ToTensorVideo(),
])(clip) transforms.CenterCropVideo((oheight, owidth)),
]
)(clip)
sum2 = result.sum() sum2 = result.sum()
msg = "height: " + str(height) + " width: " \ msg = (
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) "height: " + str(height) + " width: " + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
)
assert sum2.item() > 1, msg assert sum2.item() > 1, msg
assert sum2.item() > sum1.item(), msg assert sum2.item() > sum1.item(), msg
@pytest.mark.skipif(stats is None, reason='scipy.stats is not available') @pytest.mark.skipif(stats is None, reason="scipy.stats is not available")
@pytest.mark.parametrize('channels', [1, 3]) @pytest.mark.parametrize("channels", [1, 3])
def test_normalize_video(self, channels): def test_normalize_video(self, channels):
def samples_from_standard_normal(tensor): def samples_from_standard_normal(tensor):
p_value = stats.kstest(list(tensor.view(-1)), 'norm', args=(0, 1)).pvalue p_value = stats.kstest(list(tensor.view(-1)), "norm", args=(0, 1)).pvalue
return p_value > 0.0001 return p_value > 0.0001
random_state = random.getstate() random_state = random.getstate()
...@@ -147,7 +160,7 @@ class TestVideoTransforms(): ...@@ -147,7 +160,7 @@ class TestVideoTransforms():
trans.__repr__() trans.__repr__()
@pytest.mark.skipif(stats is None, reason='scipy.stats not available') @pytest.mark.skipif(stats is None, reason="scipy.stats not available")
def test_random_horizontal_flip_video(self): def test_random_horizontal_flip_video(self):
random_state = random.getstate() random_state = random.getstate()
random.seed(42) random.seed(42)
...@@ -179,5 +192,5 @@ class TestVideoTransforms(): ...@@ -179,5 +192,5 @@ class TestVideoTransforms():
transforms.RandomHorizontalFlipVideo().__repr__() transforms.RandomHorizontalFlipVideo().__repr__()
if __name__ == '__main__': if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
import pytest
import numpy as np
import os import os
import sys import sys
import tempfile import tempfile
import torch
import torchvision.utils as utils
from io import BytesIO from io import BytesIO
import numpy as np
import pytest
import torch
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
from PIL import Image, __version__ as PILLOW_VERSION, ImageColor import torchvision.utils as utils
from common_utils import assert_equal from common_utils import assert_equal
from PIL import Image, __version__ as PILLOW_VERSION, ImageColor
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split('.')) PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split("."))
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
def test_make_grid_not_inplace(): def test_make_grid_not_inplace():
...@@ -23,13 +22,13 @@ def test_make_grid_not_inplace(): ...@@ -23,13 +22,13 @@ def test_make_grid_not_inplace():
t_clone = t.clone() t_clone = t.clone()
utils.make_grid(t, normalize=False) utils.make_grid(t, normalize=False)
assert_equal(t, t_clone, msg='make_grid modified tensor in-place') assert_equal(t, t_clone, msg="make_grid modified tensor in-place")
utils.make_grid(t, normalize=True, scale_each=False) utils.make_grid(t, normalize=True, scale_each=False)
assert_equal(t, t_clone, msg='make_grid modified tensor in-place') assert_equal(t, t_clone, msg="make_grid modified tensor in-place")
utils.make_grid(t, normalize=True, scale_each=True) utils.make_grid(t, normalize=True, scale_each=True)
assert_equal(t, t_clone, msg='make_grid modified tensor in-place') assert_equal(t, t_clone, msg="make_grid modified tensor in-place")
def test_normalize_in_make_grid(): def test_normalize_in_make_grid():
...@@ -46,48 +45,48 @@ def test_normalize_in_make_grid(): ...@@ -46,48 +45,48 @@ def test_normalize_in_make_grid():
rounded_grid_max = torch.round(grid_max * 10 ** n_digits) / (10 ** n_digits) rounded_grid_max = torch.round(grid_max * 10 ** n_digits) / (10 ** n_digits)
rounded_grid_min = torch.round(grid_min * 10 ** n_digits) / (10 ** n_digits) rounded_grid_min = torch.round(grid_min * 10 ** n_digits) / (10 ** n_digits)
assert_equal(norm_max, rounded_grid_max, msg='Normalized max is not equal to 1') assert_equal(norm_max, rounded_grid_max, msg="Normalized max is not equal to 1")
assert_equal(norm_min, rounded_grid_min, msg='Normalized min is not equal to 0') assert_equal(norm_min, rounded_grid_min, msg="Normalized min is not equal to 0")
@pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows') @pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
def test_save_image(): def test_save_image():
with tempfile.NamedTemporaryFile(suffix='.png') as f: with tempfile.NamedTemporaryFile(suffix=".png") as f:
t = torch.rand(2, 3, 64, 64) t = torch.rand(2, 3, 64, 64)
utils.save_image(t, f.name) utils.save_image(t, f.name)
assert os.path.exists(f.name), 'The image is not present after save' assert os.path.exists(f.name), "The image is not present after save"
@pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows') @pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
def test_save_image_single_pixel(): def test_save_image_single_pixel():
with tempfile.NamedTemporaryFile(suffix='.png') as f: with tempfile.NamedTemporaryFile(suffix=".png") as f:
t = torch.rand(1, 3, 1, 1) t = torch.rand(1, 3, 1, 1)
utils.save_image(t, f.name) utils.save_image(t, f.name)
assert os.path.exists(f.name), 'The pixel image is not present after save' assert os.path.exists(f.name), "The pixel image is not present after save"
@pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows') @pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
def test_save_image_file_object(): def test_save_image_file_object():
with tempfile.NamedTemporaryFile(suffix='.png') as f: with tempfile.NamedTemporaryFile(suffix=".png") as f:
t = torch.rand(2, 3, 64, 64) t = torch.rand(2, 3, 64, 64)
utils.save_image(t, f.name) utils.save_image(t, f.name)
img_orig = Image.open(f.name) img_orig = Image.open(f.name)
fp = BytesIO() fp = BytesIO()
utils.save_image(t, fp, format='png') utils.save_image(t, fp, format="png")
img_bytes = Image.open(fp) img_bytes = Image.open(fp)
assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg='Image not stored in file object') assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg="Image not stored in file object")
@pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows') @pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
def test_save_image_single_pixel_file_object(): def test_save_image_single_pixel_file_object():
with tempfile.NamedTemporaryFile(suffix='.png') as f: with tempfile.NamedTemporaryFile(suffix=".png") as f:
t = torch.rand(1, 3, 1, 1) t = torch.rand(1, 3, 1, 1)
utils.save_image(t, f.name) utils.save_image(t, f.name)
img_orig = Image.open(f.name) img_orig = Image.open(f.name)
fp = BytesIO() fp = BytesIO()
utils.save_image(t, fp, format='png') utils.save_image(t, fp, format="png")
img_bytes = Image.open(fp) img_bytes = Image.open(fp)
assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg='Image not stored in file object') assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg="Image not stored in file object")
def test_draw_boxes(): def test_draw_boxes():
...@@ -113,13 +112,7 @@ def test_draw_boxes(): ...@@ -113,13 +112,7 @@ def test_draw_boxes():
assert_equal(img, img_cp) assert_equal(img, img_cp)
@pytest.mark.parametrize('colors', [ @pytest.mark.parametrize("colors", [None, ["red", "blue", "#FF00FF", (1, 34, 122)], "red", "#FF00FF", (1, 34, 122)])
None,
['red', 'blue', '#FF00FF', (1, 34, 122)],
'red',
'#FF00FF',
(1, 34, 122)
])
def test_draw_boxes_colors(colors): def test_draw_boxes_colors(colors):
img = torch.full((3, 100, 100), 0, dtype=torch.uint8) img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
utils.draw_bounding_boxes(img, boxes, fill=False, width=7, colors=colors) utils.draw_bounding_boxes(img, boxes, fill=False, width=7, colors=colors)
...@@ -154,8 +147,7 @@ def test_draw_invalid_boxes(): ...@@ -154,8 +147,7 @@ def test_draw_invalid_boxes():
img_tp = ((1, 1, 1), (1, 2, 3)) img_tp = ((1, 1, 1), (1, 2, 3))
img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float) img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8) img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
with pytest.raises(TypeError, match="Tensor expected"): with pytest.raises(TypeError, match="Tensor expected"):
utils.draw_bounding_boxes(img_tp, boxes) utils.draw_bounding_boxes(img_tp, boxes)
with pytest.raises(ValueError, match="Tensor uint8 expected"): with pytest.raises(ValueError, match="Tensor uint8 expected"):
...@@ -166,12 +158,15 @@ def test_draw_invalid_boxes(): ...@@ -166,12 +158,15 @@ def test_draw_invalid_boxes():
utils.draw_bounding_boxes(img_wrong2[0][:2], boxes) utils.draw_bounding_boxes(img_wrong2[0][:2], boxes)
@pytest.mark.parametrize('colors', [ @pytest.mark.parametrize(
None, "colors",
['red', 'blue'], [
['#FF00FF', (1, 34, 122)], None,
]) ["red", "blue"],
@pytest.mark.parametrize('alpha', (0, .5, .7, 1)) ["#FF00FF", (1, 34, 122)],
],
)
@pytest.mark.parametrize("alpha", (0, 0.5, 0.7, 1))
def test_draw_segmentation_masks(colors, alpha): def test_draw_segmentation_masks(colors, alpha):
"""This test makes sure that masks draw their corresponding color where they should""" """This test makes sure that masks draw their corresponding color where they should"""
num_masks, h, w = 2, 100, 100 num_masks, h, w = 2, 100, 100
...@@ -241,10 +236,10 @@ def test_draw_segmentation_masks_errors(): ...@@ -241,10 +236,10 @@ def test_draw_segmentation_masks_errors():
with pytest.raises(ValueError, match="There are more masks"): with pytest.raises(ValueError, match="There are more masks"):
utils.draw_segmentation_masks(image=img, masks=masks, colors=[]) utils.draw_segmentation_masks(image=img, masks=masks, colors=[])
with pytest.raises(ValueError, match="colors must be a tuple or a string, or a list thereof"): with pytest.raises(ValueError, match="colors must be a tuple or a string, or a list thereof"):
bad_colors = np.array(['red', 'blue']) # should be a list bad_colors = np.array(["red", "blue"]) # should be a list
utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors) utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors)
with pytest.raises(ValueError, match="It seems that you passed a tuple of colors instead of"): with pytest.raises(ValueError, match="It seems that you passed a tuple of colors instead of"):
bad_colors = ('red', 'blue') # should be a list bad_colors = ("red", "blue") # should be a list
utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors) utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors)
......
...@@ -2,17 +2,17 @@ import collections ...@@ -2,17 +2,17 @@ import collections
import itertools import itertools
import math import math
import os import os
import pytest
from pytest import approx
from fractions import Fraction from fractions import Fraction
import numpy as np import numpy as np
import pytest
import torch import torch
import torchvision.io as io import torchvision.io as io
from common_utils import assert_equal
from numpy.random import randint from numpy.random import randint
from pytest import approx
from torchvision import set_video_backend from torchvision import set_video_backend
from torchvision.io import _HAS_VIDEO_OPT from torchvision.io import _HAS_VIDEO_OPT
from common_utils import assert_equal
try: try:
...@@ -108,18 +108,14 @@ test_videos = { ...@@ -108,18 +108,14 @@ test_videos = {
} }
DecoderResult = collections.namedtuple( DecoderResult = collections.namedtuple("DecoderResult", "vframes vframe_pts vtimebase aframes aframe_pts atimebase")
"DecoderResult", "vframes vframe_pts vtimebase aframes aframe_pts atimebase"
)
# av_seek_frame is imprecise so seek to a timestamp earlier by a margin # av_seek_frame is imprecise so seek to a timestamp earlier by a margin
# The unit of margin is second # The unit of margin is second
seek_frame_margin = 0.25 seek_frame_margin = 0.25
def _read_from_stream( def _read_from_stream(container, start_pts, end_pts, stream, stream_name, buffer_size=4):
container, start_pts, end_pts, stream, stream_name, buffer_size=4
):
""" """
Args: Args:
container: pyav container container: pyav container
...@@ -231,9 +227,7 @@ def _decode_frames_by_av_module( ...@@ -231,9 +227,7 @@ def _decode_frames_by_av_module(
else: else:
aframes = torch.empty((1, 0), dtype=torch.float32) aframes = torch.empty((1, 0), dtype=torch.float32)
aframe_pts = torch.tensor( aframe_pts = torch.tensor([audio_frame.pts for audio_frame in audio_frames], dtype=torch.int64)
[audio_frame.pts for audio_frame in audio_frames], dtype=torch.int64
)
return DecoderResult( return DecoderResult(
vframes=vframes, vframes=vframes,
...@@ -273,25 +267,28 @@ def _get_video_tensor(video_dir, video_file): ...@@ -273,25 +267,28 @@ def _get_video_tensor(video_dir, video_file):
@pytest.mark.skipif(_HAS_VIDEO_OPT is False, reason="Didn't compile with ffmpeg") @pytest.mark.skipif(_HAS_VIDEO_OPT is False, reason="Didn't compile with ffmpeg")
class TestVideoReader: class TestVideoReader:
def check_separate_decoding_result(self, tv_result, config): def check_separate_decoding_result(self, tv_result, config):
"""check the decoding results from TorchVision decoder """check the decoding results from TorchVision decoder"""
""" (
vframes, vframe_pts, vtimebase, vfps, vduration, \ vframes,
aframes, aframe_pts, atimebase, asample_rate, aduration = ( vframe_pts,
tv_result vtimebase,
) vfps,
vduration,
video_duration = vduration.item() * Fraction( aframes,
vtimebase[0].item(), vtimebase[1].item() aframe_pts,
) atimebase,
asample_rate,
aduration,
) = tv_result
video_duration = vduration.item() * Fraction(vtimebase[0].item(), vtimebase[1].item())
assert video_duration == approx(config.duration, abs=0.5) assert video_duration == approx(config.duration, abs=0.5)
assert vfps.item() == approx(config.video_fps, abs=0.5) assert vfps.item() == approx(config.video_fps, abs=0.5)
if asample_rate.numel() > 0: if asample_rate.numel() > 0:
assert asample_rate.item() == config.audio_sample_rate assert asample_rate.item() == config.audio_sample_rate
audio_duration = aduration.item() * Fraction( audio_duration = aduration.item() * Fraction(atimebase[0].item(), atimebase[1].item())
atimebase[0].item(), atimebase[1].item()
)
assert audio_duration == approx(config.duration, abs=0.5) assert audio_duration == approx(config.duration, abs=0.5)
# check if pts of video frames are sorted in ascending order # check if pts of video frames are sorted in ascending order
...@@ -305,16 +302,12 @@ class TestVideoReader: ...@@ -305,16 +302,12 @@ class TestVideoReader:
def check_probe_result(self, result, config): def check_probe_result(self, result, config):
vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
video_duration = vduration.item() * Fraction( video_duration = vduration.item() * Fraction(vtimebase[0].item(), vtimebase[1].item())
vtimebase[0].item(), vtimebase[1].item()
)
assert video_duration == approx(config.duration, abs=0.5) assert video_duration == approx(config.duration, abs=0.5)
assert vfps.item() == approx(config.video_fps, abs=0.5) assert vfps.item() == approx(config.video_fps, abs=0.5)
if asample_rate.numel() > 0: if asample_rate.numel() > 0:
assert asample_rate.item() == config.audio_sample_rate assert asample_rate.item() == config.audio_sample_rate
audio_duration = aduration.item() * Fraction( audio_duration = aduration.item() * Fraction(atimebase[0].item(), atimebase[1].item())
atimebase[0].item(), atimebase[1].item()
)
assert audio_duration == approx(config.duration, abs=0.5) assert audio_duration == approx(config.duration, abs=0.5)
def check_meta_result(self, result, config): def check_meta_result(self, result, config):
...@@ -333,10 +326,18 @@ class TestVideoReader: ...@@ -333,10 +326,18 @@ class TestVideoReader:
decoder or TorchVision decoder with getPtsOnly = 1 decoder or TorchVision decoder with getPtsOnly = 1
config: config of decoding results checker config: config of decoding results checker
""" """
vframes, vframe_pts, vtimebase, _vfps, _vduration, \ (
aframes, aframe_pts, atimebase, _asample_rate, _aduration = ( vframes,
tv_result vframe_pts,
) vtimebase,
_vfps,
_vduration,
aframes,
aframe_pts,
atimebase,
_asample_rate,
_aduration,
) = tv_result
if isinstance(ref_result, list): if isinstance(ref_result, list):
# the ref_result is from new video_reader decoder # the ref_result is from new video_reader decoder
ref_result = DecoderResult( ref_result = DecoderResult(
...@@ -349,32 +350,20 @@ class TestVideoReader: ...@@ -349,32 +350,20 @@ class TestVideoReader:
) )
if vframes.numel() > 0 and ref_result.vframes.numel() > 0: if vframes.numel() > 0 and ref_result.vframes.numel() > 0:
mean_delta = torch.mean( mean_delta = torch.mean(torch.abs(vframes.float() - ref_result.vframes.float()))
torch.abs(vframes.float() - ref_result.vframes.float())
)
assert mean_delta == approx(0.0, abs=8.0) assert mean_delta == approx(0.0, abs=8.0)
mean_delta = torch.mean( mean_delta = torch.mean(torch.abs(vframe_pts.float() - ref_result.vframe_pts.float()))
torch.abs(vframe_pts.float() - ref_result.vframe_pts.float())
)
assert mean_delta == approx(0.0, abs=1.0) assert mean_delta == approx(0.0, abs=1.0)
assert_equal(vtimebase, ref_result.vtimebase) assert_equal(vtimebase, ref_result.vtimebase)
if ( if config.check_aframes and aframes.numel() > 0 and ref_result.aframes.numel() > 0:
config.check_aframes
and aframes.numel() > 0
and ref_result.aframes.numel() > 0
):
"""Audio stream is available and audio frame is required to return """Audio stream is available and audio frame is required to return
from decoder""" from decoder"""
assert_equal(aframes, ref_result.aframes) assert_equal(aframes, ref_result.aframes)
if ( if config.check_aframe_pts and aframe_pts.numel() > 0 and ref_result.aframe_pts.numel() > 0:
config.check_aframe_pts
and aframe_pts.numel() > 0
and ref_result.aframe_pts.numel() > 0
):
"""Audio stream is available""" """Audio stream is available"""
assert_equal(aframe_pts, ref_result.aframe_pts) assert_equal(aframe_pts, ref_result.aframe_pts)
...@@ -508,19 +497,25 @@ class TestVideoReader: ...@@ -508,19 +497,25 @@ class TestVideoReader:
audio_timebase_den, audio_timebase_den,
) )
vframes, vframe_pts, vtimebase, vfps, vduration, \ (
aframes, aframe_pts, atimebase, asample_rate, aduration = ( vframes,
tv_result vframe_pts,
) vtimebase,
vfps,
vduration,
aframes,
aframe_pts,
atimebase,
asample_rate,
aduration,
) = tv_result
assert (vframes.numel() > 0) is bool(readVideoStream) assert (vframes.numel() > 0) is bool(readVideoStream)
assert (vframe_pts.numel() > 0) is bool(readVideoStream) assert (vframe_pts.numel() > 0) is bool(readVideoStream)
assert (vtimebase.numel() > 0) is bool(readVideoStream) assert (vtimebase.numel() > 0) is bool(readVideoStream)
assert (vfps.numel() > 0) is bool(readVideoStream) assert (vfps.numel() > 0) is bool(readVideoStream)
expect_audio_data = ( expect_audio_data = readAudioStream == 1 and config.audio_sample_rate is not None
readAudioStream == 1 and config.audio_sample_rate is not None
)
assert (aframes.numel() > 0) is bool(expect_audio_data) assert (aframes.numel() > 0) is bool(expect_audio_data)
assert (aframe_pts.numel() > 0) is bool(expect_audio_data) assert (aframe_pts.numel() > 0) is bool(expect_audio_data)
assert (atimebase.numel() > 0) is bool(expect_audio_data) assert (atimebase.numel() > 0) is bool(expect_audio_data)
...@@ -808,19 +803,23 @@ class TestVideoReader: ...@@ -808,19 +803,23 @@ class TestVideoReader:
audio_timebase_num, audio_timebase_num,
audio_timebase_den, audio_timebase_den,
) )
vframes, vframe_pts, vtimebase, vfps, vduration, \ (
aframes, aframe_pts, atimebase, asample_rate, aduration = ( vframes,
tv_result vframe_pts,
) vtimebase,
vfps,
vduration,
aframes,
aframe_pts,
atimebase,
asample_rate,
aduration,
) = tv_result
if aframes.numel() > 0: if aframes.numel() > 0:
assert samples == asample_rate.item() assert samples == asample_rate.item()
assert 1 == aframes.size(1) assert 1 == aframes.size(1)
# when audio stream is found # when audio stream is found
duration = ( duration = float(aframe_pts[-1]) * float(atimebase[0]) / float(atimebase[1])
float(aframe_pts[-1])
* float(atimebase[0])
/ float(atimebase[1])
)
assert aframes.size(0) == approx(int(duration * asample_rate.item()), abs=0.1 * asample_rate.item()) assert aframes.size(0) == approx(int(duration * asample_rate.item()), abs=0.1 * asample_rate.item())
def test_compare_read_video_from_memory_and_file(self): def test_compare_read_video_from_memory_and_file(self):
...@@ -1040,10 +1039,18 @@ class TestVideoReader: ...@@ -1040,10 +1039,18 @@ class TestVideoReader:
audio_timebase_num, audio_timebase_num,
audio_timebase_den, audio_timebase_den,
) )
vframes, vframe_pts, vtimebase, vfps, vduration, \ (
aframes, aframe_pts, atimebase, asample_rate, aduration = ( vframes,
tv_result vframe_pts,
) vtimebase,
vfps,
vduration,
aframes,
aframe_pts,
atimebase,
asample_rate,
aduration,
) = tv_result
assert abs(config.video_fps - vfps.item()) < 0.01 assert abs(config.video_fps - vfps.item()) < 0.01
for num_frames in [4, 8, 16, 32, 64, 128]: for num_frames in [4, 8, 16, 32, 64, 128]:
...@@ -1097,41 +1104,31 @@ class TestVideoReader: ...@@ -1097,41 +1104,31 @@ class TestVideoReader:
) )
# pass 3: decode frames in range using PyAv # pass 3: decode frames in range using PyAv
video_timebase_av, audio_timebase_av = _get_timebase_by_av_module( video_timebase_av, audio_timebase_av = _get_timebase_by_av_module(full_path)
full_path
)
video_start_pts_av = _pts_convert( video_start_pts_av = _pts_convert(
video_start_pts.item(), video_start_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()), Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction( Fraction(video_timebase_av.numerator, video_timebase_av.denominator),
video_timebase_av.numerator, video_timebase_av.denominator
),
math.floor, math.floor,
) )
video_end_pts_av = _pts_convert( video_end_pts_av = _pts_convert(
video_end_pts.item(), video_end_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()), Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction( Fraction(video_timebase_av.numerator, video_timebase_av.denominator),
video_timebase_av.numerator, video_timebase_av.denominator
),
math.ceil, math.ceil,
) )
if audio_timebase_av: if audio_timebase_av:
audio_start_pts = _pts_convert( audio_start_pts = _pts_convert(
video_start_pts.item(), video_start_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()), Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction( Fraction(audio_timebase_av.numerator, audio_timebase_av.denominator),
audio_timebase_av.numerator, audio_timebase_av.denominator
),
math.floor, math.floor,
) )
audio_end_pts = _pts_convert( audio_end_pts = _pts_convert(
video_end_pts.item(), video_end_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()), Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction( Fraction(audio_timebase_av.numerator, audio_timebase_av.denominator),
audio_timebase_av.numerator, audio_timebase_av.denominator
),
math.ceil, math.ceil,
) )
...@@ -1218,46 +1215,42 @@ class TestVideoReader: ...@@ -1218,46 +1215,42 @@ class TestVideoReader:
# FUTURE: check value of video / audio frames # FUTURE: check value of video / audio frames
def test_invalid_file(self): def test_invalid_file(self):
set_video_backend('video_reader') set_video_backend("video_reader")
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
io.read_video('foo.mp4') io.read_video("foo.mp4")
set_video_backend('pyav') set_video_backend("pyav")
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
io.read_video('foo.mp4') io.read_video("foo.mp4")
def test_audio_present_pts(self): def test_audio_present_pts(self):
"""Test if audio frames are returned with pts unit.""" """Test if audio frames are returned with pts unit."""
backends = ['video_reader', 'pyav'] backends = ["video_reader", "pyav"]
start_offsets = [0, 1000] start_offsets = [0, 1000]
end_offsets = [3000, None] end_offsets = [3000, None]
for test_video, _ in test_videos.items(): for test_video, _ in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video) full_path = os.path.join(VIDEO_DIR, test_video)
container = av.open(full_path) container = av.open(full_path)
if container.streams.audio: if container.streams.audio:
for backend, start_offset, end_offset in itertools.product( for backend, start_offset, end_offset in itertools.product(backends, start_offsets, end_offsets):
backends, start_offsets, end_offsets):
set_video_backend(backend) set_video_backend(backend)
_, audio, _ = io.read_video( _, audio, _ = io.read_video(full_path, start_offset, end_offset, pts_unit="pts")
full_path, start_offset, end_offset, pts_unit='pts')
assert all([dimension > 0 for dimension in audio.shape[:2]]) assert all([dimension > 0 for dimension in audio.shape[:2]])
def test_audio_present_sec(self): def test_audio_present_sec(self):
"""Test if audio frames are returned with sec unit.""" """Test if audio frames are returned with sec unit."""
backends = ['video_reader', 'pyav'] backends = ["video_reader", "pyav"]
start_offsets = [0, 0.1] start_offsets = [0, 0.1]
end_offsets = [0.3, None] end_offsets = [0.3, None]
for test_video, _ in test_videos.items(): for test_video, _ in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video) full_path = os.path.join(VIDEO_DIR, test_video)
container = av.open(full_path) container = av.open(full_path)
if container.streams.audio: if container.streams.audio:
for backend, start_offset, end_offset in itertools.product( for backend, start_offset, end_offset in itertools.product(backends, start_offsets, end_offsets):
backends, start_offsets, end_offsets):
set_video_backend(backend) set_video_backend(backend)
_, audio, _ = io.read_video( _, audio, _ = io.read_video(full_path, start_offset, end_offset, pts_unit="sec")
full_path, start_offset, end_offset, pts_unit='sec')
assert all([dimension > 0 for dimension in audio.shape[:2]]) assert all([dimension > 0 for dimension in audio.shape[:2]])
if __name__ == '__main__': if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
import collections import collections
import os import os
import pytest
from pytest import approx
import urllib import urllib
import pytest
import torch import torch
import torchvision import torchvision
from torchvision.io import _HAS_VIDEO_OPT, VideoReader from pytest import approx
from torchvision.datasets.utils import download_url from torchvision.datasets.utils import download_url
from torchvision.io import _HAS_VIDEO_OPT, VideoReader
try: try:
...@@ -36,30 +36,16 @@ def fate(name, path="."): ...@@ -36,30 +36,16 @@ def fate(name, path="."):
test_videos = { test_videos = {
"RATRACE_wave_f_nm_np1_fr_goo_37.avi": GroundTruth( "RATRACE_wave_f_nm_np1_fr_goo_37.avi": GroundTruth(duration=2.0, video_fps=30.0, audio_sample_rate=None),
duration=2.0, video_fps=30.0, audio_sample_rate=None
),
"SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi": GroundTruth( "SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi": GroundTruth(
duration=2.0, video_fps=30.0, audio_sample_rate=None duration=2.0, video_fps=30.0, audio_sample_rate=None
), ),
"TrumanShow_wave_f_nm_np1_fr_med_26.avi": GroundTruth( "TrumanShow_wave_f_nm_np1_fr_med_26.avi": GroundTruth(duration=2.0, video_fps=30.0, audio_sample_rate=None),
duration=2.0, video_fps=30.0, audio_sample_rate=None "v_SoccerJuggling_g23_c01.avi": GroundTruth(duration=8.0, video_fps=29.97, audio_sample_rate=None),
), "v_SoccerJuggling_g24_c01.avi": GroundTruth(duration=8.0, video_fps=29.97, audio_sample_rate=None),
"v_SoccerJuggling_g23_c01.avi": GroundTruth( "R6llTwEh07w.mp4": GroundTruth(duration=10.0, video_fps=30.0, audio_sample_rate=44100),
duration=8.0, video_fps=29.97, audio_sample_rate=None "SOX5yA1l24A.mp4": GroundTruth(duration=11.0, video_fps=29.97, audio_sample_rate=48000),
), "WUzgd7C1pWA.mp4": GroundTruth(duration=11.0, video_fps=29.97, audio_sample_rate=48000),
"v_SoccerJuggling_g24_c01.avi": GroundTruth(
duration=8.0, video_fps=29.97, audio_sample_rate=None
),
"R6llTwEh07w.mp4": GroundTruth(
duration=10.0, video_fps=30.0, audio_sample_rate=44100
),
"SOX5yA1l24A.mp4": GroundTruth(
duration=11.0, video_fps=29.97, audio_sample_rate=48000
),
"WUzgd7C1pWA.mp4": GroundTruth(
duration=11.0, video_fps=29.97, audio_sample_rate=48000
),
} }
...@@ -79,13 +65,9 @@ class TestVideoApi: ...@@ -79,13 +65,9 @@ class TestVideoApi:
assert float(av_frame.pts * av_frame.time_base) == approx(vr_frame["pts"], abs=0.1) assert float(av_frame.pts * av_frame.time_base) == approx(vr_frame["pts"], abs=0.1)
av_array = torch.tensor(av_frame.to_rgb().to_ndarray()).permute( av_array = torch.tensor(av_frame.to_rgb().to_ndarray()).permute(2, 0, 1)
2, 0, 1
)
vr_array = vr_frame["data"] vr_array = vr_frame["data"]
mean_delta = torch.mean( mean_delta = torch.mean(torch.abs(av_array.float() - vr_array.float()))
torch.abs(av_array.float() - vr_array.float())
)
# on average the difference is very small and caused # on average the difference is very small and caused
# by decoding (around 1%) # by decoding (around 1%)
# TODO: asses empirically how to set this? atm it's 1% # TODO: asses empirically how to set this? atm it's 1%
...@@ -102,9 +84,7 @@ class TestVideoApi: ...@@ -102,9 +84,7 @@ class TestVideoApi:
av_array = torch.tensor(av_frame.to_ndarray()).permute(1, 0) av_array = torch.tensor(av_frame.to_ndarray()).permute(1, 0)
vr_array = vr_frame["data"] vr_array = vr_frame["data"]
max_delta = torch.max( max_delta = torch.max(torch.abs(av_array.float() - vr_array.float()))
torch.abs(av_array.float() - vr_array.float())
)
# we assure that there is never more than 1% difference in signal # we assure that there is never more than 1% difference in signal
assert max_delta.item() < 0.001 assert max_delta.item() < 0.001
...@@ -188,5 +168,5 @@ class TestVideoApi: ...@@ -188,5 +168,5 @@ class TestVideoApi:
os.remove(video_path) os.remove(video_path)
if __name__ == '__main__': if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
import os.path as osp import os.path as osp
import torch import torch
......
import warnings
import os import os
import warnings
from .extension import _HAS_OPS import torch
from torchvision import models
from torchvision import datasets from torchvision import datasets
from torchvision import io
from torchvision import models
from torchvision import ops from torchvision import ops
from torchvision import transforms from torchvision import transforms
from torchvision import utils from torchvision import utils
from torchvision import io
import torch from .extension import _HAS_OPS
try: try:
from .version import __version__ # noqa: F401 from .version import __version__ # noqa: F401
...@@ -18,14 +17,17 @@ except ImportError: ...@@ -18,14 +17,17 @@ except ImportError:
pass pass
# Check if torchvision is being imported within the root folder # Check if torchvision is being imported within the root folder
if (not _HAS_OPS and os.path.dirname(os.path.realpath(__file__)) == if not _HAS_OPS and os.path.dirname(os.path.realpath(__file__)) == os.path.join(
os.path.join(os.path.realpath(os.getcwd()), 'torchvision')): os.path.realpath(os.getcwd()), "torchvision"
message = ('You are importing torchvision within its own root folder ({}). ' ):
'This is not expected to work and may give errors. Please exit the ' message = (
'torchvision project source and relaunch your python interpreter.') "You are importing torchvision within its own root folder ({}). "
"This is not expected to work and may give errors. Please exit the "
"torchvision project source and relaunch your python interpreter."
)
warnings.warn(message.format(os.getcwd())) warnings.warn(message.format(os.getcwd()))
_image_backend = 'PIL' _image_backend = "PIL"
_video_backend = "pyav" _video_backend = "pyav"
...@@ -40,9 +42,8 @@ def set_image_backend(backend): ...@@ -40,9 +42,8 @@ def set_image_backend(backend):
generally faster than PIL, but does not support as many operations. generally faster than PIL, but does not support as many operations.
""" """
global _image_backend global _image_backend
if backend not in ['PIL', 'accimage']: if backend not in ["PIL", "accimage"]:
raise ValueError("Invalid backend '{}'. Options are 'PIL' and 'accimage'" raise ValueError("Invalid backend '{}'. Options are 'PIL' and 'accimage'".format(backend))
.format(backend))
_image_backend = backend _image_backend = backend
...@@ -71,14 +72,9 @@ def set_video_backend(backend): ...@@ -71,14 +72,9 @@ def set_video_backend(backend):
""" """
global _video_backend global _video_backend
if backend not in ["pyav", "video_reader"]: if backend not in ["pyav", "video_reader"]:
raise ValueError( raise ValueError("Invalid video backend '%s'. Options are 'pyav' and 'video_reader'" % backend)
"Invalid video backend '%s'. Options are 'pyav' and 'video_reader'" % backend
)
if backend == "video_reader" and not io._HAS_VIDEO_OPT: if backend == "video_reader" and not io._HAS_VIDEO_OPT:
message = ( message = "video_reader video backend is not available." " Please compile torchvision from source and try again"
"video_reader video backend is not available."
" Please compile torchvision from source and try again"
)
warnings.warn(message) warnings.warn(message)
else: else:
_video_backend = backend _video_backend = backend
......
import os
import importlib.machinery import importlib.machinery
import os
def _download_file_from_remote_location(fpath: str, url: str) -> None: def _download_file_from_remote_location(fpath: str, url: str) -> None:
...@@ -19,13 +19,13 @@ except ImportError: ...@@ -19,13 +19,13 @@ except ImportError:
def _get_extension_path(lib_name): def _get_extension_path(lib_name):
lib_dir = os.path.dirname(__file__) lib_dir = os.path.dirname(__file__)
if os.name == 'nt': if os.name == "nt":
# Register the main torchvision library location on the default DLL path # Register the main torchvision library location on the default DLL path
import ctypes import ctypes
import sys import sys
kernel32 = ctypes.WinDLL('kernel32.dll', use_last_error=True) kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
with_load_library_flags = hasattr(kernel32, 'AddDllDirectory') with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
prev_error_mode = kernel32.SetErrorMode(0x0001) prev_error_mode = kernel32.SetErrorMode(0x0001)
if with_load_library_flags: if with_load_library_flags:
...@@ -42,10 +42,7 @@ def _get_extension_path(lib_name): ...@@ -42,10 +42,7 @@ def _get_extension_path(lib_name):
kernel32.SetErrorMode(prev_error_mode) kernel32.SetErrorMode(prev_error_mode)
loader_details = ( loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES)
importlib.machinery.ExtensionFileLoader,
importlib.machinery.EXTENSION_SUFFIXES
)
extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
ext_specs = extfinder.find_spec(lib_name) ext_specs = extfinder.find_spec(lib_name)
......
from .lsun import LSUN, LSUNClass from .caltech import Caltech101, Caltech256
from .folder import ImageFolder, DatasetFolder from .celeba import CelebA
from .coco import CocoCaptions, CocoDetection
from .cifar import CIFAR10, CIFAR100 from .cifar import CIFAR10, CIFAR100
from .stl10 import STL10 from .cityscapes import Cityscapes
from .mnist import MNIST, EMNIST, FashionMNIST, KMNIST, QMNIST from .coco import CocoCaptions, CocoDetection
from .svhn import SVHN
from .phototour import PhotoTour
from .fakedata import FakeData from .fakedata import FakeData
from .semeion import SEMEION
from .omniglot import Omniglot
from .sbu import SBU
from .flickr import Flickr8k, Flickr30k from .flickr import Flickr8k, Flickr30k
from .voc import VOCSegmentation, VOCDetection from .folder import ImageFolder, DatasetFolder
from .cityscapes import Cityscapes from .hmdb51 import HMDB51
from .imagenet import ImageNet from .imagenet import ImageNet
from .caltech import Caltech101, Caltech256 from .inaturalist import INaturalist
from .celeba import CelebA
from .widerface import WIDERFace
from .sbd import SBDataset
from .vision import VisionDataset
from .usps import USPS
from .kinetics import Kinetics400, Kinetics from .kinetics import Kinetics400, Kinetics
from .hmdb51 import HMDB51
from .ucf101 import UCF101
from .places365 import Places365
from .kitti import Kitti from .kitti import Kitti
from .inaturalist import INaturalist
from .lfw import LFWPeople, LFWPairs from .lfw import LFWPeople, LFWPairs
from .lsun import LSUN, LSUNClass
from .mnist import MNIST, EMNIST, FashionMNIST, KMNIST, QMNIST
from .omniglot import Omniglot
from .phototour import PhotoTour
from .places365 import Places365
from .sbd import SBDataset
from .sbu import SBU
from .semeion import SEMEION
from .stl10 import STL10
from .svhn import SVHN
from .ucf101 import UCF101
from .usps import USPS
from .vision import VisionDataset
from .voc import VOCSegmentation, VOCDetection
from .widerface import WIDERFace
__all__ = ('LSUN', 'LSUNClass', __all__ = (
'ImageFolder', 'DatasetFolder', 'FakeData', "LSUN",
'CocoCaptions', 'CocoDetection', "LSUNClass",
'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST', 'QMNIST', "ImageFolder",
'MNIST', 'KMNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION', "DatasetFolder",
'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k', "FakeData",
'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet', "CocoCaptions",
'Caltech101', 'Caltech256', 'CelebA', 'WIDERFace', 'SBDataset', "CocoDetection",
'VisionDataset', 'USPS', 'Kinetics400', "Kinetics", 'HMDB51', 'UCF101', "CIFAR10",
'Places365', 'Kitti', "INaturalist", "LFWPeople", "LFWPairs" "CIFAR100",
) "EMNIST",
"FashionMNIST",
"QMNIST",
"MNIST",
"KMNIST",
"STL10",
"SVHN",
"PhotoTour",
"SEMEION",
"Omniglot",
"SBU",
"Flickr8k",
"Flickr30k",
"VOCSegmentation",
"VOCDetection",
"Cityscapes",
"ImageNet",
"Caltech101",
"Caltech256",
"CelebA",
"WIDERFace",
"SBDataset",
"VisionDataset",
"USPS",
"Kinetics400",
"Kinetics",
"HMDB51",
"UCF101",
"Places365",
"Kitti",
"INaturalist",
"LFWPeople",
"LFWPairs",
)
from PIL import Image
import os import os
import os.path import os.path
from typing import Any, Callable, List, Optional, Union, Tuple from typing import Any, Callable, List, Optional, Union, Tuple
from .vision import VisionDataset from PIL import Image
from .utils import download_and_extract_archive, verify_str_arg from .utils import download_and_extract_archive, verify_str_arg
from .vision import VisionDataset
class Caltech101(VisionDataset): class Caltech101(VisionDataset):
...@@ -32,28 +33,26 @@ class Caltech101(VisionDataset): ...@@ -32,28 +33,26 @@ class Caltech101(VisionDataset):
""" """
def __init__( def __init__(
self, self,
root: str, root: str,
target_type: Union[List[str], str] = "category", target_type: Union[List[str], str] = "category",
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False, download: bool = False,
) -> None: ) -> None:
super(Caltech101, self).__init__(os.path.join(root, 'caltech101'), super(Caltech101, self).__init__(
transform=transform, os.path.join(root, "caltech101"), transform=transform, target_transform=target_transform
target_transform=target_transform) )
os.makedirs(self.root, exist_ok=True) os.makedirs(self.root, exist_ok=True)
if not isinstance(target_type, list): if not isinstance(target_type, list):
target_type = [target_type] target_type = [target_type]
self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation")) self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation")) for t in target_type]
for t in target_type]
if download: if download:
self.download() self.download()
if not self._check_integrity(): if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' + raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it")
' You can use download=True to download it')
self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories"))) self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories")))
self.categories.remove("BACKGROUND_Google") # this is not a real class self.categories.remove("BACKGROUND_Google") # this is not a real class
...@@ -61,10 +60,12 @@ class Caltech101(VisionDataset): ...@@ -61,10 +60,12 @@ class Caltech101(VisionDataset):
# For some reason, the category names in "101_ObjectCategories" and # For some reason, the category names in "101_ObjectCategories" and
# "Annotations" do not always match. This is a manual map between the # "Annotations" do not always match. This is a manual map between the
# two. Defaults to using same name, since most names are fine. # two. Defaults to using same name, since most names are fine.
name_map = {"Faces": "Faces_2", name_map = {
"Faces_easy": "Faces_3", "Faces": "Faces_2",
"Motorbikes": "Motorbikes_16", "Faces_easy": "Faces_3",
"airplanes": "Airplanes_Side_2"} "Motorbikes": "Motorbikes_16",
"airplanes": "Airplanes_Side_2",
}
self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories)) self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories))
self.index: List[int] = [] self.index: List[int] = []
...@@ -84,20 +85,28 @@ class Caltech101(VisionDataset): ...@@ -84,20 +85,28 @@ class Caltech101(VisionDataset):
""" """
import scipy.io import scipy.io
img = Image.open(os.path.join(self.root, img = Image.open(
"101_ObjectCategories", os.path.join(
self.categories[self.y[index]], self.root,
"image_{:04d}.jpg".format(self.index[index]))) "101_ObjectCategories",
self.categories[self.y[index]],
"image_{:04d}.jpg".format(self.index[index]),
)
)
target: Any = [] target: Any = []
for t in self.target_type: for t in self.target_type:
if t == "category": if t == "category":
target.append(self.y[index]) target.append(self.y[index])
elif t == "annotation": elif t == "annotation":
data = scipy.io.loadmat(os.path.join(self.root, data = scipy.io.loadmat(
"Annotations", os.path.join(
self.annotation_categories[self.y[index]], self.root,
"annotation_{:04d}.mat".format(self.index[index]))) "Annotations",
self.annotation_categories[self.y[index]],
"annotation_{:04d}.mat".format(self.index[index]),
)
)
target.append(data["obj_contour"]) target.append(data["obj_contour"])
target = tuple(target) if len(target) > 1 else target[0] target = tuple(target) if len(target) > 1 else target[0]
...@@ -118,19 +127,21 @@ class Caltech101(VisionDataset): ...@@ -118,19 +127,21 @@ class Caltech101(VisionDataset):
def download(self) -> None: def download(self) -> None:
if self._check_integrity(): if self._check_integrity():
print('Files already downloaded and verified') print("Files already downloaded and verified")
return return
download_and_extract_archive( download_and_extract_archive(
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz", "http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz",
self.root, self.root,
filename="101_ObjectCategories.tar.gz", filename="101_ObjectCategories.tar.gz",
md5="b224c7392d521a49829488ab0f1120d9") md5="b224c7392d521a49829488ab0f1120d9",
)
download_and_extract_archive( download_and_extract_archive(
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar", "http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar",
self.root, self.root,
filename="101_Annotations.tar", filename="101_Annotations.tar",
md5="6f83eeb1f24d99cab4eb377263132c91") md5="6f83eeb1f24d99cab4eb377263132c91",
)
def extra_repr(self) -> str: def extra_repr(self) -> str:
return "Target type: {target_type}".format(**self.__dict__) return "Target type: {target_type}".format(**self.__dict__)
...@@ -152,23 +163,22 @@ class Caltech256(VisionDataset): ...@@ -152,23 +163,22 @@ class Caltech256(VisionDataset):
""" """
def __init__( def __init__(
self, self,
root: str, root: str,
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False, download: bool = False,
) -> None: ) -> None:
super(Caltech256, self).__init__(os.path.join(root, 'caltech256'), super(Caltech256, self).__init__(
transform=transform, os.path.join(root, "caltech256"), transform=transform, target_transform=target_transform
target_transform=target_transform) )
os.makedirs(self.root, exist_ok=True) os.makedirs(self.root, exist_ok=True)
if download: if download:
self.download() self.download()
if not self._check_integrity(): if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' + raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it")
' You can use download=True to download it')
self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories"))) self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories")))
self.index: List[int] = [] self.index: List[int] = []
...@@ -186,10 +196,14 @@ class Caltech256(VisionDataset): ...@@ -186,10 +196,14 @@ class Caltech256(VisionDataset):
Returns: Returns:
tuple: (image, target) where target is index of the target class. tuple: (image, target) where target is index of the target class.
""" """
img = Image.open(os.path.join(self.root, img = Image.open(
"256_ObjectCategories", os.path.join(
self.categories[self.y[index]], self.root,
"{:03d}_{:04d}.jpg".format(self.y[index] + 1, self.index[index]))) "256_ObjectCategories",
self.categories[self.y[index]],
"{:03d}_{:04d}.jpg".format(self.y[index] + 1, self.index[index]),
)
)
target = self.y[index] target = self.y[index]
...@@ -210,11 +224,12 @@ class Caltech256(VisionDataset): ...@@ -210,11 +224,12 @@ class Caltech256(VisionDataset):
def download(self) -> None: def download(self) -> None:
if self._check_integrity(): if self._check_integrity():
print('Files already downloaded and verified') print("Files already downloaded and verified")
return return
download_and_extract_archive( download_and_extract_archive(
"http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar", "http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar",
self.root, self.root,
filename="256_ObjectCategories.tar", filename="256_ObjectCategories.tar",
md5="67b4f42ca05d46448c6bb8ecd2220f6d") md5="67b4f42ca05d46448c6bb8ecd2220f6d",
)
from collections import namedtuple
import csv import csv
from functools import partial
import torch
import os import os
import PIL from collections import namedtuple
from functools import partial
from typing import Any, Callable, List, Optional, Union, Tuple from typing import Any, Callable, List, Optional, Union, Tuple
from .vision import VisionDataset
import PIL
import torch
from .utils import download_file_from_google_drive, check_integrity, verify_str_arg from .utils import download_file_from_google_drive, check_integrity, verify_str_arg
from .vision import VisionDataset
CSV = namedtuple("CSV", ["header", "index", "data"]) CSV = namedtuple("CSV", ["header", "index", "data"])
...@@ -57,16 +59,15 @@ class CelebA(VisionDataset): ...@@ -57,16 +59,15 @@ class CelebA(VisionDataset):
] ]
def __init__( def __init__(
self, self,
root: str, root: str,
split: str = "train", split: str = "train",
target_type: Union[List[str], str] = "attr", target_type: Union[List[str], str] = "attr",
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False, download: bool = False,
) -> None: ) -> None:
super(CelebA, self).__init__(root, transform=transform, super(CelebA, self).__init__(root, transform=transform, target_transform=target_transform)
target_transform=target_transform)
self.split = split self.split = split
if isinstance(target_type, list): if isinstance(target_type, list):
self.target_type = target_type self.target_type = target_type
...@@ -74,14 +75,13 @@ class CelebA(VisionDataset): ...@@ -74,14 +75,13 @@ class CelebA(VisionDataset):
self.target_type = [target_type] self.target_type = [target_type]
if not self.target_type and self.target_transform is not None: if not self.target_type and self.target_transform is not None:
raise RuntimeError('target_transform is specified but target_type is empty') raise RuntimeError("target_transform is specified but target_type is empty")
if download: if download:
self.download() self.download()
if not self._check_integrity(): if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' + raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it")
' You can use download=True to download it')
split_map = { split_map = {
"train": 0, "train": 0,
...@@ -89,8 +89,7 @@ class CelebA(VisionDataset): ...@@ -89,8 +89,7 @@ class CelebA(VisionDataset):
"test": 2, "test": 2,
"all": None, "all": None,
} }
split_ = split_map[verify_str_arg(split.lower(), "split", split_ = split_map[verify_str_arg(split.lower(), "split", ("train", "valid", "test", "all"))]
("train", "valid", "test", "all"))]
splits = self._load_csv("list_eval_partition.txt") splits = self._load_csv("list_eval_partition.txt")
identity = self._load_csv("identity_CelebA.txt") identity = self._load_csv("identity_CelebA.txt")
bbox = self._load_csv("list_bbox_celeba.txt", header=1) bbox = self._load_csv("list_bbox_celeba.txt", header=1)
...@@ -108,7 +107,7 @@ class CelebA(VisionDataset): ...@@ -108,7 +107,7 @@ class CelebA(VisionDataset):
self.landmarks_align = landmarks_align.data[mask] self.landmarks_align = landmarks_align.data[mask]
self.attr = attr.data[mask] self.attr = attr.data[mask]
# map from {-1, 1} to {0, 1} # map from {-1, 1} to {0, 1}
self.attr = torch.div(self.attr + 1, 2, rounding_mode='floor') self.attr = torch.div(self.attr + 1, 2, rounding_mode="floor")
self.attr_names = attr.header self.attr_names = attr.header
def _load_csv( def _load_csv(
...@@ -120,11 +119,11 @@ class CelebA(VisionDataset): ...@@ -120,11 +119,11 @@ class CelebA(VisionDataset):
fn = partial(os.path.join, self.root, self.base_folder) fn = partial(os.path.join, self.root, self.base_folder)
with open(fn(filename)) as csv_file: with open(fn(filename)) as csv_file:
data = list(csv.reader(csv_file, delimiter=' ', skipinitialspace=True)) data = list(csv.reader(csv_file, delimiter=" ", skipinitialspace=True))
if header is not None: if header is not None:
headers = data[header] headers = data[header]
data = data[header + 1:] data = data[header + 1 :]
indices = [row[0] for row in data] indices = [row[0] for row in data]
data = [row[1:] for row in data] data = [row[1:] for row in data]
...@@ -148,7 +147,7 @@ class CelebA(VisionDataset): ...@@ -148,7 +147,7 @@ class CelebA(VisionDataset):
import zipfile import zipfile
if self._check_integrity(): if self._check_integrity():
print('Files already downloaded and verified') print("Files already downloaded and verified")
return return
for (file_id, md5, filename) in self.file_list: for (file_id, md5, filename) in self.file_list:
...@@ -172,7 +171,7 @@ class CelebA(VisionDataset): ...@@ -172,7 +171,7 @@ class CelebA(VisionDataset):
target.append(self.landmarks_align[index, :]) target.append(self.landmarks_align[index, :])
else: else:
# TODO: refactor with utils.verify_str_arg # TODO: refactor with utils.verify_str_arg
raise ValueError("Target type \"{}\" is not recognized.".format(t)) raise ValueError('Target type "{}" is not recognized.'.format(t))
if self.transform is not None: if self.transform is not None:
X = self.transform(X) X = self.transform(X)
...@@ -192,4 +191,4 @@ class CelebA(VisionDataset): ...@@ -192,4 +191,4 @@ class CelebA(VisionDataset):
def extra_repr(self) -> str: def extra_repr(self) -> str:
lines = ["Target type: {target_type}", "Split: {split}"] lines = ["Target type: {target_type}", "Split: {split}"]
return '\n'.join(lines).format(**self.__dict__) return "\n".join(lines).format(**self.__dict__)
from PIL import Image
import os import os
import os.path import os.path
import numpy as np
import pickle import pickle
import torch
from typing import Any, Callable, Optional, Tuple from typing import Any, Callable, Optional, Tuple
from .vision import VisionDataset import numpy as np
import torch
from PIL import Image
from .utils import check_integrity, download_and_extract_archive from .utils import check_integrity, download_and_extract_archive
from .vision import VisionDataset
class CIFAR10(VisionDataset): class CIFAR10(VisionDataset):
...@@ -27,38 +28,38 @@ class CIFAR10(VisionDataset): ...@@ -27,38 +28,38 @@ class CIFAR10(VisionDataset):
downloaded again. downloaded again.
""" """
base_folder = 'cifar-10-batches-py'
base_folder = "cifar-10-batches-py"
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
filename = "cifar-10-python.tar.gz" filename = "cifar-10-python.tar.gz"
tgz_md5 = 'c58f30108f718f92721af3b95e74349a' tgz_md5 = "c58f30108f718f92721af3b95e74349a"
train_list = [ train_list = [
['data_batch_1', 'c99cafc152244af753f735de768cd75f'], ["data_batch_1", "c99cafc152244af753f735de768cd75f"],
['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], ["data_batch_2", "d4bba439e000b95fd0a9bffe97cbabec"],
['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], ["data_batch_3", "54ebc095f3ab1f0389bbae665268c751"],
['data_batch_4', '634d18415352ddfa80567beed471001a'], ["data_batch_4", "634d18415352ddfa80567beed471001a"],
['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], ["data_batch_5", "482c414d41f54cd18b22e5b47cb7c3cb"],
] ]
test_list = [ test_list = [
['test_batch', '40351d587109b95175f43aff81a1287e'], ["test_batch", "40351d587109b95175f43aff81a1287e"],
] ]
meta = { meta = {
'filename': 'batches.meta', "filename": "batches.meta",
'key': 'label_names', "key": "label_names",
'md5': '5ff9c542aee3614f3951f8cda6e48888', "md5": "5ff9c542aee3614f3951f8cda6e48888",
} }
def __init__( def __init__(
self, self,
root: str, root: str,
train: bool = True, train: bool = True,
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False, download: bool = False,
) -> None: ) -> None:
super(CIFAR10, self).__init__(root, transform=transform, super(CIFAR10, self).__init__(root, transform=transform, target_transform=target_transform)
target_transform=target_transform)
self.train = train # training set or test set self.train = train # training set or test set
...@@ -66,8 +67,7 @@ class CIFAR10(VisionDataset): ...@@ -66,8 +67,7 @@ class CIFAR10(VisionDataset):
self.download() self.download()
if not self._check_integrity(): if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' + raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it")
' You can use download=True to download it')
if self.train: if self.train:
downloaded_list = self.train_list downloaded_list = self.train_list
...@@ -80,13 +80,13 @@ class CIFAR10(VisionDataset): ...@@ -80,13 +80,13 @@ class CIFAR10(VisionDataset):
# now load the picked numpy arrays # now load the picked numpy arrays
for file_name, checksum in downloaded_list: for file_name, checksum in downloaded_list:
file_path = os.path.join(self.root, self.base_folder, file_name) file_path = os.path.join(self.root, self.base_folder, file_name)
with open(file_path, 'rb') as f: with open(file_path, "rb") as f:
entry = pickle.load(f, encoding='latin1') entry = pickle.load(f, encoding="latin1")
self.data.append(entry['data']) self.data.append(entry["data"])
if 'labels' in entry: if "labels" in entry:
self.targets.extend(entry['labels']) self.targets.extend(entry["labels"])
else: else:
self.targets.extend(entry['fine_labels']) self.targets.extend(entry["fine_labels"])
self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
...@@ -94,13 +94,14 @@ class CIFAR10(VisionDataset): ...@@ -94,13 +94,14 @@ class CIFAR10(VisionDataset):
self._load_meta() self._load_meta()
def _load_meta(self) -> None: def _load_meta(self) -> None:
path = os.path.join(self.root, self.base_folder, self.meta['filename']) path = os.path.join(self.root, self.base_folder, self.meta["filename"])
if not check_integrity(path, self.meta['md5']): if not check_integrity(path, self.meta["md5"]):
raise RuntimeError('Dataset metadata file not found or corrupted.' + raise RuntimeError(
' You can use download=True to download it') "Dataset metadata file not found or corrupted." + " You can use download=True to download it"
with open(path, 'rb') as infile: )
data = pickle.load(infile, encoding='latin1') with open(path, "rb") as infile:
self.classes = data[self.meta['key']] data = pickle.load(infile, encoding="latin1")
self.classes = data[self.meta["key"]]
self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
def __getitem__(self, index: int) -> Tuple[Any, Any]: def __getitem__(self, index: int) -> Tuple[Any, Any]:
...@@ -130,7 +131,7 @@ class CIFAR10(VisionDataset): ...@@ -130,7 +131,7 @@ class CIFAR10(VisionDataset):
def _check_integrity(self) -> bool: def _check_integrity(self) -> bool:
root = self.root root = self.root
for fentry in (self.train_list + self.test_list): for fentry in self.train_list + self.test_list:
filename, md5 = fentry[0], fentry[1] filename, md5 = fentry[0], fentry[1]
fpath = os.path.join(root, self.base_folder, filename) fpath = os.path.join(root, self.base_folder, filename)
if not check_integrity(fpath, md5): if not check_integrity(fpath, md5):
...@@ -139,7 +140,7 @@ class CIFAR10(VisionDataset): ...@@ -139,7 +140,7 @@ class CIFAR10(VisionDataset):
def download(self) -> None: def download(self) -> None:
if self._check_integrity(): if self._check_integrity():
print('Files already downloaded and verified') print("Files already downloaded and verified")
return return
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
...@@ -152,19 +153,20 @@ class CIFAR100(CIFAR10): ...@@ -152,19 +153,20 @@ class CIFAR100(CIFAR10):
This is a subclass of the `CIFAR10` Dataset. This is a subclass of the `CIFAR10` Dataset.
""" """
base_folder = 'cifar-100-python'
base_folder = "cifar-100-python"
url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
filename = "cifar-100-python.tar.gz" filename = "cifar-100-python.tar.gz"
tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' tgz_md5 = "eb9058c3a382ffc7106e4002c42a8d85"
train_list = [ train_list = [
['train', '16019d7e3df5f24257cddd939b257f8d'], ["train", "16019d7e3df5f24257cddd939b257f8d"],
] ]
test_list = [ test_list = [
['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], ["test", "f0ef6b0ae62326f3e7ffdfab6717acfc"],
] ]
meta = { meta = {
'filename': 'meta', "filename": "meta",
'key': 'fine_label_names', "key": "fine_label_names",
'md5': '7973b15100ade9c7d40fb424638fde48', "md5": "7973b15100ade9c7d40fb424638fde48",
} }
...@@ -3,9 +3,10 @@ import os ...@@ -3,9 +3,10 @@ import os
from collections import namedtuple from collections import namedtuple
from typing import Any, Callable, Dict, List, Optional, Union, Tuple from typing import Any, Callable, Dict, List, Optional, Union, Tuple
from PIL import Image
from .utils import extract_archive, verify_str_arg, iterable_to_str from .utils import extract_archive, verify_str_arg, iterable_to_str
from .vision import VisionDataset from .vision import VisionDataset
from PIL import Image
class Cityscapes(VisionDataset): class Cityscapes(VisionDataset):
...@@ -57,60 +58,62 @@ class Cityscapes(VisionDataset): ...@@ -57,60 +58,62 @@ class Cityscapes(VisionDataset):
""" """
# Based on https://github.com/mcordts/cityscapesScripts # Based on https://github.com/mcordts/cityscapesScripts
CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id', CityscapesClass = namedtuple(
'has_instances', 'ignore_in_eval', 'color']) "CityscapesClass",
["name", "id", "train_id", "category", "category_id", "has_instances", "ignore_in_eval", "color"],
)
classes = [ classes = [
CityscapesClass('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)), CityscapesClass("unlabeled", 0, 255, "void", 0, False, True, (0, 0, 0)),
CityscapesClass('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)), CityscapesClass("ego vehicle", 1, 255, "void", 0, False, True, (0, 0, 0)),
CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)), CityscapesClass("rectification border", 2, 255, "void", 0, False, True, (0, 0, 0)),
CityscapesClass('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)), CityscapesClass("out of roi", 3, 255, "void", 0, False, True, (0, 0, 0)),
CityscapesClass('static', 4, 255, 'void', 0, False, True, (0, 0, 0)), CityscapesClass("static", 4, 255, "void", 0, False, True, (0, 0, 0)),
CityscapesClass('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)), CityscapesClass("dynamic", 5, 255, "void", 0, False, True, (111, 74, 0)),
CityscapesClass('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)), CityscapesClass("ground", 6, 255, "void", 0, False, True, (81, 0, 81)),
CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)), CityscapesClass("road", 7, 0, "flat", 1, False, False, (128, 64, 128)),
CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)), CityscapesClass("sidewalk", 8, 1, "flat", 1, False, False, (244, 35, 232)),
CityscapesClass('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)), CityscapesClass("parking", 9, 255, "flat", 1, False, True, (250, 170, 160)),
CityscapesClass('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)), CityscapesClass("rail track", 10, 255, "flat", 1, False, True, (230, 150, 140)),
CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)), CityscapesClass("building", 11, 2, "construction", 2, False, False, (70, 70, 70)),
CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)), CityscapesClass("wall", 12, 3, "construction", 2, False, False, (102, 102, 156)),
CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)), CityscapesClass("fence", 13, 4, "construction", 2, False, False, (190, 153, 153)),
CityscapesClass('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)), CityscapesClass("guard rail", 14, 255, "construction", 2, False, True, (180, 165, 180)),
CityscapesClass('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)), CityscapesClass("bridge", 15, 255, "construction", 2, False, True, (150, 100, 100)),
CityscapesClass('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)), CityscapesClass("tunnel", 16, 255, "construction", 2, False, True, (150, 120, 90)),
CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)), CityscapesClass("pole", 17, 5, "object", 3, False, False, (153, 153, 153)),
CityscapesClass('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)), CityscapesClass("polegroup", 18, 255, "object", 3, False, True, (153, 153, 153)),
CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)), CityscapesClass("traffic light", 19, 6, "object", 3, False, False, (250, 170, 30)),
CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)), CityscapesClass("traffic sign", 20, 7, "object", 3, False, False, (220, 220, 0)),
CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)), CityscapesClass("vegetation", 21, 8, "nature", 4, False, False, (107, 142, 35)),
CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)), CityscapesClass("terrain", 22, 9, "nature", 4, False, False, (152, 251, 152)),
CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)), CityscapesClass("sky", 23, 10, "sky", 5, False, False, (70, 130, 180)),
CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)), CityscapesClass("person", 24, 11, "human", 6, True, False, (220, 20, 60)),
CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)), CityscapesClass("rider", 25, 12, "human", 6, True, False, (255, 0, 0)),
CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)), CityscapesClass("car", 26, 13, "vehicle", 7, True, False, (0, 0, 142)),
CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)), CityscapesClass("truck", 27, 14, "vehicle", 7, True, False, (0, 0, 70)),
CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)), CityscapesClass("bus", 28, 15, "vehicle", 7, True, False, (0, 60, 100)),
CityscapesClass('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)), CityscapesClass("caravan", 29, 255, "vehicle", 7, True, True, (0, 0, 90)),
CityscapesClass('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)), CityscapesClass("trailer", 30, 255, "vehicle", 7, True, True, (0, 0, 110)),
CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)), CityscapesClass("train", 31, 16, "vehicle", 7, True, False, (0, 80, 100)),
CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)), CityscapesClass("motorcycle", 32, 17, "vehicle", 7, True, False, (0, 0, 230)),
CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)), CityscapesClass("bicycle", 33, 18, "vehicle", 7, True, False, (119, 11, 32)),
CityscapesClass('license plate', -1, -1, 'vehicle', 7, False, True, (0, 0, 142)), CityscapesClass("license plate", -1, -1, "vehicle", 7, False, True, (0, 0, 142)),
] ]
def __init__( def __init__(
self, self,
root: str, root: str,
split: str = "train", split: str = "train",
mode: str = "fine", mode: str = "fine",
target_type: Union[List[str], str] = "instance", target_type: Union[List[str], str] = "instance",
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None, transforms: Optional[Callable] = None,
) -> None: ) -> None:
super(Cityscapes, self).__init__(root, transforms, transform, target_transform) super(Cityscapes, self).__init__(root, transforms, transform, target_transform)
self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse' self.mode = "gtFine" if mode == "fine" else "gtCoarse"
self.images_dir = os.path.join(self.root, 'leftImg8bit', split) self.images_dir = os.path.join(self.root, "leftImg8bit", split)
self.targets_dir = os.path.join(self.root, self.mode, split) self.targets_dir = os.path.join(self.root, self.mode, split)
self.target_type = target_type self.target_type = target_type
self.split = split self.split = split
...@@ -122,35 +125,37 @@ class Cityscapes(VisionDataset): ...@@ -122,35 +125,37 @@ class Cityscapes(VisionDataset):
valid_modes = ("train", "test", "val") valid_modes = ("train", "test", "val")
else: else:
valid_modes = ("train", "train_extra", "val") valid_modes = ("train", "train_extra", "val")
msg = ("Unknown value '{}' for argument split if mode is '{}'. " msg = "Unknown value '{}' for argument split if mode is '{}'. " "Valid values are {{{}}}."
"Valid values are {{{}}}.")
msg = msg.format(split, mode, iterable_to_str(valid_modes)) msg = msg.format(split, mode, iterable_to_str(valid_modes))
verify_str_arg(split, "split", valid_modes, msg) verify_str_arg(split, "split", valid_modes, msg)
if not isinstance(target_type, list): if not isinstance(target_type, list):
self.target_type = [target_type] self.target_type = [target_type]
[verify_str_arg(value, "target_type", [
("instance", "semantic", "polygon", "color")) verify_str_arg(value, "target_type", ("instance", "semantic", "polygon", "color"))
for value in self.target_type] for value in self.target_type
]
if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir): if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):
if split == 'train_extra': if split == "train_extra":
image_dir_zip = os.path.join(self.root, 'leftImg8bit{}'.format('_trainextra.zip')) image_dir_zip = os.path.join(self.root, "leftImg8bit{}".format("_trainextra.zip"))
else: else:
image_dir_zip = os.path.join(self.root, 'leftImg8bit{}'.format('_trainvaltest.zip')) image_dir_zip = os.path.join(self.root, "leftImg8bit{}".format("_trainvaltest.zip"))
if self.mode == 'gtFine': if self.mode == "gtFine":
target_dir_zip = os.path.join(self.root, '{}{}'.format(self.mode, '_trainvaltest.zip')) target_dir_zip = os.path.join(self.root, "{}{}".format(self.mode, "_trainvaltest.zip"))
elif self.mode == 'gtCoarse': elif self.mode == "gtCoarse":
target_dir_zip = os.path.join(self.root, '{}{}'.format(self.mode, '.zip')) target_dir_zip = os.path.join(self.root, "{}{}".format(self.mode, ".zip"))
if os.path.isfile(image_dir_zip) and os.path.isfile(target_dir_zip): if os.path.isfile(image_dir_zip) and os.path.isfile(target_dir_zip):
extract_archive(from_path=image_dir_zip, to_path=self.root) extract_archive(from_path=image_dir_zip, to_path=self.root)
extract_archive(from_path=target_dir_zip, to_path=self.root) extract_archive(from_path=target_dir_zip, to_path=self.root)
else: else:
raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the' raise RuntimeError(
' specified "split" and "mode" are inside the "root" directory') "Dataset not found or incomplete. Please make sure all required folders for the"
' specified "split" and "mode" are inside the "root" directory'
)
for city in os.listdir(self.images_dir): for city in os.listdir(self.images_dir):
img_dir = os.path.join(self.images_dir, city) img_dir = os.path.join(self.images_dir, city)
...@@ -158,8 +163,9 @@ class Cityscapes(VisionDataset): ...@@ -158,8 +163,9 @@ class Cityscapes(VisionDataset):
for file_name in os.listdir(img_dir): for file_name in os.listdir(img_dir):
target_types = [] target_types = []
for t in self.target_type: for t in self.target_type:
target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0], target_name = "{}_{}".format(
self._get_target_suffix(self.mode, t)) file_name.split("_leftImg8bit")[0], self._get_target_suffix(self.mode, t)
)
target_types.append(os.path.join(target_dir, target_name)) target_types.append(os.path.join(target_dir, target_name))
self.images.append(os.path.join(img_dir, file_name)) self.images.append(os.path.join(img_dir, file_name))
...@@ -174,11 +180,11 @@ class Cityscapes(VisionDataset): ...@@ -174,11 +180,11 @@ class Cityscapes(VisionDataset):
than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation. than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
""" """
image = Image.open(self.images[index]).convert('RGB') image = Image.open(self.images[index]).convert("RGB")
targets: Any = [] targets: Any = []
for i, t in enumerate(self.target_type): for i, t in enumerate(self.target_type):
if t == 'polygon': if t == "polygon":
target = self._load_json(self.targets[index][i]) target = self._load_json(self.targets[index][i])
else: else:
target = Image.open(self.targets[index][i]) target = Image.open(self.targets[index][i])
...@@ -197,19 +203,19 @@ class Cityscapes(VisionDataset): ...@@ -197,19 +203,19 @@ class Cityscapes(VisionDataset):
def extra_repr(self) -> str: def extra_repr(self) -> str:
lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"] lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"]
return '\n'.join(lines).format(**self.__dict__) return "\n".join(lines).format(**self.__dict__)
def _load_json(self, path: str) -> Dict[str, Any]: def _load_json(self, path: str) -> Dict[str, Any]:
with open(path, 'r') as file: with open(path, "r") as file:
data = json.load(file) data = json.load(file)
return data return data
def _get_target_suffix(self, mode: str, target_type: str) -> str: def _get_target_suffix(self, mode: str, target_type: str) -> str:
if target_type == 'instance': if target_type == "instance":
return '{}_instanceIds.png'.format(mode) return "{}_instanceIds.png".format(mode)
elif target_type == 'semantic': elif target_type == "semantic":
return '{}_labelIds.png'.format(mode) return "{}_labelIds.png".format(mode)
elif target_type == 'color': elif target_type == "color":
return '{}_color.png'.format(mode) return "{}_color.png".format(mode)
else: else:
return '{}_polygons.json'.format(mode) return "{}_polygons.json".format(mode)
from .vision import VisionDataset
from PIL import Image
import os import os
import os.path import os.path
from typing import Any, Callable, Optional, Tuple, List from typing import Any, Callable, Optional, Tuple, List
from PIL import Image
from .vision import VisionDataset
class CocoDetection(VisionDataset): class CocoDetection(VisionDataset):
"""`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset. """`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset.
......
import torch
from typing import Any, Callable, Optional, Tuple from typing import Any, Callable, Optional, Tuple
from .vision import VisionDataset
import torch
from .. import transforms from .. import transforms
from .vision import VisionDataset
class FakeData(VisionDataset): class FakeData(VisionDataset):
...@@ -21,16 +23,17 @@ class FakeData(VisionDataset): ...@@ -21,16 +23,17 @@ class FakeData(VisionDataset):
""" """
def __init__( def __init__(
self, self,
size: int = 1000, size: int = 1000,
image_size: Tuple[int, int, int] = (3, 224, 224), image_size: Tuple[int, int, int] = (3, 224, 224),
num_classes: int = 10, num_classes: int = 10,
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
random_offset: int = 0, random_offset: int = 0,
) -> None: ) -> None:
super(FakeData, self).__init__(None, transform=transform, # type: ignore[arg-type] super(FakeData, self).__init__(
target_transform=target_transform) None, transform=transform, target_transform=target_transform # type: ignore[arg-type]
)
self.size = size self.size = size
self.num_classes = num_classes self.num_classes = num_classes
self.image_size = image_size self.image_size = image_size
......
import glob
import os
from collections import defaultdict from collections import defaultdict
from PIL import Image
from html.parser import HTMLParser from html.parser import HTMLParser
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
import glob from PIL import Image
import os
from .vision import VisionDataset from .vision import VisionDataset
...@@ -27,26 +28,26 @@ class Flickr8kParser(HTMLParser): ...@@ -27,26 +28,26 @@ class Flickr8kParser(HTMLParser):
def handle_starttag(self, tag: str, attrs: List[Tuple[str, Optional[str]]]) -> None: def handle_starttag(self, tag: str, attrs: List[Tuple[str, Optional[str]]]) -> None:
self.current_tag = tag self.current_tag = tag
if tag == 'table': if tag == "table":
self.in_table = True self.in_table = True
def handle_endtag(self, tag: str) -> None: def handle_endtag(self, tag: str) -> None:
self.current_tag = None self.current_tag = None
if tag == 'table': if tag == "table":
self.in_table = False self.in_table = False
def handle_data(self, data: str) -> None: def handle_data(self, data: str) -> None:
if self.in_table: if self.in_table:
if data == 'Image Not Found': if data == "Image Not Found":
self.current_img = None self.current_img = None
elif self.current_tag == 'a': elif self.current_tag == "a":
img_id = data.split('/')[-2] img_id = data.split("/")[-2]
img_id = os.path.join(self.root, img_id + '_*.jpg') img_id = os.path.join(self.root, img_id + "_*.jpg")
img_id = glob.glob(img_id)[0] img_id = glob.glob(img_id)[0]
self.current_img = img_id self.current_img = img_id
self.annotations[img_id] = [] self.annotations[img_id] = []
elif self.current_tag == 'li' and self.current_img: elif self.current_tag == "li" and self.current_img:
img_id = self.current_img img_id = self.current_img
self.annotations[img_id].append(data.strip()) self.annotations[img_id].append(data.strip())
...@@ -64,14 +65,13 @@ class Flickr8k(VisionDataset): ...@@ -64,14 +65,13 @@ class Flickr8k(VisionDataset):
""" """
def __init__( def __init__(
self, self,
root: str, root: str,
ann_file: str, ann_file: str,
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
) -> None: ) -> None:
super(Flickr8k, self).__init__(root, transform=transform, super(Flickr8k, self).__init__(root, transform=transform, target_transform=target_transform)
target_transform=target_transform)
self.ann_file = os.path.expanduser(ann_file) self.ann_file = os.path.expanduser(ann_file)
# Read annotations and store in a dict # Read annotations and store in a dict
...@@ -93,7 +93,7 @@ class Flickr8k(VisionDataset): ...@@ -93,7 +93,7 @@ class Flickr8k(VisionDataset):
img_id = self.ids[index] img_id = self.ids[index]
# Image # Image
img = Image.open(img_id).convert('RGB') img = Image.open(img_id).convert("RGB")
if self.transform is not None: if self.transform is not None:
img = self.transform(img) img = self.transform(img)
...@@ -121,21 +121,20 @@ class Flickr30k(VisionDataset): ...@@ -121,21 +121,20 @@ class Flickr30k(VisionDataset):
""" """
def __init__( def __init__(
self, self,
root: str, root: str,
ann_file: str, ann_file: str,
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
) -> None: ) -> None:
super(Flickr30k, self).__init__(root, transform=transform, super(Flickr30k, self).__init__(root, transform=transform, target_transform=target_transform)
target_transform=target_transform)
self.ann_file = os.path.expanduser(ann_file) self.ann_file = os.path.expanduser(ann_file)
# Read annotations and store in a dict # Read annotations and store in a dict
self.annotations = defaultdict(list) self.annotations = defaultdict(list)
with open(self.ann_file) as fh: with open(self.ann_file) as fh:
for line in fh: for line in fh:
img_id, caption = line.strip().split('\t') img_id, caption = line.strip().split("\t")
self.annotations[img_id[:-2]].append(caption) self.annotations[img_id[:-2]].append(caption)
self.ids = list(sorted(self.annotations.keys())) self.ids = list(sorted(self.annotations.keys()))
...@@ -152,7 +151,7 @@ class Flickr30k(VisionDataset): ...@@ -152,7 +151,7 @@ class Flickr30k(VisionDataset):
# Image # Image
filename = os.path.join(self.root, img_id) filename = os.path.join(self.root, img_id)
img = Image.open(filename).convert('RGB') img = Image.open(filename).convert("RGB")
if self.transform is not None: if self.transform is not None:
img = self.transform(img) img = self.transform(img)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment