Commit bf491463 authored by limm's avatar limm
Browse files

add v0.19.1 release

parent e17f5ea2
import torch
from torchvision.transforms import Compose
import unittest
import random import random
import numpy as np
import warnings import warnings
from _assert_utils import assert_equal
import numpy as np
import pytest
import torch
from common_utils import assert_equal
from torchvision.transforms import Compose
try: try:
from scipy import stats from scipy import stats
...@@ -17,21 +18,22 @@ with warnings.catch_warnings(record=True): ...@@ -17,21 +18,22 @@ with warnings.catch_warnings(record=True):
import torchvision.transforms._transforms_video as transforms import torchvision.transforms._transforms_video as transforms
class TestVideoTransforms(unittest.TestCase): class TestVideoTransforms:
def test_random_crop_video(self): def test_random_crop_video(self):
numFrames = random.randint(4, 128) numFrames = random.randint(4, 128)
height = random.randint(10, 32) * 2 height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2 width = random.randint(10, 32) * 2
oheight = random.randint(5, (height - 2) / 2) * 2 oheight = random.randint(5, (height - 2) // 2) * 2
owidth = random.randint(5, (width - 2) / 2) * 2 owidth = random.randint(5, (width - 2) // 2) * 2
clip = torch.randint(0, 256, (numFrames, height, width, 3), dtype=torch.uint8) clip = torch.randint(0, 256, (numFrames, height, width, 3), dtype=torch.uint8)
result = Compose([ result = Compose(
transforms.ToTensorVideo(), [
transforms.RandomCropVideo((oheight, owidth)), transforms.ToTensorVideo(),
])(clip) transforms.RandomCropVideo((oheight, owidth)),
self.assertEqual(result.size(2), oheight) ]
self.assertEqual(result.size(3), owidth) )(clip)
assert result.size(2) == oheight
assert result.size(3) == owidth
transforms.RandomCropVideo((oheight, owidth)).__repr__() transforms.RandomCropVideo((oheight, owidth)).__repr__()
...@@ -39,15 +41,17 @@ class TestVideoTransforms(unittest.TestCase): ...@@ -39,15 +41,17 @@ class TestVideoTransforms(unittest.TestCase):
numFrames = random.randint(4, 128) numFrames = random.randint(4, 128)
height = random.randint(10, 32) * 2 height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2 width = random.randint(10, 32) * 2
oheight = random.randint(5, (height - 2) / 2) * 2 oheight = random.randint(5, (height - 2) // 2) * 2
owidth = random.randint(5, (width - 2) / 2) * 2 owidth = random.randint(5, (width - 2) // 2) * 2
clip = torch.randint(0, 256, (numFrames, height, width, 3), dtype=torch.uint8) clip = torch.randint(0, 256, (numFrames, height, width, 3), dtype=torch.uint8)
result = Compose([ result = Compose(
transforms.ToTensorVideo(), [
transforms.RandomResizedCropVideo((oheight, owidth)), transforms.ToTensorVideo(),
])(clip) transforms.RandomResizedCropVideo((oheight, owidth)),
self.assertEqual(result.size(2), oheight) ]
self.assertEqual(result.size(3), owidth) )(clip)
assert result.size(2) == oheight
assert result.size(3) == owidth
transforms.RandomResizedCropVideo((oheight, owidth)).__repr__() transforms.RandomResizedCropVideo((oheight, owidth)).__repr__()
...@@ -55,67 +59,77 @@ class TestVideoTransforms(unittest.TestCase): ...@@ -55,67 +59,77 @@ class TestVideoTransforms(unittest.TestCase):
numFrames = random.randint(4, 128) numFrames = random.randint(4, 128)
height = random.randint(10, 32) * 2 height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2 width = random.randint(10, 32) * 2
oheight = random.randint(5, (height - 2) / 2) * 2 oheight = random.randint(5, (height - 2) // 2) * 2
owidth = random.randint(5, (width - 2) / 2) * 2 owidth = random.randint(5, (width - 2) // 2) * 2
clip = torch.ones((numFrames, height, width, 3), dtype=torch.uint8) * 255 clip = torch.ones((numFrames, height, width, 3), dtype=torch.uint8) * 255
oh1 = (height - oheight) // 2 oh1 = (height - oheight) // 2
ow1 = (width - owidth) // 2 ow1 = (width - owidth) // 2
clipNarrow = clip[:, oh1:oh1 + oheight, ow1:ow1 + owidth, :] clipNarrow = clip[:, oh1 : oh1 + oheight, ow1 : ow1 + owidth, :]
clipNarrow.fill_(0) clipNarrow.fill_(0)
result = Compose([ result = Compose(
transforms.ToTensorVideo(), [
transforms.CenterCropVideo((oheight, owidth)), transforms.ToTensorVideo(),
])(clip) transforms.CenterCropVideo((oheight, owidth)),
]
msg = "height: " + str(height) + " width: " \ )(clip)
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
self.assertEqual(result.sum().item(), 0, msg) msg = (
"height: " + str(height) + " width: " + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
)
assert result.sum().item() == 0, msg
oheight += 1 oheight += 1
owidth += 1 owidth += 1
result = Compose([ result = Compose(
transforms.ToTensorVideo(), [
transforms.CenterCropVideo((oheight, owidth)), transforms.ToTensorVideo(),
])(clip) transforms.CenterCropVideo((oheight, owidth)),
]
)(clip)
sum1 = result.sum() sum1 = result.sum()
msg = "height: " + str(height) + " width: " \ msg = (
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) "height: " + str(height) + " width: " + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
self.assertEqual(sum1.item() > 1, True, msg) )
assert sum1.item() > 1, msg
oheight += 1 oheight += 1
owidth += 1 owidth += 1
result = Compose([ result = Compose(
transforms.ToTensorVideo(), [
transforms.CenterCropVideo((oheight, owidth)), transforms.ToTensorVideo(),
])(clip) transforms.CenterCropVideo((oheight, owidth)),
]
)(clip)
sum2 = result.sum() sum2 = result.sum()
msg = "height: " + str(height) + " width: " \ msg = (
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) "height: " + str(height) + " width: " + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
self.assertTrue(sum2.item() > 1, msg) )
self.assertTrue(sum2.item() > sum1.item(), msg) assert sum2.item() > 1, msg
assert sum2.item() > sum1.item(), msg
@unittest.skipIf(stats is None, 'scipy.stats is not available') @pytest.mark.skipif(stats is None, reason="scipy.stats is not available")
def test_normalize_video(self): @pytest.mark.parametrize("channels", [1, 3])
def test_normalize_video(self, channels):
def samples_from_standard_normal(tensor): def samples_from_standard_normal(tensor):
p_value = stats.kstest(list(tensor.view(-1)), 'norm', args=(0, 1)).pvalue p_value = stats.kstest(list(tensor.view(-1)), "norm", args=(0, 1)).pvalue
return p_value > 0.0001 return p_value > 0.0001
random_state = random.getstate() random_state = random.getstate()
random.seed(42) random.seed(42)
for channels in [1, 3]:
numFrames = random.randint(4, 128) numFrames = random.randint(4, 128)
height = random.randint(32, 256) height = random.randint(32, 256)
width = random.randint(32, 256) width = random.randint(32, 256)
mean = random.random() mean = random.random()
std = random.random() std = random.random()
clip = torch.normal(mean, std, size=(channels, numFrames, height, width)) clip = torch.normal(mean, std, size=(channels, numFrames, height, width))
mean = [clip[c].mean().item() for c in range(channels)] mean = [clip[c].mean().item() for c in range(channels)]
std = [clip[c].std().item() for c in range(channels)] std = [clip[c].std().item() for c in range(channels)]
normalized = transforms.NormalizeVideo(mean, std)(clip) normalized = transforms.NormalizeVideo(mean, std)(clip)
self.assertTrue(samples_from_standard_normal(normalized)) assert samples_from_standard_normal(normalized)
random.setstate(random_state) random.setstate(random_state)
# Checking the optional in-place behaviour # Checking the optional in-place behaviour
...@@ -129,49 +143,36 @@ class TestVideoTransforms(unittest.TestCase): ...@@ -129,49 +143,36 @@ class TestVideoTransforms(unittest.TestCase):
numFrames, height, width = 64, 4, 4 numFrames, height, width = 64, 4, 4
trans = transforms.ToTensorVideo() trans = transforms.ToTensorVideo()
with self.assertRaises(TypeError): with pytest.raises(TypeError):
trans(np.random.rand(numFrames, height, width, 1).tolist()) np_rng = np.random.RandomState(0)
trans(np_rng.rand(numFrames, height, width, 1).tolist())
with pytest.raises(TypeError):
trans(torch.rand((numFrames, height, width, 1), dtype=torch.float)) trans(torch.rand((numFrames, height, width, 1), dtype=torch.float))
with self.assertRaises(ValueError): with pytest.raises(ValueError):
trans(torch.ones((3, numFrames, height, width, 3), dtype=torch.uint8)) trans(torch.ones((3, numFrames, height, width, 3), dtype=torch.uint8))
with pytest.raises(ValueError):
trans(torch.ones((height, width, 3), dtype=torch.uint8)) trans(torch.ones((height, width, 3), dtype=torch.uint8))
with pytest.raises(ValueError):
trans(torch.ones((width, 3), dtype=torch.uint8)) trans(torch.ones((width, 3), dtype=torch.uint8))
with pytest.raises(ValueError):
trans(torch.ones((3), dtype=torch.uint8)) trans(torch.ones((3), dtype=torch.uint8))
trans.__repr__() trans.__repr__()
@unittest.skipIf(stats is None, 'scipy.stats not available') @pytest.mark.parametrize("p", (0, 1))
def test_random_horizontal_flip_video(self): def test_random_horizontal_flip_video(self, p):
random_state = random.getstate()
random.seed(42)
clip = torch.rand((3, 4, 112, 112), dtype=torch.float) clip = torch.rand((3, 4, 112, 112), dtype=torch.float)
hclip = clip.flip((-1)) hclip = clip.flip(-1)
num_samples = 250
num_horizontal = 0
for _ in range(num_samples):
out = transforms.RandomHorizontalFlipVideo()(clip)
if torch.all(torch.eq(out, hclip)):
num_horizontal += 1
p_value = stats.binom_test(num_horizontal, num_samples, p=0.5) out = transforms.RandomHorizontalFlipVideo(p=p)(clip)
random.setstate(random_state) if p == 0:
self.assertGreater(p_value, 0.0001) torch.testing.assert_close(out, clip)
elif p == 1:
num_samples = 250 torch.testing.assert_close(out, hclip)
num_horizontal = 0
for _ in range(num_samples):
out = transforms.RandomHorizontalFlipVideo(p=0.7)(clip)
if torch.all(torch.eq(out, hclip)):
num_horizontal += 1
p_value = stats.binom_test(num_horizontal, num_samples, p=0.7)
random.setstate(random_state)
self.assertGreater(p_value, 0.0001)
transforms.RandomHorizontalFlipVideo().__repr__() transforms.RandomHorizontalFlipVideo().__repr__()
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() pytest.main([__file__])
from copy import deepcopy
import pytest
import torch
from common_utils import assert_equal, make_bounding_boxes, make_image, make_segmentation_mask, make_video
from PIL import Image
from torchvision import tv_tensors
@pytest.fixture(autouse=True)
def restore_tensor_return_type():
# This is for security, as we should already be restoring the default manually in each test anyway
# (at least at the time of writing...)
yield
tv_tensors.set_return_type("Tensor")
@pytest.mark.parametrize("data", [torch.rand(3, 32, 32), Image.new("RGB", (32, 32), color=123)])
def test_image_instance(data):
image = tv_tensors.Image(data)
assert isinstance(image, torch.Tensor)
assert image.ndim == 3 and image.shape[0] == 3
@pytest.mark.parametrize("data", [torch.randint(0, 10, size=(1, 32, 32)), Image.new("L", (32, 32), color=2)])
def test_mask_instance(data):
mask = tv_tensors.Mask(data)
assert isinstance(mask, torch.Tensor)
assert mask.ndim == 3 and mask.shape[0] == 1
@pytest.mark.parametrize("data", [torch.randint(0, 32, size=(5, 4)), [[0, 0, 5, 5], [2, 2, 7, 7]], [1, 2, 3, 4]])
@pytest.mark.parametrize(
"format", ["XYXY", "CXCYWH", tv_tensors.BoundingBoxFormat.XYXY, tv_tensors.BoundingBoxFormat.XYWH]
)
def test_bbox_instance(data, format):
bboxes = tv_tensors.BoundingBoxes(data, format=format, canvas_size=(32, 32))
assert isinstance(bboxes, torch.Tensor)
assert bboxes.ndim == 2 and bboxes.shape[1] == 4
if isinstance(format, str):
format = tv_tensors.BoundingBoxFormat[(format.upper())]
assert bboxes.format == format
def test_bbox_dim_error():
data_3d = [[[1, 2, 3, 4]]]
with pytest.raises(ValueError, match="Expected a 1D or 2D tensor, got 3D"):
tv_tensors.BoundingBoxes(data_3d, format="XYXY", canvas_size=(32, 32))
@pytest.mark.parametrize(
("data", "input_requires_grad", "expected_requires_grad"),
[
([[[0.0, 1.0], [0.0, 1.0]]], None, False),
([[[0.0, 1.0], [0.0, 1.0]]], False, False),
([[[0.0, 1.0], [0.0, 1.0]]], True, True),
(torch.rand(3, 16, 16, requires_grad=False), None, False),
(torch.rand(3, 16, 16, requires_grad=False), False, False),
(torch.rand(3, 16, 16, requires_grad=False), True, True),
(torch.rand(3, 16, 16, requires_grad=True), None, True),
(torch.rand(3, 16, 16, requires_grad=True), False, False),
(torch.rand(3, 16, 16, requires_grad=True), True, True),
],
)
def test_new_requires_grad(data, input_requires_grad, expected_requires_grad):
tv_tensor = tv_tensors.Image(data, requires_grad=input_requires_grad)
assert tv_tensor.requires_grad is expected_requires_grad
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
def test_isinstance(make_input):
assert isinstance(make_input(), torch.Tensor)
def test_wrapping_no_copy():
tensor = torch.rand(3, 16, 16)
image = tv_tensors.Image(tensor)
assert image.data_ptr() == tensor.data_ptr()
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
def test_to_wrapping(make_input):
dp = make_input()
dp_to = dp.to(torch.float64)
assert type(dp_to) is type(dp)
assert dp_to.dtype is torch.float64
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
def test_to_tv_tensor_reference(make_input, return_type):
tensor = torch.rand((3, 16, 16), dtype=torch.float64)
dp = make_input()
with tv_tensors.set_return_type(return_type):
tensor_to = tensor.to(dp)
assert type(tensor_to) is (type(dp) if return_type == "TVTensor" else torch.Tensor)
assert tensor_to.dtype is dp.dtype
assert type(tensor) is torch.Tensor
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
def test_clone_wrapping(make_input, return_type):
dp = make_input()
with tv_tensors.set_return_type(return_type):
dp_clone = dp.clone()
assert type(dp_clone) is type(dp)
assert dp_clone.data_ptr() != dp.data_ptr()
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
def test_requires_grad__wrapping(make_input, return_type):
dp = make_input(dtype=torch.float)
assert not dp.requires_grad
with tv_tensors.set_return_type(return_type):
dp_requires_grad = dp.requires_grad_(True)
assert type(dp_requires_grad) is type(dp)
assert dp.requires_grad
assert dp_requires_grad.requires_grad
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
def test_detach_wrapping(make_input, return_type):
dp = make_input(dtype=torch.float).requires_grad_(True)
with tv_tensors.set_return_type(return_type):
dp_detached = dp.detach()
assert type(dp_detached) is type(dp)
@pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
def test_force_subclass_with_metadata(return_type):
# Sanity checks for the ops in _FORCE_TORCHFUNCTION_SUBCLASS and tv_tensors with metadata
# Largely the same as above, we additionally check that the metadata is preserved
format, canvas_size = "XYXY", (32, 32)
bbox = tv_tensors.BoundingBoxes([[0, 0, 5, 5], [2, 2, 7, 7]], format=format, canvas_size=canvas_size)
tv_tensors.set_return_type(return_type)
bbox = bbox.clone()
if return_type == "TVTensor":
assert bbox.format, bbox.canvas_size == (format, canvas_size)
bbox = bbox.to(torch.float64)
if return_type == "TVTensor":
assert bbox.format, bbox.canvas_size == (format, canvas_size)
bbox = bbox.detach()
if return_type == "TVTensor":
assert bbox.format, bbox.canvas_size == (format, canvas_size)
assert not bbox.requires_grad
bbox.requires_grad_(True)
if return_type == "TVTensor":
assert bbox.format, bbox.canvas_size == (format, canvas_size)
assert bbox.requires_grad
tv_tensors.set_return_type("tensor")
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
def test_other_op_no_wrapping(make_input, return_type):
dp = make_input()
with tv_tensors.set_return_type(return_type):
# any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
output = dp * 2
assert type(output) is (type(dp) if return_type == "TVTensor" else torch.Tensor)
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize(
"op",
[
lambda t: t.numpy(),
lambda t: t.tolist(),
lambda t: t.max(dim=-1),
],
)
def test_no_tensor_output_op_no_wrapping(make_input, op):
dp = make_input()
output = op(dp)
assert type(output) is not type(dp)
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
def test_inplace_op_no_wrapping(make_input, return_type):
dp = make_input()
original_type = type(dp)
with tv_tensors.set_return_type(return_type):
output = dp.add_(0)
assert type(output) is (type(dp) if return_type == "TVTensor" else torch.Tensor)
assert type(dp) is original_type
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
def test_wrap(make_input):
dp = make_input()
# any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
output = dp * 2
dp_new = tv_tensors.wrap(output, like=dp)
assert type(dp_new) is type(dp)
assert dp_new.data_ptr() == output.data_ptr()
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("requires_grad", [False, True])
def test_deepcopy(make_input, requires_grad):
dp = make_input(dtype=torch.float)
dp.requires_grad_(requires_grad)
dp_deepcopied = deepcopy(dp)
assert dp_deepcopied is not dp
assert dp_deepcopied.data_ptr() != dp.data_ptr()
assert_equal(dp_deepcopied, dp)
assert type(dp_deepcopied) is type(dp)
assert dp_deepcopied.requires_grad is requires_grad
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
@pytest.mark.parametrize(
"op",
(
lambda dp: dp + torch.rand(*dp.shape),
lambda dp: torch.rand(*dp.shape) + dp,
lambda dp: dp * torch.rand(*dp.shape),
lambda dp: torch.rand(*dp.shape) * dp,
lambda dp: dp + 3,
lambda dp: 3 + dp,
lambda dp: dp + dp,
lambda dp: dp.sum(),
lambda dp: dp.reshape(-1),
lambda dp: dp.int(),
lambda dp: torch.stack([dp, dp]),
lambda dp: torch.chunk(dp, 2)[0],
lambda dp: torch.unbind(dp)[0],
),
)
def test_usual_operations(make_input, return_type, op):
dp = make_input()
with tv_tensors.set_return_type(return_type):
out = op(dp)
assert type(out) is (type(dp) if return_type == "TVTensor" else torch.Tensor)
if isinstance(dp, tv_tensors.BoundingBoxes) and return_type == "TVTensor":
assert hasattr(out, "format")
assert hasattr(out, "canvas_size")
def test_subclasses():
img = make_image()
masks = make_segmentation_mask()
with pytest.raises(TypeError, match="unsupported operand"):
img + masks
def test_set_return_type():
img = make_image()
assert type(img + 3) is torch.Tensor
with tv_tensors.set_return_type("TVTensor"):
assert type(img + 3) is tv_tensors.Image
assert type(img + 3) is torch.Tensor
tv_tensors.set_return_type("TVTensor")
assert type(img + 3) is tv_tensors.Image
with tv_tensors.set_return_type("tensor"):
assert type(img + 3) is torch.Tensor
with tv_tensors.set_return_type("TVTensor"):
assert type(img + 3) is tv_tensors.Image
tv_tensors.set_return_type("tensor")
assert type(img + 3) is torch.Tensor
assert type(img + 3) is torch.Tensor
# Exiting a context manager will restore the return type as it was prior to entering it,
# regardless of whether the "global" tv_tensors.set_return_type() was called within the context manager.
assert type(img + 3) is tv_tensors.Image
tv_tensors.set_return_type("tensor")
def test_return_type_input():
img = make_image()
# Case-insensitive
with tv_tensors.set_return_type("tvtensor"):
assert type(img + 3) is tv_tensors.Image
with pytest.raises(ValueError, match="return_type must be"):
tv_tensors.set_return_type("typo")
tv_tensors.set_return_type("tensor")
import pytest
import numpy as np
import os import os
import re
import sys import sys
import tempfile import tempfile
from io import BytesIO
import numpy as np
import pytest
import torch import torch
import torchvision.transforms.functional as F
import torchvision.utils as utils import torchvision.utils as utils
from common_utils import assert_equal, cpu_and_cuda
from PIL import __version__ as PILLOW_VERSION, Image, ImageColor
from torchvision.transforms.v2.functional import to_dtype
from io import BytesIO
import torchvision.transforms.functional as F
from PIL import Image, __version__ as PILLOW_VERSION, ImageColor
from _assert_utils import assert_equal
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split("."))
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split('.')) boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], keypoints = torch.tensor([[[10, 10], [5, 5], [2, 2]], [[20, 20], [30, 30], [3, 3]]], dtype=torch.float)
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
def test_make_grid_not_inplace(): def test_make_grid_not_inplace():
...@@ -23,13 +26,13 @@ def test_make_grid_not_inplace(): ...@@ -23,13 +26,13 @@ def test_make_grid_not_inplace():
t_clone = t.clone() t_clone = t.clone()
utils.make_grid(t, normalize=False) utils.make_grid(t, normalize=False)
assert_equal(t, t_clone, msg='make_grid modified tensor in-place') assert_equal(t, t_clone, msg="make_grid modified tensor in-place")
utils.make_grid(t, normalize=True, scale_each=False) utils.make_grid(t, normalize=True, scale_each=False)
assert_equal(t, t_clone, msg='make_grid modified tensor in-place') assert_equal(t, t_clone, msg="make_grid modified tensor in-place")
utils.make_grid(t, normalize=True, scale_each=True) utils.make_grid(t, normalize=True, scale_each=True)
assert_equal(t, t_clone, msg='make_grid modified tensor in-place') assert_equal(t, t_clone, msg="make_grid modified tensor in-place")
def test_normalize_in_make_grid(): def test_normalize_in_make_grid():
...@@ -43,51 +46,51 @@ def test_normalize_in_make_grid(): ...@@ -43,51 +46,51 @@ def test_normalize_in_make_grid():
# Rounding the result to one decimal for comparison # Rounding the result to one decimal for comparison
n_digits = 1 n_digits = 1
rounded_grid_max = torch.round(grid_max * 10 ** n_digits) / (10 ** n_digits) rounded_grid_max = torch.round(grid_max * 10**n_digits) / (10**n_digits)
rounded_grid_min = torch.round(grid_min * 10 ** n_digits) / (10 ** n_digits) rounded_grid_min = torch.round(grid_min * 10**n_digits) / (10**n_digits)
assert_equal(norm_max, rounded_grid_max, msg='Normalized max is not equal to 1') assert_equal(norm_max, rounded_grid_max, msg="Normalized max is not equal to 1")
assert_equal(norm_min, rounded_grid_min, msg='Normalized min is not equal to 0') assert_equal(norm_min, rounded_grid_min, msg="Normalized min is not equal to 0")
@pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows') @pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
def test_save_image(): def test_save_image():
with tempfile.NamedTemporaryFile(suffix='.png') as f: with tempfile.NamedTemporaryFile(suffix=".png") as f:
t = torch.rand(2, 3, 64, 64) t = torch.rand(2, 3, 64, 64)
utils.save_image(t, f.name) utils.save_image(t, f.name)
assert os.path.exists(f.name), 'The image is not present after save' assert os.path.exists(f.name), "The image is not present after save"
@pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows') @pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
def test_save_image_single_pixel(): def test_save_image_single_pixel():
with tempfile.NamedTemporaryFile(suffix='.png') as f: with tempfile.NamedTemporaryFile(suffix=".png") as f:
t = torch.rand(1, 3, 1, 1) t = torch.rand(1, 3, 1, 1)
utils.save_image(t, f.name) utils.save_image(t, f.name)
assert os.path.exists(f.name), 'The pixel image is not present after save' assert os.path.exists(f.name), "The pixel image is not present after save"
@pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows') @pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
def test_save_image_file_object(): def test_save_image_file_object():
with tempfile.NamedTemporaryFile(suffix='.png') as f: with tempfile.NamedTemporaryFile(suffix=".png") as f:
t = torch.rand(2, 3, 64, 64) t = torch.rand(2, 3, 64, 64)
utils.save_image(t, f.name) utils.save_image(t, f.name)
img_orig = Image.open(f.name) img_orig = Image.open(f.name)
fp = BytesIO() fp = BytesIO()
utils.save_image(t, fp, format='png') utils.save_image(t, fp, format="png")
img_bytes = Image.open(fp) img_bytes = Image.open(fp)
assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg='Image not stored in file object') assert_equal(F.pil_to_tensor(img_orig), F.pil_to_tensor(img_bytes), msg="Image not stored in file object")
@pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows') @pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
def test_save_image_single_pixel_file_object(): def test_save_image_single_pixel_file_object():
with tempfile.NamedTemporaryFile(suffix='.png') as f: with tempfile.NamedTemporaryFile(suffix=".png") as f:
t = torch.rand(1, 3, 1, 1) t = torch.rand(1, 3, 1, 1)
utils.save_image(t, f.name) utils.save_image(t, f.name)
img_orig = Image.open(f.name) img_orig = Image.open(f.name)
fp = BytesIO() fp = BytesIO()
utils.save_image(t, fp, format='png') utils.save_image(t, fp, format="png")
img_bytes = Image.open(fp) img_bytes = Image.open(fp)
assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg='Image not stored in file object') assert_equal(F.pil_to_tensor(img_orig), F.pil_to_tensor(img_bytes), msg="Image not stored in file object")
def test_draw_boxes(): def test_draw_boxes():
...@@ -103,7 +106,7 @@ def test_draw_boxes(): ...@@ -103,7 +106,7 @@ def test_draw_boxes():
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy()) res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
res.save(path) res.save(path)
if PILLOW_VERSION >= (8, 2): if PILLOW_VERSION >= (10, 1):
# The reference image is only valid for new PIL versions # The reference image is only valid for new PIL versions
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
assert_equal(result, expected) assert_equal(result, expected)
...@@ -113,11 +116,37 @@ def test_draw_boxes(): ...@@ -113,11 +116,37 @@ def test_draw_boxes():
assert_equal(img, img_cp) assert_equal(img, img_cp)
@pytest.mark.parametrize("fill", [True, False])
def test_draw_boxes_dtypes(fill):
img_uint8 = torch.full((3, 100, 100), 255, dtype=torch.uint8)
out_uint8 = utils.draw_bounding_boxes(img_uint8, boxes, fill=fill)
assert img_uint8 is not out_uint8
assert out_uint8.dtype == torch.uint8
img_float = to_dtype(img_uint8, torch.float, scale=True)
out_float = utils.draw_bounding_boxes(img_float, boxes, fill=fill)
assert img_float is not out_float
assert out_float.is_floating_point()
torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1)
@pytest.mark.parametrize("colors", [None, ["red", "blue", "#FF00FF", (1, 34, 122)], "red", "#FF00FF", (1, 34, 122)])
def test_draw_boxes_colors(colors):
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
utils.draw_bounding_boxes(img, boxes, fill=False, width=7, colors=colors)
with pytest.raises(ValueError, match="Number of colors must be equal or larger than the number of objects"):
utils.draw_bounding_boxes(image=img, boxes=boxes, colors=[])
def test_draw_boxes_vanilla(): def test_draw_boxes_vanilla():
img = torch.full((3, 100, 100), 0, dtype=torch.uint8) img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
img_cp = img.clone() img_cp = img.clone()
boxes_cp = boxes.clone() boxes_cp = boxes.clone()
result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7) result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7, colors="white")
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_vanilla.png") path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_vanilla.png")
if not os.path.exists(path): if not os.path.exists(path):
...@@ -131,39 +160,75 @@ def test_draw_boxes_vanilla(): ...@@ -131,39 +160,75 @@ def test_draw_boxes_vanilla():
assert_equal(img, img_cp) assert_equal(img, img_cp)
def test_draw_boxes_grayscale():
img = torch.full((1, 4, 4), fill_value=255, dtype=torch.uint8)
boxes = torch.tensor([[0, 0, 3, 3]], dtype=torch.int64)
bboxed_img = utils.draw_bounding_boxes(image=img, boxes=boxes, colors=["#1BBC9B"])
assert bboxed_img.size(0) == 3
def test_draw_invalid_boxes(): def test_draw_invalid_boxes():
img_tp = ((1, 1, 1), (1, 2, 3)) img_tp = ((1, 1, 1), (1, 2, 3))
img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8) img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], img_correct = torch.zeros((3, 10, 10), dtype=torch.uint8)
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
boxes_wrong = torch.tensor([[10, 10, 4, 5], [30, 20, 10, 5]], dtype=torch.float)
labels_wrong = ["one", "two"]
colors_wrong = ["pink", "blue"]
with pytest.raises(TypeError, match="Tensor expected"): with pytest.raises(TypeError, match="Tensor expected"):
utils.draw_bounding_boxes(img_tp, boxes) utils.draw_bounding_boxes(img_tp, boxes)
with pytest.raises(ValueError, match="Tensor uint8 expected"):
utils.draw_bounding_boxes(img_wrong1, boxes)
with pytest.raises(ValueError, match="Pass individual images, not batches"): with pytest.raises(ValueError, match="Pass individual images, not batches"):
utils.draw_bounding_boxes(img_wrong2, boxes) utils.draw_bounding_boxes(img_wrong2, boxes)
with pytest.raises(ValueError, match="Only grayscale and RGB images are supported"):
utils.draw_bounding_boxes(img_wrong2[0][:2], boxes)
with pytest.raises(ValueError, match="Number of boxes"):
utils.draw_bounding_boxes(img_correct, boxes, labels_wrong)
with pytest.raises(ValueError, match="Number of colors"):
utils.draw_bounding_boxes(img_correct, boxes, colors=colors_wrong)
with pytest.raises(ValueError, match="Boxes need to be in"):
utils.draw_bounding_boxes(img_correct, boxes_wrong)
def test_draw_boxes_warning():
img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
with pytest.warns(UserWarning, match=re.escape("Argument 'font_size' will be ignored since 'font' is not set.")):
utils.draw_bounding_boxes(img, boxes, font_size=11)
@pytest.mark.parametrize('colors', [ def test_draw_no_boxes():
None, img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
['red', 'blue'], boxes = torch.full((0, 4), 0, dtype=torch.float)
['#FF00FF', (1, 34, 122)], with pytest.warns(UserWarning, match=re.escape("boxes doesn't contain any box. No box was drawn")):
]) res = utils.draw_bounding_boxes(img, boxes)
@pytest.mark.parametrize('alpha', (0, .5, .7, 1)) # Check that the function didn't change the image
def test_draw_segmentation_masks(colors, alpha): assert res.eq(img).all()
@pytest.mark.parametrize(
"colors",
[
None,
"blue",
"#FF00FF",
(1, 34, 122),
["red", "blue"],
["#FF00FF", (1, 34, 122)],
],
)
@pytest.mark.parametrize("alpha", (0, 0.5, 0.7, 1))
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_draw_segmentation_masks(colors, alpha, device):
"""This test makes sure that masks draw their corresponding color where they should""" """This test makes sure that masks draw their corresponding color where they should"""
num_masks, h, w = 2, 100, 100 num_masks, h, w = 2, 100, 100
dtype = torch.uint8 dtype = torch.uint8
img = torch.randint(0, 256, size=(3, h, w), dtype=dtype) img = torch.randint(0, 256, size=(3, h, w), dtype=dtype, device=device)
masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool) masks = torch.zeros((num_masks, h, w), dtype=torch.bool, device=device)
masks[0, 10:20, 10:20] = True
masks[1, 15:25, 15:25] = True
# For testing we enforce that there's no overlap between the masks. The
# current behaviour is that the last mask's color will take priority when
# masks overlap, but this makes testing slightly harder so we don't really
# care
overlap = masks[0] & masks[1] overlap = masks[0] & masks[1]
masks[:, overlap] = False
out = utils.draw_segmentation_masks(img, masks, colors=colors, alpha=alpha) out = utils.draw_segmentation_masks(img, masks, colors=colors, alpha=alpha)
assert out.dtype == dtype assert out.dtype == dtype
...@@ -175,27 +240,53 @@ def test_draw_segmentation_masks(colors, alpha): ...@@ -175,27 +240,53 @@ def test_draw_segmentation_masks(colors, alpha):
if colors is None: if colors is None:
colors = utils._generate_color_palette(num_masks) colors = utils._generate_color_palette(num_masks)
elif isinstance(colors, str) or isinstance(colors, tuple):
colors = [colors]
# Make sure each mask draws with its own color # Make sure each mask draws with its own color
for mask, color in zip(masks, colors): for mask, color in zip(masks, colors):
if isinstance(color, str): if isinstance(color, str):
color = ImageColor.getrgb(color) color = ImageColor.getrgb(color)
color = torch.tensor(color, dtype=dtype) color = torch.tensor(color, dtype=dtype, device=device)
if alpha == 1: if alpha == 1:
assert (out[:, mask] == color[:, None]).all() assert (out[:, mask & ~overlap] == color[:, None]).all()
elif alpha == 0: elif alpha == 0:
assert (out[:, mask] == img[:, mask]).all() assert (out[:, mask & ~overlap] == img[:, mask & ~overlap]).all()
interpolated_color = (img[:, mask & ~overlap] * (1 - alpha) + color[:, None] * alpha).to(dtype)
torch.testing.assert_close(out[:, mask & ~overlap], interpolated_color, rtol=0.0, atol=1.0)
interpolated_overlap = (img[:, overlap] * (1 - alpha)).to(dtype)
torch.testing.assert_close(out[:, overlap], interpolated_overlap, rtol=0.0, atol=1.0)
def test_draw_segmentation_masks_dtypes():
num_masks, h, w = 2, 100, 100
masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool)
img_uint8 = torch.randint(0, 256, size=(3, h, w), dtype=torch.uint8)
out_uint8 = utils.draw_segmentation_masks(img_uint8, masks)
assert img_uint8 is not out_uint8
assert out_uint8.dtype == torch.uint8
img_float = to_dtype(img_uint8, torch.float, scale=True)
out_float = utils.draw_segmentation_masks(img_float, masks)
assert img_float is not out_float
assert out_float.is_floating_point()
interpolated_color = (img[:, mask] * (1 - alpha) + color[:, None] * alpha).to(dtype) torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1)
torch.testing.assert_close(out[:, mask], interpolated_color, rtol=0.0, atol=1.0)
def test_draw_segmentation_masks_errors(): @pytest.mark.parametrize("device", cpu_and_cuda())
def test_draw_segmentation_masks_errors(device):
h, w = 10, 10 h, w = 10, 10
masks = torch.randint(0, 2, size=(h, w), dtype=torch.bool) masks = torch.randint(0, 2, size=(h, w), dtype=torch.bool, device=device)
img = torch.randint(0, 256, size=(3, h, w), dtype=torch.uint8) img = torch.randint(0, 256, size=(3, h, w), dtype=torch.uint8, device=device)
with pytest.raises(TypeError, match="The image must be a tensor"): with pytest.raises(TypeError, match="The image must be a tensor"):
utils.draw_segmentation_masks(image="Not A Tensor Image", masks=masks) utils.draw_segmentation_masks(image="Not A Tensor Image", masks=masks)
...@@ -217,15 +308,236 @@ def test_draw_segmentation_masks_errors(): ...@@ -217,15 +308,236 @@ def test_draw_segmentation_masks_errors():
with pytest.raises(ValueError, match="must have the same height and width"): with pytest.raises(ValueError, match="must have the same height and width"):
masks_bad_shape = torch.randint(0, 2, size=(h + 4, w), dtype=torch.bool) masks_bad_shape = torch.randint(0, 2, size=(h + 4, w), dtype=torch.bool)
utils.draw_segmentation_masks(image=img, masks=masks_bad_shape) utils.draw_segmentation_masks(image=img, masks=masks_bad_shape)
with pytest.raises(ValueError, match="There are more masks"): with pytest.raises(ValueError, match="Number of colors must be equal or larger than the number of objects"):
utils.draw_segmentation_masks(image=img, masks=masks, colors=[]) utils.draw_segmentation_masks(image=img, masks=masks, colors=[])
with pytest.raises(ValueError, match="colors must be a tuple or a string, or a list thereof"): with pytest.raises(ValueError, match="`colors` must be a tuple or a string, or a list thereof"):
bad_colors = np.array(['red', 'blue']) # should be a list bad_colors = np.array(["red", "blue"]) # should be a list
utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors) utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors)
with pytest.raises(ValueError, match="It seems that you passed a tuple of colors instead of"): with pytest.raises(ValueError, match="If passed as tuple, colors should be an RGB triplet"):
bad_colors = ('red', 'blue') # should be a list bad_colors = ("red", "blue") # should be a list
utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors) utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors)
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_draw_no_segmention_mask(device):
img = torch.full((3, 100, 100), 0, dtype=torch.uint8, device=device)
masks = torch.full((0, 100, 100), 0, dtype=torch.bool, device=device)
with pytest.warns(UserWarning, match=re.escape("masks doesn't contain any mask. No mask was drawn")):
res = utils.draw_segmentation_masks(img, masks)
# Check that the function didn't change the image
assert res.eq(img).all()
def test_draw_keypoints_vanilla():
# Keypoints is declared on top as global variable
keypoints_cp = keypoints.clone()
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
img_cp = img.clone()
result = utils.draw_keypoints(
img,
keypoints,
colors="red",
connectivity=[
(0, 1),
],
)
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_keypoint_vanilla.png")
if not os.path.exists(path):
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
res.save(path)
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
assert_equal(result, expected)
# Check that keypoints are not modified inplace
assert_equal(keypoints, keypoints_cp)
# Check that image is not modified in place
assert_equal(img, img_cp)
def test_draw_keypoins_K_equals_one():
# Non-regression test for https://github.com/pytorch/vision/pull/8439
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
keypoints = torch.tensor([[[10, 10]]], dtype=torch.float)
utils.draw_keypoints(img, keypoints)
@pytest.mark.parametrize("colors", ["red", "#FF00FF", (1, 34, 122)])
def test_draw_keypoints_colored(colors):
# Keypoints is declared on top as global variable
keypoints_cp = keypoints.clone()
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
img_cp = img.clone()
result = utils.draw_keypoints(
img,
keypoints,
colors=colors,
connectivity=[
(0, 1),
],
)
assert result.size(0) == 3
assert_equal(keypoints, keypoints_cp)
assert_equal(img, img_cp)
@pytest.mark.parametrize("connectivity", [[(0, 1)], [(0, 1), (1, 2)]])
@pytest.mark.parametrize(
"vis",
[
torch.tensor([[1, 1, 0], [1, 1, 0]], dtype=torch.bool),
torch.tensor([[1, 1, 0], [1, 1, 0]], dtype=torch.float).unsqueeze_(-1),
],
)
def test_draw_keypoints_visibility(connectivity, vis):
# Keypoints is declared on top as global variable
keypoints_cp = keypoints.clone()
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
img_cp = img.clone()
vis_cp = vis if vis is None else vis.clone()
result = utils.draw_keypoints(
image=img,
keypoints=keypoints,
connectivity=connectivity,
colors="red",
visibility=vis,
)
assert result.size(0) == 3
assert_equal(keypoints, keypoints_cp)
assert_equal(img, img_cp)
# compare with a fakedata image
# connect the key points 0 to 1 for both skeletons and do not show the other key points
path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_keypoints_visibility.png"
)
if not os.path.exists(path):
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
res.save(path)
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
assert_equal(result, expected)
if vis_cp is None:
assert vis is None
else:
assert_equal(vis, vis_cp)
assert vis.dtype == vis_cp.dtype
def test_draw_keypoints_visibility_default():
# Keypoints is declared on top as global variable
keypoints_cp = keypoints.clone()
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
img_cp = img.clone()
result = utils.draw_keypoints(
image=img,
keypoints=keypoints,
connectivity=[(0, 1)],
colors="red",
visibility=None,
)
assert result.size(0) == 3
assert_equal(keypoints, keypoints_cp)
assert_equal(img, img_cp)
# compare against fakedata image, which connects 0->1 for both key-point skeletons
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_keypoint_vanilla.png")
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
assert_equal(result, expected)
def test_draw_keypoints_dtypes():
image_uint8 = torch.randint(0, 256, size=(3, 100, 100), dtype=torch.uint8)
image_float = to_dtype(image_uint8, torch.float, scale=True)
out_uint8 = utils.draw_keypoints(image_uint8, keypoints)
out_float = utils.draw_keypoints(image_float, keypoints)
assert out_uint8.dtype == torch.uint8
assert out_uint8 is not image_uint8
assert out_float.is_floating_point()
assert out_float is not image_float
torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1)
def test_draw_keypoints_errors():
h, w = 10, 10
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
with pytest.raises(TypeError, match="The image must be a tensor"):
utils.draw_keypoints(image="Not A Tensor Image", keypoints=keypoints)
with pytest.raises(ValueError, match="The image dtype must be"):
img_bad_dtype = torch.full((3, h, w), 0, dtype=torch.int64)
utils.draw_keypoints(image=img_bad_dtype, keypoints=keypoints)
with pytest.raises(ValueError, match="Pass individual images, not batches"):
batch = torch.randint(0, 256, size=(10, 3, h, w), dtype=torch.uint8)
utils.draw_keypoints(image=batch, keypoints=keypoints)
with pytest.raises(ValueError, match="Pass an RGB image"):
one_channel = torch.randint(0, 256, size=(1, h, w), dtype=torch.uint8)
utils.draw_keypoints(image=one_channel, keypoints=keypoints)
with pytest.raises(ValueError, match="keypoints must be of shape"):
invalid_keypoints = torch.tensor([[10, 10, 10, 10], [5, 6, 7, 8]], dtype=torch.float)
utils.draw_keypoints(image=img, keypoints=invalid_keypoints)
with pytest.raises(ValueError, match=re.escape("visibility must be of shape (num_instances, K)")):
one_dim_visibility = torch.tensor([True, True, True], dtype=torch.bool)
utils.draw_keypoints(image=img, keypoints=keypoints, visibility=one_dim_visibility)
with pytest.raises(ValueError, match=re.escape("visibility must be of shape (num_instances, K)")):
three_dim_visibility = torch.ones((2, 3, 4), dtype=torch.bool)
utils.draw_keypoints(image=img, keypoints=keypoints, visibility=three_dim_visibility)
with pytest.raises(ValueError, match="keypoints and visibility must have the same dimensionality"):
vis_wrong_n = torch.ones((3, 3), dtype=torch.bool)
utils.draw_keypoints(image=img, keypoints=keypoints, visibility=vis_wrong_n)
with pytest.raises(ValueError, match="keypoints and visibility must have the same dimensionality"):
vis_wrong_k = torch.ones((2, 4), dtype=torch.bool)
utils.draw_keypoints(image=img, keypoints=keypoints, visibility=vis_wrong_k)
@pytest.mark.parametrize("batch", (True, False))
def test_flow_to_image(batch):
h, w = 100, 100
flow = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
flow = torch.stack(flow[::-1], dim=0).float()
flow[0] -= h / 2
flow[1] -= w / 2
if batch:
flow = torch.stack([flow, flow])
img = utils.flow_to_image(flow)
assert img.shape == (2, 3, h, w) if batch else (3, h, w)
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "expected_flow.pt")
expected_img = torch.load(path, map_location="cpu", weights_only=True)
if batch:
expected_img = torch.stack([expected_img, expected_img])
assert_equal(expected_img, img)
@pytest.mark.parametrize(
"input_flow, match",
(
(torch.full((3, 10, 10), 0, dtype=torch.float), "Input flow should have shape"),
(torch.full((5, 3, 10, 10), 0, dtype=torch.float), "Input flow should have shape"),
(torch.full((2, 10), 0, dtype=torch.float), "Input flow should have shape"),
(torch.full((5, 2, 10), 0, dtype=torch.float), "Input flow should have shape"),
(torch.full((2, 10, 30), 0, dtype=torch.int), "Flow should be of dtype torch.float"),
),
)
def test_flow_to_image_errors(input_flow, match):
with pytest.raises(ValueError, match=match):
utils.flow_to_image(flow=input_flow)
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
import math
import os
import pytest
import torch
import torchvision
from torchvision.io import _HAS_GPU_VIDEO_DECODER, VideoReader
try:
import av
except ImportError:
av = None
VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos")
@pytest.mark.skipif(_HAS_GPU_VIDEO_DECODER is False, reason="Didn't compile with support for gpu decoder")
class TestVideoGPUDecoder:
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
@pytest.mark.parametrize(
"video_file",
[
"RATRACE_wave_f_nm_np1_fr_goo_37.avi",
"TrumanShow_wave_f_nm_np1_fr_med_26.avi",
"v_SoccerJuggling_g23_c01.avi",
"v_SoccerJuggling_g24_c01.avi",
"R6llTwEh07w.mp4",
"SOX5yA1l24A.mp4",
"WUzgd7C1pWA.mp4",
],
)
def test_frame_reading(self, video_file):
torchvision.set_video_backend("cuda")
full_path = os.path.join(VIDEO_DIR, video_file)
decoder = VideoReader(full_path)
with av.open(full_path) as container:
for av_frame in container.decode(container.streams.video[0]):
av_frames = torch.tensor(av_frame.to_rgb(src_colorspace="ITU709").to_ndarray())
vision_frames = next(decoder)["data"]
mean_delta = torch.mean(torch.abs(av_frames.float() - vision_frames.cpu().float()))
assert mean_delta < 0.75
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
@pytest.mark.parametrize("keyframes", [True, False])
@pytest.mark.parametrize(
"full_path, duration",
[
(os.path.join(VIDEO_DIR, x), y)
for x, y in [
("v_SoccerJuggling_g23_c01.avi", 8.0),
("v_SoccerJuggling_g24_c01.avi", 8.0),
("R6llTwEh07w.mp4", 10.0),
("SOX5yA1l24A.mp4", 11.0),
("WUzgd7C1pWA.mp4", 11.0),
]
],
)
def test_seek_reading(self, keyframes, full_path, duration):
torchvision.set_video_backend("cuda")
decoder = VideoReader(full_path)
time = duration / 2
decoder.seek(time, keyframes_only=keyframes)
with av.open(full_path) as container:
container.seek(int(time * 1000000), any_frame=not keyframes, backward=False)
for av_frame in container.decode(container.streams.video[0]):
av_frames = torch.tensor(av_frame.to_rgb(src_colorspace="ITU709").to_ndarray())
vision_frames = next(decoder)["data"]
mean_delta = torch.mean(torch.abs(av_frames.float() - vision_frames.cpu().float()))
assert mean_delta < 0.75
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
@pytest.mark.parametrize(
"video_file",
[
"RATRACE_wave_f_nm_np1_fr_goo_37.avi",
"TrumanShow_wave_f_nm_np1_fr_med_26.avi",
"v_SoccerJuggling_g23_c01.avi",
"v_SoccerJuggling_g24_c01.avi",
"R6llTwEh07w.mp4",
"SOX5yA1l24A.mp4",
"WUzgd7C1pWA.mp4",
],
)
def test_metadata(self, video_file):
torchvision.set_video_backend("cuda")
full_path = os.path.join(VIDEO_DIR, video_file)
decoder = VideoReader(full_path)
video_metadata = decoder.get_metadata()["video"]
with av.open(full_path) as container:
video = container.streams.video[0]
av_duration = float(video.duration * video.time_base)
assert math.isclose(video_metadata["duration"], av_duration, rel_tol=1e-2)
assert math.isclose(video_metadata["fps"], video.base_rate, rel_tol=1e-2)
if __name__ == "__main__":
pytest.main([__file__])
import collections import collections
import math import math
import os import os
import time
import unittest
from fractions import Fraction from fractions import Fraction
import numpy as np import numpy as np
import pytest
import torch import torch
import torchvision.io as io import torchvision.io as io
from common_utils import assert_equal
from numpy.random import randint from numpy.random import randint
from pytest import approx
from torchvision import set_video_backend
from torchvision.io import _HAS_VIDEO_OPT from torchvision.io import _HAS_VIDEO_OPT
from common_utils import PY39_SKIP
from _assert_utils import assert_equal
try: try:
...@@ -23,9 +23,6 @@ except ImportError: ...@@ -23,9 +23,6 @@ except ImportError:
av = None av = None
from urllib.error import URLError
VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos") VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos")
CheckerConfig = [ CheckerConfig = [
...@@ -110,18 +107,14 @@ test_videos = { ...@@ -110,18 +107,14 @@ test_videos = {
} }
DecoderResult = collections.namedtuple( DecoderResult = collections.namedtuple("DecoderResult", "vframes vframe_pts vtimebase aframes aframe_pts atimebase")
"DecoderResult", "vframes vframe_pts vtimebase aframes aframe_pts atimebase"
)
"""av_seek_frame is imprecise so seek to a timestamp earlier by a margin # av_seek_frame is imprecise so seek to a timestamp earlier by a margin
The unit of margin is second""" # The unit of margin is second
seek_frame_margin = 0.25 SEEK_FRAME_MARGIN = 0.25
def _read_from_stream( def _read_from_stream(container, start_pts, end_pts, stream, stream_name, buffer_size=4):
container, start_pts, end_pts, stream, stream_name, buffer_size=4
):
""" """
Args: Args:
container: pyav container container: pyav container
...@@ -134,7 +127,7 @@ def _read_from_stream( ...@@ -134,7 +127,7 @@ def _read_from_stream(
ascending order. We need to decode more frames even when we meet end ascending order. We need to decode more frames even when we meet end
pts pts
""" """
# seeking in the stream is imprecise. Thus, seek to an ealier PTS by a margin # seeking in the stream is imprecise. Thus, seek to an earlier PTS by a margin
margin = 1 margin = 1
seek_offset = max(start_pts - margin, 0) seek_offset = max(start_pts - margin, 0)
...@@ -233,9 +226,7 @@ def _decode_frames_by_av_module( ...@@ -233,9 +226,7 @@ def _decode_frames_by_av_module(
else: else:
aframes = torch.empty((1, 0), dtype=torch.float32) aframes = torch.empty((1, 0), dtype=torch.float32)
aframe_pts = torch.tensor( aframe_pts = torch.tensor([audio_frame.pts for audio_frame in audio_frames], dtype=torch.int64)
[audio_frame.pts for audio_frame in audio_frames], dtype=torch.int64
)
return DecoderResult( return DecoderResult(
vframes=vframes, vframes=vframes,
...@@ -266,64 +257,64 @@ def _get_video_tensor(video_dir, video_file): ...@@ -266,64 +257,64 @@ def _get_video_tensor(video_dir, video_file):
assert os.path.exists(full_path), "File not found: %s" % full_path assert os.path.exists(full_path), "File not found: %s" % full_path
with open(full_path, "rb") as fp: with open(full_path, "rb") as fp:
video_tensor = torch.from_numpy(np.frombuffer(fp.read(), dtype=np.uint8)) video_tensor = torch.frombuffer(fp.read(), dtype=torch.uint8)
return full_path, video_tensor return full_path, video_tensor
@unittest.skipIf(av is None, "PyAV unavailable") @pytest.mark.skipif(av is None, reason="PyAV unavailable")
@unittest.skipIf(_HAS_VIDEO_OPT is False, "Didn't compile with ffmpeg") @pytest.mark.skipif(_HAS_VIDEO_OPT is False, reason="Didn't compile with ffmpeg")
class TestVideoReader(unittest.TestCase): class TestVideoReader:
def check_separate_decoding_result(self, tv_result, config): def check_separate_decoding_result(self, tv_result, config):
"""check the decoding results from TorchVision decoder """check the decoding results from TorchVision decoder"""
""" (
vframes, vframe_pts, vtimebase, vfps, vduration, \ vframes,
aframes, aframe_pts, atimebase, asample_rate, aduration = ( vframe_pts,
tv_result vtimebase,
) vfps,
vduration,
aframes,
aframe_pts,
atimebase,
asample_rate,
aduration,
) = tv_result
video_duration = vduration.item() * Fraction(vtimebase[0].item(), vtimebase[1].item())
assert video_duration == approx(config.duration, abs=0.5)
assert vfps.item() == approx(config.video_fps, abs=0.5)
video_duration = vduration.item() * Fraction(
vtimebase[0].item(), vtimebase[1].item()
)
self.assertAlmostEqual(video_duration, config.duration, delta=0.5)
self.assertAlmostEqual(vfps.item(), config.video_fps, delta=0.5)
if asample_rate.numel() > 0: if asample_rate.numel() > 0:
self.assertEqual(asample_rate.item(), config.audio_sample_rate) assert asample_rate.item() == config.audio_sample_rate
audio_duration = aduration.item() * Fraction( audio_duration = aduration.item() * Fraction(atimebase[0].item(), atimebase[1].item())
atimebase[0].item(), atimebase[1].item() assert audio_duration == approx(config.duration, abs=0.5)
)
self.assertAlmostEqual(audio_duration, config.duration, delta=0.5)
# check if pts of video frames are sorted in ascending order # check if pts of video frames are sorted in ascending order
for i in range(len(vframe_pts) - 1): for i in range(len(vframe_pts) - 1):
self.assertEqual(vframe_pts[i] < vframe_pts[i + 1], True) assert vframe_pts[i] < vframe_pts[i + 1]
if len(aframe_pts) > 1: if len(aframe_pts) > 1:
# check if pts of audio frames are sorted in ascending order # check if pts of audio frames are sorted in ascending order
for i in range(len(aframe_pts) - 1): for i in range(len(aframe_pts) - 1):
self.assertEqual(aframe_pts[i] < aframe_pts[i + 1], True) assert aframe_pts[i] < aframe_pts[i + 1]
def check_probe_result(self, result, config): def check_probe_result(self, result, config):
vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
video_duration = vduration.item() * Fraction( video_duration = vduration.item() * Fraction(vtimebase[0].item(), vtimebase[1].item())
vtimebase[0].item(), vtimebase[1].item() assert video_duration == approx(config.duration, abs=0.5)
) assert vfps.item() == approx(config.video_fps, abs=0.5)
self.assertAlmostEqual(video_duration, config.duration, delta=0.5)
self.assertAlmostEqual(vfps.item(), config.video_fps, delta=0.5)
if asample_rate.numel() > 0: if asample_rate.numel() > 0:
self.assertEqual(asample_rate.item(), config.audio_sample_rate) assert asample_rate.item() == config.audio_sample_rate
audio_duration = aduration.item() * Fraction( audio_duration = aduration.item() * Fraction(atimebase[0].item(), atimebase[1].item())
atimebase[0].item(), atimebase[1].item() assert audio_duration == approx(config.duration, abs=0.5)
)
self.assertAlmostEqual(audio_duration, config.duration, delta=0.5)
def check_meta_result(self, result, config): def check_meta_result(self, result, config):
self.assertAlmostEqual(result.video_duration, config.duration, delta=0.5) assert result.video_duration == approx(config.duration, abs=0.5)
self.assertAlmostEqual(result.video_fps, config.video_fps, delta=0.5) assert result.video_fps == approx(config.video_fps, abs=0.5)
if result.has_audio > 0: if result.has_audio > 0:
self.assertEqual(result.audio_sample_rate, config.audio_sample_rate) assert result.audio_sample_rate == config.audio_sample_rate
self.assertAlmostEqual(result.audio_duration, config.duration, delta=0.5) assert result.audio_duration == approx(config.duration, abs=0.5)
def compare_decoding_result(self, tv_result, ref_result, config=all_check_config): def compare_decoding_result(self, tv_result, ref_result, config=all_check_config):
""" """
...@@ -334,10 +325,18 @@ class TestVideoReader(unittest.TestCase): ...@@ -334,10 +325,18 @@ class TestVideoReader(unittest.TestCase):
decoder or TorchVision decoder with getPtsOnly = 1 decoder or TorchVision decoder with getPtsOnly = 1
config: config of decoding results checker config: config of decoding results checker
""" """
vframes, vframe_pts, vtimebase, _vfps, _vduration, \ (
aframes, aframe_pts, atimebase, _asample_rate, _aduration = ( vframes,
tv_result vframe_pts,
) vtimebase,
_vfps,
_vduration,
aframes,
aframe_pts,
atimebase,
_asample_rate,
_aduration,
) = tv_result
if isinstance(ref_result, list): if isinstance(ref_result, list):
# the ref_result is from new video_reader decoder # the ref_result is from new video_reader decoder
ref_result = DecoderResult( ref_result = DecoderResult(
...@@ -350,43 +349,32 @@ class TestVideoReader(unittest.TestCase): ...@@ -350,43 +349,32 @@ class TestVideoReader(unittest.TestCase):
) )
if vframes.numel() > 0 and ref_result.vframes.numel() > 0: if vframes.numel() > 0 and ref_result.vframes.numel() > 0:
mean_delta = torch.mean( mean_delta = torch.mean(torch.abs(vframes.float() - ref_result.vframes.float()))
torch.abs(vframes.float() - ref_result.vframes.float()) assert mean_delta == approx(0.0, abs=8.0)
)
self.assertAlmostEqual(mean_delta, 0, delta=8.0)
mean_delta = torch.mean( mean_delta = torch.mean(torch.abs(vframe_pts.float() - ref_result.vframe_pts.float()))
torch.abs(vframe_pts.float() - ref_result.vframe_pts.float()) assert mean_delta == approx(0.0, abs=1.0)
)
self.assertAlmostEqual(mean_delta, 0, delta=1.0)
assert_equal(vtimebase, ref_result.vtimebase) assert_equal(vtimebase, ref_result.vtimebase)
if ( if config.check_aframes and aframes.numel() > 0 and ref_result.aframes.numel() > 0:
config.check_aframes
and aframes.numel() > 0
and ref_result.aframes.numel() > 0
):
"""Audio stream is available and audio frame is required to return """Audio stream is available and audio frame is required to return
from decoder""" from decoder"""
assert_equal(aframes, ref_result.aframes) assert_equal(aframes, ref_result.aframes)
if ( if config.check_aframe_pts and aframe_pts.numel() > 0 and ref_result.aframe_pts.numel() > 0:
config.check_aframe_pts
and aframe_pts.numel() > 0
and ref_result.aframe_pts.numel() > 0
):
"""Audio stream is available""" """Audio stream is available"""
assert_equal(aframe_pts, ref_result.aframe_pts) assert_equal(aframe_pts, ref_result.aframe_pts)
assert_equal(atimebase, ref_result.atimebase) assert_equal(atimebase, ref_result.atimebase)
@unittest.skip( @pytest.mark.parametrize("test_video", test_videos.keys())
"This stress test will iteratively decode the same set of videos." def test_stress_test_read_video_from_file(self, test_video):
"It helps to detect memory leak but it takes lots of time to run." pytest.skip(
"By default, it is disabled" "This stress test will iteratively decode the same set of videos."
) "It helps to detect memory leak but it takes lots of time to run."
def test_stress_test_read_video_from_file(self): "By default, it is disabled"
)
num_iter = 10000 num_iter = 10000
# video related # video related
width, height, min_dimension, max_dimension = 0, 0, 0, 0 width, height, min_dimension, max_dimension = 0, 0, 0, 0
...@@ -398,53 +386,12 @@ class TestVideoReader(unittest.TestCase): ...@@ -398,53 +386,12 @@ class TestVideoReader(unittest.TestCase):
audio_timebase_num, audio_timebase_den = 0, 1 audio_timebase_num, audio_timebase_den = 0, 1
for _i in range(num_iter): for _i in range(num_iter):
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
# pass 1: decode all frames using new decoder
torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
@PY39_SKIP
def test_read_video_from_file(self):
"""
Test the case when decoder starts with a video file to decode frames.
"""
# video related
width, height, min_dimension, max_dimension = 0, 0, 0, 0
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
samples, channels = 0, 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video) full_path = os.path.join(VIDEO_DIR, test_video)
# pass 1: decode all frames using new decoder # pass 1: decode all frames using new decoder
tv_result = torch.ops.video_reader.read_video_from_file( torch.ops.video_reader.read_video_from_file(
full_path, full_path,
seek_frame_margin, SEEK_FRAME_MARGIN,
0, # getPtsOnly 0, # getPtsOnly
1, # readVideoStream 1, # readVideoStream
width, width,
...@@ -463,15 +410,57 @@ class TestVideoReader(unittest.TestCase): ...@@ -463,15 +410,57 @@ class TestVideoReader(unittest.TestCase):
audio_timebase_num, audio_timebase_num,
audio_timebase_den, audio_timebase_den,
) )
# pass 2: decode all frames using av
pyav_result = _decode_frames_by_av_module(full_path)
# check results from TorchVision decoder
self.check_separate_decoding_result(tv_result, config)
# compare decoding results
self.compare_decoding_result(tv_result, pyav_result, config)
@PY39_SKIP @pytest.mark.parametrize("test_video,config", test_videos.items())
def test_read_video_from_file_read_single_stream_only(self): def test_read_video_from_file(self, test_video, config):
"""
Test the case when decoder starts with a video file to decode frames.
"""
# video related
width, height, min_dimension, max_dimension = 0, 0, 0, 0
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
samples, channels = 0, 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
full_path = os.path.join(VIDEO_DIR, test_video)
# pass 1: decode all frames using new decoder
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
SEEK_FRAME_MARGIN,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
# pass 2: decode all frames using av
pyav_result = _decode_frames_by_av_module(full_path)
# check results from TorchVision decoder
self.check_separate_decoding_result(tv_result, config)
# compare decoding results
self.compare_decoding_result(tv_result, pyav_result, config)
@pytest.mark.parametrize("test_video,config", test_videos.items())
@pytest.mark.parametrize("read_video_stream,read_audio_stream", [(1, 0), (0, 1)])
def test_read_video_from_file_read_single_stream_only(
self, test_video, config, read_video_stream, read_audio_stream
):
""" """
Test the case when decoder starts with a video file to decode frames, and Test the case when decoder starts with a video file to decode frames, and
only reads video stream and ignores audio stream only reads video stream and ignores audio stream
...@@ -485,51 +474,56 @@ class TestVideoReader(unittest.TestCase): ...@@ -485,51 +474,56 @@ class TestVideoReader(unittest.TestCase):
audio_start_pts, audio_end_pts = 0, -1 audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1 audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items(): full_path = os.path.join(VIDEO_DIR, test_video)
full_path = os.path.join(VIDEO_DIR, test_video) # decode all frames using new decoder
for readVideoStream, readAudioStream in [(1, 0), (0, 1)]: tv_result = torch.ops.video_reader.read_video_from_file(
# decode all frames using new decoder full_path,
tv_result = torch.ops.video_reader.read_video_from_file( SEEK_FRAME_MARGIN,
full_path, 0, # getPtsOnly
seek_frame_margin, read_video_stream,
0, # getPtsOnly width,
readVideoStream, height,
width, min_dimension,
height, max_dimension,
min_dimension, video_start_pts,
max_dimension, video_end_pts,
video_start_pts, video_timebase_num,
video_end_pts, video_timebase_den,
video_timebase_num, read_audio_stream,
video_timebase_den, samples,
readAudioStream, channels,
samples, audio_start_pts,
channels, audio_end_pts,
audio_start_pts, audio_timebase_num,
audio_end_pts, audio_timebase_den,
audio_timebase_num, )
audio_timebase_den,
) (
vframes,
vframes, vframe_pts, vtimebase, vfps, vduration, \ vframe_pts,
aframes, aframe_pts, atimebase, asample_rate, aduration = ( vtimebase,
tv_result vfps,
) vduration,
aframes,
self.assertEqual(vframes.numel() > 0, readVideoStream) aframe_pts,
self.assertEqual(vframe_pts.numel() > 0, readVideoStream) atimebase,
self.assertEqual(vtimebase.numel() > 0, readVideoStream) asample_rate,
self.assertEqual(vfps.numel() > 0, readVideoStream) aduration,
) = tv_result
expect_audio_data = (
readAudioStream == 1 and config.audio_sample_rate is not None assert (vframes.numel() > 0) is bool(read_video_stream)
) assert (vframe_pts.numel() > 0) is bool(read_video_stream)
self.assertEqual(aframes.numel() > 0, expect_audio_data) assert (vtimebase.numel() > 0) is bool(read_video_stream)
self.assertEqual(aframe_pts.numel() > 0, expect_audio_data) assert (vfps.numel() > 0) is bool(read_video_stream)
self.assertEqual(atimebase.numel() > 0, expect_audio_data)
self.assertEqual(asample_rate.numel() > 0, expect_audio_data) expect_audio_data = read_audio_stream == 1 and config.audio_sample_rate is not None
assert (aframes.numel() > 0) is bool(expect_audio_data)
def test_read_video_from_file_rescale_min_dimension(self): assert (aframe_pts.numel() > 0) is bool(expect_audio_data)
assert (atimebase.numel() > 0) is bool(expect_audio_data)
assert (asample_rate.numel() > 0) is bool(expect_audio_data)
@pytest.mark.parametrize("test_video", test_videos.keys())
def test_read_video_from_file_rescale_min_dimension(self, test_video):
""" """
Test the case when decoder starts with a video file to decode frames, and Test the case when decoder starts with a video file to decode frames, and
video min dimension between height and width is set. video min dimension between height and width is set.
...@@ -543,35 +537,33 @@ class TestVideoReader(unittest.TestCase): ...@@ -543,35 +537,33 @@ class TestVideoReader(unittest.TestCase):
audio_start_pts, audio_end_pts = 0, -1 audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1 audio_timebase_num, audio_timebase_den = 0, 1
for test_video, _config in test_videos.items(): full_path = os.path.join(VIDEO_DIR, test_video)
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
tv_result = torch.ops.video_reader.read_video_from_file( full_path,
full_path, SEEK_FRAME_MARGIN,
seek_frame_margin, 0, # getPtsOnly
0, # getPtsOnly 1, # readVideoStream
1, # readVideoStream width,
width, height,
height, min_dimension,
min_dimension, max_dimension,
max_dimension, video_start_pts,
video_start_pts, video_end_pts,
video_end_pts, video_timebase_num,
video_timebase_num, video_timebase_den,
video_timebase_den, 1, # readAudioStream
1, # readAudioStream samples,
samples, channels,
channels, audio_start_pts,
audio_start_pts, audio_end_pts,
audio_end_pts, audio_timebase_num,
audio_timebase_num, audio_timebase_den,
audio_timebase_den, )
) assert min_dimension == min(tv_result[0].size(1), tv_result[0].size(2))
self.assertEqual(
min_dimension, min(tv_result[0].size(1), tv_result[0].size(2))
)
def test_read_video_from_file_rescale_max_dimension(self): @pytest.mark.parametrize("test_video", test_videos.keys())
def test_read_video_from_file_rescale_max_dimension(self, test_video):
""" """
Test the case when decoder starts with a video file to decode frames, and Test the case when decoder starts with a video file to decode frames, and
video min dimension between height and width is set. video min dimension between height and width is set.
...@@ -585,35 +577,33 @@ class TestVideoReader(unittest.TestCase): ...@@ -585,35 +577,33 @@ class TestVideoReader(unittest.TestCase):
audio_start_pts, audio_end_pts = 0, -1 audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1 audio_timebase_num, audio_timebase_den = 0, 1
for test_video, _config in test_videos.items(): full_path = os.path.join(VIDEO_DIR, test_video)
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
tv_result = torch.ops.video_reader.read_video_from_file( full_path,
full_path, SEEK_FRAME_MARGIN,
seek_frame_margin, 0, # getPtsOnly
0, # getPtsOnly 1, # readVideoStream
1, # readVideoStream width,
width, height,
height, min_dimension,
min_dimension, max_dimension,
max_dimension, video_start_pts,
video_start_pts, video_end_pts,
video_end_pts, video_timebase_num,
video_timebase_num, video_timebase_den,
video_timebase_den, 1, # readAudioStream
1, # readAudioStream samples,
samples, channels,
channels, audio_start_pts,
audio_start_pts, audio_end_pts,
audio_end_pts, audio_timebase_num,
audio_timebase_num, audio_timebase_den,
audio_timebase_den, )
) assert max_dimension == max(tv_result[0].size(1), tv_result[0].size(2))
self.assertEqual(
max_dimension, max(tv_result[0].size(1), tv_result[0].size(2))
)
def test_read_video_from_file_rescale_both_min_max_dimension(self): @pytest.mark.parametrize("test_video", test_videos.keys())
def test_read_video_from_file_rescale_both_min_max_dimension(self, test_video):
""" """
Test the case when decoder starts with a video file to decode frames, and Test the case when decoder starts with a video file to decode frames, and
video min dimension between height and width is set. video min dimension between height and width is set.
...@@ -627,38 +617,34 @@ class TestVideoReader(unittest.TestCase): ...@@ -627,38 +617,34 @@ class TestVideoReader(unittest.TestCase):
audio_start_pts, audio_end_pts = 0, -1 audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1 audio_timebase_num, audio_timebase_den = 0, 1
for test_video, _config in test_videos.items(): full_path = os.path.join(VIDEO_DIR, test_video)
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
tv_result = torch.ops.video_reader.read_video_from_file( full_path,
full_path, SEEK_FRAME_MARGIN,
seek_frame_margin, 0, # getPtsOnly
0, # getPtsOnly 1, # readVideoStream
1, # readVideoStream width,
width, height,
height, min_dimension,
min_dimension, max_dimension,
max_dimension, video_start_pts,
video_start_pts, video_end_pts,
video_end_pts, video_timebase_num,
video_timebase_num, video_timebase_den,
video_timebase_den, 1, # readAudioStream
1, # readAudioStream samples,
samples, channels,
channels, audio_start_pts,
audio_start_pts, audio_end_pts,
audio_end_pts, audio_timebase_num,
audio_timebase_num, audio_timebase_den,
audio_timebase_den, )
) assert min_dimension == min(tv_result[0].size(1), tv_result[0].size(2))
self.assertEqual( assert max_dimension == max(tv_result[0].size(1), tv_result[0].size(2))
min_dimension, min(tv_result[0].size(1), tv_result[0].size(2))
)
self.assertEqual(
max_dimension, max(tv_result[0].size(1), tv_result[0].size(2))
)
def test_read_video_from_file_rescale_width(self): @pytest.mark.parametrize("test_video", test_videos.keys())
def test_read_video_from_file_rescale_width(self, test_video):
""" """
Test the case when decoder starts with a video file to decode frames, and Test the case when decoder starts with a video file to decode frames, and
video width is set. video width is set.
...@@ -672,33 +658,33 @@ class TestVideoReader(unittest.TestCase): ...@@ -672,33 +658,33 @@ class TestVideoReader(unittest.TestCase):
audio_start_pts, audio_end_pts = 0, -1 audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1 audio_timebase_num, audio_timebase_den = 0, 1
for test_video, _config in test_videos.items(): full_path = os.path.join(VIDEO_DIR, test_video)
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
tv_result = torch.ops.video_reader.read_video_from_file( full_path,
full_path, SEEK_FRAME_MARGIN,
seek_frame_margin, 0, # getPtsOnly
0, # getPtsOnly 1, # readVideoStream
1, # readVideoStream width,
width, height,
height, min_dimension,
min_dimension, max_dimension,
max_dimension, video_start_pts,
video_start_pts, video_end_pts,
video_end_pts, video_timebase_num,
video_timebase_num, video_timebase_den,
video_timebase_den, 1, # readAudioStream
1, # readAudioStream samples,
samples, channels,
channels, audio_start_pts,
audio_start_pts, audio_end_pts,
audio_end_pts, audio_timebase_num,
audio_timebase_num, audio_timebase_den,
audio_timebase_den, )
) assert tv_result[0].size(2) == width
self.assertEqual(tv_result[0].size(2), width)
def test_read_video_from_file_rescale_height(self): @pytest.mark.parametrize("test_video", test_videos.keys())
def test_read_video_from_file_rescale_height(self, test_video):
""" """
Test the case when decoder starts with a video file to decode frames, and Test the case when decoder starts with a video file to decode frames, and
video height is set. video height is set.
...@@ -712,33 +698,33 @@ class TestVideoReader(unittest.TestCase): ...@@ -712,33 +698,33 @@ class TestVideoReader(unittest.TestCase):
audio_start_pts, audio_end_pts = 0, -1 audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1 audio_timebase_num, audio_timebase_den = 0, 1
for test_video, _config in test_videos.items(): full_path = os.path.join(VIDEO_DIR, test_video)
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
tv_result = torch.ops.video_reader.read_video_from_file( full_path,
full_path, SEEK_FRAME_MARGIN,
seek_frame_margin, 0, # getPtsOnly
0, # getPtsOnly 1, # readVideoStream
1, # readVideoStream width,
width, height,
height, min_dimension,
min_dimension, max_dimension,
max_dimension, video_start_pts,
video_start_pts, video_end_pts,
video_end_pts, video_timebase_num,
video_timebase_num, video_timebase_den,
video_timebase_den, 1, # readAudioStream
1, # readAudioStream samples,
samples, channels,
channels, audio_start_pts,
audio_start_pts, audio_end_pts,
audio_end_pts, audio_timebase_num,
audio_timebase_num, audio_timebase_den,
audio_timebase_den, )
) assert tv_result[0].size(1) == height
self.assertEqual(tv_result[0].size(1), height)
def test_read_video_from_file_rescale_width_and_height(self): @pytest.mark.parametrize("test_video", test_videos.keys())
def test_read_video_from_file_rescale_width_and_height(self, test_video):
""" """
Test the case when decoder starts with a video file to decode frames, and Test the case when decoder starts with a video file to decode frames, and
both video height and width are set. both video height and width are set.
...@@ -752,95 +738,92 @@ class TestVideoReader(unittest.TestCase): ...@@ -752,95 +738,92 @@ class TestVideoReader(unittest.TestCase):
audio_start_pts, audio_end_pts = 0, -1 audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1 audio_timebase_num, audio_timebase_den = 0, 1
for test_video, _config in test_videos.items(): full_path = os.path.join(VIDEO_DIR, test_video)
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
tv_result = torch.ops.video_reader.read_video_from_file( full_path,
full_path, SEEK_FRAME_MARGIN,
seek_frame_margin, 0, # getPtsOnly
0, # getPtsOnly 1, # readVideoStream
1, # readVideoStream width,
width, height,
height, min_dimension,
min_dimension, max_dimension,
max_dimension, video_start_pts,
video_start_pts, video_end_pts,
video_end_pts, video_timebase_num,
video_timebase_num, video_timebase_den,
video_timebase_den, 1, # readAudioStream
1, # readAudioStream samples,
samples, channels,
channels, audio_start_pts,
audio_start_pts, audio_end_pts,
audio_end_pts, audio_timebase_num,
audio_timebase_num, audio_timebase_den,
audio_timebase_den, )
) assert tv_result[0].size(1) == height
self.assertEqual(tv_result[0].size(1), height) assert tv_result[0].size(2) == width
self.assertEqual(tv_result[0].size(2), width)
@PY39_SKIP @pytest.mark.parametrize("test_video", test_videos.keys())
def test_read_video_from_file_audio_resampling(self): @pytest.mark.parametrize("samples", [9600, 96000])
def test_read_video_from_file_audio_resampling(self, test_video, samples):
""" """
Test the case when decoder starts with a video file to decode frames, and Test the case when decoder starts with a video file to decode frames, and
audio waveform are resampled audio waveform are resampled
""" """
# video related
width, height, min_dimension, max_dimension = 0, 0, 0, 0
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
channels = 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for samples in [9600, 96000]: # downsampling # upsampling full_path = os.path.join(VIDEO_DIR, test_video)
# video related
width, height, min_dimension, max_dimension = 0, 0, 0, 0 tv_result = torch.ops.video_reader.read_video_from_file(
video_start_pts, video_end_pts = 0, -1 full_path,
video_timebase_num, video_timebase_den = 0, 1 SEEK_FRAME_MARGIN,
# audio related 0, # getPtsOnly
channels = 0 1, # readVideoStream
audio_start_pts, audio_end_pts = 0, -1 width,
audio_timebase_num, audio_timebase_den = 0, 1 height,
min_dimension,
for test_video, _config in test_videos.items(): max_dimension,
full_path = os.path.join(VIDEO_DIR, test_video) video_start_pts,
video_end_pts,
tv_result = torch.ops.video_reader.read_video_from_file( video_timebase_num,
full_path, video_timebase_den,
seek_frame_margin, 1, # readAudioStream
0, # getPtsOnly samples,
1, # readVideoStream channels,
width, audio_start_pts,
height, audio_end_pts,
min_dimension, audio_timebase_num,
max_dimension, audio_timebase_den,
video_start_pts, )
video_end_pts, (
video_timebase_num, vframes,
video_timebase_den, vframe_pts,
1, # readAudioStream vtimebase,
samples, vfps,
channels, vduration,
audio_start_pts, aframes,
audio_end_pts, aframe_pts,
audio_timebase_num, atimebase,
audio_timebase_den, asample_rate,
) aduration,
vframes, vframe_pts, vtimebase, vfps, vduration, \ ) = tv_result
aframes, aframe_pts, atimebase, asample_rate, aduration = ( if aframes.numel() > 0:
tv_result assert samples == asample_rate.item()
) assert 1 == aframes.size(1)
if aframes.numel() > 0: # when audio stream is found
self.assertEqual(samples, asample_rate.item()) duration = float(aframe_pts[-1]) * float(atimebase[0]) / float(atimebase[1])
self.assertEqual(1, aframes.size(1)) assert aframes.size(0) == approx(int(duration * asample_rate.item()), abs=0.1 * asample_rate.item())
# when audio stream is found
duration = ( @pytest.mark.parametrize("test_video,config", test_videos.items())
float(aframe_pts[-1]) def test_compare_read_video_from_memory_and_file(self, test_video, config):
* float(atimebase[0])
/ float(atimebase[1])
)
self.assertAlmostEqual(
aframes.size(0),
int(duration * asample_rate.item()),
delta=0.1 * asample_rate.item(),
)
@PY39_SKIP
def test_compare_read_video_from_memory_and_file(self):
""" """
Test the case when video is already in memory, and decoder reads data in memory Test the case when video is already in memory, and decoder reads data in memory
""" """
...@@ -853,61 +836,60 @@ class TestVideoReader(unittest.TestCase): ...@@ -853,61 +836,60 @@ class TestVideoReader(unittest.TestCase):
audio_start_pts, audio_end_pts = 0, -1 audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1 audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items(): full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
# pass 1: decode all frames using cpp decoder
# pass 1: decode all frames using cpp decoder tv_result_memory = torch.ops.video_reader.read_video_from_memory(
tv_result_memory = torch.ops.video_reader.read_video_from_memory( video_tensor,
video_tensor, SEEK_FRAME_MARGIN,
seek_frame_margin, 0, # getPtsOnly
0, # getPtsOnly 1, # readVideoStream
1, # readVideoStream width,
width, height,
height, min_dimension,
min_dimension, max_dimension,
max_dimension, video_start_pts,
video_start_pts, video_end_pts,
video_end_pts, video_timebase_num,
video_timebase_num, video_timebase_den,
video_timebase_den, 1, # readAudioStream
1, # readAudioStream samples,
samples, channels,
channels, audio_start_pts,
audio_start_pts, audio_end_pts,
audio_end_pts, audio_timebase_num,
audio_timebase_num, audio_timebase_den,
audio_timebase_den, )
) self.check_separate_decoding_result(tv_result_memory, config)
self.check_separate_decoding_result(tv_result_memory, config) # pass 2: decode all frames from file
# pass 2: decode all frames from file tv_result_file = torch.ops.video_reader.read_video_from_file(
tv_result_file = torch.ops.video_reader.read_video_from_file( full_path,
full_path, SEEK_FRAME_MARGIN,
seek_frame_margin, 0, # getPtsOnly
0, # getPtsOnly 1, # readVideoStream
1, # readVideoStream width,
width, height,
height, min_dimension,
min_dimension, max_dimension,
max_dimension, video_start_pts,
video_start_pts, video_end_pts,
video_end_pts, video_timebase_num,
video_timebase_num, video_timebase_den,
video_timebase_den, 1, # readAudioStream
1, # readAudioStream samples,
samples, channels,
channels, audio_start_pts,
audio_start_pts, audio_end_pts,
audio_end_pts, audio_timebase_num,
audio_timebase_num, audio_timebase_den,
audio_timebase_den, )
)
self.check_separate_decoding_result(tv_result_file, config) self.check_separate_decoding_result(tv_result_file, config)
# finally, compare results decoded from memory and file # finally, compare results decoded from memory and file
self.compare_decoding_result(tv_result_memory, tv_result_file) self.compare_decoding_result(tv_result_memory, tv_result_file)
@PY39_SKIP @pytest.mark.parametrize("test_video,config", test_videos.items())
def test_read_video_from_memory(self): def test_read_video_from_memory(self, test_video, config):
""" """
Test the case when video is already in memory, and decoder reads data in memory Test the case when video is already in memory, and decoder reads data in memory
""" """
...@@ -920,39 +902,38 @@ class TestVideoReader(unittest.TestCase): ...@@ -920,39 +902,38 @@ class TestVideoReader(unittest.TestCase):
audio_start_pts, audio_end_pts = 0, -1 audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1 audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items(): full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
# pass 1: decode all frames using cpp decoder
# pass 1: decode all frames using cpp decoder tv_result = torch.ops.video_reader.read_video_from_memory(
tv_result = torch.ops.video_reader.read_video_from_memory( video_tensor,
video_tensor, SEEK_FRAME_MARGIN,
seek_frame_margin, 0, # getPtsOnly
0, # getPtsOnly 1, # readVideoStream
1, # readVideoStream width,
width, height,
height, min_dimension,
min_dimension, max_dimension,
max_dimension, video_start_pts,
video_start_pts, video_end_pts,
video_end_pts, video_timebase_num,
video_timebase_num, video_timebase_den,
video_timebase_den, 1, # readAudioStream
1, # readAudioStream samples,
samples, channels,
channels, audio_start_pts,
audio_start_pts, audio_end_pts,
audio_end_pts, audio_timebase_num,
audio_timebase_num, audio_timebase_den,
audio_timebase_den, )
) # pass 2: decode all frames using av
# pass 2: decode all frames using av pyav_result = _decode_frames_by_av_module(full_path)
pyav_result = _decode_frames_by_av_module(full_path)
self.check_separate_decoding_result(tv_result, config) self.check_separate_decoding_result(tv_result, config)
self.compare_decoding_result(tv_result, pyav_result, config) self.compare_decoding_result(tv_result, pyav_result, config)
@PY39_SKIP @pytest.mark.parametrize("test_video,config", test_videos.items())
def test_read_video_from_memory_get_pts_only(self): def test_read_video_from_memory_get_pts_only(self, test_video, config):
""" """
Test the case when video is already in memory, and decoder reads data in memory. Test the case when video is already in memory, and decoder reads data in memory.
Compare frame pts between decoding for pts only and full decoding Compare frame pts between decoding for pts only and full decoding
...@@ -967,238 +948,234 @@ class TestVideoReader(unittest.TestCase): ...@@ -967,238 +948,234 @@ class TestVideoReader(unittest.TestCase):
audio_start_pts, audio_end_pts = 0, -1 audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1 audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items(): _, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
# pass 1: decode all frames using cpp decoder
# pass 1: decode all frames using cpp decoder tv_result = torch.ops.video_reader.read_video_from_memory(
tv_result = torch.ops.video_reader.read_video_from_memory( video_tensor,
video_tensor, SEEK_FRAME_MARGIN,
seek_frame_margin, 0, # getPtsOnly
0, # getPtsOnly 1, # readVideoStream
1, # readVideoStream width,
width, height,
height, min_dimension,
min_dimension, max_dimension,
max_dimension, video_start_pts,
video_start_pts, video_end_pts,
video_end_pts, video_timebase_num,
video_timebase_num, video_timebase_den,
video_timebase_den, 1, # readAudioStream
1, # readAudioStream samples,
samples, channels,
channels, audio_start_pts,
audio_start_pts, audio_end_pts,
audio_end_pts, audio_timebase_num,
audio_timebase_num, audio_timebase_den,
audio_timebase_den, )
) assert abs(config.video_fps - tv_result[3].item()) < 0.01
self.assertAlmostEqual(config.video_fps, tv_result[3].item(), delta=0.01)
# pass 2: decode all frames to get PTS only using cpp decoder
# pass 2: decode all frames to get PTS only using cpp decoder tv_result_pts_only = torch.ops.video_reader.read_video_from_memory(
tv_result_pts_only = torch.ops.video_reader.read_video_from_memory( video_tensor,
video_tensor, SEEK_FRAME_MARGIN,
seek_frame_margin, 1, # getPtsOnly
1, # getPtsOnly 1, # readVideoStream
1, # readVideoStream width,
width, height,
height, min_dimension,
min_dimension, max_dimension,
max_dimension, video_start_pts,
video_start_pts, video_end_pts,
video_end_pts, video_timebase_num,
video_timebase_num, video_timebase_den,
video_timebase_den, 1, # readAudioStream
1, # readAudioStream samples,
samples, channels,
channels, audio_start_pts,
audio_start_pts, audio_end_pts,
audio_end_pts, audio_timebase_num,
audio_timebase_num, audio_timebase_den,
audio_timebase_den, )
)
self.assertEqual(tv_result_pts_only[0].numel(), 0) assert not tv_result_pts_only[0].numel()
self.assertEqual(tv_result_pts_only[5].numel(), 0) assert not tv_result_pts_only[5].numel()
self.compare_decoding_result(tv_result, tv_result_pts_only) self.compare_decoding_result(tv_result, tv_result_pts_only)
@PY39_SKIP @pytest.mark.parametrize("test_video,config", test_videos.items())
def test_read_video_in_range_from_memory(self): @pytest.mark.parametrize("num_frames", [4, 8, 16, 32, 64, 128])
def test_read_video_in_range_from_memory(self, test_video, config, num_frames):
""" """
Test the case when video is already in memory, and decoder reads data in memory. Test the case when video is already in memory, and decoder reads data in memory.
In addition, decoder takes meaningful start- and end PTS as input, and decode In addition, decoder takes meaningful start- and end PTS as input, and decode
frames within that interval frames within that interval
""" """
for test_video, config in test_videos.items(): full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video) # video related
# video related width, height, min_dimension, max_dimension = 0, 0, 0, 0
width, height, min_dimension, max_dimension = 0, 0, 0, 0 video_start_pts, video_end_pts = 0, -1
video_start_pts, video_end_pts = 0, -1 video_timebase_num, video_timebase_den = 0, 1
video_timebase_num, video_timebase_den = 0, 1 # audio related
# audio related samples, channels = 0, 0
samples, channels = 0, 0 audio_start_pts, audio_end_pts = 0, -1
audio_start_pts, audio_end_pts = 0, -1 audio_timebase_num, audio_timebase_den = 0, 1
audio_timebase_num, audio_timebase_den = 0, 1 # pass 1: decode all frames using new decoder
# pass 1: decode all frames using new decoder tv_result = torch.ops.video_reader.read_video_from_memory(
tv_result = torch.ops.video_reader.read_video_from_memory( video_tensor,
video_tensor, SEEK_FRAME_MARGIN,
seek_frame_margin, 0, # getPtsOnly
0, # getPtsOnly 1, # readVideoStream
1, # readVideoStream width,
width, height,
height, min_dimension,
min_dimension, max_dimension,
max_dimension, video_start_pts,
video_start_pts, video_end_pts,
video_end_pts, video_timebase_num,
video_timebase_num, video_timebase_den,
video_timebase_den, 1, # readAudioStream
1, # readAudioStream samples,
samples, channels,
channels, audio_start_pts,
audio_start_pts, audio_end_pts,
audio_end_pts, audio_timebase_num,
audio_timebase_num, audio_timebase_den,
audio_timebase_den, )
(
vframes,
vframe_pts,
vtimebase,
vfps,
vduration,
aframes,
aframe_pts,
atimebase,
asample_rate,
aduration,
) = tv_result
assert abs(config.video_fps - vfps.item()) < 0.01
start_pts_ind_max = vframe_pts.size(0) - num_frames
if start_pts_ind_max <= 0:
return
# randomly pick start pts
start_pts_ind = randint(0, start_pts_ind_max)
end_pts_ind = start_pts_ind + num_frames - 1
video_start_pts = vframe_pts[start_pts_ind]
video_end_pts = vframe_pts[end_pts_ind]
video_timebase_num, video_timebase_den = vtimebase[0], vtimebase[1]
if len(atimebase) > 0:
# when audio stream is available
audio_timebase_num, audio_timebase_den = atimebase[0], atimebase[1]
audio_start_pts = _pts_convert(
video_start_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(audio_timebase_num.item(), audio_timebase_den.item()),
math.floor,
)
audio_end_pts = _pts_convert(
video_end_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(audio_timebase_num.item(), audio_timebase_den.item()),
math.ceil,
)
# pass 2: decode frames in the randomly generated range
tv_result = torch.ops.video_reader.read_video_from_memory(
video_tensor,
SEEK_FRAME_MARGIN,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
# pass 3: decode frames in range using PyAv
video_timebase_av, audio_timebase_av = _get_timebase_by_av_module(full_path)
video_start_pts_av = _pts_convert(
video_start_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(video_timebase_av.numerator, video_timebase_av.denominator),
math.floor,
)
video_end_pts_av = _pts_convert(
video_end_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(video_timebase_av.numerator, video_timebase_av.denominator),
math.ceil,
)
if audio_timebase_av:
audio_start_pts = _pts_convert(
video_start_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(audio_timebase_av.numerator, audio_timebase_av.denominator),
math.floor,
)
audio_end_pts = _pts_convert(
video_end_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(audio_timebase_av.numerator, audio_timebase_av.denominator),
math.ceil,
) )
vframes, vframe_pts, vtimebase, vfps, vduration, \
aframes, aframe_pts, atimebase, asample_rate, aduration = ( pyav_result = _decode_frames_by_av_module(
tv_result full_path,
) video_start_pts_av,
self.assertAlmostEqual(config.video_fps, vfps.item(), delta=0.01) video_end_pts_av,
audio_start_pts,
for num_frames in [4, 8, 16, 32, 64, 128]: audio_end_pts,
start_pts_ind_max = vframe_pts.size(0) - num_frames )
if start_pts_ind_max <= 0:
continue assert tv_result[0].size(0) == num_frames
# randomly pick start pts if pyav_result.vframes.size(0) == num_frames:
start_pts_ind = randint(0, start_pts_ind_max) # if PyAv decodes a different number of video frames, skip
end_pts_ind = start_pts_ind + num_frames - 1 # comparing the decoding results between Torchvision video reader
video_start_pts = vframe_pts[start_pts_ind] # and PyAv
video_end_pts = vframe_pts[end_pts_ind] self.compare_decoding_result(tv_result, pyav_result, config)
video_timebase_num, video_timebase_den = vtimebase[0], vtimebase[1] @pytest.mark.parametrize("test_video,config", test_videos.items())
if len(atimebase) > 0: def test_probe_video_from_file(self, test_video, config):
# when audio stream is available
audio_timebase_num, audio_timebase_den = atimebase[0], atimebase[1]
audio_start_pts = _pts_convert(
video_start_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(audio_timebase_num.item(), audio_timebase_den.item()),
math.floor,
)
audio_end_pts = _pts_convert(
video_end_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(audio_timebase_num.item(), audio_timebase_den.item()),
math.ceil,
)
# pass 2: decode frames in the randomly generated range
tv_result = torch.ops.video_reader.read_video_from_memory(
video_tensor,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
max_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
# pass 3: decode frames in range using PyAv
video_timebase_av, audio_timebase_av = _get_timebase_by_av_module(
full_path
)
video_start_pts_av = _pts_convert(
video_start_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(
video_timebase_av.numerator, video_timebase_av.denominator
),
math.floor,
)
video_end_pts_av = _pts_convert(
video_end_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(
video_timebase_av.numerator, video_timebase_av.denominator
),
math.ceil,
)
if audio_timebase_av:
audio_start_pts = _pts_convert(
video_start_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(
audio_timebase_av.numerator, audio_timebase_av.denominator
),
math.floor,
)
audio_end_pts = _pts_convert(
video_end_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(
audio_timebase_av.numerator, audio_timebase_av.denominator
),
math.ceil,
)
pyav_result = _decode_frames_by_av_module(
full_path,
video_start_pts_av,
video_end_pts_av,
audio_start_pts,
audio_end_pts,
)
self.assertEqual(tv_result[0].size(0), num_frames)
if pyav_result.vframes.size(0) == num_frames:
# if PyAv decodes a different number of video frames, skip
# comparing the decoding results between Torchvision video reader
# and PyAv
self.compare_decoding_result(tv_result, pyav_result, config)
def test_probe_video_from_file(self):
""" """
Test the case when decoder probes a video file Test the case when decoder probes a video file
""" """
for test_video, config in test_videos.items(): full_path = os.path.join(VIDEO_DIR, test_video)
full_path = os.path.join(VIDEO_DIR, test_video) probe_result = torch.ops.video_reader.probe_video_from_file(full_path)
probe_result = torch.ops.video_reader.probe_video_from_file(full_path) self.check_probe_result(probe_result, config)
self.check_probe_result(probe_result, config)
def test_probe_video_from_memory(self): @pytest.mark.parametrize("test_video,config", test_videos.items())
def test_probe_video_from_memory(self, test_video, config):
""" """
Test the case when decoder probes a video in memory Test the case when decoder probes a video in memory
""" """
for test_video, config in test_videos.items(): _, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video) probe_result = torch.ops.video_reader.probe_video_from_memory(video_tensor)
probe_result = torch.ops.video_reader.probe_video_from_memory(video_tensor) self.check_probe_result(probe_result, config)
self.check_probe_result(probe_result, config)
def test_probe_video_from_memory_script(self): @pytest.mark.parametrize("test_video,config", test_videos.items())
def test_probe_video_from_memory_script(self, test_video, config):
scripted_fun = torch.jit.script(io._probe_video_from_memory) scripted_fun = torch.jit.script(io._probe_video_from_memory)
self.assertIsNotNone(scripted_fun) assert scripted_fun is not None
for test_video, config in test_videos.items(): _, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video) probe_result = scripted_fun(video_tensor)
probe_result = scripted_fun(video_tensor) self.check_meta_result(probe_result, config)
self.check_meta_result(probe_result, config)
@PY39_SKIP @pytest.mark.parametrize("test_video", test_videos.keys())
def test_read_video_from_memory_scripted(self): def test_read_video_from_memory_scripted(self, test_video):
""" """
Test the case when video is already in memory, and decoder reads data in memory Test the case when video is already in memory, and decoder reads data in memory
""" """
...@@ -1212,71 +1189,66 @@ class TestVideoReader(unittest.TestCase): ...@@ -1212,71 +1189,66 @@ class TestVideoReader(unittest.TestCase):
audio_timebase_num, audio_timebase_den = 0, 1 audio_timebase_num, audio_timebase_den = 0, 1
scripted_fun = torch.jit.script(io._read_video_from_memory) scripted_fun = torch.jit.script(io._read_video_from_memory)
self.assertIsNotNone(scripted_fun) assert scripted_fun is not None
for test_video, _config in test_videos.items(): _, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
# decode all frames using cpp decoder
# decode all frames using cpp decoder scripted_fun(
scripted_fun( video_tensor,
video_tensor, SEEK_FRAME_MARGIN,
seek_frame_margin, 1, # readVideoStream
1, # readVideoStream width,
width, height,
height, min_dimension,
min_dimension, max_dimension,
max_dimension, [video_start_pts, video_end_pts],
[video_start_pts, video_end_pts], video_timebase_num,
video_timebase_num, video_timebase_den,
video_timebase_den, 1, # readAudioStream
1, # readAudioStream samples,
samples, channels,
channels, [audio_start_pts, audio_end_pts],
[audio_start_pts, audio_end_pts], audio_timebase_num,
audio_timebase_num, audio_timebase_den,
audio_timebase_den, )
) # FUTURE: check value of video / audio frames
# FUTURE: check value of video / audio frames
def test_invalid_file(self):
def test_audio_video_sync(self): set_video_backend("video_reader")
"""Test if audio/video are synchronised with pyav output.""" with pytest.raises(RuntimeError):
for test_video, config in test_videos.items(): io.read_video("foo.mp4")
full_path = os.path.join(VIDEO_DIR, test_video)
container = av.open(full_path) set_video_backend("pyav")
if not container.streams.audio: with pytest.raises(RuntimeError):
# Skip if no audio stream io.read_video("foo.mp4")
continue
start_pts_val, cutoff = 0, 1 @pytest.mark.parametrize("test_video", test_videos.keys())
if container.streams.video: @pytest.mark.parametrize("backend", ["video_reader", "pyav"])
video = container.streams.video[0] @pytest.mark.parametrize("start_offset", [0, 500])
arr = [] @pytest.mark.parametrize("end_offset", [3000, None])
for index, frame in enumerate(container.decode(video)): def test_audio_present_pts(self, test_video, backend, start_offset, end_offset):
if index == cutoff: """Test if audio frames are returned with pts unit."""
start_pts_val = frame.pts full_path = os.path.join(VIDEO_DIR, test_video)
if index >= cutoff: container = av.open(full_path)
arr.append(frame.to_rgb().to_ndarray()) if container.streams.audio:
visual, _, info = io.read_video(full_path, start_pts=start_pts_val, pts_unit='pts') set_video_backend(backend)
self.assertAlmostEqual( _, audio, _ = io.read_video(full_path, start_offset, end_offset, pts_unit="pts")
config.video_fps, info['video_fps'], delta=0.0001 assert all([dimension > 0 for dimension in audio.shape[:2]])
)
arr = torch.Tensor(arr) @pytest.mark.parametrize("test_video", test_videos.keys())
if arr.shape == visual.shape: @pytest.mark.parametrize("backend", ["video_reader", "pyav"])
self.assertGreaterEqual( @pytest.mark.parametrize("start_offset", [0, 0.1])
torch.mean(torch.isclose(visual.float(), arr, atol=1e-5).float()), 0.99) @pytest.mark.parametrize("end_offset", [0.3, None])
def test_audio_present_sec(self, test_video, backend, start_offset, end_offset):
container = av.open(full_path) """Test if audio frames are returned with sec unit."""
if container.streams.audio: full_path = os.path.join(VIDEO_DIR, test_video)
audio = container.streams.audio[0] container = av.open(full_path)
arr = [] if container.streams.audio:
for index, frame in enumerate(container.decode(audio)): set_video_backend(backend)
if index >= cutoff: _, audio, _ = io.read_video(full_path, start_offset, end_offset, pts_unit="sec")
arr.append(frame.to_ndarray()) assert all([dimension > 0 for dimension in audio.shape[:2]])
_, audio, _ = io.read_video(full_path, start_pts=start_pts_val, pts_unit='pts')
arr = torch.as_tensor(np.concatenate(arr, axis=1))
if arr.shape == audio.shape:
self.assertGreaterEqual(
torch.mean(torch.isclose(audio.float(), arr).float()), 0.99)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() pytest.main([__file__])
import collections import collections
import os import os
import unittest import urllib
import pytest
import torch import torch
import torchvision import torchvision
from torchvision.io import _HAS_VIDEO_OPT, VideoReader from pytest import approx
from torchvision.datasets.utils import download_url from torchvision.datasets.utils import download_url
from torchvision.io import _HAS_VIDEO_OPT, VideoReader
# WARNING: these tests have been skipped forever on the CI because the video ops
# are never properly available. This is bad, but things have been in a terrible
# state for a long time already as we write this comment, and we'll hopefully be
# able to get rid of this all soon.
from common_utils import PY39_SKIP
try: try:
import av import av
...@@ -24,6 +31,13 @@ CheckerConfig = ["duration", "video_fps", "audio_sample_rate"] ...@@ -24,6 +31,13 @@ CheckerConfig = ["duration", "video_fps", "audio_sample_rate"]
GroundTruth = collections.namedtuple("GroundTruth", " ".join(CheckerConfig)) GroundTruth = collections.namedtuple("GroundTruth", " ".join(CheckerConfig))
def backends():
backends_ = ["video_reader"]
if av is not None:
backends_.append("pyav")
return backends_
def fate(name, path="."): def fate(name, path="."):
"""Download and return a path to a sample from the FFmpeg test suite. """Download and return a path to a sample from the FFmpeg test suite.
See the `FFmpeg Automated Test Environment <https://www.ffmpeg.org/fate.html>`_ See the `FFmpeg Automated Test Environment <https://www.ffmpeg.org/fate.html>`_
...@@ -35,166 +49,264 @@ def fate(name, path="."): ...@@ -35,166 +49,264 @@ def fate(name, path="."):
test_videos = { test_videos = {
"RATRACE_wave_f_nm_np1_fr_goo_37.avi": GroundTruth( "RATRACE_wave_f_nm_np1_fr_goo_37.avi": GroundTruth(duration=2.0, video_fps=30.0, audio_sample_rate=None),
duration=2.0, video_fps=30.0, audio_sample_rate=None
),
"SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi": GroundTruth( "SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi": GroundTruth(
duration=2.0, video_fps=30.0, audio_sample_rate=None duration=2.0, video_fps=30.0, audio_sample_rate=None
), ),
"TrumanShow_wave_f_nm_np1_fr_med_26.avi": GroundTruth( "TrumanShow_wave_f_nm_np1_fr_med_26.avi": GroundTruth(duration=2.0, video_fps=30.0, audio_sample_rate=None),
duration=2.0, video_fps=30.0, audio_sample_rate=None "v_SoccerJuggling_g23_c01.avi": GroundTruth(duration=8.0, video_fps=29.97, audio_sample_rate=None),
), "v_SoccerJuggling_g24_c01.avi": GroundTruth(duration=8.0, video_fps=29.97, audio_sample_rate=None),
"v_SoccerJuggling_g23_c01.avi": GroundTruth( "R6llTwEh07w.mp4": GroundTruth(duration=10.0, video_fps=30.0, audio_sample_rate=44100),
duration=8.0, video_fps=29.97, audio_sample_rate=None "SOX5yA1l24A.mp4": GroundTruth(duration=11.0, video_fps=29.97, audio_sample_rate=48000),
), "WUzgd7C1pWA.mp4": GroundTruth(duration=11.0, video_fps=29.97, audio_sample_rate=48000),
"v_SoccerJuggling_g24_c01.avi": GroundTruth(
duration=8.0, video_fps=29.97, audio_sample_rate=None
),
"R6llTwEh07w.mp4": GroundTruth(
duration=10.0, video_fps=30.0, audio_sample_rate=44100
),
"SOX5yA1l24A.mp4": GroundTruth(
duration=11.0, video_fps=29.97, audio_sample_rate=48000
),
"WUzgd7C1pWA.mp4": GroundTruth(
duration=11.0, video_fps=29.97, audio_sample_rate=48000
),
} }
@unittest.skipIf(_HAS_VIDEO_OPT is False, "Didn't compile with ffmpeg") @pytest.mark.skipif(_HAS_VIDEO_OPT is False, reason="Didn't compile with ffmpeg")
@PY39_SKIP class TestVideoApi:
class TestVideoApi(unittest.TestCase): @pytest.mark.skipif(av is None, reason="PyAV unavailable")
@unittest.skipIf(av is None, "PyAV unavailable") @pytest.mark.parametrize("test_video", test_videos.keys())
def test_frame_reading(self): @pytest.mark.parametrize("backend", backends())
for test_video, config in test_videos.items(): def test_frame_reading(self, test_video, backend):
full_path = os.path.join(VIDEO_DIR, test_video) torchvision.set_video_backend(backend)
full_path = os.path.join(VIDEO_DIR, test_video)
av_reader = av.open(full_path) with av.open(full_path) as av_reader:
if av_reader.streams.video: if av_reader.streams.video:
video_reader = VideoReader(full_path, "video") av_frames, vr_frames = [], []
av_pts, vr_pts = [], []
# get av frames
for av_frame in av_reader.decode(av_reader.streams.video[0]): for av_frame in av_reader.decode(av_reader.streams.video[0]):
vr_frame = next(video_reader) av_frames.append(torch.tensor(av_frame.to_rgb().to_ndarray()).permute(2, 0, 1))
av_pts.append(av_frame.pts * av_frame.time_base)
self.assertAlmostEqual(
float(av_frame.pts * av_frame.time_base), # get vr frames
vr_frame["pts"], video_reader = VideoReader(full_path, "video")
delta=0.1, for vr_frame in video_reader:
) vr_frames.append(vr_frame["data"])
vr_pts.append(vr_frame["pts"])
av_array = torch.tensor(av_frame.to_rgb().to_ndarray()).permute(
2, 0, 1 # same number of frames
) assert len(vr_frames) == len(av_frames)
vr_array = vr_frame["data"] assert len(vr_pts) == len(av_pts)
mean_delta = torch.mean(
torch.abs(av_array.float() - vr_array.float()) # compare the frames and ptss
) for i in range(len(vr_frames)):
assert float(av_pts[i]) == approx(vr_pts[i], abs=0.1)
mean_delta = torch.mean(torch.abs(av_frames[i].float() - vr_frames[i].float()))
# on average the difference is very small and caused # on average the difference is very small and caused
# by decoding (around 1%) # by decoding (around 1%)
# TODO: asses empirically how to set this? atm it's 1% # TODO: asses empirically how to set this? atm it's 1%
# averaged over all frames # averaged over all frames
self.assertTrue(mean_delta.item() < 2.5) assert mean_delta.item() < 2.55
del vr_frames, av_frames, vr_pts, av_pts
av_reader = av.open(full_path) # test audio reading compared to PYAV
with av.open(full_path) as av_reader:
if av_reader.streams.audio: if av_reader.streams.audio:
video_reader = VideoReader(full_path, "audio") av_frames, vr_frames = [], []
av_pts, vr_pts = [], []
# get av frames
for av_frame in av_reader.decode(av_reader.streams.audio[0]): for av_frame in av_reader.decode(av_reader.streams.audio[0]):
vr_frame = next(video_reader) av_frames.append(torch.tensor(av_frame.to_ndarray()).permute(1, 0))
self.assertAlmostEqual( av_pts.append(av_frame.pts * av_frame.time_base)
float(av_frame.pts * av_frame.time_base), av_reader.close()
vr_frame["pts"],
delta=0.1,
)
av_array = torch.tensor(av_frame.to_ndarray()).permute(1, 0)
vr_array = vr_frame["data"]
max_delta = torch.max(
torch.abs(av_array.float() - vr_array.float())
)
# we assure that there is never more than 1% difference in signal
self.assertTrue(max_delta.item() < 0.001)
def test_metadata(self): # get vr frames
video_reader = VideoReader(full_path, "audio")
for vr_frame in video_reader:
vr_frames.append(vr_frame["data"])
vr_pts.append(vr_frame["pts"])
# same number of frames
assert len(vr_frames) == len(av_frames)
assert len(vr_pts) == len(av_pts)
# compare the frames and ptss
for i in range(len(vr_frames)):
assert float(av_pts[i]) == approx(vr_pts[i], abs=0.1)
max_delta = torch.max(torch.abs(av_frames[i].float() - vr_frames[i].float()))
# we assure that there is never more than 1% difference in signal
assert max_delta.item() < 0.001
@pytest.mark.parametrize("stream", ["video", "audio"])
@pytest.mark.parametrize("test_video", test_videos.keys())
@pytest.mark.parametrize("backend", backends())
def test_frame_reading_mem_vs_file(self, test_video, stream, backend):
torchvision.set_video_backend(backend)
full_path = os.path.join(VIDEO_DIR, test_video)
reader = VideoReader(full_path)
reader_md = reader.get_metadata()
if stream in reader_md:
# Test video reading from file vs from memory
vr_frames, vr_frames_mem = [], []
vr_pts, vr_pts_mem = [], []
# get vr frames
video_reader = VideoReader(full_path, stream)
for vr_frame in video_reader:
vr_frames.append(vr_frame["data"])
vr_pts.append(vr_frame["pts"])
# get vr frames = read from memory
f = open(full_path, "rb")
fbytes = f.read()
f.close()
video_reader_from_mem = VideoReader(fbytes, stream)
for vr_frame_from_mem in video_reader_from_mem:
vr_frames_mem.append(vr_frame_from_mem["data"])
vr_pts_mem.append(vr_frame_from_mem["pts"])
# same number of frames
assert len(vr_frames) == len(vr_frames_mem)
assert len(vr_pts) == len(vr_pts_mem)
# compare the frames and ptss
for i in range(len(vr_frames)):
assert vr_pts[i] == vr_pts_mem[i]
mean_delta = torch.mean(torch.abs(vr_frames[i].float() - vr_frames_mem[i].float()))
# on average the difference is very small and caused
# by decoding (around 1%)
# TODO: asses empirically how to set this? atm it's 1%
# averaged over all frames
assert mean_delta.item() < 2.55
del vr_frames, vr_pts, vr_frames_mem, vr_pts_mem
else:
del reader, reader_md
@pytest.mark.parametrize("test_video,config", test_videos.items())
@pytest.mark.parametrize("backend", backends())
def test_metadata(self, test_video, config, backend):
""" """
Test that the metadata returned via pyav corresponds to the one returned Test that the metadata returned via pyav corresponds to the one returned
by the new video decoder API by the new video decoder API
""" """
for test_video, config in test_videos.items(): torchvision.set_video_backend(backend)
full_path = os.path.join(VIDEO_DIR, test_video) full_path = os.path.join(VIDEO_DIR, test_video)
reader = VideoReader(full_path, "video") reader = VideoReader(full_path, "video")
reader_md = reader.get_metadata() reader_md = reader.get_metadata()
self.assertAlmostEqual( assert config.video_fps == approx(reader_md["video"]["fps"][0], abs=0.0001)
config.video_fps, reader_md["video"]["fps"][0], delta=0.0001 assert config.duration == approx(reader_md["video"]["duration"][0], abs=0.5)
)
self.assertAlmostEqual( @pytest.mark.parametrize("test_video", test_videos.keys())
config.duration, reader_md["video"]["duration"][0], delta=0.5 @pytest.mark.parametrize("backend", backends())
) def test_seek_start(self, test_video, backend):
torchvision.set_video_backend(backend)
def test_seek_start(self): full_path = os.path.join(VIDEO_DIR, test_video)
for test_video, config in test_videos.items(): video_reader = VideoReader(full_path, "video")
full_path = os.path.join(VIDEO_DIR, test_video) num_frames = 0
for _ in video_reader:
video_reader = VideoReader(full_path, "video") num_frames += 1
# now seek the container to 0 and do it again
# It's often that starting seek can be inprecise
# this way and it doesn't start at 0
video_reader.seek(0)
start_num_frames = 0
for _ in video_reader:
start_num_frames += 1
assert start_num_frames == num_frames
# now seek the container to < 0 to check for unexpected behaviour
video_reader.seek(-1)
start_num_frames = 0
for _ in video_reader:
start_num_frames += 1
assert start_num_frames == num_frames
@pytest.mark.parametrize("test_video", test_videos.keys())
@pytest.mark.parametrize("backend", ["video_reader"])
def test_accurateseek_middle(self, test_video, backend):
torchvision.set_video_backend(backend)
full_path = os.path.join(VIDEO_DIR, test_video)
stream = "video"
video_reader = VideoReader(full_path, stream)
md = video_reader.get_metadata()
duration = md[stream]["duration"][0]
if duration is not None:
num_frames = 0 num_frames = 0
for frame in video_reader: for _ in video_reader:
num_frames += 1 num_frames += 1
# now seek the container to 0 and do it again video_reader.seek(duration / 2)
# It's often that starting seek can be inprecise middle_num_frames = 0
# this way and it doesn't start at 0 for _ in video_reader:
video_reader.seek(0) middle_num_frames += 1
start_num_frames = 0
for frame in video_reader:
start_num_frames += 1
self.assertEqual(start_num_frames, num_frames)
# now seek the container to < 0 to check for unexpected behaviour assert middle_num_frames < num_frames
video_reader.seek(-1) assert middle_num_frames == approx(num_frames // 2, abs=1)
start_num_frames = 0
for frame in video_reader:
start_num_frames += 1
self.assertEqual(start_num_frames, num_frames) video_reader.seek(duration / 2)
frame = next(video_reader)
lb = duration / 2 - 1 / md[stream]["fps"][0]
ub = duration / 2 + 1 / md[stream]["fps"][0]
assert (lb <= frame["pts"]) and (ub >= frame["pts"])
def test_accurateseek_middle(self): def test_fate_suite(self):
for test_video, config in test_videos.items(): # TODO: remove the try-except statement once the connectivity issues are resolved
full_path = os.path.join(VIDEO_DIR, test_video) try:
video_path = fate("sub/MovText_capability_tester.mp4", VIDEO_DIR)
except (urllib.error.URLError, ConnectionError) as error:
pytest.skip(f"Skipping due to connectivity issues: {error}")
vr = VideoReader(video_path)
metadata = vr.get_metadata()
stream = "video" assert metadata["subtitles"]["duration"] is not None
video_reader = VideoReader(full_path, stream) os.remove(video_path)
md = video_reader.get_metadata()
duration = md[stream]["duration"][0]
if duration is not None:
num_frames = 0 @pytest.mark.skipif(av is None, reason="PyAV unavailable")
for frame in video_reader: @pytest.mark.parametrize("test_video,config", test_videos.items())
num_frames += 1 @pytest.mark.parametrize("backend", backends())
def test_keyframe_reading(self, test_video, config, backend):
torchvision.set_video_backend(backend)
full_path = os.path.join(VIDEO_DIR, test_video)
video_reader.seek(duration / 2) av_reader = av.open(full_path)
middle_num_frames = 0 # reduce streams to only keyframes
for frame in video_reader: av_stream = av_reader.streams.video[0]
middle_num_frames += 1 av_stream.codec_context.skip_frame = "NONKEY"
self.assertTrue(middle_num_frames < num_frames) av_keyframes = []
self.assertAlmostEqual(middle_num_frames, num_frames // 2, delta=1) vr_keyframes = []
if av_reader.streams.video:
video_reader.seek(duration / 2) # get all keyframes using pyav. Then, seek randomly into video reader
frame = next(video_reader) # and assert that all the returned values are in AV_KEYFRAMES
lb = duration / 2 - 1 / md[stream]["fps"][0]
ub = duration / 2 + 1 / md[stream]["fps"][0]
self.assertTrue((lb <= frame["pts"]) & (ub >= frame["pts"]))
def test_fate_suite(self): for av_frame in av_reader.decode(av_stream):
video_path = fate("sub/MovText_capability_tester.mp4", VIDEO_DIR) av_keyframes.append(float(av_frame.pts * av_frame.time_base))
vr = VideoReader(video_path)
metadata = vr.get_metadata()
self.assertTrue(metadata["subtitles"]["duration"] is not None) if len(av_keyframes) > 1:
os.remove(video_path) video_reader = VideoReader(full_path, "video")
for i in range(1, len(av_keyframes)):
seek_val = (av_keyframes[i] + av_keyframes[i - 1]) / 2
data = next(video_reader.seek(seek_val, True))
vr_keyframes.append(data["pts"])
data = next(video_reader.seek(config.duration, True))
vr_keyframes.append(data["pts"])
assert len(av_keyframes) == len(vr_keyframes)
# NOTE: this video gets different keyframe with different
# loaders (0.333 pyav, 0.666 for us)
if test_video != "TrumanShow_wave_f_nm_np1_fr_med_26.avi":
for i in range(len(av_keyframes)):
assert av_keyframes[i] == approx(vr_keyframes[i], rel=0.001)
def test_src(self):
with pytest.raises(ValueError, match="src cannot be empty"):
VideoReader(src="")
with pytest.raises(ValueError, match="src must be either string"):
VideoReader(src=2)
with pytest.raises(TypeError, match="unexpected keyword argument"):
VideoReader(path="path")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() pytest.main([__file__])
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
project(test_frcnn_tracing)
find_package(Torch REQUIRED)
find_package(TorchVision REQUIRED)
# This due to some headers importing Python.h
find_package(Python3 COMPONENTS Development)
add_executable(test_frcnn_tracing test_frcnn_tracing.cpp)
target_compile_features(test_frcnn_tracing PUBLIC cxx_range_for)
target_link_libraries(test_frcnn_tracing ${TORCH_LIBRARIES} TorchVision::TorchVision Python3::Python)
set_property(TARGET test_frcnn_tracing PROPERTY CXX_STANDARD 14)
import os.path as osp
import torch
import torchvision
HERE = osp.dirname(osp.abspath(__file__))
ASSETS = osp.dirname(osp.dirname(HERE))
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False)
model.eval()
traced_model = torch.jit.script(model)
traced_model.save("fasterrcnn_resnet50_fpn.pt")
import warnings
import os import os
import warnings
from .extension import _HAS_OPS from modulefinder import Module
from torchvision import models
from torchvision import datasets
from torchvision import ops
from torchvision import transforms
from torchvision import utils
from torchvision import io
import torch import torch
# Don't re-order these, we need to load the _C extension (done when importing
# .extensions) before entering _meta_registrations.
from .extension import _HAS_OPS # usort:skip
from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils # usort:skip
try: try:
from .version import __version__ # noqa: F401 from .version import __version__ # noqa: F401
except ImportError: except ImportError:
pass pass
# Check if torchvision is being imported within the root folder # Check if torchvision is being imported within the root folder
if (not _HAS_OPS and os.path.dirname(os.path.realpath(__file__)) == if not _HAS_OPS and os.path.dirname(os.path.realpath(__file__)) == os.path.join(
os.path.join(os.path.realpath(os.getcwd()), 'torchvision')): os.path.realpath(os.getcwd()), "torchvision"
message = ('You are importing torchvision within its own root folder ({}). ' ):
'This is not expected to work and may give errors. Please exit the ' message = (
'torchvision project source and relaunch your python interpreter.') "You are importing torchvision within its own root folder ({}). "
"This is not expected to work and may give errors. Please exit the "
"torchvision project source and relaunch your python interpreter."
)
warnings.warn(message.format(os.getcwd())) warnings.warn(message.format(os.getcwd()))
_image_backend = 'PIL' _image_backend = "PIL"
_video_backend = "pyav" _video_backend = "pyav"
...@@ -40,9 +41,8 @@ def set_image_backend(backend): ...@@ -40,9 +41,8 @@ def set_image_backend(backend):
generally faster than PIL, but does not support as many operations. generally faster than PIL, but does not support as many operations.
""" """
global _image_backend global _image_backend
if backend not in ['PIL', 'accimage']: if backend not in ["PIL", "accimage"]:
raise ValueError("Invalid backend '{}'. Options are 'PIL' and 'accimage'" raise ValueError(f"Invalid backend '{backend}'. Options are 'PIL' and 'accimage'")
.format(backend))
_image_backend = backend _image_backend = backend
...@@ -63,23 +63,23 @@ def set_video_backend(backend): ...@@ -63,23 +63,23 @@ def set_video_backend(backend):
binding for the FFmpeg libraries. binding for the FFmpeg libraries.
The :mod:`video_reader` package includes a native C++ implementation on The :mod:`video_reader` package includes a native C++ implementation on
top of FFMPEG libraries, and a python API of TorchScript custom operator. top of FFMPEG libraries, and a python API of TorchScript custom operator.
It is generally decoding faster than :mod:`pyav`, but perhaps is less robust. It generally decodes faster than :mod:`pyav`, but is perhaps less robust.
.. note:: .. note::
Building with FFMPEG is disabled by default in the latest master. If you want to use the 'video_reader' Building with FFMPEG is disabled by default in the latest `main`. If you want to use the 'video_reader'
backend, please compile torchvision from source. backend, please compile torchvision from source.
""" """
global _video_backend global _video_backend
if backend not in ["pyav", "video_reader"]: if backend not in ["pyav", "video_reader", "cuda"]:
raise ValueError( raise ValueError("Invalid video backend '%s'. Options are 'pyav', 'video_reader' and 'cuda'" % backend)
"Invalid video backend '%s'. Options are 'pyav' and 'video_reader'" % backend
)
if backend == "video_reader" and not io._HAS_VIDEO_OPT: if backend == "video_reader" and not io._HAS_VIDEO_OPT:
message = ( # TODO: better messages
"video_reader video backend is not available." message = "video_reader video backend is not available. Please compile torchvision from source and try again"
" Please compile torchvision from source and try again" raise RuntimeError(message)
) elif backend == "cuda" and not io._HAS_GPU_VIDEO_DECODER:
warnings.warn(message) # TODO: better messages
message = "cuda video backend is not available."
raise RuntimeError(message)
else: else:
_video_backend = backend _video_backend = backend
...@@ -97,3 +97,9 @@ def get_video_backend(): ...@@ -97,3 +97,9 @@ def get_video_backend():
def _is_tracing(): def _is_tracing():
return torch._C._get_tracing_state() return torch._C._get_tracing_state()
def disable_beta_transforms_warning():
# Noop, only exists to avoid breaking existing code.
# See https://github.com/pytorch/vision/issues/7896
pass
import importlib.machinery
import os
from torch.hub import _get_torch_home
_HOME = os.path.join(_get_torch_home(), "datasets", "vision")
_USE_SHARDED_DATASETS = False
def _download_file_from_remote_location(fpath: str, url: str) -> None:
pass
def _is_remote_location_available() -> bool:
return False
try:
from torch.hub import load_state_dict_from_url # noqa: 401
except ImportError:
from torch.utils.model_zoo import load_url as load_state_dict_from_url # noqa: 401
def _get_extension_path(lib_name):
lib_dir = os.path.dirname(__file__)
if os.name == "nt":
# Register the main torchvision library location on the default DLL path
import ctypes
kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
prev_error_mode = kernel32.SetErrorMode(0x0001)
if with_load_library_flags:
kernel32.AddDllDirectory.restype = ctypes.c_void_p
os.add_dll_directory(lib_dir)
kernel32.SetErrorMode(prev_error_mode)
loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES)
extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
ext_specs = extfinder.find_spec(lib_name)
if ext_specs is None:
raise ImportError
return ext_specs.origin
import functools
import torch
import torch._custom_ops
import torch.library
# Ensure that torch.ops.torchvision is visible
import torchvision.extension # noqa: F401
@functools.lru_cache(None)
def get_meta_lib():
return torch.library.Library("torchvision", "IMPL", "Meta")
def register_meta(op_name, overload_name="default"):
def wrapper(fn):
if torchvision.extension._has_ops():
get_meta_lib().impl(getattr(getattr(torch.ops.torchvision, op_name), overload_name), fn)
return fn
return wrapper
@register_meta("roi_align")
def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
torch._check(
input.dtype == rois.dtype,
lambda: (
"Expected tensor for input to have the same type as tensor for rois; "
f"but type {input.dtype} does not equal {rois.dtype}"
),
)
num_rois = rois.size(0)
channels = input.size(1)
return input.new_empty((num_rois, channels, pooled_height, pooled_width))
@register_meta("_roi_align_backward")
def meta_roi_align_backward(
grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio, aligned
):
torch._check(
grad.dtype == rois.dtype,
lambda: (
"Expected tensor for grad to have the same type as tensor for rois; "
f"but type {grad.dtype} does not equal {rois.dtype}"
),
)
return grad.new_empty((batch_size, channels, height, width))
@register_meta("ps_roi_align")
def meta_ps_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio):
torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
torch._check(
input.dtype == rois.dtype,
lambda: (
"Expected tensor for input to have the same type as tensor for rois; "
f"but type {input.dtype} does not equal {rois.dtype}"
),
)
channels = input.size(1)
torch._check(
channels % (pooled_height * pooled_width) == 0,
"input channels must be a multiple of pooling height * pooling width",
)
num_rois = rois.size(0)
out_size = (num_rois, channels // (pooled_height * pooled_width), pooled_height, pooled_width)
return input.new_empty(out_size), torch.empty(out_size, dtype=torch.int32, device="meta")
@register_meta("_ps_roi_align_backward")
def meta_ps_roi_align_backward(
grad,
rois,
channel_mapping,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
batch_size,
channels,
height,
width,
):
torch._check(
grad.dtype == rois.dtype,
lambda: (
"Expected tensor for grad to have the same type as tensor for rois; "
f"but type {grad.dtype} does not equal {rois.dtype}"
),
)
return grad.new_empty((batch_size, channels, height, width))
@register_meta("roi_pool")
def meta_roi_pool(input, rois, spatial_scale, pooled_height, pooled_width):
torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
torch._check(
input.dtype == rois.dtype,
lambda: (
"Expected tensor for input to have the same type as tensor for rois; "
f"but type {input.dtype} does not equal {rois.dtype}"
),
)
num_rois = rois.size(0)
channels = input.size(1)
out_size = (num_rois, channels, pooled_height, pooled_width)
return input.new_empty(out_size), torch.empty(out_size, device="meta", dtype=torch.int32)
@register_meta("_roi_pool_backward")
def meta_roi_pool_backward(
grad, rois, argmax, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width
):
torch._check(
grad.dtype == rois.dtype,
lambda: (
"Expected tensor for grad to have the same type as tensor for rois; "
f"but type {grad.dtype} does not equal {rois.dtype}"
),
)
return grad.new_empty((batch_size, channels, height, width))
@register_meta("ps_roi_pool")
def meta_ps_roi_pool(input, rois, spatial_scale, pooled_height, pooled_width):
torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
torch._check(
input.dtype == rois.dtype,
lambda: (
"Expected tensor for input to have the same type as tensor for rois; "
f"but type {input.dtype} does not equal {rois.dtype}"
),
)
channels = input.size(1)
torch._check(
channels % (pooled_height * pooled_width) == 0,
"input channels must be a multiple of pooling height * pooling width",
)
num_rois = rois.size(0)
out_size = (num_rois, channels // (pooled_height * pooled_width), pooled_height, pooled_width)
return input.new_empty(out_size), torch.empty(out_size, device="meta", dtype=torch.int32)
@register_meta("_ps_roi_pool_backward")
def meta_ps_roi_pool_backward(
grad, rois, channel_mapping, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width
):
torch._check(
grad.dtype == rois.dtype,
lambda: (
"Expected tensor for grad to have the same type as tensor for rois; "
f"but type {grad.dtype} does not equal {rois.dtype}"
),
)
return grad.new_empty((batch_size, channels, height, width))
@torch.library.register_fake("torchvision::nms")
def meta_nms(dets, scores, iou_threshold):
torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D")
torch._check(dets.size(1) == 4, lambda: f"boxes should have 4 elements in dimension 1, got {dets.size(1)}")
torch._check(scores.dim() == 1, lambda: f"scores should be a 1d tensor, got {scores.dim()}")
torch._check(
dets.size(0) == scores.size(0),
lambda: f"boxes and scores should have same number of elements in dimension 0, got {dets.size(0)} and {scores.size(0)}",
)
ctx = torch._custom_ops.get_ctx()
num_to_keep = ctx.create_unbacked_symint()
return dets.new_empty(num_to_keep, dtype=torch.long)
@register_meta("deform_conv2d")
def meta_deform_conv2d(
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dil_h,
dil_w,
n_weight_grps,
n_offset_grps,
use_mask,
):
out_height, out_width = offset.shape[-2:]
out_channels = weight.shape[0]
batch_size = input.shape[0]
return input.new_empty((batch_size, out_channels, out_height, out_width))
@register_meta("_deform_conv2d_backward")
def meta_deform_conv2d_backward(
grad,
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups,
use_mask,
):
grad_input = input.new_empty(input.shape)
grad_weight = weight.new_empty(weight.shape)
grad_offset = offset.new_empty(offset.shape)
grad_mask = mask.new_empty(mask.shape)
grad_bias = bias.new_empty(bias.shape)
return grad_input, grad_weight, grad_offset, grad_mask, grad_bias
import enum
from typing import Sequence, Type, TypeVar
T = TypeVar("T", bound=enum.Enum)
class StrEnumMeta(enum.EnumMeta):
auto = enum.auto
def from_str(self: Type[T], member: str) -> T: # type: ignore[misc]
try:
return self[member]
except KeyError:
# TODO: use `add_suggestion` from torchvision.prototype.utils._internal to improve the error message as
# soon as it is migrated.
raise ValueError(f"Unknown value '{member}' for {self.__name__}.") from None
class StrEnum(enum.Enum, metaclass=StrEnumMeta):
pass
def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
if not seq:
return ""
if len(seq) == 1:
return f"'{seq[0]}'"
head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'"
tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'"
return head + tail
...@@ -48,6 +48,23 @@ bool AudioSampler::init(const SamplerParameters& params) { ...@@ -48,6 +48,23 @@ bool AudioSampler::init(const SamplerParameters& params) {
return false; return false;
} }
#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(57, 28, 100)
SwrContext* swrContext_ = NULL;
AVChannelLayout channel_out;
AVChannelLayout channel_in;
av_channel_layout_default(&channel_out, params.out.audio.channels);
av_channel_layout_default(&channel_in, params.in.audio.channels);
int ret = swr_alloc_set_opts2(
&swrContext_,
&channel_out,
(AVSampleFormat)params.out.audio.format,
params.out.audio.samples,
&channel_in,
(AVSampleFormat)params.in.audio.format,
params.in.audio.samples,
0,
logCtx_);
#else
swrContext_ = swr_alloc_set_opts( swrContext_ = swr_alloc_set_opts(
nullptr, nullptr,
av_get_default_channel_layout(params.out.audio.channels), av_get_default_channel_layout(params.out.audio.channels),
...@@ -58,6 +75,7 @@ bool AudioSampler::init(const SamplerParameters& params) { ...@@ -58,6 +75,7 @@ bool AudioSampler::init(const SamplerParameters& params) {
params.in.audio.samples, params.in.audio.samples,
0, 0,
logCtx_); logCtx_);
#endif
if (swrContext_ == nullptr) { if (swrContext_ == nullptr) {
LOG(ERROR) << "Cannot allocate SwrContext"; LOG(ERROR) << "Cannot allocate SwrContext";
return false; return false;
...@@ -65,7 +83,7 @@ bool AudioSampler::init(const SamplerParameters& params) { ...@@ -65,7 +83,7 @@ bool AudioSampler::init(const SamplerParameters& params) {
int result; int result;
if ((result = swr_init(swrContext_)) < 0) { if ((result = swr_init(swrContext_)) < 0) {
LOG(ERROR) << "swr_init faield, err: " << Util::generateErrorDesc(result) LOG(ERROR) << "swr_init failed, err: " << Util::generateErrorDesc(result)
<< ", in -> format: " << params.in.audio.format << ", in -> format: " << params.in.audio.format
<< ", channels: " << params.in.audio.channels << ", channels: " << params.in.audio.channels
<< ", samples: " << params.in.audio.samples << ", samples: " << params.in.audio.samples
...@@ -116,12 +134,12 @@ int AudioSampler::sample( ...@@ -116,12 +134,12 @@ int AudioSampler::sample(
outNumSamples, outNumSamples,
inPlanes, inPlanes,
inNumSamples)) < 0) { inNumSamples)) < 0) {
LOG(ERROR) << "swr_convert faield, err: " LOG(ERROR) << "swr_convert failed, err: "
<< Util::generateErrorDesc(result); << Util::generateErrorDesc(result);
return result; return result;
} }
CHECK_LE(result, outNumSamples); TORCH_CHECK_LE(result, outNumSamples);
if (result) { if (result) {
if ((result = av_samples_get_buffer_size( if ((result = av_samples_get_buffer_size(
...@@ -132,7 +150,7 @@ int AudioSampler::sample( ...@@ -132,7 +150,7 @@ int AudioSampler::sample(
1)) >= 0) { 1)) >= 0) {
out->append(result); out->append(result);
} else { } else {
LOG(ERROR) << "av_samples_get_buffer_size faield, err: " LOG(ERROR) << "av_samples_get_buffer_size failed, err: "
<< Util::generateErrorDesc(result); << Util::generateErrorDesc(result);
} }
} }
...@@ -140,7 +158,7 @@ int AudioSampler::sample( ...@@ -140,7 +158,7 @@ int AudioSampler::sample(
// allocate a temporary buffer // allocate a temporary buffer
auto* tmpBuffer = static_cast<uint8_t*>(av_malloc(outBufferBytes)); auto* tmpBuffer = static_cast<uint8_t*>(av_malloc(outBufferBytes));
if (!tmpBuffer) { if (!tmpBuffer) {
LOG(ERROR) << "av_alloc faield, for size: " << outBufferBytes; LOG(ERROR) << "av_alloc failed, for size: " << outBufferBytes;
return -1; return -1;
} }
...@@ -158,7 +176,7 @@ int AudioSampler::sample( ...@@ -158,7 +176,7 @@ int AudioSampler::sample(
outNumSamples, outNumSamples,
inPlanes, inPlanes,
inNumSamples)) < 0) { inNumSamples)) < 0) {
LOG(ERROR) << "swr_convert faield, err: " LOG(ERROR) << "swr_convert failed, err: "
<< Util::generateErrorDesc(result); << Util::generateErrorDesc(result);
av_free(tmpBuffer); av_free(tmpBuffer);
return result; return result;
...@@ -166,7 +184,7 @@ int AudioSampler::sample( ...@@ -166,7 +184,7 @@ int AudioSampler::sample(
av_free(tmpBuffer); av_free(tmpBuffer);
CHECK_LE(result, outNumSamples); TORCH_CHECK_LE(result, outNumSamples);
if (result) { if (result) {
result = av_samples_get_buffer_size( result = av_samples_get_buffer_size(
......
...@@ -6,26 +6,36 @@ ...@@ -6,26 +6,36 @@
namespace ffmpeg { namespace ffmpeg {
namespace { namespace {
static int get_nb_channels(const AVFrame* frame, const AVCodecContext* codec) {
#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(57, 28, 100)
return frame ? frame->ch_layout.nb_channels : codec->ch_layout.nb_channels;
#else
return frame ? frame->channels : codec->channels;
#endif
}
bool operator==(const AudioFormat& x, const AVFrame& y) { bool operator==(const AudioFormat& x, const AVFrame& y) {
return x.samples == y.sample_rate && x.channels == y.channels && return x.samples == static_cast<size_t>(y.sample_rate) &&
x.channels == static_cast<size_t>(get_nb_channels(&y, nullptr)) &&
x.format == y.format; x.format == y.format;
} }
bool operator==(const AudioFormat& x, const AVCodecContext& y) { bool operator==(const AudioFormat& x, const AVCodecContext& y) {
return x.samples == y.sample_rate && x.channels == y.channels && return x.samples == static_cast<size_t>(y.sample_rate) &&
x.channels == static_cast<size_t>(get_nb_channels(nullptr, &y)) &&
x.format == y.sample_fmt; x.format == y.sample_fmt;
} }
AudioFormat& toAudioFormat(AudioFormat& x, const AVFrame& y) { AudioFormat& toAudioFormat(AudioFormat& x, const AVFrame& y) {
x.samples = y.sample_rate; x.samples = y.sample_rate;
x.channels = y.channels; x.channels = get_nb_channels(&y, nullptr);
x.format = y.format; x.format = y.format;
return x; return x;
} }
AudioFormat& toAudioFormat(AudioFormat& x, const AVCodecContext& y) { AudioFormat& toAudioFormat(AudioFormat& x, const AVCodecContext& y) {
x.samples = y.sample_rate; x.samples = y.sample_rate;
x.channels = y.channels; x.channels = get_nb_channels(nullptr, &y);
x.format = y.sample_fmt; x.format = y.sample_fmt;
return x; return x;
} }
...@@ -54,9 +64,15 @@ int AudioStream::initFormat() { ...@@ -54,9 +64,15 @@ int AudioStream::initFormat() {
if (format_.format.audio.samples == 0) { if (format_.format.audio.samples == 0) {
format_.format.audio.samples = codecCtx_->sample_rate; format_.format.audio.samples = codecCtx_->sample_rate;
} }
#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(57, 28, 100)
if (format_.format.audio.channels == 0) {
format_.format.audio.channels = codecCtx_->ch_layout.nb_channels;
}
#else
if (format_.format.audio.channels == 0) { if (format_.format.audio.channels == 0) {
format_.format.audio.channels = codecCtx_->channels; format_.format.audio.channels = codecCtx_->channels;
} }
#endif
if (format_.format.audio.format == AV_SAMPLE_FMT_NONE) { if (format_.format.audio.format == AV_SAMPLE_FMT_NONE) {
format_.format.audio.format = codecCtx_->sample_fmt; format_.format.audio.format = codecCtx_->sample_fmt;
} }
...@@ -68,6 +84,7 @@ int AudioStream::initFormat() { ...@@ -68,6 +84,7 @@ int AudioStream::initFormat() {
: -1; : -1;
} }
// copies audio sample bytes via swr_convert call in audio_sampler.cpp
int AudioStream::copyFrameBytes(ByteStorage* out, bool flush) { int AudioStream::copyFrameBytes(ByteStorage* out, bool flush) {
if (!sampler_) { if (!sampler_) {
sampler_ = std::make_unique<AudioSampler>(codecCtx_); sampler_ = std::make_unique<AudioSampler>(codecCtx_);
...@@ -95,6 +112,8 @@ int AudioStream::copyFrameBytes(ByteStorage* out, bool flush) { ...@@ -95,6 +112,8 @@ int AudioStream::copyFrameBytes(ByteStorage* out, bool flush) {
<< ", channels: " << format_.format.audio.channels << ", channels: " << format_.format.audio.channels
<< ", format: " << format_.format.audio.format; << ", format: " << format_.format.audio.format;
} }
// calls to a sampler that converts the audio samples and copies them to the
// out buffer via ffmpeg::swr_convert
return sampler_->sample(flush ? nullptr : frame_, out); return sampler_->sample(flush ? nullptr : frame_, out);
} }
......
#include "decoder.h" #include "decoder.h"
#include <c10/util/Logging.h> #include <c10/util/Logging.h>
#include <libavutil/avutil.h>
#include <future> #include <future>
#include <iostream> #include <iostream>
#include <mutex> #include <mutex>
...@@ -17,25 +18,6 @@ constexpr size_t kIoBufferSize = 96 * 1024; ...@@ -17,25 +18,6 @@ constexpr size_t kIoBufferSize = 96 * 1024;
constexpr size_t kIoPaddingSize = AV_INPUT_BUFFER_PADDING_SIZE; constexpr size_t kIoPaddingSize = AV_INPUT_BUFFER_PADDING_SIZE;
constexpr size_t kLogBufferSize = 1024; constexpr size_t kLogBufferSize = 1024;
int ffmpeg_lock(void** mutex, enum AVLockOp op) {
std::mutex** handle = (std::mutex**)mutex;
switch (op) {
case AV_LOCK_CREATE:
*handle = new std::mutex();
break;
case AV_LOCK_OBTAIN:
(*handle)->lock();
break;
case AV_LOCK_RELEASE:
(*handle)->unlock();
break;
case AV_LOCK_DESTROY:
delete *handle;
break;
}
return 0;
}
bool mapFfmpegType(AVMediaType media, MediaType* type) { bool mapFfmpegType(AVMediaType media, MediaType* type) {
switch (media) { switch (media) {
case AVMEDIA_TYPE_AUDIO: case AVMEDIA_TYPE_AUDIO:
...@@ -196,11 +178,11 @@ int64_t Decoder::seekCallback(int64_t offset, int whence) { ...@@ -196,11 +178,11 @@ int64_t Decoder::seekCallback(int64_t offset, int whence) {
void Decoder::initOnce() { void Decoder::initOnce() {
static std::once_flag flagInit; static std::once_flag flagInit;
std::call_once(flagInit, []() { std::call_once(flagInit, []() {
#if LIBAVUTIL_VERSION_MAJOR < 56 // Before FFMPEG 4.0
av_register_all(); av_register_all();
avcodec_register_all(); avcodec_register_all();
#endif
avformat_network_init(); avformat_network_init();
// register ffmpeg lock manager
av_lockmgr_register(&ffmpeg_lock);
av_log_set_callback(Decoder::logFunction); av_log_set_callback(Decoder::logFunction);
av_log_set_level(AV_LOG_ERROR); av_log_set_level(AV_LOG_ERROR);
VLOG(1) << "Registered ffmpeg libs"; VLOG(1) << "Registered ffmpeg libs";
...@@ -215,6 +197,12 @@ Decoder::~Decoder() { ...@@ -215,6 +197,12 @@ Decoder::~Decoder() {
cleanUp(); cleanUp();
} }
// Initialise the format context that holds information about the container and
// fill it with minimal information about the format (codecs are not opened
// here). Function reads in information about the streams from the container
// into inputCtx and then passes it to decoder::openStreams. Finally, if seek is
// specified within the decoder parameters, it seeks into the correct frame
// (note, the seek defined here is "precise" seek).
bool Decoder::init( bool Decoder::init(
const DecoderParameters& params, const DecoderParameters& params,
DecoderInCallback&& in, DecoderInCallback&& in,
...@@ -268,7 +256,7 @@ bool Decoder::init( ...@@ -268,7 +256,7 @@ bool Decoder::init(
break; break;
} }
fmt = av_find_input_format(fmtName); fmt = (AVInputFormat*)av_find_input_format(fmtName);
} }
const size_t avioCtxBufferSize = kIoBufferSize; const size_t avioCtxBufferSize = kIoBufferSize;
...@@ -324,6 +312,8 @@ bool Decoder::init( ...@@ -324,6 +312,8 @@ bool Decoder::init(
} }
} }
av_dict_set_int(&options, "probesize", params_.probeSize, 0);
interrupted_ = false; interrupted_ = false;
// ffmpeg avformat_open_input call can hang if media source doesn't respond // ffmpeg avformat_open_input call can hang if media source doesn't respond
...@@ -381,7 +371,7 @@ bool Decoder::init( ...@@ -381,7 +371,7 @@ bool Decoder::init(
cleanUp(); cleanUp();
return false; return false;
} }
// SyncDecoder inherits Decoder which would override onInit.
onInit(); onInit();
if (params.startOffset != 0) { if (params.startOffset != 0) {
...@@ -396,11 +386,17 @@ bool Decoder::init( ...@@ -396,11 +386,17 @@ bool Decoder::init(
return true; return true;
} }
// open appropriate CODEC for every type of stream and move it to the class
// variable `streams_` and make sure it is in range for decoding
bool Decoder::openStreams(std::vector<DecoderMetadata>* metadata) { bool Decoder::openStreams(std::vector<DecoderMetadata>* metadata) {
for (int i = 0; i < inputCtx_->nb_streams; i++) { for (unsigned int i = 0; i < inputCtx_->nb_streams; i++) {
// - find the corespondent format at params_.formats set // - find the corespondent format at params_.formats set
MediaFormat format; MediaFormat format;
#if LIBAVUTIL_VERSION_MAJOR < 56 // Before FFMPEG 4.0
const auto media = inputCtx_->streams[i]->codec->codec_type; const auto media = inputCtx_->streams[i]->codec->codec_type;
#else // FFMPEG 4.0+
const auto media = inputCtx_->streams[i]->codecpar->codec_type;
#endif
if (!mapFfmpegType(media, &format.type)) { if (!mapFfmpegType(media, &format.type)) {
VLOG(1) << "Stream media: " << media << " at index " << i VLOG(1) << "Stream media: " << media << " at index " << i
<< " gets ignored, unknown type"; << " gets ignored, unknown type";
...@@ -424,20 +420,20 @@ bool Decoder::openStreams(std::vector<DecoderMetadata>* metadata) { ...@@ -424,20 +420,20 @@ bool Decoder::openStreams(std::vector<DecoderMetadata>* metadata) {
if (it->stream == -2 || // all streams of this type are welcome if (it->stream == -2 || // all streams of this type are welcome
(!stream && (it->stream == -1 || it->stream == i))) { // new stream (!stream && (it->stream == -1 || it->stream == i))) { // new stream
VLOG(1) << "Stream type: " << format.type << " found, at index: " << i; VLOG(1) << "Stream type: " << format.type << " found, at index: " << i;
auto stream = createStream( auto stream_2 = createStream(
format.type, format.type,
inputCtx_, inputCtx_,
i, i,
params_.convertPtsToWallTime, params_.convertPtsToWallTime,
it->format, it->format,
params_.loggingUuid); params_.loggingUuid);
CHECK(stream); CHECK(stream_2);
if (stream->openCodec(metadata) < 0) { if (stream_2->openCodec(metadata, params_.numThreads) < 0) {
LOG(ERROR) << "uuid=" << params_.loggingUuid LOG(ERROR) << "uuid=" << params_.loggingUuid
<< " open codec failed, stream_idx=" << i; << " open codec failed, stream_idx=" << i;
return false; return false;
} }
streams_.emplace(i, std::move(stream)); streams_.emplace(i, std::move(stream_2));
inRange_.set(i, true); inRange_.set(i, true);
} }
} }
...@@ -478,6 +474,10 @@ void Decoder::cleanUp() { ...@@ -478,6 +474,10 @@ void Decoder::cleanUp() {
seekableBuffer_.shutdown(); seekableBuffer_.shutdown();
} }
// function does actual work, derived class calls it in working thread
// periodically. On success method returns 0, ENODATA on EOF, ETIMEDOUT if
// no frames got decoded in the specified timeout time, AVERROR_BUFFER_TOO_SMALL
// when unable to allocate packet and error on unrecoverable error
int Decoder::getFrame(size_t workingTimeInMs) { int Decoder::getFrame(size_t workingTimeInMs) {
if (inRange_.none()) { if (inRange_.none()) {
return ENODATA; return ENODATA;
...@@ -486,10 +486,16 @@ int Decoder::getFrame(size_t workingTimeInMs) { ...@@ -486,10 +486,16 @@ int Decoder::getFrame(size_t workingTimeInMs) {
// once decode() method gets called and grab some bytes // once decode() method gets called and grab some bytes
// run this method again // run this method again
// init package // init package
AVPacket avPacket; // update 03/22: moving memory management to ffmpeg
av_init_packet(&avPacket); AVPacket* avPacket;
avPacket.data = nullptr; avPacket = av_packet_alloc();
avPacket.size = 0; if (avPacket == nullptr) {
LOG(ERROR) << "uuid=" << params_.loggingUuid
<< " decoder as not able to allocate the packet.";
return AVERROR_BUFFER_TOO_SMALL;
}
avPacket->data = nullptr;
avPacket->size = 0;
auto end = std::chrono::steady_clock::now() + auto end = std::chrono::steady_clock::now() +
std::chrono::milliseconds(workingTimeInMs); std::chrono::milliseconds(workingTimeInMs);
...@@ -501,28 +507,44 @@ int Decoder::getFrame(size_t workingTimeInMs) { ...@@ -501,28 +507,44 @@ int Decoder::getFrame(size_t workingTimeInMs) {
int result = 0; int result = 0;
size_t decodingErrors = 0; size_t decodingErrors = 0;
bool decodedFrame = false; bool decodedFrame = false;
while (!interrupted_ && inRange_.any() && !decodedFrame && watcher()) { while (!interrupted_ && inRange_.any() && !decodedFrame) {
result = av_read_frame(inputCtx_, &avPacket); if (watcher() == false) {
LOG(ERROR) << "uuid=" << params_.loggingUuid << " hit ETIMEDOUT";
result = ETIMEDOUT;
break;
}
result = av_read_frame(inputCtx_, avPacket);
if (result == AVERROR(EAGAIN)) { if (result == AVERROR(EAGAIN)) {
VLOG(4) << "Decoder is busy..."; VLOG(4) << "Decoder is busy...";
std::this_thread::yield(); std::this_thread::yield();
result = 0; // reset error, EAGAIN is not an error at all result = 0; // reset error, EAGAIN is not an error at all
// reset the packet to default settings
av_packet_unref(avPacket);
continue; continue;
} else if (result == AVERROR_EOF) { } else if (result == AVERROR_EOF) {
flushStreams(); flushStreams();
VLOG(1) << "End of stream"; VLOG(1) << "End of stream";
result = ENODATA; result = ENODATA;
break; break;
} else if (
result == AVERROR(EPERM) && params_.skipOperationNotPermittedPackets) {
// reset error, lets skip packets with EPERM
result = 0;
// reset the packet to default settings
av_packet_unref(avPacket);
continue;
} else if (result < 0) { } else if (result < 0) {
flushStreams(); flushStreams();
LOG(ERROR) << "Error detected: " << Util::generateErrorDesc(result); LOG(ERROR) << "uuid=" << params_.loggingUuid
<< " error detected: " << Util::generateErrorDesc(result);
break; break;
} }
// get stream // get stream; if stream cannot be found reset the packet to
auto stream = findByIndex(avPacket.stream_index); // default settings
auto stream = findByIndex(avPacket->stream_index);
if (stream == nullptr || !inRange_.test(stream->getIndex())) { if (stream == nullptr || !inRange_.test(stream->getIndex())) {
av_packet_unref(&avPacket); av_packet_unref(avPacket);
continue; continue;
} }
...@@ -533,9 +555,10 @@ int Decoder::getFrame(size_t workingTimeInMs) { ...@@ -533,9 +555,10 @@ int Decoder::getFrame(size_t workingTimeInMs) {
bool gotFrame = false; bool gotFrame = false;
bool hasMsg = false; bool hasMsg = false;
// packet either got consumed completely or not at all // packet either got consumed completely or not at all
if ((result = processPacket(stream, &avPacket, &gotFrame, &hasMsg)) < 0) { if ((result = processPacket(
stream, avPacket, &gotFrame, &hasMsg, params_.fastSeek)) < 0) {
LOG(ERROR) << "uuid=" << params_.loggingUuid LOG(ERROR) << "uuid=" << params_.loggingUuid
<< " processPacket failed with code=" << result; << " processPacket failed with code: " << result;
break; break;
} }
...@@ -566,20 +589,18 @@ int Decoder::getFrame(size_t workingTimeInMs) { ...@@ -566,20 +589,18 @@ int Decoder::getFrame(size_t workingTimeInMs) {
result = 0; result = 0;
av_packet_unref(&avPacket); av_packet_unref(avPacket);
} }
av_packet_unref(&avPacket); av_packet_free(&avPacket);
VLOG(2) << "Interrupted loop" VLOG(2) << "Interrupted loop"
<< ", interrupted_ " << interrupted_ << ", inRange_.any() " << ", interrupted_ " << interrupted_ << ", inRange_.any() "
<< inRange_.any() << ", decodedFrame " << decodedFrame << ", result " << inRange_.any() << ", decodedFrame " << decodedFrame << ", result "
<< result; << result;
// loop can be terminated, either by: // loop can be terminated, either by:
// 1. explcitly iterrupted // 1. explicitly interrupted
// 2. terminated by workable timeout // 3. unrecoverable error or ENODATA (end of stream) or ETIMEDOUT (timeout)
// 3. unrecoverable error or ENODATA (end of stream)
// 4. decoded frames pts are out of the specified range // 4. decoded frames pts are out of the specified range
// 5. success decoded frame // 5. success decoded frame
if (interrupted_) { if (interrupted_) {
...@@ -594,11 +615,13 @@ int Decoder::getFrame(size_t workingTimeInMs) { ...@@ -594,11 +615,13 @@ int Decoder::getFrame(size_t workingTimeInMs) {
return 0; return 0;
} }
// find stream by stream index
Stream* Decoder::findByIndex(int streamIndex) const { Stream* Decoder::findByIndex(int streamIndex) const {
auto it = streams_.find(streamIndex); auto it = streams_.find(streamIndex);
return it != streams_.end() ? it->second.get() : nullptr; return it != streams_.end() ? it->second.get() : nullptr;
} }
// find stream by type; note finds only the first stream of a given type
Stream* Decoder::findByType(const MediaFormat& format) const { Stream* Decoder::findByType(const MediaFormat& format) const {
for (auto& stream : streams_) { for (auto& stream : streams_) {
if (stream.second->getMediaFormat().type == format.type) { if (stream.second->getMediaFormat().type == format.type) {
...@@ -608,11 +631,14 @@ Stream* Decoder::findByType(const MediaFormat& format) const { ...@@ -608,11 +631,14 @@ Stream* Decoder::findByType(const MediaFormat& format) const {
return nullptr; return nullptr;
} }
// given the stream and packet, decode the frame buffers into the
// DecoderOutputMessage data structure via stream::decodePacket function.
int Decoder::processPacket( int Decoder::processPacket(
Stream* stream, Stream* stream,
AVPacket* packet, AVPacket* packet,
bool* gotFrame, bool* gotFrame,
bool* hasMsg) { bool* hasMsg,
bool fastSeek) {
// decode package // decode package
int result; int result;
DecoderOutputMessage msg; DecoderOutputMessage msg;
...@@ -625,7 +651,15 @@ int Decoder::processPacket( ...@@ -625,7 +651,15 @@ int Decoder::processPacket(
bool endInRange = bool endInRange =
params_.endOffset <= 0 || msg.header.pts <= params_.endOffset; params_.endOffset <= 0 || msg.header.pts <= params_.endOffset;
inRange_.set(stream->getIndex(), endInRange); inRange_.set(stream->getIndex(), endInRange);
if (endInRange && msg.header.pts >= params_.startOffset) { // if fastseek is enabled, we're returning the first
// frame that we decode after (potential) seek.
// By default, we perform accurate seek to the closest
// following frame
bool startCondition = true;
if (!fastSeek) {
startCondition = msg.header.pts >= params_.startOffset;
}
if (endInRange && startCondition) {
*hasMsg = true; *hasMsg = true;
push(std::move(msg)); push(std::move(msg));
} }
......
...@@ -59,11 +59,11 @@ class Decoder : public MediaDecoder { ...@@ -59,11 +59,11 @@ class Decoder : public MediaDecoder {
private: private:
// mark below function for a proper invocation // mark below function for a proper invocation
virtual bool enableLogLevel(int level) const; bool enableLogLevel(int level) const;
virtual void logCallback(int level, const std::string& message); void logCallback(int level, const std::string& message);
virtual int readCallback(uint8_t* buf, int size); int readCallback(uint8_t* buf, int size);
virtual int64_t seekCallback(int64_t offset, int whence); int64_t seekCallback(int64_t offset, int whence);
virtual int shutdownCallback(); int shutdownCallback();
bool openStreams(std::vector<DecoderMetadata>* metadata); bool openStreams(std::vector<DecoderMetadata>* metadata);
Stream* findByIndex(int streamIndex) const; Stream* findByIndex(int streamIndex) const;
...@@ -72,7 +72,8 @@ class Decoder : public MediaDecoder { ...@@ -72,7 +72,8 @@ class Decoder : public MediaDecoder {
Stream* stream, Stream* stream,
AVPacket* packet, AVPacket* packet,
bool* gotFrame, bool* gotFrame,
bool* hasMsg); bool* hasMsg,
bool fastSeek = false);
void flushStreams(); void flushStreams();
void cleanUp(); void cleanUp();
......
...@@ -165,7 +165,7 @@ struct MediaFormat { ...@@ -165,7 +165,7 @@ struct MediaFormat {
struct DecoderParameters { struct DecoderParameters {
// local file, remote file, http url, rtmp stream uri, etc. anything that // local file, remote file, http url, rtmp stream uri, etc. anything that
// ffmpeg can recognize // ffmpeg can recognize
std::string uri; std::string uri{std::string()};
// timeout on getting bytes for decoding // timeout on getting bytes for decoding
size_t timeoutMs{1000}; size_t timeoutMs{1000};
// logging level, default AV_LOG_PANIC // logging level, default AV_LOG_PANIC
...@@ -190,10 +190,15 @@ struct DecoderParameters { ...@@ -190,10 +190,15 @@ struct DecoderParameters {
bool listen{false}; bool listen{false};
// don't copy frame body, only header // don't copy frame body, only header
bool headerOnly{false}; bool headerOnly{false};
// enable fast seek (seek only to keyframes)
bool fastSeek{false};
// interrupt init method on timeout // interrupt init method on timeout
bool preventStaleness{true}; bool preventStaleness{true};
// seek tolerated accuracy (us) // seek tolerated accuracy (us)
double seekAccuracy{1000000.0}; double seekAccuracy{1000000.0};
// Allow multithreaded decoding for numThreads > 1;
// 0 numThreads=0 sets up sensible defaults
int numThreads{1};
// what media types should be processed, default none // what media types should be processed, default none
std::set<MediaFormat> formats; std::set<MediaFormat> formats;
...@@ -205,6 +210,15 @@ struct DecoderParameters { ...@@ -205,6 +210,15 @@ struct DecoderParameters {
std::string tlsCertFile; std::string tlsCertFile;
std::string tlsKeyFile; std::string tlsKeyFile;
// Skip packets that fail with EPERM errors and continue decoding.
bool skipOperationNotPermittedPackets{false};
// probing size in bytes, i.e. the size of the data to analyze to get stream
// information. A higher value will enable detecting more information in case
// it is dispersed into the stream, but will increase latency. Must be an
// integer not lesser than 32. It is 5000000 by default.
int64_t probeSize{5000000};
}; };
struct DecoderHeader { struct DecoderHeader {
...@@ -287,7 +301,7 @@ struct DecoderMetadata { ...@@ -287,7 +301,7 @@ struct DecoderMetadata {
}; };
/** /**
* Abstract class for decoding media bytes * Abstract class for decoding media bytes
* It has two diffrent modes. Internal media bytes retrieval for given uri and * It has two different modes. Internal media bytes retrieval for given uri and
* external media bytes provider in case of memory streams * external media bytes provider in case of memory streams
*/ */
class MediaDecoder { class MediaDecoder {
......
GPU Decoder
===========
GPU decoder depends on ffmpeg for demuxing, uses NVDECODE APIs from the nvidia-video-codec sdk and uses cuda for processing on gpu. In order to use this, please follow the following steps:
* Download the latest `nvidia-video-codec-sdk <https://developer.nvidia.com/nvidia-video-codec-sdk/download>`_
* Extract the zipped file.
* Set TORCHVISION_INCLUDE environment variable to the location of the video codec headers(`nvcuvid.h` and `cuviddec.h`), which would be under `Interface` directory.
* Set TORCHVISION_LIBRARY environment variable to the location of the video codec library(`libnvcuvid.so`), which would be under `Lib/linux/stubs/x86_64` directory.
* Install the latest ffmpeg from `conda-forge` channel.
.. code:: bash
conda install -c conda-forge ffmpeg
* Set CUDA_HOME environment variable to the cuda root directory.
* Build torchvision from source:
.. code:: bash
python setup.py install
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment