Commit bc5806ba authored by Tai-Wang's avatar Tai-Wang Committed by ChaimZhu
Browse files

[Refactor] ImVoxelNet

parent f63a62b8
model = dict( model = dict(
type='ImVoxelNet', type='ImVoxelNet',
data_preprocessor=dict(
type='Det3DDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_size_divisor=32),
backbone=dict( backbone=dict(
type='ResNet', type='mmdet.ResNet',
depth=50, depth=50,
num_stages=4, num_stages=4,
out_indices=(0, 1, 2, 3), out_indices=(0, 1, 2, 3),
...@@ -11,7 +17,7 @@ model = dict( ...@@ -11,7 +17,7 @@ model = dict(
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'), init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'),
style='pytorch'), style='pytorch'),
neck=dict( neck=dict(
type='FPN', type='mmdet.FPN',
in_channels=[256, 512, 1024, 2048], in_channels=[256, 512, 1024, 2048],
out_channels=64, out_channels=64,
num_outs=4), num_outs=4),
...@@ -31,14 +37,16 @@ model = dict( ...@@ -31,14 +37,16 @@ model = dict(
diff_rad_by_sin=True, diff_rad_by_sin=True,
bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'), bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'),
loss_cls=dict( loss_cls=dict(
type='FocalLoss', type='mmdet.FocalLoss',
use_sigmoid=True, use_sigmoid=True,
gamma=2.0, gamma=2.0,
alpha=0.25, alpha=0.25,
loss_weight=1.0), loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0), loss_bbox=dict(
type='mmdet.SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0),
loss_dir=dict( loss_dir=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.2)), type='mmdet.CrossEntropyLoss', use_sigmoid=False,
loss_weight=0.2)),
n_voxels=[216, 248, 12], n_voxels=[216, 248, 12],
anchor_generator=dict( anchor_generator=dict(
type='AlignedAnchor3DRangeGenerator', type='AlignedAnchor3DRangeGenerator',
...@@ -46,8 +54,8 @@ model = dict( ...@@ -46,8 +54,8 @@ model = dict(
rotations=[.0]), rotations=[.0]),
train_cfg=dict( train_cfg=dict(
assigner=dict( assigner=dict(
type='MaxIoUAssigner', type='Max3DIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'), iou_calculator=dict(type='mmdet3d.BboxOverlapsNearest3D'),
pos_iou_thr=0.6, pos_iou_thr=0.6,
neg_iou_thr=0.45, neg_iou_thr=0.45,
min_pos_iou=0.45, min_pos_iou=0.45,
...@@ -69,92 +77,119 @@ data_root = 'data/kitti/' ...@@ -69,92 +77,119 @@ data_root = 'data/kitti/'
class_names = ['Car'] class_names = ['Car']
input_modality = dict(use_lidar=False, use_camera=True) input_modality = dict(use_lidar=False, use_camera=True)
point_cloud_range = [0, -39.68, -3, 69.12, 39.68, 1] point_cloud_range = [0, -39.68, -3, 69.12, 39.68, 1]
img_norm_cfg = dict( metainfo = dict(CLASSES=class_names)
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
# file_client_args = dict(backend='disk')
# Uncomment the following if use ceph or other file clients.
# See https://mmcv.readthedocs.io/en/latest/api.html#mmcv.fileio.FileClient
# for more details.
file_client_args = dict(
backend='petrel',
path_mapping=dict({
'./data/kitti/':
's3://openmmlab/datasets/detection3d/kitti/',
'data/kitti/':
's3://openmmlab/datasets/detection3d/kitti/'
}))
train_pipeline = [ train_pipeline = [
dict(type='LoadAnnotations3D'), dict(type='LoadAnnotations3D'),
dict(type='LoadImageFromFile'), dict(type='LoadImageFromFileMono3D', file_client_args=file_client_args),
dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5), dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
dict( dict(
type='Resize', type='RandomResize', scale=[(1173, 352), (1387, 416)],
img_scale=[(1173, 352), (1387, 416)], keep_ratio=True),
keep_ratio=True,
multiscale_mode='range'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='DefaultFormatBundle3D', class_names=class_names), dict(type='Pack3DDetInputs', keys=['img', 'gt_bboxes_3d', 'gt_labels_3d'])
dict(type='Collect3D', keys=['img', 'gt_bboxes_3d', 'gt_labels_3d'])
] ]
test_pipeline = [ test_pipeline = [
dict(type='LoadImageFromFile'), dict(type='LoadImageFromFileMono3D', file_client_args=file_client_args),
dict(type='Resize', img_scale=(1280, 384), keep_ratio=True), dict(type='Resize', scale=(1280, 384), keep_ratio=True),
dict(type='Normalize', **img_norm_cfg), dict(type='Pack3DDetInputs', keys=['img'])
dict(type='Pad', size_divisor=32),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['img'])
] ]
data = dict( train_dataloader = dict(
samples_per_gpu=4, batch_size=4,
workers_per_gpu=4, num_workers=4,
train=dict( persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type='RepeatDataset', type='RepeatDataset',
times=3, times=3,
dataset=dict( dataset=dict(
type=dataset_type, type=dataset_type,
data_root=data_root, data_root=data_root,
ann_file=data_root + 'kitti_infos_train.pkl', ann_file='kitti_infos_train.pkl',
split='training', data_prefix=dict(img='training/image_2'),
pts_prefix='velodyne_reduced',
pipeline=train_pipeline, pipeline=train_pipeline,
modality=input_modality, modality=input_modality,
classes=class_names, test_mode=False,
test_mode=False)), metainfo=metainfo)))
val=dict( val_dataloader = dict(
batch_size=1,
num_workers=1,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type, type=dataset_type,
data_root=data_root, data_root=data_root,
ann_file=data_root + 'kitti_infos_val.pkl', ann_file='kitti_infos_val.pkl',
split='training', data_prefix=dict(img='training/image_2'),
pts_prefix='velodyne_reduced',
pipeline=test_pipeline, pipeline=test_pipeline,
modality=input_modality, modality=input_modality,
classes=class_names, test_mode=True,
test_mode=True), metainfo=metainfo))
test=dict( test_dataloader = val_dataloader
type=dataset_type,
data_root=data_root, val_evaluator = dict(
type='KittiMetric',
ann_file=data_root + 'kitti_infos_val.pkl', ann_file=data_root + 'kitti_infos_val.pkl',
split='training', metric='bbox')
pts_prefix='velodyne_reduced', test_evaluator = val_evaluator
pipeline=test_pipeline,
modality=input_modality,
classes=class_names,
test_mode=True))
optimizer = dict( # optimizer
type='AdamW', optim_wrapper = dict(
lr=0.0001, type='OptimWrapper',
weight_decay=0.0001, optimizer=dict(type='AdamW', lr=0.0001, weight_decay=0.0001),
paramwise_cfg=dict( paramwise_cfg=dict(
custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)})) custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)}),
optimizer_config = dict(grad_clip=dict(max_norm=35., norm_type=2)) clip_grad=dict(max_norm=35., norm_type=2))
lr_config = dict(policy='step', step=[8, 11]) param_scheduler = [
total_epochs = 12 dict(
type='MultiStepLR',
begin=0,
end=12,
by_epoch=True,
milestones=[8, 11],
gamma=0.1)
]
# hooks
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=1),
sampler_seed=dict(type='DistSamplerSeedHook'),
)
# training schedule for 2x
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=12, val_interval=1)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
# runtime
default_scope = 'mmdet3d'
env_cfg = dict(
cudnn_benchmark=False,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'),
)
checkpoint_config = dict(interval=1, max_keep_ckpts=1)
log_config = dict(
interval=50,
hooks=[dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook')])
evaluation = dict(interval=1)
dist_params = dict(backend='nccl')
find_unused_parameters = True # only 1 of 4 FPN outputs is used
log_level = 'INFO' log_level = 'INFO'
load_from = None load_from = None
resume_from = None resume = False
workflow = [('train', 1)] dist_params = dict(backend='nccl')
find_unused_parameters = True # only 1 of 4 FPN outputs is used
...@@ -259,7 +259,7 @@ class Det3DDataset(BaseDataset): ...@@ -259,7 +259,7 @@ class Det3DDataset(BaseDataset):
info['lidar2img'] = info['cam2img'] @ info['lidar2cam'] info['lidar2img'] = info['cam2img'] @ info['lidar2cam']
if not self.test_mode: if not self.test_mode:
# used in traing # used in training
info['ann_info'] = self.parse_ann_info(info) info['ann_info'] = self.parse_ann_info(info)
if self.test_mode and self.load_eval_anns: if self.test_mode and self.load_eval_anns:
info['eval_ann_info'] = self.parse_ann_info(info) info['eval_ann_info'] = self.parse_ann_info(info)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple, Union
import torch import torch
from mmdet3d.core import bbox3d2result, build_prior_generator from mmdet3d.core import Det3DDataSample, InstanceList, build_prior_generator
from mmdet3d.core.utils import ConfigType, OptConfigType, SampleList
from mmdet3d.models.fusion_layers.point_fusion import point_sample from mmdet3d.models.fusion_layers.point_fusion import point_sample
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from mmdet.models.detectors import BaseDetector from mmdet.models.detectors import BaseDetector
...@@ -9,20 +12,40 @@ from mmdet.models.detectors import BaseDetector ...@@ -9,20 +12,40 @@ from mmdet.models.detectors import BaseDetector
@MODELS.register_module() @MODELS.register_module()
class ImVoxelNet(BaseDetector): class ImVoxelNet(BaseDetector):
r"""`ImVoxelNet <https://arxiv.org/abs/2106.01178>`_.""" r"""`ImVoxelNet <https://arxiv.org/abs/2106.01178>`_.
Args:
backbone (:obj:`ConfigDict` or dict): The backbone config.
neck (:obj:`ConfigDict` or dict): The neck config.
neck_3d (:obj:`ConfigDict` or dict): The 3D neck config.
bbox_head (:obj:`ConfigDict` or dict): The bbox head config.
n_voxels (list): Number of voxels along x, y, z axis.
anchor_generator (:obj:`ConfigDict` or dict): The anchor generator
config.
train_cfg (:obj:`ConfigDict` or dict, optional): Config dict of
training hyper-parameters. Defaults to None.
test_cfg (:obj:`ConfigDict` or dict, optional): Config dict of test
hyper-parameters. Defaults to None.
data_preprocessor (dict or ConfigDict, optional): The pre-process
config of :class:`BaseDataPreprocessor`. it usually includes,
``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
init_cfg (:obj:`ConfigDict` or dict, optional): The initialization
config. Defaults to None.
"""
def __init__(self, def __init__(self,
backbone, backbone: ConfigType,
neck, neck: ConfigType,
neck_3d, neck_3d: ConfigType,
bbox_head, bbox_head: ConfigType,
n_voxels, n_voxels: List,
anchor_generator, anchor_generator: ConfigType,
train_cfg=None, train_cfg: OptConfigType = None,
test_cfg=None, test_cfg: OptConfigType = None,
pretrained=None, data_preprocessor: OptConfigType = None,
init_cfg=None): init_cfg: OptConfigType = None):
super().__init__(init_cfg=init_cfg) super().__init__(
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
self.backbone = MODELS.build(backbone) self.backbone = MODELS.build(backbone)
self.neck = MODELS.build(neck) self.neck = MODELS.build(neck)
self.neck_3d = MODELS.build(neck_3d) self.neck_3d = MODELS.build(neck_3d)
...@@ -34,22 +57,59 @@ class ImVoxelNet(BaseDetector): ...@@ -34,22 +57,59 @@ class ImVoxelNet(BaseDetector):
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
def extract_feat(self, img, img_metas): def convert_to_datasample(self, results_list: InstanceList) -> SampleList:
"""Convert results list to `Det3DDataSample`.
Args:
results_list (list[:obj:`InstanceData`]): 3D Detection results of
each image.
Returns:
list[:obj:`Det3DDataSample`]: 3D Detection results of the
input images. Each Det3DDataSample usually contain
'pred_instances_3d'. And the ``pred_instances_3d`` usually
contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instance, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes_3d (Tensor): Contains a tensor with shape
(num_instances, C) where C >=7.
"""
out_results_list = []
for i in range(len(results_list)):
result = Det3DDataSample()
result.pred_instances_3d = results_list[i]
out_results_list.append(result)
return out_results_list
def extract_feat(self, batch_inputs_dict: dict,
batch_data_samples: SampleList):
"""Extract 3d features from the backbone -> fpn -> 3d projection. """Extract 3d features from the backbone -> fpn -> 3d projection.
Args: Args:
img (torch.Tensor): Input images of shape (N, C_in, H, W). batch_inputs_dict (dict): The model input dict which include
img_metas (list): Image metas. the 'imgs' key.
- imgs (torch.Tensor, optional): Image of each sample.
batch_data_samples (list[:obj:`DetDataSample`]): The batch
data samples. It usually includes information such
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
Returns: Returns:
torch.Tensor: of shape (N, C_out, N_x, N_y, N_z) torch.Tensor: of shape (N, C_out, N_x, N_y, N_z)
""" """
img = batch_inputs_dict['imgs']
batch_img_metas = [
data_samples.metainfo for data_samples in batch_data_samples
]
x = self.backbone(img) x = self.backbone(img)
x = self.neck(x)[0] x = self.neck(x)[0]
points = self.anchor_generator.grid_anchors( points = self.anchor_generator.grid_anchors(
[self.n_voxels[::-1]], device=img.device)[0][:, :3] [self.n_voxels[::-1]], device=img.device)[0][:, :3]
volumes = [] volumes = []
for feature, img_meta in zip(x, img_metas): for feature, img_meta in zip(x, batch_img_metas):
img_scale_factor = ( img_scale_factor = (
points.new_tensor(img_meta['scale_factor'][:2]) points.new_tensor(img_meta['scale_factor'][:2])
if 'scale_factor' in img_meta.keys() else 1) if 'scale_factor' in img_meta.keys() else 1)
...@@ -57,11 +117,12 @@ class ImVoxelNet(BaseDetector): ...@@ -57,11 +117,12 @@ class ImVoxelNet(BaseDetector):
img_crop_offset = ( img_crop_offset = (
points.new_tensor(img_meta['img_crop_offset']) points.new_tensor(img_meta['img_crop_offset'])
if 'img_crop_offset' in img_meta.keys() else 0) if 'img_crop_offset' in img_meta.keys() else 0)
lidar2img = points.new_tensor(img_meta['lidar2img'])
volume = point_sample( volume = point_sample(
img_meta, img_meta,
img_features=feature[None, ...], img_features=feature[None, ...],
points=points, points=points,
proj_mat=points.new_tensor(img_meta['lidar2img']), proj_mat=lidar2img,
coord_type='LIDAR', coord_type='LIDAR',
img_scale_factor=img_scale_factor, img_scale_factor=img_scale_factor,
img_crop_offset=img_crop_offset, img_crop_offset=img_crop_offset,
...@@ -75,64 +136,77 @@ class ImVoxelNet(BaseDetector): ...@@ -75,64 +136,77 @@ class ImVoxelNet(BaseDetector):
x = self.neck_3d(x) x = self.neck_3d(x)
return x return x
def forward_train(self, img, img_metas, gt_bboxes_3d, gt_labels_3d, def loss(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
**kwargs): **kwargs) -> Union[dict, list]:
"""Forward of training. """Calculate losses from a batch of inputs and data samples.
Args: Args:
img (torch.Tensor): Input images of shape (N, C_in, H, W). batch_inputs_dict (dict): The model input dict which include
img_metas (list): Image metas. the 'imgs' key.
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): gt bboxes of each batch.
gt_labels_3d (list[torch.Tensor]): gt class labels of each batch. - imgs (torch.Tensor, optional): Image of each sample.
batch_data_samples (list[:obj:`DetDataSample`]): The batch
data samples. It usually includes information such
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
Returns: Returns:
dict[str, torch.Tensor]: A dictionary of loss components. dict: A dictionary of loss components.
""" """
x = self.extract_feat(img, img_metas)
x = self.bbox_head(x) x = self.extract_feat(batch_inputs_dict, batch_data_samples)
losses = self.bbox_head.loss(*x, gt_bboxes_3d, gt_labels_3d, img_metas) losses = self.bbox_head.loss(x, batch_data_samples, **kwargs)
return losses return losses
def forward_test(self, img, img_metas, **kwargs): def predict(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
"""Forward of testing. **kwargs) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing.
Args: Args:
img (torch.Tensor): Input images of shape (N, C_in, H, W). batch_inputs_dict (dict): The model input dict which include
img_metas (list): Image metas. the 'imgs' key.
Returns: - imgs (torch.Tensor, optional): Image of each sample.
list[dict]: Predicted 3d boxes.
"""
# not supporting aug_test for now
return self.simple_test(img, img_metas)
def simple_test(self, img, img_metas): batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
"""Test without augmentations. Samples. It usually includes information such as
`gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
Args:
img (torch.Tensor): Input images of shape (N, C_in, H, W).
img_metas (list): Image metas.
Returns: Returns:
list[dict]: Predicted 3d boxes. list[:obj:`Det3DDataSample`]: Detection results of the
input images. Each Det3DDataSample usually contain
'pred_instances_3d'. And the ``pred_instances_3d`` usually
contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instance, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes_3d (Tensor): Contains a tensor with shape
(num_instances, C) where C >=7.
""" """
x = self.extract_feat(img, img_metas) x = self.extract_feat(batch_inputs_dict, batch_data_samples)
x = self.bbox_head(x) results_list = self.bbox_head.predict(x, batch_data_samples, **kwargs)
bbox_list = self.bbox_head.get_bboxes(*x, img_metas) predictions = self.convert_to_datasample(results_list)
bbox_results = [ return predictions
bbox3d2result(det_bboxes, det_scores, det_labels)
for det_bboxes, det_scores, det_labels in bbox_list
]
return bbox_results
def aug_test(self, imgs, img_metas, **kwargs): def _forward(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
"""Test with augmentations. *args, **kwargs) -> Tuple[List[torch.Tensor]]:
"""Network forward process. Usually includes backbone, neck and head
forward without any post-processing.
Args: Args:
imgs (list[torch.Tensor]): Input images of shape (N, C_in, H, W). batch_inputs_dict (dict): The model input dict which include
img_metas (list): Image metas. the 'imgs' key.
- imgs (torch.Tensor, optional): Image of each sample.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
Returns: Returns:
list[dict]: Predicted 3d boxes. tuple[list]: A tuple of features from ``bbox_head`` forward.
""" """
raise NotImplementedError x = self.extract_feat(batch_inputs_dict, batch_data_samples)
results = self.bbox_head.forward(x)
return results
import unittest
import torch
from mmengine import DefaultScope
from mmdet3d.registry import MODELS
from tests.utils.model_utils import (_create_detector_inputs,
_get_detector_cfg, _setup_seed)
class TestImVoxelNet(unittest.TestCase):
def test_h3dnet(self):
import mmdet3d.models
assert hasattr(mmdet3d.models, 'ImVoxelNet')
DefaultScope.get_instance('test_ImVoxelNet', scope_name='mmdet3d')
_setup_seed(0)
imvoxel_net_cfg = _get_detector_cfg(
'imvoxelnet/imvoxelnet_4x8_kitti-3d-car.py')
model = MODELS.build(imvoxel_net_cfg)
num_gt_instance = 1
data = [
_create_detector_inputs(
with_points=False,
with_img=True,
img_size=(384, 1280),
num_gt_instance=num_gt_instance,
with_pts_semantic_mask=False,
with_pts_instance_mask=False)
]
if torch.cuda.is_available():
model = model.cuda()
# test simple_test
with torch.no_grad():
batch_inputs, data_samples = model.data_preprocessor(
data, True)
results = model.forward(
batch_inputs, data_samples, mode='predict')
self.assertEqual(len(results), len(data))
self.assertIn('bboxes_3d', results[0].pred_instances_3d)
self.assertIn('scores_3d', results[0].pred_instances_3d)
self.assertIn('labels_3d', results[0].pred_instances_3d)
# save the memory
with torch.no_grad():
losses = model.forward(batch_inputs, data_samples, mode='loss')
self.assertGreater(losses['loss_cls'], 0)
self.assertGreater(losses['loss_bbox'], 0)
self.assertGreater(losses['loss_dir'], 0)
...@@ -104,7 +104,13 @@ def _create_detector_inputs(seed=0, ...@@ -104,7 +104,13 @@ def _create_detector_inputs(seed=0,
points = torch.rand([num_points, points_feat_dim]) points = torch.rand([num_points, points_feat_dim])
else: else:
points = None points = None
if with_img: if with_img:
if isinstance(img_size, tuple):
img = torch.rand(3, img_size[0], img_size[1])
meta_info['img_shape'] = img_size
meta_info['ori_shape'] = img_size
else:
img = torch.rand(3, img_size, img_size) img = torch.rand(3, img_size, img_size)
meta_info['img_shape'] = (img_size, img_size) meta_info['img_shape'] = (img_size, img_size)
meta_info['ori_shape'] = (img_size, img_size) meta_info['ori_shape'] = (img_size, img_size)
...@@ -126,9 +132,8 @@ def _create_detector_inputs(seed=0, ...@@ -126,9 +132,8 @@ def _create_detector_inputs(seed=0,
gt_instance = InstanceData() gt_instance = InstanceData()
gt_instance.labels = torch.randint(0, num_classes, [num_gt_instance]) gt_instance.labels = torch.randint(0, num_classes, [num_gt_instance])
gt_instance.bboxes = torch.rand(num_gt_instance, 4) gt_instance.bboxes = torch.rand(num_gt_instance, 4)
gt_instance.bboxes[:, gt_instance.bboxes[:, 2:] = \
2:] = gt_instance.bboxes[:, :2] + gt_instance.bboxes[:, gt_instance.bboxes[:, :2] + gt_instance.bboxes[:, 2:]
2:]
data_sample.gt_instances = gt_instance data_sample.gt_instances = gt_instance
data_sample.gt_pts_seg = PointData() data_sample.gt_pts_seg = PointData()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment