test_data_transforms_tensor.py 908 Bytes
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
5
6
7
8
9
10
11
12
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


import unittest

import numpy as np
import torch
from d2go.data.transforms import tensor as tensor_aug
from detectron2.data.transforms.augmentation import AugmentationList


Sam Tsai's avatar
Sam Tsai committed
13
class TestDataTransformsTensor(unittest.TestCase):
facebook-github-bot's avatar
facebook-github-bot committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
    def test_tensor_aug(self):
        """Data augmentation that that allows torch.Tensor as input"""

        img = torch.ones(3, 8, 6)
        augs = [tensor_aug.Tensor2Array(), tensor_aug.Array2Tensor()]

        inputs = tensor_aug.AugInput(image=img)
        transforms = AugmentationList(augs)(inputs)
        self.assertArrayEqual(img, inputs.image)

        # inverse is the same as itself
        out_img = transforms.inverse().apply_image(img)
        self.assertArrayEqual(img, out_img)

    def assertArrayEqual(self, a1, a2):
        self.assertTrue(np.array_equal(a1, a2))