exporter.py 4.13 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
5
6
7
8
9
10
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

"""
Binary to convert pytorch detectron2go model to a predictor, which contains model(s) in
deployable format (such as torchscript, caffe2, ...)
"""

import copy
import logging
11
import sys
12
from typing import Dict, List, Type, Union
facebook-github-bot's avatar
facebook-github-bot committed
13
14

import mobile_cv.lut.lib.pt.flops_utils as flops_utils
15
from d2go.config import CfgNode, temp_defrost
16
from d2go.export.exporter import convert_and_export_predictor
17
from d2go.runner import BaseRunner
18
from d2go.setup import basic_argument_parser, prepare_for_launch, setup_after_launch
facebook-github-bot's avatar
facebook-github-bot committed
19
20
21
22
23
24
25
from mobile_cv.common.misc.py import post_mortem_if_fail


logger = logging.getLogger("d2go.tools.export")


def main(
26
27
28
    cfg: CfgNode,
    output_dir: str,
    runner_class: Union[str, Type[BaseRunner]],
facebook-github-bot's avatar
facebook-github-bot committed
29
    # binary specific optional arguments
30
    predictor_types: List[str],
31
    device: str = "cpu",
facebook-github-bot's avatar
facebook-github-bot committed
32
33
34
    compare_accuracy: bool = False,
    skip_if_fail: bool = False,
):
35
36
37
38
39
40
41
    if compare_accuracy:
        raise NotImplementedError(
            "compare_accuracy functionality isn't currently supported."
        )
        # NOTE: dict for metrics of all exported models (and original pytorch model)
        # ret["accuracy_comparison"] = accuracy_comparison

facebook-github-bot's avatar
facebook-github-bot committed
42
    cfg = copy.deepcopy(cfg)
43
    runner = setup_after_launch(cfg, output_dir, runner_class)
facebook-github-bot's avatar
facebook-github-bot committed
44
45

    with temp_defrost(cfg):
46
        cfg.merge_from_list(["MODEL.DEVICE", device])
facebook-github-bot's avatar
facebook-github-bot committed
47
48
49
50
51
    model = runner.build_model(cfg, eval_only=True)

    # NOTE: train dataset is used to avoid leakage since the data might be used for
    # running calibration for quantization. test_loader is used to make sure it follows
    # the inference behaviour (augmentation will not be applied).
52
    datasets = list(cfg.DATASETS.TRAIN)
RangiLyu's avatar
RangiLyu committed
53
    data_loader = runner.build_detection_test_loader(cfg, datasets)
facebook-github-bot's avatar
facebook-github-bot committed
54
55
56
57
58
59

    logger.info("Running the pytorch model and print FLOPS ...")
    first_batch = next(iter(data_loader))
    input_args = (first_batch,)
    flops_utils.print_model_flops(model, input_args)

60
    predictor_paths: Dict[str, str] = {}
facebook-github-bot's avatar
facebook-github-bot committed
61
62
63
64
65
    for typ in predictor_types:
        # convert_and_export_predictor might alter the model, copy before calling it
        pytorch_model = copy.deepcopy(model)
        try:
            predictor_path = convert_and_export_predictor(
66
67
68
69
70
                cfg,
                pytorch_model,
                typ,
                output_dir,
                data_loader,
facebook-github-bot's avatar
facebook-github-bot committed
71
72
73
74
            )
            logger.info(f"Predictor type {typ} has been exported to {predictor_path}")
            predictor_paths[typ] = predictor_path
        except Exception as e:
75
            logger.exception(f"Export {typ} predictor failed: {e}")
facebook-github-bot's avatar
facebook-github-bot committed
76
77
78
79
80
81
82
83
84
85
            if not skip_if_fail:
                raise e

    ret = {"predictor_paths": predictor_paths, "accuracy_comparison": {}}

    return ret


@post_mortem_if_fail()
def run_with_cmdline_args(args):
86
    cfg, output_dir, runner_name = prepare_for_launch(args)
facebook-github-bot's avatar
facebook-github-bot committed
87
88
89
    return main(
        cfg,
        output_dir,
90
        runner_name,
facebook-github-bot's avatar
facebook-github-bot committed
91
92
        # binary specific optional arguments
        predictor_types=args.predictor_types,
93
        device=args.device,
facebook-github-bot's avatar
facebook-github-bot committed
94
95
96
97
98
99
100
101
102
103
104
105
106
        compare_accuracy=args.compare_accuracy,
        skip_if_fail=args.skip_if_fail,
    )


def get_parser():
    parser = basic_argument_parser(distributed=False)
    parser.add_argument(
        "--predictor-types",
        type=str,
        nargs="+",
        help="List of strings specify the types of predictors to export",
    )
107
108
109
    parser.add_argument(
        "--device", default="cpu", help="the device to export the model on"
    )
facebook-github-bot's avatar
facebook-github-bot committed
110
111
112
113
    parser.add_argument(
        "--compare-accuracy",
        action="store_true",
        help="If true, all exported models and the original pytorch model will be"
Alexander Pivovarov's avatar
Alexander Pivovarov committed
114
        " evaluated on cfg.DATASETS.TEST",
facebook-github-bot's avatar
facebook-github-bot committed
115
116
117
118
119
120
121
122
123
124
    )
    parser.add_argument(
        "--skip-if-fail",
        action="store_true",
        default=False,
        help="If set, suppress the exception for failed exporting and continue to"
        " export the next type of model",
    )
    return parser

125

Tsahi Glik's avatar
Tsahi Glik committed
126
127
def cli(args=None):
    args = sys.argv[1:] if args is None else args
128
    run_with_cmdline_args(get_parser().parse_args(args))
facebook-github-bot's avatar
facebook-github-bot committed
129

130

facebook-github-bot's avatar
facebook-github-bot committed
131
if __name__ == "__main__":
Tsahi Glik's avatar
Tsahi Glik committed
132
    cli()