test_promotion.py 3.95 KB
Newer Older
Carl Case's avatar
Carl Case committed
1
2
3
4
5
6
7
8
9
import unittest

import itertools as it

from apex import amp
import torch
from torch import nn
import torch.nn.functional as F

rohithkrn's avatar
rohithkrn committed
10
from utils import common_init, HALF, FLOAT, DTYPES, DTYPES2, MATCH_INPUT
Carl Case's avatar
Carl Case committed
11

rohithkrn's avatar
rohithkrn committed
12
13
14
15
16
17
18
19
20
21
class _TestPromotion(unittest.TestCase):
    def run_binary_promote_test(self, fns, input_shape, lp_type, x_inplace=False):
        if lp_type == torch.half:
            dtypes = DTYPES
        elif lp_type == torch.bfloat16:
            dtypes = DTYPES2
        else:
            raise RuntimeError("Creating test class with invalid low_precision type. \
                                Supported types are torch.half and torch.bfloat16")
        type_pairs = it.product(dtypes, dtypes)
Carl Case's avatar
Carl Case committed
22
23
        for fn, (xtype, ytype) in it.product(fns, type_pairs):
            x = torch.randn(input_shape, dtype=xtype).requires_grad_()
Carl Case's avatar
Carl Case committed
24
25
26
27
            x_leaf = x
            if x_inplace:
                # We need a non-leaf to call in place on
                x = x.clone()
Carl Case's avatar
Carl Case committed
28
29
            y = torch.randn(input_shape, dtype=ytype)
            out = fn(x, y)
Carl Case's avatar
Carl Case committed
30
31
32
            if x_inplace:
                # In place: always match xtype
                self.assertEqual(out.type(), x.type())
Carl Case's avatar
Carl Case committed
33
            else:
Carl Case's avatar
Carl Case committed
34
35
36
37
                # Out of place: match widest type
                if xtype == torch.float or ytype == torch.float:
                    self.assertEqual(out.type(), FLOAT)
                else:
rohithkrn's avatar
rohithkrn committed
38
                    self.assertEqual(out.type(), MATCH_INPUT[lp_type])
Carl Case's avatar
Carl Case committed
39
            out.float().sum().backward()
Carl Case's avatar
Carl Case committed
40
            self.assertEqual(x_leaf.grad.dtype, xtype)
Carl Case's avatar
Carl Case committed
41

rohithkrn's avatar
rohithkrn committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    def _test_cat_matches_widest(self, lp_type):
        shape = self.b
        ys = [torch.randn(shape, dtype=lp_type) for _ in range(5)]
        x_float = torch.randn(shape)
        out = torch.cat(ys + [x_float])
        self.assertEqual(out.type(), FLOAT)
        x_lp = torch.randn(shape, dtype=lp_type)
        out = torch.cat(ys + [x_lp])
        self.assertEqual(out.type(), MATCH_INPUT[lp_type])

    def _test_inplace_exp_is_error_for_lp(self, lp_type):
        xs = torch.randn(self.b)
        xs.exp_()
        self.assertEqual(xs.type(), FLOAT)
        xs = torch.randn(self.b, dtype=lp_type)
        with self.assertRaises(NotImplementedError):
            xs.exp_()

class TestPromotionHalf(_TestPromotion):
    def setUp(self):
        self.handle = amp.init(enabled=True, patch_type=torch.half)
        common_init(self)

    def tearDown(self):
        self.handle._deactivate()

Carl Case's avatar
Carl Case committed
68
69
70
    def test_atan2_matches_widest(self):
        fns = [lambda x, y : torch.atan2(x, y),
               lambda x, y : x.atan2(y)]
rohithkrn's avatar
rohithkrn committed
71
        self.run_binary_promote_test(fns, (self.b,), torch.half)
Carl Case's avatar
Carl Case committed
72
73
74
75

    def test_mul_matches_widest(self):
        fns = [lambda x, y : torch.mul(x, y),
               lambda x, y: x.mul(y)]
rohithkrn's avatar
rohithkrn committed
76
        self.run_binary_promote_test(fns, (self.b,), torch.half)
Carl Case's avatar
Carl Case committed
77
78

    def test_cat_matches_widest(self):
rohithkrn's avatar
rohithkrn committed
79
        self._test_cat_matches_widest(torch.half)
Carl Case's avatar
Carl Case committed
80

Carl Case's avatar
Carl Case committed
81
    def test_inplace_exp_is_error_for_half(self):
rohithkrn's avatar
rohithkrn committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
        self._test_inplace_exp_is_error_for_lp(torch.half)

    def test_inplace_add_matches_self(self):
        fn = lambda x, y: x.add_(y)
        self.run_binary_promote_test([fn], (self.b,), torch.half, x_inplace=True)

class TestPromotionBFloat16(_TestPromotion):
    def setUp(self):
        self.handle = amp.init(enabled=True, patch_type=torch.bfloat16)
        common_init(self)

    def tearDown(self):
        self.handle._deactivate()

    def test_mul_matches_widest(self):
        fns = [lambda x, y : torch.mul(x, y),
               lambda x, y: x.mul(y)]
        self.run_binary_promote_test(fns, (self.b,), torch.bfloat16)

    def test_cat_matches_widest(self):
        self._test_cat_matches_widest(torch.bfloat16)

    def test_inplace_exp_is_error_for_bfloat16(self):
        self._test_inplace_exp_is_error_for_lp(torch.bfloat16)
Carl Case's avatar
Carl Case committed
106
107
108

    def test_inplace_add_matches_self(self):
        fn = lambda x, y: x.add_(y)
rohithkrn's avatar
rohithkrn committed
109
        self.run_binary_promote_test([fn], (self.b,), torch.bfloat16, x_inplace=True)
Carl Case's avatar
Carl Case committed
110
111
112

if __name__ == '__main__':
    unittest.main()