test_cpp_models.py 5.25 KB
Newer Older
Shahriar's avatar
Shahriar committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
import os
import unittest
from torchvision import models, transforms, _C_tests

from PIL import Image
import torchvision.transforms.functional as F


def process_model(model, tensor, func, name):
    model.eval()
    traced_script_module = torch.jit.trace(model, tensor)
    traced_script_module.save("model.pt")

    py_output = model.forward(tensor)
    cpp_output = func("model.pt", tensor)

    assert torch.allclose(py_output, cpp_output), 'Output mismatch of ' + name + ' models'


def read_image1():
    image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
    image = Image.open(image_path)
    image = image.resize((224, 224))
    x = F.to_tensor(image)
    return x.view(1, 3, 224, 224)


def read_image2():
    image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
    image = Image.open(image_path)
    image = image.resize((299, 299))
    x = F.to_tensor(image)
    x = x.view(1, 3, 299, 299)
    return torch.cat([x, x], 0)


class Tester(unittest.TestCase):
    pretrained = False
    image = read_image1()

    def test_alexnet(self):
        process_model(models.alexnet(self.pretrained), self.image, _C_tests.forward_alexnet, 'Alexnet')

    def test_vgg11(self):
        process_model(models.vgg11(self.pretrained), self.image, _C_tests.forward_vgg11, 'VGG11')

    def test_vgg13(self):
        process_model(models.vgg13(self.pretrained), self.image, _C_tests.forward_vgg13, 'VGG13')

    def test_vgg16(self):
        process_model(models.vgg16(self.pretrained), self.image, _C_tests.forward_vgg16, 'VGG16')

    def test_vgg19(self):
        process_model(models.vgg19(self.pretrained), self.image, _C_tests.forward_vgg19, 'VGG19')

    def test_vgg11_bn(self):
        process_model(models.vgg11_bn(self.pretrained), self.image, _C_tests.forward_vgg11bn, 'VGG11BN')

    def test_vgg13_bn(self):
        process_model(models.vgg13_bn(self.pretrained), self.image, _C_tests.forward_vgg13bn, 'VGG13BN')

    def test_vgg16_bn(self):
        process_model(models.vgg16_bn(self.pretrained), self.image, _C_tests.forward_vgg16bn, 'VGG16BN')

    def test_vgg19_bn(self):
        process_model(models.vgg19_bn(self.pretrained), self.image, _C_tests.forward_vgg19bn, 'VGG19BN')

    def test_resnet18(self):
        process_model(models.resnet18(self.pretrained), self.image, _C_tests.forward_resnet18, 'Resnet18')

    def test_resnet34(self):
        process_model(models.resnet34(self.pretrained), self.image, _C_tests.forward_resnet34, 'Resnet34')

    def test_resnet50(self):
        process_model(models.resnet50(self.pretrained), self.image, _C_tests.forward_resnet50, 'Resnet50')

    def test_resnet101(self):
        process_model(models.resnet101(self.pretrained), self.image, _C_tests.forward_resnet101, 'Resnet101')

    def test_resnet152(self):
        process_model(models.resnet152(self.pretrained), self.image, _C_tests.forward_resnet152, 'Resnet152')

    def test_resnext50_32x4d(self):
        process_model(models.resnext50_32x4d(), self.image, _C_tests.forward_resnext50_32x4d, 'ResNext50_32x4d')

    def test_resnext101_32x8d(self):
        process_model(models.resnext101_32x8d(), self.image, _C_tests.forward_resnext101_32x8d, 'ResNext101_32x8d')

    def test_squeezenet1_0(self):
        process_model(models.squeezenet1_0(self.pretrained), self.image,
                      _C_tests.forward_squeezenet1_0, 'Squeezenet1.0')

    def test_squeezenet1_1(self):
        process_model(models.squeezenet1_1(self.pretrained), self.image,
                      _C_tests.forward_squeezenet1_1, 'Squeezenet1.1')

    def test_densenet121(self):
        process_model(models.densenet121(self.pretrained), self.image, _C_tests.forward_densenet121, 'Densenet121')

    def test_densenet169(self):
        process_model(models.densenet169(self.pretrained), self.image, _C_tests.forward_densenet169, 'Densenet169')

    def test_densenet201(self):
        process_model(models.densenet201(self.pretrained), self.image, _C_tests.forward_densenet201, 'Densenet201')

    def test_densenet161(self):
        process_model(models.densenet161(self.pretrained), self.image, _C_tests.forward_densenet161, 'Densenet161')

    def test_mobilenet_v2(self):
        process_model(models.mobilenet_v2(self.pretrained), self.image, _C_tests.forward_mobilenetv2, 'MobileNet')

    def test_googlenet(self):
        process_model(models.googlenet(self.pretrained), self.image, _C_tests.forward_googlenet, 'GoogLeNet')

Shahriar's avatar
Shahriar committed
116
117
118
119
120
121
122
123
124
125
126
127
    def test_mnasnet0_5(self):
        process_model(models.mnasnet0_5(self.pretrained), self.image, _C_tests.forward_mnasnet0_5, 'MNASNet0_5')

    def test_mnasnet0_75(self):
        process_model(models.mnasnet0_75(self.pretrained), self.image, _C_tests.forward_mnasnet0_75, 'MNASNet0_75')

    def test_mnasnet1_0(self):
        process_model(models.mnasnet1_0(self.pretrained), self.image, _C_tests.forward_mnasnet1_0, 'MNASNet1_0')

    def test_mnasnet1_3(self):
        process_model(models.mnasnet1_3(self.pretrained), self.image, _C_tests.forward_mnasnet1_3, 'MNASNet1_3')

Shahriar's avatar
Shahriar committed
128
129
130
131
132
133
134
    def test_inception_v3(self):
        self.image = read_image2()
        process_model(models.inception_v3(self.pretrained), self.image, _C_tests.forward_inceptionv3, 'Inceptionv3')


if __name__ == '__main__':
    unittest.main()