test_scale.py 3.33 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import unittest

import functools as ft
import itertools as it

from apex import amp
import torch
from torch import nn
import torch.nn.functional as F

from utils import common_init, HALF, FLOAT,\
    ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT

try:
  import amp_C
  scale_check_overflow = amp_C.scale_check_overflow
  disabled = False
except ImportError as err:
  print("amp_C fused kernel unavailable, disabling TestScale.  ImportError was ", err)
  disabled = True


class TestScale(unittest.TestCase):

    def setUp(self):
        self.scale = 128.0
        self.nx = 999
        self.ny = 888
  
        self.overflow_buf = torch.cuda.IntTensor([0])
        self.fp16 = torch.ones((self.ny, self.nx), device='cuda', dtype=torch.float16)
        self.fp32 = torch.ones((self.ny, self.nx), device='cuda', dtype=torch.float32)
        self.fp16_ref = torch.ones((1, 1), device='cuda', dtype=torch.float16)
        self.fp32_ref = torch.ones((1, 1), device='cuda', dtype=torch.float32)

        common_init(self)

    def tearDown(self):
        pass

    def downscale_test(self, input, output, ref):
        self.overflow_buf.zero_()
        input.fill_(1.0)
        if input is not output:
            output.fill_(3.0)
        input.mul_(self.scale)
        scale_check_overflow(input, 1./self.scale, self.overflow_buf, output)
        self.assertTrue(torch.allclose(output, ref))
        self.assertTrue(self.overflow_buf.item() == 0)
 
    def find_inf_test(self, input, output, ref, x, y, val):
        self.overflow_buf.zero_()
        input.fill_(1.0)
        if input is not output:
            output.fill_(3.0)
        input[x,y] = val
        scale_check_overflow(input, 1./self.scale, self.overflow_buf, output)
        self.assertTrue(self.overflow_buf.item())

    # Currently, the fused kernel gives a hard error if you attempt to downscale
    # into fp16 output, which imo is the desired behavior.  Maybe someday we
    # will learn otherwise.
    # @unittest.skipIf(disabled, "amp_C is unavailable")
    # def test_fp16_to_fp16(self):
    #     self.downscale_test(self.fp16, self.fp16, self.fp16_ref)

    @unittest.skipIf(disabled, "amp_C is unavailable")
    def test_fp16_to_fp32(self):
        self.downscale_test(self.fp16, self.fp32, self.fp32_ref)

    # @unittest.skipIf(disabled, "amp_C is unavailable")
    # def test_fp32_to_fp16(self):
    #     self.downscale_test(self.fp32, self.fp16, self.fp16_ref)

    @unittest.skipIf(disabled, "amp_C is unavailable")
    def test_fp32_to_fp32(self):
        self.downscale_test(self.fp32, self.fp32, self.fp32_ref)

    @unittest.skipIf(disabled, "amp_C is unavailable")
    def test_fp16_to_fp32_find_inf_nan(self):
        self.find_inf_test(self.fp16, self.fp32, self.fp32_ref, 0, 0, float('nan'))
        self.find_inf_test(self.fp16, self.fp32, self.fp32_ref, self.ny//2, self.nx//2, float('inf'))
        self.find_inf_test(self.fp16, self.fp32, self.fp32_ref, self.ny-1, self.nx-1, float('nan'))

    @unittest.skipIf(disabled, "amp_C is unavailable")
    def test_fp32_to_fp32_find_inf_nan(self):
        self.find_inf_test(self.fp32, self.fp32, self.fp32_ref, 0, 0, float('inf'))
        self.find_inf_test(self.fp32, self.fp32, self.fp32_ref, self.ny//2, self.nx//2, float('nan'))
        self.find_inf_test(self.fp32, self.fp32, self.fp32_ref, self.ny-1, self.nx-1, float('inf'))


if __name__ == '__main__':
    unittest.main()