Unverified Commit 08a11c17 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

add a jupyter notebook demo (#1158)

parent 63b9d104
...@@ -62,13 +62,13 @@ python tools/test.py configs/mask_rcnn_r50_fpn_1x.py \ ...@@ -62,13 +62,13 @@ python tools/test.py configs/mask_rcnn_r50_fpn_1x.py \
We provide a webcam demo to illustrate the results. We provide a webcam demo to illustrate the results.
```shell ```shell
python tools/webcam_demo.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--device ${GPU_ID}] [--camera-id ${CAMERA-ID}] [--score-thr ${CAMERA-ID}] python demo/webcam_demo.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--device ${GPU_ID}] [--camera-id ${CAMERA-ID}] [--score-thr ${CAMERA-ID}]
``` ```
Examples: Examples:
```shell ```shell
python tools/webcam_demo.py configs/faster_rcnn_r50_fpn_1x.py \ python demo/webcam_demo.py configs/faster_rcnn_r50_fpn_1x.py \
checkpoints/faster_rcnn_r50_fpn_1x_20181010-3d1b3351.pth checkpoints/faster_rcnn_r50_fpn_1x_20181010-3d1b3351.pth
``` ```
...@@ -103,6 +103,8 @@ for frame in video: ...@@ -103,6 +103,8 @@ for frame in video:
show_result(frame, result, model.CLASSES, wait_time=1) show_result(frame, result, model.CLASSES, wait_time=1)
``` ```
A notebook demo can be found in [demo/inference_demo.ipynb](demo/inference_demo.ipynb).
## Train a model ## Train a model
......
This diff is collapsed.
from .env import get_root_logger, init_dist, set_random_seed from .env import get_root_logger, init_dist, set_random_seed
from .inference import inference_detector, init_detector, show_result from .inference import (inference_detector, init_detector, show_result,
show_result_pyplot)
from .train import train_detector from .train import train_detector
__all__ = [ __all__ = [
'init_dist', 'get_root_logger', 'set_random_seed', 'train_detector', 'init_dist', 'get_root_logger', 'set_random_seed', 'train_detector',
'init_detector', 'inference_detector', 'show_result' 'init_detector', 'inference_detector', 'show_result', 'show_result_pyplot'
] ]
import warnings import warnings
import matplotlib.pyplot as plt
import mmcv import mmcv
import numpy as np import numpy as np
import pycocotools.mask as maskUtils import pycocotools.mask as maskUtils
...@@ -105,6 +106,7 @@ def show_result(img, ...@@ -105,6 +106,7 @@ def show_result(img,
class_names, class_names,
score_thr=0.3, score_thr=0.3,
wait_time=0, wait_time=0,
show=True,
out_file=None): out_file=None):
"""Visualize the detection results on the image. """Visualize the detection results on the image.
...@@ -115,11 +117,17 @@ def show_result(img, ...@@ -115,11 +117,17 @@ def show_result(img,
class_names (list[str] or tuple[str]): A list of class names. class_names (list[str] or tuple[str]): A list of class names.
score_thr (float): The threshold to visualize the bboxes and masks. score_thr (float): The threshold to visualize the bboxes and masks.
wait_time (int): Value of waitKey param. wait_time (int): Value of waitKey param.
show (bool, optional): Whether to show the image with opencv or not.
out_file (str, optional): If specified, the visualization result will out_file (str, optional): If specified, the visualization result will
be written to the out file instead of shown in a window. be written to the out file instead of shown in a window.
Returns:
np.ndarray or None: If neither `show` nor `out_file` is specified, the
visualized image is returned, otherwise None is returned.
""" """
assert isinstance(class_names, (tuple, list)) assert isinstance(class_names, (tuple, list))
img = mmcv.imread(img) img = mmcv.imread(img)
img = img.copy()
if isinstance(result, tuple): if isinstance(result, tuple):
bbox_result, segm_result = result bbox_result, segm_result = result
else: else:
...@@ -140,11 +148,36 @@ def show_result(img, ...@@ -140,11 +148,36 @@ def show_result(img,
] ]
labels = np.concatenate(labels) labels = np.concatenate(labels)
mmcv.imshow_det_bboxes( mmcv.imshow_det_bboxes(
img.copy(), img,
bboxes, bboxes,
labels, labels,
class_names=class_names, class_names=class_names,
score_thr=score_thr, score_thr=score_thr,
show=out_file is None, show=show,
wait_time=wait_time, wait_time=wait_time,
out_file=out_file) out_file=out_file)
if not (show or out_file):
return img
def show_result_pyplot(img,
result,
class_names,
score_thr=0.3,
fig_size=(15, 10)):
"""Visualize the detection results on the image.
Args:
img (str or np.ndarray): Image filename or loaded image.
result (tuple[list] or list): The detection result, can be either
(bbox, segm) or just bbox.
class_names (list[str] or tuple[str]): A list of class names.
score_thr (float): The threshold to visualize the bboxes and masks.
fig_size (tuple): Figure size of the pyplot figure.
out_file (str, optional): If specified, the visualization result will
be written to the out file instead of shown in a window.
"""
img = show_result(
img, result, class_names, score_thr=score_thr, show=False)
plt.figure(figsize=fig_size)
plt.imshow(mmcv.bgr2rgb(img))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment