test_cache.py 5.94 KB
Newer Older
1
2
3
4
5
6
import unittest

import functools as ft
import itertools as it

from apex import amp
7
from apex.amp import _amp_state
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
from torch import nn
import torch.nn.functional as F

from utils import common_init, HALF, FLOAT,\
    ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT

def get_reference_grad(i, w, ops):
    # Creating new tensors ensures, among other things, that the new tensors are not in the cache.
    # In fact, they are guaranteed not to use the cache because they are not torch.nn.Parameters.
    fp32_i = i.detach().clone().float()
    fp32_w = w.detach().clone().float().requires_grad_()
    loss = ops(fp32_i, fp32_w)
    loss.backward()
    return fp32_w.grad

class WhitelistModule(torch.nn.Module):
    def __init__(self, dtype):
        super(WhitelistModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.arange(8*8, device='cuda', dtype=dtype).view(8,8))

    @staticmethod
    def ops(input, weight):
        return (input.mm(weight)).mm(weight).sum()

    def forward(self, input):
        return self.ops(input, self.weight)


class BlacklistModule(torch.nn.Module):
    def __init__(self, dtype):
        super(BlacklistModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.arange(2*8, device='cuda', dtype=dtype).view(2,8))

    @staticmethod
    def ops(input, weight):
        return (input + torch.pow(weight, 2) + torch.pow(weight, 2)).sum()

    def forward(self, input):
        return self.ops(input, self.weight)


class PromoteModule(torch.nn.Module):
    def __init__(self, dtype):
        super(PromoteModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.arange(2*8, device='cuda', dtype=dtype).view(2,8))

    @staticmethod
    def ops(input, weight):
        return ((input*weight)*weight).sum()

    def forward(self, input):
        return self.ops(input, self.weight)

class TestCache(unittest.TestCase):
    def setUp(self):
        self.x = torch.ones((2, 8), device='cuda', dtype=torch.float32)
        common_init(self)

    def tearDown(self):
68
        pass
69

rohithkrn's avatar
rohithkrn committed
70
    def train_eval_train_test(self, module, t, opt_level):
71
        model = module(t).cuda()
72
73
74
        optimizer = torch.optim.SGD(model.parameters(), lr=1.0)

        _amp_state.allow_incoming_model_not_fp32 = True
rohithkrn's avatar
rohithkrn committed
75
        model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level, verbosity=0)
76
        _amp_state.allow_incoming_model_not_fp32 = False
77
78
79
80
81
82
        
        def training_step():
            for param in model.parameters():
                param.grad = None
        
            loss = model(self.x).sum()
83
84
            _amp_state.loss_scalers[0]._loss_scale = 4.0
            with amp.scale_loss(loss, optimizer) as scaled_loss:
85
86
87
88
89
90
91
92
93
94
95
                scaled_loss.backward()
        
            self.assertEqual(len([p.grad for p in model.parameters() if p.grad is not None]), 1)
            self.assertEqual(model.weight.grad.type(), model.weight.type())
        
            reference_grad = get_reference_grad(self.x, model.weight, model.ops)
        
            # Currently there's no difference in the allclose calls, so no need for branching,
            # but I'm keeping this in case we want different tolerances for fp16 and fp32 checks. 
            if model.weight.grad.type() == "torch.cuda.HalfTensor":
                self.assertTrue(torch.allclose(model.weight.grad.float(), reference_grad))
rohithkrn's avatar
rohithkrn committed
96
97
            elif model.weight.grad.type() == "torch.cuda.BFloat16Tensor":
                self.assertTrue(torch.allclose(model.weight.grad.float(), reference_grad))
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
            elif model.weight.grad.type() == "torch.cuda.FloatTensor":
                self.assertTrue(torch.allclose(model.weight.grad.float(), reference_grad))
            else:
                raise RuntimeError("model.weight.grad.type = {}".format(model.weight.grad.type()))

            model.weight.data -= 1.
        
        # Simulates first epoch
        training_step()
        
        # Simulates eval
        with torch.no_grad():
            loss = model(self.x).sum()
        
        # Simulates resuming training after eval
        training_step()
114
115

        _amp_state.handle._deactivate()
116
117
118
119
   
    # I could easily have these as a set of for loops in a single test,
    # instead of going for granularity.
    def test_whitelist_module_fp16_weight(self):
rohithkrn's avatar
rohithkrn committed
120
        self.train_eval_train_test(WhitelistModule, torch.float16, "O1")
121
122

    def test_whitelist_module_fp32_weight(self):
rohithkrn's avatar
rohithkrn committed
123
        self.train_eval_train_test(WhitelistModule, torch.float32, "O1")
124
125

    def test_blacklist_module_fp16_weight(self):
rohithkrn's avatar
rohithkrn committed
126
        self.train_eval_train_test(BlacklistModule, torch.float16, "O1")
127
128

    def test_blacklist_module_fp32_weight(self):
rohithkrn's avatar
rohithkrn committed
129
        self.train_eval_train_test(BlacklistModule, torch.float32, "O1")
130
131

    def test_promote_module_fp16_weight(self):
rohithkrn's avatar
rohithkrn committed
132
133
134
135
136
137
138
139
140
        self.train_eval_train_test(PromoteModule, torch.float16, "O1")

    def test_promote_module_fp32_weight(self):
        self.train_eval_train_test(PromoteModule, torch.float32, "O1")

    # opt_level = O4
    def test_whitelist_module_bfp16_weight(self):
        self.train_eval_train_test(WhitelistModule, torch.bfloat16, "O4")

Hubert Lu's avatar
Hubert Lu committed
141
    @unittest.skip("The failing unit test is introduced by a PyTorch commit sometime in between rocm/pytorch:rocm4.3.1_ubuntu18.04_py3.6_pytorch_1.9.0 and 2021/12/01. Same error is also observed on CUDA. Please refer to https://github.com/ROCmSoftwarePlatform/apex/issues/62")
rohithkrn's avatar
rohithkrn committed
142
143
144
145
146
147
148
149
150
151
152
    def test_whitelist_module_fp32_weight(self):
        self.train_eval_train_test(WhitelistModule, torch.float32, "O4")

    def test_blacklist_module_bfp16_weight(self):
        self.train_eval_train_test(BlacklistModule, torch.bfloat16, "O4")

    def test_blacklist_module_fp32_weight(self):
        self.train_eval_train_test(BlacklistModule, torch.float32, "O4")

    def test_promote_module_bfp16_weight(self):
        self.train_eval_train_test(PromoteModule, torch.bfloat16, "O4")
153
154

    def test_promote_module_fp32_weight(self):
rohithkrn's avatar
rohithkrn committed
155
        self.train_eval_train_test(PromoteModule, torch.float32, "O4")
156
157
158
159


if __name__ == '__main__':
    unittest.main()