test_cpp_models.py 5.94 KB
Newer Older
Shahriar's avatar
Shahriar committed
1
import os
2
import sys
3
import unittest
Shahriar's avatar
Shahriar committed
4

5
import torch
Shahriar's avatar
Shahriar committed
6
import torchvision.transforms.functional as F
7
from PIL import Image
8
from torchvision import models
Shahriar's avatar
Shahriar committed
9

10
11
12
13
14
try:
    from torchvision import _C_tests
except ImportError:
    _C_tests = None

Shahriar's avatar
Shahriar committed
15
16
17
18
19
20
21
22
23

def process_model(model, tensor, func, name):
    model.eval()
    traced_script_module = torch.jit.trace(model, tensor)
    traced_script_module.save("model.pt")

    py_output = model.forward(tensor)
    cpp_output = func("model.pt", tensor)

24
    assert torch.allclose(py_output, cpp_output), "Output mismatch of " + name + " models"
Shahriar's avatar
Shahriar committed
25
26
27


def read_image1():
28
29
30
    image_path = os.path.join(
        os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg"
    )
Shahriar's avatar
Shahriar committed
31
32
    image = Image.open(image_path)
    image = image.resize((224, 224))
33
34
    x = F.pil_to_tensor(image)
    x = F.convert_image_dtype(x)
Shahriar's avatar
Shahriar committed
35
36
37
38
    return x.view(1, 3, 224, 224)


def read_image2():
39
40
41
    image_path = os.path.join(
        os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg"
    )
Shahriar's avatar
Shahriar committed
42
43
    image = Image.open(image_path)
    image = image.resize((299, 299))
44
45
    x = F.pil_to_tensor(image)
    x = F.convert_image_dtype(x)
Shahriar's avatar
Shahriar committed
46
47
48
49
    x = x.view(1, 3, 299, 299)
    return torch.cat([x, x], 0)


50
51
@unittest.skipIf(
    sys.platform == "darwin" or True,
52
    "C++ models are broken on OS X at the moment, and there's a BC breakage on main; "
53
54
    "see https://github.com/pytorch/vision/issues/1191",
)
Shahriar's avatar
Shahriar committed
55
56
57
58
59
class Tester(unittest.TestCase):
    pretrained = False
    image = read_image1()

    def test_alexnet(self):
60
        process_model(models.alexnet(self.pretrained), self.image, _C_tests.forward_alexnet, "Alexnet")
Shahriar's avatar
Shahriar committed
61
62

    def test_vgg11(self):
63
        process_model(models.vgg11(self.pretrained), self.image, _C_tests.forward_vgg11, "VGG11")
Shahriar's avatar
Shahriar committed
64
65

    def test_vgg13(self):
66
        process_model(models.vgg13(self.pretrained), self.image, _C_tests.forward_vgg13, "VGG13")
Shahriar's avatar
Shahriar committed
67
68

    def test_vgg16(self):
69
        process_model(models.vgg16(self.pretrained), self.image, _C_tests.forward_vgg16, "VGG16")
Shahriar's avatar
Shahriar committed
70
71

    def test_vgg19(self):
72
        process_model(models.vgg19(self.pretrained), self.image, _C_tests.forward_vgg19, "VGG19")
Shahriar's avatar
Shahriar committed
73
74

    def test_vgg11_bn(self):
75
        process_model(models.vgg11_bn(self.pretrained), self.image, _C_tests.forward_vgg11bn, "VGG11BN")
Shahriar's avatar
Shahriar committed
76
77

    def test_vgg13_bn(self):
78
        process_model(models.vgg13_bn(self.pretrained), self.image, _C_tests.forward_vgg13bn, "VGG13BN")
Shahriar's avatar
Shahriar committed
79
80

    def test_vgg16_bn(self):
81
        process_model(models.vgg16_bn(self.pretrained), self.image, _C_tests.forward_vgg16bn, "VGG16BN")
Shahriar's avatar
Shahriar committed
82
83

    def test_vgg19_bn(self):
84
        process_model(models.vgg19_bn(self.pretrained), self.image, _C_tests.forward_vgg19bn, "VGG19BN")
Shahriar's avatar
Shahriar committed
85
86

    def test_resnet18(self):
87
        process_model(models.resnet18(self.pretrained), self.image, _C_tests.forward_resnet18, "Resnet18")
Shahriar's avatar
Shahriar committed
88
89

    def test_resnet34(self):
90
        process_model(models.resnet34(self.pretrained), self.image, _C_tests.forward_resnet34, "Resnet34")
Shahriar's avatar
Shahriar committed
91
92

    def test_resnet50(self):
