mmdet3d2torchserve.py 3.75 KB
Newer Older
yeshenglong1's avatar
yeshenglong1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright (c) OpenMMLab. All rights reserved.
from argparse import ArgumentParser, Namespace
from pathlib import Path
from tempfile import TemporaryDirectory

import mmcv

try:
    from model_archiver.model_packaging import package_model
    from model_archiver.model_packaging_utils import ModelExportUtils
except ImportError:
    package_model = None


def mmdet3d2torchserve(
zhe chen's avatar
zhe chen committed
16
17
18
19
20
21
        config_file: str,
        checkpoint_file: str,
        output_folder: str,
        model_name: str,
        model_version: str = '1.0',
        force: bool = False,
yeshenglong1's avatar
yeshenglong1 committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
):
    """Converts MMDetection3D model (config + checkpoint) to TorchServe `.mar`.

    Args:
        config_file (str):
            In MMDetection3D config format.
            The contents vary for each task repository.
        checkpoint_file (str):
            In MMDetection3D checkpoint format.
            The contents vary for each task repository.
        output_folder (str):
            Folder where `{model_name}.mar` will be created.
            The file created will be in TorchServe archive format.
        model_name (str):
            If not None, used for naming the `{model_name}.mar` file
            that will be created under `output_folder`.
            If None, `{Path(checkpoint_file).stem}` will be used.
        model_version (str, optional):
            Model's version. Default: '1.0'.
        force (bool, optional):
            If True, if there is an existing `{model_name}.mar`
            file under `output_folder` it will be overwritten.
            Default: False.
    """
    mmcv.mkdir_or_exist(output_folder)

    config = mmcv.Config.fromfile(config_file)

    with TemporaryDirectory() as tmpdir:
        config.dump(f'{tmpdir}/config.py')

        args = Namespace(
            **{
                'model_file': f'{tmpdir}/config.py',
                'serialized_file': checkpoint_file,
                'handler': f'{Path(__file__).parent}/mmdet3d_handler.py',
                'model_name': model_name or Path(checkpoint_file).stem,
                'version': model_version,
                'export_path': output_folder,
                'force': force,
                'requirements_file': None,
                'extra_files': None,
                'runtime': 'python',
                'archive_format': 'default'
            })
        manifest = ModelExportUtils.generate_manifest_json(args)
        package_model(args, manifest)


def parse_args():
    parser = ArgumentParser(
        description='Convert MMDetection models to TorchServe `.mar` format.')
    parser.add_argument('config', type=str, help='config file path')
    parser.add_argument('checkpoint', type=str, help='checkpoint file path')
    parser.add_argument(
        '--output-folder',
        type=str,
        required=True,
        help='Folder where `{model_name}.mar` will be created.')
    parser.add_argument(
        '--model-name',
        type=str,
        default=None,
        help='If not None, used for naming the `{model_name}.mar`'
zhe chen's avatar
zhe chen committed
86
87
             'file that will be created under `output_folder`.'
             'If None, `{Path(checkpoint_file).stem}` will be used.')
yeshenglong1's avatar
yeshenglong1 committed
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    parser.add_argument(
        '--model-version',
        type=str,
        default='1.0',
        help='Number used for versioning.')
    parser.add_argument(
        '-f',
        '--force',
        action='store_true',
        help='overwrite the existing `{model_name}.mar`')
    args = parser.parse_args()

    return args


if __name__ == '__main__':
    args = parse_args()

    if package_model is None:
        raise ImportError('`torch-model-archiver` is required.'
                          'Try: pip install torch-model-archiver')

    mmdet3d2torchserve(args.config, args.checkpoint, args.output_folder,
                       args.model_name, args.model_version, args.force)