torchvision.py 4.85 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import torchvision
import torchvision.models as tm
from packaging import version

from ..registry import ModelAttribute, model_zoo

data_gen_fn = lambda: dict(x=torch.rand(4, 3, 224, 224))
output_transform_fn = lambda x: dict(output=x)

# special data gen fn
inception_v3_data_gen_fn = lambda: dict(x=torch.rand(4, 3, 299, 299))


# special model fn
def swin_s():
    from torchvision.models.swin_transformer import Swin_T_Weights, _swin_transformer

    # adapted from torchvision.models.swin_transformer.swin_small
    weights = None
    weights = Swin_T_Weights.verify(weights)
    progress = True

    return _swin_transformer(
        patch_size=[4, 4],
        embed_dim=96,
        depths=[2, 2, 6, 2],
        num_heads=[3, 6, 12, 24],
        window_size=[7, 7],
30
        stochastic_depth_prob=0,  # it is originally 0.2, but we set it to 0 to make it deterministic
31
32
33
34
35
36
        weights=weights,
        progress=progress,
    )


# special output transform fn
37
38
39
40
41
42
43
44
45
google_net_output_transform_fn = (
    lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.GoogLeNetOutputs) else dict(output=x)
)
swin_s_output_output_transform_fn = (
    lambda x: {f"output{idx}": val for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x)
)
inception_v3_output_transform_fn = (
    lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.InceptionOutputs) else dict(output=x)
)
46

47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
model_zoo.register(
    name="torchvision_alexnet", model_fn=tm.alexnet, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn
)
model_zoo.register(
    name="torchvision_densenet121",
    model_fn=tm.densenet121,
    data_gen_fn=data_gen_fn,
    output_transform_fn=output_transform_fn,
)
model_zoo.register(
    name="torchvision_efficientnet_b0",
    model_fn=tm.efficientnet_b0,
    data_gen_fn=data_gen_fn,
    output_transform_fn=output_transform_fn,
    model_attribute=ModelAttribute(has_stochastic_depth_prob=True),
)
model_zoo.register(
    name="torchvision_googlenet",
    model_fn=tm.googlenet,
    data_gen_fn=data_gen_fn,
    output_transform_fn=google_net_output_transform_fn,
)
model_zoo.register(
    name="torchvision_inception_v3",
    model_fn=tm.inception_v3,
    data_gen_fn=inception_v3_data_gen_fn,
    output_transform_fn=inception_v3_output_transform_fn,
)
model_zoo.register(
    name="torchvision_mobilenet_v2",
    model_fn=tm.mobilenet_v2,
    data_gen_fn=data_gen_fn,
    output_transform_fn=output_transform_fn,
)
model_zoo.register(
    name="torchvision_mobilenet_v3_small",
    model_fn=tm.mobilenet_v3_small,
    data_gen_fn=data_gen_fn,
    output_transform_fn=output_transform_fn,
)
model_zoo.register(
    name="torchvision_mnasnet0_5",
    model_fn=tm.mnasnet0_5,
    data_gen_fn=data_gen_fn,
    output_transform_fn=output_transform_fn,
)
model_zoo.register(
    name="torchvision_resnet18", model_fn=tm.resnet18, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn
)
model_zoo.register(
    name="torchvision_regnet_x_16gf",
    model_fn=tm.regnet_x_16gf,
    data_gen_fn=data_gen_fn,
    output_transform_fn=output_transform_fn,
)
model_zoo.register(
    name="torchvision_resnext50_32x4d",
    model_fn=tm.resnext50_32x4d,
    data_gen_fn=data_gen_fn,
    output_transform_fn=output_transform_fn,
)
model_zoo.register(
    name="torchvision_shufflenet_v2_x0_5",
    model_fn=tm.shufflenet_v2_x0_5,
    data_gen_fn=data_gen_fn,
    output_transform_fn=output_transform_fn,
)
model_zoo.register(
    name="torchvision_squeezenet1_0",
    model_fn=tm.squeezenet1_0,
    data_gen_fn=data_gen_fn,
    output_transform_fn=output_transform_fn,
)
120

121
122
123
124
125
126
127
128
129
model_zoo.register(
    name="torchvision_vgg11", model_fn=tm.vgg11, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn
)
model_zoo.register(
    name="torchvision_wide_resnet50_2",
    model_fn=tm.wide_resnet50_2,
    data_gen_fn=data_gen_fn,
    output_transform_fn=output_transform_fn,
)
130

131
132
133
134
135
136
137
138
139
140
141
142
143
144
if version.parse(torchvision.__version__) >= version.parse("0.12.0"):
    model_zoo.register(
        name="torchvision_vit_b_16",
        model_fn=tm.vit_b_16,
        data_gen_fn=data_gen_fn,
        output_transform_fn=output_transform_fn,
    )
    model_zoo.register(
        name="torchvision_convnext_base",
        model_fn=tm.convnext_base,
        data_gen_fn=data_gen_fn,
        output_transform_fn=output_transform_fn,
        model_attribute=ModelAttribute(has_stochastic_depth_prob=True),
    )
145

146
if version.parse(torchvision.__version__) >= version.parse("0.13.0"):
147
    model_zoo.register(
148
        name="torchvision_swin_s",
149
150
151
152
        model_fn=swin_s,
        data_gen_fn=data_gen_fn,
        output_transform_fn=swin_s_output_output_transform_fn,
    )
153
154
155
156
157
158
159
    model_zoo.register(
        name="torchvision_efficientnet_v2_s",
        model_fn=tm.efficientnet_v2_s,
        data_gen_fn=data_gen_fn,
        output_transform_fn=output_transform_fn,
        model_attribute=ModelAttribute(has_stochastic_depth_prob=True),
    )