93
        process_model(models.resnet50(self.pretrained), self.image, _C_tests.forward_resnet50, "Resnet50")
Shahriar's avatar
Shahriar committed
94
95

    def test_resnet101(self):
96
        process_model(models.resnet101(self.pretrained), self.image, _C_tests.forward_resnet101, "Resnet101")
Shahriar's avatar
Shahriar committed
97
98

    def test_resnet152(self):
99
        process_model(models.resnet152(self.pretrained), self.image, _C_tests.forward_resnet152, "Resnet152")
Shahriar's avatar
Shahriar committed
100
101

    def test_resnext50_32x4d(self):
102
        process_model(models.resnext50_32x4d(), self.image, _C_tests.forward_resnext50_32x4d, "ResNext50_32x4d")
Shahriar's avatar
Shahriar committed
103
104

    def test_resnext101_32x8d(self):
105
        process_model(models.resnext101_32x8d(), self.image, _C_tests.forward_resnext101_32x8d, "ResNext101_32x8d")
Shahriar's avatar
Shahriar committed
106

107
    def test_wide_resnet50_2(self):
108
        process_model(models.wide_resnet50_2(), self.image, _C_tests.forward_wide_resnet50_2, "WideResNet50_2")
109
110

    def test_wide_resnet101_2(self):
111
        process_model(models.wide_resnet101_2(), self.image, _C_tests.forward_wide_resnet101_2, "WideResNet101_2")
112

Shahriar's avatar
Shahriar committed
113
    def test_squeezenet1_0(self):
114
115
116
        process_model(
            models.squeezenet1_0(self.pretrained), self.image, _C_tests.forward_squeezenet1_0, "Squeezenet1.0"
        )
Shahriar's avatar
Shahriar committed
117
118

    def test_squeezenet1_1(self):
119
120
121
        process_model(
            models.squeezenet1_1(self.pretrained), self.image, _C_tests.forward_squeezenet1_1, "Squeezenet1.1"
        )
Shahriar's avatar
Shahriar committed
122
123

    def test_densenet121(self):
124
        process_model(models.densenet121(self.pretrained), self.image, _C_tests.forward_densenet121, "Densenet121")
Shahriar's avatar
Shahriar committed
125
126

    def test_densenet169(self):
127
        process_model(models.densenet169(self.pretrained), self.image, _C_tests.forward_densenet169, "Densenet169")
Shahriar's avatar
Shahriar committed
128
129

    def test_densenet201(self):
130
        process_model(models.densenet201(self.pretrained), self.image, _C_tests.forward_densenet201, "Densenet201")
Shahriar's avatar
Shahriar committed
131
132

    def test_densenet161(self):
133
        process_model(models.densenet161(self.pretrained), self.image, _C_tests.forward_densenet161, "Densenet161")
Shahriar's avatar
Shahriar committed
134
135

    def test_mobilenet_v2(self):
136
        process_model(models.mobilenet_v2(self.pretrained), self.image, _C_tests.forward_mobilenetv2, "MobileNet")
Shahriar's avatar
Shahriar committed
137
138

    def test_googlenet(self):
139
        process_model(models.googlenet(self.pretrained), self.image, _C_tests.forward_googlenet, "GoogLeNet")
Shahriar's avatar
Shahriar committed
140

Shahriar's avatar
Shahriar committed
141
    def test_mnasnet0_5(self):
142
        process_model(models.mnasnet0_5(self.pretrained), self.image, _C_tests.forward_mnasnet0_5, "MNASNet0_5")
Shahriar's avatar
Shahriar committed
143
144

    def test_mnasnet0_75(self):
145
        process_model(models.mnasnet0_75(self.pretrained), self.image, _C_tests.forward_mnasnet0_75, "MNASNet0_75")
Shahriar's avatar
Shahriar committed
146
147

    def test_mnasnet1_0(self):
148
        process_model(models.mnasnet1_0(self.pretrained), self.image, _C_tests.forward_mnasnet1_0, "MNASNet1_0")
Shahriar's avatar
Shahriar committed
149
150

    def test_mnasnet1_3(self):
151
        process_model(models.mnasnet1_3(self.pretrained), self.image, _C_tests.forward_mnasnet1_3, "MNASNet1_3")
Shahriar's avatar
Shahriar committed
152

Shahriar's avatar
Shahriar committed
153
154
    def test_inception_v3(self):
        self.image = read_image2()
155
        process_model(models.inception_v3(self.pretrained), self.image, _C_tests.forward_inceptionv3, "Inceptionv3")
Shahriar's avatar
Shahriar committed
156
157


158
if __name__ == "__main__":
Shahriar's avatar
Shahriar committed
159
    unittest.main()