test_fp16_optimizer.py 879 Bytes
Newer Older
Michael Carilli's avatar
Michael Carilli committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import unittest

import functools as ft
import itertools as it

import torch
from apex.fp16_utils import FP16_Optimizer

class TestFP16Optimizer(unittest.TestCase):
    def setUp(self):
        N, D_in, D_out = 64, 1024, 16
        self.N = N
        self.D_in = D_in
        self.D_out = D_out
        self.x = torch.randn((N, D_in), dtype=torch.float16, device='cuda')
        self.y = torch.randn((N, D_out), dtype=torch.float16, device='cuda')
        self.model = torch.nn.Linear(D_in, D_out).cuda().half()

    # def tearDown(self):
    #     pass

    def test_minimal(self):
        pass

    def test_minimal_static(self):
        pass

    def test_minimal_dynamic(self):
        pass

    def test_closure(self):
        pass

    def test_closure_dynamic(self):
        pass

    def test_save_load(self):
        pass

if __name__ == '__main__':
    unittest.main()