common_testing.py 1.96 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

Roman Shapovalov's avatar
Roman Shapovalov committed
3
from typing import Optional
facebook-github-bot's avatar
facebook-github-bot committed
4
5

import unittest
6
7

import numpy as np
facebook-github-bot's avatar
facebook-github-bot committed
8
9
10
11
12
13
14
15
import torch


class TestCaseMixin(unittest.TestCase):
    def assertSeparate(self, tensor1, tensor2) -> None:
        """
        Verify that tensor1 and tensor2 have their data in distinct locations.
        """
16
        self.assertNotEqual(tensor1.storage().data_ptr(), tensor2.storage().data_ptr())
facebook-github-bot's avatar
facebook-github-bot committed
17

Georgia Gkioxari's avatar
Georgia Gkioxari committed
18
19
20
21
    def assertNotSeparate(self, tensor1, tensor2) -> None:
        """
        Verify that tensor1 and tensor2 have their data in the same locations.
        """
22
        self.assertEqual(tensor1.storage().data_ptr(), tensor2.storage().data_ptr())
Georgia Gkioxari's avatar
Georgia Gkioxari committed
23

facebook-github-bot's avatar
facebook-github-bot committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    def assertAllSeparate(self, tensor_list) -> None:
        """
        Verify that all tensors in tensor_list have their data in
        distinct locations.
        """
        ptrs = [i.storage().data_ptr() for i in tensor_list]
        self.assertCountEqual(ptrs, set(ptrs))

    def assertClose(
        self,
        input,
        other,
        *,
        rtol: float = 1e-05,
        atol: float = 1e-08,
Roman Shapovalov's avatar
Roman Shapovalov committed
39
40
        equal_nan: bool = False,
        msg: Optional[str] = None,
facebook-github-bot's avatar
facebook-github-bot committed
41
42
43
44
45
46
    ) -> None:
        """
        Verify that two tensors or arrays are the same shape and close.
        Args:
            input, other: two tensors or two arrays.
            rtol, atol, equal_nan: as for torch.allclose.
Roman Shapovalov's avatar
Roman Shapovalov committed
47
            msg: message in case the assertion is violated.
facebook-github-bot's avatar
facebook-github-bot committed
48
49
50
51
52
53
54
55
56
57
58
59
        Note:
            Optional arguments here are all keyword-only, to avoid confusion
            with msg arguments on other assert functions.
        """

        self.assertEqual(np.shape(input), np.shape(other))

        if torch.is_tensor(input):
            close = torch.allclose(
                input, other, rtol=rtol, atol=atol, equal_nan=equal_nan
            )
        else:
Roman Shapovalov's avatar
Roman Shapovalov committed
60
61
62
63
            close = np.allclose(
                input, other, rtol=rtol, atol=atol, equal_nan=equal_nan
            )
        self.assertTrue(close, msg)