Unverified Commit 94c2d862 authored by ChaimZhu's avatar ChaimZhu Committed by GitHub
Browse files

[Feature] Add mmdet3d2torchserve tool (#977)



* add tochserve support

* fully support torchserve

* Delete the empty file

* fix typos

* add docstrings and doc

* add config.properties

* fix typos and dosctrings

* change pipeline name
Co-authored-by: default avatarTai-Wang <tab_wang@outlook.com>
parent cbc2491f
ARG PYTORCH="1.6.0"
ARG CUDA="10.1"
ARG CUDNN="7"
FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
ARG MMCV="1.3.8"
ARG MMSEGMENTATION="0.14.1"
ARG MMDET="2.14.0"
ARG MMDET3D="0.17.1"
ENV PYTHONUNBUFFERED TRUE
RUN apt-get update && \
DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \
ca-certificates \
g++ \
openjdk-11-jre-headless \
# MMDet3D Requirements
ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 \
&& rm -rf /var/lib/apt/lists/*
ENV PATH="/opt/conda/bin:$PATH"
RUN export FORCE_CUDA=1
# TORCHSEVER
RUN pip install torchserve torch-model-archiver
# MMLAB
ARG PYTORCH
ARG CUDA
RUN ["/bin/bash", "-c", "pip install mmcv-full==${MMCV} -f https://download.openmmlab.com/mmcv/dist/cu${CUDA//./}/torch${PYTORCH}/index.html"]
RUN pip install mmdet==${MMDET}
RUN pip install mmsegmentation==${MMSEGMENTATION}
RUN pip install mmdet3d==${MMDET3D}
RUN useradd -m model-server \
&& mkdir -p /home/model-server/tmp
COPY entrypoint.sh /usr/local/bin/entrypoint.sh
RUN chmod +x /usr/local/bin/entrypoint.sh \
&& chown -R model-server /home/model-server
COPY config.properties /home/model-server/config.properties
RUN mkdir /home/model-server/model-store && chown -R model-server /home/model-server/model-store
EXPOSE 8080 8081 8082
USER model-server
WORKDIR /home/model-server
ENV TEMP=/home/model-server/tmp
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
CMD ["serve"]
inference_address=http://0.0.0.0:8080
management_address=http://0.0.0.0:8081
metrics_address=http://0.0.0.0:8082
model_store=/home/model-server/model-store
load_models=all
#!/bin/bash
set -e
if [[ "$1" = "serve" ]]; then
shift 1
torchserve --start --ts-config /home/model-server/config.properties
else
eval "$@"
fi
# prevent docker exit
tail -f /dev/null
...@@ -122,6 +122,64 @@ python tools/misc/browse_dataset.py configs/_base_/datasets/nus-mono3d.py --task ...@@ -122,6 +122,64 @@ python tools/misc/browse_dataset.py configs/_base_/datasets/nus-mono3d.py --task
&emsp; &emsp;
# Model Serving
**Note**: This tool is still experimental now, only SECOND is supported to be served with [`TorchServe`](https://pytorch.org/serve/). We'll support more models in the future.
In order to serve an `MMDetection3D` model with [`TorchServe`](https://pytorch.org/serve/), you can follow the steps:
## 1. Convert the model from MMDetection3D to TorchServe
```shell
python tools/deployment/mmdet3d2torchserve.py ${CONFIG_FILE} ${CHECKPOINT_FILE} \
--output-folder ${MODEL_STORE} \
--model-name ${MODEL_NAME}
```
**Note**: ${MODEL_STORE} needs to be an absolute path to a folder.
## 2. Build `mmdet3d-serve` docker image
```shell
docker build -t mmdet3d-serve:latest docker/serve/
```
## 3. Run `mmdet3d-serve`
Check the official docs for [running TorchServe with docker](https://github.com/pytorch/serve/blob/master/docker/README.md#running-torchserve-in-a-production-docker-environment).
In order to run it on the GPU, you need to install [nvidia-docker](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). You can omit the `--gpus` argument in order to run on the CPU.
Example:
```shell
docker run --rm \
--cpus 8 \
--gpus device=0 \
-p8080:8080 -p8081:8081 -p8082:8082 \
--mount type=bind,source=$MODEL_STORE,target=/home/model-server/model-store \
mmdet3d-serve:latest
```
[Read the docs](https://github.com/pytorch/serve/blob/072f5d088cce9bb64b2a18af065886c9b01b317b/docs/rest_api.md/) about the Inference (8080), Management (8081) and Metrics (8082) APis
## 4. Test deployment
You can use `test_torchserver.py` to compare result of torchserver and pytorch.
```shell
python tools/deployment/test_torchserver.py ${IMAGE_FILE} ${CONFIG_FILE} ${CHECKPOINT_FILE} ${MODEL_NAME}
[--inference-addr ${INFERENCE_ADDR}] [--device ${DEVICE}] [--score-thr ${SCORE_THR}]
```
Example:
```shell
python tools/deployment/test_torchserver.py demo/data/kitti/kitti_000008.bin configs/second/hv_second_secfpn_6x8_80e_kitti-3d-car.py checkpoints/hv_second_secfpn_6x8_80e_kitti-3d-car_20200620_230238-393f000c.pth second
```
&emsp;
# Model Complexity # Model Complexity
You can use `tools/analysis_tools/get_flops.py` in MMDetection3D, a script adapted from [flops-counter.pytorch](https://github.com/sovrasov/flops-counter.pytorch), to compute the FLOPs and params of a given model. You can use `tools/analysis_tools/get_flops.py` in MMDetection3D, a script adapted from [flops-counter.pytorch](https://github.com/sovrasov/flops-counter.pytorch), to compute the FLOPs and params of a given model.
......
...@@ -123,6 +123,64 @@ python tools/misc/browse_dataset.py configs/_base_/datasets/nus-mono3d.py --task ...@@ -123,6 +123,64 @@ python tools/misc/browse_dataset.py configs/_base_/datasets/nus-mono3d.py --task
&emsp; &emsp;
# 模型部署
**Note**: 此工具仍然处于试验阶段,目前只有 SECOND 支持用 [`TorchServe`](https://pytorch.org/serve/) 部署,我们将会在未来支持更多的模型。
为了使用 [`TorchServe`](https://pytorch.org/serve/) 部署 `MMDetection3D` 模型,您可以遵循以下步骤:
## 1. 将模型从 MMDetection3D 转换到 TorchServe
```shell
python tools/deployment/mmdet3d2torchserve.py ${CONFIG_FILE} ${CHECKPOINT_FILE} \
--output-folder ${MODEL_STORE} \
--model-name ${MODEL_NAME}
```
**Note**: ${MODEL_STORE} 需要为文件夹的绝对路径。
## 2. 构建 `mmdet3d-serve` 镜像
```shell
docker build -t mmdet3d-serve:latest docker/serve/
```
## 3. 运行 `mmdet3d-serve`
查看官网文档来 [使用 docker 运行 TorchServe](https://github.com/pytorch/serve/blob/master/docker/README.md#running-torchserve-in-a-production-docker-environment)
为了在 GPU 上运行,您需要安装 [nvidia-docker](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html)。您可以忽略 `--gpus` 参数,从而在 CPU 上运行。
例子:
```shell
docker run --rm \
--cpus 8 \
--gpus device=0 \
-p8080:8080 -p8081:8081 -p8082:8082 \
--mount type=bind,source=$MODEL_STORE,target=/home/model-server/model-store \
mmdet3d-serve:latest
```
[阅读文档](https://github.com/pytorch/serve/blob/072f5d088cce9bb64b2a18af065886c9b01b317b/docs/rest_api.md/) 关于 Inference (8080), Management (8081) and Metrics (8082) 接口。
## 4. 测试部署
您可以使用 `test_torchserver.py` 进行部署, 同时比较 torchserver 和 pytorch 的结果。
```shell
python tools/deployment/test_torchserver.py ${IMAGE_FILE} ${CONFIG_FILE} ${CHECKPOINT_FILE} ${MODEL_NAME}
[--inference-addr ${INFERENCE_ADDR}] [--device ${DEVICE}] [--score-thr ${SCORE_THR}]
```
例子:
```shell
python tools/deployment/test_torchserver.py demo/data/kitti/kitti_000008.bin configs/second/hv_second_secfpn_6x8_80e_kitti-3d-car.py checkpoints/hv_second_secfpn_6x8_80e_kitti-3d-car_20200620_230238-393f000c.pth second
```
&emsp;
# 模型复杂度 # 模型复杂度
您可以使用 MMDetection 中的 `tools/analysis_tools/get_flops.py` 这个脚本文件,基于 [flops-counter.pytorch](https://github.com/sovrasov/flops-counter.pytorch) 计算一个给定模型的计算量 (FLOPS) 和参数量 (params)。 您可以使用 MMDetection 中的 `tools/analysis_tools/get_flops.py` 这个脚本文件,基于 [flops-counter.pytorch](https://github.com/sovrasov/flops-counter.pytorch) 计算一个给定模型的计算量 (FLOPS) 和参数量 (params)。
......
...@@ -82,10 +82,19 @@ def inference_detector(model, pcd): ...@@ -82,10 +82,19 @@ def inference_detector(model, pcd):
""" """
cfg = model.cfg cfg = model.cfg
device = next(model.parameters()).device # model device device = next(model.parameters()).device # model device
if not isinstance(pcd, str):
cfg = cfg.copy()
# set loading pipeline type
cfg.data.test.pipeline[0].type = 'LoadPointsFromDict'
# build the data pipeline # build the data pipeline
test_pipeline = deepcopy(cfg.data.test.pipeline) test_pipeline = deepcopy(cfg.data.test.pipeline)
test_pipeline = Compose(test_pipeline) test_pipeline = Compose(test_pipeline)
box_type_3d, box_mode_3d = get_box_type(cfg.data.test.box_type_3d) box_type_3d, box_mode_3d = get_box_type(cfg.data.test.box_type_3d)
if isinstance(pcd, str):
# load from point clouds file
data = dict( data = dict(
pts_filename=pcd, pts_filename=pcd,
box_type_3d=box_type_3d, box_type_3d=box_type_3d,
...@@ -102,6 +111,24 @@ def inference_detector(model, pcd): ...@@ -102,6 +111,24 @@ def inference_detector(model, pcd):
bbox_fields=[], bbox_fields=[],
mask_fields=[], mask_fields=[],
seg_fields=[]) seg_fields=[])
else:
# load from http
data = dict(
points=pcd,
box_type_3d=box_type_3d,
box_mode_3d=box_mode_3d,
# for ScanNet demo we need axis_align_matrix
ann_info=dict(axis_align_matrix=np.eye(4)),
sweeps=[],
# set timestamp = 0
timestamp=[0],
img_fields=[],
bbox3d_fields=[],
pts_mask_fields=[],
pts_seg_fields=[],
bbox_fields=[],
mask_fields=[],
seg_fields=[])
data = test_pipeline(data) data = test_pipeline(data)
data = collate([data], samples_per_gpu=1) data = collate([data], samples_per_gpu=1)
if next(model.parameters()).is_cuda: if next(model.parameters()).is_cuda:
......
...@@ -12,11 +12,12 @@ from .nuscenes_mono_dataset import NuScenesMonoDataset ...@@ -12,11 +12,12 @@ from .nuscenes_mono_dataset import NuScenesMonoDataset
from .pipelines import (AffineResize, BackgroundPointsFilter, GlobalAlignment, from .pipelines import (AffineResize, BackgroundPointsFilter, GlobalAlignment,
GlobalRotScaleTrans, IndoorPatchPointSample, GlobalRotScaleTrans, IndoorPatchPointSample,
IndoorPointSample, LoadAnnotations3D, IndoorPointSample, LoadAnnotations3D,
LoadPointsFromFile, LoadPointsFromMultiSweeps, LoadPointsFromDict, LoadPointsFromFile,
NormalizePointsColor, ObjectNameFilter, ObjectNoise, LoadPointsFromMultiSweeps, NormalizePointsColor,
ObjectRangeFilter, ObjectSample, PointSample, ObjectNameFilter, ObjectNoise, ObjectRangeFilter,
PointShuffle, PointsRangeFilter, RandomDropPointsColor, ObjectSample, PointSample, PointShuffle,
RandomFlip3D, RandomJitterPoints, RandomShiftScale, PointsRangeFilter, RandomDropPointsColor, RandomFlip3D,
RandomJitterPoints, RandomShiftScale,
VoxelBasedPointSampler) VoxelBasedPointSampler)
# yapf: enable # yapf: enable
from .s3dis_dataset import S3DISDataset, S3DISSegDataset from .s3dis_dataset import S3DISDataset, S3DISSegDataset
...@@ -38,5 +39,6 @@ __all__ = [ ...@@ -38,5 +39,6 @@ __all__ = [
'Custom3DDataset', 'Custom3DSegDataset', 'LoadPointsFromMultiSweeps', 'Custom3DDataset', 'Custom3DSegDataset', 'LoadPointsFromMultiSweeps',
'WaymoDataset', 'BackgroundPointsFilter', 'VoxelBasedPointSampler', 'WaymoDataset', 'BackgroundPointsFilter', 'VoxelBasedPointSampler',
'get_loading_pipeline', 'RandomDropPointsColor', 'RandomJitterPoints', 'get_loading_pipeline', 'RandomDropPointsColor', 'RandomJitterPoints',
'ObjectNameFilter', 'AffineResize', 'RandomShiftScale' 'ObjectNameFilter', 'AffineResize', 'RandomShiftScale',
'LoadPointsFromDict'
] ]
...@@ -3,9 +3,9 @@ from mmdet.datasets.pipelines import Compose ...@@ -3,9 +3,9 @@ from mmdet.datasets.pipelines import Compose
from .dbsampler import DataBaseSampler from .dbsampler import DataBaseSampler
from .formating import Collect3D, DefaultFormatBundle, DefaultFormatBundle3D from .formating import Collect3D, DefaultFormatBundle, DefaultFormatBundle3D
from .loading import (LoadAnnotations3D, LoadImageFromFileMono3D, from .loading import (LoadAnnotations3D, LoadImageFromFileMono3D,
LoadMultiViewImageFromFiles, LoadPointsFromFile, LoadMultiViewImageFromFiles, LoadPointsFromDict,
LoadPointsFromMultiSweeps, NormalizePointsColor, LoadPointsFromFile, LoadPointsFromMultiSweeps,
PointSegClassMapping) NormalizePointsColor, PointSegClassMapping)
from .test_time_aug import MultiScaleFlipAug3D from .test_time_aug import MultiScaleFlipAug3D
# yapf: disable # yapf: disable
from .transforms_3d import (AffineResize, BackgroundPointsFilter, from .transforms_3d import (AffineResize, BackgroundPointsFilter,
...@@ -27,5 +27,6 @@ __all__ = [ ...@@ -27,5 +27,6 @@ __all__ = [
'LoadPointsFromMultiSweeps', 'BackgroundPointsFilter', 'LoadPointsFromMultiSweeps', 'BackgroundPointsFilter',
'VoxelBasedPointSampler', 'GlobalAlignment', 'IndoorPatchPointSample', 'VoxelBasedPointSampler', 'GlobalAlignment', 'IndoorPatchPointSample',
'LoadImageFromFileMono3D', 'ObjectNameFilter', 'RandomDropPointsColor', 'LoadImageFromFileMono3D', 'ObjectNameFilter', 'RandomDropPointsColor',
'RandomJitterPoints', 'AffineResize', 'RandomShiftScale' 'RandomJitterPoints', 'AffineResize', 'RandomShiftScale',
'LoadPointsFromDict'
] ]
...@@ -338,7 +338,7 @@ class NormalizePointsColor(object): ...@@ -338,7 +338,7 @@ class NormalizePointsColor(object):
class LoadPointsFromFile(object): class LoadPointsFromFile(object):
"""Load Points From File. """Load Points From File.
Load sunrgbd and scannet points from file. Load points from file.
Args: Args:
coord_type (str): The type of coordinates of points cloud. coord_type (str): The type of coordinates of points cloud.
...@@ -460,6 +460,15 @@ class LoadPointsFromFile(object): ...@@ -460,6 +460,15 @@ class LoadPointsFromFile(object):
return repr_str return repr_str
@PIPELINES.register_module()
class LoadPointsFromDict(LoadPointsFromFile):
"""Load Points From Dict."""
def __call__(self, results):
assert 'points' in results
return results
@PIPELINES.register_module() @PIPELINES.register_module()
class LoadAnnotations3D(LoadAnnotations): class LoadAnnotations3D(LoadAnnotations):
"""Load Annotations3D. """Load Annotations3D.
......
...@@ -8,7 +8,7 @@ line_length = 79 ...@@ -8,7 +8,7 @@ line_length = 79
multi_line_output = 0 multi_line_output = 0
known_standard_library = setuptools known_standard_library = setuptools
known_first_party = mmdet,mmseg,mmdet3d known_first_party = mmdet,mmseg,mmdet3d
known_third_party = cv2,imageio,indoor3d_util,load_scannet_data,lyft_dataset_sdk,m2r,matplotlib,mmcv,nuimages,numba,numpy,nuscenes,pandas,plyfile,pycocotools,pyquaternion,pytest,pytorch_sphinx_theme,recommonmark,scannet_utils,scipy,seaborn,shapely,skimage,tensorflow,terminaltables,torch,trimesh,waymo_open_dataset known_third_party = cv2,imageio,indoor3d_util,load_scannet_data,lyft_dataset_sdk,m2r,matplotlib,mmcv,nuimages,numba,numpy,nuscenes,pandas,plyfile,pycocotools,pyquaternion,pytest,pytorch_sphinx_theme,recommonmark,requests,scannet_utils,scipy,seaborn,shapely,skimage,tensorflow,terminaltables,torch,trimesh,ts,waymo_open_dataset
no_lines_before = STDLIB,LOCALFOLDER no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY default_section = THIRDPARTY
......
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
from argparse import ArgumentParser, Namespace
from pathlib import Path
from tempfile import TemporaryDirectory
try:
from model_archiver.model_packaging import package_model
from model_archiver.model_packaging_utils import ModelExportUtils
except ImportError:
package_model = None
def mmdet3d2torchserve(
config_file: str,
checkpoint_file: str,
output_folder: str,
model_name: str,
model_version: str = '1.0',
force: bool = False,
):
"""Converts MMDetection3D model (config + checkpoint) to TorchServe `.mar`.
Args:
config_file (str):
In MMDetection3D config format.
The contents vary for each task repository.
checkpoint_file (str):
In MMDetection3D checkpoint format.
The contents vary for each task repository.
output_folder (str):
Folder where `{model_name}.mar` will be created.
The file created will be in TorchServe archive format.
model_name (str):
If not None, used for naming the `{model_name}.mar` file
that will be created under `output_folder`.
If None, `{Path(checkpoint_file).stem}` will be used.
model_version (str, optional):
Model's version. Default: '1.0'.
force (bool, optional):
If True, if there is an existing `{model_name}.mar`
file under `output_folder` it will be overwritten.
Default: False.
"""
mmcv.mkdir_or_exist(output_folder)
config = mmcv.Config.fromfile(config_file)
with TemporaryDirectory() as tmpdir:
config.dump(f'{tmpdir}/config.py')
args = Namespace(
**{
'model_file': f'{tmpdir}/config.py',
'serialized_file': checkpoint_file,
'handler': f'{Path(__file__).parent}/mmdet3d_handler.py',
'model_name': model_name or Path(checkpoint_file).stem,
'version': model_version,
'export_path': output_folder,
'force': force,
'requirements_file': None,
'extra_files': None,
'runtime': 'python',
'archive_format': 'default'
})
manifest = ModelExportUtils.generate_manifest_json(args)
package_model(args, manifest)
def parse_args():
parser = ArgumentParser(
description='Convert MMDetection models to TorchServe `.mar` format.')
parser.add_argument('config', type=str, help='config file path')
parser.add_argument('checkpoint', type=str, help='checkpoint file path')
parser.add_argument(
'--output-folder',
type=str,
required=True,
help='Folder where `{model_name}.mar` will be created.')
parser.add_argument(
'--model-name',
type=str,
default=None,
help='If not None, used for naming the `{model_name}.mar`'
'file that will be created under `output_folder`.'
'If None, `{Path(checkpoint_file).stem}` will be used.')
parser.add_argument(
'--model-version',
type=str,
default='1.0',
help='Number used for versioning.')
parser.add_argument(
'-f',
'--force',
action='store_true',
help='overwrite the existing `{model_name}.mar`')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
if package_model is None:
raise ImportError('`torch-model-archiver` is required.'
'Try: pip install torch-model-archiver')
mmdet3d2torchserve(args.config, args.checkpoint, args.output_folder,
args.model_name, args.model_version, args.force)
# Copyright (c) OpenMMLab. All rights reserved.
import base64
import numpy as np
import os
import torch
from ts.torch_handler.base_handler import BaseHandler
from mmdet3d.apis import inference_detector, init_model
from mmdet3d.core.points import get_points_type
class MMdet3dHandler(BaseHandler):
"""MMDetection3D Handler used in TorchServe.
Handler to load models in MMDetection3D, and it will process data to get
predicted results. For now, it only supports SECOND.
"""
threshold = 0.5
load_dim = 4
use_dim = [0, 1, 2, 3]
coord_type = 'LIDAR'
attribute_dims = None
def initialize(self, context):
"""Initialize function loads the model in MMDetection3D.
Args:
context (context): It is a JSON Object containing information
pertaining to the model artifacts parameters.
"""
properties = context.system_properties
self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = torch.device(self.map_location + ':' +
str(properties.get('gpu_id')) if torch.cuda.
is_available() else self.map_location)
self.manifest = context.manifest
model_dir = properties.get('model_dir')
serialized_file = self.manifest['model']['serializedFile']
checkpoint = os.path.join(model_dir, serialized_file)
self.config_file = os.path.join(model_dir, 'config.py')
self.model = init_model(self.config_file, checkpoint, self.device)
self.initialized = True
def preprocess(self, data):
"""Preprocess function converts data into LiDARPoints class.
Args:
data (List): Input data from the request.
Returns:
`LiDARPoints` : The preprocess function returns the input
point cloud data as LiDARPoints class.
"""
for row in data:
# Compat layer: normally the envelope should just return the data
# directly, but older versions of Torchserve didn't have envelope.
pts = row.get('data') or row.get('body')
if isinstance(pts, str):
pts = base64.b64decode(pts)
points = np.frombuffer(pts, dtype=np.float32)
points = points.reshape(-1, self.load_dim)
points = points[:, self.use_dim]
points_class = get_points_type(self.coord_type)
points = points_class(
points,
points_dim=points.shape[-1],
attribute_dims=self.attribute_dims)
return points
def inference(self, data):
"""Inference Function.
This function is used to make a prediction call on the
given input request.
Args:
data (`LiDARPoints`): LiDARPoints class passed to make
the inference request.
Returns:
List(dict) : The predicted result is returned in this function.
"""
results, _ = inference_detector(self.model, data)
return results
def postprocess(self, data):
"""Postprocess function.
This function makes use of the output from the inference and
converts it into a torchserve supported response output.
Args:
data (List[dict]): The data received from the prediction
output of the model.
Returns:
List: The post process function returns a list of the predicted
output.
"""
output = []
for pts_index, result in enumerate(data):
output.append([])
if 'pts_bbox' in result.keys():
pred_bboxes = result['pts_bbox']['boxes_3d'].tensor.numpy()
pred_scores = result['pts_bbox']['scores_3d'].numpy()
else:
pred_bboxes = result['boxes_3d'].tensor.numpy()
pred_scores = result['scores_3d'].numpy()
index = pred_scores > self.threshold
bbox_coords = pred_bboxes[index].tolist()
score = pred_scores[index].tolist()
output[pts_index].append({'3dbbox': bbox_coords, 'score': score})
return output
import numpy as np
import requests
from argparse import ArgumentParser
from mmdet3d.apis import inference_detector, init_model
def parse_args():
parser = ArgumentParser()
parser.add_argument('pcd', help='Point cloud file')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument('model_name', help='The model name in the server')
parser.add_argument(
'--inference-addr',
default='127.0.0.1:8080',
help='Address and port of the inference server')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--score-thr', type=float, default=0.5, help='3d bbox score threshold')
args = parser.parse_args()
return args
def parse_result(input):
bbox = input[0]['3dbbox']
result = np.array(bbox)
return result
def main(args):
# build the model from a config file and a checkpoint file
model = init_model(args.config, args.checkpoint, device=args.device)
# test a single point cloud file
model_result, _ = inference_detector(model, args.pcd)
# filter the 3d bboxes whose scores > 0.5
if 'pts_bbox' in model_result[0].keys():
pred_bboxes = model_result[0]['pts_bbox']['boxes_3d'].tensor.numpy()
pred_scores = model_result[0]['pts_bbox']['scores_3d'].numpy()
else:
pred_bboxes = model_result[0]['boxes_3d'].tensor.numpy()
pred_scores = model_result[0]['scores_3d'].numpy()
model_result = pred_bboxes[pred_scores > 0.5]
url = 'http://' + args.inference_addr + '/predictions/' + args.model_name
with open(args.pcd, 'rb') as points:
response = requests.post(url, points)
server_result = parse_result(response.json())
assert np.allclose(model_result, server_result)
if __name__ == '__main__':
args = parse_args()
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment