Commit 64917bcc authored by Zhicheng Yan's avatar Zhicheng Yan Committed by Francisco Massa
Browse files

Video transforms (#1353)

* video transforms

* [video transforms]in ToTensorVideo, divide value by 255.0

* [video transforms] fix a bug

* fix linting

* Make changes backwards-compatible
parent a15ff20f
from __future__ import division
import torch
import torchvision.transforms as transforms
import unittest
import random
import numpy as np
try:
from scipy import stats
except ImportError:
stats = None
class Tester(unittest.TestCase):
def test_random_crop_video(self):
numFrames = random.randint(4, 128)
height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2
oheight = random.randint(5, (height - 2) / 2) * 2
owidth = random.randint(5, (width - 2) / 2) * 2
clip = torch.randint(0, 256, (numFrames, height, width, 3), dtype=torch.uint8)
result = transforms.Compose([
transforms.ToTensorVideo(),
transforms.RandomCropVideo((oheight, owidth)),
])(clip)
assert result.size(2) == oheight
assert result.size(3) == owidth
transforms.RandomCropVideo((oheight, owidth)).__repr__()
def test_random_resized_crop_video(self):
numFrames = random.randint(4, 128)
height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2
oheight = random.randint(5, (height - 2) / 2) * 2
owidth = random.randint(5, (width - 2) / 2) * 2
clip = torch.randint(0, 256, (numFrames, height, width, 3), dtype=torch.uint8)
result = transforms.Compose([
transforms.ToTensorVideo(),
transforms.RandomResizedCropVideo((oheight, owidth)),
])(clip)
assert result.size(2) == oheight
assert result.size(3) == owidth
transforms.RandomResizedCropVideo((oheight, owidth)).__repr__()
def test_center_crop_video(self):
numFrames = random.randint(4, 128)
height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2
oheight = random.randint(5, (height - 2) / 2) * 2
owidth = random.randint(5, (width - 2) / 2) * 2
clip = torch.ones((numFrames, height, width, 3), dtype=torch.uint8) * 255
oh1 = (height - oheight) // 2
ow1 = (width - owidth) // 2
clipNarrow = clip[:, oh1:oh1 + oheight, ow1:ow1 + owidth, :]
clipNarrow.fill_(0)
result = transforms.Compose([
transforms.ToTensorVideo(),
transforms.CenterCropVideo((oheight, owidth)),
])(clip)
msg = "height: " + str(height) + " width: " \
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
self.assertEqual(result.sum().item(), 0, msg)
oheight += 1
owidth += 1
result = transforms.Compose([
transforms.ToTensorVideo(),
transforms.CenterCropVideo((oheight, owidth)),
])(clip)
sum1 = result.sum()
msg = "height: " + str(height) + " width: " \
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
self.assertEqual(sum1.item() > 1, True, msg)
oheight += 1
owidth += 1
result = transforms.Compose([
transforms.ToTensorVideo(),
transforms.CenterCropVideo((oheight, owidth)),
])(clip)
sum2 = result.sum()
msg = "height: " + str(height) + " width: " \
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
self.assertTrue(sum2.item() > 1, msg)
self.assertTrue(sum2.item() > sum1.item(), msg)
@unittest.skipIf(stats is None, 'scipy.stats is not available')
def test_normalize_video(self):
def samples_from_standard_normal(tensor):
p_value = stats.kstest(list(tensor.view(-1)), 'norm', args=(0, 1)).pvalue
return p_value > 0.0001
random_state = random.getstate()
random.seed(42)
for channels in [1, 3]:
numFrames = random.randint(4, 128)
height = random.randint(32, 256)
width = random.randint(32, 256)
mean = random.random()
std = random.random()
clip = torch.normal(mean, std, size=(channels, numFrames, height, width))
mean = [clip[c].mean().item() for c in range(channels)]
std = [clip[c].std().item() for c in range(channels)]
normalized = transforms.NormalizeVideo(mean, std)(clip)
assert samples_from_standard_normal(normalized)
random.setstate(random_state)
# Checking the optional in-place behaviour
tensor = torch.rand((3, 128, 16, 16))
tensor_inplace = transforms.NormalizeVideo((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)(tensor)
assert torch.equal(tensor, tensor_inplace)
transforms.NormalizeVideo((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True).__repr__()
def test_to_tensor_video(self):
numFrames, height, width = 64, 4, 4
trans = transforms.ToTensorVideo()
with self.assertRaises(TypeError):
trans(np.random.rand(numFrames, height, width, 1).tolist())
trans(torch.rand((numFrames, height, width, 1), dtype=torch.float))
with self.assertRaises(ValueError):
trans(torch.ones((3, numFrames, height, width, 3), dtype=torch.uint8))
trans(torch.ones((height, width, 3), dtype=torch.uint8))
trans(torch.ones((width, 3), dtype=torch.uint8))
trans(torch.ones((3), dtype=torch.uint8))
trans.__repr__()
@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_horizontal_flip_video(self):
random_state = random.getstate()
random.seed(42)
clip = torch.rand((3, 4, 112, 112), dtype=torch.float)
hclip = clip.flip((-1))
num_samples = 250
num_horizontal = 0
for _ in range(num_samples):
out = transforms.RandomHorizontalFlipVideo()(clip)
if torch.all(torch.eq(out, hclip)):
num_horizontal += 1
p_value = stats.binom_test(num_horizontal, num_samples, p=0.5)
random.setstate(random_state)
assert p_value > 0.0001
num_samples = 250
num_horizontal = 0
for _ in range(num_samples):
out = transforms.RandomHorizontalFlipVideo(p=0.7)(clip)
if torch.all(torch.eq(out, hclip)):
num_horizontal += 1
p_value = stats.binom_test(num_horizontal, num_samples, p=0.7)
random.setstate(random_state)
assert p_value > 0.0001
transforms.RandomHorizontalFlipVideo().__repr__()
if __name__ == '__main__':
unittest.main()
from .transforms import *
from .transforms_video import *
import torch
def _is_tensor_video_clip(clip):
if not torch.is_tensor(clip):
raise TypeError("clip should be Tesnor. Got %s" % type(clip))
if not clip.ndimension() == 4:
raise ValueError("clip should be 4D. Got %dD" % clip.dim())
return True
def crop(clip, i, j, h, w):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
"""
assert len(clip.size()) == 4, "clip should be a 4D tensor"
return clip[..., i:i + h, j:j + w]
def resize(clip, target_size, interpolation_mode):
assert len(target_size) == 2, "target size should be tuple (height, width)"
return torch.nn.functional.interpolate(
clip, size=target_size, mode=interpolation_mode
)
def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
"""
Do spatial cropping and resizing to the video clip
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
i (int): i in (i,j) i.e coordinates of the upper left corner.
j (int): j in (i,j) i.e coordinates of the upper left corner.
h (int): Height of the cropped region.
w (int): Width of the cropped region.
size (tuple(int, int)): height and width of resized clip
Returns:
clip (torch.tensor): Resized and cropped clip. Size is (C, T, H, W)
"""
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
clip = crop(clip, i, j, h, w)
clip = resize(clip, size, interpolation_mode)
return clip
def center_crop(clip, crop_size):
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
h, w = clip.size(-2), clip.size(-1)
th, tw = crop_size
assert h >= th and w >= tw, "height and width must be no smaller than crop_size"
i = int(round((h - th) / 2.0))
j = int(round((w - tw) / 2.0))
return crop(clip, i, j, th, tw)
def to_tensor(clip):
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimenions of clip tensor
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C)
Return:
clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)
"""
_is_tensor_video_clip(clip)
if not clip.dtype == torch.uint8:
raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
return clip.float().permute(3, 0, 1, 2) / 255.0
def normalize(clip, mean, std, inplace=False):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
mean (tuple): pixel RGB mean. Size is (3)
std (tuple): pixel standard deviation. Size is (3)
Returns:
normalized clip (torch.tensor): Size is (C, T, H, W)
"""
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
if not inplace:
clip = clip.clone()
mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
return clip
def hflip(clip):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
Returns:
flipped clip (torch.tensor): Size is (C, T, H, W)
"""
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
return clip.flip((-1))
......@@ -40,6 +40,15 @@ _pil_interpolation_to_str = {
}
def _get_image_size(img):
if F._is_pil_image(img):
return img.size
elif isinstance(img, torch.Tensor) and img.dim() > 2:
return img.shape[-2:][::-1]
else:
raise TypeError("Unexpected type {}".format(type(img)))
class Compose(object):
"""Composes several transforms together.
......@@ -444,7 +453,7 @@ class RandomCrop(object):
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
"""
w, h = img.size
w, h = _get_image_size(img)
th, tw = output_size
if w == tw and h == th:
return 0, 0, h, w
......@@ -635,7 +644,8 @@ class RandomResizedCrop(object):
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
sized crop.
"""
area = img.size[0] * img.size[1]
width, height = _get_image_size(img)
area = height * width
for attempt in range(10):
target_area = random.uniform(*scale) * area
......@@ -645,24 +655,24 @@ class RandomResizedCrop(object):
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if 0 < w <= img.size[0] and 0 < h <= img.size[1]:
i = random.randint(0, img.size[1] - h)
j = random.randint(0, img.size[0] - w)
if 0 < w <= width and 0 < h <= height:
i = random.randint(0, height - h)
j = random.randint(0, width - w)
return i, j, h, w
# Fallback to central crop
in_ratio = img.size[0] / img.size[1]
in_ratio = float(width) / float(height)
if (in_ratio < min(ratio)):
w = img.size[0]
w = width
h = int(round(w / min(ratio)))
elif (in_ratio > max(ratio)):
h = img.size[1]
h = height
w = int(round(h * max(ratio)))
else: # whole image
w = img.size[0]
h = img.size[1]
i = (img.size[1] - h) // 2
j = (img.size[0] - w) // 2
w = width
h = height
i = (height - h) // 2
j = (width - w) // 2
return i, j, h, w
def __call__(self, img):
......
#!/usr/bin/env python3
import numbers
import random
from torchvision.transforms import (
RandomCrop,
RandomResizedCrop,
)
from . import functional_video as F
__all__ = [
"RandomCropVideo",
"RandomResizedCropVideo",
"CenterCropVideo",
"NormalizeVideo",
"ToTensorVideo",
"RandomHorizontalFlipVideo",
]
class RandomCropVideo(RandomCrop):
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
Returns:
torch.tensor: randomly cropped/resized video clip.
size is (C, T, OH, OW)
"""
i, j, h, w = self.get_params(clip, self.size)
return F.crop(clip, i, j, h, w)
def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size)
class RandomResizedCropVideo(RandomResizedCrop):
def __init__(
self,
size,
scale=(0.08, 1.0),
ratio=(3.0 / 4.0, 4.0 / 3.0),
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
assert len(size) == 2, "size should be tuple (height, width)"
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
self.scale = scale
self.ratio = ratio
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
Returns:
torch.tensor: randomly cropped/resized video clip.
size is (C, T, H, W)
"""
i, j, h, w = self.get_params(clip, self.scale, self.ratio)
return F.resized_crop(clip, i, j, h, w, self.size, self.interpolation_mode)
def __repr__(self):
return self.__class__.__name__ + \
'(size={0}, interpolation_mode={1}, scale={2}, ratio={3})'.format(
self.size, self.interpolation_mode, self.scale, self.ratio
)
class CenterCropVideo(object):
def __init__(self, crop_size):
if isinstance(crop_size, numbers.Number):
self.crop_size = (int(crop_size), int(crop_size))
else:
self.crop_size = crop_size
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
Returns:
torch.tensor: central cropping of video clip. Size is
(C, T, crop_size, crop_size)
"""
return F.center_crop(clip, self.crop_size)
def __repr__(self):
return self.__class__.__name__ + '(crop_size={0})'.format(self.crop_size)
class NormalizeVideo(object):
"""
Normalize the video clip by mean subtraction and division by standard deviation
Args:
mean (3-tuple): pixel RGB mean
std (3-tuple): pixel RGB standard deviation
inplace (boolean): whether do in-place normalization
"""
def __init__(self, mean, std, inplace=False):
self.mean = mean
self.std = std
self.inplace = inplace
def __call__(self, clip):
"""
Args:
clip (torch.tensor): video clip to be normalized. Size is (C, T, H, W)
"""
return F.normalize(clip, self.mean, self.std, self.inplace)
def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1}, inplace={2})'.format(
self.mean, self.std, self.inplace)
class ToTensorVideo(object):
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimenions of clip tensor
"""
def __init__(self):
pass
def __call__(self, clip):
"""
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C)
Return:
clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)
"""
return F.to_tensor(clip)
def __repr__(self):
return self.__class__.__name__
class RandomHorizontalFlipVideo(object):
"""
Flip the video clip along the horizonal direction with a given probability
Args:
p (float): probability of the clip being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
self.p = p
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Size is (C, T, H, W)
Return:
clip (torch.tensor): Size is (C, T, H, W)
"""
if random.random() < self.p:
clip = F.hflip(clip)
return clip
def __repr__(self):
return self.__class__.__name__ + "(p={0})".format(self.p)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment