test_cpp_models.py 5.48 KB
Newer Older
Shahriar's avatar
Shahriar committed
1
import os
2
import sys
3
import unittest
Shahriar's avatar
Shahriar committed
4

5
import torch
Shahriar's avatar
Shahriar committed
6
import torchvision.transforms.functional as F
7
from PIL import Image
8
from torchvision import models
Shahriar's avatar
Shahriar committed
9

10
11
12
13
14
try:
    from torchvision import _C_tests
except ImportError:
    _C_tests = None

Shahriar's avatar
Shahriar committed
15
16
17
18
19
20
21
22
23

def process_model(model, tensor, func, name):
    model.eval()
    traced_script_module = torch.jit.trace(model, tensor)
    traced_script_module.save("model.pt")

    py_output = model.forward(tensor)
    cpp_output = func("model.pt", tensor)

24
    assert torch.allclose(py_output, cpp_output), "Output mismatch of " + name + " models"
Shahriar's avatar
Shahriar committed
25
26
27


def read_image1():
28
29
30
    image_path = os.path.join(
        os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg"
    )
Shahriar's avatar
Shahriar committed
31
32
    image = Image.open(image_path)
    image = image.resize((224, 224))
33
34
    x = F.pil_to_tensor(image)
    x = F.convert_image_dtype(x)
Shahriar's avatar
Shahriar committed
35
36
37
38
    return x.view(1, 3, 224, 224)


def read_image2():
39
40
41
    image_path = os.path.join(
        os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg"
    )
Shahriar's avatar
Shahriar committed
42
43
    image = Image.open(image_path)
    image = image.resize((299, 299))
44
45
    x = F.pil_to_tensor(image)
    x = F.convert_image_dtype(x)
Shahriar's avatar
Shahriar committed
46
47
48
49
    x = x.view(1, 3, 299, 299)
    return torch.cat([x, x], 0)


50
51
@unittest.skipIf(
    sys.platform == "darwin" or True,
52
    "C++ models are broken on OS X at the moment, and there's a BC breakage on main; "
53
54
    "see https://github.com/pytorch/vision/issues/1191",
)
Shahriar's avatar
Shahriar committed
55
56
57
58
class Tester(unittest.TestCase):
    image = read_image1()

    def test_alexnet(self):
59
        process_model(models.alexnet(), self.image, _C_tests.forward_alexnet, "Alexnet")
Shahriar's avatar
Shahriar committed
60
61

    def test_vgg11(self):
62
        process_model(models.vgg11(), self.image, _C_tests.forward_vgg11, "VGG11")
Shahriar's avatar
Shahriar committed
63
64

    def test_vgg13(self):
65
        process_model(models.vgg13(), self.image, _C_tests.forward_vgg13, "VGG13")
Shahriar's avatar
Shahriar committed
66
67

    def test_vgg16(self):
68
        process_model(models.vgg16(), self.image, _C_tests.forward_vgg16, "VGG16")
Shahriar's avatar
Shahriar committed
69
70

    def test_vgg19(self):
71
        process_model(models.vgg19(), self.image, _C_tests.forward_vgg19, "VGG19")
Shahriar's avatar
Shahriar committed
72
73

    def test_vgg11_bn(self):
74
        process_model(models.vgg11_bn(), self.image, _C_tests.forward_vgg11bn, "VGG11BN")
Shahriar's avatar
Shahriar committed
75
76

    def test_vgg13_bn(self):
77
        process_model(models.vgg13_bn(), self.image, _C_tests.forward_vgg13bn, "VGG13BN")
Shahriar's avatar
Shahriar committed
78
79

    def test_vgg16_bn(self):
80
        process_model(models.vgg16_bn(), self.image, _C_tests.forward_vgg16bn, "VGG16BN")
Shahriar's avatar
Shahriar committed
81
82

    def test_vgg19_bn(self):
83
        process_model(models.vgg19_bn(), self.image, _C_tests.forward_vgg19bn, "VGG19BN")
Shahriar's avatar
Shahriar committed
84
85

    def test_resnet18(self):
86
        process_model(models.resnet18(), self.image, _C_tests.forward_resnet18, "Resnet18")
Shahriar's avatar
Shahriar committed
87
88

    def test_resnet34(self):
89
        process_model(models.resnet34(), self.image, _C_tests.forward_resnet34, "Resnet34")
Shahriar's avatar
Shahriar committed
90
91

    def test_resnet50(self):
92
        process_model(models.resnet50(), self.image, _C_tests.forward_resnet50, "Resnet50")
Shahriar's avatar
Shahriar committed
93
94

    def test_resnet101(self):
95
        process_model(models.resnet101(), self.image, _C_tests.forward_resnet101, "Resnet101")
Shahriar's avatar
Shahriar committed
96
97

    def test_resnet152(self):
98
        process_model(models.resnet152(), self.image, _C_tests.forward_resnet152, "Resnet152")
Shahriar's avatar
Shahriar committed
99
100

    def test_resnext50_32x4d(self):
101
        process_model(models.resnext50_32x4d(), self.image, _C_tests.forward_resnext50_32x4d, "ResNext50_32x4d")
Shahriar's avatar
Shahriar committed
102
103

    def test_resnext101_32x8d(self):
104
        process_model(models.resnext101_32x8d(), self.image, _C_tests.forward_resnext101_32x8d, "ResNext101_32x8d")
Shahriar's avatar
Shahriar committed
105

106
    def test_wide_resnet50_2(self):
107
        process_model(models.wide_resnet50_2(), self.image, _C_tests.forward_wide_resnet50_2, "WideResNet50_2")
108
109

    def test_wide_resnet101_2(self):
110
        process_model(models.wide_resnet101_2(), self.image, _C_tests.forward_wide_resnet101_2, "WideResNet101_2")
111

Shahriar's avatar
Shahriar committed
112
    def test_squeezenet1_0(self):
113
        process_model(models.squeezenet1_0(), self.image, _C_tests.forward_squeezenet1_0, "Squeezenet1.0")
Shahriar's avatar
Shahriar committed
114
115

    def test_squeezenet1_1(self):
116
        process_model(models.squeezenet1_1(), self.image, _C_tests.forward_squeezenet1_1, "Squeezenet1.1")
Shahriar's avatar
Shahriar committed
117
118

    def test_densenet121(self):
119
        process_model(models.densenet121(), self.image, _C_tests.forward_densenet121, "Densenet121")
Shahriar's avatar
Shahriar committed
120
121

    def test_densenet169(self):
122
        process_model(models.densenet169(), self.image, _C_tests.forward_densenet169, "Densenet169")
Shahriar's avatar
Shahriar committed
123
124

    def test_densenet201(self):
125
        process_model(models.densenet201(), self.image, _C_tests.forward_densenet201, "Densenet201")
Shahriar's avatar
Shahriar committed
126
127

    def test_densenet161(self):
128
        process_model(models.densenet161(), self.image, _C_tests.forward_densenet161, "Densenet161")
Shahriar's avatar
Shahriar committed
129
130

    def test_mobilenet_v2(self):
131
        process_model(models.mobilenet_v2(), self.image, _C_tests.forward_mobilenetv2, "MobileNet")
Shahriar's avatar
Shahriar committed
132
133

    def test_googlenet(self):
134
        process_model(models.googlenet(), self.image, _C_tests.forward_googlenet, "GoogLeNet")
Shahriar's avatar
Shahriar committed
135

Shahriar's avatar
Shahriar committed
136
    def test_mnasnet0_5(self):
137
        process_model(models.mnasnet0_5(), self.image, _C_tests.forward_mnasnet0_5, "MNASNet0_5")
Shahriar's avatar
Shahriar committed
138
139

    def test_mnasnet0_75(self):
140
        process_model(models.mnasnet0_75(), self.image, _C_tests.forward_mnasnet0_75, "MNASNet0_75")
Shahriar's avatar
Shahriar committed
141
142

    def test_mnasnet1_0(self):
143
        process_model(models.mnasnet1_0(), self.image, _C_tests.forward_mnasnet1_0, "MNASNet1_0")
Shahriar's avatar
Shahriar committed
144
145

    def test_mnasnet1_3(self):
146
        process_model(models.mnasnet1_3(), self.image, _C_tests.forward_mnasnet1_3, "MNASNet1_3")
Shahriar's avatar
Shahriar committed
147

Shahriar's avatar
Shahriar committed
148
149
    def test_inception_v3(self):
        self.image = read_image2()
150
        process_model(models.inception_v3(), self.image, _C_tests.forward_inceptionv3, "Inceptionv3")
Shahriar's avatar
Shahriar committed
151
152


153
if __name__ == "__main__":
Shahriar's avatar
Shahriar committed
154
    unittest.main()