Commit d194082c authored by ekka's avatar ekka Committed by Francisco Massa
Browse files

Make crop scriptable (#1379)

* Make crop torchscriptable

Relevant #1375

* Invert x and y axis

* fix lint

* Add crop test

* revert deletion of space in functional

* add import random

* add dimension in doc

* add import

* fix flake8

* change to self.assert*

* convert to uint8

* assertTrue

* lint
parent b0f88dff
import torch
import torchvision.transforms as transforms
import torchvision.transforms.functional_tensor as F_t import torchvision.transforms.functional_tensor as F_t
import torchvision.transforms.functional as F
import unittest import unittest
import torch import random
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
...@@ -19,6 +22,20 @@ class Tester(unittest.TestCase): ...@@ -19,6 +22,20 @@ class Tester(unittest.TestCase):
self.assertEqual(hflipped_img.shape, img_tensor.shape) self.assertEqual(hflipped_img.shape, img_tensor.shape)
self.assertTrue(torch.equal(img_tensor, hflipped_img_again)) self.assertTrue(torch.equal(img_tensor, hflipped_img_again))
def test_crop(self):
img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8)
top = random.randint(0, 15)
left = random.randint(0, 15)
height = random.randint(1, 16 - top)
width = random.randint(1, 16 - left)
img_cropped = F_t.crop(img_tensor, top, left, height, width)
img_PIL = transforms.ToPILImage()(img_tensor)
img_PIL_cropped = F.crop(img_PIL, top, left, height, width)
img_cropped_GT = transforms.ToTensor()(img_PIL_cropped)
self.assertTrue(torch.equal(img_cropped, (img_cropped_GT * 255).to(torch.uint8)),
"functional_tensor crop not working")
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -31,3 +31,20 @@ def hflip(img_tensor): ...@@ -31,3 +31,20 @@ def hflip(img_tensor):
raise TypeError('tensor is not a torch image.') raise TypeError('tensor is not a torch image.')
return img_tensor.flip(-1) return img_tensor.flip(-1)
def crop(img, top, left, height, width):
"""Crop the given Image Tensor.
Args:
img (Tensor): Image to be cropped in the form [C, H, W]. (0,0) denotes the top left corner of the image.
top (int): Vertical component of the top left corner of the crop box.
left (int): Horizontal component of the top left corner of the crop box.
height (int): Height of the crop box.
width (int): Width of the crop box.
Returns:
Tensor: Cropped image.
"""
if not F._is_tensor_image(img):
raise TypeError('tensor is not a torch image.')
return img[..., top:top + height, left:left + width]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment