common_testing.py 1.93 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

import unittest
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
4
from typing import Optional
5
6

import numpy as np
facebook-github-bot's avatar
facebook-github-bot committed
7
8
9
10
11
12
13
14
import torch


class TestCaseMixin(unittest.TestCase):
    def assertSeparate(self, tensor1, tensor2) -> None:
        """
        Verify that tensor1 and tensor2 have their data in distinct locations.
        """
15
        self.assertNotEqual(tensor1.storage().data_ptr(), tensor2.storage().data_ptr())
facebook-github-bot's avatar
facebook-github-bot committed
16

Georgia Gkioxari's avatar
Georgia Gkioxari committed
17
18
19
20
    def assertNotSeparate(self, tensor1, tensor2) -> None:
        """
        Verify that tensor1 and tensor2 have their data in the same locations.
        """
21
        self.assertEqual(tensor1.storage().data_ptr(), tensor2.storage().data_ptr())
Georgia Gkioxari's avatar
Georgia Gkioxari committed
22

facebook-github-bot's avatar
facebook-github-bot committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
    def assertAllSeparate(self, tensor_list) -> None:
        """
        Verify that all tensors in tensor_list have their data in
        distinct locations.
        """
        ptrs = [i.storage().data_ptr() for i in tensor_list]
        self.assertCountEqual(ptrs, set(ptrs))

    def assertClose(
        self,
        input,
        other,
        *,
        rtol: float = 1e-05,
        atol: float = 1e-08,
Roman Shapovalov's avatar
Roman Shapovalov committed
38
39
        equal_nan: bool = False,
        msg: Optional[str] = None,
facebook-github-bot's avatar
facebook-github-bot committed
40
41
42
43
44
45
    ) -> None:
        """
        Verify that two tensors or arrays are the same shape and close.
        Args:
            input, other: two tensors or two arrays.
            rtol, atol, equal_nan: as for torch.allclose.
Roman Shapovalov's avatar
Roman Shapovalov committed
46
            msg: message in case the assertion is violated.
facebook-github-bot's avatar
facebook-github-bot committed
47
48
49
50
51
52
53
54
55
56
57
58
        Note:
            Optional arguments here are all keyword-only, to avoid confusion
            with msg arguments on other assert functions.
        """

        self.assertEqual(np.shape(input), np.shape(other))

        if torch.is_tensor(input):
            close = torch.allclose(
                input, other, rtol=rtol, atol=atol, equal_nan=equal_nan
            )
        else:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
59
            close = np.allclose(input, other, rtol=rtol, atol=atol, equal_nan=equal_nan)
Roman Shapovalov's avatar
Roman Shapovalov committed
60
        self.assertTrue(close, msg)