exporter.py 3.94 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
5
6
7
8
9
10
11
12
13
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

"""
Binary to convert pytorch detectron2go model to a predictor, which contains model(s) in
deployable format (such as torchscript, caffe2, ...)
"""

import copy
import logging
import typing

import mobile_cv.lut.lib.pt.flops_utils as flops_utils
14
from d2go.config import temp_defrost
facebook-github-bot's avatar
facebook-github-bot committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from d2go.export.api import convert_and_export_predictor
from d2go.setup import (
    basic_argument_parser,
    prepare_for_launch,
    setup_after_launch,
)
from mobile_cv.common.misc.py import post_mortem_if_fail


logger = logging.getLogger("d2go.tools.export")


def main(
    cfg,
    output_dir,
    runner,
    # binary specific optional arguments
    predictor_types: typing.List[str],
33
    device: str = "cpu",
facebook-github-bot's avatar
facebook-github-bot committed
34
35
36
    compare_accuracy: bool = False,
    skip_if_fail: bool = False,
):
37
38
39
40
41
42
43
    if compare_accuracy:
        raise NotImplementedError(
            "compare_accuracy functionality isn't currently supported."
        )
        # NOTE: dict for metrics of all exported models (and original pytorch model)
        # ret["accuracy_comparison"] = accuracy_comparison

facebook-github-bot's avatar
facebook-github-bot committed
44
45
46
47
    cfg = copy.deepcopy(cfg)
    setup_after_launch(cfg, output_dir, runner)

    with temp_defrost(cfg):
48
        cfg.merge_from_list(["MODEL.DEVICE", device])
facebook-github-bot's avatar
facebook-github-bot committed
49
50
51
52
53
    model = runner.build_model(cfg, eval_only=True)

    # NOTE: train dataset is used to avoid leakage since the data might be used for
    # running calibration for quantization. test_loader is used to make sure it follows
    # the inference behaviour (augmentation will not be applied).
54
    datasets = list(cfg.DATASETS.TRAIN)
RangiLyu's avatar
RangiLyu committed
55
    data_loader = runner.build_detection_test_loader(cfg, datasets)
facebook-github-bot's avatar
facebook-github-bot committed
56
57
58
59
60
61
62
63
64
65
66
67

    logger.info("Running the pytorch model and print FLOPS ...")
    first_batch = next(iter(data_loader))
    input_args = (first_batch,)
    flops_utils.print_model_flops(model, input_args)

    predictor_paths: typing.Dict[str, str] = {}
    for typ in predictor_types:
        # convert_and_export_predictor might alter the model, copy before calling it
        pytorch_model = copy.deepcopy(model)
        try:
            predictor_path = convert_and_export_predictor(
68
69
70
71
72
                cfg,
                pytorch_model,
                typ,
                output_dir,
                data_loader,
facebook-github-bot's avatar
facebook-github-bot committed
73
74
75
76
            )
            logger.info(f"Predictor type {typ} has been exported to {predictor_path}")
            predictor_paths[typ] = predictor_path
        except Exception as e:
77
            logger.exception(f"Export {typ} predictor failed: {e}")
facebook-github-bot's avatar
facebook-github-bot committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
            if not skip_if_fail:
                raise e

    ret = {"predictor_paths": predictor_paths, "accuracy_comparison": {}}

    return ret


@post_mortem_if_fail()
def run_with_cmdline_args(args):
    cfg, output_dir, runner = prepare_for_launch(args)
    return main(
        cfg,
        output_dir,
        runner,
        # binary specific optional arguments
        predictor_types=args.predictor_types,
95
        device=args.device,
facebook-github-bot's avatar
facebook-github-bot committed
96
97
98
99
100
101
102
103
104
105
106
107
108
        compare_accuracy=args.compare_accuracy,
        skip_if_fail=args.skip_if_fail,
    )


def get_parser():
    parser = basic_argument_parser(distributed=False)
    parser.add_argument(
        "--predictor-types",
        type=str,
        nargs="+",
        help="List of strings specify the types of predictors to export",
    )
109
110
111
    parser.add_argument(
        "--device", default="cpu", help="the device to export the model on"
    )
facebook-github-bot's avatar
facebook-github-bot committed
112
113
114
115
    parser.add_argument(
        "--compare-accuracy",
        action="store_true",
        help="If true, all exported models and the original pytorch model will be"
Alexander Pivovarov's avatar
Alexander Pivovarov committed
116
        " evaluated on cfg.DATASETS.TEST",
facebook-github-bot's avatar
facebook-github-bot committed
117
118
119
120
121
122
123
124
125
126
    )
    parser.add_argument(
        "--skip-if-fail",
        action="store_true",
        default=False,
        help="If set, suppress the exception for failed exporting and continue to"
        " export the next type of model",
    )
    return parser

127

facebook-github-bot's avatar
facebook-github-bot committed
128
129
130
def cli():
    run_with_cmdline_args(get_parser().parse_args())

131

facebook-github-bot's avatar
facebook-github-bot committed
132
133
if __name__ == "__main__":
    cli()