export_coreml.py 1.2 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch

from timm import create_model
import model

import utils

import torch
import torchvision
from argparse import ArgumentParser

parser = ArgumentParser()

parser.add_argument('--model', default='repvit_m1_1', type=str)
parser.add_argument('--resolution', default=224, type=int)
parser.add_argument('--ckpt', default=None, type=str)

if __name__ == "__main__":
    # Load a pre-trained version of MobileNetV2
    args = parser.parse_args()
    model = create_model(args.model, distillation=True)
    if args.ckpt:
        model.load_state_dict(torch.load(args.ckpt)['model'])
    utils.replace_batchnorm(model)
    model.eval()

    # Trace the model with random data.
    resolution = args.resolution
    example_input = torch.rand(1, 3, resolution, resolution) 
    traced_model = torch.jit.trace(model, example_input)
    out = traced_model(example_input)

    import coremltools as ct

    # Using image_input in the inputs parameter:
    # Convert to Core ML neural network using the Unified Conversion API.
    model = ct.convert(
        traced_model,
        inputs=[ct.ImageType(shape=example_input.shape)]
    )

    # Save the converted model.
    model.save(f"coreml/{args.model}_{resolution}.mlmodel")