exporter.py 4 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
5
6
7
8
9
10
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

"""
Binary to convert pytorch detectron2go model to a predictor, which contains model(s) in
deployable format (such as torchscript, caffe2, ...)
"""

import copy
import logging
11
import sys
facebook-github-bot's avatar
facebook-github-bot committed
12
13
14
import typing

import mobile_cv.lut.lib.pt.flops_utils as flops_utils
15
from d2go.config import temp_defrost
16
from d2go.export.exporter import convert_and_export_predictor
17
from d2go.setup import basic_argument_parser, prepare_for_launch, setup_after_launch
facebook-github-bot's avatar
facebook-github-bot committed
18
19
20
21
22
23
24
25
26
27
28
29
from mobile_cv.common.misc.py import post_mortem_if_fail


logger = logging.getLogger("d2go.tools.export")


def main(
    cfg,
    output_dir,
    runner,
    # binary specific optional arguments
    predictor_types: typing.List[str],
30
    device: str = "cpu",
facebook-github-bot's avatar
facebook-github-bot committed
31
32
33
    compare_accuracy: bool = False,
    skip_if_fail: bool = False,
):
34
35
36
37
38
39
40
    if compare_accuracy:
        raise NotImplementedError(
            "compare_accuracy functionality isn't currently supported."
        )
        # NOTE: dict for metrics of all exported models (and original pytorch model)
        # ret["accuracy_comparison"] = accuracy_comparison

facebook-github-bot's avatar
facebook-github-bot committed
41
42
43
44
    cfg = copy.deepcopy(cfg)
    setup_after_launch(cfg, output_dir, runner)

    with temp_defrost(cfg):
45
        cfg.merge_from_list(["MODEL.DEVICE", device])
facebook-github-bot's avatar
facebook-github-bot committed
46
47
48
49
50
    model = runner.build_model(cfg, eval_only=True)

    # NOTE: train dataset is used to avoid leakage since the data might be used for
    # running calibration for quantization. test_loader is used to make sure it follows
    # the inference behaviour (augmentation will not be applied).
51
    datasets = list(cfg.DATASETS.TRAIN)
RangiLyu's avatar
RangiLyu committed
52
    data_loader = runner.build_detection_test_loader(cfg, datasets)
facebook-github-bot's avatar
facebook-github-bot committed
53
54
55
56
57
58
59
60
61
62
63
64

    logger.info("Running the pytorch model and print FLOPS ...")
    first_batch = next(iter(data_loader))
    input_args = (first_batch,)
    flops_utils.print_model_flops(model, input_args)

    predictor_paths: typing.Dict[str, str] = {}
    for typ in predictor_types:
        # convert_and_export_predictor might alter the model, copy before calling it
        pytorch_model = copy.deepcopy(model)
        try:
            predictor_path = convert_and_export_predictor(
65
66
67
68
69
                cfg,
                pytorch_model,
                typ,
                output_dir,
                data_loader,
facebook-github-bot's avatar
facebook-github-bot committed
70
71
72
73
            )
            logger.info(f"Predictor type {typ} has been exported to {predictor_path}")
            predictor_paths[typ] = predictor_path
        except Exception as e:
74
            logger.exception(f"Export {typ} predictor failed: {e}")
facebook-github-bot's avatar
facebook-github-bot committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
            if not skip_if_fail:
                raise e

    ret = {"predictor_paths": predictor_paths, "accuracy_comparison": {}}

    return ret


@post_mortem_if_fail()
def run_with_cmdline_args(args):
    cfg, output_dir, runner = prepare_for_launch(args)
    return main(
        cfg,
        output_dir,
        runner,
        # binary specific optional arguments
        predictor_types=args.predictor_types,
92
        device=args.device,
facebook-github-bot's avatar
facebook-github-bot committed
93
94
95
96
97
98
99
100
101
102
103
104
105
        compare_accuracy=args.compare_accuracy,
        skip_if_fail=args.skip_if_fail,
    )


def get_parser():
    parser = basic_argument_parser(distributed=False)
    parser.add_argument(
        "--predictor-types",
        type=str,
        nargs="+",
        help="List of strings specify the types of predictors to export",
    )
106
107
108
    parser.add_argument(
        "--device", default="cpu", help="the device to export the model on"
    )
facebook-github-bot's avatar
facebook-github-bot committed
109
110
111
112
    parser.add_argument(
        "--compare-accuracy",
        action="store_true",
        help="If true, all exported models and the original pytorch model will be"
Alexander Pivovarov's avatar
Alexander Pivovarov committed
113
        " evaluated on cfg.DATASETS.TEST",
facebook-github-bot's avatar
facebook-github-bot committed
114
115
116
117
118
119
120
121
122
123
    )
    parser.add_argument(
        "--skip-if-fail",
        action="store_true",
        default=False,
        help="If set, suppress the exception for failed exporting and continue to"
        " export the next type of model",
    )
    return parser

124

Tsahi Glik's avatar
Tsahi Glik committed
125
126
def cli(args=None):
    args = sys.argv[1:] if args is None else args
127
    run_with_cmdline_args(get_parser().parse_args(args))
facebook-github-bot's avatar
facebook-github-bot committed
128

129

facebook-github-bot's avatar
facebook-github-bot committed
130
if __name__ == "__main__":
Tsahi Glik's avatar
Tsahi Glik committed
131
    cli()