onnx2vacc_quant_dataset.py 4.23 KB
Newer Older
limm's avatar
limm committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import logging
import os.path as osp
from copy import deepcopy
from typing import Optional, Sequence

import h5py
import tqdm
from mmengine import Config

from mmdeploy.apis.utils import build_task_processor
from mmdeploy.utils import get_root_logger, load_config


def get_tensor_func(model, input_data):
    input_data = model.data_preprocessor(input_data)
    return input_data['inputs']


def process_model_config(model_cfg: Config,
                         input_shape: Optional[Sequence[int]] = None):
    """Process the model config.

    Args:
        model_cfg (Config): The model config.
        input_shape (list[int]): A list of two integer in (width, height)
            format specifying input shape. Default: None.

    Returns:
        Config: the model config after processing.
    """

    cfg = model_cfg.copy()

    pipeline = cfg.test_pipeline

    for i, transform in enumerate(pipeline):
        # for static exporting
        if transform.type == 'Resize':
            pipeline[i].keep_ratio = False
            pipeline[i].scale = tuple(input_shape)
        if transform.type in ('YOLOv5KeepRatioResize', 'LetterResize'):
            pipeline[i].scale = tuple(input_shape)

    pipeline = [
        transform for transform in pipeline
        if transform.type != 'LoadAnnotations'
    ]
    cfg.test_pipeline = pipeline
    return cfg


def get_quant(deploy_cfg: Config,
              model_cfg: Config,
              shape_dict: dict,
              checkpoint_path: str,
              work_dir: str,
              device: str = 'cpu',
              dataset_type: str = 'val'):

    model_shape = list(shape_dict.values())[0]
    model_cfg = process_model_config(model_cfg,
                                     (model_shape[3], model_shape[2]))

    task_processor = build_task_processor(model_cfg, deploy_cfg, device)
    model = task_processor.build_pytorch_model(checkpoint_path)
    calib_dataloader = deepcopy(model_cfg[f'{dataset_type}_dataloader'])
    calib_dataloader['batch_size'] = 1

    dataloader = task_processor.build_dataloader(calib_dataloader)
    output_quant_dataset_path = osp.join(work_dir, 'calib_data.h5')

    with h5py.File(output_quant_dataset_path, mode='w') as file:
        calib_data_group = file.create_group('calib_data')
        input_data_group = calib_data_group.create_group('input')

        # get an available input shape randomly
        for data_id, input_data in enumerate(tqdm.tqdm(dataloader)):
            # input_data = data_preprocessor(input_data)['inputs'].numpy()
            input_data = get_tensor_func(model, input_data).numpy()
            calib_data_shape = input_data.shape
            assert model_shape[2] >= calib_data_shape[2] and model_shape[
                3] >= calib_data_shape[
                    3], f'vacc backend model shape is {tuple(model_shape[2:])}, \
                        the calib_data shape {calib_data_shape[2:]} is bigger'

            input_data_group.create_dataset(
                str(data_id),
                shape=input_data.shape,
                compression='gzip',
                compression_opts=4,
                data=input_data)


def parse_args():
    parser = argparse.ArgumentParser(
        description='Generate vacc quant dataset from ONNX.')
    parser.add_argument('--deploy-cfg', help='Input deploy config path')
    parser.add_argument('--model-cfg', help='Input model config path')
    parser.add_argument('--shape-dict', help='Input model shape')
    parser.add_argument('--checkpoint-path', help='checkpoint path')
    parser.add_argument('--work-dir', help='Output quant dataset dir')

    parser.add_argument(
        '--log-level',
        help='set log level',
        default='INFO',
        choices=list(logging._nameToLevel.keys()))
    args = parser.parse_args()

    return args


def main():
    args = parse_args()
    logger = get_root_logger(log_level=args.log_level)

    deploy_cfg, model_cfg = load_config(args.deploy_cfg, args.model_cfg)
    work_dir = args.work_dir
    checkpoint_path = args.checkpoint_path
    shape_dict = eval(args.shape_dict)

    get_quant(deploy_cfg, model_cfg, shape_dict, checkpoint_path, work_dir)
    logger.info('onnx2vacc_quant_dataset success.')


if __name__ == '__main__':
    main()