test_multi_tensor_scale.py 4.47 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import unittest

import functools as ft
import itertools as it

from apex import amp
import torch
from torch import nn
import torch.nn.functional as F

from utils import common_init, HALF, FLOAT,\
    ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT

try:
  import amp_C
  from amp_C import multi_tensor_scale 
  from apex.multi_tensor_apply import MultiTensorApply
  disabled = False
except ImportError as err:
  print("amp_C fused kernels unavailable, disabling TestMultiTensorApply.  ImportError was ", err)
  disabled = True


class TestMultiTensorScale(unittest.TestCase):

    def setUp(self):
        self.scale = 4.0
        self.overflow_buf = torch.cuda.IntTensor(1).zero_()
        self.ref = torch.cuda.FloatTensor([1.0])

        common_init(self)

    def tearDown(self):
        pass

    # The tensor creation here is written for convenience, not speed.
37
    def downscale(self, sizea, sizeb, applier, repeat_tensors, in_type, out_type, inplace=False):
38
39
40
41
42
43
        self.overflow_buf.zero_()
        a = torch.cuda.FloatTensor(sizea).fill_(self.scale)
        b = torch.cuda.FloatTensor(sizeb).fill_(self.scale)

        out_list = []
        for i in range(repeat_tensors):
44
            out_list += [a.clone().to(out_type), b.clone().to(out_type)]
45
46
47
48
49
50
51
52

        if inplace:
            in_list = out_list
        else:
            in_list = [out.clone().to(in_type) for out in out_list]

        applier(multi_tensor_scale, self.overflow_buf, [in_list, out_list], 1./self.scale)

53
        self.assertTrue(all([torch.allclose(out, self.ref.to(out_type)) for out in out_list]))
54
55
        self.assertTrue(self.overflow_buf.item() == 0)
 
56
    def find_inf(self, sizea, sizeb, applier, repeat_tensors, in_type, out_type, t, ind, val, inplace=False):
57
58
59
60
61
62
        self.overflow_buf.zero_()
        a = torch.cuda.FloatTensor(sizea).fill_(self.scale)
        b = torch.cuda.FloatTensor(sizeb).fill_(self.scale)

        out_list = []
        for i in range(repeat_tensors):
63
            out_list += [a.clone().to(out_type), b.clone().to(out_type)]
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

        if inplace:
            in_list = out_list
        else:
            in_list = [out.clone().to(in_type) for out in out_list]

        applier(multi_tensor_scale, self.overflow_buf, [in_list, out_list], 1./self.scale)

        self.overflow_buf.zero_()
        in_list[t][ind] = val
        applier(multi_tensor_scale, self.overflow_buf, [in_list, out_list], 1./self.scale)
        self.assertTrue(self.overflow_buf.item())

    # Currently, the fused kernel gives a hard error if you attempt to downscale
    # into fp16 output, which imo is the desired behavior.  Maybe someday we
    # will learn otherwise.
    # @unittest.skipIf(disabled, "amp_C is unavailable")
    # def test_fp16_to_fp16(self):
    #     self.downscale(self.fp16, self.fp16, self.fp16_ref)
    # 
    # @unittest.skipIf(disabled, "amp_C is unavailable")
    # def test_fp32_to_fp16(self):
    #     self.downscale(self.fp32, self.fp16, self.fp16_ref)

    @unittest.skipIf(disabled, "amp_C is unavailable")
    def test_fuzz(self):
        input_size_pairs = (
            (7777*77, 555*555),
            (777, 555),
            (555, 2048*32+1),
            (2048*32+1, 555),
            (555, 2048*32),
            (2048*32, 555),
            (33333, 555),
            (555, 33333))
        appliers = (
            MultiTensorApply(2048*32), 
            MultiTensorApply(333),
            MultiTensorApply(33333))
        repeat_tensors = (
            1,
            55)

        for sizea, sizeb in input_size_pairs:
          for applier in appliers:
            for repeat in repeat_tensors:
110
111
112
113
114
115
116
117
118
119
120
121
122
              for in_type in (torch.float32, torch.float16):
                for out_type in (torch.float32, torch.float16):
                  for inplace in (True, False):
                    if inplace is True and (out_type is not in_type):
                      continue
                    else:
                      self.downscale(sizea, sizeb, applier, repeat, in_type, out_type, inplace=inplace)
                      self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
                                    0, 0, float('nan'), inplace=inplace)
                      self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
                                    2*repeat-1, sizeb-1, float('inf'), inplace=inplace)
                      self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
                                   2*(repeat//2), sizea//2, float('inf'), inplace=inplace)
123
124
125
126
127



if __name__ == '__main__':
    unittest.main()