torchscript.py 7.07 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


5
import contextlib
facebook-github-bot's avatar
facebook-github-bot committed
6
7
import logging
import os
8
from typing import Tuple, Optional, Dict, NamedTuple, List, AnyStr, Set
facebook-github-bot's avatar
facebook-github-bot committed
9
10

import torch
11
12
13
14
from d2go.export.api import ModelExportMethodRegistry, ModelExportMethod
from detectron2.config.instantiate import dump_dataclass, instantiate
from detectron2.export.flatten import TracingAdapter, flatten_to_tuple
from detectron2.export.torchscript_patch import patch_builtin_len
15
from detectron2.utils.file_io import PathManager
16
from mobile_cv.common.misc.file_utils import make_temp_directory
17
18
from mobile_cv.common.misc.iter_utils import recursive_iterate
from mobile_cv.predictor.model_wrappers import load_model
facebook-github-bot's avatar
facebook-github-bot committed
19
from torch import nn
20
21
22
from torch._C import MobileOptimizerType
from torch.utils.bundled_inputs import augment_model_with_bundled_inputs
from torch.utils.mobile_optimizer import optimize_for_mobile
facebook-github-bot's avatar
facebook-github-bot committed
23
24
25
26
27


logger = logging.getLogger(__name__)


28
29
30
31
32
33
34
class MobileOptimizationConfig(NamedTuple):
    # optimize_for_mobile
    optimization_blocklist: Set[MobileOptimizerType] = None
    preserved_methods: List[AnyStr] = None
    backend: str = "CPU"


facebook-github-bot's avatar
facebook-github-bot committed
35
36
37
38
def trace_and_save_torchscript(
    model: nn.Module,
    inputs: Tuple[torch.Tensor],
    output_path: str,
39
    mobile_optimization: Optional[MobileOptimizationConfig] = None,
facebook-github-bot's avatar
facebook-github-bot committed
40
41
42
    _extra_files: Optional[Dict[str, bytes]] = None,
):
    logger.info("Tracing and saving TorchScript to {} ...".format(output_path))
43
44
45
    PathManager.mkdirs(output_path)
    if _extra_files is None:
        _extra_files = {}
facebook-github-bot's avatar
facebook-github-bot committed
46

47
    with torch.no_grad():
facebook-github-bot's avatar
facebook-github-bot committed
48
49
        script_model = torch.jit.trace(model, inputs)

50
    with make_temp_directory("trace_and_save_torchscript") as tmp_dir:
facebook-github-bot's avatar
facebook-github-bot committed
51

52
53
54
55
56
57
58
59
60
61
62
63
        @contextlib.contextmanager
        def _synced_local_file(rel_path):
            remote_file = os.path.join(output_path, rel_path)
            local_file = os.path.join(tmp_dir, rel_path)
            yield local_file
            PathManager.copy_from_local(local_file, remote_file, overwrite=True)

        with _synced_local_file("model.jit") as model_file:
            torch.jit.save(script_model, model_file, _extra_files=_extra_files)

        with _synced_local_file("data.pth") as data_file:
            torch.save(inputs, data_file)
facebook-github-bot's avatar
facebook-github-bot committed
64

65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
        if mobile_optimization is not None:
            logger.info("Applying optimize_for_mobile ...")
            liteopt_model = optimize_for_mobile(
                script_model,
                optimization_blocklist=mobile_optimization.optimization_blocklist,
                preserved_methods=mobile_optimization.preserved_methods,
                backend=mobile_optimization.backend,
            )
            with _synced_local_file("mobile_optimized.ptl") as lite_path:
                liteopt_model._save_for_lite_interpreter(lite_path)
            # liteopt_model(*inputs)  # sanity check
            op_names = torch.jit.export_opnames(liteopt_model)
            logger.info(
                "Operator names from lite interpreter:\n{}".format("\n".join(op_names))
            )
facebook-github-bot's avatar
facebook-github-bot committed
80

81
            logger.info("Applying augment_model_with_bundled_inputs ...")
82
83
84
85
86
87
            # make all tensors zero-like to save storage
            iters = recursive_iterate(inputs)
            for x in iters:
                if isinstance(x, torch.Tensor):
                    iters.send(torch.zeros_like(x))
            inputs = iters.value
88
89
90
91
            augment_model_with_bundled_inputs(liteopt_model, [inputs])
            liteopt_model.run_on_bundled_input(0)  # sanity check
            with _synced_local_file("mobile_optimized_bundled.ptl") as lite_path:
                liteopt_model._save_for_lite_interpreter(lite_path)
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177


def tracing_adapter_wrap_export(old_f):
    def new_f(cls, model, input_args, *args, **kwargs):
        adapter = TracingAdapter(model, input_args)
        load_kwargs = old_f(cls, adapter, adapter.flattened_inputs, *args, **kwargs)
        inputs_schema = dump_dataclass(adapter.inputs_schema)
        outputs_schema = dump_dataclass(adapter.outputs_schema)
        assert "inputs_schema" not in load_kwargs
        assert "outputs_schema" not in load_kwargs
        load_kwargs.update(
            {"inputs_schema": inputs_schema, "outputs_schema": outputs_schema}
        )
        return load_kwargs

    return new_f


def tracing_adapter_wrap_load(old_f):
    def new_f(cls, save_path, **load_kwargs):
        assert "inputs_schema" in load_kwargs
        assert "outputs_schema" in load_kwargs
        inputs_schema = instantiate(load_kwargs.pop("inputs_schema"))
        outputs_schema = instantiate(load_kwargs.pop("outputs_schema"))
        traced_model = old_f(cls, save_path, **load_kwargs)

        class TracingAdapterModelWrapper(nn.Module):
            def __init__(self, traced_model, inputs_schema, outputs_schema):
                super().__init__()
                self.traced_model = traced_model
                self.inputs_schema = inputs_schema
                self.outputs_schema = outputs_schema

            def forward(self, *input_args):
                flattened_inputs, _ = flatten_to_tuple(input_args)
                flattened_outputs = self.traced_model(*flattened_inputs)
                return self.outputs_schema(flattened_outputs)

        return TracingAdapterModelWrapper(traced_model, inputs_schema, outputs_schema)

    return new_f


@ModelExportMethodRegistry.register("torchscript")
@ModelExportMethodRegistry.register("torchscript_int8")
@ModelExportMethodRegistry.register("torchscript_mobile")
@ModelExportMethodRegistry.register("torchscript_mobile_int8")
class DefaultTorchscriptExport(ModelExportMethod):
    @classmethod
    def export(cls, model, input_args, save_path, export_method, **export_kwargs):
        if export_method is not None:
            # update export_kwargs based on export_method
            assert isinstance(export_method, str)
            if "_mobile" in export_method:
                if "mobile_optimization" in export_kwargs:
                    logger.warning(
                        "`mobile_optimization` is already specified, keep using it"
                    )
                else:
                    export_kwargs["mobile_optimization"] = MobileOptimizationConfig()

        trace_and_save_torchscript(model, input_args, save_path, **export_kwargs)
        return {}

    @classmethod
    def load(cls, save_path, **load_kwargs):
        return load_model(save_path, "torchscript")


@ModelExportMethodRegistry.register("torchscript@tracing")
@ModelExportMethodRegistry.register("torchscript_int8@tracing")
@ModelExportMethodRegistry.register("torchscript_mobile@tracing")
@ModelExportMethodRegistry.register("torchscript_mobile_int8@tracing")
class D2TorchscriptTracingExport(DefaultTorchscriptExport):
    @classmethod
    @tracing_adapter_wrap_export
    def export(cls, model, input_args, save_path, export_method, **export_kwargs):
        with patch_builtin_len():
            return super().export(
                model, input_args, save_path, export_method, **export_kwargs
            )

    @classmethod
    @tracing_adapter_wrap_load
    def load(cls, save_path, **load_kwargs):
        return super().load(save_path, **load_kwargs)