Commit bf491463 authored by limm's avatar limm
Browse files

add v0.19.1 release

parent e17f5ea2
import torch
from torchvision.transforms import Compose
import unittest
import random
import numpy as np
import warnings
from _assert_utils import assert_equal
import numpy as np
import pytest
import torch
from common_utils import assert_equal
from torchvision.transforms import Compose
try:
from scipy import stats
......@@ -17,21 +18,22 @@ with warnings.catch_warnings(record=True):
import torchvision.transforms._transforms_video as transforms
class TestVideoTransforms(unittest.TestCase):
class TestVideoTransforms:
def test_random_crop_video(self):
numFrames = random.randint(4, 128)
height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2
oheight = random.randint(5, (height - 2) / 2) * 2
owidth = random.randint(5, (width - 2) / 2) * 2
oheight = random.randint(5, (height - 2) // 2) * 2
owidth = random.randint(5, (width - 2) // 2) * 2
clip = torch.randint(0, 256, (numFrames, height, width, 3), dtype=torch.uint8)
result = Compose([
transforms.ToTensorVideo(),
transforms.RandomCropVideo((oheight, owidth)),
])(clip)
self.assertEqual(result.size(2), oheight)
self.assertEqual(result.size(3), owidth)
result = Compose(
[
transforms.ToTensorVideo(),
transforms.RandomCropVideo((oheight, owidth)),
]
)(clip)
assert result.size(2) == oheight
assert result.size(3) == owidth
transforms.RandomCropVideo((oheight, owidth)).__repr__()
......@@ -39,15 +41,17 @@ class TestVideoTransforms(unittest.TestCase):
numFrames = random.randint(4, 128)
height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2
oheight = random.randint(5, (height - 2) / 2) * 2
owidth = random.randint(5, (width - 2) / 2) * 2
oheight = random.randint(5, (height - 2) // 2) * 2
owidth = random.randint(5, (width - 2) // 2) * 2
clip = torch.randint(0, 256, (numFrames, height, width, 3), dtype=torch.uint8)
result = Compose([
transforms.ToTensorVideo(),
transforms.RandomResizedCropVideo((oheight, owidth)),
])(clip)
self.assertEqual(result.size(2), oheight)
self.assertEqual(result.size(3), owidth)
result = Compose(
[
transforms.ToTensorVideo(),
transforms.RandomResizedCropVideo((oheight, owidth)),
]
)(clip)
assert result.size(2) == oheight
assert result.size(3) == owidth
transforms.RandomResizedCropVideo((oheight, owidth)).__repr__()
......@@ -55,67 +59,77 @@ class TestVideoTransforms(unittest.TestCase):
numFrames = random.randint(4, 128)
height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2
oheight = random.randint(5, (height - 2) / 2) * 2
owidth = random.randint(5, (width - 2) / 2) * 2
oheight = random.randint(5, (height - 2) // 2) * 2
owidth = random.randint(5, (width - 2) // 2) * 2
clip = torch.ones((numFrames, height, width, 3), dtype=torch.uint8) * 255
oh1 = (height - oheight) // 2
ow1 = (width - owidth) // 2
clipNarrow = clip[:, oh1:oh1 + oheight, ow1:ow1 + owidth, :]
clipNarrow = clip[:, oh1 : oh1 + oheight, ow1 : ow1 + owidth, :]
clipNarrow.fill_(0)
result = Compose([
transforms.ToTensorVideo(),
transforms.CenterCropVideo((oheight, owidth)),
])(clip)
msg = "height: " + str(height) + " width: " \
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
self.assertEqual(result.sum().item(), 0, msg)
result = Compose(
[
transforms.ToTensorVideo(),
transforms.CenterCropVideo((oheight, owidth)),
]
)(clip)
msg = (
"height: " + str(height) + " width: " + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
)
assert result.sum().item() == 0, msg
oheight += 1
owidth += 1
result = Compose([
transforms.ToTensorVideo(),
transforms.CenterCropVideo((oheight, owidth)),
])(clip)
result = Compose(
[
transforms.ToTensorVideo(),
transforms.CenterCropVideo((oheight, owidth)),
]
)(clip)
sum1 = result.sum()
msg = "height: " + str(height) + " width: " \
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
self.assertEqual(sum1.item() > 1, True, msg)
msg = (
"height: " + str(height) + " width: " + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
)
assert sum1.item() > 1, msg
oheight += 1
owidth += 1
result = Compose([
transforms.ToTensorVideo(),
transforms.CenterCropVideo((oheight, owidth)),
])(clip)
result = Compose(
[
transforms.ToTensorVideo(),
transforms.CenterCropVideo((oheight, owidth)),
]
)(clip)
sum2 = result.sum()
msg = "height: " + str(height) + " width: " \
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
self.assertTrue(sum2.item() > 1, msg)
self.assertTrue(sum2.item() > sum1.item(), msg)
msg = (
"height: " + str(height) + " width: " + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
)
assert sum2.item() > 1, msg
assert sum2.item() > sum1.item(), msg
@unittest.skipIf(stats is None, 'scipy.stats is not available')
def test_normalize_video(self):
@pytest.mark.skipif(stats is None, reason="scipy.stats is not available")
@pytest.mark.parametrize("channels", [1, 3])
def test_normalize_video(self, channels):
def samples_from_standard_normal(tensor):
p_value = stats.kstest(list(tensor.view(-1)), 'norm', args=(0, 1)).pvalue
p_value = stats.kstest(list(tensor.view(-1)), "norm", args=(0, 1)).pvalue
return p_value > 0.0001
random_state = random.getstate()
random.seed(42)
for channels in [1, 3]:
numFrames = random.randint(4, 128)
height = random.randint(32, 256)
width = random.randint(32, 256)
mean = random.random()
std = random.random()
clip = torch.normal(mean, std, size=(channels, numFrames, height, width))
mean = [clip[c].mean().item() for c in range(channels)]
std = [clip[c].std().item() for c in range(channels)]
normalized = transforms.NormalizeVideo(mean, std)(clip)
self.assertTrue(samples_from_standard_normal(normalized))
numFrames = random.randint(4, 128)
height = random.randint(32, 256)
width = random.randint(32, 256)
mean = random.random()
std = random.random()
clip = torch.normal(mean, std, size=(channels, numFrames, height, width))
mean = [clip[c].mean().item() for c in range(channels)]
std = [clip[c].std().item() for c in range(channels)]
normalized = transforms.NormalizeVideo(mean, std)(clip)
assert samples_from_standard_normal(normalized)
random.setstate(random_state)
# Checking the optional in-place behaviour
......@@ -129,49 +143,36 @@ class TestVideoTransforms(unittest.TestCase):
numFrames, height, width = 64, 4, 4
trans = transforms.ToTensorVideo()
with self.assertRaises(TypeError):
trans(np.random.rand(numFrames, height, width, 1).tolist())
with pytest.raises(TypeError):
np_rng = np.random.RandomState(0)
trans(np_rng.rand(numFrames, height, width, 1).tolist())
with pytest.raises(TypeError):
trans(torch.rand((numFrames, height, width, 1), dtype=torch.float))
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
trans(torch.ones((3, numFrames, height, width, 3), dtype=torch.uint8))
with pytest.raises(ValueError):
trans(torch.ones((height, width, 3), dtype=torch.uint8))
with pytest.raises(ValueError):
trans(torch.ones((width, 3), dtype=torch.uint8))
with pytest.raises(ValueError):
trans(torch.ones((3), dtype=torch.uint8))
trans.__repr__()
@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_horizontal_flip_video(self):
random_state = random.getstate()
random.seed(42)
@pytest.mark.parametrize("p", (0, 1))
def test_random_horizontal_flip_video(self, p):
clip = torch.rand((3, 4, 112, 112), dtype=torch.float)
hclip = clip.flip((-1))
num_samples = 250
num_horizontal = 0
for _ in range(num_samples):
out = transforms.RandomHorizontalFlipVideo()(clip)
if torch.all(torch.eq(out, hclip)):
num_horizontal += 1
hclip = clip.flip(-1)
p_value = stats.binom_test(num_horizontal, num_samples, p=0.5)
random.setstate(random_state)
self.assertGreater(p_value, 0.0001)
num_samples = 250
num_horizontal = 0
for _ in range(num_samples):
out = transforms.RandomHorizontalFlipVideo(p=0.7)(clip)
if torch.all(torch.eq(out, hclip)):
num_horizontal += 1
p_value = stats.binom_test(num_horizontal, num_samples, p=0.7)
random.setstate(random_state)
self.assertGreater(p_value, 0.0001)
out = transforms.RandomHorizontalFlipVideo(p=p)(clip)
if p == 0:
torch.testing.assert_close(out, clip)
elif p == 1:
torch.testing.assert_close(out, hclip)
transforms.RandomHorizontalFlipVideo().__repr__()
if __name__ == '__main__':
unittest.main()
if __name__ == "__main__":
pytest.main([__file__])
from copy import deepcopy
import pytest
import torch
from common_utils import assert_equal, make_bounding_boxes, make_image, make_segmentation_mask, make_video
from PIL import Image
from torchvision import tv_tensors
@pytest.fixture(autouse=True)
def restore_tensor_return_type():
# This is for security, as we should already be restoring the default manually in each test anyway
# (at least at the time of writing...)
yield
tv_tensors.set_return_type("Tensor")
@pytest.mark.parametrize("data", [torch.rand(3, 32, 32), Image.new("RGB", (32, 32), color=123)])
def test_image_instance(data):
image = tv_tensors.Image(data)
assert isinstance(image, torch.Tensor)
assert image.ndim == 3 and image.shape[0] == 3
@pytest.mark.parametrize("data", [torch.randint(0, 10, size=(1, 32, 32)), Image.new("L", (32, 32), color=2)])
def test_mask_instance(data):
mask = tv_tensors.Mask(data)
assert isinstance(mask, torch.Tensor)
assert mask.ndim == 3 and mask.shape[0] == 1
@pytest.mark.parametrize("data", [torch.randint(0, 32, size=(5, 4)), [[0, 0, 5, 5], [2, 2, 7, 7]], [1, 2, 3, 4]])
@pytest.mark.parametrize(
"format", ["XYXY", "CXCYWH", tv_tensors.BoundingBoxFormat.XYXY, tv_tensors.BoundingBoxFormat.XYWH]
)
def test_bbox_instance(data, format):
bboxes = tv_tensors.BoundingBoxes(data, format=format, canvas_size=(32, 32))
assert isinstance(bboxes, torch.Tensor)
assert bboxes.ndim == 2 and bboxes.shape[1] == 4
if isinstance(format, str):
format = tv_tensors.BoundingBoxFormat[(format.upper())]
assert bboxes.format == format
def test_bbox_dim_error():
data_3d = [[[1, 2, 3, 4]]]
with pytest.raises(ValueError, match="Expected a 1D or 2D tensor, got 3D"):
tv_tensors.BoundingBoxes(data_3d, format="XYXY", canvas_size=(32, 32))
@pytest.mark.parametrize(
("data", "input_requires_grad", "expected_requires_grad"),
[
([[[0.0, 1.0], [0.0, 1.0]]], None, False),
([[[0.0, 1.0], [0.0, 1.0]]], False, False),
([[[0.0, 1.0], [0.0, 1.0]]], True, True),
(torch.rand(3, 16, 16, requires_grad=False), None, False),
(torch.rand(3, 16, 16, requires_grad=False), False, False),
(torch.rand(3, 16, 16, requires_grad=False), True, True),
(torch.rand(3, 16, 16, requires_grad=True), None, True),
(torch.rand(3, 16, 16, requires_grad=True), False, False),
(torch.rand(3, 16, 16, requires_grad=True), True, True),
],
)
def test_new_requires_grad(data, input_requires_grad, expected_requires_grad):
tv_tensor = tv_tensors.Image(data, requires_grad=input_requires_grad)
assert tv_tensor.requires_grad is expected_requires_grad
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
def test_isinstance(make_input):
assert isinstance(make_input(), torch.Tensor)
def test_wrapping_no_copy():
tensor = torch.rand(3, 16, 16)
image = tv_tensors.Image(tensor)
assert image.data_ptr() == tensor.data_ptr()
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
def test_to_wrapping(make_input):
dp = make_input()
dp_to = dp.to(torch.float64)
assert type(dp_to) is type(dp)
assert dp_to.dtype is torch.float64
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
def test_to_tv_tensor_reference(make_input, return_type):
tensor = torch.rand((3, 16, 16), dtype=torch.float64)
dp = make_input()
with tv_tensors.set_return_type(return_type):
tensor_to = tensor.to(dp)
assert type(tensor_to) is (type(dp) if return_type == "TVTensor" else torch.Tensor)
assert tensor_to.dtype is dp.dtype
assert type(tensor) is torch.Tensor
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
def test_clone_wrapping(make_input, return_type):
dp = make_input()
with tv_tensors.set_return_type(return_type):
dp_clone = dp.clone()
assert type(dp_clone) is type(dp)
assert dp_clone.data_ptr() != dp.data_ptr()
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
def test_requires_grad__wrapping(make_input, return_type):
dp = make_input(dtype=torch.float)
assert not dp.requires_grad
with tv_tensors.set_return_type(return_type):
dp_requires_grad = dp.requires_grad_(True)
assert type(dp_requires_grad) is type(dp)
assert dp.requires_grad
assert dp_requires_grad.requires_grad
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
def test_detach_wrapping(make_input, return_type):
dp = make_input(dtype=torch.float).requires_grad_(True)
with tv_tensors.set_return_type(return_type):
dp_detached = dp.detach()
assert type(dp_detached) is type(dp)
@pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
def test_force_subclass_with_metadata(return_type):
# Sanity checks for the ops in _FORCE_TORCHFUNCTION_SUBCLASS and tv_tensors with metadata
# Largely the same as above, we additionally check that the metadata is preserved
format, canvas_size = "XYXY", (32, 32)
bbox = tv_tensors.BoundingBoxes([[0, 0, 5, 5], [2, 2, 7, 7]], format=format, canvas_size=canvas_size)
tv_tensors.set_return_type(return_type)
bbox = bbox.clone()
if return_type == "TVTensor":
assert bbox.format, bbox.canvas_size == (format, canvas_size)
bbox = bbox.to(torch.float64)
if return_type == "TVTensor":
assert bbox.format, bbox.canvas_size == (format, canvas_size)
bbox = bbox.detach()
if return_type == "TVTensor":
assert bbox.format, bbox.canvas_size == (format, canvas_size)
assert not bbox.requires_grad
bbox.requires_grad_(True)
if return_type == "TVTensor":
assert bbox.format, bbox.canvas_size == (format, canvas_size)
assert bbox.requires_grad
tv_tensors.set_return_type("tensor")
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
def test_other_op_no_wrapping(make_input, return_type):
dp = make_input()
with tv_tensors.set_return_type(return_type):
# any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
output = dp * 2
assert type(output) is (type(dp) if return_type == "TVTensor" else torch.Tensor)
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize(
"op",
[
lambda t: t.numpy(),
lambda t: t.tolist(),
lambda t: t.max(dim=-1),
],
)
def test_no_tensor_output_op_no_wrapping(make_input, op):
dp = make_input()
output = op(dp)
assert type(output) is not type(dp)
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
def test_inplace_op_no_wrapping(make_input, return_type):
dp = make_input()
original_type = type(dp)
with tv_tensors.set_return_type(return_type):
output = dp.add_(0)
assert type(output) is (type(dp) if return_type == "TVTensor" else torch.Tensor)
assert type(dp) is original_type
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
def test_wrap(make_input):
dp = make_input()
# any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
output = dp * 2
dp_new = tv_tensors.wrap(output, like=dp)
assert type(dp_new) is type(dp)
assert dp_new.data_ptr() == output.data_ptr()
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("requires_grad", [False, True])
def test_deepcopy(make_input, requires_grad):
dp = make_input(dtype=torch.float)
dp.requires_grad_(requires_grad)
dp_deepcopied = deepcopy(dp)
assert dp_deepcopied is not dp
assert dp_deepcopied.data_ptr() != dp.data_ptr()
assert_equal(dp_deepcopied, dp)
assert type(dp_deepcopied) is type(dp)
assert dp_deepcopied.requires_grad is requires_grad
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
@pytest.mark.parametrize(
"op",
(
lambda dp: dp + torch.rand(*dp.shape),
lambda dp: torch.rand(*dp.shape) + dp,
lambda dp: dp * torch.rand(*dp.shape),
lambda dp: torch.rand(*dp.shape) * dp,
lambda dp: dp + 3,
lambda dp: 3 + dp,
lambda dp: dp + dp,
lambda dp: dp.sum(),
lambda dp: dp.reshape(-1),
lambda dp: dp.int(),
lambda dp: torch.stack([dp, dp]),
lambda dp: torch.chunk(dp, 2)[0],
lambda dp: torch.unbind(dp)[0],
),
)
def test_usual_operations(make_input, return_type, op):
dp = make_input()
with tv_tensors.set_return_type(return_type):
out = op(dp)
assert type(out) is (type(dp) if return_type == "TVTensor" else torch.Tensor)
if isinstance(dp, tv_tensors.BoundingBoxes) and return_type == "TVTensor":
assert hasattr(out, "format")
assert hasattr(out, "canvas_size")
def test_subclasses():
img = make_image()
masks = make_segmentation_mask()
with pytest.raises(TypeError, match="unsupported operand"):
img + masks
def test_set_return_type():
img = make_image()
assert type(img + 3) is torch.Tensor
with tv_tensors.set_return_type("TVTensor"):
assert type(img + 3) is tv_tensors.Image
assert type(img + 3) is torch.Tensor
tv_tensors.set_return_type("TVTensor")
assert type(img + 3) is tv_tensors.Image
with tv_tensors.set_return_type("tensor"):
assert type(img + 3) is torch.Tensor
with tv_tensors.set_return_type("TVTensor"):
assert type(img + 3) is tv_tensors.Image
tv_tensors.set_return_type("tensor")
assert type(img + 3) is torch.Tensor
assert type(img + 3) is torch.Tensor
# Exiting a context manager will restore the return type as it was prior to entering it,
# regardless of whether the "global" tv_tensors.set_return_type() was called within the context manager.
assert type(img + 3) is tv_tensors.Image
tv_tensors.set_return_type("tensor")
def test_return_type_input():
img = make_image()
# Case-insensitive
with tv_tensors.set_return_type("tvtensor"):
assert type(img + 3) is tv_tensors.Image
with pytest.raises(ValueError, match="return_type must be"):
tv_tensors.set_return_type("typo")
tv_tensors.set_return_type("tensor")
import pytest
import numpy as np
import os
import re
import sys
import tempfile
from io import BytesIO
import numpy as np
import pytest
import torch
import torchvision.transforms.functional as F
import torchvision.utils as utils
from common_utils import assert_equal, cpu_and_cuda
from PIL import __version__ as PILLOW_VERSION, Image, ImageColor
from torchvision.transforms.v2.functional import to_dtype
from io import BytesIO
import torchvision.transforms.functional as F
from PIL import Image, __version__ as PILLOW_VERSION, ImageColor
from _assert_utils import assert_equal
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split("."))
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split('.'))
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0],
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
keypoints = torch.tensor([[[10, 10], [5, 5], [2, 2]], [[20, 20], [30, 30], [3, 3]]], dtype=torch.float)
def test_make_grid_not_inplace():
......@@ -23,13 +26,13 @@ def test_make_grid_not_inplace():
t_clone = t.clone()
utils.make_grid(t, normalize=False)
assert_equal(t, t_clone, msg='make_grid modified tensor in-place')
assert_equal(t, t_clone, msg="make_grid modified tensor in-place")
utils.make_grid(t, normalize=True, scale_each=False)
assert_equal(t, t_clone, msg='make_grid modified tensor in-place')
assert_equal(t, t_clone, msg="make_grid modified tensor in-place")
utils.make_grid(t, normalize=True, scale_each=True)
assert_equal(t, t_clone, msg='make_grid modified tensor in-place')
assert_equal(t, t_clone, msg="make_grid modified tensor in-place")
def test_normalize_in_make_grid():
......@@ -43,51 +46,51 @@ def test_normalize_in_make_grid():
# Rounding the result to one decimal for comparison
n_digits = 1
rounded_grid_max = torch.round(grid_max * 10 ** n_digits) / (10 ** n_digits)
rounded_grid_min = torch.round(grid_min * 10 ** n_digits) / (10 ** n_digits)
rounded_grid_max = torch.round(grid_max * 10**n_digits) / (10**n_digits)
rounded_grid_min = torch.round(grid_min * 10**n_digits) / (10**n_digits)
assert_equal(norm_max, rounded_grid_max, msg='Normalized max is not equal to 1')
assert_equal(norm_min, rounded_grid_min, msg='Normalized min is not equal to 0')
assert_equal(norm_max, rounded_grid_max, msg="Normalized max is not equal to 1")
assert_equal(norm_min, rounded_grid_min, msg="Normalized min is not equal to 0")
@pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows')
@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
def test_save_image():
with tempfile.NamedTemporaryFile(suffix='.png') as f:
with tempfile.NamedTemporaryFile(suffix=".png") as f:
t = torch.rand(2, 3, 64, 64)
utils.save_image(t, f.name)
assert os.path.exists(f.name), 'The image is not present after save'
assert os.path.exists(f.name), "The image is not present after save"
@pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows')
@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
def test_save_image_single_pixel():
with tempfile.NamedTemporaryFile(suffix='.png') as f:
with tempfile.NamedTemporaryFile(suffix=".png") as f:
t = torch.rand(1, 3, 1, 1)
utils.save_image(t, f.name)
assert os.path.exists(f.name), 'The pixel image is not present after save'
assert os.path.exists(f.name), "The pixel image is not present after save"
@pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows')
@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
def test_save_image_file_object():
with tempfile.NamedTemporaryFile(suffix='.png') as f:
with tempfile.NamedTemporaryFile(suffix=".png") as f:
t = torch.rand(2, 3, 64, 64)
utils.save_image(t, f.name)
img_orig = Image.open(f.name)
fp = BytesIO()
utils.save_image(t, fp, format='png')
utils.save_image(t, fp, format="png")
img_bytes = Image.open(fp)
assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg='Image not stored in file object')
assert_equal(F.pil_to_tensor(img_orig), F.pil_to_tensor(img_bytes), msg="Image not stored in file object")
@pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows')
@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
def test_save_image_single_pixel_file_object():
with tempfile.NamedTemporaryFile(suffix='.png') as f:
with tempfile.NamedTemporaryFile(suffix=".png") as f:
t = torch.rand(1, 3, 1, 1)
utils.save_image(t, f.name)
img_orig = Image.open(f.name)
fp = BytesIO()
utils.save_image(t, fp, format='png')
utils.save_image(t, fp, format="png")
img_bytes = Image.open(fp)
assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg='Image not stored in file object')
assert_equal(F.pil_to_tensor(img_orig), F.pil_to_tensor(img_bytes), msg="Image not stored in file object")
def test_draw_boxes():
......@@ -103,7 +106,7 @@ def test_draw_boxes():
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
res.save(path)
if PILLOW_VERSION >= (8, 2):
if PILLOW_VERSION >= (10, 1):
# The reference image is only valid for new PIL versions
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
assert_equal(result, expected)
......@@ -113,11 +116,37 @@ def test_draw_boxes():
assert_equal(img, img_cp)
@pytest.mark.parametrize("fill", [True, False])
def test_draw_boxes_dtypes(fill):
img_uint8 = torch.full((3, 100, 100), 255, dtype=torch.uint8)
out_uint8 = utils.draw_bounding_boxes(img_uint8, boxes, fill=fill)
assert img_uint8 is not out_uint8
assert out_uint8.dtype == torch.uint8
img_float = to_dtype(img_uint8, torch.float, scale=True)
out_float = utils.draw_bounding_boxes(img_float, boxes, fill=fill)
assert img_float is not out_float
assert out_float.is_floating_point()
torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1)
@pytest.mark.parametrize("colors", [None, ["red", "blue", "#FF00FF", (1, 34, 122)], "red", "#FF00FF", (1, 34, 122)])
def test_draw_boxes_colors(colors):
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
utils.draw_bounding_boxes(img, boxes, fill=False, width=7, colors=colors)
with pytest.raises(ValueError, match="Number of colors must be equal or larger than the number of objects"):
utils.draw_bounding_boxes(image=img, boxes=boxes, colors=[])
def test_draw_boxes_vanilla():
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
img_cp = img.clone()
boxes_cp = boxes.clone()
result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7)
result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7, colors="white")
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_vanilla.png")
if not os.path.exists(path):
......@@ -131,39 +160,75 @@ def test_draw_boxes_vanilla():
assert_equal(img, img_cp)
def test_draw_boxes_grayscale():
img = torch.full((1, 4, 4), fill_value=255, dtype=torch.uint8)
boxes = torch.tensor([[0, 0, 3, 3]], dtype=torch.int64)
bboxed_img = utils.draw_bounding_boxes(image=img, boxes=boxes, colors=["#1BBC9B"])
assert bboxed_img.size(0) == 3
def test_draw_invalid_boxes():
img_tp = ((1, 1, 1), (1, 2, 3))
img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0],
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
img_correct = torch.zeros((3, 10, 10), dtype=torch.uint8)
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
boxes_wrong = torch.tensor([[10, 10, 4, 5], [30, 20, 10, 5]], dtype=torch.float)
labels_wrong = ["one", "two"]
colors_wrong = ["pink", "blue"]
with pytest.raises(TypeError, match="Tensor expected"):
utils.draw_bounding_boxes(img_tp, boxes)
with pytest.raises(ValueError, match="Tensor uint8 expected"):
utils.draw_bounding_boxes(img_wrong1, boxes)
with pytest.raises(ValueError, match="Pass individual images, not batches"):
utils.draw_bounding_boxes(img_wrong2, boxes)
with pytest.raises(ValueError, match="Only grayscale and RGB images are supported"):
utils.draw_bounding_boxes(img_wrong2[0][:2], boxes)
with pytest.raises(ValueError, match="Number of boxes"):
utils.draw_bounding_boxes(img_correct, boxes, labels_wrong)
with pytest.raises(ValueError, match="Number of colors"):
utils.draw_bounding_boxes(img_correct, boxes, colors=colors_wrong)
with pytest.raises(ValueError, match="Boxes need to be in"):
utils.draw_bounding_boxes(img_correct, boxes_wrong)
def test_draw_boxes_warning():
img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
with pytest.warns(UserWarning, match=re.escape("Argument 'font_size' will be ignored since 'font' is not set.")):
utils.draw_bounding_boxes(img, boxes, font_size=11)
@pytest.mark.parametrize('colors', [
None,
['red', 'blue'],
['#FF00FF', (1, 34, 122)],
])
@pytest.mark.parametrize('alpha', (0, .5, .7, 1))
def test_draw_segmentation_masks(colors, alpha):
def test_draw_no_boxes():
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
boxes = torch.full((0, 4), 0, dtype=torch.float)
with pytest.warns(UserWarning, match=re.escape("boxes doesn't contain any box. No box was drawn")):
res = utils.draw_bounding_boxes(img, boxes)
# Check that the function didn't change the image
assert res.eq(img).all()
@pytest.mark.parametrize(
"colors",
[
None,
"blue",
"#FF00FF",
(1, 34, 122),
["red", "blue"],
["#FF00FF", (1, 34, 122)],
],
)
@pytest.mark.parametrize("alpha", (0, 0.5, 0.7, 1))
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_draw_segmentation_masks(colors, alpha, device):
"""This test makes sure that masks draw their corresponding color where they should"""
num_masks, h, w = 2, 100, 100
dtype = torch.uint8
img = torch.randint(0, 256, size=(3, h, w), dtype=dtype)
masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool)
img = torch.randint(0, 256, size=(3, h, w), dtype=dtype, device=device)
masks = torch.zeros((num_masks, h, w), dtype=torch.bool, device=device)
masks[0, 10:20, 10:20] = True
masks[1, 15:25, 15:25] = True
# For testing we enforce that there's no overlap between the masks. The
# current behaviour is that the last mask's color will take priority when
# masks overlap, but this makes testing slightly harder so we don't really
# care
overlap = masks[0] & masks[1]
masks[:, overlap] = False
out = utils.draw_segmentation_masks(img, masks, colors=colors, alpha=alpha)
assert out.dtype == dtype
......@@ -175,27 +240,53 @@ def test_draw_segmentation_masks(colors, alpha):
if colors is None:
colors = utils._generate_color_palette(num_masks)
elif isinstance(colors, str) or isinstance(colors, tuple):
colors = [colors]
# Make sure each mask draws with its own color
for mask, color in zip(masks, colors):
if isinstance(color, str):
color = ImageColor.getrgb(color)
color = torch.tensor(color, dtype=dtype)
color = torch.tensor(color, dtype=dtype, device=device)
if alpha == 1:
assert (out[:, mask] == color[:, None]).all()
assert (out[:, mask & ~overlap] == color[:, None]).all()
elif alpha == 0:
assert (out[:, mask] == img[:, mask]).all()
assert (out[:, mask & ~overlap] == img[:, mask & ~overlap]).all()
interpolated_color = (img[:, mask & ~overlap] * (1 - alpha) + color[:, None] * alpha).to(dtype)
torch.testing.assert_close(out[:, mask & ~overlap], interpolated_color, rtol=0.0, atol=1.0)
interpolated_overlap = (img[:, overlap] * (1 - alpha)).to(dtype)
torch.testing.assert_close(out[:, overlap], interpolated_overlap, rtol=0.0, atol=1.0)
def test_draw_segmentation_masks_dtypes():
num_masks, h, w = 2, 100, 100
masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool)
img_uint8 = torch.randint(0, 256, size=(3, h, w), dtype=torch.uint8)
out_uint8 = utils.draw_segmentation_masks(img_uint8, masks)
assert img_uint8 is not out_uint8
assert out_uint8.dtype == torch.uint8
img_float = to_dtype(img_uint8, torch.float, scale=True)
out_float = utils.draw_segmentation_masks(img_float, masks)
assert img_float is not out_float
assert out_float.is_floating_point()
interpolated_color = (img[:, mask] * (1 - alpha) + color[:, None] * alpha).to(dtype)
torch.testing.assert_close(out[:, mask], interpolated_color, rtol=0.0, atol=1.0)
torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1)
def test_draw_segmentation_masks_errors():
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_draw_segmentation_masks_errors(device):
h, w = 10, 10
masks = torch.randint(0, 2, size=(h, w), dtype=torch.bool)
img = torch.randint(0, 256, size=(3, h, w), dtype=torch.uint8)
masks = torch.randint(0, 2, size=(h, w), dtype=torch.bool, device=device)
img = torch.randint(0, 256, size=(3, h, w), dtype=torch.uint8, device=device)
with pytest.raises(TypeError, match="The image must be a tensor"):
utils.draw_segmentation_masks(image="Not A Tensor Image", masks=masks)
......@@ -217,15 +308,236 @@ def test_draw_segmentation_masks_errors():
with pytest.raises(ValueError, match="must have the same height and width"):
masks_bad_shape = torch.randint(0, 2, size=(h + 4, w), dtype=torch.bool)
utils.draw_segmentation_masks(image=img, masks=masks_bad_shape)
with pytest.raises(ValueError, match="There are more masks"):
with pytest.raises(ValueError, match="Number of colors must be equal or larger than the number of objects"):
utils.draw_segmentation_masks(image=img, masks=masks, colors=[])
with pytest.raises(ValueError, match="colors must be a tuple or a string, or a list thereof"):
bad_colors = np.array(['red', 'blue']) # should be a list
with pytest.raises(ValueError, match="`colors` must be a tuple or a string, or a list thereof"):
bad_colors = np.array(["red", "blue"]) # should be a list
utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors)
with pytest.raises(ValueError, match="It seems that you passed a tuple of colors instead of"):
bad_colors = ('red', 'blue') # should be a list
with pytest.raises(ValueError, match="If passed as tuple, colors should be an RGB triplet"):
bad_colors = ("red", "blue") # should be a list
utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors)
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_draw_no_segmention_mask(device):
img = torch.full((3, 100, 100), 0, dtype=torch.uint8, device=device)
masks = torch.full((0, 100, 100), 0, dtype=torch.bool, device=device)
with pytest.warns(UserWarning, match=re.escape("masks doesn't contain any mask. No mask was drawn")):
res = utils.draw_segmentation_masks(img, masks)
# Check that the function didn't change the image
assert res.eq(img).all()
def test_draw_keypoints_vanilla():
# Keypoints is declared on top as global variable
keypoints_cp = keypoints.clone()
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
img_cp = img.clone()
result = utils.draw_keypoints(
img,
keypoints,
colors="red",
connectivity=[
(0, 1),
],
)
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_keypoint_vanilla.png")
if not os.path.exists(path):
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
res.save(path)
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
assert_equal(result, expected)
# Check that keypoints are not modified inplace
assert_equal(keypoints, keypoints_cp)
# Check that image is not modified in place
assert_equal(img, img_cp)
def test_draw_keypoins_K_equals_one():
# Non-regression test for https://github.com/pytorch/vision/pull/8439
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
keypoints = torch.tensor([[[10, 10]]], dtype=torch.float)
utils.draw_keypoints(img, keypoints)
@pytest.mark.parametrize("colors", ["red", "#FF00FF", (1, 34, 122)])
def test_draw_keypoints_colored(colors):
# Keypoints is declared on top as global variable
keypoints_cp = keypoints.clone()
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
img_cp = img.clone()
result = utils.draw_keypoints(
img,
keypoints,
colors=colors,
connectivity=[
(0, 1),
],
)
assert result.size(0) == 3
assert_equal(keypoints, keypoints_cp)
assert_equal(img, img_cp)
@pytest.mark.parametrize("connectivity", [[(0, 1)], [(0, 1), (1, 2)]])
@pytest.mark.parametrize(
"vis",
[
torch.tensor([[1, 1, 0], [1, 1, 0]], dtype=torch.bool),
torch.tensor([[1, 1, 0], [1, 1, 0]], dtype=torch.float).unsqueeze_(-1),
],
)
def test_draw_keypoints_visibility(connectivity, vis):
# Keypoints is declared on top as global variable
keypoints_cp = keypoints.clone()
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
img_cp = img.clone()
vis_cp = vis if vis is None else vis.clone()
result = utils.draw_keypoints(
image=img,
keypoints=keypoints,
connectivity=connectivity,
colors="red",
visibility=vis,
)
assert result.size(0) == 3
assert_equal(keypoints, keypoints_cp)
assert_equal(img, img_cp)
# compare with a fakedata image
# connect the key points 0 to 1 for both skeletons and do not show the other key points
path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_keypoints_visibility.png"
)
if not os.path.exists(path):
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
res.save(path)
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
assert_equal(result, expected)
if vis_cp is None:
assert vis is None
else:
assert_equal(vis, vis_cp)
assert vis.dtype == vis_cp.dtype
def test_draw_keypoints_visibility_default():
# Keypoints is declared on top as global variable
keypoints_cp = keypoints.clone()
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
img_cp = img.clone()
result = utils.draw_keypoints(
image=img,
keypoints=keypoints,
connectivity=[(0, 1)],
colors="red",
visibility=None,
)
assert result.size(0) == 3
assert_equal(keypoints, keypoints_cp)
assert_equal(img, img_cp)
# compare against fakedata image, which connects 0->1 for both key-point skeletons
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_keypoint_vanilla.png")
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
assert_equal(result, expected)
def test_draw_keypoints_dtypes():
image_uint8 = torch.randint(0, 256, size=(3, 100, 100), dtype=torch.uint8)
image_float = to_dtype(image_uint8, torch.float, scale=True)
out_uint8 = utils.draw_keypoints(image_uint8, keypoints)
out_float = utils.draw_keypoints(image_float, keypoints)
assert out_uint8.dtype == torch.uint8
assert out_uint8 is not image_uint8
assert out_float.is_floating_point()
assert out_float is not image_float
torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1)
def test_draw_keypoints_errors():
h, w = 10, 10
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
with pytest.raises(TypeError, match="The image must be a tensor"):
utils.draw_keypoints(image="Not A Tensor Image", keypoints=keypoints)
with pytest.raises(ValueError, match="The image dtype must be"):
img_bad_dtype = torch.full((3, h, w), 0, dtype=torch.int64)
utils.draw_keypoints(image=img_bad_dtype, keypoints=keypoints)
with pytest.raises(ValueError, match="Pass individual images, not batches"):
batch = torch.randint(0, 256, size=(10, 3, h, w), dtype=torch.uint8)
utils.draw_keypoints(image=batch, keypoints=keypoints)
with pytest.raises(ValueError, match="Pass an RGB image"):
one_channel = torch.randint(0, 256, size=(1, h, w), dtype=torch.uint8)
utils.draw_keypoints(image=one_channel, keypoints=keypoints)
with pytest.raises(ValueError, match="keypoints must be of shape"):
invalid_keypoints = torch.tensor([[10, 10, 10, 10], [5, 6, 7, 8]], dtype=torch.float)
utils.draw_keypoints(image=img, keypoints=invalid_keypoints)
with pytest.raises(ValueError, match=re.escape("visibility must be of shape (num_instances, K)")):
one_dim_visibility = torch.tensor([True, True, True], dtype=torch.bool)
utils.draw_keypoints(image=img, keypoints=keypoints, visibility=one_dim_visibility)
with pytest.raises(ValueError, match=re.escape("visibility must be of shape (num_instances, K)")):
three_dim_visibility = torch.ones((2, 3, 4), dtype=torch.bool)
utils.draw_keypoints(image=img, keypoints=keypoints, visibility=three_dim_visibility)
with pytest.raises(ValueError, match="keypoints and visibility must have the same dimensionality"):
vis_wrong_n = torch.ones((3, 3), dtype=torch.bool)
utils.draw_keypoints(image=img, keypoints=keypoints, visibility=vis_wrong_n)
with pytest.raises(ValueError, match="keypoints and visibility must have the same dimensionality"):
vis_wrong_k = torch.ones((2, 4), dtype=torch.bool)
utils.draw_keypoints(image=img, keypoints=keypoints, visibility=vis_wrong_k)
@pytest.mark.parametrize("batch", (True, False))
def test_flow_to_image(batch):
h, w = 100, 100
flow = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
flow = torch.stack(flow[::-1], dim=0).float()
flow[0] -= h / 2
flow[1] -= w / 2
if batch:
flow = torch.stack([flow, flow])
img = utils.flow_to_image(flow)
assert img.shape == (2, 3, h, w) if batch else (3, h, w)
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "expected_flow.pt")
expected_img = torch.load(path, map_location="cpu", weights_only=True)
if batch:
expected_img = torch.stack([expected_img, expected_img])
assert_equal(expected_img, img)
@pytest.mark.parametrize(
"input_flow, match",
(
(torch.full((3, 10, 10), 0, dtype=torch.float), "Input flow should have shape"),
(torch.full((5, 3, 10, 10), 0, dtype=torch.float), "Input flow should have shape"),
(torch.full((2, 10), 0, dtype=torch.float), "Input flow should have shape"),
(torch.full((5, 2, 10), 0, dtype=torch.float), "Input flow should have shape"),
(torch.full((2, 10, 30), 0, dtype=torch.int), "Flow should be of dtype torch.float"),
),
)
def test_flow_to_image_errors(input_flow, match):
with pytest.raises(ValueError, match=match):
utils.flow_to_image(flow=input_flow)
if __name__ == "__main__":
pytest.main([__file__])
import math
import os
import pytest
import torch
import torchvision
from torchvision.io import _HAS_GPU_VIDEO_DECODER, VideoReader
try:
import av
except ImportError:
av = None
VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos")
@pytest.mark.skipif(_HAS_GPU_VIDEO_DECODER is False, reason="Didn't compile with support for gpu decoder")
class TestVideoGPUDecoder:
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
@pytest.mark.parametrize(
"video_file",
[
"RATRACE_wave_f_nm_np1_fr_goo_37.avi",
"TrumanShow_wave_f_nm_np1_fr_med_26.avi",
"v_SoccerJuggling_g23_c01.avi",
"v_SoccerJuggling_g24_c01.avi",
"R6llTwEh07w.mp4",
"SOX5yA1l24A.mp4",
"WUzgd7C1pWA.mp4",
],
)
def test_frame_reading(self, video_file):
torchvision.set_video_backend("cuda")
full_path = os.path.join(VIDEO_DIR, video_file)
decoder = VideoReader(full_path)
with av.open(full_path) as container:
for av_frame in container.decode(container.streams.video[0]):
av_frames = torch.tensor(av_frame.to_rgb(src_colorspace="ITU709").to_ndarray())
vision_frames = next(decoder)["data"]
mean_delta = torch.mean(torch.abs(av_frames.float() - vision_frames.cpu().float()))
assert mean_delta < 0.75
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
@pytest.mark.parametrize("keyframes", [True, False])
@pytest.mark.parametrize(
"full_path, duration",
[
(os.path.join(VIDEO_DIR, x), y)
for x, y in [
("v_SoccerJuggling_g23_c01.avi", 8.0),
("v_SoccerJuggling_g24_c01.avi", 8.0),
("R6llTwEh07w.mp4", 10.0),
("SOX5yA1l24A.mp4", 11.0),
("WUzgd7C1pWA.mp4", 11.0),
]
],
)
def test_seek_reading(self, keyframes, full_path, duration):
torchvision.set_video_backend("cuda")
decoder = VideoReader(full_path)
time = duration / 2
decoder.seek(time, keyframes_only=keyframes)
with av.open(full_path) as container:
container.seek(int(time * 1000000), any_frame=not keyframes, backward=False)
for av_frame in container.decode(container.streams.video[0]):
av_frames = torch.tensor(av_frame.to_rgb(src_colorspace="ITU709").to_ndarray())
vision_frames = next(decoder)["data"]
mean_delta = torch.mean(torch.abs(av_frames.float() - vision_frames.cpu().float()))
assert mean_delta < 0.75
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
@pytest.mark.parametrize(
"video_file",
[
"RATRACE_wave_f_nm_np1_fr_goo_37.avi",
"TrumanShow_wave_f_nm_np1_fr_med_26.avi",
"v_SoccerJuggling_g23_c01.avi",
"v_SoccerJuggling_g24_c01.avi",
"R6llTwEh07w.mp4",
"SOX5yA1l24A.mp4",
"WUzgd7C1pWA.mp4",
],
)
def test_metadata(self, video_file):
torchvision.set_video_backend("cuda")
full_path = os.path.join(VIDEO_DIR, video_file)
decoder = VideoReader(full_path)
video_metadata = decoder.get_metadata()["video"]
with av.open(full_path) as container:
video = container.streams.video[0]
av_duration = float(video.duration * video.time_base)
assert math.isclose(video_metadata["duration"], av_duration, rel_tol=1e-2)
assert math.isclose(video_metadata["fps"], video.base_rate, rel_tol=1e-2)
if __name__ == "__main__":
pytest.main([__file__])
import collections
import math
import os
import time
import unittest
from fractions import Fraction
import numpy as np
import pytest
import torch
import torchvision.io as io
from common_utils import assert_equal
from numpy.random import randint
from pytest import approx
from torchvision import set_video_backend
from torchvision.io import _HAS_VIDEO_OPT
from common_utils import PY39_SKIP
from _assert_utils import assert_equal
try:
......@@ -23,9 +23,6 @@ except ImportError:
av = None
from urllib.error import URLError
VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos")
CheckerConfig = [
......@@ -110,18 +107,14 @@ test_videos = {
}
DecoderResult = collections.namedtuple(
"DecoderResult", "vframes vframe_pts vtimebase aframes aframe_pts atimebase"
)
DecoderResult = collections.namedtuple("DecoderResult", "vframes vframe_pts vtimebase aframes aframe_pts atimebase")
"""av_seek_frame is imprecise so seek to a timestamp earlier by a margin
The unit of margin is second"""
seek_frame_margin = 0.25
# av_seek_frame is imprecise so seek to a timestamp earlier by a margin
# The unit of margin is second
SEEK_FRAME_MARGIN = 0.25
def _read_from_stream(
container, start_pts, end_pts, stream, stream_name, buffer_size=4
):
def _read_from_stream(container, start_pts, end_pts, stream, stream_name, buffer_size=4):
"""
Args:
container: pyav container
......@@ -134,7 +127,7 @@ def _read_from_stream(
ascending order. We need to decode more frames even when we meet end
pts
"""
# seeking in the stream is imprecise. Thus, seek to an ealier PTS by a margin
# seeking in the stream is imprecise. Thus, seek to an earlier PTS by a margin
margin = 1
seek_offset = max(start_pts - margin, 0)
......@@ -233,9 +226,7 @@ def _decode_frames_by_av_module(
else:
aframes = torch.empty((1, 0), dtype=torch.float32)
aframe_pts = torch.tensor(
[audio_frame.pts for audio_frame in audio_frames], dtype=torch.int64
)
aframe_pts = torch.tensor([audio_frame.pts for audio_frame in audio_frames], dtype=torch.int64)
return DecoderResult(
vframes=vframes,
......@@ -266,64 +257,64 @@ def _get_video_tensor(video_dir, video_file):
assert os.path.exists(full_path), "File not found: %s" % full_path
with open(full_path, "rb") as fp:
video_tensor = torch.from_numpy(np.frombuffer(fp.read(), dtype=np.uint8))
video_tensor = torch.frombuffer(fp.read(), dtype=torch.uint8)
return full_path, video_tensor
@unittest.skipIf(av is None, "PyAV unavailable")
@unittest.skipIf(_HAS_VIDEO_OPT is False, "Didn't compile with ffmpeg")
class TestVideoReader(unittest.TestCase):
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
@pytest.mark.skipif(_HAS_VIDEO_OPT is False, reason="Didn't compile with ffmpeg")
class TestVideoReader:
def check_separate_decoding_result(self, tv_result, config):
"""check the decoding results from TorchVision decoder
"""
vframes, vframe_pts, vtimebase, vfps, vduration, \
aframes, aframe_pts, atimebase, asample_rate, aduration = (
tv_result
)
"""check the decoding results from TorchVision decoder"""
(
vframes,
vframe_pts,
vtimebase,
vfps,
vduration,
aframes,
aframe_pts,
atimebase,
asample_rate,
aduration,
) = tv_result
video_duration = vduration.item() * Fraction(vtimebase[0].item(), vtimebase[1].item())
assert video_duration == approx(config.duration, abs=0.5)
assert vfps.item() == approx(config.video_fps, abs=0.5)
video_duration = vduration.item() * Fraction(
vtimebase[0].item(), vtimebase[1].item()
)
self.assertAlmostEqual(video_duration, config.duration, delta=0.5)
self.assertAlmostEqual(vfps.item(), config.video_fps, delta=0.5)
if asample_rate.numel() > 0:
self.assertEqual(asample_rate.item(), config.audio_sample_rate)
audio_duration = aduration.item() * Fraction(
atimebase[0].item(), atimebase[1].item()
)
self.assertAlmostEqual(audio_duration, config.duration, delta=0.5)
assert asample_rate.item() == config.audio_sample_rate
audio_duration = aduration.item() * Fraction(atimebase[0].item(), atimebase[1].item())
assert audio_duration == approx(config.duration, abs=0.5)
# check if pts of video frames are sorted in ascending order
for i in range(len(vframe_pts) - 1):
self.assertEqual(vframe_pts[i] < vframe_pts[i + 1], True)
assert vframe_pts[i] < vframe_pts[i + 1]
if len(aframe_pts) > 1:
# check if pts of audio frames are sorted in ascending order
for i in range(len(aframe_pts) - 1):
self.assertEqual(aframe_pts[i] < aframe_pts[i + 1], True)
assert aframe_pts[i] < aframe_pts[i + 1]
def check_probe_result(self, result, config):
vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
video_duration = vduration.item() * Fraction(
vtimebase[0].item(), vtimebase[1].item()
)
self.assertAlmostEqual(video_duration, config.duration, delta=0.5)
self.assertAlmostEqual(vfps.item(), config.video_fps, delta=0.5)
video_duration = vduration.item() * Fraction(vtimebase[0].item(), vtimebase[1].item())
assert video_duration == approx(config.duration, abs=0.5)
assert vfps.item() == approx(config.video_fps, abs=0.5)
if asample_rate.numel() > 0:
self.assertEqual(asample_rate.item(), config.audio_sample_rate)
audio_duration = aduration.item() * Fraction(
atimebase[0].item(), atimebase[1].item()
)
self.assertAlmostEqual(audio_duration, config.duration, delta=0.5)
assert asample_rate.item() == config.audio_sample_rate
audio_duration = aduration.item() * Fraction(atimebase[0].item(), atimebase[1].item())
assert audio_duration == approx(config.duration, abs=0.5)
def check_meta_result(self, result, config):
self.assertAlmostEqual(result.video_duration, config.duration, delta=0.5)
self.assertAlmostEqual(result.video_fps, config.video_fps, delta=0.5)
assert result.video_duration == approx(config.duration, abs=0.5)
assert result.video_fps == approx(config.video_fps, abs=0.5)
if result.has_audio > 0:
self.assertEqual(result.audio_sample_rate, config.audio_sample_rate)
self.assertAlmostEqual(result.audio_duration, config.duration, delta=0.5)
assert result.audio_sample_rate == config.audio_sample_rate
assert result.audio_duration == approx(config.duration, abs=0.5)
def compare_decoding_result(self, tv_result, ref_result, config=all_check_config):
"""
......@@ -334,10 +325,18 @@ class TestVideoReader(unittest.TestCase):
decoder or TorchVision decoder with getPtsOnly = 1
config: config of decoding results checker
"""
vframes, vframe_pts, vtimebase, _vfps, _vduration, \
aframes, aframe_pts, atimebase, _asample_rate, _aduration = (
tv_result
)
(
vframes,
vframe_pts,
vtimebase,
_vfps,
_vduration,
aframes,
aframe_pts,
atimebase,
_asample_rate,
_aduration,
) = tv_result
if isinstance(ref_result, list):
# the ref_result is from new video_reader decoder
ref_result = DecoderResult(
......@@ -350,43 +349,32 @@ class TestVideoReader(unittest.TestCase):
)
if vframes.numel() > 0 and ref_result.vframes.numel() > 0:
mean_delta = torch.mean(
torch.abs(vframes.float() - ref_result.vframes.float())
)
self.assertAlmostEqual(mean_delta, 0, delta=8.0)
mean_delta = torch.mean(torch.abs(vframes.float() - ref_result.vframes.float()))
assert mean_delta == approx(0.0, abs=8.0)
mean_delta = torch.mean(
torch.abs(vframe_pts.float() - ref_result.vframe_pts.float())
)
self.assertAlmostEqual(mean_delta, 0, delta=1.0)
mean_delta = torch.mean(torch.abs(vframe_pts.float() - ref_result.vframe_pts.float()))
assert mean_delta == approx(0.0, abs=1.0)
assert_equal(vtimebase, ref_result.vtimebase)
if (
config.check_aframes
and aframes.numel() > 0
and ref_result.aframes.numel() > 0
):
if config.check_aframes and aframes.numel() > 0 and ref_result.aframes.numel() > 0:
"""Audio stream is available and audio frame is required to return
from decoder"""
assert_equal(aframes, ref_result.aframes)
if (
config.check_aframe_pts
and aframe_pts.numel() > 0
and ref_result.aframe_pts.numel() > 0
):
if config.check_aframe_pts and aframe_pts.numel() > 0 and ref_result.aframe_pts.numel() > 0:
"""Audio stream is available"""
assert_equal(aframe_pts, ref_result.aframe_pts)
assert_equal(atimebase, ref_result.atimebase)
@unittest.skip(
"This stress test will iteratively decode the same set of videos."
"It helps to detect memory leak but it takes lots of time to run."
"By default, it is disabled"
)
def test_stress_test_read_video_from_file(self):
@pytest.mark.parametrize("test_video", test_videos.keys())
def test_stress_test_read_video_from_file(self, test_video):
pytest.skip(
"This stress test will iteratively decode the same set of videos."
"It helps to detect memory leak but it takes lots of time to run."
"By default, it is disabled"
)
num_iter = 10000
# video related
width, height, min_dimension, max_dimension = 0, 0, 0, 0
......@@ -398,53 +386,12 @@ class TestVideoReader(unittest.TestCase):
audio_timebase_num, audio_timebase_den = 0, 1
for _i in range(num_iter):
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
# pass 1: decode all frames using new decoder
torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
@PY39_SKIP
def test_read_video_from_file(self):
"""
Test the case when decoder starts with a video file to decode frames.
"""
# video related
width, height, min_dimension, max_dimension = 0, 0, 0, 0
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
samples, channels = 0, 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
# pass 1: decode all frames using new decoder
tv_result = torch.ops.video_reader.read_video_from_file(
torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
SEEK_FRAME_MARGIN,
0, # getPtsOnly
1, # readVideoStream
width,
......@@ -463,15 +410,57 @@ class TestVideoReader(unittest.TestCase):
audio_timebase_num,
audio_timebase_den,
)
# pass 2: decode all frames using av
pyav_result = _decode_frames_by_av_module(full_path)
# check results from TorchVision decoder
self.check_separate_decoding_result(tv_result, config)
# compare decoding results
self.compare_decoding_result(tv_result, pyav_result, config)
@PY39_SKIP
def test_read_video_from_file_read_single_stream_only(self):
@pytest.mark.parametrize("test_video,config", test_videos.items())
def test_read_video_from_file(self, test_video, config):
"""
Test the case when decoder starts with a video file to decode frames.
"""
# video related
width, height, min_dimension, max_dimension = 0, 0, 0, 0
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
samples, channels = 0, 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
full_path = os.path.join(VIDEO_DIR, test_video)
# pass 1: decode all frames using new decoder
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
SEEK_FRAME_MARGIN,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
# pass 2: decode all frames using av
pyav_result = _decode_frames_by_av_module(full_path)
# check results from TorchVision decoder
self.check_separate_decoding_result(tv_result, config)
# compare decoding results
self.compare_decoding_result(tv_result, pyav_result, config)
@pytest.mark.parametrize("test_video,config", test_videos.items())
@pytest.mark.parametrize("read_video_stream,read_audio_stream", [(1, 0), (0, 1)])
def test_read_video_from_file_read_single_stream_only(
self, test_video, config, read_video_stream, read_audio_stream
):
"""
Test the case when decoder starts with a video file to decode frames, and
only reads video stream and ignores audio stream
......@@ -485,51 +474,56 @@ class TestVideoReader(unittest.TestCase):
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
for readVideoStream, readAudioStream in [(1, 0), (0, 1)]:
# decode all frames using new decoder
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
0, # getPtsOnly
readVideoStream,
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
readAudioStream,
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
vframes, vframe_pts, vtimebase, vfps, vduration, \
aframes, aframe_pts, atimebase, asample_rate, aduration = (
tv_result
)
self.assertEqual(vframes.numel() > 0, readVideoStream)
self.assertEqual(vframe_pts.numel() > 0, readVideoStream)
self.assertEqual(vtimebase.numel() > 0, readVideoStream)
self.assertEqual(vfps.numel() > 0, readVideoStream)
expect_audio_data = (
readAudioStream == 1 and config.audio_sample_rate is not None
)
self.assertEqual(aframes.numel() > 0, expect_audio_data)
self.assertEqual(aframe_pts.numel() > 0, expect_audio_data)
self.assertEqual(atimebase.numel() > 0, expect_audio_data)
self.assertEqual(asample_rate.numel() > 0, expect_audio_data)
def test_read_video_from_file_rescale_min_dimension(self):
full_path = os.path.join(VIDEO_DIR, test_video)
# decode all frames using new decoder
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
SEEK_FRAME_MARGIN,
0, # getPtsOnly
read_video_stream,
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
read_audio_stream,
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
(
vframes,
vframe_pts,
vtimebase,
vfps,
vduration,
aframes,
aframe_pts,
atimebase,
asample_rate,
aduration,
) = tv_result
assert (vframes.numel() > 0) is bool(read_video_stream)
assert (vframe_pts.numel() > 0) is bool(read_video_stream)
assert (vtimebase.numel() > 0) is bool(read_video_stream)
assert (vfps.numel() > 0) is bool(read_video_stream)
expect_audio_data = read_audio_stream == 1 and config.audio_sample_rate is not None
assert (aframes.numel() > 0) is bool(expect_audio_data)
assert (aframe_pts.numel() > 0) is bool(expect_audio_data)
assert (atimebase.numel() > 0) is bool(expect_audio_data)
assert (asample_rate.numel() > 0) is bool(expect_audio_data)
@pytest.mark.parametrize("test_video", test_videos.keys())
def test_read_video_from_file_rescale_min_dimension(self, test_video):
"""
Test the case when decoder starts with a video file to decode frames, and
video min dimension between height and width is set.
......@@ -543,35 +537,33 @@ class TestVideoReader(unittest.TestCase):
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
self.assertEqual(
min_dimension, min(tv_result[0].size(1), tv_result[0].size(2))
)
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
SEEK_FRAME_MARGIN,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
assert min_dimension == min(tv_result[0].size(1), tv_result[0].size(2))
def test_read_video_from_file_rescale_max_dimension(self):
@pytest.mark.parametrize("test_video", test_videos.keys())
def test_read_video_from_file_rescale_max_dimension(self, test_video):
"""
Test the case when decoder starts with a video file to decode frames, and
video min dimension between height and width is set.
......@@ -585,35 +577,33 @@ class TestVideoReader(unittest.TestCase):
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
self.assertEqual(
max_dimension, max(tv_result[0].size(1), tv_result[0].size(2))
)
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
SEEK_FRAME_MARGIN,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
assert max_dimension == max(tv_result[0].size(1), tv_result[0].size(2))
def test_read_video_from_file_rescale_both_min_max_dimension(self):
@pytest.mark.parametrize("test_video", test_videos.keys())
def test_read_video_from_file_rescale_both_min_max_dimension(self, test_video):
"""
Test the case when decoder starts with a video file to decode frames, and
video min dimension between height and width is set.
......@@ -627,38 +617,34 @@ class TestVideoReader(unittest.TestCase):
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
self.assertEqual(
min_dimension, min(tv_result[0].size(1), tv_result[0].size(2))
)
self.assertEqual(
max_dimension, max(tv_result[0].size(1), tv_result[0].size(2))
)
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
SEEK_FRAME_MARGIN,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
assert min_dimension == min(tv_result[0].size(1), tv_result[0].size(2))
assert max_dimension == max(tv_result[0].size(1), tv_result[0].size(2))
def test_read_video_from_file_rescale_width(self):
@pytest.mark.parametrize("test_video", test_videos.keys())
def test_read_video_from_file_rescale_width(self, test_video):
"""
Test the case when decoder starts with a video file to decode frames, and
video width is set.
......@@ -672,33 +658,33 @@ class TestVideoReader(unittest.TestCase):
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
self.assertEqual(tv_result[0].size(2), width)
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
SEEK_FRAME_MARGIN,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
assert tv_result[0].size(2) == width
def test_read_video_from_file_rescale_height(self):
@pytest.mark.parametrize("test_video", test_videos.keys())
def test_read_video_from_file_rescale_height(self, test_video):
"""
Test the case when decoder starts with a video file to decode frames, and
video height is set.
......@@ -712,33 +698,33 @@ class TestVideoReader(unittest.TestCase):
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
self.assertEqual(tv_result[0].size(1), height)
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
SEEK_FRAME_MARGIN,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
assert tv_result[0].size(1) == height
def test_read_video_from_file_rescale_width_and_height(self):
@pytest.mark.parametrize("test_video", test_videos.keys())
def test_read_video_from_file_rescale_width_and_height(self, test_video):
"""
Test the case when decoder starts with a video file to decode frames, and
both video height and width are set.
......@@ -752,95 +738,92 @@ class TestVideoReader(unittest.TestCase):
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
self.assertEqual(tv_result[0].size(1), height)
self.assertEqual(tv_result[0].size(2), width)
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
SEEK_FRAME_MARGIN,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
assert tv_result[0].size(1) == height
assert tv_result[0].size(2) == width
@PY39_SKIP
def test_read_video_from_file_audio_resampling(self):
@pytest.mark.parametrize("test_video", test_videos.keys())
@pytest.mark.parametrize("samples", [9600, 96000])
def test_read_video_from_file_audio_resampling(self, test_video, samples):
"""
Test the case when decoder starts with a video file to decode frames, and
audio waveform are resampled
"""
# video related
width, height, min_dimension, max_dimension = 0, 0, 0, 0
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
channels = 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for samples in [9600, 96000]: # downsampling # upsampling
# video related
width, height, min_dimension, max_dimension = 0, 0, 0, 0
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
channels = 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
vframes, vframe_pts, vtimebase, vfps, vduration, \
aframes, aframe_pts, atimebase, asample_rate, aduration = (
tv_result
)
if aframes.numel() > 0:
self.assertEqual(samples, asample_rate.item())
self.assertEqual(1, aframes.size(1))
# when audio stream is found
duration = (
float(aframe_pts[-1])
* float(atimebase[0])
/ float(atimebase[1])
)
self.assertAlmostEqual(
aframes.size(0),
int(duration * asample_rate.item()),
delta=0.1 * asample_rate.item(),
)
@PY39_SKIP
def test_compare_read_video_from_memory_and_file(self):
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
SEEK_FRAME_MARGIN,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
(
vframes,
vframe_pts,
vtimebase,
vfps,
vduration,
aframes,
aframe_pts,
atimebase,
asample_rate,
aduration,
) = tv_result
if aframes.numel() > 0:
assert samples == asample_rate.item()
assert 1 == aframes.size(1)
# when audio stream is found
duration = float(aframe_pts[-1]) * float(atimebase[0]) / float(atimebase[1])
assert aframes.size(0) == approx(int(duration * asample_rate.item()), abs=0.1 * asample_rate.item())
@pytest.mark.parametrize("test_video,config", test_videos.items())
def test_compare_read_video_from_memory_and_file(self, test_video, config):
"""
Test the case when video is already in memory, and decoder reads data in memory
"""
......@@ -853,61 +836,60 @@ class TestVideoReader(unittest.TestCase):
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
# pass 1: decode all frames using cpp decoder
tv_result_memory = torch.ops.video_reader.read_video_from_memory(
video_tensor,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
self.check_separate_decoding_result(tv_result_memory, config)
# pass 2: decode all frames from file
tv_result_file = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
# pass 1: decode all frames using cpp decoder
tv_result_memory = torch.ops.video_reader.read_video_from_memory(
video_tensor,
SEEK_FRAME_MARGIN,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
self.check_separate_decoding_result(tv_result_memory, config)
# pass 2: decode all frames from file
tv_result_file = torch.ops.video_reader.read_video_from_file(
full_path,
SEEK_FRAME_MARGIN,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
self.check_separate_decoding_result(tv_result_file, config)
# finally, compare results decoded from memory and file
self.compare_decoding_result(tv_result_memory, tv_result_file)
self.check_separate_decoding_result(tv_result_file, config)
# finally, compare results decoded from memory and file
self.compare_decoding_result(tv_result_memory, tv_result_file)
@PY39_SKIP
def test_read_video_from_memory(self):
@pytest.mark.parametrize("test_video,config", test_videos.items())
def test_read_video_from_memory(self, test_video, config):
"""
Test the case when video is already in memory, and decoder reads data in memory
"""
......@@ -920,39 +902,38 @@ class TestVideoReader(unittest.TestCase):
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
# pass 1: decode all frames using cpp decoder
tv_result = torch.ops.video_reader.read_video_from_memory(
video_tensor,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
# pass 2: decode all frames using av
pyav_result = _decode_frames_by_av_module(full_path)
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
# pass 1: decode all frames using cpp decoder
tv_result = torch.ops.video_reader.read_video_from_memory(
video_tensor,
SEEK_FRAME_MARGIN,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
# pass 2: decode all frames using av
pyav_result = _decode_frames_by_av_module(full_path)
self.check_separate_decoding_result(tv_result, config)
self.compare_decoding_result(tv_result, pyav_result, config)
self.check_separate_decoding_result(tv_result, config)
self.compare_decoding_result(tv_result, pyav_result, config)
@PY39_SKIP
def test_read_video_from_memory_get_pts_only(self):
@pytest.mark.parametrize("test_video,config", test_videos.items())
def test_read_video_from_memory_get_pts_only(self, test_video, config):
"""
Test the case when video is already in memory, and decoder reads data in memory.
Compare frame pts between decoding for pts only and full decoding
......@@ -967,238 +948,234 @@ class TestVideoReader(unittest.TestCase):
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
# pass 1: decode all frames using cpp decoder
tv_result = torch.ops.video_reader.read_video_from_memory(
video_tensor,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
self.assertAlmostEqual(config.video_fps, tv_result[3].item(), delta=0.01)
# pass 2: decode all frames to get PTS only using cpp decoder
tv_result_pts_only = torch.ops.video_reader.read_video_from_memory(
video_tensor,
seek_frame_margin,
1, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
_, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
# pass 1: decode all frames using cpp decoder
tv_result = torch.ops.video_reader.read_video_from_memory(
video_tensor,
SEEK_FRAME_MARGIN,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
assert abs(config.video_fps - tv_result[3].item()) < 0.01
# pass 2: decode all frames to get PTS only using cpp decoder
tv_result_pts_only = torch.ops.video_reader.read_video_from_memory(
video_tensor,
SEEK_FRAME_MARGIN,
1, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
self.assertEqual(tv_result_pts_only[0].numel(), 0)
self.assertEqual(tv_result_pts_only[5].numel(), 0)
self.compare_decoding_result(tv_result, tv_result_pts_only)
assert not tv_result_pts_only[0].numel()
assert not tv_result_pts_only[5].numel()
self.compare_decoding_result(tv_result, tv_result_pts_only)
@PY39_SKIP
def test_read_video_in_range_from_memory(self):
@pytest.mark.parametrize("test_video,config", test_videos.items())
@pytest.mark.parametrize("num_frames", [4, 8, 16, 32, 64, 128])
def test_read_video_in_range_from_memory(self, test_video, config, num_frames):
"""
Test the case when video is already in memory, and decoder reads data in memory.
In addition, decoder takes meaningful start- and end PTS as input, and decode
frames within that interval
"""
for test_video, config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
# video related
width, height, min_dimension, max_dimension = 0, 0, 0, 0
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
samples, channels = 0, 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
# pass 1: decode all frames using new decoder
tv_result = torch.ops.video_reader.read_video_from_memory(
video_tensor,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
# video related
width, height, min_dimension, max_dimension = 0, 0, 0, 0
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
samples, channels = 0, 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
# pass 1: decode all frames using new decoder
tv_result = torch.ops.video_reader.read_video_from_memory(
video_tensor,
SEEK_FRAME_MARGIN,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
(
vframes,
vframe_pts,
vtimebase,
vfps,
vduration,
aframes,
aframe_pts,
atimebase,
asample_rate,
aduration,
) = tv_result
assert abs(config.video_fps - vfps.item()) < 0.01
start_pts_ind_max = vframe_pts.size(0) - num_frames
if start_pts_ind_max <= 0:
return
# randomly pick start pts
start_pts_ind = randint(0, start_pts_ind_max)
end_pts_ind = start_pts_ind + num_frames - 1
video_start_pts = vframe_pts[start_pts_ind]
video_end_pts = vframe_pts[end_pts_ind]
video_timebase_num, video_timebase_den = vtimebase[0], vtimebase[1]
if len(atimebase) > 0:
# when audio stream is available
audio_timebase_num, audio_timebase_den = atimebase[0], atimebase[1]
audio_start_pts = _pts_convert(
video_start_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(audio_timebase_num.item(), audio_timebase_den.item()),
math.floor,
)
audio_end_pts = _pts_convert(
video_end_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(audio_timebase_num.item(), audio_timebase_den.item()),
math.ceil,
)
# pass 2: decode frames in the randomly generated range
tv_result = torch.ops.video_reader.read_video_from_memory(
video_tensor,
SEEK_FRAME_MARGIN,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
# pass 3: decode frames in range using PyAv
video_timebase_av, audio_timebase_av = _get_timebase_by_av_module(full_path)
video_start_pts_av = _pts_convert(
video_start_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(video_timebase_av.numerator, video_timebase_av.denominator),
math.floor,
)
video_end_pts_av = _pts_convert(
video_end_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(video_timebase_av.numerator, video_timebase_av.denominator),
math.ceil,
)
if audio_timebase_av:
audio_start_pts = _pts_convert(
video_start_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(audio_timebase_av.numerator, audio_timebase_av.denominator),
math.floor,
)
audio_end_pts = _pts_convert(
video_end_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(audio_timebase_av.numerator, audio_timebase_av.denominator),
math.ceil,
)
vframes, vframe_pts, vtimebase, vfps, vduration, \
aframes, aframe_pts, atimebase, asample_rate, aduration = (
tv_result
)
self.assertAlmostEqual(config.video_fps, vfps.item(), delta=0.01)
for num_frames in [4, 8, 16, 32, 64, 128]:
start_pts_ind_max = vframe_pts.size(0) - num_frames
if start_pts_ind_max <= 0:
continue
# randomly pick start pts
start_pts_ind = randint(0, start_pts_ind_max)
end_pts_ind = start_pts_ind + num_frames - 1
video_start_pts = vframe_pts[start_pts_ind]
video_end_pts = vframe_pts[end_pts_ind]
video_timebase_num, video_timebase_den = vtimebase[0], vtimebase[1]
if len(atimebase) > 0:
# when audio stream is available
audio_timebase_num, audio_timebase_den = atimebase[0], atimebase[1]
audio_start_pts = _pts_convert(
video_start_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(audio_timebase_num.item(), audio_timebase_den.item()),
math.floor,
)
audio_end_pts = _pts_convert(
video_end_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(audio_timebase_num.item(), audio_timebase_den.item()),
math.ceil,
)
# pass 2: decode frames in the randomly generated range
tv_result = torch.ops.video_reader.read_video_from_memory(
video_tensor,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
# pass 3: decode frames in range using PyAv
video_timebase_av, audio_timebase_av = _get_timebase_by_av_module(
full_path
)
video_start_pts_av = _pts_convert(
video_start_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(
video_timebase_av.numerator, video_timebase_av.denominator
),
math.floor,
)
video_end_pts_av = _pts_convert(
video_end_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(
video_timebase_av.numerator, video_timebase_av.denominator
),
math.ceil,
)
if audio_timebase_av:
audio_start_pts = _pts_convert(
video_start_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(
audio_timebase_av.numerator, audio_timebase_av.denominator
),
math.floor,
)
audio_end_pts = _pts_convert(
video_end_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(
audio_timebase_av.numerator, audio_timebase_av.denominator
),
math.ceil,
)
pyav_result = _decode_frames_by_av_module(
full_path,
video_start_pts_av,
video_end_pts_av,
audio_start_pts,
audio_end_pts,
)
self.assertEqual(tv_result[0].size(0), num_frames)
if pyav_result.vframes.size(0) == num_frames:
# if PyAv decodes a different number of video frames, skip
# comparing the decoding results between Torchvision video reader
# and PyAv
self.compare_decoding_result(tv_result, pyav_result, config)
def test_probe_video_from_file(self):
pyav_result = _decode_frames_by_av_module(
full_path,
video_start_pts_av,
video_end_pts_av,
audio_start_pts,
audio_end_pts,
)
assert tv_result[0].size(0) == num_frames
if pyav_result.vframes.size(0) == num_frames:
# if PyAv decodes a different number of video frames, skip
# comparing the decoding results between Torchvision video reader
# and PyAv
self.compare_decoding_result(tv_result, pyav_result, config)
@pytest.mark.parametrize("test_video,config", test_videos.items())
def test_probe_video_from_file(self, test_video, config):
"""
Test the case when decoder probes a video file
"""
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
probe_result = torch.ops.video_reader.probe_video_from_file(full_path)
self.check_probe_result(probe_result, config)
full_path = os.path.join(VIDEO_DIR, test_video)
probe_result = torch.ops.video_reader.probe_video_from_file(full_path)
self.check_probe_result(probe_result, config)
def test_probe_video_from_memory(self):
@pytest.mark.parametrize("test_video,config", test_videos.items())
def test_probe_video_from_memory(self, test_video, config):
"""
Test the case when decoder probes a video in memory
"""
for test_video, config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
probe_result = torch.ops.video_reader.probe_video_from_memory(video_tensor)
self.check_probe_result(probe_result, config)
_, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
probe_result = torch.ops.video_reader.probe_video_from_memory(video_tensor)
self.check_probe_result(probe_result, config)
def test_probe_video_from_memory_script(self):
@pytest.mark.parametrize("test_video,config", test_videos.items())
def test_probe_video_from_memory_script(self, test_video, config):
scripted_fun = torch.jit.script(io._probe_video_from_memory)
self.assertIsNotNone(scripted_fun)
assert scripted_fun is not None
for test_video, config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
probe_result = scripted_fun(video_tensor)
self.check_meta_result(probe_result, config)
_, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
probe_result = scripted_fun(video_tensor)
self.check_meta_result(probe_result, config)
@PY39_SKIP
def test_read_video_from_memory_scripted(self):
@pytest.mark.parametrize("test_video", test_videos.keys())
def test_read_video_from_memory_scripted(self, test_video):
"""
Test the case when video is already in memory, and decoder reads data in memory
"""
......@@ -1212,71 +1189,66 @@ class TestVideoReader(unittest.TestCase):
audio_timebase_num, audio_timebase_den = 0, 1
scripted_fun = torch.jit.script(io._read_video_from_memory)
self.assertIsNotNone(scripted_fun)
for test_video, _config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
# decode all frames using cpp decoder
scripted_fun(
video_tensor,
seek_frame_margin,
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
[video_start_pts, video_end_pts],
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
[audio_start_pts, audio_end_pts],
audio_timebase_num,
audio_timebase_den,
)
# FUTURE: check value of video / audio frames
def test_audio_video_sync(self):
"""Test if audio/video are synchronised with pyav output."""
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
container = av.open(full_path)
if not container.streams.audio:
# Skip if no audio stream
continue
start_pts_val, cutoff = 0, 1
if container.streams.video:
video = container.streams.video[0]
arr = []
for index, frame in enumerate(container.decode(video)):
if index == cutoff:
start_pts_val = frame.pts
if index >= cutoff:
arr.append(frame.to_rgb().to_ndarray())
visual, _, info = io.read_video(full_path, start_pts=start_pts_val, pts_unit='pts')
self.assertAlmostEqual(
config.video_fps, info['video_fps'], delta=0.0001
)
arr = torch.Tensor(arr)
if arr.shape == visual.shape:
self.assertGreaterEqual(
torch.mean(torch.isclose(visual.float(), arr, atol=1e-5).float()), 0.99)
container = av.open(full_path)
if container.streams.audio:
audio = container.streams.audio[0]
arr = []
for index, frame in enumerate(container.decode(audio)):
if index >= cutoff:
arr.append(frame.to_ndarray())
_, audio, _ = io.read_video(full_path, start_pts=start_pts_val, pts_unit='pts')
arr = torch.as_tensor(np.concatenate(arr, axis=1))
if arr.shape == audio.shape:
self.assertGreaterEqual(
torch.mean(torch.isclose(audio.float(), arr).float()), 0.99)
assert scripted_fun is not None
_, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
# decode all frames using cpp decoder
scripted_fun(
video_tensor,
SEEK_FRAME_MARGIN,
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
[video_start_pts, video_end_pts],
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
[audio_start_pts, audio_end_pts],
audio_timebase_num,
audio_timebase_den,
)
# FUTURE: check value of video / audio frames
def test_invalid_file(self):
set_video_backend("video_reader")
with pytest.raises(RuntimeError):
io.read_video("foo.mp4")
set_video_backend("pyav")
with pytest.raises(RuntimeError):
io.read_video("foo.mp4")
@pytest.mark.parametrize("test_video", test_videos.keys())
@pytest.mark.parametrize("backend", ["video_reader", "pyav"])
@pytest.mark.parametrize("start_offset", [0, 500])
@pytest.mark.parametrize("end_offset", [3000, None])
def test_audio_present_pts(self, test_video, backend, start_offset, end_offset):
"""Test if audio frames are returned with pts unit."""
full_path = os.path.join(VIDEO_DIR, test_video)
container = av.open(full_path)
if container.streams.audio:
set_video_backend(backend)
_, audio, _ = io.read_video(full_path, start_offset, end_offset, pts_unit="pts")
assert all([dimension > 0 for dimension in audio.shape[:2]])
@pytest.mark.parametrize("test_video", test_videos.keys())
@pytest.mark.parametrize("backend", ["video_reader", "pyav"])
@pytest.mark.parametrize("start_offset", [0, 0.1])
@pytest.mark.parametrize("end_offset", [0.3, None])
def test_audio_present_sec(self, test_video, backend, start_offset, end_offset):
"""Test if audio frames are returned with sec unit."""
full_path = os.path.join(VIDEO_DIR, test_video)
container = av.open(full_path)
if container.streams.audio:
set_video_backend(backend)
_, audio, _ = io.read_video(full_path, start_offset, end_offset, pts_unit="sec")
assert all([dimension > 0 for dimension in audio.shape[:2]])
if __name__ == "__main__":
unittest.main()
pytest.main([__file__])
import collections
import os
import unittest
import urllib
import pytest
import torch
import torchvision
from torchvision.io import _HAS_VIDEO_OPT, VideoReader
from pytest import approx
from torchvision.datasets.utils import download_url
from torchvision.io import _HAS_VIDEO_OPT, VideoReader
# WARNING: these tests have been skipped forever on the CI because the video ops
# are never properly available. This is bad, but things have been in a terrible
# state for a long time already as we write this comment, and we'll hopefully be
# able to get rid of this all soon.
from common_utils import PY39_SKIP
try:
import av
......@@ -24,6 +31,13 @@ CheckerConfig = ["duration", "video_fps", "audio_sample_rate"]
GroundTruth = collections.namedtuple("GroundTruth", " ".join(CheckerConfig))
def backends():
backends_ = ["video_reader"]
if av is not None:
backends_.append("pyav")
return backends_
def fate(name, path="."):
"""Download and return a path to a sample from the FFmpeg test suite.
See the `FFmpeg Automated Test Environment <https://www.ffmpeg.org/fate.html>`_
......@@ -35,166 +49,264 @@ def fate(name, path="."):
test_videos = {
"RATRACE_wave_f_nm_np1_fr_goo_37.avi": GroundTruth(
duration=2.0, video_fps=30.0, audio_sample_rate=None
),
"RATRACE_wave_f_nm_np1_fr_goo_37.avi": GroundTruth(duration=2.0, video_fps=30.0, audio_sample_rate=None),
"SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi": GroundTruth(
duration=2.0, video_fps=30.0, audio_sample_rate=None
),
"TrumanShow_wave_f_nm_np1_fr_med_26.avi": GroundTruth(
duration=2.0, video_fps=30.0, audio_sample_rate=None
),
"v_SoccerJuggling_g23_c01.avi": GroundTruth(
duration=8.0, video_fps=29.97, audio_sample_rate=None
),
"v_SoccerJuggling_g24_c01.avi": GroundTruth(
duration=8.0, video_fps=29.97, audio_sample_rate=None
),
"R6llTwEh07w.mp4": GroundTruth(
duration=10.0, video_fps=30.0, audio_sample_rate=44100
),
"SOX5yA1l24A.mp4": GroundTruth(
duration=11.0, video_fps=29.97, audio_sample_rate=48000
),
"WUzgd7C1pWA.mp4": GroundTruth(
duration=11.0, video_fps=29.97, audio_sample_rate=48000
),
"TrumanShow_wave_f_nm_np1_fr_med_26.avi": GroundTruth(duration=2.0, video_fps=30.0, audio_sample_rate=None),
"v_SoccerJuggling_g23_c01.avi": GroundTruth(duration=8.0, video_fps=29.97, audio_sample_rate=None),
"v_SoccerJuggling_g24_c01.avi": GroundTruth(duration=8.0, video_fps=29.97, audio_sample_rate=None),
"R6llTwEh07w.mp4": GroundTruth(duration=10.0, video_fps=30.0, audio_sample_rate=44100),
"SOX5yA1l24A.mp4": GroundTruth(duration=11.0, video_fps=29.97, audio_sample_rate=48000),
"WUzgd7C1pWA.mp4": GroundTruth(duration=11.0, video_fps=29.97, audio_sample_rate=48000),
}
@unittest.skipIf(_HAS_VIDEO_OPT is False, "Didn't compile with ffmpeg")
@PY39_SKIP
class TestVideoApi(unittest.TestCase):
@unittest.skipIf(av is None, "PyAV unavailable")
def test_frame_reading(self):
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
av_reader = av.open(full_path)
@pytest.mark.skipif(_HAS_VIDEO_OPT is False, reason="Didn't compile with ffmpeg")
class TestVideoApi:
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
@pytest.mark.parametrize("test_video", test_videos.keys())
@pytest.mark.parametrize("backend", backends())
def test_frame_reading(self, test_video, backend):
torchvision.set_video_backend(backend)
full_path = os.path.join(VIDEO_DIR, test_video)
with av.open(full_path) as av_reader:
if av_reader.streams.video:
video_reader = VideoReader(full_path, "video")
av_frames, vr_frames = [], []
av_pts, vr_pts = [], []
# get av frames
for av_frame in av_reader.decode(av_reader.streams.video[0]):
vr_frame = next(video_reader)
self.assertAlmostEqual(
float(av_frame.pts * av_frame.time_base),
vr_frame["pts"],
delta=0.1,
)
av_array = torch.tensor(av_frame.to_rgb().to_ndarray()).permute(
2, 0, 1
)
vr_array = vr_frame["data"]
mean_delta = torch.mean(
torch.abs(av_array.float() - vr_array.float())
)
av_frames.append(torch.tensor(av_frame.to_rgb().to_ndarray()).permute(2, 0, 1))
av_pts.append(av_frame.pts * av_frame.time_base)
# get vr frames
video_reader = VideoReader(full_path, "video")
for vr_frame in video_reader:
vr_frames.append(vr_frame["data"])
vr_pts.append(vr_frame["pts"])
# same number of frames
assert len(vr_frames) == len(av_frames)
assert len(vr_pts) == len(av_pts)
# compare the frames and ptss
for i in range(len(vr_frames)):
assert float(av_pts[i]) == approx(vr_pts[i], abs=0.1)
mean_delta = torch.mean(torch.abs(av_frames[i].float() - vr_frames[i].float()))
# on average the difference is very small and caused
# by decoding (around 1%)
# TODO: asses empirically how to set this? atm it's 1%
# averaged over all frames
self.assertTrue(mean_delta.item() < 2.5)
assert mean_delta.item() < 2.55
del vr_frames, av_frames, vr_pts, av_pts
av_reader = av.open(full_path)
# test audio reading compared to PYAV
with av.open(full_path) as av_reader:
if av_reader.streams.audio:
video_reader = VideoReader(full_path, "audio")
av_frames, vr_frames = [], []
av_pts, vr_pts = [], []
# get av frames
for av_frame in av_reader.decode(av_reader.streams.audio[0]):
vr_frame = next(video_reader)
self.assertAlmostEqual(
float(av_frame.pts * av_frame.time_base),
vr_frame["pts"],
delta=0.1,
)
av_array = torch.tensor(av_frame.to_ndarray()).permute(1, 0)
vr_array = vr_frame["data"]
max_delta = torch.max(
torch.abs(av_array.float() - vr_array.float())
)
# we assure that there is never more than 1% difference in signal
self.assertTrue(max_delta.item() < 0.001)
av_frames.append(torch.tensor(av_frame.to_ndarray()).permute(1, 0))
av_pts.append(av_frame.pts * av_frame.time_base)
av_reader.close()
def test_metadata(self):
# get vr frames
video_reader = VideoReader(full_path, "audio")
for vr_frame in video_reader:
vr_frames.append(vr_frame["data"])
vr_pts.append(vr_frame["pts"])
# same number of frames
assert len(vr_frames) == len(av_frames)
assert len(vr_pts) == len(av_pts)
# compare the frames and ptss
for i in range(len(vr_frames)):
assert float(av_pts[i]) == approx(vr_pts[i], abs=0.1)
max_delta = torch.max(torch.abs(av_frames[i].float() - vr_frames[i].float()))
# we assure that there is never more than 1% difference in signal
assert max_delta.item() < 0.001
@pytest.mark.parametrize("stream", ["video", "audio"])
@pytest.mark.parametrize("test_video", test_videos.keys())
@pytest.mark.parametrize("backend", backends())
def test_frame_reading_mem_vs_file(self, test_video, stream, backend):
torchvision.set_video_backend(backend)
full_path = os.path.join(VIDEO_DIR, test_video)
reader = VideoReader(full_path)
reader_md = reader.get_metadata()
if stream in reader_md:
# Test video reading from file vs from memory
vr_frames, vr_frames_mem = [], []
vr_pts, vr_pts_mem = [], []
# get vr frames
video_reader = VideoReader(full_path, stream)
for vr_frame in video_reader:
vr_frames.append(vr_frame["data"])
vr_pts.append(vr_frame["pts"])
# get vr frames = read from memory
f = open(full_path, "rb")
fbytes = f.read()
f.close()
video_reader_from_mem = VideoReader(fbytes, stream)
for vr_frame_from_mem in video_reader_from_mem:
vr_frames_mem.append(vr_frame_from_mem["data"])
vr_pts_mem.append(vr_frame_from_mem["pts"])
# same number of frames
assert len(vr_frames) == len(vr_frames_mem)
assert len(vr_pts) == len(vr_pts_mem)
# compare the frames and ptss
for i in range(len(vr_frames)):
assert vr_pts[i] == vr_pts_mem[i]
mean_delta = torch.mean(torch.abs(vr_frames[i].float() - vr_frames_mem[i].float()))
# on average the difference is very small and caused
# by decoding (around 1%)
# TODO: asses empirically how to set this? atm it's 1%
# averaged over all frames
assert mean_delta.item() < 2.55
del vr_frames, vr_pts, vr_frames_mem, vr_pts_mem
else:
del reader, reader_md
@pytest.mark.parametrize("test_video,config", test_videos.items())
@pytest.mark.parametrize("backend", backends())
def test_metadata(self, test_video, config, backend):
"""
Test that the metadata returned via pyav corresponds to the one returned
by the new video decoder API
"""
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
reader = VideoReader(full_path, "video")
reader_md = reader.get_metadata()
self.assertAlmostEqual(
config.video_fps, reader_md["video"]["fps"][0], delta=0.0001
)
self.assertAlmostEqual(
config.duration, reader_md["video"]["duration"][0], delta=0.5
)
def test_seek_start(self):
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
video_reader = VideoReader(full_path, "video")
torchvision.set_video_backend(backend)
full_path = os.path.join(VIDEO_DIR, test_video)
reader = VideoReader(full_path, "video")
reader_md = reader.get_metadata()
assert config.video_fps == approx(reader_md["video"]["fps"][0], abs=0.0001)
assert config.duration == approx(reader_md["video"]["duration"][0], abs=0.5)
@pytest.mark.parametrize("test_video", test_videos.keys())
@pytest.mark.parametrize("backend", backends())
def test_seek_start(self, test_video, backend):
torchvision.set_video_backend(backend)
full_path = os.path.join(VIDEO_DIR, test_video)
video_reader = VideoReader(full_path, "video")
num_frames = 0
for _ in video_reader:
num_frames += 1
# now seek the container to 0 and do it again
# It's often that starting seek can be inprecise
# this way and it doesn't start at 0
video_reader.seek(0)
start_num_frames = 0
for _ in video_reader:
start_num_frames += 1
assert start_num_frames == num_frames
# now seek the container to < 0 to check for unexpected behaviour
video_reader.seek(-1)
start_num_frames = 0
for _ in video_reader:
start_num_frames += 1
assert start_num_frames == num_frames
@pytest.mark.parametrize("test_video", test_videos.keys())
@pytest.mark.parametrize("backend", ["video_reader"])
def test_accurateseek_middle(self, test_video, backend):
torchvision.set_video_backend(backend)
full_path = os.path.join(VIDEO_DIR, test_video)
stream = "video"
video_reader = VideoReader(full_path, stream)
md = video_reader.get_metadata()
duration = md[stream]["duration"][0]
if duration is not None:
num_frames = 0
for frame in video_reader:
for _ in video_reader:
num_frames += 1
# now seek the container to 0 and do it again
# It's often that starting seek can be inprecise
# this way and it doesn't start at 0
video_reader.seek(0)
start_num_frames = 0
for frame in video_reader:
start_num_frames += 1
self.assertEqual(start_num_frames, num_frames)
video_reader.seek(duration / 2)
middle_num_frames = 0
for _ in video_reader:
middle_num_frames += 1
# now seek the container to < 0 to check for unexpected behaviour
video_reader.seek(-1)
start_num_frames = 0
for frame in video_reader:
start_num_frames += 1
assert middle_num_frames < num_frames
assert middle_num_frames == approx(num_frames // 2, abs=1)
self.assertEqual(start_num_frames, num_frames)
video_reader.seek(duration / 2)
frame = next(video_reader)
lb = duration / 2 - 1 / md[stream]["fps"][0]
ub = duration / 2 + 1 / md[stream]["fps"][0]
assert (lb <= frame["pts"]) and (ub >= frame["pts"])
def test_accurateseek_middle(self):
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
def test_fate_suite(self):
# TODO: remove the try-except statement once the connectivity issues are resolved
try:
video_path = fate("sub/MovText_capability_tester.mp4", VIDEO_DIR)
except (urllib.error.URLError, ConnectionError) as error:
pytest.skip(f"Skipping due to connectivity issues: {error}")
vr = VideoReader(video_path)
metadata = vr.get_metadata()
stream = "video"
video_reader = VideoReader(full_path, stream)
md = video_reader.get_metadata()
duration = md[stream]["duration"][0]
if duration is not None:
assert metadata["subtitles"]["duration"] is not None
os.remove(video_path)
num_frames = 0
for frame in video_reader:
num_frames += 1
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
@pytest.mark.parametrize("test_video,config", test_videos.items())
@pytest.mark.parametrize("backend", backends())
def test_keyframe_reading(self, test_video, config, backend):
torchvision.set_video_backend(backend)
full_path = os.path.join(VIDEO_DIR, test_video)
video_reader.seek(duration / 2)
middle_num_frames = 0
for frame in video_reader:
middle_num_frames += 1
av_reader = av.open(full_path)
# reduce streams to only keyframes
av_stream = av_reader.streams.video[0]
av_stream.codec_context.skip_frame = "NONKEY"
self.assertTrue(middle_num_frames < num_frames)
self.assertAlmostEqual(middle_num_frames, num_frames // 2, delta=1)
av_keyframes = []
vr_keyframes = []
if av_reader.streams.video:
video_reader.seek(duration / 2)
frame = next(video_reader)
lb = duration / 2 - 1 / md[stream]["fps"][0]
ub = duration / 2 + 1 / md[stream]["fps"][0]
self.assertTrue((lb <= frame["pts"]) & (ub >= frame["pts"]))
# get all keyframes using pyav. Then, seek randomly into video reader
# and assert that all the returned values are in AV_KEYFRAMES
def test_fate_suite(self):
video_path = fate("sub/MovText_capability_tester.mp4", VIDEO_DIR)
vr = VideoReader(video_path)
metadata = vr.get_metadata()
for av_frame in av_reader.decode(av_stream):
av_keyframes.append(float(av_frame.pts * av_frame.time_base))
self.assertTrue(metadata["subtitles"]["duration"] is not None)
os.remove(video_path)
if len(av_keyframes) > 1:
video_reader = VideoReader(full_path, "video")
for i in range(1, len(av_keyframes)):
seek_val = (av_keyframes[i] + av_keyframes[i - 1]) / 2
data = next(video_reader.seek(seek_val, True))
vr_keyframes.append(data["pts"])
data = next(video_reader.seek(config.duration, True))
vr_keyframes.append(data["pts"])
assert len(av_keyframes) == len(vr_keyframes)
# NOTE: this video gets different keyframe with different
# loaders (0.333 pyav, 0.666 for us)
if test_video != "TrumanShow_wave_f_nm_np1_fr_med_26.avi":
for i in range(len(av_keyframes)):
assert av_keyframes[i] == approx(vr_keyframes[i], rel=0.001)
def test_src(self):
with pytest.raises(ValueError, match="src cannot be empty"):
VideoReader(src="")
with pytest.raises(ValueError, match="src must be either string"):
VideoReader(src=2)
with pytest.raises(TypeError, match="unexpected keyword argument"):
VideoReader(path="path")
if __name__ == "__main__":
unittest.main()
pytest.main([__file__])
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
project(test_frcnn_tracing)
find_package(Torch REQUIRED)
find_package(TorchVision REQUIRED)
# This due to some headers importing Python.h
find_package(Python3 COMPONENTS Development)
add_executable(test_frcnn_tracing test_frcnn_tracing.cpp)
target_compile_features(test_frcnn_tracing PUBLIC cxx_range_for)
target_link_libraries(test_frcnn_tracing ${TORCH_LIBRARIES} TorchVision::TorchVision Python3::Python)
set_property(TARGET test_frcnn_tracing PROPERTY CXX_STANDARD 14)
import os.path as osp
import torch
import torchvision
HERE = osp.dirname(osp.abspath(__file__))
ASSETS = osp.dirname(osp.dirname(HERE))
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False)
model.eval()
traced_model = torch.jit.script(model)
traced_model.save("fasterrcnn_resnet50_fpn.pt")
import warnings
import os
from .extension import _HAS_OPS
from torchvision import models
from torchvision import datasets
from torchvision import ops
from torchvision import transforms
from torchvision import utils
from torchvision import io
import warnings
from modulefinder import Module
import torch
# Don't re-order these, we need to load the _C extension (done when importing
# .extensions) before entering _meta_registrations.
from .extension import _HAS_OPS # usort:skip
from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils # usort:skip
try:
from .version import __version__ # noqa: F401
except ImportError:
pass
# Check if torchvision is being imported within the root folder
if (not _HAS_OPS and os.path.dirname(os.path.realpath(__file__)) ==
os.path.join(os.path.realpath(os.getcwd()), 'torchvision')):
message = ('You are importing torchvision within its own root folder ({}). '
'This is not expected to work and may give errors. Please exit the '
'torchvision project source and relaunch your python interpreter.')
if not _HAS_OPS and os.path.dirname(os.path.realpath(__file__)) == os.path.join(
os.path.realpath(os.getcwd()), "torchvision"
):
message = (
"You are importing torchvision within its own root folder ({}). "
"This is not expected to work and may give errors. Please exit the "
"torchvision project source and relaunch your python interpreter."
)
warnings.warn(message.format(os.getcwd()))
_image_backend = 'PIL'
_image_backend = "PIL"
_video_backend = "pyav"
......@@ -40,9 +41,8 @@ def set_image_backend(backend):
generally faster than PIL, but does not support as many operations.
"""
global _image_backend
if backend not in ['PIL', 'accimage']:
raise ValueError("Invalid backend '{}'. Options are 'PIL' and 'accimage'"
.format(backend))
if backend not in ["PIL", "accimage"]:
raise ValueError(f"Invalid backend '{backend}'. Options are 'PIL' and 'accimage'")
_image_backend = backend
......@@ -63,23 +63,23 @@ def set_video_backend(backend):
binding for the FFmpeg libraries.
The :mod:`video_reader` package includes a native C++ implementation on
top of FFMPEG libraries, and a python API of TorchScript custom operator.
It is generally decoding faster than :mod:`pyav`, but perhaps is less robust.
It generally decodes faster than :mod:`pyav`, but is perhaps less robust.
.. note::
Building with FFMPEG is disabled by default in the latest master. If you want to use the 'video_reader'
Building with FFMPEG is disabled by default in the latest `main`. If you want to use the 'video_reader'
backend, please compile torchvision from source.
"""
global _video_backend
if backend not in ["pyav", "video_reader"]:
raise ValueError(
"Invalid video backend '%s'. Options are 'pyav' and 'video_reader'" % backend
)
if backend not in ["pyav", "video_reader", "cuda"]:
raise ValueError("Invalid video backend '%s'. Options are 'pyav', 'video_reader' and 'cuda'" % backend)
if backend == "video_reader" and not io._HAS_VIDEO_OPT:
message = (
"video_reader video backend is not available."
" Please compile torchvision from source and try again"
)
warnings.warn(message)
# TODO: better messages
message = "video_reader video backend is not available. Please compile torchvision from source and try again"
raise RuntimeError(message)
elif backend == "cuda" and not io._HAS_GPU_VIDEO_DECODER:
# TODO: better messages
message = "cuda video backend is not available."
raise RuntimeError(message)
else:
_video_backend = backend
......@@ -97,3 +97,9 @@ def get_video_backend():
def _is_tracing():
return torch._C._get_tracing_state()
def disable_beta_transforms_warning():
# Noop, only exists to avoid breaking existing code.
# See https://github.com/pytorch/vision/issues/7896
pass
import importlib.machinery
import os
from torch.hub import _get_torch_home
_HOME = os.path.join(_get_torch_home(), "datasets", "vision")
_USE_SHARDED_DATASETS = False
def _download_file_from_remote_location(fpath: str, url: str) -> None:
pass
def _is_remote_location_available() -> bool:
return False
try:
from torch.hub import load_state_dict_from_url # noqa: 401
except ImportError:
from torch.utils.model_zoo import load_url as load_state_dict_from_url # noqa: 401
def _get_extension_path(lib_name):
lib_dir = os.path.dirname(__file__)
if os.name == "nt":
# Register the main torchvision library location on the default DLL path
import ctypes
kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
prev_error_mode = kernel32.SetErrorMode(0x0001)
if with_load_library_flags:
kernel32.AddDllDirectory.restype = ctypes.c_void_p
os.add_dll_directory(lib_dir)
kernel32.SetErrorMode(prev_error_mode)
loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES)
extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
ext_specs = extfinder.find_spec(lib_name)
if ext_specs is None:
raise ImportError
return ext_specs.origin
import functools
import torch
import torch._custom_ops
import torch.library
# Ensure that torch.ops.torchvision is visible
import torchvision.extension # noqa: F401
@functools.lru_cache(None)
def get_meta_lib():
return torch.library.Library("torchvision", "IMPL", "Meta")
def register_meta(op_name, overload_name="default"):
def wrapper(fn):
if torchvision.extension._has_ops():
get_meta_lib().impl(getattr(getattr(torch.ops.torchvision, op_name), overload_name), fn)
return fn
return wrapper
@register_meta("roi_align")
def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
torch._check(
input.dtype == rois.dtype,
lambda: (
"Expected tensor for input to have the same type as tensor for rois; "
f"but type {input.dtype} does not equal {rois.dtype}"
),
)
num_rois = rois.size(0)
channels = input.size(1)
return input.new_empty((num_rois, channels, pooled_height, pooled_width))
@register_meta("_roi_align_backward")
def meta_roi_align_backward(
grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio, aligned
):
torch._check(
grad.dtype == rois.dtype,
lambda: (
"Expected tensor for grad to have the same type as tensor for rois; "
f"but type {grad.dtype} does not equal {rois.dtype}"
),
)
return grad.new_empty((batch_size, channels, height, width))
@register_meta("ps_roi_align")
def meta_ps_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio):
torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
torch._check(
input.dtype == rois.dtype,
lambda: (
"Expected tensor for input to have the same type as tensor for rois; "
f"but type {input.dtype} does not equal {rois.dtype}"
),
)
channels = input.size(1)
torch._check(
channels % (pooled_height * pooled_width) == 0,
"input channels must be a multiple of pooling height * pooling width",
)
num_rois = rois.size(0)
out_size = (num_rois, channels // (pooled_height * pooled_width), pooled_height, pooled_width)
return input.new_empty(out_size), torch.empty(out_size, dtype=torch.int32, device="meta")
@register_meta("_ps_roi_align_backward")
def meta_ps_roi_align_backward(
grad,
rois,
channel_mapping,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
batch_size,
channels,
height,
width,
):
torch._check(
grad.dtype == rois.dtype,
lambda: (
"Expected tensor for grad to have the same type as tensor for rois; "
f"but type {grad.dtype} does not equal {rois.dtype}"
),
)
return grad.new_empty((batch_size, channels, height, width))
@register_meta("roi_pool")
def meta_roi_pool(input, rois, spatial_scale, pooled_height, pooled_width):
torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
torch._check(
input.dtype == rois.dtype,
lambda: (
"Expected tensor for input to have the same type as tensor for rois; "
f"but type {input.dtype} does not equal {rois.dtype}"
),
)
num_rois = rois.size(0)
channels = input.size(1)
out_size = (num_rois, channels, pooled_height, pooled_width)
return input.new_empty(out_size), torch.empty(out_size, device="meta", dtype=torch.int32)
@register_meta("_roi_pool_backward")
def meta_roi_pool_backward(
grad, rois, argmax, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width
):
torch._check(
grad.dtype == rois.dtype,
lambda: (
"Expected tensor for grad to have the same type as tensor for rois; "
f"but type {grad.dtype} does not equal {rois.dtype}"
),
)
return grad.new_empty((batch_size, channels, height, width))
@register_meta("ps_roi_pool")
def meta_ps_roi_pool(input, rois, spatial_scale, pooled_height, pooled_width):
torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
torch._check(
input.dtype == rois.dtype,
lambda: (
"Expected tensor for input to have the same type as tensor for rois; "
f"but type {input.dtype} does not equal {rois.dtype}"
),
)
channels = input.size(1)
torch._check(
channels % (pooled_height * pooled_width) == 0,
"input channels must be a multiple of pooling height * pooling width",
)
num_rois = rois.size(0)
out_size = (num_rois, channels // (pooled_height * pooled_width), pooled_height, pooled_width)
return input.new_empty(out_size), torch.empty(out_size, device="meta", dtype=torch.int32)
@register_meta("_ps_roi_pool_backward")
def meta_ps_roi_pool_backward(
grad, rois, channel_mapping, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width
):
torch._check(
grad.dtype == rois.dtype,
lambda: (
"Expected tensor for grad to have the same type as tensor for rois; "
f"but type {grad.dtype} does not equal {rois.dtype}"
),
)
return grad.new_empty((batch_size, channels, height, width))
@torch.library.register_fake("torchvision::nms")
def meta_nms(dets, scores, iou_threshold):
torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D")
torch._check(dets.size(1) == 4, lambda: f"boxes should have 4 elements in dimension 1, got {dets.size(1)}")
torch._check(scores.dim() == 1, lambda: f"scores should be a 1d tensor, got {scores.dim()}")
torch._check(
dets.size(0) == scores.size(0),
lambda: f"boxes and scores should have same number of elements in dimension 0, got {dets.size(0)} and {scores.size(0)}",
)
ctx = torch._custom_ops.get_ctx()
num_to_keep = ctx.create_unbacked_symint()
return dets.new_empty(num_to_keep, dtype=torch.long)
@register_meta("deform_conv2d")
def meta_deform_conv2d(
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dil_h,
dil_w,
n_weight_grps,
n_offset_grps,
use_mask,
):
out_height, out_width = offset.shape[-2:]
out_channels = weight.shape[0]
batch_size = input.shape[0]
return input.new_empty((batch_size, out_channels, out_height, out_width))
@register_meta("_deform_conv2d_backward")
def meta_deform_conv2d_backward(
grad,
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups,
use_mask,
):
grad_input = input.new_empty(input.shape)
grad_weight = weight.new_empty(weight.shape)
grad_offset = offset.new_empty(offset.shape)
grad_mask = mask.new_empty(mask.shape)
grad_bias = bias.new_empty(bias.shape)
return grad_input, grad_weight, grad_offset, grad_mask, grad_bias
import enum
from typing import Sequence, Type, TypeVar
T = TypeVar("T", bound=enum.Enum)
class StrEnumMeta(enum.EnumMeta):
auto = enum.auto
def from_str(self: Type[T], member: str) -> T: # type: ignore[misc]
try:
return self[member]
except KeyError:
# TODO: use `add_suggestion` from torchvision.prototype.utils._internal to improve the error message as
# soon as it is migrated.
raise ValueError(f"Unknown value '{member}' for {self.__name__}.") from None
class StrEnum(enum.Enum, metaclass=StrEnumMeta):
pass
def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
if not seq:
return ""
if len(seq) == 1:
return f"'{seq[0]}'"
head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'"
tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'"
return head + tail
......@@ -48,6 +48,23 @@ bool AudioSampler::init(const SamplerParameters& params) {
return false;
}
#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(57, 28, 100)
SwrContext* swrContext_ = NULL;
AVChannelLayout channel_out;
AVChannelLayout channel_in;
av_channel_layout_default(&channel_out, params.out.audio.channels);
av_channel_layout_default(&channel_in, params.in.audio.channels);
int ret = swr_alloc_set_opts2(
&swrContext_,
&channel_out,
(AVSampleFormat)params.out.audio.format,
params.out.audio.samples,
&channel_in,
(AVSampleFormat)params.in.audio.format,
params.in.audio.samples,
0,
logCtx_);
#else
swrContext_ = swr_alloc_set_opts(
nullptr,
av_get_default_channel_layout(params.out.audio.channels),
......@@ -58,6 +75,7 @@ bool AudioSampler::init(const SamplerParameters& params) {
params.in.audio.samples,
0,
logCtx_);
#endif
if (swrContext_ == nullptr) {
LOG(ERROR) << "Cannot allocate SwrContext";
return false;
......@@ -65,7 +83,7 @@ bool AudioSampler::init(const SamplerParameters& params) {
int result;
if ((result = swr_init(swrContext_)) < 0) {
LOG(ERROR) << "swr_init faield, err: " << Util::generateErrorDesc(result)
LOG(ERROR) << "swr_init failed, err: " << Util::generateErrorDesc(result)
<< ", in -> format: " << params.in.audio.format
<< ", channels: " << params.in.audio.channels
<< ", samples: " << params.in.audio.samples
......@@ -116,12 +134,12 @@ int AudioSampler::sample(
outNumSamples,
inPlanes,
inNumSamples)) < 0) {
LOG(ERROR) << "swr_convert faield, err: "
LOG(ERROR) << "swr_convert failed, err: "
<< Util::generateErrorDesc(result);
return result;
}
CHECK_LE(result, outNumSamples);
TORCH_CHECK_LE(result, outNumSamples);
if (result) {
if ((result = av_samples_get_buffer_size(
......@@ -132,7 +150,7 @@ int AudioSampler::sample(
1)) >= 0) {
out->append(result);
} else {
LOG(ERROR) << "av_samples_get_buffer_size faield, err: "
LOG(ERROR) << "av_samples_get_buffer_size failed, err: "
<< Util::generateErrorDesc(result);
}
}
......@@ -140,7 +158,7 @@ int AudioSampler::sample(
// allocate a temporary buffer
auto* tmpBuffer = static_cast<uint8_t*>(av_malloc(outBufferBytes));
if (!tmpBuffer) {
LOG(ERROR) << "av_alloc faield, for size: " << outBufferBytes;
LOG(ERROR) << "av_alloc failed, for size: " << outBufferBytes;
return -1;
}
......@@ -158,7 +176,7 @@ int AudioSampler::sample(
outNumSamples,
inPlanes,
inNumSamples)) < 0) {
LOG(ERROR) << "swr_convert faield, err: "
LOG(ERROR) << "swr_convert failed, err: "
<< Util::generateErrorDesc(result);
av_free(tmpBuffer);
return result;
......@@ -166,7 +184,7 @@ int AudioSampler::sample(
av_free(tmpBuffer);
CHECK_LE(result, outNumSamples);
TORCH_CHECK_LE(result, outNumSamples);
if (result) {
result = av_samples_get_buffer_size(
......
......@@ -6,26 +6,36 @@
namespace ffmpeg {
namespace {
static int get_nb_channels(const AVFrame* frame, const AVCodecContext* codec) {
#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(57, 28, 100)
return frame ? frame->ch_layout.nb_channels : codec->ch_layout.nb_channels;
#else
return frame ? frame->channels : codec->channels;
#endif
}
bool operator==(const AudioFormat& x, const AVFrame& y) {
return x.samples == y.sample_rate && x.channels == y.channels &&
return x.samples == static_cast<size_t>(y.sample_rate) &&
x.channels == static_cast<size_t>(get_nb_channels(&y, nullptr)) &&
x.format == y.format;
}
bool operator==(const AudioFormat& x, const AVCodecContext& y) {
return x.samples == y.sample_rate && x.channels == y.channels &&
return x.samples == static_cast<size_t>(y.sample_rate) &&
x.channels == static_cast<size_t>(get_nb_channels(nullptr, &y)) &&
x.format == y.sample_fmt;
}
AudioFormat& toAudioFormat(AudioFormat& x, const AVFrame& y) {
x.samples = y.sample_rate;
x.channels = y.channels;
x.channels = get_nb_channels(&y, nullptr);
x.format = y.format;
return x;
}
AudioFormat& toAudioFormat(AudioFormat& x, const AVCodecContext& y) {
x.samples = y.sample_rate;
x.channels = y.channels;
x.channels = get_nb_channels(nullptr, &y);
x.format = y.sample_fmt;
return x;
}
......@@ -54,9 +64,15 @@ int AudioStream::initFormat() {
if (format_.format.audio.samples == 0) {
format_.format.audio.samples = codecCtx_->sample_rate;
}
#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(57, 28, 100)
if (format_.format.audio.channels == 0) {
format_.format.audio.channels = codecCtx_->ch_layout.nb_channels;
}
#else
if (format_.format.audio.channels == 0) {
format_.format.audio.channels = codecCtx_->channels;
}
#endif
if (format_.format.audio.format == AV_SAMPLE_FMT_NONE) {
format_.format.audio.format = codecCtx_->sample_fmt;
}
......@@ -68,6 +84,7 @@ int AudioStream::initFormat() {
: -1;
}
// copies audio sample bytes via swr_convert call in audio_sampler.cpp
int AudioStream::copyFrameBytes(ByteStorage* out, bool flush) {
if (!sampler_) {
sampler_ = std::make_unique<AudioSampler>(codecCtx_);
......@@ -95,6 +112,8 @@ int AudioStream::copyFrameBytes(ByteStorage* out, bool flush) {
<< ", channels: " << format_.format.audio.channels
<< ", format: " << format_.format.audio.format;
}
// calls to a sampler that converts the audio samples and copies them to the
// out buffer via ffmpeg::swr_convert
return sampler_->sample(flush ? nullptr : frame_, out);
}
......
#include "decoder.h"
#include <c10/util/Logging.h>
#include <libavutil/avutil.h>
#include <future>
#include <iostream>
#include <mutex>
......@@ -17,25 +18,6 @@ constexpr size_t kIoBufferSize = 96 * 1024;
constexpr size_t kIoPaddingSize = AV_INPUT_BUFFER_PADDING_SIZE;
constexpr size_t kLogBufferSize = 1024;
int ffmpeg_lock(void** mutex, enum AVLockOp op) {
std::mutex** handle = (std::mutex**)mutex;
switch (op) {
case AV_LOCK_CREATE:
*handle = new std::mutex();
break;
case AV_LOCK_OBTAIN:
(*handle)->lock();
break;
case AV_LOCK_RELEASE:
(*handle)->unlock();
break;
case AV_LOCK_DESTROY:
delete *handle;
break;
}
return 0;
}
bool mapFfmpegType(AVMediaType media, MediaType* type) {
switch (media) {
case AVMEDIA_TYPE_AUDIO:
......@@ -196,11 +178,11 @@ int64_t Decoder::seekCallback(int64_t offset, int whence) {
void Decoder::initOnce() {
static std::once_flag flagInit;
std::call_once(flagInit, []() {
#if LIBAVUTIL_VERSION_MAJOR < 56 // Before FFMPEG 4.0
av_register_all();
avcodec_register_all();
#endif
avformat_network_init();
// register ffmpeg lock manager
av_lockmgr_register(&ffmpeg_lock);
av_log_set_callback(Decoder::logFunction);
av_log_set_level(AV_LOG_ERROR);
VLOG(1) << "Registered ffmpeg libs";
......@@ -215,6 +197,12 @@ Decoder::~Decoder() {
cleanUp();
}
// Initialise the format context that holds information about the container and
// fill it with minimal information about the format (codecs are not opened
// here). Function reads in information about the streams from the container
// into inputCtx and then passes it to decoder::openStreams. Finally, if seek is
// specified within the decoder parameters, it seeks into the correct frame
// (note, the seek defined here is "precise" seek).
bool Decoder::init(
const DecoderParameters& params,
DecoderInCallback&& in,
......@@ -268,7 +256,7 @@ bool Decoder::init(
break;
}
fmt = av_find_input_format(fmtName);
fmt = (AVInputFormat*)av_find_input_format(fmtName);
}
const size_t avioCtxBufferSize = kIoBufferSize;
......@@ -324,6 +312,8 @@ bool Decoder::init(
}
}
av_dict_set_int(&options, "probesize", params_.probeSize, 0);
interrupted_ = false;
// ffmpeg avformat_open_input call can hang if media source doesn't respond
......@@ -381,7 +371,7 @@ bool Decoder::init(
cleanUp();
return false;
}
// SyncDecoder inherits Decoder which would override onInit.
onInit();
if (params.startOffset != 0) {
......@@ -396,11 +386,17 @@ bool Decoder::init(
return true;
}
// open appropriate CODEC for every type of stream and move it to the class
// variable `streams_` and make sure it is in range for decoding
bool Decoder::openStreams(std::vector<DecoderMetadata>* metadata) {
for (int i = 0; i < inputCtx_->nb_streams; i++) {
for (unsigned int i = 0; i < inputCtx_->nb_streams; i++) {
// - find the corespondent format at params_.formats set
MediaFormat format;
#if LIBAVUTIL_VERSION_MAJOR < 56 // Before FFMPEG 4.0
const auto media = inputCtx_->streams[i]->codec->codec_type;
#else // FFMPEG 4.0+
const auto media = inputCtx_->streams[i]->codecpar->codec_type;
#endif
if (!mapFfmpegType(media, &format.type)) {
VLOG(1) << "Stream media: " << media << " at index " << i
<< " gets ignored, unknown type";
......@@ -424,20 +420,20 @@ bool Decoder::openStreams(std::vector<DecoderMetadata>* metadata) {
if (it->stream == -2 || // all streams of this type are welcome
(!stream && (it->stream == -1 || it->stream == i))) { // new stream
VLOG(1) << "Stream type: " << format.type << " found, at index: " << i;
auto stream = createStream(
auto stream_2 = createStream(
format.type,
inputCtx_,
i,
params_.convertPtsToWallTime,
it->format,
params_.loggingUuid);
CHECK(stream);
if (stream->openCodec(metadata) < 0) {
CHECK(stream_2);
if (stream_2->openCodec(metadata, params_.numThreads) < 0) {
LOG(ERROR) << "uuid=" << params_.loggingUuid
<< " open codec failed, stream_idx=" << i;
return false;
}
streams_.emplace(i, std::move(stream));
streams_.emplace(i, std::move(stream_2));
inRange_.set(i, true);
}
}
......@@ -478,6 +474,10 @@ void Decoder::cleanUp() {
seekableBuffer_.shutdown();
}
// function does actual work, derived class calls it in working thread
// periodically. On success method returns 0, ENODATA on EOF, ETIMEDOUT if
// no frames got decoded in the specified timeout time, AVERROR_BUFFER_TOO_SMALL
// when unable to allocate packet and error on unrecoverable error
int Decoder::getFrame(size_t workingTimeInMs) {
if (inRange_.none()) {
return ENODATA;
......@@ -486,10 +486,16 @@ int Decoder::getFrame(size_t workingTimeInMs) {
// once decode() method gets called and grab some bytes
// run this method again
// init package
AVPacket avPacket;
av_init_packet(&avPacket);
avPacket.data = nullptr;
avPacket.size = 0;
// update 03/22: moving memory management to ffmpeg
AVPacket* avPacket;
avPacket = av_packet_alloc();
if (avPacket == nullptr) {
LOG(ERROR) << "uuid=" << params_.loggingUuid
<< " decoder as not able to allocate the packet.";
return AVERROR_BUFFER_TOO_SMALL;
}
avPacket->data = nullptr;
avPacket->size = 0;
auto end = std::chrono::steady_clock::now() +
std::chrono::milliseconds(workingTimeInMs);
......@@ -501,28 +507,44 @@ int Decoder::getFrame(size_t workingTimeInMs) {
int result = 0;
size_t decodingErrors = 0;
bool decodedFrame = false;
while (!interrupted_ && inRange_.any() && !decodedFrame && watcher()) {
result = av_read_frame(inputCtx_, &avPacket);
while (!interrupted_ && inRange_.any() && !decodedFrame) {
if (watcher() == false) {
LOG(ERROR) << "uuid=" << params_.loggingUuid << " hit ETIMEDOUT";
result = ETIMEDOUT;
break;
}
result = av_read_frame(inputCtx_, avPacket);
if (result == AVERROR(EAGAIN)) {
VLOG(4) << "Decoder is busy...";
std::this_thread::yield();
result = 0; // reset error, EAGAIN is not an error at all
// reset the packet to default settings
av_packet_unref(avPacket);
continue;
} else if (result == AVERROR_EOF) {
flushStreams();
VLOG(1) << "End of stream";
result = ENODATA;
break;
} else if (
result == AVERROR(EPERM) && params_.skipOperationNotPermittedPackets) {
// reset error, lets skip packets with EPERM
result = 0;
// reset the packet to default settings
av_packet_unref(avPacket);
continue;
} else if (result < 0) {
flushStreams();
LOG(ERROR) << "Error detected: " << Util::generateErrorDesc(result);
LOG(ERROR) << "uuid=" << params_.loggingUuid
<< " error detected: " << Util::generateErrorDesc(result);
break;
}
// get stream
auto stream = findByIndex(avPacket.stream_index);
// get stream; if stream cannot be found reset the packet to
// default settings
auto stream = findByIndex(avPacket->stream_index);
if (stream == nullptr || !inRange_.test(stream->getIndex())) {
av_packet_unref(&avPacket);
av_packet_unref(avPacket);
continue;
}
......@@ -533,9 +555,10 @@ int Decoder::getFrame(size_t workingTimeInMs) {
bool gotFrame = false;
bool hasMsg = false;
// packet either got consumed completely or not at all
if ((result = processPacket(stream, &avPacket, &gotFrame, &hasMsg)) < 0) {
if ((result = processPacket(
stream, avPacket, &gotFrame, &hasMsg, params_.fastSeek)) < 0) {
LOG(ERROR) << "uuid=" << params_.loggingUuid
<< " processPacket failed with code=" << result;
<< " processPacket failed with code: " << result;
break;
}
......@@ -566,20 +589,18 @@ int Decoder::getFrame(size_t workingTimeInMs) {
result = 0;
av_packet_unref(&avPacket);
av_packet_unref(avPacket);
}
av_packet_unref(&avPacket);
av_packet_free(&avPacket);
VLOG(2) << "Interrupted loop"
<< ", interrupted_ " << interrupted_ << ", inRange_.any() "
<< inRange_.any() << ", decodedFrame " << decodedFrame << ", result "
<< result;
// loop can be terminated, either by:
// 1. explcitly iterrupted
// 2. terminated by workable timeout
// 3. unrecoverable error or ENODATA (end of stream)
// 1. explicitly interrupted
// 3. unrecoverable error or ENODATA (end of stream) or ETIMEDOUT (timeout)
// 4. decoded frames pts are out of the specified range
// 5. success decoded frame
if (interrupted_) {
......@@ -594,11 +615,13 @@ int Decoder::getFrame(size_t workingTimeInMs) {
return 0;
}
// find stream by stream index
Stream* Decoder::findByIndex(int streamIndex) const {
auto it = streams_.find(streamIndex);
return it != streams_.end() ? it->second.get() : nullptr;
}
// find stream by type; note finds only the first stream of a given type
Stream* Decoder::findByType(const MediaFormat& format) const {
for (auto& stream : streams_) {
if (stream.second->getMediaFormat().type == format.type) {
......@@ -608,11 +631,14 @@ Stream* Decoder::findByType(const MediaFormat& format) const {
return nullptr;
}
// given the stream and packet, decode the frame buffers into the
// DecoderOutputMessage data structure via stream::decodePacket function.
int Decoder::processPacket(
Stream* stream,
AVPacket* packet,
bool* gotFrame,
bool* hasMsg) {
bool* hasMsg,
bool fastSeek) {
// decode package
int result;
DecoderOutputMessage msg;
......@@ -625,7 +651,15 @@ int Decoder::processPacket(
bool endInRange =
params_.endOffset <= 0 || msg.header.pts <= params_.endOffset;
inRange_.set(stream->getIndex(), endInRange);
if (endInRange && msg.header.pts >= params_.startOffset) {
// if fastseek is enabled, we're returning the first
// frame that we decode after (potential) seek.
// By default, we perform accurate seek to the closest
// following frame
bool startCondition = true;
if (!fastSeek) {
startCondition = msg.header.pts >= params_.startOffset;
}
if (endInRange && startCondition) {
*hasMsg = true;
push(std::move(msg));
}
......
......@@ -59,11 +59,11 @@ class Decoder : public MediaDecoder {
private:
// mark below function for a proper invocation
virtual bool enableLogLevel(int level) const;
virtual void logCallback(int level, const std::string& message);
virtual int readCallback(uint8_t* buf, int size);
virtual int64_t seekCallback(int64_t offset, int whence);
virtual int shutdownCallback();
bool enableLogLevel(int level) const;
void logCallback(int level, const std::string& message);
int readCallback(uint8_t* buf, int size);
int64_t seekCallback(int64_t offset, int whence);
int shutdownCallback();
bool openStreams(std::vector<DecoderMetadata>* metadata);
Stream* findByIndex(int streamIndex) const;
......@@ -72,7 +72,8 @@ class Decoder : public MediaDecoder {
Stream* stream,
AVPacket* packet,
bool* gotFrame,
bool* hasMsg);
bool* hasMsg,
bool fastSeek = false);
void flushStreams();
void cleanUp();
......
......@@ -165,7 +165,7 @@ struct MediaFormat {
struct DecoderParameters {
// local file, remote file, http url, rtmp stream uri, etc. anything that
// ffmpeg can recognize
std::string uri;
std::string uri{std::string()};
// timeout on getting bytes for decoding
size_t timeoutMs{1000};
// logging level, default AV_LOG_PANIC
......@@ -190,10 +190,15 @@ struct DecoderParameters {
bool listen{false};
// don't copy frame body, only header
bool headerOnly{false};
// enable fast seek (seek only to keyframes)
bool fastSeek{false};
// interrupt init method on timeout
bool preventStaleness{true};
// seek tolerated accuracy (us)
double seekAccuracy{1000000.0};
// Allow multithreaded decoding for numThreads > 1;
// 0 numThreads=0 sets up sensible defaults
int numThreads{1};
// what media types should be processed, default none
std::set<MediaFormat> formats;
......@@ -205,6 +210,15 @@ struct DecoderParameters {
std::string tlsCertFile;
std::string tlsKeyFile;
// Skip packets that fail with EPERM errors and continue decoding.
bool skipOperationNotPermittedPackets{false};
// probing size in bytes, i.e. the size of the data to analyze to get stream
// information. A higher value will enable detecting more information in case
// it is dispersed into the stream, but will increase latency. Must be an
// integer not lesser than 32. It is 5000000 by default.
int64_t probeSize{5000000};
};
struct DecoderHeader {
......@@ -287,7 +301,7 @@ struct DecoderMetadata {
};
/**
* Abstract class for decoding media bytes
* It has two diffrent modes. Internal media bytes retrieval for given uri and
* It has two different modes. Internal media bytes retrieval for given uri and
* external media bytes provider in case of memory streams
*/
class MediaDecoder {
......
GPU Decoder
===========
GPU decoder depends on ffmpeg for demuxing, uses NVDECODE APIs from the nvidia-video-codec sdk and uses cuda for processing on gpu. In order to use this, please follow the following steps:
* Download the latest `nvidia-video-codec-sdk <https://developer.nvidia.com/nvidia-video-codec-sdk/download>`_
* Extract the zipped file.
* Set TORCHVISION_INCLUDE environment variable to the location of the video codec headers(`nvcuvid.h` and `cuviddec.h`), which would be under `Interface` directory.
* Set TORCHVISION_LIBRARY environment variable to the location of the video codec library(`libnvcuvid.so`), which would be under `Lib/linux/stubs/x86_64` directory.
* Install the latest ffmpeg from `conda-forge` channel.
.. code:: bash
conda install -c conda-forge ffmpeg
* Set CUDA_HOME environment variable to the cuda root directory.
* Build torchvision from source:
.. code:: bash
python setup.py install
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment