Commit 9729ca54 authored by Rinat Shigapov's avatar Rinat Shigapov Committed by Kai Chen
Browse files

Async inference interface (#1647)

* async inference support

* implemented concurrent decorator

* fixes for Python versions < 3.7

* async methods depend on python version

* revert changes in forward method

* async_test -> async_simple_test, debug logging is done via logger.debug

* add async test

* add asynctest to test requirements

* async tests are run in Python 3.7

* check CUDA, add docs

* fix device

* run test only if CUDA is available

* fix linting

* custom operators can run on nondefault stream

* set current stream in kernel launch configuration

* example fixes

* add async/sync interface comparison benchmark

* fix linting
parent cd0d37cc
......@@ -74,6 +74,7 @@ python demo/webcam_demo.py configs/faster_rcnn_r50_fpn_1x.py \
### High-level APIs for testing images
#### Synchronous interface
Here is an example of building the model and test given images.
```python
......@@ -103,6 +104,48 @@ for frame in video:
A notebook demo can be found in [demo/inference_demo.ipynb](../demo/inference_demo.ipynb).
#### Asynchronous interface - supported for Python 3.7+
Async interface allows not to block CPU on GPU bound inference code and enables better CPU/GPU utilization for single threaded application. Inference can be done concurrently either between different input data samples or between different models of some inference pipeline.
See `tests/async_benchmark.py` to compare the speed of synchronous and asynchronous interfaces.
```python
import asyncio
import torch
from mmdet.apis import init_detector, async_inference_detector, show_result
from mmdet.utils.contextmanagers import concurrent
async def main():
config_file = 'configs/faster_rcnn_r50_fpn_1x.py'
checkpoint_file = 'checkpoints/faster_rcnn_r50_fpn_1x_20181010-3d1b3351.pth'
device = 'cuda:0'
model = init_detector(config_file, checkpoint=checkpoint_file, device=device)
# queue is used for concurrent inference of multiple images
streamqueue = asyncio.Queue()
# queue size defines concurrency level
streamqueue_size = 3
for _ in range(streamqueue_size):
streamqueue.put_nowait(torch.cuda.Stream(device=device))
# test a single image and show the results
img = 'test.jpg' # or img = mmcv.imread(img), which will only load it once
async with concurrent(streamqueue):
result = await async_inference_detector(model, img)
# visualize the results in a new window
show_result(img, result, model.CLASSES)
# or save the visualization results to image files
show_result(img, result, model.CLASSES, out_file='result.jpg')
asyncio.run(main())
```
## Train a model
......
from .env import get_root_logger, init_dist, set_random_seed
from .inference import (inference_detector, init_detector, show_result,
show_result_pyplot)
from .inference import (async_inference_detector, inference_detector,
init_detector, show_result, show_result_pyplot)
from .train import train_detector
__all__ = [
'init_dist', 'get_root_logger', 'set_random_seed', 'train_detector',
'init_detector', 'inference_detector', 'show_result', 'show_result_pyplot'
'async_inference_detector', 'init_dist', 'get_root_logger',
'set_random_seed', 'train_detector', 'init_detector', 'inference_detector',
'show_result', 'show_result_pyplot'
]
......@@ -84,7 +84,34 @@ def inference_detector(model, img):
# forward the model
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
return result
async def async_inference_detector(model, img):
"""Async inference image(s) with the detector.
Args:
model (nn.Module): The loaded detector.
imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
images.
Returns:
Awaitable detection results.
"""
cfg = model.cfg
device = next(model.parameters()).device # model device
# build the data pipeline
test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
test_pipeline = Compose(test_pipeline)
# prepare data
data = dict(img=img)
data = test_pipeline(data)
data = scatter(collate([data], samples_per_gpu=1), [device])[0]
# We don't restore `torch.is_grad_enabled()` value during concurrent
# inference since execution can overlap
torch.set_grad_enabled(False)
result = await model.aforward_test(rescale=True, **data)
return result
......
......@@ -65,7 +65,6 @@ class RPNHead(AnchorHead):
rpn_cls_score = cls_scores[idx]
rpn_bbox_pred = bbox_preds[idx]
assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
anchors = mlvl_anchors[idx]
rpn_cls_score = rpn_cls_score.permute(1, 2, 0)
if self.use_sigmoid_cls:
rpn_cls_score = rpn_cls_score.reshape(-1)
......@@ -74,6 +73,7 @@ class RPNHead(AnchorHead):
rpn_cls_score = rpn_cls_score.reshape(-1, 2)
scores = rpn_cls_score.softmax(dim=1)[:, 1]
rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, 4)
anchors = mlvl_anchors[idx]
if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre:
_, topk_inds = scores.topk(cfg.nms_pre)
rpn_bbox_pred = rpn_bbox_pred[topk_inds, :]
......
......@@ -61,6 +61,10 @@ class BaseDetector(nn.Module):
"""
pass
@abstractmethod
async def async_simple_test(self, img, img_meta, **kwargs):
pass
@abstractmethod
def simple_test(self, img, img_meta, **kwargs):
pass
......@@ -74,6 +78,26 @@ class BaseDetector(nn.Module):
logger = logging.getLogger()
logger.info('load model from: {}'.format(pretrained))
async def aforward_test(self, *, img, img_meta, **kwargs):
for var, name in [(img, 'img'), (img_meta, 'img_meta')]:
if not isinstance(var, list):
raise TypeError('{} must be a list, but got {}'.format(
name, type(var)))
num_augs = len(img)
if num_augs != len(img_meta):
raise ValueError(
'num of augmentations ({}) != num of image meta ({})'.format(
len(img), len(img_meta)))
# TODO: remove the restriction of imgs_per_gpu == 1 when prepared
imgs_per_gpu = img[0].size(0)
assert imgs_per_gpu == 1
if num_augs == 1:
return await self.async_simple_test(img[0], img_meta[0], **kwargs)
else:
raise NotImplementedError
def forward_test(self, imgs, img_metas, **kwargs):
"""
Args:
......
import logging
import sys
import torch
from mmdet.core import (bbox2roi, bbox_mapping, merge_aug_bboxes,
merge_aug_masks, merge_aug_proposals, multiclass_nms)
logger = logging.getLogger(__name__)
if sys.version_info >= (3, 7):
from mmdet.utils.contextmanagers import completed
class RPNTestMixin(object):
if sys.version_info >= (3, 7):
async def async_test_rpn(self, x, img_meta, rpn_test_cfg):
sleep_interval = rpn_test_cfg.pop("async_sleep_interval", 0.025)
async with completed(
__name__, "rpn_head_forward",
sleep_interval=sleep_interval):
rpn_outs = self.rpn_head(x)
proposal_inputs = rpn_outs + (img_meta, rpn_test_cfg)
proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
return proposal_list
def simple_test_rpn(self, x, img_meta, rpn_test_cfg):
rpn_outs = self.rpn_head(x)
proposal_inputs = rpn_outs + (img_meta, rpn_test_cfg)
......@@ -37,6 +59,41 @@ class RPNTestMixin(object):
class BBoxTestMixin(object):
if sys.version_info >= (3, 7):
async def async_test_bboxes(self,
x,
img_meta,
proposals,
rcnn_test_cfg,
rescale=False,
bbox_semaphore=None,
global_lock=None):
"""Async test only det bboxes without augmentation."""
rois = bbox2roi(proposals)
roi_feats = self.bbox_roi_extractor(
x[:len(self.bbox_roi_extractor.featmap_strides)], rois)
if self.with_shared_head:
roi_feats = self.shared_head(roi_feats)
sleep_interval = rcnn_test_cfg.get("async_sleep_interval", 0.017)
async with completed(
__name__, "bbox_head_forward",
sleep_interval=sleep_interval):
cls_score, bbox_pred = self.bbox_head(roi_feats)
img_shape = img_meta[0]['img_shape']
scale_factor = img_meta[0]['scale_factor']
det_bboxes, det_labels = self.bbox_head.get_det_bboxes(
rois,
cls_score,
bbox_pred,
img_shape,
scale_factor,
rescale=rescale,
cfg=rcnn_test_cfg)
return det_bboxes, det_labels
def simple_test_bboxes(self,
x,
img_meta,
......@@ -102,6 +159,46 @@ class BBoxTestMixin(object):
class MaskTestMixin(object):
if sys.version_info >= (3, 7):
async def async_test_mask(self,
x,
img_meta,
det_bboxes,
det_labels,
rescale=False,
mask_test_cfg=None):
# image shape of the first image in the batch (only one)
ori_shape = img_meta[0]['ori_shape']
scale_factor = img_meta[0]['scale_factor']
if det_bboxes.shape[0] == 0:
segm_result = [[]
for _ in range(self.mask_head.num_classes - 1)]
else:
_bboxes = (
det_bboxes[:, :4] *
scale_factor if rescale else det_bboxes)
mask_rois = bbox2roi([_bboxes])
mask_feats = self.mask_roi_extractor(
x[:len(self.mask_roi_extractor.featmap_strides)],
mask_rois)
if self.with_shared_head:
mask_feats = self.shared_head(mask_feats)
if mask_test_cfg and mask_test_cfg.get('async_sleep_interval'):
sleep_interval = mask_test_cfg['async_sleep_interval']
else:
sleep_interval = 0.035
async with completed(
__name__,
"mask_head_forward",
sleep_interval=sleep_interval):
mask_pred = self.mask_head(mask_feats)
segm_result = self.mask_head.get_seg_masks(
mask_pred, _bboxes, det_labels, self.test_cfg.rcnn,
ori_shape, scale_factor, rescale)
return segm_result
def simple_test_mask(self,
x,
img_meta,
......
......@@ -260,14 +260,49 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
return losses
async def async_simple_test(self,
img,
img_meta,
proposals=None,
rescale=False):
"""Async test without augmentation."""
assert self.with_bbox, "Bbox head must be implemented."
x = self.extract_feat(img)
if proposals is None:
proposal_list = await self.async_test_rpn(x, img_meta,
self.test_cfg.rpn)
else:
proposal_list = proposals
det_bboxes, det_labels = await self.async_test_bboxes(
x, img_meta, proposal_list, self.test_cfg.rcnn, rescale=rescale)
bbox_results = bbox2result(det_bboxes, det_labels,
self.bbox_head.num_classes)
if not self.with_mask:
return bbox_results
else:
segm_results = await self.async_test_mask(
x,
img_meta,
det_bboxes,
det_labels,
rescale=rescale,
mask_test_cfg=self.test_cfg.get('mask'))
return bbox_results, segm_results
def simple_test(self, img, img_meta, proposals=None, rescale=False):
"""Test without augmentation."""
assert self.with_bbox, "Bbox head must be implemented."
x = self.extract_feat(img)
proposal_list = self.simple_test_rpn(
x, img_meta, self.test_cfg.rpn) if proposals is None else proposals
if proposals is None:
proposal_list = self.simple_test_rpn(x, img_meta,
self.test_cfg.rpn)
else:
proposal_list = proposals
det_bboxes, det_labels = self.simple_test_bboxes(
x, img_meta, proposal_list, self.test_cfg.rcnn, rescale=rescale)
......
......@@ -61,6 +61,7 @@
// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCAtomics.cuh>
#include <stdio.h>
#include <math.h>
......@@ -261,7 +262,7 @@ void deformable_im2col(
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
scalar_t *data_col_ = data_col.data<scalar_t>();
deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
channel_per_deformable_group, parallel_imgs, channels, deformable_group,
......@@ -355,7 +356,7 @@ void deformable_col2im(
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
scalar_t *grad_im_ = grad_im.data<scalar_t>();
deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
num_kernels, data_col_, data_offset_, channels, height, width, ksize_h,
ksize_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group,
......@@ -454,7 +455,7 @@ void deformable_col2im_coord(
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
scalar_t *grad_offset_ = grad_offset.data<scalar_t>();
deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
num_kernels, data_col_, data_im_, data_offset_, channels, height, width,
ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group,
......@@ -784,7 +785,7 @@ void modulated_deformable_im2col_cuda(
const scalar_t *data_mask_ = data_mask.data<scalar_t>();
scalar_t *data_col_ = data_col.data<scalar_t>();
modulated_deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
modulated_deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,
batch_size, channels, deformable_group, height_col, width_col, data_col_);
......@@ -816,7 +817,7 @@ void modulated_deformable_col2im_cuda(
const scalar_t *data_mask_ = data_mask.data<scalar_t>();
scalar_t *grad_im_ = grad_im.data<scalar_t>();
modulated_deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
modulated_deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im,
kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group,
......@@ -851,7 +852,7 @@ void modulated_deformable_col2im_coord_cuda(
scalar_t *grad_offset_ = grad_offset.data<scalar_t>();
scalar_t *grad_mask_ = grad_mask.data<scalar_t>();
modulated_deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
modulated_deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im,
kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group,
......
......@@ -296,7 +296,7 @@ void DeformablePSROIPoolForward(const at::Tensor data,
scalar_t *top_data = out.data<scalar_t>();
scalar_t *top_count_data = top_count.data<scalar_t>();
DeformablePSROIPoolForwardKernel<<<GET_BLOCKS(count), CUDA_NUM_THREADS>>>(
DeformablePSROIPoolForwardKernel<<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
count, bottom_data, (scalar_t)spatial_scale, channels, height, width, pooled_height, pooled_width,
bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part, output_dim,
group_size, part_size, num_classes, channels_each_class, top_data, top_count_data);
......@@ -349,7 +349,7 @@ void DeformablePSROIPoolBackwardAcc(const at::Tensor out_grad,
scalar_t *bottom_trans_diff = no_trans ? NULL : trans_grad.data<scalar_t>();
const scalar_t *top_count_data = top_count.data<scalar_t>();
DeformablePSROIPoolBackwardAccKernel<<<GET_BLOCKS(count), CUDA_NUM_THREADS>>>(
DeformablePSROIPoolBackwardAccKernel<<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
count, top_diff, top_count_data, num_rois, (scalar_t)spatial_scale, channels, height, width,
pooled_height, pooled_width, output_dim, bottom_data_diff, bottom_trans_diff,
bottom_data, bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part,
......@@ -361,4 +361,4 @@ void DeformablePSROIPoolBackwardAcc(const at::Tensor out_grad,
{
printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err));
}
}
\ No newline at end of file
}
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCAtomics.cuh>
#define CUDA_1D_KERNEL_LOOP(i, n) \
......@@ -63,7 +64,8 @@ int MaskedIm2colForwardLaucher(const at::Tensor bottom_data, const int height,
const int64_t *mask_w_idx_ = mask_w_idx.data<int64_t>();
scalar_t *top_data_ = top_data.data<scalar_t>();
MaskedIm2colForward<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, at::cuda::getCurrentCUDAStream()
>>>(
output_size, bottom_data_, height, width, kernel_h, kernel_w,
pad_h, pad_w, mask_h_idx_, mask_w_idx_, mask_cnt, top_data_);
}));
......@@ -103,7 +105,7 @@ int MaskedCol2imForwardLaucher(const at::Tensor bottom_data, const int height,
scalar_t *top_data_ = top_data.data<scalar_t>();
MaskedCol2imForward<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, at::cuda::getCurrentCUDAStream()>>>(
output_size, bottom_data_, height, width, channels, mask_h_idx_,
mask_w_idx_, mask_cnt, top_data_);
}));
......
......@@ -96,16 +96,19 @@ at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) {
dim3 blocks(THCCeilDiv(boxes_num, threadsPerBlock),
THCCeilDiv(boxes_num, threadsPerBlock));
dim3 threads(threadsPerBlock);
nms_kernel<<<blocks, threads>>>(boxes_num,
nms_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(boxes_num,
nms_overlap_thresh,
boxes_dev,
mask_dev);
std::vector<unsigned long long> mask_host(boxes_num * col_blocks);
THCudaCheck(cudaMemcpy(&mask_host[0],
mask_dev,
sizeof(unsigned long long) * boxes_num * col_blocks,
cudaMemcpyDeviceToHost));
THCudaCheck(cudaMemcpyAsync(
&mask_host[0],
mask_dev,
sizeof(unsigned long long) * boxes_num * col_blocks,
cudaMemcpyDeviceToHost,
at::cuda::getCurrentCUDAStream()
));
std::vector<unsigned long long> remv(col_blocks);
memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
......
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <cmath>
#include <vector>
......
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCAtomics.cuh>
#define CUDA_1D_KERNEL_LOOP(i, n) \
......@@ -131,7 +132,7 @@ int ROIAlignForwardLaucher(const at::Tensor features, const at::Tensor rois,
scalar_t *top_data = output.data<scalar_t>();
ROIAlignForward<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, at::cuda::getCurrentCUDAStream()>>>(
output_size, bottom_data, rois_data, scalar_t(spatial_scale),
sample_num, channels, height, width, pooled_height,
pooled_width, top_data);
......@@ -272,7 +273,7 @@ int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
}
ROIAlignBackward<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, at::cuda::getCurrentCUDAStream()>>>(
output_size, top_diff, rois_data, spatial_scale, sample_num,
channels, height, width, pooled_height, pooled_width,
bottom_diff);
......
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCAtomics.cuh>
#define CUDA_1D_KERNEL_LOOP(i, n) \
......@@ -93,7 +94,7 @@ int ROIPoolForwardLaucher(const at::Tensor features, const at::Tensor rois,
int *argmax_data = argmax.data<int>();
ROIPoolForward<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, at::cuda::getCurrentCUDAStream()>>>(
output_size, bottom_data, rois_data, scalar_t(spatial_scale),
channels, height, width, pooled_h, pooled_w, top_data,
argmax_data);
......@@ -146,7 +147,7 @@ int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
}
ROIPoolBackward<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, at::cuda::getCurrentCUDAStream()>>>(
output_size, top_diff, rois_data, argmax_data,
scalar_t(spatial_scale), channels, height, width, pooled_h,
pooled_w, bottom_diff);
......
......@@ -120,7 +120,7 @@ at::Tensor SigmoidFocalLoss_forward_cuda(const at::Tensor &logits,
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
logits.scalar_type(), "SigmoidFocalLoss_forward", [&] {
SigmoidFocalLossForward<scalar_t><<<grid, block>>>(
SigmoidFocalLossForward<scalar_t><<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
losses_size, logits.contiguous().data<scalar_t>(),
targets.contiguous().data<int64_t>(), num_classes, gamma, alpha,
num_samples, losses.data<scalar_t>());
......@@ -159,7 +159,7 @@ at::Tensor SigmoidFocalLoss_backward_cuda(const at::Tensor &logits,
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
logits.scalar_type(), "SigmoidFocalLoss_backward", [&] {
SigmoidFocalLossBackward<scalar_t><<<grid, block>>>(
SigmoidFocalLossBackward<scalar_t><<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
d_logits_size, logits.contiguous().data<scalar_t>(),
targets.contiguous().data<int64_t>(),
d_losses.contiguous().data<scalar_t>(), num_classes, gamma, alpha,
......
# coding: utf-8
import asyncio
import contextlib
import logging
import os
import time
from typing import List
import torch
logger = logging.getLogger(__name__)
DEBUG_COMPLETED_TIME = bool(os.environ.get('DEBUG_COMPLETED_TIME', False))
@contextlib.asynccontextmanager
async def completed(trace_name="",
name="",
sleep_interval=0.05,
streams: List[torch.cuda.Stream] = None):
"""
Async context manager that waits for work to complete on
given CUDA streams.
"""
if not torch.cuda.is_available():
yield
return
stream_before_context_switch = torch.cuda.current_stream()
if not streams:
streams = [stream_before_context_switch]
else:
streams = [s if s else stream_before_context_switch for s in streams]
end_events = [
torch.cuda.Event(enable_timing=DEBUG_COMPLETED_TIME) for _ in streams
]
if DEBUG_COMPLETED_TIME:
start = torch.cuda.Event(enable_timing=True)
stream_before_context_switch.record_event(start)
cpu_start = time.monotonic()
logger.debug("%s %s starting, streams: %s", trace_name, name, streams)
grad_enabled_before = torch.is_grad_enabled()
try:
yield
finally:
current_stream = torch.cuda.current_stream()
assert current_stream == stream_before_context_switch
if DEBUG_COMPLETED_TIME:
cpu_end = time.monotonic()
for i, stream in enumerate(streams):
event = end_events[i]
stream.record_event(event)
grad_enabled_after = torch.is_grad_enabled()
# observed change of torch.is_grad_enabled() during concurrent run of
# async_test_bboxes code
assert grad_enabled_before == grad_enabled_after, \
"Unexpected is_grad_enabled() value change"
are_done = [e.query() for e in end_events]
logger.debug("%s %s completed: %s streams: %s", trace_name, name,
are_done, streams)
with torch.cuda.stream(stream_before_context_switch):
while not all(are_done):
await asyncio.sleep(sleep_interval)
are_done = [e.query() for e in end_events]
logger.debug("%s %s completed: %s streams: %s", trace_name,
name, are_done, streams)
current_stream = torch.cuda.current_stream()
assert current_stream == stream_before_context_switch
if DEBUG_COMPLETED_TIME:
cpu_time = (cpu_end - cpu_start) * 1000
stream_times_ms = ""
for i, stream in enumerate(streams):
elapsed_time = start.elapsed_time(end_events[i])
stream_times_ms += " {stream} {elapsed_time:.2f} ms".format(
stream, elapsed_time)
logger.info("{trace_name} {name} cpu_time {cpu_time:.2f} ms",
trace_name, name, cpu_time, stream_times_ms)
@contextlib.asynccontextmanager
async def concurrent(streamqueue: asyncio.Queue,
trace_name="concurrent",
name="stream"):
"""Run code concurrently in different streams.
:param streamqueue: asyncio.Queue instance.
Queue tasks define the pool of streams used for concurrent execution.
"""
if not torch.cuda.is_available():
yield
return
initial_stream = torch.cuda.current_stream()
with torch.cuda.stream(initial_stream):
stream = await streamqueue.get()
assert isinstance(stream, torch.cuda.Stream)
try:
with torch.cuda.stream(stream):
logger.debug("%s %s is starting, stream: %s", trace_name, name,
stream)
yield
current = torch.cuda.current_stream()
assert current == stream
logger.debug("%s %s has finished, stream: %s", trace_name,
name, stream)
finally:
streamqueue.task_done()
streamqueue.put_nowait(stream)
import contextlib
import sys
import time
import torch
if sys.version_info >= (3, 7):
@contextlib.contextmanager
def profile_time(trace_name,
name,
enabled=True,
stream=None,
end_stream=None):
"""Print time spent by CPU and GPU.
Useful as a temporary context manager to find sweet spots of
code suitable for async implementation.
"""
if (not enabled) or not torch.cuda.is_available():
yield
return
stream = stream if stream else torch.cuda.current_stream()
end_stream = end_stream if end_stream else stream
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
stream.record_event(start)
try:
cpu_start = time.monotonic()
yield
finally:
cpu_end = time.monotonic()
end_stream.record_event(end)
end.synchronize()
cpu_time = (cpu_end - cpu_start) * 1000
gpu_time = start.elapsed_time(end)
msg = "{} {} cpu_time {:.2f} ms ".format(trace_name, name,
cpu_time)
msg += "gpu_time {:.2f} ms stream {}".format(gpu_time, stream)
print(msg, end_stream)
......@@ -162,7 +162,7 @@ if __name__ == '__main__':
],
license='Apache License 2.0',
setup_requires=['pytest-runner', 'cython', 'numpy'],
tests_require=['pytest', 'xdoctest'],
tests_require=['pytest', 'xdoctest', 'asynctest'],
install_requires=get_requirements(),
ext_modules=[
make_cuda_ext(
......
# coding: utf-8
import asyncio
import os
import shutil
import urllib
import mmcv
import torch
from mmdet.apis import (async_inference_detector, inference_detector,
init_detector, show_result)
from mmdet.utils.contextmanagers import concurrent
from mmdet.utils.profiling import profile_time
async def main():
"""
Benchmark between async and synchronous inference interfaces.
Sample runs for 20 demo images on K80 GPU, model - mask_rcnn_r50_fpn_1x:
async sync
7981.79 ms 9660.82 ms
8074.52 ms 9660.94 ms
7976.44 ms 9406.83 ms
Async variant takes about 0.83-0.85 of the time of the synchronous
interface.
"""
project_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
config_file = os.path.join(project_dir, 'configs/mask_rcnn_r50_fpn_1x.py')
checkpoint_file = os.path.join(
project_dir, 'checkpoints/mask_rcnn_r50_fpn_1x_20181010-069fa190.pth')
if not os.path.exists(checkpoint_file):
url = ('https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection'
'/models/mask_rcnn_r50_fpn_1x_20181010-069fa190.pth')
print('Downloading {} ...'.format(url))
local_filename, _ = urllib.request.urlretrieve(url)
os.makedirs(os.path.dirname(checkpoint_file), exist_ok=True)
shutil.move(local_filename, checkpoint_file)
print('Saved as {}'.format(checkpoint_file))
else:
print('Using existing checkpoint {}'.format(checkpoint_file))
device = 'cuda:0'
model = init_detector(
config_file, checkpoint=checkpoint_file, device=device)
# queue is used for concurrent inference of multiple images
streamqueue = asyncio.Queue()
# queue size defines concurrency level
streamqueue_size = 4
for _ in range(streamqueue_size):
streamqueue.put_nowait(torch.cuda.Stream(device=device))
# test a single image and show the results
img = mmcv.imread(os.path.join(project_dir, 'demo/demo.jpg'))
# warmup
await async_inference_detector(model, img)
async def detect(img):
async with concurrent(streamqueue):
return await async_inference_detector(model, img)
num_of_images = 20
with profile_time('benchmark', 'async'):
tasks = [
asyncio.create_task(detect(img)) for _ in range(num_of_images)
]
async_results = await asyncio.gather(*tasks)
with torch.cuda.stream(torch.cuda.default_stream()):
with profile_time('benchmark', 'sync'):
sync_results = [
inference_detector(model, img) for _ in range(num_of_images)
]
result_dir = os.path.join(project_dir, 'demo')
show_result(
img,
async_results[0],
model.CLASSES,
score_thr=0.5,
show=False,
out_file=os.path.join(result_dir, 'result_async.jpg'))
show_result(
img,
sync_results[0],
model.CLASSES,
score_thr=0.5,
show=False,
out_file=os.path.join(result_dir, 'result_sync.jpg'))
if __name__ == '__main__':
asyncio.run(main())
......@@ -4,3 +4,4 @@ yapf
pytest-cov
codecov
xdoctest >= 0.10.0
asynctest
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment