run_inference_on_fw.py 5.24 KB
Newer Older
Sugon_ldc's avatar
Sugon_ldc committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#!/usr/bin/env python3

# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""
To infer the model on framework runtime, you can use `run_inference_on_fw.py` script.
It infers data obtained from pointed data loader locally and saves received data into npz files.
Those files are stored in directory pointed by `--output-dir` argument.

Example call:

```shell script
python ./triton/run_inference_on_fw.py \
    --input-path /models/exported/model.onnx \
    --input-type onnx \
    --dataloader triton/dataloader.py \
    --data-dir /data/imagenet \
    --batch-size 32 \
    --output-dir /results/dump_local \
    --dump-labels
```
"""

import argparse
import logging
import os
from pathlib import Path

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["TF_ENABLE_DEPRECATION_WARNINGS"] = "0"

from tqdm import tqdm

# method from PEP-366 to support relative import in executed modules
if __package__ is None:
    __package__ = Path(__file__).parent.name

from .deployment_toolkit.args import ArgParserGenerator
from .deployment_toolkit.core import DATALOADER_FN_NAME, BaseLoader, BaseRunner, Format, load_from_file
from .deployment_toolkit.dump import NpzWriter
from .deployment_toolkit.extensions import loaders, runners

LOGGER = logging.getLogger("run_inference_on_fw")


def _verify_and_format_dump(args, ids, x, y_pred, y_real):
    data = {"outputs": y_pred, "ids": {"ids": ids}}
    if args.dump_inputs:
        data["inputs"] = x
    if args.dump_labels:
        if not y_real:
            raise ValueError(
                "Found empty label values. Please provide labels in dataloader_fn or do not use --dump-labels argument"
            )
        data["labels"] = y_real
    return data


def _parse_and_validate_args():
    supported_inputs = set(runners.supported_extensions) & set(loaders.supported_extensions)

    parser = argparse.ArgumentParser(description="Dump local inference output of given model", allow_abbrev=False)
    parser.add_argument("--input-path", help="Path to input model", required=True)
    parser.add_argument("--input-type", help="Input model type", choices=supported_inputs, required=True)
    parser.add_argument("--dataloader", help="Path to python file containing dataloader.", required=True)
    parser.add_argument("--output-dir", help="Path to dir where output files will be stored", required=True)
    parser.add_argument("--dump-labels", help="Dump labels to output dir", action="store_true", default=False)
    parser.add_argument("--dump-inputs", help="Dump inputs to output dir", action="store_true", default=False)
    parser.add_argument("-v", "--verbose", help="Verbose logs", action="store_true", default=False)

    args, *_ = parser.parse_known_args()

    get_dataloader_fn = load_from_file(args.dataloader, label="dataloader", target=DATALOADER_FN_NAME)
    ArgParserGenerator(get_dataloader_fn).update_argparser(parser)

    Loader: BaseLoader = loaders.get(args.input_type)
    ArgParserGenerator(Loader, module_path=args.input_path).update_argparser(parser)

    Runner: BaseRunner = runners.get(args.input_type)
    ArgParserGenerator(Runner).update_argparser(parser)

    args = parser.parse_args()

    types_requiring_io_params = []

    if args.input_type in types_requiring_io_params and not all(p for p in [args.inputs, args.outputs]):
        parser.error(f"For {args.input_type} input provide --inputs and --outputs parameters")

    return args


def main():
    args = _parse_and_validate_args()

    log_level = logging.INFO if not args.verbose else logging.DEBUG
    log_format = "%(asctime)s %(levelname)s %(name)s %(message)s"
    logging.basicConfig(level=log_level, format=log_format)

    LOGGER.info(f"args:")
    for key, value in vars(args).items():
        LOGGER.info(f"    {key} = {value}")

    Loader: BaseLoader = loaders.get(args.input_type)
    Runner: BaseRunner = runners.get(args.input_type)

    loader = ArgParserGenerator(Loader, module_path=args.input_path).from_args(args)
    runner = ArgParserGenerator(Runner).from_args(args)
    LOGGER.info(f"Loading {args.input_path}")
    model = loader.load(args.input_path)
    with runner.init_inference(model=model) as runner_session, NpzWriter(args.output_dir) as writer:
        get_dataloader_fn = load_from_file(args.dataloader, label="dataloader", target=DATALOADER_FN_NAME)
        dataloader_fn = ArgParserGenerator(get_dataloader_fn).from_args(args)
        LOGGER.info(f"Data loader initialized; Running inference")
        for ids, x, y_real in tqdm(dataloader_fn(), unit="batch", mininterval=10):
            y_pred = runner_session(x)
            data = _verify_and_format_dump(args, ids=ids, x=x, y_pred=y_pred, y_real=y_real)
            writer.write(**data)
        LOGGER.info(f"Inference finished")


if __name__ == "__main__":
    main()