test_fp16_optimizer.py 988 Bytes
Newer Older
Michael Carilli's avatar
Michael Carilli committed
1
2
3
4
5
6
7
8
import unittest

import functools as ft
import itertools as it

import torch
from apex.fp16_utils import FP16_Optimizer

9
10
# Currently no-ops (tested via examples).
# FP16_Optimizer to be deprecated and moved under unified Amp API.
Michael Carilli's avatar
Michael Carilli committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class TestFP16Optimizer(unittest.TestCase):
    def setUp(self):
        N, D_in, D_out = 64, 1024, 16
        self.N = N
        self.D_in = D_in
        self.D_out = D_out
        self.x = torch.randn((N, D_in), dtype=torch.float16, device='cuda')
        self.y = torch.randn((N, D_out), dtype=torch.float16, device='cuda')
        self.model = torch.nn.Linear(D_in, D_out).cuda().half()

    # def tearDown(self):
    #     pass

    def test_minimal(self):
        pass

    def test_minimal_static(self):
        pass

    def test_minimal_dynamic(self):
        pass

    def test_closure(self):
        pass

    def test_closure_dynamic(self):
        pass

    def test_save_load(self):
        pass

if __name__ == '__main__':
    unittest.main()