Commit 0fc002df authored by huchen's avatar huchen
Browse files

init the dlexamples new

parent 0e04b692
import os
import torch
from torchvision import transforms as T
from torchvision.transforms import functional as F
from PIL.Image import NEAREST, BILINEAR, BICUBIC
import numpy as np
import unittest
from common_utils import TransformsTester, get_tmp_dir, int_dtypes, float_dtypes
class Tester(TransformsTester):
def setUp(self):
self.device = "cpu"
def _test_functional_op(self, func, fn_kwargs):
if fn_kwargs is None:
fn_kwargs = {}
f = getattr(F, func)
tensor, pil_img = self._create_data(height=10, width=10, device=self.device)
transformed_tensor = f(tensor, **fn_kwargs)
transformed_pil_img = f(pil_img, **fn_kwargs)
self.compareTensorToPIL(transformed_tensor, transformed_pil_img)
def _test_transform_vs_scripted(self, transform, s_transform, tensor, msg=None):
torch.manual_seed(12)
out1 = transform(tensor)
torch.manual_seed(12)
out2 = s_transform(tensor)
self.assertTrue(out1.equal(out2), msg=msg)
def _test_transform_vs_scripted_on_batch(self, transform, s_transform, batch_tensors, msg=None):
torch.manual_seed(12)
transformed_batch = transform(batch_tensors)
for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
torch.manual_seed(12)
transformed_img = transform(img_tensor)
self.assertTrue(transformed_img.equal(transformed_batch[i, ...]), msg=msg)
torch.manual_seed(12)
s_transformed_batch = s_transform(batch_tensors)
self.assertTrue(transformed_batch.equal(s_transformed_batch), msg=msg)
def _test_class_op(self, method, meth_kwargs=None, test_exact_match=True, **match_kwargs):
if meth_kwargs is None:
meth_kwargs = {}
# test for class interface
f = getattr(T, method)(**meth_kwargs)
scripted_fn = torch.jit.script(f)
tensor, pil_img = self._create_data(26, 34, device=self.device)
# set seed to reproduce the same transformation for tensor and PIL image
torch.manual_seed(12)
transformed_tensor = f(tensor)
torch.manual_seed(12)
transformed_pil_img = f(pil_img)
if test_exact_match:
self.compareTensorToPIL(transformed_tensor, transformed_pil_img, **match_kwargs)
else:
self.approxEqualTensorToPIL(transformed_tensor.float(), transformed_pil_img, **match_kwargs)
torch.manual_seed(12)
transformed_tensor_script = scripted_fn(tensor)
self.assertTrue(transformed_tensor.equal(transformed_tensor_script))
batch_tensors = self._create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device)
self._test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors)
with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_{}.pt".format(method)))
def _test_op(self, func, method, fn_kwargs=None, meth_kwargs=None):
self._test_functional_op(func, fn_kwargs)
self._test_class_op(method, meth_kwargs)
def test_random_horizontal_flip(self):
self._test_op('hflip', 'RandomHorizontalFlip')
def test_random_vertical_flip(self):
self._test_op('vflip', 'RandomVerticalFlip')
def test_color_jitter(self):
tol = 1.0 + 1e-10
for f in [0.1, 0.5, 1.0, 1.34, (0.3, 0.7), [0.4, 0.5]]:
meth_kwargs = {"brightness": f}
self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
)
for f in [0.2, 0.5, 1.0, 1.5, (0.3, 0.7), [0.4, 0.5]]:
meth_kwargs = {"contrast": f}
self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
)
for f in [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]]:
meth_kwargs = {"saturation": f}
self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
)
for f in [0.2, 0.5, (-0.2, 0.3), [-0.4, 0.5]]:
meth_kwargs = {"hue": f}
self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=0.1, agg_method="mean"
)
# All 4 parameters together
meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2}
self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=0.1, agg_method="mean"
)
def test_pad(self):
for m in ["constant", "edge", "reflect", "symmetric"]:
fill = 127 if m == "constant" else 0
for mul in [1, -1]:
# Test functional.pad (PIL and Tensor) with padding as single int
self._test_functional_op(
"pad", fn_kwargs={"padding": mul * 2, "fill": fill, "padding_mode": m}
)
# Test functional.pad and transforms.Pad with padding as [int, ]
fn_kwargs = meth_kwargs = {"padding": [mul * 2, ], "fill": fill, "padding_mode": m}
self._test_op(
"pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
# Test functional.pad and transforms.Pad with padding as list
fn_kwargs = meth_kwargs = {"padding": [mul * 4, 4], "fill": fill, "padding_mode": m}
self._test_op(
"pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
# Test functional.pad and transforms.Pad with padding as tuple
fn_kwargs = meth_kwargs = {"padding": (mul * 2, 2, 2, mul * 2), "fill": fill, "padding_mode": m}
self._test_op(
"pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
def test_crop(self):
fn_kwargs = {"top": 2, "left": 3, "height": 4, "width": 5}
# Test transforms.RandomCrop with size and padding as tuple
meth_kwargs = {"size": (4, 5), "padding": (4, 4), "pad_if_needed": True, }
self._test_op(
'crop', 'RandomCrop', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
sizes = [5, [5, ], [6, 6]]
padding_configs = [
{"padding_mode": "constant", "fill": 0},
{"padding_mode": "constant", "fill": 10},
{"padding_mode": "constant", "fill": 20},
{"padding_mode": "edge"},
{"padding_mode": "reflect"},
]
for size in sizes:
for padding_config in padding_configs:
config = dict(padding_config)
config["size"] = size
self._test_class_op("RandomCrop", config)
def test_center_crop(self):
fn_kwargs = {"output_size": (4, 5)}
meth_kwargs = {"size": (4, 5), }
self._test_op(
"center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = {"output_size": (5,)}
meth_kwargs = {"size": (5, )}
self._test_op(
"center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
tensor = torch.randint(0, 255, (3, 10, 10), dtype=torch.uint8, device=self.device)
# Test torchscript of transforms.CenterCrop with size as int
f = T.CenterCrop(size=5)
scripted_fn = torch.jit.script(f)
scripted_fn(tensor)
# Test torchscript of transforms.CenterCrop with size as [int, ]
f = T.CenterCrop(size=[5, ])
scripted_fn = torch.jit.script(f)
scripted_fn(tensor)
# Test torchscript of transforms.CenterCrop with size as tuple
f = T.CenterCrop(size=(6, 6))
scripted_fn = torch.jit.script(f)
scripted_fn(tensor)
with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_center_crop.pt"))
def _test_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kwargs=None):
if fn_kwargs is None:
fn_kwargs = {}
if meth_kwargs is None:
meth_kwargs = {}
fn = getattr(F, func)
scripted_fn = torch.jit.script(fn)
tensor, pil_img = self._create_data(height=20, width=20, device=self.device)
transformed_t_list = fn(tensor, **fn_kwargs)
transformed_p_list = fn(pil_img, **fn_kwargs)
self.assertEqual(len(transformed_t_list), len(transformed_p_list))
self.assertEqual(len(transformed_t_list), out_length)
for transformed_tensor, transformed_pil_img in zip(transformed_t_list, transformed_p_list):
self.compareTensorToPIL(transformed_tensor, transformed_pil_img)
transformed_t_list_script = scripted_fn(tensor.detach().clone(), **fn_kwargs)
self.assertEqual(len(transformed_t_list), len(transformed_t_list_script))
self.assertEqual(len(transformed_t_list_script), out_length)
for transformed_tensor, transformed_tensor_script in zip(transformed_t_list, transformed_t_list_script):
self.assertTrue(transformed_tensor.equal(transformed_tensor_script),
msg="{} vs {}".format(transformed_tensor, transformed_tensor_script))
# test for class interface
fn = getattr(T, method)(**meth_kwargs)
scripted_fn = torch.jit.script(fn)
output = scripted_fn(tensor)
self.assertEqual(len(output), len(transformed_t_list_script))
# test on batch of tensors
batch_tensors = self._create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device)
torch.manual_seed(12)
transformed_batch_list = fn(batch_tensors)
for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
torch.manual_seed(12)
transformed_img_list = fn(img_tensor)
for transformed_img, transformed_batch in zip(transformed_img_list, transformed_batch_list):
self.assertTrue(transformed_img.equal(transformed_batch[i, ...]),
msg="{} vs {}".format(transformed_img, transformed_batch[i, ...]))
with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_op_list_{}.pt".format(method)))
def test_five_crop(self):
fn_kwargs = meth_kwargs = {"size": (5,)}
self._test_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": [5, ]}
self._test_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": (4, 5)}
self._test_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": [4, 5]}
self._test_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
def test_ten_crop(self):
fn_kwargs = meth_kwargs = {"size": (5,)}
self._test_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": [5, ]}
self._test_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": (4, 5)}
self._test_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": [4, 5]}
self._test_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
def test_resize(self):
# TODO: Minimal check for bug-fix, improve this later
x = torch.rand(3, 32, 46)
t = T.Resize(size=38)
y = t(x)
# If size is an int, smaller edge of the image will be matched to this number.
# i.e, if height > width, then image will be rescaled to (size * height / width, size).
self.assertTrue(isinstance(y, torch.Tensor))
self.assertEqual(y.shape[1], 38)
self.assertEqual(y.shape[2], int(38 * 46 / 32))
tensor, _ = self._create_data(height=34, width=36, device=self.device)
batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
script_fn = torch.jit.script(F.resize)
for dt in [None, torch.float32, torch.float64]:
if dt is not None:
# This is a trivial cast to float of uint8 data to test all cases
tensor = tensor.to(dt)
for size in [32, 34, [32, ], [32, 32], (32, 32), [34, 35]]:
for interpolation in [BILINEAR, BICUBIC, NEAREST]:
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation)
if isinstance(size, int):
script_size = [size, ]
else:
script_size = size
s_resized_tensor = script_fn(tensor, size=script_size, interpolation=interpolation)
self.assertTrue(s_resized_tensor.equal(resized_tensor))
transform = T.Resize(size=script_size, interpolation=interpolation)
s_transform = torch.jit.script(transform)
self._test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
with get_tmp_dir() as tmp_dir:
script_fn.save(os.path.join(tmp_dir, "t_resize.pt"))
def test_resized_crop(self):
tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
for scale in [(0.7, 1.2), [0.7, 1.2]]:
for ratio in [(0.75, 1.333), [0.75, 1.333]]:
for size in [(32, ), [44, ], [32, ], [32, 32], (32, 32), [44, 55]]:
for interpolation in [NEAREST, BILINEAR, BICUBIC]:
transform = T.RandomResizedCrop(
size=size, scale=scale, ratio=ratio, interpolation=interpolation
)
s_transform = torch.jit.script(transform)
self._test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
with get_tmp_dir() as tmp_dir:
s_transform.save(os.path.join(tmp_dir, "t_resized_crop.pt"))
def test_random_affine(self):
tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
for shear in [15, 10.0, (5.0, 10.0), [-15, 15], [-10.0, 10.0, -11.0, 11.0]]:
for scale in [(0.7, 1.2), [0.7, 1.2]]:
for translate in [(0.1, 0.2), [0.2, 0.1]]:
for degrees in [45, 35.0, (-45, 45), [-90.0, 90.0]]:
for interpolation in [NEAREST, BILINEAR]:
transform = T.RandomAffine(
degrees=degrees, translate=translate,
scale=scale, shear=shear, resample=interpolation
)
s_transform = torch.jit.script(transform)
self._test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
with get_tmp_dir() as tmp_dir:
s_transform.save(os.path.join(tmp_dir, "t_random_affine.pt"))
def test_random_rotate(self):
tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
for center in [(0, 0), [10, 10], None, (56, 44)]:
for expand in [True, False]:
for degrees in [45, 35.0, (-45, 45), [-90.0, 90.0]]:
for interpolation in [NEAREST, BILINEAR]:
transform = T.RandomRotation(
degrees=degrees, resample=interpolation, expand=expand, center=center
)
s_transform = torch.jit.script(transform)
self._test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
with get_tmp_dir() as tmp_dir:
s_transform.save(os.path.join(tmp_dir, "t_random_rotate.pt"))
def test_random_perspective(self):
tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
for distortion_scale in np.linspace(0.1, 1.0, num=20):
for interpolation in [NEAREST, BILINEAR]:
transform = T.RandomPerspective(
distortion_scale=distortion_scale,
interpolation=interpolation
)
s_transform = torch.jit.script(transform)
self._test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
with get_tmp_dir() as tmp_dir:
s_transform.save(os.path.join(tmp_dir, "t_perspective.pt"))
def test_to_grayscale(self):
meth_kwargs = {"num_output_channels": 1}
tol = 1.0 + 1e-10
self._test_class_op(
"Grayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
)
meth_kwargs = {"num_output_channels": 3}
self._test_class_op(
"Grayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
)
meth_kwargs = {}
self._test_class_op(
"RandomGrayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
)
def test_normalize(self):
tensor, _ = self._create_data(26, 34, device=self.device)
batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)
tensor = tensor.to(dtype=torch.float32) / 255.0
# test for class interface
fn = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
scripted_fn = torch.jit.script(fn)
self._test_transform_vs_scripted(fn, scripted_fn, tensor)
self._test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)
with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_norm.pt"))
def test_linear_transformation(self):
c, h, w = 3, 24, 32
tensor, _ = self._create_data(h, w, channels=c, device=self.device)
matrix = torch.rand(c * h * w, c * h * w, device=self.device)
mean_vector = torch.rand(c * h * w, device=self.device)
fn = T.LinearTransformation(matrix, mean_vector)
scripted_fn = torch.jit.script(fn)
self._test_transform_vs_scripted(fn, scripted_fn, tensor)
batch_tensors = torch.rand(4, c, h, w, device=self.device)
# We skip some tests from _test_transform_vs_scripted_on_batch as
# results for scripted and non-scripted transformations are not exactly the same
torch.manual_seed(12)
transformed_batch = fn(batch_tensors)
torch.manual_seed(12)
s_transformed_batch = scripted_fn(batch_tensors)
self.assertTrue(transformed_batch.equal(s_transformed_batch))
with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_norm.pt"))
def test_compose(self):
tensor, _ = self._create_data(26, 34, device=self.device)
tensor = tensor.to(dtype=torch.float32) / 255.0
transforms = T.Compose([
T.CenterCrop(10),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
s_transforms = torch.nn.Sequential(*transforms.transforms)
scripted_fn = torch.jit.script(s_transforms)
torch.manual_seed(12)
transformed_tensor = transforms(tensor)
torch.manual_seed(12)
transformed_tensor_script = scripted_fn(tensor)
self.assertTrue(transformed_tensor.equal(transformed_tensor_script), msg="{}".format(transforms))
t = T.Compose([
lambda x: x,
])
with self.assertRaisesRegex(RuntimeError, r"Could not get name of python class object"):
torch.jit.script(t)
def test_random_apply(self):
tensor, _ = self._create_data(26, 34, device=self.device)
tensor = tensor.to(dtype=torch.float32) / 255.0
transforms = T.RandomApply([
T.RandomHorizontalFlip(),
T.ColorJitter(),
], p=0.4)
s_transforms = T.RandomApply(torch.nn.ModuleList([
T.RandomHorizontalFlip(),
T.ColorJitter(),
]), p=0.4)
scripted_fn = torch.jit.script(s_transforms)
torch.manual_seed(12)
transformed_tensor = transforms(tensor)
torch.manual_seed(12)
transformed_tensor_script = scripted_fn(tensor)
self.assertTrue(transformed_tensor.equal(transformed_tensor_script), msg="{}".format(transforms))
if torch.device(self.device).type == "cpu":
# Can't check this twice, otherwise
# "Can't redefine method: forward on class: __torch__.torchvision.transforms.transforms.RandomApply"
transforms = T.RandomApply([
T.ColorJitter(),
], p=0.3)
with self.assertRaisesRegex(RuntimeError, r"Module 'RandomApply' has no attribute 'transforms'"):
torch.jit.script(transforms)
def test_gaussian_blur(self):
tol = 1.0 + 1e-10
self._test_class_op(
"GaussianBlur", meth_kwargs={"kernel_size": 3, "sigma": 0.75},
test_exact_match=False, agg_method="max", tol=tol
)
self._test_class_op(
"GaussianBlur", meth_kwargs={"kernel_size": 23, "sigma": [0.1, 2.0]},
test_exact_match=False, agg_method="max", tol=tol
)
self._test_class_op(
"GaussianBlur", meth_kwargs={"kernel_size": 23, "sigma": (0.1, 2.0)},
test_exact_match=False, agg_method="max", tol=tol
)
self._test_class_op(
"GaussianBlur", meth_kwargs={"kernel_size": [3, 3], "sigma": (1.0, 1.0)},
test_exact_match=False, agg_method="max", tol=tol
)
self._test_class_op(
"GaussianBlur", meth_kwargs={"kernel_size": (3, 3), "sigma": (0.1, 2.0)},
test_exact_match=False, agg_method="max", tol=tol
)
self._test_class_op(
"GaussianBlur", meth_kwargs={"kernel_size": [23], "sigma": 0.75},
test_exact_match=False, agg_method="max", tol=tol
)
def test_random_erasing(self):
img = torch.rand(3, 60, 60)
# Test Set 0: invalid value
random_erasing = T.RandomErasing(value=(0.1, 0.2, 0.3, 0.4), p=1.0)
with self.assertRaises(ValueError, msg="If value is a sequence, it should have either a single value or 3"):
random_erasing(img)
tensor, _ = self._create_data(24, 32, channels=3, device=self.device)
batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)
test_configs = [
{"value": 0.2},
{"value": "random"},
{"value": (0.2, 0.2, 0.2)},
{"value": "random", "ratio": (0.1, 0.2)},
]
for config in test_configs:
fn = T.RandomErasing(**config)
scripted_fn = torch.jit.script(fn)
self._test_transform_vs_scripted(fn, scripted_fn, tensor)
self._test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)
with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_random_erasing.pt"))
def test_convert_image_dtype(self):
tensor, _ = self._create_data(26, 34, device=self.device)
batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)
for in_dtype in int_dtypes() + float_dtypes():
in_tensor = tensor.to(in_dtype)
in_batch_tensors = batch_tensors.to(in_dtype)
for out_dtype in int_dtypes() + float_dtypes():
fn = T.ConvertImageDtype(dtype=out_dtype)
scripted_fn = torch.jit.script(fn)
if (in_dtype == torch.float32 and out_dtype in (torch.int32, torch.int64)) or \
(in_dtype == torch.float64 and out_dtype == torch.int64):
with self.assertRaisesRegex(RuntimeError, r"cannot be performed safely"):
self._test_transform_vs_scripted(fn, scripted_fn, in_tensor)
with self.assertRaisesRegex(RuntimeError, r"cannot be performed safely"):
self._test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors)
continue
self._test_transform_vs_scripted(fn, scripted_fn, in_tensor)
self._test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors)
with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_convert_dtype.pt"))
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester):
def setUp(self):
self.device = "cuda"
if __name__ == '__main__':
unittest.main()
import torch
import torchvision.transforms._transforms_video as transforms
from torchvision.transforms import Compose
import unittest
import random
import numpy as np
try:
from scipy import stats
except ImportError:
stats = None
class TestVideoTransforms(unittest.TestCase):
def test_random_crop_video(self):
numFrames = random.randint(4, 128)
height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2
oheight = random.randint(5, (height - 2) / 2) * 2
owidth = random.randint(5, (width - 2) / 2) * 2
clip = torch.randint(0, 256, (numFrames, height, width, 3), dtype=torch.uint8)
result = Compose([
transforms.ToTensorVideo(),
transforms.RandomCropVideo((oheight, owidth)),
])(clip)
self.assertEqual(result.size(2), oheight)
self.assertEqual(result.size(3), owidth)
transforms.RandomCropVideo((oheight, owidth)).__repr__()
def test_random_resized_crop_video(self):
numFrames = random.randint(4, 128)
height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2
oheight = random.randint(5, (height - 2) / 2) * 2
owidth = random.randint(5, (width - 2) / 2) * 2
clip = torch.randint(0, 256, (numFrames, height, width, 3), dtype=torch.uint8)
result = Compose([
transforms.ToTensorVideo(),
transforms.RandomResizedCropVideo((oheight, owidth)),
])(clip)
self.assertEqual(result.size(2), oheight)
self.assertEqual(result.size(3), owidth)
transforms.RandomResizedCropVideo((oheight, owidth)).__repr__()
def test_center_crop_video(self):
numFrames = random.randint(4, 128)
height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2
oheight = random.randint(5, (height - 2) / 2) * 2
owidth = random.randint(5, (width - 2) / 2) * 2
clip = torch.ones((numFrames, height, width, 3), dtype=torch.uint8) * 255
oh1 = (height - oheight) // 2
ow1 = (width - owidth) // 2
clipNarrow = clip[:, oh1:oh1 + oheight, ow1:ow1 + owidth, :]
clipNarrow.fill_(0)
result = Compose([
transforms.ToTensorVideo(),
transforms.CenterCropVideo((oheight, owidth)),
])(clip)
msg = "height: " + str(height) + " width: " \
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
self.assertEqual(result.sum().item(), 0, msg)
oheight += 1
owidth += 1
result = Compose([
transforms.ToTensorVideo(),
transforms.CenterCropVideo((oheight, owidth)),
])(clip)
sum1 = result.sum()
msg = "height: " + str(height) + " width: " \
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
self.assertEqual(sum1.item() > 1, True, msg)
oheight += 1
owidth += 1
result = Compose([
transforms.ToTensorVideo(),
transforms.CenterCropVideo((oheight, owidth)),
])(clip)
sum2 = result.sum()
msg = "height: " + str(height) + " width: " \
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
self.assertTrue(sum2.item() > 1, msg)
self.assertTrue(sum2.item() > sum1.item(), msg)
@unittest.skipIf(stats is None, 'scipy.stats is not available')
def test_normalize_video(self):
def samples_from_standard_normal(tensor):
p_value = stats.kstest(list(tensor.view(-1)), 'norm', args=(0, 1)).pvalue
return p_value > 0.0001
random_state = random.getstate()
random.seed(42)
for channels in [1, 3]:
numFrames = random.randint(4, 128)
height = random.randint(32, 256)
width = random.randint(32, 256)
mean = random.random()
std = random.random()
clip = torch.normal(mean, std, size=(channels, numFrames, height, width))
mean = [clip[c].mean().item() for c in range(channels)]
std = [clip[c].std().item() for c in range(channels)]
normalized = transforms.NormalizeVideo(mean, std)(clip)
self.assertTrue(samples_from_standard_normal(normalized))
random.setstate(random_state)
# Checking the optional in-place behaviour
tensor = torch.rand((3, 128, 16, 16))
tensor_inplace = transforms.NormalizeVideo((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)(tensor)
self.assertTrue(torch.equal(tensor, tensor_inplace))
transforms.NormalizeVideo((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True).__repr__()
def test_to_tensor_video(self):
numFrames, height, width = 64, 4, 4
trans = transforms.ToTensorVideo()
with self.assertRaises(TypeError):
trans(np.random.rand(numFrames, height, width, 1).tolist())
trans(torch.rand((numFrames, height, width, 1), dtype=torch.float))
with self.assertRaises(ValueError):
trans(torch.ones((3, numFrames, height, width, 3), dtype=torch.uint8))
trans(torch.ones((height, width, 3), dtype=torch.uint8))
trans(torch.ones((width, 3), dtype=torch.uint8))
trans(torch.ones((3), dtype=torch.uint8))
trans.__repr__()
@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_horizontal_flip_video(self):
random_state = random.getstate()
random.seed(42)
clip = torch.rand((3, 4, 112, 112), dtype=torch.float)
hclip = clip.flip((-1))
num_samples = 250
num_horizontal = 0
for _ in range(num_samples):
out = transforms.RandomHorizontalFlipVideo()(clip)
if torch.all(torch.eq(out, hclip)):
num_horizontal += 1
p_value = stats.binom_test(num_horizontal, num_samples, p=0.5)
random.setstate(random_state)
self.assertGreater(p_value, 0.0001)
num_samples = 250
num_horizontal = 0
for _ in range(num_samples):
out = transforms.RandomHorizontalFlipVideo(p=0.7)(clip)
if torch.all(torch.eq(out, hclip)):
num_horizontal += 1
p_value = stats.binom_test(num_horizontal, num_samples, p=0.7)
random.setstate(random_state)
self.assertGreater(p_value, 0.0001)
transforms.RandomHorizontalFlipVideo().__repr__()
if __name__ == '__main__':
unittest.main()
import os
import sys
import tempfile
import torch
import torchvision.utils as utils
import unittest
from io import BytesIO
import torchvision.transforms.functional as F
from PIL import Image
class Tester(unittest.TestCase):
def test_make_grid_not_inplace(self):
t = torch.rand(5, 3, 10, 10)
t_clone = t.clone()
utils.make_grid(t, normalize=False)
self.assertTrue(torch.equal(t, t_clone), 'make_grid modified tensor in-place')
utils.make_grid(t, normalize=True, scale_each=False)
self.assertTrue(torch.equal(t, t_clone), 'make_grid modified tensor in-place')
utils.make_grid(t, normalize=True, scale_each=True)
self.assertTrue(torch.equal(t, t_clone), 'make_grid modified tensor in-place')
def test_normalize_in_make_grid(self):
t = torch.rand(5, 3, 10, 10) * 255
norm_max = torch.tensor(1.0)
norm_min = torch.tensor(0.0)
grid = utils.make_grid(t, normalize=True)
grid_max = torch.max(grid)
grid_min = torch.min(grid)
# Rounding the result to one decimal for comparison
n_digits = 1
rounded_grid_max = torch.round(grid_max * 10 ** n_digits) / (10 ** n_digits)
rounded_grid_min = torch.round(grid_min * 10 ** n_digits) / (10 ** n_digits)
self.assertTrue(torch.equal(norm_max, rounded_grid_max), 'Normalized max is not equal to 1')
self.assertTrue(torch.equal(norm_min, rounded_grid_min), 'Normalized min is not equal to 0')
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
def test_save_image(self):
with tempfile.NamedTemporaryFile(suffix='.png') as f:
t = torch.rand(2, 3, 64, 64)
utils.save_image(t, f.name)
self.assertTrue(os.path.exists(f.name), 'The image is not present after save')
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
def test_save_image_single_pixel(self):
with tempfile.NamedTemporaryFile(suffix='.png') as f:
t = torch.rand(1, 3, 1, 1)
utils.save_image(t, f.name)
self.assertTrue(os.path.exists(f.name), 'The pixel image is not present after save')
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
def test_save_image_file_object(self):
with tempfile.NamedTemporaryFile(suffix='.png') as f:
t = torch.rand(2, 3, 64, 64)
utils.save_image(t, f.name)
img_orig = Image.open(f.name)
fp = BytesIO()
utils.save_image(t, fp, format='png')
img_bytes = Image.open(fp)
self.assertTrue(torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)),
'Image not stored in file object')
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
def test_save_image_single_pixel_file_object(self):
with tempfile.NamedTemporaryFile(suffix='.png') as f:
t = torch.rand(1, 3, 1, 1)
utils.save_image(t, f.name)
img_orig = Image.open(f.name)
fp = BytesIO()
utils.save_image(t, fp, format='png')
img_bytes = Image.open(fp)
self.assertTrue(torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)),
'Pixel Image not stored in file object')
if __name__ == '__main__':
unittest.main()
import os
import collections
import contextlib
import tempfile
import unittest
import random
import itertools
import numpy as np
import torch
import torchvision
from torchvision.io import _HAS_VIDEO_OPT, VideoReader
try:
import av
# Do a version test too
torchvision.io.video._check_av_available()
except ImportError:
av = None
VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos")
CheckerConfig = [
"duration",
"video_fps",
"audio_sample_rate",
# We find for some videos (e.g. HMDB51 videos), the decoded audio frames and pts are
# slightly different between TorchVision decoder and PyAv decoder. So omit it during check
"check_aframes",
"check_aframe_pts",
]
GroundTruth = collections.namedtuple("GroundTruth", " ".join(CheckerConfig))
all_check_config = GroundTruth(
duration=0,
video_fps=0,
audio_sample_rate=0,
check_aframes=True,
check_aframe_pts=True,
)
test_videos = {
"RATRACE_wave_f_nm_np1_fr_goo_37.avi": GroundTruth(
duration=2.0,
video_fps=30.0,
audio_sample_rate=None,
check_aframes=True,
check_aframe_pts=True,
),
"SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi": GroundTruth(
duration=2.0,
video_fps=30.0,
audio_sample_rate=None,
check_aframes=True,
check_aframe_pts=True,
),
"TrumanShow_wave_f_nm_np1_fr_med_26.avi": GroundTruth(
duration=2.0,
video_fps=30.0,
audio_sample_rate=None,
check_aframes=True,
check_aframe_pts=True,
),
"v_SoccerJuggling_g23_c01.avi": GroundTruth(
duration=8.0,
video_fps=29.97,
audio_sample_rate=None,
check_aframes=True,
check_aframe_pts=True,
),
"v_SoccerJuggling_g24_c01.avi": GroundTruth(
duration=8.0,
video_fps=29.97,
audio_sample_rate=None,
check_aframes=True,
check_aframe_pts=True,
),
# Last three test segfault on video reader (see issues)
"R6llTwEh07w.mp4": GroundTruth(
duration=10.0,
video_fps=30.0,
audio_sample_rate=44100,
# PyAv miss one audio frame at the beginning (pts=0)
check_aframes=False,
check_aframe_pts=False,
),
"SOX5yA1l24A.mp4": GroundTruth(
duration=11.0,
video_fps=29.97,
audio_sample_rate=48000,
# PyAv miss one audio frame at the beginning (pts=0)
check_aframes=False,
check_aframe_pts=False,
),
"WUzgd7C1pWA.mp4": GroundTruth(
duration=11.0,
video_fps=29.97,
audio_sample_rate=48000,
# PyAv miss one audio frame at the beginning (pts=0)
check_aframes=False,
check_aframe_pts=False,
),
}
DecoderResult = collections.namedtuple(
"DecoderResult", "vframes vframe_pts vtimebase aframes aframe_pts atimebase"
)
def _read_from_stream(
container, start_pts, end_pts, stream, stream_name, buffer_size=4
):
"""
Args:
container: pyav container
start_pts/end_pts: the starting/ending Presentation TimeStamp where
frames are read
stream: pyav stream
stream_name: a dictionary of streams. For example, {"video": 0} means
video stream at stream index 0
buffer_size: pts of frames decoded by PyAv is not guaranteed to be in
ascending order. We need to decode more frames even when we meet end
pts
"""
# seeking in the stream is imprecise. Thus, seek to an ealier PTS by a margin
margin = 1
seek_offset = max(start_pts - margin, 0)
container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
frames = {}
buffer_count = 0
for frame in container.decode(**stream_name):
if frame.pts < start_pts:
continue
if frame.pts <= end_pts:
frames[frame.pts] = frame
else:
buffer_count += 1
if buffer_count >= buffer_size:
break
result = [frames[pts] for pts in sorted(frames)]
return result
def _fraction_to_tensor(fraction):
ret = torch.zeros([2], dtype=torch.int32)
ret[0] = fraction.numerator
ret[1] = fraction.denominator
return ret
def _decode_frames_by_av_module(
full_path,
video_start_pts=0,
video_end_pts=None,
audio_start_pts=0,
audio_end_pts=None,
):
"""
Use PyAv to decode video frames. This provides a reference for our decoder
to compare the decoding results.
Input arguments:
full_path: video file path
video_start_pts/video_end_pts: the starting/ending Presentation TimeStamp where
frames are read
"""
if video_end_pts is None:
video_end_pts = float("inf")
if audio_end_pts is None:
audio_end_pts = float("inf")
container = av.open(full_path)
video_frames = []
vtimebase = torch.zeros([0], dtype=torch.int32)
if container.streams.video:
video_frames = _read_from_stream(
container,
video_start_pts,
video_end_pts,
container.streams.video[0],
{"video": 0},
)
# container.streams.video[0].average_rate is not a reliable estimator of
# frame rate. It can be wrong for certain codec, such as VP80
# So we do not return video fps here
vtimebase = _fraction_to_tensor(container.streams.video[0].time_base)
audio_frames = []
atimebase = torch.zeros([0], dtype=torch.int32)
if container.streams.audio:
audio_frames = _read_from_stream(
container,
audio_start_pts,
audio_end_pts,
container.streams.audio[0],
{"audio": 0},
)
atimebase = _fraction_to_tensor(container.streams.audio[0].time_base)
container.close()
vframes = [frame.to_rgb().to_ndarray() for frame in video_frames]
vframes = torch.as_tensor(np.stack(vframes))
vframe_pts = torch.tensor([frame.pts for frame in video_frames], dtype=torch.int64)
aframes = [frame.to_ndarray() for frame in audio_frames]
if aframes:
aframes = np.transpose(np.concatenate(aframes, axis=1))
aframes = torch.as_tensor(aframes)
else:
aframes = torch.empty((1, 0), dtype=torch.float32)
aframe_pts = torch.tensor(
[audio_frame.pts for audio_frame in audio_frames], dtype=torch.int64
)
return DecoderResult(
vframes=vframes.permute(0, 3, 1, 2),
vframe_pts=vframe_pts,
vtimebase=vtimebase,
aframes=aframes,
aframe_pts=aframe_pts,
atimebase=atimebase,
)
def _template_read_video(video_object, s=0, e=None):
if e is None:
e = float("inf")
if e < s:
raise ValueError(
"end time should be larger than start time, got "
"start time={} and end time={}".format(s, e)
)
video_object.set_current_stream("video")
video_object.seek(s)
video_frames = torch.empty(0)
frames = []
video_pts = []
for frame in itertools.takewhile(lambda x: x['pts'] <= e, video_object):
if frame['pts'] < s:
continue
frames.append(frame['data'])
video_pts.append(frame['pts'])
if len(frames) > 0:
video_frames = torch.stack(frames, 0)
video_object.set_current_stream("audio")
video_object.seek(s)
audio_frames = torch.empty(0)
frames = []
audio_pts = []
for frame in itertools.takewhile(lambda x: x['pts'] <= e, video_object):
if frame['pts'] < s:
continue
frames.append(frame['data'])
audio_pts.append(frame['pts'])
if len(frames) > 0:
audio_frames = torch.stack(frames, 0)
return DecoderResult(
vframes=video_frames,
vframe_pts=video_pts,
vtimebase=None,
aframes=audio_frames,
aframe_pts=audio_pts,
atimebase=None,
)
return video_frames, audio_frames, video_object.get_metadata()
@unittest.skipIf(_HAS_VIDEO_OPT is False, "Didn't compile with ffmpeg")
class TestVideo(unittest.TestCase):
@unittest.skipIf(av is None, "PyAV unavailable")
def test_read_video_tensor(self):
"""
Check if reading the video using the `next` based API yields the
same sized tensors as the pyav alternative.
"""
torchvision.set_video_backend("pyav")
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
# pass 1: decode all frames using existing TV decoder
tv_result, _, _ = torchvision.io.read_video(full_path, pts_unit="sec")
tv_result = tv_result.permute(0, 3, 1, 2)
# pass 2: decode all frames using new api
reader = VideoReader(full_path, "video")
frames = []
for frame in reader:
frames.append(frame['data'])
new_api = torch.stack(frames, 0)
self.assertEqual(tv_result.size(), new_api.size())
# def test_partial_video_reading_fn(self):
# torchvision.set_video_backend("video_reader")
# for test_video, config in test_videos.items():
# full_path = os.path.join(VIDEO_DIR, test_video)
# # select two random points between 0 and duration
# r = []
# r.append(random.uniform(0, config.duration))
# r.append(random.uniform(0, config.duration))
# s = min(r)
# e = max(r)
# reader = VideoReader(full_path, "video")
# results = _template_read_video(reader, s, e)
# tv_video, tv_audio, info = torchvision.io.read_video(
# full_path, start_pts=s, end_pts=e, pts_unit="sec"
# )
# self.assertAlmostEqual(tv_video.size(0), results.vframes.size(0), delta=2.0)
# def test_pts(self):
# """
# Check if every frame read from
# """
# torchvision.set_video_backend("video_reader")
# for test_video, config in test_videos.items():
# full_path = os.path.join(VIDEO_DIR, test_video)
# tv_timestamps, _ = torchvision.io.read_video_timestamps(
# full_path, pts_unit="sec"
# )
# # pass 2: decode all frames using new api
# reader = VideoReader(full_path, "video")
# pts = []
# t, p = next(reader)
# while t.numel() > 0: # THIS NEEDS TO BE FIXED
# pts.append(p)
# t, p = next(reader)
# tv_timestamps = [float(p) for p in tv_timestamps]
# napi_pts = [float(p) for p in pts]
# for i in range(len(napi_pts)):
# self.assertAlmostEqual(napi_pts[i], tv_timestamps[i], delta=0.001)
# # check if pts of video frames are sorted in ascending order
# for i in range(len(napi_pts) - 1):
# self.assertEqual(napi_pts[i] < napi_pts[i + 1], True)
@unittest.skipIf(av is None, "PyAV unavailable")
def test_metadata(self):
"""
Test that the metadata returned via pyav corresponds to the one returned
by the new video decoder API
"""
torchvision.set_video_backend("pyav")
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
reader = VideoReader(full_path, "video")
reader_md = reader.get_metadata()
self.assertAlmostEqual(
config.video_fps, reader_md["video"]["fps"][0], delta=0.0001
)
self.assertAlmostEqual(
config.duration, reader_md["video"]["duration"][0], delta=0.5
)
@unittest.skipIf(av is None, "PyAV unavailable")
def test_video_reading_fn(self):
"""
Test that the outputs of the pyav and ffmpeg outputs are mostly the same
"""
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
ref_result = _decode_frames_by_av_module(full_path)
reader = VideoReader(full_path, "video")
newapi_result = _template_read_video(reader)
# First we check if the frames are approximately the same
# (note that every codec context has signature artefacts which
# make a direct comparison not feasible)
if newapi_result.vframes.numel() > 0 and ref_result.vframes.numel() > 0:
mean_delta = torch.mean(
torch.abs(
newapi_result.vframes.float() - ref_result.vframes.float()
)
)
self.assertAlmostEqual(mean_delta, 0, delta=8.0)
# Just a sanity check: are the two of the correct size?
self.assertEqual(newapi_result.vframes.size(), ref_result.vframes.size())
# Lastly, we compare the resulting audio streams
if (
config.check_aframes
and newapi_result.aframes.numel() > 0
and ref_result.aframes.numel() > 0
):
"""Audio stream is available and audio frame is required to return
from decoder"""
is_same = torch.all(
torch.eq(newapi_result.aframes, ref_result.aframes)
).item()
self.assertEqual(is_same, True)
if __name__ == "__main__":
unittest.main()
import collections
import math
import os
import time
import unittest
from fractions import Fraction
import numpy as np
import torch
import torchvision.io as io
from numpy.random import randint
from torchvision.io import _HAS_VIDEO_OPT
try:
import av
# Do a version test too
io.video._check_av_available()
except ImportError:
av = None
from urllib.error import URLError
VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos")
CheckerConfig = [
"duration",
"video_fps",
"audio_sample_rate",
# We find for some videos (e.g. HMDB51 videos), the decoded audio frames and pts are
# slightly different between TorchVision decoder and PyAv decoder. So omit it during check
"check_aframes",
"check_aframe_pts",
]
GroundTruth = collections.namedtuple("GroundTruth", " ".join(CheckerConfig))
all_check_config = GroundTruth(
duration=0,
video_fps=0,
audio_sample_rate=0,
check_aframes=True,
check_aframe_pts=True,
)
test_videos = {
"RATRACE_wave_f_nm_np1_fr_goo_37.avi": GroundTruth(
duration=2.0,
video_fps=30.0,
audio_sample_rate=None,
check_aframes=True,
check_aframe_pts=True,
),
"SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi": GroundTruth(
duration=2.0,
video_fps=30.0,
audio_sample_rate=None,
check_aframes=True,
check_aframe_pts=True,
),
"TrumanShow_wave_f_nm_np1_fr_med_26.avi": GroundTruth(
duration=2.0,
video_fps=30.0,
audio_sample_rate=None,
check_aframes=True,
check_aframe_pts=True,
),
"v_SoccerJuggling_g23_c01.avi": GroundTruth(
duration=8.0,
video_fps=29.97,
audio_sample_rate=None,
check_aframes=True,
check_aframe_pts=True,
),
"v_SoccerJuggling_g24_c01.avi": GroundTruth(
duration=8.0,
video_fps=29.97,
audio_sample_rate=None,
check_aframes=True,
check_aframe_pts=True,
),
"R6llTwEh07w.mp4": GroundTruth(
duration=10.0,
video_fps=30.0,
audio_sample_rate=44100,
# PyAv miss one audio frame at the beginning (pts=0)
check_aframes=False,
check_aframe_pts=False,
),
"SOX5yA1l24A.mp4": GroundTruth(
duration=11.0,
video_fps=29.97,
audio_sample_rate=48000,
# PyAv miss one audio frame at the beginning (pts=0)
check_aframes=False,
check_aframe_pts=False,
),
"WUzgd7C1pWA.mp4": GroundTruth(
duration=11.0,
video_fps=29.97,
audio_sample_rate=48000,
# PyAv miss one audio frame at the beginning (pts=0)
check_aframes=False,
check_aframe_pts=False,
),
}
DecoderResult = collections.namedtuple(
"DecoderResult", "vframes vframe_pts vtimebase aframes aframe_pts atimebase"
)
"""av_seek_frame is imprecise so seek to a timestamp earlier by a margin
The unit of margin is second"""
seek_frame_margin = 0.25
def _read_from_stream(
container, start_pts, end_pts, stream, stream_name, buffer_size=4
):
"""
Args:
container: pyav container
start_pts/end_pts: the starting/ending Presentation TimeStamp where
frames are read
stream: pyav stream
stream_name: a dictionary of streams. For example, {"video": 0} means
video stream at stream index 0
buffer_size: pts of frames decoded by PyAv is not guaranteed to be in
ascending order. We need to decode more frames even when we meet end
pts
"""
# seeking in the stream is imprecise. Thus, seek to an ealier PTS by a margin
margin = 1
seek_offset = max(start_pts - margin, 0)
container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
frames = {}
buffer_count = 0
for frame in container.decode(**stream_name):
if frame.pts < start_pts:
continue
if frame.pts <= end_pts:
frames[frame.pts] = frame
else:
buffer_count += 1
if buffer_count >= buffer_size:
break
result = [frames[pts] for pts in sorted(frames)]
return result
def _get_timebase_by_av_module(full_path):
container = av.open(full_path)
video_time_base = container.streams.video[0].time_base
if container.streams.audio:
audio_time_base = container.streams.audio[0].time_base
else:
audio_time_base = None
return video_time_base, audio_time_base
def _fraction_to_tensor(fraction):
ret = torch.zeros([2], dtype=torch.int32)
ret[0] = fraction.numerator
ret[1] = fraction.denominator
return ret
def _decode_frames_by_av_module(
full_path,
video_start_pts=0,
video_end_pts=None,
audio_start_pts=0,
audio_end_pts=None,
):
"""
Use PyAv to decode video frames. This provides a reference for our decoder
to compare the decoding results.
Input arguments:
full_path: video file path
video_start_pts/video_end_pts: the starting/ending Presentation TimeStamp where
frames are read
"""
if video_end_pts is None:
video_end_pts = float("inf")
if audio_end_pts is None:
audio_end_pts = float("inf")
container = av.open(full_path)
video_frames = []
vtimebase = torch.zeros([0], dtype=torch.int32)
if container.streams.video:
video_frames = _read_from_stream(
container,
video_start_pts,
video_end_pts,
container.streams.video[0],
{"video": 0},
)
# container.streams.video[0].average_rate is not a reliable estimator of
# frame rate. It can be wrong for certain codec, such as VP80
# So we do not return video fps here
vtimebase = _fraction_to_tensor(container.streams.video[0].time_base)
audio_frames = []
atimebase = torch.zeros([0], dtype=torch.int32)
if container.streams.audio:
audio_frames = _read_from_stream(
container,
audio_start_pts,
audio_end_pts,
container.streams.audio[0],
{"audio": 0},
)
atimebase = _fraction_to_tensor(container.streams.audio[0].time_base)
container.close()
vframes = [frame.to_rgb().to_ndarray() for frame in video_frames]
vframes = torch.as_tensor(np.stack(vframes))
vframe_pts = torch.tensor([frame.pts for frame in video_frames], dtype=torch.int64)
aframes = [frame.to_ndarray() for frame in audio_frames]
if aframes:
aframes = np.transpose(np.concatenate(aframes, axis=1))
aframes = torch.as_tensor(aframes)
else:
aframes = torch.empty((1, 0), dtype=torch.float32)
aframe_pts = torch.tensor(
[audio_frame.pts for audio_frame in audio_frames], dtype=torch.int64
)
return DecoderResult(
vframes=vframes,
vframe_pts=vframe_pts,
vtimebase=vtimebase,
aframes=aframes,
aframe_pts=aframe_pts,
atimebase=atimebase,
)
def _pts_convert(pts, timebase_from, timebase_to, round_func=math.floor):
"""convert pts between different time bases
Args:
pts: presentation timestamp, float
timebase_from: original timebase. Fraction
timebase_to: new timebase. Fraction
round_func: rounding function.
"""
new_pts = Fraction(pts, 1) * timebase_from / timebase_to
return int(round_func(new_pts))
def _get_video_tensor(video_dir, video_file):
"""open a video file, and represent the video data by a PT tensor"""
full_path = os.path.join(video_dir, video_file)
assert os.path.exists(full_path), "File not found: %s" % full_path
with open(full_path, "rb") as fp:
video_tensor = torch.from_numpy(np.frombuffer(fp.read(), dtype=np.uint8))
return full_path, video_tensor
@unittest.skipIf(av is None, "PyAV unavailable")
@unittest.skipIf(_HAS_VIDEO_OPT is False, "Didn't compile with ffmpeg")
class TestVideoReader(unittest.TestCase):
def check_separate_decoding_result(self, tv_result, config):
"""check the decoding results from TorchVision decoder
"""
vframes, vframe_pts, vtimebase, vfps, vduration, \
aframes, aframe_pts, atimebase, asample_rate, aduration = (
tv_result
)
video_duration = vduration.item() * Fraction(
vtimebase[0].item(), vtimebase[1].item()
)
self.assertAlmostEqual(video_duration, config.duration, delta=0.5)
self.assertAlmostEqual(vfps.item(), config.video_fps, delta=0.5)
if asample_rate.numel() > 0:
self.assertEqual(asample_rate.item(), config.audio_sample_rate)
audio_duration = aduration.item() * Fraction(
atimebase[0].item(), atimebase[1].item()
)
self.assertAlmostEqual(audio_duration, config.duration, delta=0.5)
# check if pts of video frames are sorted in ascending order
for i in range(len(vframe_pts) - 1):
self.assertEqual(vframe_pts[i] < vframe_pts[i + 1], True)
if len(aframe_pts) > 1:
# check if pts of audio frames are sorted in ascending order
for i in range(len(aframe_pts) - 1):
self.assertEqual(aframe_pts[i] < aframe_pts[i + 1], True)
def check_probe_result(self, result, config):
vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
video_duration = vduration.item() * Fraction(
vtimebase[0].item(), vtimebase[1].item()
)
self.assertAlmostEqual(video_duration, config.duration, delta=0.5)
self.assertAlmostEqual(vfps.item(), config.video_fps, delta=0.5)
if asample_rate.numel() > 0:
self.assertEqual(asample_rate.item(), config.audio_sample_rate)
audio_duration = aduration.item() * Fraction(
atimebase[0].item(), atimebase[1].item()
)
self.assertAlmostEqual(audio_duration, config.duration, delta=0.5)
def check_meta_result(self, result, config):
self.assertAlmostEqual(result.video_duration, config.duration, delta=0.5)
self.assertAlmostEqual(result.video_fps, config.video_fps, delta=0.5)
if result.has_audio > 0:
self.assertEqual(result.audio_sample_rate, config.audio_sample_rate)
self.assertAlmostEqual(result.audio_duration, config.duration, delta=0.5)
def compare_decoding_result(self, tv_result, ref_result, config=all_check_config):
"""
Compare decoding results from two sources.
Args:
tv_result: decoding results from TorchVision decoder
ref_result: reference decoding results which can be from either PyAv
decoder or TorchVision decoder with getPtsOnly = 1
config: config of decoding results checker
"""
vframes, vframe_pts, vtimebase, _vfps, _vduration, \
aframes, aframe_pts, atimebase, _asample_rate, _aduration = (
tv_result
)
if isinstance(ref_result, list):
# the ref_result is from new video_reader decoder
ref_result = DecoderResult(
vframes=ref_result[0],
vframe_pts=ref_result[1],
vtimebase=ref_result[2],
aframes=ref_result[5],
aframe_pts=ref_result[6],
atimebase=ref_result[7],
)
if vframes.numel() > 0 and ref_result.vframes.numel() > 0:
mean_delta = torch.mean(
torch.abs(vframes.float() - ref_result.vframes.float())
)
self.assertAlmostEqual(mean_delta, 0, delta=8.0)
mean_delta = torch.mean(
torch.abs(vframe_pts.float() - ref_result.vframe_pts.float())
)
self.assertAlmostEqual(mean_delta, 0, delta=1.0)
is_same = torch.all(torch.eq(vtimebase, ref_result.vtimebase)).item()
self.assertEqual(is_same, True)
if (
config.check_aframes
and aframes.numel() > 0
and ref_result.aframes.numel() > 0
):
"""Audio stream is available and audio frame is required to return
from decoder"""
is_same = torch.all(torch.eq(aframes, ref_result.aframes)).item()
self.assertEqual(is_same, True)
if (
config.check_aframe_pts
and aframe_pts.numel() > 0
and ref_result.aframe_pts.numel() > 0
):
"""Audio stream is available"""
is_same = torch.all(torch.eq(aframe_pts, ref_result.aframe_pts)).item()
self.assertEqual(is_same, True)
is_same = torch.all(torch.eq(atimebase, ref_result.atimebase)).item()
self.assertEqual(is_same, True)
@unittest.skip(
"This stress test will iteratively decode the same set of videos."
"It helps to detect memory leak but it takes lots of time to run."
"By default, it is disabled"
)
def test_stress_test_read_video_from_file(self):
num_iter = 10000
# video related
width, height, min_dimension, max_dimension = 0, 0, 0, 0
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
samples, channels = 0, 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for _i in range(num_iter):
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
# pass 1: decode all frames using new decoder
torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
def test_read_video_from_file(self):
"""
Test the case when decoder starts with a video file to decode frames.
"""
# video related
width, height, min_dimension, max_dimension = 0, 0, 0, 0
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
samples, channels = 0, 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
# pass 1: decode all frames using new decoder
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
# pass 2: decode all frames using av
pyav_result = _decode_frames_by_av_module(full_path)
# check results from TorchVision decoder
self.check_separate_decoding_result(tv_result, config)
# compare decoding results
self.compare_decoding_result(tv_result, pyav_result, config)
def test_read_video_from_file_read_single_stream_only(self):
"""
Test the case when decoder starts with a video file to decode frames, and
only reads video stream and ignores audio stream
"""
# video related
width, height, min_dimension, max_dimension = 0, 0, 0, 0
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
samples, channels = 0, 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
for readVideoStream, readAudioStream in [(1, 0), (0, 1)]:
# decode all frames using new decoder
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
0, # getPtsOnly
readVideoStream,
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
readAudioStream,
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
vframes, vframe_pts, vtimebase, vfps, vduration, \
aframes, aframe_pts, atimebase, asample_rate, aduration = (
tv_result
)
self.assertEqual(vframes.numel() > 0, readVideoStream)
self.assertEqual(vframe_pts.numel() > 0, readVideoStream)
self.assertEqual(vtimebase.numel() > 0, readVideoStream)
self.assertEqual(vfps.numel() > 0, readVideoStream)
expect_audio_data = (
readAudioStream == 1 and config.audio_sample_rate is not None
)
self.assertEqual(aframes.numel() > 0, expect_audio_data)
self.assertEqual(aframe_pts.numel() > 0, expect_audio_data)
self.assertEqual(atimebase.numel() > 0, expect_audio_data)
self.assertEqual(asample_rate.numel() > 0, expect_audio_data)
def test_read_video_from_file_rescale_min_dimension(self):
"""
Test the case when decoder starts with a video file to decode frames, and
video min dimension between height and width is set.
"""
# video related
width, height, min_dimension, max_dimension = 0, 0, 128, 0
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
samples, channels = 0, 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
self.assertEqual(
min_dimension, min(tv_result[0].size(1), tv_result[0].size(2))
)
def test_read_video_from_file_rescale_max_dimension(self):
"""
Test the case when decoder starts with a video file to decode frames, and
video min dimension between height and width is set.
"""
# video related
width, height, min_dimension, max_dimension = 0, 0, 0, 85
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
samples, channels = 0, 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
self.assertEqual(
max_dimension, max(tv_result[0].size(1), tv_result[0].size(2))
)
def test_read_video_from_file_rescale_both_min_max_dimension(self):
"""
Test the case when decoder starts with a video file to decode frames, and
video min dimension between height and width is set.
"""
# video related
width, height, min_dimension, max_dimension = 0, 0, 64, 85
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
samples, channels = 0, 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
self.assertEqual(
min_dimension, min(tv_result[0].size(1), tv_result[0].size(2))
)
self.assertEqual(
max_dimension, max(tv_result[0].size(1), tv_result[0].size(2))
)
def test_read_video_from_file_rescale_width(self):
"""
Test the case when decoder starts with a video file to decode frames, and
video width is set.
"""
# video related
width, height, min_dimension, max_dimension = 256, 0, 0, 0
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
samples, channels = 0, 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
self.assertEqual(tv_result[0].size(2), width)
def test_read_video_from_file_rescale_height(self):
"""
Test the case when decoder starts with a video file to decode frames, and
video height is set.
"""
# video related
width, height, min_dimension, max_dimension = 0, 224, 0, 0
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
samples, channels = 0, 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
self.assertEqual(tv_result[0].size(1), height)
def test_read_video_from_file_rescale_width_and_height(self):
"""
Test the case when decoder starts with a video file to decode frames, and
both video height and width are set.
"""
# video related
width, height, min_dimension, max_dimension = 320, 240, 0, 0
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
samples, channels = 0, 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
self.assertEqual(tv_result[0].size(1), height)
self.assertEqual(tv_result[0].size(2), width)
def test_read_video_from_file_audio_resampling(self):
"""
Test the case when decoder starts with a video file to decode frames, and
audio waveform are resampled
"""
for samples in [9600, 96000]: # downsampling # upsampling
# video related
width, height, min_dimension, max_dimension = 0, 0, 0, 0
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
channels = 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
vframes, vframe_pts, vtimebase, vfps, vduration, \
aframes, aframe_pts, atimebase, asample_rate, aduration = (
tv_result
)
if aframes.numel() > 0:
self.assertEqual(samples, asample_rate.item())
self.assertEqual(1, aframes.size(1))
# when audio stream is found
duration = (
float(aframe_pts[-1])
* float(atimebase[0])
/ float(atimebase[1])
)
self.assertAlmostEqual(
aframes.size(0),
int(duration * asample_rate.item()),
delta=0.1 * asample_rate.item(),
)
def test_compare_read_video_from_memory_and_file(self):
"""
Test the case when video is already in memory, and decoder reads data in memory
"""
# video related
width, height, min_dimension, max_dimension = 0, 0, 0, 0
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
samples, channels = 0, 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
# pass 1: decode all frames using cpp decoder
tv_result_memory = torch.ops.video_reader.read_video_from_memory(
video_tensor,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
self.check_separate_decoding_result(tv_result_memory, config)
# pass 2: decode all frames from file
tv_result_file = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
self.check_separate_decoding_result(tv_result_file, config)
# finally, compare results decoded from memory and file
self.compare_decoding_result(tv_result_memory, tv_result_file)
def test_read_video_from_memory(self):
"""
Test the case when video is already in memory, and decoder reads data in memory
"""
# video related
width, height, min_dimension, max_dimension = 0, 0, 0, 0
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
samples, channels = 0, 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
# pass 1: decode all frames using cpp decoder
tv_result = torch.ops.video_reader.read_video_from_memory(
video_tensor,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
# pass 2: decode all frames using av
pyav_result = _decode_frames_by_av_module(full_path)
self.check_separate_decoding_result(tv_result, config)
self.compare_decoding_result(tv_result, pyav_result, config)
def test_read_video_from_memory_get_pts_only(self):
"""
Test the case when video is already in memory, and decoder reads data in memory.
Compare frame pts between decoding for pts only and full decoding
for both pts and frame data
"""
# video related
width, height, min_dimension, max_dimension = 0, 0, 0, 0
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
samples, channels = 0, 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
# pass 1: decode all frames using cpp decoder
tv_result = torch.ops.video_reader.read_video_from_memory(
video_tensor,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
self.assertAlmostEqual(config.video_fps, tv_result[3].item(), delta=0.01)
# pass 2: decode all frames to get PTS only using cpp decoder
tv_result_pts_only = torch.ops.video_reader.read_video_from_memory(
video_tensor,
seek_frame_margin,
1, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
self.assertEqual(tv_result_pts_only[0].numel(), 0)
self.assertEqual(tv_result_pts_only[5].numel(), 0)
self.compare_decoding_result(tv_result, tv_result_pts_only)
def test_read_video_in_range_from_memory(self):
"""
Test the case when video is already in memory, and decoder reads data in memory.
In addition, decoder takes meaningful start- and end PTS as input, and decode
frames within that interval
"""
for test_video, config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
# video related
width, height, min_dimension, max_dimension = 0, 0, 0, 0
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
samples, channels = 0, 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
# pass 1: decode all frames using new decoder
tv_result = torch.ops.video_reader.read_video_from_memory(
video_tensor,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
vframes, vframe_pts, vtimebase, vfps, vduration, \
aframes, aframe_pts, atimebase, asample_rate, aduration = (
tv_result
)
self.assertAlmostEqual(config.video_fps, vfps.item(), delta=0.01)
for num_frames in [4, 8, 16, 32, 64, 128]:
start_pts_ind_max = vframe_pts.size(0) - num_frames
if start_pts_ind_max <= 0:
continue
# randomly pick start pts
start_pts_ind = randint(0, start_pts_ind_max)
end_pts_ind = start_pts_ind + num_frames - 1
video_start_pts = vframe_pts[start_pts_ind]
video_end_pts = vframe_pts[end_pts_ind]
video_timebase_num, video_timebase_den = vtimebase[0], vtimebase[1]
if len(atimebase) > 0:
# when audio stream is available
audio_timebase_num, audio_timebase_den = atimebase[0], atimebase[1]
audio_start_pts = _pts_convert(
video_start_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(audio_timebase_num.item(), audio_timebase_den.item()),
math.floor,
)
audio_end_pts = _pts_convert(
video_end_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(audio_timebase_num.item(), audio_timebase_den.item()),
math.ceil,
)
# pass 2: decode frames in the randomly generated range
tv_result = torch.ops.video_reader.read_video_from_memory(
video_tensor,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
# pass 3: decode frames in range using PyAv
video_timebase_av, audio_timebase_av = _get_timebase_by_av_module(
full_path
)
video_start_pts_av = _pts_convert(
video_start_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(
video_timebase_av.numerator, video_timebase_av.denominator
),
math.floor,
)
video_end_pts_av = _pts_convert(
video_end_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(
video_timebase_av.numerator, video_timebase_av.denominator
),
math.ceil,
)
if audio_timebase_av:
audio_start_pts = _pts_convert(
video_start_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(
audio_timebase_av.numerator, audio_timebase_av.denominator
),
math.floor,
)
audio_end_pts = _pts_convert(
video_end_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(
audio_timebase_av.numerator, audio_timebase_av.denominator
),
math.ceil,
)
pyav_result = _decode_frames_by_av_module(
full_path,
video_start_pts_av,
video_end_pts_av,
audio_start_pts,
audio_end_pts,
)
self.assertEqual(tv_result[0].size(0), num_frames)
if pyav_result.vframes.size(0) == num_frames:
# if PyAv decodes a different number of video frames, skip
# comparing the decoding results between Torchvision video reader
# and PyAv
self.compare_decoding_result(tv_result, pyav_result, config)
def test_probe_video_from_file(self):
"""
Test the case when decoder probes a video file
"""
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
probe_result = torch.ops.video_reader.probe_video_from_file(full_path)
self.check_probe_result(probe_result, config)
def test_probe_video_from_memory(self):
"""
Test the case when decoder probes a video in memory
"""
for test_video, config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
probe_result = torch.ops.video_reader.probe_video_from_memory(video_tensor)
self.check_probe_result(probe_result, config)
def test_probe_video_from_memory_script(self):
scripted_fun = torch.jit.script(io._probe_video_from_memory)
self.assertIsNotNone(scripted_fun)
for test_video, config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
probe_result = scripted_fun(video_tensor)
self.check_meta_result(probe_result, config)
def test_read_video_from_memory_scripted(self):
"""
Test the case when video is already in memory, and decoder reads data in memory
"""
# video related
width, height, min_dimension, max_dimension = 0, 0, 0, 0
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
samples, channels = 0, 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
scripted_fun = torch.jit.script(io._read_video_from_memory)
self.assertIsNotNone(scripted_fun)
for test_video, _config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
# decode all frames using cpp decoder
scripted_fun(
video_tensor,
seek_frame_margin,
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
[video_start_pts, video_end_pts],
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
[audio_start_pts, audio_end_pts],
audio_timebase_num,
audio_timebase_den,
)
# FUTURE: check value of video / audio frames
if __name__ == "__main__":
unittest.main()
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
project(test_frcnn_tracing)
find_package(Torch REQUIRED)
find_package(TorchVision REQUIRED)
# This due to some headers importing Python.h
find_package(Python3 COMPONENTS Development)
add_executable(test_frcnn_tracing test_frcnn_tracing.cpp)
target_compile_features(test_frcnn_tracing PUBLIC cxx_range_for)
target_link_libraries(test_frcnn_tracing ${TORCH_LIBRARIES} TorchVision::TorchVision Python3::Python)
set_property(TARGET test_frcnn_tracing PROPERTY CXX_STANDARD 14)
#include <ATen/ATen.h>
#include <torch/script.h>
#include <torch/torch.h>
#include <torchvision/ROIAlign.h>
#include <torchvision/cpu/vision_cpu.h>
#include <torchvision/nms.h>
#ifdef _WIN32
// Windows only
// This is necessary until operators are automatically registered on include
static auto _nms = &nms_cpu;
#endif
int main() {
torch::DeviceType device_type;
device_type = torch::kCPU;
torch::jit::script::Module module;
try {
std::cout << "Loading model\n";
// Deserialize the ScriptModule from a file using torch::jit::load().
module = torch::jit::load("fasterrcnn_resnet50_fpn.pt");
std::cout << "Model loaded\n";
} catch (const torch::Error& e) {
std::cout << "error loading the model\n";
return -1;
} catch (const std::exception& e) {
std::cout << "Other error: " << e.what() << "\n";
return -1;
}
// TorchScript models require a List[IValue] as input
std::vector<torch::jit::IValue> inputs;
// Faster RCNN accepts a List[Tensor] as main input
std::vector<torch::Tensor> images;
images.push_back(torch::rand({3, 256, 275}));
images.push_back(torch::rand({3, 256, 275}));
inputs.push_back(images);
auto output = module.forward(inputs);
std::cout << "ok\n";
std::cout << "output" << output << "\n";
if (torch::cuda::is_available()) {
// Move traced model to GPU
module.to(torch::kCUDA);
// Add GPU inputs
images.clear();
inputs.clear();
torch::TensorOptions options = torch::TensorOptions{torch::kCUDA};
images.push_back(torch::rand({3, 256, 275}, options));
images.push_back(torch::rand({3, 256, 275}, options));
inputs.push_back(images);
auto output = module.forward(inputs);
std::cout << "ok\n";
std::cout << "output" << output << "\n";
}
return 0;
}
import os.path as osp
import torch
import torchvision
HERE = osp.dirname(osp.abspath(__file__))
ASSETS = osp.dirname(osp.dirname(HERE))
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False)
model.eval()
traced_model = torch.jit.script(model)
traced_model.save("fasterrcnn_resnet50_fpn.pt")
import warnings
import os
from .extension import _HAS_OPS
from torchvision import models
from torchvision import datasets
from torchvision import ops
from torchvision import transforms
from torchvision import utils
from torchvision import io
import torch
try:
from .version import __version__ # noqa: F401
except ImportError:
pass
# Check if torchvision is being imported within the root folder
if (not _HAS_OPS and os.path.dirname(os.path.realpath(__file__)) ==
os.path.join(os.path.realpath(os.getcwd()), 'torchvision')):
message = ('You are importing torchvision within its own root folder ({}). '
'This is not expected to work and may give errors. Please exit the '
'torchvision project source and relaunch your python interpreter.')
warnings.warn(message.format(os.getcwd()))
_image_backend = 'PIL'
_video_backend = "pyav"
def set_image_backend(backend):
"""
Specifies the package used to load images.
Args:
backend (string): Name of the image backend. one of {'PIL', 'accimage'}.
The :mod:`accimage` package uses the Intel IPP library. It is
generally faster than PIL, but does not support as many operations.
"""
global _image_backend
if backend not in ['PIL', 'accimage']:
raise ValueError("Invalid backend '{}'. Options are 'PIL' and 'accimage'"
.format(backend))
_image_backend = backend
def get_image_backend():
"""
Gets the name of the package used to load images
"""
return _image_backend
def set_video_backend(backend):
"""
Specifies the package used to decode videos.
Args:
backend (string): Name of the video backend. one of {'pyav', 'video_reader'}.
The :mod:`pyav` package uses the 3rd party PyAv library. It is a Pythonic
binding for the FFmpeg libraries.
The :mod:`video_reader` package includes a native C++ implementation on
top of FFMPEG libraries, and a python API of TorchScript custom operator.
It is generally decoding faster than :mod:`pyav`, but perhaps is less robust.
"""
global _video_backend
if backend not in ["pyav", "video_reader"]:
raise ValueError(
"Invalid video backend '%s'. Options are 'pyav' and 'video_reader'" % backend
)
if backend == "video_reader" and not io._HAS_VIDEO_OPT:
message = (
"video_reader video backend is not available."
" Please compile torchvision from source and try again"
)
warnings.warn(message)
else:
_video_backend = backend
def get_video_backend():
return _video_backend
def _is_tracing():
return torch._C._get_tracing_state()
#pragma once
#if defined(WITH_CUDA) || defined(WITH_HIP)
#include "autocast.h"
#endif
// TODO: put this stuff in torchvision namespace
at::Tensor deform_conv2d(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
const int64_t stride_h,
const int64_t stride_w,
const int64_t pad_h,
const int64_t pad_w,
const int64_t dilation_h,
const int64_t dilation_w,
const int64_t groups,
const int64_t offset_groups) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::deform_conv2d", "")
.typed<decltype(deform_conv2d)>();
return op.call(
input,
weight,
offset,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups);
}
#if defined(WITH_CUDA) || defined(WITH_HIP)
at::Tensor DeformConv2d_autocast(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
const int64_t stride_h,
const int64_t stride_w,
const int64_t pad_h,
const int64_t pad_w,
const int64_t dilation_h,
const int64_t dilation_w,
const int64_t groups,
const int64_t offset_groups) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
return deform_conv2d(
at::autocast::cached_cast(at::kFloat, input),
at::autocast::cached_cast(at::kFloat, weight),
at::autocast::cached_cast(at::kFloat, offset),
at::autocast::cached_cast(at::kFloat, bias),
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups)
.to(input.scalar_type());
}
#endif
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
_deform_conv2d_backward(
const at::Tensor& grad,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
const int64_t stride_h,
const int64_t stride_w,
const int64_t pad_h,
const int64_t pad_w,
const int64_t dilation_h,
const int64_t dilation_w,
const int64_t groups,
const int64_t offset_groups) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::_deform_conv2d_backward", "")
.typed<decltype(_deform_conv2d_backward)>();
return op.call(
grad,
input,
weight,
offset,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups);
}
class DeformConv2dFunction
: public torch::autograd::Function<DeformConv2dFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
torch::autograd::Variable input,
torch::autograd::Variable weight,
torch::autograd::Variable offset,
torch::autograd::Variable bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups) {
at::AutoNonVariableTypeMode g; // TODO_vv: check if necessary
auto output = deform_conv2d(
input,
weight,
offset,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups);
ctx->save_for_backward({input, weight, offset, bias});
ctx->saved_data["stride_h"] = stride_h;
ctx->saved_data["stride_w"] = stride_w;
ctx->saved_data["pad_h"] = pad_h;
ctx->saved_data["pad_w"] = pad_w;
ctx->saved_data["dilation_h"] = dilation_h;
ctx->saved_data["dilation_w"] = dilation_w;
ctx->saved_data["groups"] = groups;
ctx->saved_data["offset_groups"] = offset_groups;
return {
output,
};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output) {
auto saved = ctx->get_saved_variables();
auto input = saved[0];
auto weight = saved[1];
auto offset = saved[2];
auto bias = saved[3];
auto stride_h = ctx->saved_data["stride_h"].toInt();
auto stride_w = ctx->saved_data["stride_w"].toInt();
auto pad_h = ctx->saved_data["pad_h"].toInt();
auto pad_w = ctx->saved_data["pad_w"].toInt();
auto dilation_h = ctx->saved_data["dilation_h"].toInt();
auto dilation_w = ctx->saved_data["dilation_w"].toInt();
auto groups = ctx->saved_data["groups"].toInt();
auto offset_groups = ctx->saved_data["offset_groups"].toInt();
auto grads = _deform_conv2d_backward(
grad_output[0],
input,
weight,
offset,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups);
auto grad_input = std::get<0>(grads);
auto grad_weight = std::get<1>(grads);
auto grad_offset = std::get<2>(grads);
auto grad_bias = std::get<3>(grads);
return {
grad_input,
grad_weight,
grad_offset,
grad_bias,
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
};
}
};
// TODO: There should be an easier way to do this
class DeformConv2dBackwardFunction
: public torch::autograd::Function<DeformConv2dBackwardFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
torch::autograd::Variable grad,
torch::autograd::Variable input,
torch::autograd::Variable weight,
torch::autograd::Variable offset,
torch::autograd::Variable bias,
const int64_t stride_h,
const int64_t stride_w,
const int64_t pad_h,
const int64_t pad_w,
const int64_t dilation_h,
const int64_t dilation_w,
const int64_t groups,
const int64_t offset_groups) {
at::AutoNonVariableTypeMode g;
auto result = _deform_conv2d_backward(
grad,
input,
weight,
offset,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups);
auto grad_input = std::get<0>(result);
auto grad_weight = std::get<1>(result);
auto grad_offset = std::get<2>(result);
auto grad_bias = std::get<3>(result);
return {
grad_input,
grad_weight,
grad_offset,
grad_bias,
};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output) {
TORCH_CHECK(0, "double backwards on deform_conv2d not supported");
}
};
at::Tensor DeformConv2d_autograd(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
const int64_t stride_h,
const int64_t stride_w,
const int64_t pad_h,
const int64_t pad_w,
const int64_t dilation_h,
const int64_t dilation_w,
const int64_t groups,
const int64_t offset_groups) {
return DeformConv2dFunction::apply(
input,
weight,
offset,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups)[0];
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
DeformConv2d_backward_autograd(
const at::Tensor& grad,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
const int64_t stride_h,
const int64_t stride_w,
const int64_t pad_h,
const int64_t pad_w,
const int64_t dilation_h,
const int64_t dilation_w,
const int64_t groups,
const int64_t offset_groups) {
auto result = DeformConv2dBackwardFunction::apply(
grad,
input,
weight,
offset,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups);
return std::make_tuple(result[0], result[1], result[2], result[3]);
}
\ No newline at end of file
#pragma once
#include "cpu/vision_cpu.h"
#ifdef WITH_CUDA
#include "cuda/vision_cuda.h"
#endif
#ifdef WITH_HIP
#include "hip/vision_cuda.h"
#endif
#include <iostream>
std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio) {
if (input.is_cuda()) {
#if defined(WITH_CUDA) || defined(WITH_HIP)
return PSROIAlign_forward_cuda(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio);
#else
TORCH_CHECK(false, "Not compiled with GPU support");
#endif
}
return PSROIAlign_forward_cpu(
input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
}
at::Tensor PSROIAlign_backward(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& mapping_channel,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const int batch_size,
const int channels,
const int height,
const int width) {
if (grad.is_cuda()) {
#if defined(WITH_CUDA) || defined(WITH_HIP)
return PSROIAlign_backward_cuda(
grad,
rois,
mapping_channel,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
batch_size,
channels,
height,
width);
#else
TORCH_CHECK(false, "Not compiled with GPU support");
#endif
}
return PSROIAlign_backward_cpu(
grad,
rois,
mapping_channel,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
batch_size,
channels,
height,
width);
}
class PSROIAlignFunction
: public torch::autograd::Function<PSROIAlignFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
torch::autograd::Variable input,
torch::autograd::Variable rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio) {
ctx->saved_data["spatial_scale"] = spatial_scale;
ctx->saved_data["pooled_height"] = pooled_height;
ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["sampling_ratio"] = sampling_ratio;
ctx->saved_data["input_shape"] = input.sizes();
auto result = PSROIAlign_forward(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio);
auto output = std::get<0>(result);
auto channel_mapping = std::get<1>(result);
ctx->save_for_backward({rois, channel_mapping});
ctx->mark_non_differentiable({channel_mapping});
return {output, channel_mapping};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output) {
// Use data saved in forward
auto saved = ctx->get_saved_variables();
auto rois = saved[0];
auto channel_mapping = saved[1];
auto input_shape = ctx->saved_data["input_shape"].toIntList();
auto grad_in = PSROIAlign_backward(
grad_output[0],
rois,
channel_mapping,
ctx->saved_data["spatial_scale"].toDouble(),
ctx->saved_data["pooled_height"].toInt(),
ctx->saved_data["pooled_width"].toInt(),
ctx->saved_data["sampling_ratio"].toInt(),
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3]);
return {grad_in,
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable()};
}
};
std::tuple<at::Tensor, at::Tensor> ps_roi_align(
const at::Tensor& input,
const at::Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio) {
auto result = PSROIAlignFunction::apply(
input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
return std::tuple<at::Tensor, at::Tensor>(result[0], result[1]);
}
#pragma once
#include "cpu/vision_cpu.h"
#ifdef WITH_CUDA
#include "cuda/vision_cuda.h"
#endif
#ifdef WITH_HIP
#include "hip/vision_cuda.h"
#endif
std::tuple<at::Tensor, at::Tensor> PSROIPool_forward(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width) {
if (input.is_cuda()) {
#if defined(WITH_CUDA) || defined(WITH_HIP)
return PSROIPool_forward_cuda(
input, rois, spatial_scale, pooled_height, pooled_width);
#else
TORCH_CHECK(false, "Not compiled with GPU support");
#endif
}
return PSROIPool_forward_cpu(
input, rois, spatial_scale, pooled_height, pooled_width);
}
at::Tensor PSROIPool_backward(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& mapping_channel,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width) {
if (grad.is_cuda()) {
#if defined(WITH_CUDA) || defined(WITH_HIP)
return PSROIPool_backward_cuda(
grad,
rois,
mapping_channel,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width);
#else
TORCH_CHECK(false, "Not compiled with GPU support");
#endif
}
return PSROIPool_backward_cpu(
grad,
rois,
mapping_channel,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width);
}
class PSROIPoolFunction : public torch::autograd::Function<PSROIPoolFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
torch::autograd::Variable input,
torch::autograd::Variable rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width) {
ctx->saved_data["spatial_scale"] = spatial_scale;
ctx->saved_data["pooled_height"] = pooled_height;
ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["input_shape"] = input.sizes();
auto result = PSROIPool_forward(
input, rois, spatial_scale, pooled_height, pooled_width);
auto output = std::get<0>(result);
auto channel_mapping = std::get<1>(result);
ctx->save_for_backward({rois, channel_mapping});
ctx->mark_non_differentiable({channel_mapping});
return {output, channel_mapping};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output) {
// Use data saved in forward
auto saved = ctx->get_saved_variables();
auto rois = saved[0];
auto channel_mapping = saved[1];
auto input_shape = ctx->saved_data["input_shape"].toIntList();
auto grad_in = PSROIPool_backward(
grad_output[0],
rois,
channel_mapping,
ctx->saved_data["spatial_scale"].toDouble(),
ctx->saved_data["pooled_height"].toInt(),
ctx->saved_data["pooled_width"].toInt(),
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3]);
return {grad_in,
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable()};
}
};
std::tuple<at::Tensor, at::Tensor> ps_roi_pool(
const at::Tensor& input,
const at::Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width) {
auto result = PSROIPoolFunction::apply(
input, rois, spatial_scale, pooled_height, pooled_width);
return std::tuple<at::Tensor, at::Tensor>(result[0], result[1]);
}
#pragma once
#include "cpu/vision_cpu.h"
#ifdef WITH_CUDA
#include "autocast.h"
#include "cuda/vision_cuda.h"
#endif
#ifdef WITH_HIP
#include "autocast.h"
#include "hip/vision_cuda.h"
#endif
// TODO: put this stuff in torchvision namespace
// roi_align dispatch nexus
at::Tensor roi_align(
const at::Tensor& input, // Input feature map.
const at::Tensor& rois, // List of ROIs to pool over.
const double spatial_scale, // The scale of the image features. ROIs will be
// scaled to this.
const int64_t pooled_height, // The height of the pooled feature map.
const int64_t pooled_width, // The width of the pooled feature
const int64_t sampling_ratio, // The number of points to sample in each bin
const bool aligned) // The flag for pixel shift
// along each axis.
{
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::roi_align", "")
.typed<decltype(roi_align)>();
return op.call(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
aligned);
}
#if defined(WITH_CUDA) || defined(WITH_HIP)
at::Tensor ROIAlign_autocast(
const at::Tensor& input,
const at::Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio,
const bool aligned) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
return roi_align(
at::autocast::cached_cast(at::kFloat, input),
at::autocast::cached_cast(at::kFloat, rois),
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
aligned)
.to(input.scalar_type());
}
#endif
at::Tensor _roi_align_backward(
const at::Tensor& grad,
const at::Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t batch_size,
const int64_t channels,
const int64_t height,
const int64_t width,
const int64_t sampling_ratio,
const bool aligned) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::_roi_align_backward", "")
.typed<decltype(_roi_align_backward)>();
return op.call(
grad,
rois,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width,
sampling_ratio,
aligned);
}
class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
torch::autograd::Variable input,
torch::autograd::Variable rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio,
const bool aligned) {
ctx->saved_data["spatial_scale"] = spatial_scale;
ctx->saved_data["pooled_height"] = pooled_height;
ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["sampling_ratio"] = sampling_ratio;
ctx->saved_data["aligned"] = aligned;
ctx->saved_data["input_shape"] = input.sizes();
ctx->save_for_backward({rois});
at::AutoNonVariableTypeMode g;
auto result = roi_align(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
aligned);
return {result};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output) {
// Use data saved in forward
auto saved = ctx->get_saved_variables();
auto rois = saved[0];
auto input_shape = ctx->saved_data["input_shape"].toIntList();
auto grad_in = _roi_align_backward(
grad_output[0],
rois,
ctx->saved_data["spatial_scale"].toDouble(),
ctx->saved_data["pooled_height"].toInt(),
ctx->saved_data["pooled_width"].toInt(),
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
ctx->saved_data["sampling_ratio"].toInt(),
ctx->saved_data["aligned"].toBool());
return {grad_in,
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable()};
}
};
// TODO: There should be an easier way to do this
class ROIAlignBackwardFunction
: public torch::autograd::Function<ROIAlignBackwardFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
torch::autograd::Variable grad,
torch::autograd::Variable rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t batch_size,
const int64_t channels,
const int64_t height,
const int64_t width,
const int64_t sampling_ratio,
const bool aligned) {
at::AutoNonVariableTypeMode g;
auto result = _roi_align_backward(
grad,
rois,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width,
sampling_ratio,
aligned);
return {result};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output) {
TORCH_CHECK(0, "double backwards on roi_align not supported");
}
};
at::Tensor ROIAlign_autograd(
const at::Tensor& input,
const at::Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio,
const bool aligned) {
return ROIAlignFunction::apply(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
aligned)[0];
}
at::Tensor ROIAlign_backward_autograd(
const at::Tensor& grad,
const at::Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t batch_size,
const int64_t channels,
const int64_t height,
const int64_t width,
const int64_t sampling_ratio,
const bool aligned) {
return ROIAlignBackwardFunction::apply(
grad,
rois,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width,
sampling_ratio,
aligned)[0];
}
#pragma once
#include "cpu/vision_cpu.h"
#ifdef WITH_CUDA
#include "cuda/vision_cuda.h"
#endif
#ifdef WITH_HIP
#include "hip/vision_cuda.h"
#endif
std::tuple<at::Tensor, at::Tensor> ROIPool_forward(
const at::Tensor& input,
const at::Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width) {
if (input.is_cuda()) {
#if defined(WITH_CUDA) || defined(WITH_HIP)
return ROIPool_forward_cuda(
input, rois, spatial_scale, pooled_height, pooled_width);
#else
TORCH_CHECK(false, "Not compiled with GPU support");
#endif
}
return ROIPool_forward_cpu(
input, rois, spatial_scale, pooled_height, pooled_width);
}
at::Tensor ROIPool_backward(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& argmax,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width) {
if (grad.is_cuda()) {
#if defined(WITH_CUDA) || defined(WITH_HIP)
return ROIPool_backward_cuda(
grad,
rois,
argmax,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width);
#else
TORCH_CHECK(false, "Not compiled with GPU support");
#endif
}
return ROIPool_backward_cpu(
grad,
rois,
argmax,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width);
}
class ROIPoolFunction : public torch::autograd::Function<ROIPoolFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
torch::autograd::Variable input,
torch::autograd::Variable rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width) {
ctx->saved_data["spatial_scale"] = spatial_scale;
ctx->saved_data["pooled_height"] = pooled_height;
ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["input_shape"] = input.sizes();
auto result = ROIPool_forward(
input, rois, spatial_scale, pooled_height, pooled_width);
auto output = std::get<0>(result);
auto argmax = std::get<1>(result);
ctx->save_for_backward({rois, argmax});
ctx->mark_non_differentiable({argmax});
return {output, argmax};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output) {
// Use data saved in forward
auto saved = ctx->get_saved_variables();
auto rois = saved[0];
auto argmax = saved[1];
auto input_shape = ctx->saved_data["input_shape"].toIntList();
auto grad_in = ROIPool_backward(
grad_output[0],
rois,
argmax,
ctx->saved_data["spatial_scale"].toDouble(),
ctx->saved_data["pooled_height"].toInt(),
ctx->saved_data["pooled_width"].toInt(),
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3]);
return {grad_in,
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable()};
}
};
std::tuple<at::Tensor, at::Tensor> roi_pool(
const at::Tensor& input,
const at::Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width) {
auto result = ROIPoolFunction::apply(
input, rois, spatial_scale, pooled_height, pooled_width);
return std::tuple<at::Tensor, at::Tensor>(result[0], result[1]);
}
#pragma once
#if defined(WITH_CUDA) || defined(WITH_HIP)
#include <ATen/autocast_mode.h>
#endif
/*!
******************* BEGIN Caffe Copyright Notice and Disclaimer
*****************
*
* COPYRIGHT
*
* All contributions by the University of California:
* Copyright (c) 2014-2017 The Regents of the University of California (Regents)
* All rights reserved.
*
* All other contributions:
* Copyright (c) 2014-2017, the respective contributors
* All rights reserved.
*
* Caffe uses a shared copyright model: each contributor holds copyright over
* their contributions to Caffe. The project versioning records all such
* contribution and copyright details. If a contributor wants to further mark
* their specific copyright on a particular contribution, they should indicate
* their copyright solely in the commit message of the change when it is
* committed.
*
* LICENSE
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
*FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
*DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
*SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
*CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
*OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
*OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* CONTRIBUTION AGREEMENT
*
* By contributing to the BVLC/caffe repository through pull-request, comment,
* or otherwise, the contributor releases their content to the
* license and copyright terms herein.
*
***************** END Caffe Copyright Notice and Disclaimer
*********************
*
* Copyright (c) 2018 Microsoft
* Licensed under The MIT License [see LICENSE for details]
* \file modulated_deformable_im2col.cuh
* \brief Function definitions of converting an image to
* column matrix based on kernel, padding, dilation, and offset.
* These functions are mainly used in deformable convolution operators.
* \ref: https://arxiv.org/abs/1703.06211
* \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
*/
// modified from
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
// modified from
// https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp
#include <ATen/ATen.h>
#include <ATen/TensorUtils.h>
#include <TH/TH.h>
#include <cmath>
#include <iostream>
#include <tuple>
const int kMaxParallelImgs = 32;
template <typename scalar_t>
static scalar_t bilinear_interpolate(
const scalar_t* in,
const int height,
const int width,
scalar_t h,
scalar_t w) {
if (h <= -1 || height <= h || w <= -1 || width <= w) {
return 0;
}
int h_low = floor(h);
int w_low = floor(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
scalar_t lh = h - h_low;
scalar_t lw = w - w_low;
scalar_t hh = 1 - lh, hw = 1 - lw;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0)
v1 = in[h_low * width + w_low];
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
v2 = in[h_low * width + w_high];
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
v3 = in[h_high * width + w_low];
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
v4 = in[h_high * width + w_high];
scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename scalar_t>
static void deformable_im2col_kernel(
const int n,
const scalar_t* input,
const scalar_t* offset,
const int height,
const int width,
const int weight_h,
const int weight_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dil_h,
const int dil_w,
const int batch_sz,
const int n_in_channels,
const int n_offset_grps,
const int out_h,
const int out_w,
scalar_t* columns) {
for (int index = 0; index != n; ++index) {
const int out_x = index % out_w;
const int out_y = (index / out_w) % out_h;
const int out_b = (index / (out_w * out_h)) % batch_sz;
const int in_c = index / (out_w * out_h * batch_sz);
const int out_c = in_c * weight_h * weight_w;
int c_per_offset_grp = n_in_channels / n_offset_grps;
const int grp_idx = in_c / c_per_offset_grp;
auto columns_ptr = columns +
(out_c * (batch_sz * out_h * out_w) + out_b * (out_h * out_w) +
out_y * out_w + out_x);
auto input_ptr = input +
(out_b * (n_in_channels * height * width) + in_c * (height * width));
auto offset_ptr = offset +
(out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * out_h *
out_w;
for (int i = 0; i < weight_h; ++i) {
for (int j = 0; j < weight_w; ++j) {
const int offset_idx = 2 * (i * weight_w + j);
const scalar_t offset_h =
offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x];
const scalar_t offset_w = offset_ptr
[(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x];
const scalar_t y = (out_y * stride_h - pad_h) + i * dil_h + offset_h;
const scalar_t x = (out_x * stride_w - pad_w) + j * dil_w + offset_w;
*columns_ptr = bilinear_interpolate(input_ptr, height, width, y, x);
columns_ptr += batch_sz * out_h * out_w;
}
}
}
}
static void deformable_im2col(
const at::Tensor input,
const at::Tensor data_offset,
int n_in_channels,
int height,
int width,
int weight_h,
int weight_w,
int pad_h,
int pad_w,
int stride_h,
int stride_w,
int dil_h,
int dil_w,
int out_h,
int out_w,
int parallel_imgs,
int deformable_group,
at::Tensor data_col) {
int num_kernels = n_in_channels * out_h * out_w * parallel_imgs;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "deformable_im2col", ([&] {
deformable_im2col_kernel(
num_kernels,
input.data_ptr<scalar_t>(),
data_offset.data_ptr<scalar_t>(),
height,
width,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dil_h,
dil_w,
parallel_imgs,
n_in_channels,
deformable_group,
out_h,
out_w,
data_col.data_ptr<scalar_t>());
}));
}
static int get_greatest_divisor_below_bound(int n, int bound) {
for (int k = bound; k > 1; --k) {
if (n % k == 0) {
return k;
}
}
return 1;
}
at::Tensor DeformConv2d_forward_cpu(
const at::Tensor& input_param,
const at::Tensor& weight_param,
const at::Tensor& offset_param,
const at::Tensor& bias_param,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dil_h,
int64_t dil_w,
int64_t n_weight_grps,
int64_t n_offset_grps) {
at::Tensor input = input_param.contiguous();
at::Tensor offset = offset_param.contiguous();
at::Tensor weight = weight_param.contiguous();
at::Tensor bias = bias_param.contiguous();
TORCH_CHECK(input.ndimension() == 4);
TORCH_CHECK(offset.ndimension() == 4);
TORCH_CHECK(weight.ndimension() == 4);
TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor");
int batch_sz = input.size(0);
int n_in_channels = input.size(1);
int in_h = input.size(2);
int in_w = input.size(3);
int n_parallel_imgs =
get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs);
// Unpack shapes and args
int out_channels = weight.size(0);
int weight_h = weight.size(2);
int weight_w = weight.size(3);
int ker_h = dil_h * (weight_h - 1) + 1;
int ker_w = dil_w * (weight_w - 1) + 1;
int out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1;
int out_w = ((in_w + 2 * pad_w - ker_w) / stride_w) + 1;
TORCH_CHECK(
weight_h > 0 && weight_w > 0,
"weight_h: ",
weight_h,
" weight_w: ",
weight_w);
TORCH_CHECK(
stride_h > 0 && stride_w > 0,
"stride_h: ",
stride_h,
" stride_w: ",
stride_w);
TORCH_CHECK(pad_h >= 0 && pad_w >= 0, "pad_h: ", pad_h, " pad_w: ", pad_w);
TORCH_CHECK(dil_h > 0 && dil_w > 0, "dil_h: ", dil_h, " dil_w: ", dil_w);
TORCH_CHECK(weight.size(1) * n_weight_grps == input.size(1));
TORCH_CHECK(weight.size(0) % n_weight_grps == 0);
TORCH_CHECK(
(offset.size(1) == n_offset_grps * 2 * weight_h * weight_w),
"offset.shape[1] is not valid: got: ",
offset.size(1),
" expected: ",
n_offset_grps * 2 * weight_h * weight_w);
TORCH_CHECK(input.size(1) % n_offset_grps == 0);
TORCH_CHECK(
(offset.size(0) == input.size(0)), "invalid batch size of offset");
TORCH_CHECK(
(offset.size(2) == out_h && offset.size(3) == out_w),
"offset output dims: (",
offset.size(2),
", ",
offset.size(3),
") - ",
"computed output dims: (",
out_h,
", ",
out_w,
")");
TORCH_CHECK(
out_h > 0 && out_w > 0,
"Calculated output size too small - out_h: ",
out_h,
" out_w: ",
out_w);
auto out = at::zeros({batch_sz, out_channels, out_h, out_w}, input.options());
if (batch_sz == 0) {
return out;
}
// Separate batches into blocks
out = out.view({batch_sz / n_parallel_imgs,
n_parallel_imgs,
out_channels,
out_h,
out_w});
input = input.view(
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
offset = offset.view({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w,
out_h,
out_w});
at::Tensor out_buf = at::zeros(
{batch_sz / n_parallel_imgs,
out_channels,
n_parallel_imgs * out_h,
out_w},
out.options());
// Separate channels into convolution groups
out_buf = out_buf.view({out_buf.size(0),
n_weight_grps,
out_buf.size(1) / n_weight_grps,
out_buf.size(2),
out_buf.size(3)});
weight = weight.view({n_weight_grps,
weight.size(0) / n_weight_grps,
weight.size(1),
weight.size(2),
weight.size(3)});
// Sample points and perform convolution
auto columns = at::zeros(
{n_in_channels * weight_h * weight_w, n_parallel_imgs * out_h * out_w},
input.options());
for (int b = 0; b < batch_sz / n_parallel_imgs; b++) {
deformable_im2col(
input[b],
offset[b],
n_in_channels,
in_h,
in_w,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dil_h,
dil_w,
out_h,
out_w,
n_parallel_imgs,
n_offset_grps,
columns);
columns = columns.view(
{n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)});
for (int g = 0; g < n_weight_grps; g++) {
out_buf[b][g] = out_buf[b][g]
.flatten(1)
.addmm_(weight[g].flatten(1), columns[g])
.view_as(out_buf[b][g]);
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
}
out_buf = out_buf.view({batch_sz / n_parallel_imgs,
out_channels,
n_parallel_imgs,
out_h,
out_w});
out_buf.transpose_(1, 2);
out.copy_(out_buf);
out = out.view({batch_sz, out_channels, out_h, out_w});
return out + bias.view({1, out_channels, 1, 1});
}
template <typename scalar_t>
static void deformable_col2im_kernel(
const int n,
const scalar_t* col,
const scalar_t* offset,
const int channels,
const int height,
const int width,
const int kernel_h,
const int kernel_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int batch_sz,
const int n_offset_grps,
const int out_h,
const int out_w,
scalar_t* grad_im) {
for (int index = 0; index != n; ++index) {
const int out_x = index % out_w;
const int out_y = (index / out_w) % out_h;
const int b = (index / (out_w * out_h)) % batch_sz;
const int j = (index / (out_w * out_h * batch_sz)) % kernel_w;
const int i = (index / (out_w * out_h * batch_sz * kernel_w)) % kernel_h;
const int c = index / (out_w * out_h * batch_sz * kernel_w * kernel_h);
int c_per_offset_grp = channels / n_offset_grps;
const int offset_grp = c / c_per_offset_grp;
auto offset_ptr = offset +
(b * n_offset_grps + offset_grp) * 2 * kernel_h * kernel_w * out_h *
out_w;
const int offset_h_ptr =
((2 * (i * kernel_w + j)) * out_h + out_y) * out_w + out_x;
const int offset_w_ptr =
((2 * (i * kernel_w + j) + 1) * out_h + out_y) * out_w + out_x;
const scalar_t offset_h = offset_ptr[offset_h_ptr];
const scalar_t offset_w = offset_ptr[offset_w_ptr];
const scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h;
const scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w;
for (int dy = -1; dy <= 1; dy++) {
for (int dx = -1; dx <= 1; dx++) {
int yp = int(y) + dy;
int xp = int(x) + dx;
if (0 <= yp && yp < height && 0 <= xp && xp < width &&
std::abs(y - yp) < 1 && std::abs(x - xp) < 1) {
int grad_pos = ((b * channels + c) * height + yp) * width + xp;
scalar_t weight = (1 - std::abs(y - yp)) * (1 - std::abs(x - xp));
grad_im[grad_pos] += weight * col[index];
}
}
}
}
}
static void compute_grad_input(
const at::Tensor columns,
const at::Tensor offset,
const int channels,
const int height,
const int width,
const int weight_h,
const int weight_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int parallel_imgs,
const int n_offset_grps,
at::Tensor grad_im) {
int out_h =
(height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1;
int out_w =
(width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1;
int num_kernels =
channels * weight_h * weight_w * out_h * out_w * parallel_imgs;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
columns.scalar_type(), "deformable_col2im", ([&] {
deformable_col2im_kernel(
num_kernels,
columns.data_ptr<scalar_t>(),
offset.data_ptr<scalar_t>(),
channels,
height,
width,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
parallel_imgs,
n_offset_grps,
out_h,
out_w,
grad_im.data_ptr<scalar_t>());
}));
}
template <typename scalar_t>
static scalar_t get_coordinate_weight(
const scalar_t* im_data,
const int height,
const int width,
scalar_t y,
scalar_t x,
bool is_y_direction) {
int y_l = floor(y);
int x_l = floor(x);
int y_h = y_l + 1;
int x_h = x_l + 1;
bool valid_y_l = 0 <= y_l && y_l < height;
bool valid_y_h = 0 <= y_h && y_h < height;
bool valid_x_l = 0 <= x_l && x_l < width;
bool valid_x_h = 0 <= x_h && x_h < width;
scalar_t zero = 0;
scalar_t v_yx = (valid_y_l && valid_x_l) ? im_data[y_l * width + x_l] : zero;
scalar_t v_yX = (valid_y_l && valid_x_h) ? im_data[y_l * width + x_h] : zero;
scalar_t v_Yx = (valid_y_h && valid_x_l) ? im_data[y_h * width + x_l] : zero;
scalar_t v_YX = (valid_y_h && valid_x_h) ? im_data[y_h * width + x_h] : zero;
if (is_y_direction) {
scalar_t dx = x - x_l;
return dx * (v_YX - v_yX) + (1 - dx) * (v_Yx - v_yx);
} else {
scalar_t dy = y - y_l;
return dy * (v_YX - v_Yx) + (1 - dy) * (v_yX - v_yx);
}
}
template <typename scalar_t>
static void deformable_col2im_coord_kernel(
const int n,
const scalar_t* col,
const scalar_t* im,
const scalar_t* offset,
const int channels,
const int height,
const int width,
const int weight_h,
const int weight_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int batch_sz,
const int offset_channels,
const int n_offset_grps,
const int out_h,
const int out_w,
scalar_t* grad_offset) {
for (int index = 0; index != n; ++index) {
scalar_t val = 0;
int w = index % out_w;
int h = (index / out_w) % out_h;
int c = (index / (out_w * out_h)) % offset_channels;
int b = index / (out_w * out_h * offset_channels);
const int offset_grp = c / (2 * weight_h * weight_w);
const int col_step = weight_h * weight_w;
int c_per_offset_grp = channels / n_offset_grps;
auto col_ptr = col +
offset_grp * c_per_offset_grp * weight_h * weight_w * batch_sz * out_w *
out_h;
auto im_ptr = im +
(b * n_offset_grps + offset_grp) * c_per_offset_grp * height * width;
auto offset_ptr = offset +
(b * n_offset_grps + offset_grp) * 2 * weight_h * weight_w * out_h *
out_w;
const int offset_c = c - offset_grp * 2 * weight_h * weight_w;
const bool is_y_direction = offset_c % 2 == 0;
const int c_bound = c_per_offset_grp * weight_h * weight_w;
for (int col_c = (offset_c / 2); col_c < c_bound; col_c += col_step) {
const int col_pos = (((col_c * batch_sz + b) * out_h) + h) * out_w + w;
int out_x = col_pos % out_w;
int out_y = (col_pos / out_w) % out_h;
int j = (col_pos / (out_w * out_h * batch_sz)) % weight_w;
int i = (col_pos / (out_w * out_h * batch_sz * weight_w)) % weight_h;
const int offset_h_idx =
(((2 * (i * weight_w + j)) * out_h + out_y) * out_w + out_x);
const int offset_w_idx =
(((2 * (i * weight_w + j) + 1) * out_h + out_y) * out_w + out_x);
const scalar_t offset_h = offset_ptr[offset_h_idx];
const scalar_t offset_w = offset_ptr[offset_w_idx];
scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h;
scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w;
const scalar_t weight =
get_coordinate_weight(im_ptr, height, width, y, x, is_y_direction);
val += weight * col_ptr[col_pos];
im_ptr += height * width;
}
grad_offset[index] = val;
}
}
static void compute_grad_offset(
const at::Tensor columns,
const at::Tensor input,
const at::Tensor offset,
const int channels,
const int height,
const int width,
const int weight_h,
const int weight_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int parallel_imgs,
const int n_offset_grps,
at::Tensor grad_offset) {
int out_h =
(height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1;
int out_w =
(width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1;
int num_kernels =
out_h * out_w * 2 * weight_h * weight_w * n_offset_grps * parallel_imgs;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
columns.scalar_type(), "deformable_col2im_coord", ([&] {
deformable_col2im_coord_kernel(
num_kernels,
columns.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
offset.data_ptr<scalar_t>(),
channels,
height,
width,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
parallel_imgs,
2 * weight_h * weight_w * n_offset_grps,
n_offset_grps,
out_h,
out_w,
grad_offset.data_ptr<scalar_t>());
}));
}
static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cpu(
at::Tensor input,
at::Tensor weight,
at::Tensor offset,
at::Tensor grad_out,
int stride_h,
int stride_w,
int pad_h,
int pad_w,
int dil_h,
int dil_w,
int n_weight_grps,
int n_offset_grps,
int n_parallel_imgs) {
int batch_sz = input.size(0);
int n_in_channels = input.size(1);
int in_h = input.size(2);
int in_w = input.size(3);
n_parallel_imgs = std::min(batch_sz, n_parallel_imgs);
long n_out_channels = weight.size(0);
int weight_h = weight.size(2);
int weight_w = weight.size(3);
long out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) / stride_h + 1;
long out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) / stride_w + 1;
auto grad_input = at::zeros_like(input);
auto grad_offset = at::zeros_like(offset);
if (batch_sz == 0) {
return std::make_tuple(grad_input, grad_offset);
}
auto columns = at::empty(
{n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w},
input.options());
// Separate into blocks
grad_input = grad_input.reshape(
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
input = input.reshape(
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
grad_offset = grad_offset.reshape({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w,
out_h,
out_w});
offset = offset.reshape({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w,
out_h,
out_w});
grad_out = grad_out
.reshape({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_weight_grps,
n_out_channels / n_weight_grps,
out_h,
out_w})
.permute({0, 2, 3, 1, 4, 5});
weight = weight.reshape({n_weight_grps,
weight.size(0) / n_weight_grps,
weight.size(1),
weight.size(2),
weight.size(3)});
columns = columns.view(
{n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)});
for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) {
columns.zero_();
// Separate into weight groups
for (int g = 0; g < n_weight_grps; g++) {
columns[g] = columns[g].addmm_(
weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1));
}
compute_grad_offset(
columns,
input[elt],
offset[elt],
n_in_channels,
in_h,
in_w,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dil_h,
dil_w,
n_parallel_imgs,
n_offset_grps,
grad_offset[elt]);
compute_grad_input(
columns,
offset[elt],
n_in_channels,
in_h,
in_w,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dil_h,
dil_w,
n_parallel_imgs,
n_offset_grps,
grad_input[elt]);
}
grad_input = grad_input.view({batch_sz, n_in_channels, in_h, in_w});
grad_offset = grad_offset.view(
{batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w});
return std::make_tuple(grad_input, grad_offset);
}
static at::Tensor deform_conv2d_backward_parameters_cpu(
at::Tensor input,
const at::Tensor& weight,
at::Tensor offset,
const at::Tensor& grad_out,
int stride_h,
int stride_w,
int pad_h,
int pad_w,
int dil_h,
int dil_w,
int n_weight_grps,
int n_offset_grps,
int n_parallel_imgs) {
int batch_sz = input.size(0);
int n_in_channels = input.size(1);
int in_h = input.size(2);
int in_w = input.size(3);
n_parallel_imgs = std::min(batch_sz, n_parallel_imgs);
long n_out_channels = weight.size(0);
int weight_h = weight.size(2);
int weight_w = weight.size(3);
long out_h = grad_out.size(2);
long out_w = grad_out.size(3);
auto grad_weight = at::zeros_like(weight);
if (batch_sz == 0) {
return grad_weight;
}
at::Tensor grad_out_buf = grad_out
.reshape({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_weight_grps,
n_out_channels / n_weight_grps,
out_h,
out_w})
.permute({0, 2, 3, 1, 4, 5})
.contiguous();
input = input.reshape(
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
offset = offset.reshape({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w,
out_h,
out_w});
grad_weight = grad_weight.view({n_weight_grps,
grad_weight.size(0) / n_weight_grps,
grad_weight.size(1),
grad_weight.size(2),
grad_weight.size(3)});
auto columns = at::empty(
{n_weight_grps,
n_in_channels * weight_w * weight_h / n_weight_grps,
n_parallel_imgs * out_h * out_w},
input.options());
for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) {
deformable_im2col(
input[elt],
offset[elt],
n_in_channels,
in_h,
in_w,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dil_h,
dil_w,
out_h,
out_w,
n_parallel_imgs,
n_offset_grps,
columns);
for (int g = 0; g < n_weight_grps; g++) {
grad_weight[g] =
grad_weight[g]
.flatten(1)
.addmm_(
grad_out_buf[elt][g].flatten(1), columns[g].transpose(1, 0))
.view_as(grad_weight[g]);
}
}
grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
grad_weight.size(2),
grad_weight.size(3),
grad_weight.size(4)});
return grad_weight;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
DeformConv2d_backward_cpu(
const at::Tensor& grad_out_param,
const at::Tensor& input_param,
const at::Tensor& weight_param,
const at::Tensor& offset_param,
const at::Tensor& bias_param,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dil_h,
int64_t dil_w,
int64_t n_weight_grps,
int64_t n_offset_grps) {
at::Tensor grad_out = grad_out_param.contiguous();
at::Tensor input = input_param.contiguous();
at::Tensor weight = weight_param.contiguous();
at::Tensor offset = offset_param.contiguous();
at::Tensor bias = bias_param.contiguous();
const int batch_sz = input.size(0);
const int n_parallel_imgs =
get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs);
auto grad_input_and_offset = deform_conv2d_backward_input_cpu(
input,
weight,
offset,
grad_out,
stride_h,
stride_w,
pad_h,
pad_w,
dil_h,
dil_w,
n_weight_grps,
n_offset_grps,
n_parallel_imgs);
auto grad_input = std::get<0>(grad_input_and_offset);
auto grad_offset = std::get<1>(grad_input_and_offset);
auto grad_weight = deform_conv2d_backward_parameters_cpu(
input,
weight,
offset,
grad_out,
stride_h,
stride_w,
pad_h,
pad_w,
dil_h,
dil_w,
n_weight_grps,
n_offset_grps,
n_parallel_imgs);
auto grad_bias = at::ones_like(bias) * grad_out.sum({0, 2, 3});
return std::make_tuple(grad_input, grad_weight, grad_offset, grad_bias);
}
#include <ATen/ATen.h>
#include <ATen/TensorUtils.h>
#include <TH/TH.h>
template <typename T>
T bilinear_interpolate(
const T* input,
const int height,
const int width,
T y,
T x,
const int index /* index for debug only*/) {
// deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) {
// empty
return 0;
}
if (y <= 0)
y = 0;
if (x <= 0)
x = 0;
int y_low = (int)y;
int x_low = (int)x;
int y_high;
int x_high;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (T)y_low;
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = (T)x_low;
} else {
x_high = x_low + 1;
}
T ly = y - y_low;
T lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
// do bilinear interpolation
T v1 = input[y_low * width + x_low];
T v2 = input[y_low * width + x_high];
T v3 = input[y_high * width + x_low];
T v4 = input[y_high * width + x_high];
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename T>
void PSROIAlignForwardCPU(
const int nthreads,
const T* input,
const T spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const T* rois,
const int channels_out,
T* output,
int* channel_mapping) {
int num_rois = nthreads / channels_out / pooled_width / pooled_height;
for (int n = 0; n < num_rois; n++) {
// [start, end) interval for spatial sampling
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
// Do not using rounding; this implementation detail is critical
T roi_start_w = offset_rois[1] * spatial_scale - static_cast<T>(0.5);
T roi_start_h = offset_rois[2] * spatial_scale - static_cast<T>(0.5);
T roi_end_w = offset_rois[3] * spatial_scale - static_cast<T>(0.5);
T roi_end_h = offset_rois[4] * spatial_scale - static_cast<T>(0.5);
T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
int c_in = 0;
for (int c_out = 0; c_out < channels_out; ++c_out) {
for (int ph = 0; ph < pooled_height; ++ph) {
for (int pw = 0; pw < pooled_width; ++pw) {
int index =
((n * channels_out + c_out) * pooled_height + ph) * pooled_width +
pw;
// Do not using floor/ceil; this implementation detail is critical
T hstart = static_cast<T>(ph) * bin_size_h + roi_start_h;
T wstart = static_cast<T>(pw) * bin_size_w + roi_start_w;
// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height);
int roi_bin_grid_w = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_width / pooled_width);
const T count = roi_bin_grid_h * roi_bin_grid_w;
const T* offset_input =
input + (roi_batch_ind * channels + c_in) * height * width;
T out_sum = 0;
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = hstart +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h);
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = wstart +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T val = bilinear_interpolate(
offset_input, height, width, y, x, index);
out_sum += val;
}
}
out_sum /= count;
output[index] = out_sum;
channel_mapping[index] = c_in;
c_in++;
}
}
}
}
}
template <typename T>
void bilinear_interpolate_gradient(
const int height,
const int width,
T y,
T x,
T& w1,
T& w2,
T& w3,
T& w4,
int& x_low,
int& x_high,
int& y_low,
int& y_high,
const int index /* index for debug only*/) {
// deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) {
// empty
w1 = w2 = w3 = w4 = 0.;
x_low = x_high = y_low = y_high = -1;
return;
}
if (y <= 0)
y = 0;
if (x <= 0)
x = 0;
y_low = (int)y;
x_low = (int)x;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (T)y_low;
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = (T)x_low;
} else {
x_high = x_low + 1;
}
T ly = y - y_low;
T lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
// reference in forward
// T v1 = input[y_low * width + x_low];
// T v2 = input[y_low * width + x_high];
// T v3 = input[y_high * width + x_low];
// T v4 = input[y_high * width + x_high];
// T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
}
template <class T>
inline void add(T* address, const T& val) {
*address += val;
}
template <typename T>
void PSROIAlignBackwardCPU(
const int nthreads,
const T* grad_output,
const int* channel_mapping,
const int num_rois,
const T spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const int channels_out,
T* grad_input,
const T* rois) {
for (int index = 0; index < nthreads; index++) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int n = index / pooled_width / pooled_height / channels_out;
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
// Do not using rounding; this implementation detail is critical
T roi_start_w = offset_rois[1] * spatial_scale - static_cast<T>(0.5);
T roi_start_h = offset_rois[2] * spatial_scale - static_cast<T>(0.5);
T roi_end_w = offset_rois[3] * spatial_scale - static_cast<T>(0.5);
T roi_end_h = offset_rois[4] * spatial_scale - static_cast<T>(0.5);
// Force too small ROIs to be 1x1
T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
T bin_size_h = roi_height / static_cast<T>(pooled_height);
T bin_size_w = roi_width / static_cast<T>(pooled_width);
int c_in = channel_mapping[index];
T* grad_input_offset =
grad_input + (roi_batch_ind * channels + c_in) * height * width;
// Do not using floor/ceil; this implementation detail is critical
T hstart = static_cast<T>(ph) * bin_size_h + roi_start_h;
T wstart = static_cast<T>(pw) * bin_size_w + roi_start_w;
const T grad_output_this_bin = grad_output[index];
// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
const T count = roi_bin_grid_h * roi_bin_grid_w;
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = hstart +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h);
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = wstart +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T w1, w2, w3, w4;
int x_low, x_high, y_low, y_high;
bilinear_interpolate_gradient(
height,
width,
y,
x,
w1,
w2,
w3,
w4,
x_low,
x_high,
y_low,
y_high,
index);
T g1 = grad_output_this_bin * w1 / count;
T g2 = grad_output_this_bin * w2 / count;
T g3 = grad_output_this_bin * w3 / count;
T g4 = grad_output_this_bin * w4 / count;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
add(grad_input_offset + y_low * width + x_low, g1);
add(grad_input_offset + y_low * width + x_high, g2);
add(grad_input_offset + y_high * width + x_low, g3);
add(grad_input_offset + y_high * width + x_high, g4);
} // if
} // ix
} // iy
}
}
std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cpu(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio) {
// Check if input tensors are CPU tensors
TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor");
TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor");
TORCH_CHECK(
rois.size(1) == 5, "Tensor rois should have shape as Tensor[K, 5]");
at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
at::CheckedFrom c = "PSROIAlign_forward_cpu";
at::checkAllSameType(c, {input_t, rois_t});
int num_rois = rois.size(0);
int channels = input.size(1);
int height = input.size(2);
int width = input.size(3);
TORCH_CHECK(
channels % (pooled_height * pooled_width) == 0,
"input channels must be a multiple of pooling height * pooling width");
int channels_out = channels / (pooled_height * pooled_width);
auto output = at::zeros(
{num_rois, channels_out, pooled_height, pooled_width}, input.options());
auto channel_mapping =
at::zeros(output.sizes(), input.options().dtype(at::kInt));
auto output_size = output.numel();
if (output_size == 0) {
return std::make_tuple(output, channel_mapping);
}
auto input_ = input.contiguous(), rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "PSROIAlign_forward", [&] {
PSROIAlignForwardCPU<scalar_t>(
output_size,
input_.data_ptr<scalar_t>(),
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
sampling_ratio,
rois_.data_ptr<scalar_t>(),
channels_out,
output.data_ptr<scalar_t>(),
channel_mapping.data_ptr<int>());
});
return std::make_tuple(output, channel_mapping);
}
at::Tensor PSROIAlign_backward_cpu(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const int batch_size,
const int channels,
const int height,
const int width) {
// Check if input tensors are CPU tensors
TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor");
TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor");
TORCH_CHECK(
channel_mapping.device().is_cpu(),
"channel_mapping must be a CPU tensor");
at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2},
channel_mapping_t{channel_mapping, "channel_mapping", 3};
at::CheckedFrom c = "PSROIAlign_backward_cpu";
at::checkAllSameType(c, {grad_t, rois_t});
auto num_rois = rois.size(0);
auto grad_input =
at::zeros({batch_size, channels, height, width}, grad.options());
// handle possibly empty gradients
if (grad.numel() == 0) {
return grad_input;
}
int channels_out = channels / (pooled_height * pooled_width);
auto grad_ = grad.contiguous(), rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "PSROIAlign_backward", [&] {
PSROIAlignBackwardCPU<scalar_t>(
grad.numel(),
grad_.data_ptr<scalar_t>(),
channel_mapping.data_ptr<int>(),
num_rois,
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
sampling_ratio,
channels_out,
grad_input.data_ptr<scalar_t>(),
rois_.data_ptr<scalar_t>());
});
return grad_input;
}
#include <ATen/ATen.h>
#include <ATen/TensorUtils.h>
#include <TH/TH.h>
#include <algorithm>
template <class T>
inline void add(T* address, const T& val) {
*address += val;
}
template <typename T>
void PSROIPoolForward(
const T* input,
const T spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const T* rois,
const int channels_out,
const int num_rois,
T* output,
int* channel_mapping) {
for (int n = 0; n < num_rois; ++n) {
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
int roi_start_w = round(offset_rois[1] * spatial_scale);
int roi_start_h = round(offset_rois[2] * spatial_scale);
int roi_end_w = round(offset_rois[3] * spatial_scale);
int roi_end_h = round(offset_rois[4] * spatial_scale);
// Force too small ROIs to be 1x1
int roi_width = std::max(roi_end_w - roi_start_w, 1);
int roi_height = std::max(roi_end_h - roi_start_h, 1);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
int c_in = 0;
for (int c_out = 0; c_out < channels_out; ++c_out) {
for (int ph = 0; ph < pooled_height; ++ph) {
for (int pw = 0; pw < pooled_width; ++pw) {
int hstart = static_cast<int>(floor(static_cast<T>(ph) * bin_size_h));
int wstart = static_cast<int>(floor(static_cast<T>(pw) * bin_size_w));
int hend =
static_cast<int>(ceil(static_cast<T>(ph + 1) * bin_size_h));
int wend =
static_cast<int>(ceil(static_cast<T>(pw + 1) * bin_size_w));
// Add roi offsets and clip to input boundaries
hstart = std::min(std::max(hstart + roi_start_h, 0), height - 1);
hend = std::min(std::max(hend + roi_start_h, 0), height - 1);
wstart = std::min(std::max(wstart + roi_start_w, 0), width - 1);
wend = std::min(std::max(wend + roi_start_w, 0), width - 1);
bool is_empty = (hend <= hstart) || (wend <= wstart);
const T* offset_input =
input + (roi_batch_ind * channels + c_in) * height * width;
T out_sum = 0;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int input_index = h * width + w;
out_sum += offset_input[input_index];
}
}
int index =
((n * channels_out + c_out) * pooled_height + ph) * pooled_width +
pw;
T bin_area = (hend - hstart) * (wend - wstart);
output[index] = is_empty ? static_cast<T>(0) : out_sum / bin_area;
channel_mapping[index] = c_in;
c_in++;
}
}
}
}
}
template <typename T>
void PSROIPoolBackward(
const T* grad_output,
const int* channel_mapping,
const int num_rois,
const T spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int channels_out,
T* grad_input,
const T* rois) {
for (int n = 0; n < num_rois; ++n) {
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
int roi_start_w = roundf(offset_rois[1] * spatial_scale);
int roi_start_h = roundf(offset_rois[2] * spatial_scale);
int roi_end_w = roundf(offset_rois[3] * spatial_scale);
int roi_end_h = roundf(offset_rois[4] * spatial_scale);
// Force too small ROIs to be 1x1
int roi_width = std::max(roi_end_w - roi_start_w, 1);
int roi_height = std::max(roi_end_h - roi_start_h, 1);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
for (int ph = 0; ph < pooled_height; ++ph) {
for (int pw = 0; pw < pooled_width; ++pw) {
int hstart = static_cast<int>(floor(static_cast<T>(ph) * bin_size_h));
int wstart = static_cast<int>(floor(static_cast<T>(pw) * bin_size_w));
int hend = static_cast<int>(ceil(static_cast<T>(ph + 1) * bin_size_h));
int wend = static_cast<int>(ceil(static_cast<T>(pw + 1) * bin_size_w));
// Add roi offsets and clip to input boundaries
hstart = std::min(std::max(hstart + roi_start_h, 0), height);
hend = std::min(std::max(hend + roi_start_h, 0), height);
wstart = std::min(std::max(wstart + roi_start_w, 0), width);
wend = std::min(std::max(wend + roi_start_w, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart);
for (int c_out = 0; c_out < channels_out; ++c_out) {
int index =
((n * channels_out + c_out) * pooled_height + ph) * pooled_width +
pw;
int c_in = channel_mapping[index];
T* grad_input_offset =
grad_input + (roi_batch_ind * channels + c_in) * height * width;
T bin_area = (hend - hstart) * (wend - wstart);
T diff_val =
is_empty ? static_cast<T>(0) : grad_output[index] / bin_area;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int grad_input_index = h * width + w;
add(grad_input_offset + grad_input_index, diff_val);
}
}
}
}
}
}
}
std::tuple<at::Tensor, at::Tensor> PSROIPool_forward_cpu(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width) {
// Check if input tensors are CPU tensors
TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor");
TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor");
TORCH_CHECK(
rois.size(1) == 5, "Tensor rois should have shape as Tensor[K, 5]");
at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
at::CheckedFrom c = "PSROIPool_forward_cpu";
at::checkAllSameType(c, {input_t, rois_t});
int num_rois = rois.size(0);
int channels = input.size(1);
int height = input.size(2);
int width = input.size(3);
TORCH_CHECK(
channels % (pooled_height * pooled_width) == 0,
"input channels must be a multiple of pooling height * pooling width");
int channels_out = channels / (pooled_height * pooled_width);
auto output = at::zeros(
{num_rois, channels_out, pooled_height, pooled_width}, input.options());
auto channel_mapping =
at::zeros(output.sizes(), input.options().dtype(at::kInt));
auto output_size = output.numel();
if (output_size == 0) {
return std::make_tuple(output, channel_mapping);
}
auto input_ = input.contiguous(), rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "PSROIPool_forward", [&] {
PSROIPoolForward<scalar_t>(
input_.data_ptr<scalar_t>(),
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
rois_.data_ptr<scalar_t>(),
channels_out,
num_rois,
output.data_ptr<scalar_t>(),
channel_mapping.data_ptr<int>());
});
return std::make_tuple(output, channel_mapping);
}
at::Tensor PSROIPool_backward_cpu(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width) {
// Check if input tensors are CPU tensors
TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor");
TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor");
TORCH_CHECK(
channel_mapping.device().is_cpu(),
"channel_mapping must be a CPU tensor");
at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2},
channel_mapping_t{channel_mapping, "channel_mapping", 3};
at::CheckedFrom c = "PSROIPool_backward_cpu";
at::checkAllSameType(c, {grad_t, rois_t});
auto num_rois = rois.size(0);
auto grad_input =
at::zeros({batch_size, channels, height, width}, grad.options());
// handle possibly empty gradients
if (grad.numel() == 0) {
return grad_input;
}
int channels_out = channels / (pooled_height * pooled_width);
auto grad_ = grad.contiguous(), rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "PSROIPool_backward", [&] {
PSROIPoolBackward<scalar_t>(
grad_.data_ptr<scalar_t>(),
channel_mapping.data_ptr<int>(),
num_rois,
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
channels_out,
grad_input.data_ptr<scalar_t>(),
rois_.data_ptr<scalar_t>());
});
return grad_input;
}
#include <ATen/TensorUtils.h>
#include "vision_cpu.h"
// implementation taken from Caffe2
template <typename T>
struct PreCalc {
int pos1;
int pos2;
int pos3;
int pos4;
T w1;
T w2;
T w3;
T w4;
};
template <typename T>
void pre_calc_for_bilinear_interpolate(
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int iy_upper,
const int ix_upper,
T roi_start_h,
T roi_start_w,
T bin_size_h,
T bin_size_w,
int roi_bin_grid_h,
int roi_bin_grid_w,
std::vector<PreCalc<T>>& pre_calc) {
int pre_calc_index = 0;
for (int ph = 0; ph < pooled_height; ph++) {
for (int pw = 0; pw < pooled_width; pw++) {
for (int iy = 0; iy < iy_upper; iy++) {
const T yy = roi_start_h + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
for (int ix = 0; ix < ix_upper; ix++) {
const T xx = roi_start_w + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T x = xx;
T y = yy;
// deal with: inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) {
// empty
PreCalc<T> pc;
pc.pos1 = 0;
pc.pos2 = 0;
pc.pos3 = 0;
pc.pos4 = 0;
pc.w1 = 0;
pc.w2 = 0;
pc.w3 = 0;
pc.w4 = 0;
pre_calc[pre_calc_index] = pc;
pre_calc_index += 1;
continue;
}
if (y <= 0) {
y = 0;
}
if (x <= 0) {
x = 0;
}
int y_low = (int)y;
int x_low = (int)x;
int y_high;
int x_high;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (T)y_low;
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = (T)x_low;
} else {
x_high = x_low + 1;
}
T ly = y - y_low;
T lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
// save weights and indeces
PreCalc<T> pc;
pc.pos1 = y_low * width + x_low;
pc.pos2 = y_low * width + x_high;
pc.pos3 = y_high * width + x_low;
pc.pos4 = y_high * width + x_high;
pc.w1 = w1;
pc.w2 = w2;
pc.w3 = w3;
pc.w4 = w4;
pre_calc[pre_calc_index] = pc;
pre_calc_index += 1;
}
}
}
}
}
template <typename T>
void ROIAlignForward(
const int nthreads,
const T* input,
const T& spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const bool aligned,
const T* rois,
T* output) {
int n_rois = nthreads / channels / pooled_width / pooled_height;
// (n, c, ph, pw) is an element in the pooled output
// can be parallelized using omp
// #pragma omp parallel for num_threads(32)
for (int n = 0; n < n_rois; n++) {
int index_n = n * channels * pooled_width * pooled_height;
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
// Do not using rounding; this implementation detail is critical
T offset = aligned ? (T)0.5 : (T)0.0;
T roi_start_w = offset_rois[1] * spatial_scale - offset;
T roi_start_h = offset_rois[2] * spatial_scale - offset;
T roi_end_w = offset_rois[3] * spatial_scale - offset;
T roi_end_h = offset_rois[4] * spatial_scale - offset;
T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
if (!aligned) {
// Force malformed ROIs to be 1x1
roi_width = std::max(roi_width, (T)1.);
roi_height = std::max(roi_height, (T)1.);
}
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
// We do average (integral) pooling inside a bin
// When the grid is empty, output zeros.
const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
// we want to precalculate indeces and weights shared by all chanels,
// this is the key point of optimiation
std::vector<PreCalc<T>> pre_calc(
roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height);
pre_calc_for_bilinear_interpolate(
height,
width,
pooled_height,
pooled_width,
roi_bin_grid_h,
roi_bin_grid_w,
roi_start_h,
roi_start_w,
bin_size_h,
bin_size_w,
roi_bin_grid_h,
roi_bin_grid_w,
pre_calc);
for (int c = 0; c < channels; c++) {
int index_n_c = index_n + c * pooled_width * pooled_height;
const T* offset_input =
input + (roi_batch_ind * channels + c) * height * width;
int pre_calc_index = 0;
for (int ph = 0; ph < pooled_height; ph++) {
for (int pw = 0; pw < pooled_width; pw++) {
int index = index_n_c + ph * pooled_width + pw;
T output_val = 0.;
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
PreCalc<T> pc = pre_calc[pre_calc_index];
output_val += pc.w1 * offset_input[pc.pos1] +
pc.w2 * offset_input[pc.pos2] +
pc.w3 * offset_input[pc.pos3] + pc.w4 * offset_input[pc.pos4];
pre_calc_index += 1;
}
}
output_val /= count;
output[index] = output_val;
} // for pw
} // for ph
} // for c
} // for n
}
template <typename T>
void bilinear_interpolate_gradient(
const int height,
const int width,
T y,
T x,
T& w1,
T& w2,
T& w3,
T& w4,
int& x_low,
int& x_high,
int& y_low,
int& y_high,
const int index /* index for debug only*/) {
// deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) {
// empty
w1 = w2 = w3 = w4 = 0.;
x_low = x_high = y_low = y_high = -1;
return;
}
if (y <= 0)
y = 0;
if (x <= 0)
x = 0;
y_low = (int)y;
x_low = (int)x;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (T)y_low;
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = (T)x_low;
} else {
x_high = x_low + 1;
}
T ly = y - y_low;
T lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
// reference in forward
// T v1 = input[y_low * width + x_low];
// T v2 = input[y_low * width + x_high];
// T v3 = input[y_high * width + x_low];
// T v4 = input[y_high * width + x_high];
// T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
}
template <class T>
inline void add(T* address, const T& val) {
*address += val;
}
template <typename T>
void ROIAlignBackward(
const int nthreads,
const T* grad_output,
const T& spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const bool aligned,
T* grad_input,
const T* rois,
const int n_stride,
const int c_stride,
const int h_stride,
const int w_stride) {
for (int index = 0; index < nthreads; index++) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
// Do not using rounding; this implementation detail is critical
T offset = aligned ? (T)0.5 : (T)0.0;
T roi_start_w = offset_rois[1] * spatial_scale - offset;
T roi_start_h = offset_rois[2] * spatial_scale - offset;
T roi_end_w = offset_rois[3] * spatial_scale - offset;
T roi_end_h = offset_rois[4] * spatial_scale - offset;
T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
if (!aligned) {
// Force malformed ROIs to be 1x1
roi_width = std::max(roi_width, (T)1.);
roi_height = std::max(roi_height, (T)1.);
}
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
T* offset_grad_input =
grad_input + ((roi_batch_ind * channels + c) * height * width);
int output_offset = n * n_stride + c * c_stride;
const T* offset_grad_output = grad_output + output_offset;
const T grad_output_this_bin =
offset_grad_output[ph * h_stride + pw * w_stride];
// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
// We do average (integral) pooling inside a bin
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = roi_start_h + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = roi_start_w + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T w1, w2, w3, w4;
int x_low, x_high, y_low, y_high;
bilinear_interpolate_gradient(
height,
width,
y,
x,
w1,
w2,
w3,
w4,
x_low,
x_high,
y_low,
y_high,
index);
T g1 = grad_output_this_bin * w1 / count;
T g2 = grad_output_this_bin * w2 / count;
T g3 = grad_output_this_bin * w3 / count;
T g4 = grad_output_this_bin * w4 / count;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
// atomic add is not needed for now since it is single threaded
add(offset_grad_input + y_low * width + x_low, static_cast<T>(g1));
add(offset_grad_input + y_low * width + x_high, static_cast<T>(g2));
add(offset_grad_input + y_high * width + x_low, static_cast<T>(g3));
add(offset_grad_input + y_high * width + x_high, static_cast<T>(g4));
} // if
} // ix
} // iy
} // for
} // ROIAlignBackward
at::Tensor ROIAlign_forward_cpu(
const at::Tensor& input,
const at::Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio,
const bool aligned) {
TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor");
TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor");
TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]");
at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
at::CheckedFrom c = "ROIAlign_forward_cpu";
at::checkAllSameType(c, {input_t, rois_t});
auto num_rois = rois.size(0);
auto channels = input.size(1);
auto height = input.size(2);
auto width = input.size(3);
at::Tensor output = at::zeros(
{num_rois, channels, pooled_height, pooled_width}, input.options());
auto output_size = num_rois * pooled_height * pooled_width * channels;
if (output.numel() == 0)
return output;
auto input_ = input.contiguous(), rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "ROIAlign_forward", [&] {
ROIAlignForward<scalar_t>(
output_size,
input_.data_ptr<scalar_t>(),
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
sampling_ratio,
aligned,
rois_.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>());
});
return output;
}
at::Tensor ROIAlign_backward_cpu(
const at::Tensor& grad,
const at::Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t batch_size,
const int64_t channels,
const int64_t height,
const int64_t width,
const int64_t sampling_ratio,
const bool aligned) {
TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor");
TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor");
at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2};
at::CheckedFrom c = "ROIAlign_backward_cpu";
at::checkAllSameType(c, {grad_t, rois_t});
at::Tensor grad_input =
at::zeros({batch_size, channels, height, width}, grad.options());
// handle possibly empty gradients
if (grad.numel() == 0) {
return grad_input;
}
// get stride values to ensure indexing into gradients is correct.
int n_stride = grad.stride(0);
int c_stride = grad.stride(1);
int h_stride = grad.stride(2);
int w_stride = grad.stride(3);
auto rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "ROIAlign_forward", [&] {
ROIAlignBackward<scalar_t>(
grad.numel(),
grad.data_ptr<scalar_t>(),
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
sampling_ratio,
aligned,
grad_input.data_ptr<scalar_t>(),
rois_.data_ptr<scalar_t>(),
n_stride,
c_stride,
h_stride,
w_stride);
});
return grad_input;
}
#include <ATen/ATen.h>
#include <ATen/TensorUtils.h>
#include <TH/TH.h>
#include <algorithm>
template <class T>
inline void add(T* address, const T& val) {
*address += val;
}
template <typename T>
void RoIPoolForward(
const T* input,
const T spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const T* rois,
const int num_rois,
T* output,
int* argmax_data) {
for (int n = 0; n < num_rois; ++n) {
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
int roi_start_w = round(offset_rois[1] * spatial_scale);
int roi_start_h = round(offset_rois[2] * spatial_scale);
int roi_end_w = round(offset_rois[3] * spatial_scale);
int roi_end_h = round(offset_rois[4] * spatial_scale);
// Force malformed ROIs to be 1x1
int roi_width = std::max(roi_end_w - roi_start_w + 1, 1);
int roi_height = std::max(roi_end_h - roi_start_h + 1, 1);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
for (int ph = 0; ph < pooled_height; ++ph) {
for (int pw = 0; pw < pooled_width; ++pw) {
int hstart = static_cast<int>(floor(static_cast<T>(ph) * bin_size_h));
int wstart = static_cast<int>(floor(static_cast<T>(pw) * bin_size_w));
int hend = static_cast<int>(ceil(static_cast<T>(ph + 1) * bin_size_h));
int wend = static_cast<int>(ceil(static_cast<T>(pw + 1) * bin_size_w));
// Add roi offsets and clip to input boundaries
hstart = std::min(std::max(hstart + roi_start_h, 0), height);
hend = std::min(std::max(hend + roi_start_h, 0), height);
wstart = std::min(std::max(wstart + roi_start_w, 0), width);
wend = std::min(std::max(wend + roi_start_w, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart);
for (int c = 0; c < channels; ++c) {
// Define an empty pooling region to be zero
T maxval = is_empty ? 0 : -FLT_MAX;
// If nothing is pooled, argmax = -1 causes nothing to be backprop'd
int maxidx = -1;
const T* input_offset =
input + (roi_batch_ind * channels + c) * height * width;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int input_index = h * width + w;
if (input_offset[input_index] > maxval) {
maxval = input_offset[input_index];
maxidx = input_index;
}
}
}
int index =
((n * channels + c) * pooled_height + ph) * pooled_width + pw;
output[index] = maxval;
argmax_data[index] = maxidx;
} // channels
} // pooled_width
} // pooled_height
} // num_rois
}
template <typename T>
void RoIPoolBackward(
const T* grad_output,
const int* argmax_data,
const int num_rois,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
T* grad_input,
const T* rois,
const int n_stride,
const int c_stride,
const int h_stride,
const int w_stride) {
for (int n = 0; n < num_rois; ++n) {
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
for (int c = 0; c < channels; ++c) {
T* grad_input_offset =
grad_input + ((roi_batch_ind * channels + c) * height * width);
const int* argmax_data_offset =
argmax_data + (n * channels + c) * pooled_height * pooled_width;
for (int ph = 0; ph < pooled_height; ++ph) {
for (int pw = 0; pw < pooled_width; ++pw) {
int output_offset = n * n_stride + c * c_stride;
int argmax = argmax_data_offset[ph * pooled_width + pw];
if (argmax != -1) {
add(grad_input_offset + argmax,
static_cast<T>(
grad_output
[output_offset + ph * h_stride + pw * w_stride]));
}
} // pooled_width
} // pooled_height
} // channels
} // num_rois
}
std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cpu(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width) {
TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor");
TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor");
at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
at::CheckedFrom c = "ROIPool_forward_cpu";
at::checkAllSameType(c, {input_t, rois_t});
int num_rois = rois.size(0);
int channels = input.size(1);
int height = input.size(2);
int width = input.size(3);
at::Tensor output = at::zeros(
{num_rois, channels, pooled_height, pooled_width}, input.options());
at::Tensor argmax = at::zeros(
{num_rois, channels, pooled_height, pooled_width},
input.options().dtype(at::kInt));
if (output.numel() == 0) {
return std::make_tuple(output, argmax);
}
auto input_ = input.contiguous(), rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "ROIPool_forward", [&] {
RoIPoolForward<scalar_t>(
input_.data_ptr<scalar_t>(),
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
rois_.data_ptr<scalar_t>(),
num_rois,
output.data_ptr<scalar_t>(),
argmax.data_ptr<int>());
});
return std::make_tuple(output, argmax);
}
at::Tensor ROIPool_backward_cpu(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& argmax,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width) {
// Check if input tensors are CPU tensors
TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor");
TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor");
TORCH_CHECK(argmax.device().is_cpu(), "argmax must be a CPU tensor");
TORCH_CHECK(
rois.size(1) == 5, "Tensor rois should have shape as Tensor[K, 5]");
at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2};
at::CheckedFrom c = "ROIPool_backward_cpu";
at::checkAllSameType(c, {grad_t, rois_t});
auto num_rois = rois.size(0);
at::Tensor grad_input =
at::zeros({batch_size, channels, height, width}, grad.options());
// handle possibly empty gradients
if (grad.numel() == 0) {
return grad_input;
}
// get stride values to ensure indexing into gradients is correct.
int n_stride = grad.stride(0);
int c_stride = grad.stride(1);
int h_stride = grad.stride(2);
int w_stride = grad.stride(3);
auto rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "ROIPool_backward", [&] {
RoIPoolBackward<scalar_t>(
grad.data_ptr<scalar_t>(),
argmax.data_ptr<int>(),
num_rois,
channels,
height,
width,
pooled_height,
pooled_width,
grad_input.data_ptr<scalar_t>(),
rois_.data_ptr<scalar_t>(),
n_stride,
c_stride,
h_stride,
w_stride);
});
return grad_input;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment