Unverified Commit 333536f6 authored by Wenwei Zhang's avatar Wenwei Zhang Committed by GitHub
Browse files

Release v1.0.0rc1

parents 9c7270d0 f747daab
......@@ -86,7 +86,8 @@ class LyftDataset(Custom3DDataset):
modality=None,
box_type_3d='LiDAR',
filter_empty_gt=True,
test_mode=False):
test_mode=False,
**kwargs):
self.load_interval = load_interval
super().__init__(
data_root=data_root,
......@@ -96,7 +97,8 @@ class LyftDataset(Custom3DDataset):
modality=modality,
box_type_3d=box_type_3d,
filter_empty_gt=filter_empty_gt,
test_mode=test_mode)
test_mode=test_mode,
**kwargs)
if self.modality is None:
self.modality = dict(
......@@ -116,7 +118,8 @@ class LyftDataset(Custom3DDataset):
Returns:
list[dict]: List of annotations sorted by timestamps.
"""
data = mmcv.load(ann_file)
# loading data from a file-like object needs file format
data = mmcv.load(ann_file, file_format='pkl')
data_infos = list(sorted(data['infos'], key=lambda e: e['timestamp']))
data_infos = data_infos[::self.load_interval]
self.metadata = data['metadata']
......
......@@ -125,7 +125,8 @@ class NuScenesDataset(Custom3DDataset):
filter_empty_gt=True,
test_mode=False,
eval_version='detection_cvpr_2019',
use_valid_flag=False):
use_valid_flag=False,
**kwargs):
self.load_interval = load_interval
self.use_valid_flag = use_valid_flag
super().__init__(
......@@ -136,7 +137,8 @@ class NuScenesDataset(Custom3DDataset):
modality=modality,
box_type_3d=box_type_3d,
filter_empty_gt=filter_empty_gt,
test_mode=test_mode)
test_mode=test_mode,
**kwargs)
self.with_velocity = with_velocity
self.eval_version = eval_version
......@@ -184,7 +186,8 @@ class NuScenesDataset(Custom3DDataset):
Returns:
list[dict]: List of annotations sorted by timestamps.
"""
data = mmcv.load(ann_file)
# loading data from a file-like object needs file format
data = mmcv.load(ann_file, file_format='pkl')
data_infos = list(sorted(data['infos'], key=lambda e: e['timestamp']))
data_infos = data_infos[::self.load_interval]
self.metadata = data['metadata']
......
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os
import warnings
import mmcv
import numpy as np
......@@ -104,7 +105,8 @@ class DataBaseSampler(object):
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=4,
use_dim=[0, 1, 2, 3])):
use_dim=[0, 1, 2, 3]),
file_client_args=dict(backend='disk')):
super().__init__()
self.data_root = data_root
self.info_path = info_path
......@@ -114,8 +116,20 @@ class DataBaseSampler(object):
self.cat2label = {name: i for i, name in enumerate(classes)}
self.label2cat = {i: name for i, name in enumerate(classes)}
self.points_loader = mmcv.build_from_cfg(points_loader, PIPELINES)
self.file_client = mmcv.FileClient(**file_client_args)
db_infos = mmcv.load(info_path)
# load data base infos
if hasattr(self.file_client, 'get_local_path'):
with self.file_client.get_local_path(info_path) as local_path:
# loading data from a file-like object needs file format
db_infos = mmcv.load(open(local_path, 'rb'), file_format='pkl')
else:
warnings.warn(
'The used MMCV version does not have get_local_path. '
f'We treat the {info_path} as local paths and it '
'might cause errors if the path is not a local path. '
'Please use MMCV>= 1.3.16 if you meet errors.')
db_infos = mmcv.load(info_path)
# filter database infos
from mmdet3d.utils import get_root_logger
......
......@@ -518,7 +518,7 @@ class LoadAnnotations3D(LoadAnnotations):
with_seg=False,
with_bbox_depth=False,
poly2mask=True,
seg_3d_dtype='int',
seg_3d_dtype=np.int64,
file_client_args=dict(backend='disk')):
super().__init__(
with_bbox,
......@@ -600,11 +600,11 @@ class LoadAnnotations3D(LoadAnnotations):
self.file_client = mmcv.FileClient(**self.file_client_args)
try:
mask_bytes = self.file_client.get(pts_instance_mask_path)
pts_instance_mask = np.frombuffer(mask_bytes, dtype=np.int)
pts_instance_mask = np.frombuffer(mask_bytes, dtype=np.int64)
except ConnectionError:
mmcv.check_file_exist(pts_instance_mask_path)
pts_instance_mask = np.fromfile(
pts_instance_mask_path, dtype=np.long)
pts_instance_mask_path, dtype=np.int64)
results['pts_instance_mask'] = pts_instance_mask
results['pts_mask_fields'].append('pts_instance_mask')
......@@ -631,7 +631,7 @@ class LoadAnnotations3D(LoadAnnotations):
except ConnectionError:
mmcv.check_file_exist(pts_semantic_mask_path)
pts_semantic_mask = np.fromfile(
pts_semantic_mask_path, dtype=np.long)
pts_semantic_mask_path, dtype=np.int64)
results['pts_semantic_mask'] = pts_semantic_mask
results['pts_seg_fields'].append('pts_semantic_mask')
......
......@@ -356,7 +356,7 @@ class ObjectSample(object):
input_dict['img'] = sampled_dict['img']
input_dict['gt_bboxes_3d'] = gt_bboxes_3d
input_dict['gt_labels_3d'] = gt_labels_3d.astype(np.long)
input_dict['gt_labels_3d'] = gt_labels_3d.astype(np.int64)
input_dict['points'] = points
return input_dict
......@@ -907,9 +907,9 @@ class PointSample(object):
point_range = range(len(points))
if sample_range is not None and not replace:
# Only sampling the near points when len(points) >= num_samples
depth = np.linalg.norm(points.tensor, axis=1)
far_inds = np.where(depth >= sample_range)[0]
near_inds = np.where(depth < sample_range)[0]
dist = np.linalg.norm(points.tensor, axis=1)
far_inds = np.where(dist >= sample_range)[0]
near_inds = np.where(dist < sample_range)[0]
# in case there are too many far points
if len(far_inds) > num_samples:
far_inds = np.random.choice(
......@@ -936,12 +936,6 @@ class PointSample(object):
and 'pts_semantic_mask' keys are updated in the result dict.
"""
points = results['points']
# Points in Camera coord can provide the depth information.
# TODO: Need to support distance-based sampling for other coord system.
if self.sample_range is not None:
from mmdet3d.core.points import CameraPoints
assert isinstance(points, CameraPoints), \
'Sampling based on distance is only applicable for CAM coord'
points, choices = self._points_random_sampling(
points,
self.num_points,
......
......@@ -54,7 +54,8 @@ class S3DISDataset(Custom3DDataset):
modality=None,
box_type_3d='Depth',
filter_empty_gt=True,
test_mode=False):
test_mode=False,
*kwargs):
super().__init__(
data_root=data_root,
ann_file=ann_file,
......@@ -63,7 +64,8 @@ class S3DISDataset(Custom3DDataset):
modality=modality,
box_type_3d=box_type_3d,
filter_empty_gt=filter_empty_gt,
test_mode=test_mode)
test_mode=test_mode,
*kwargs)
def get_ann_info(self, index):
"""Get annotation info according to the given index.
......@@ -85,10 +87,10 @@ class S3DISDataset(Custom3DDataset):
if info['annos']['gt_num'] != 0:
gt_bboxes_3d = info['annos']['gt_boxes_upright_depth'].astype(
np.float32) # k, 6
gt_labels_3d = info['annos']['class'].astype(np.long)
gt_labels_3d = info['annos']['class'].astype(np.int64)
else:
gt_bboxes_3d = np.zeros((0, 6), dtype=np.float32)
gt_labels_3d = np.zeros((0, ), dtype=np.long)
gt_labels_3d = np.zeros((0, ), dtype=np.int64)
# to target box structure
gt_bboxes_3d = DepthInstance3DBoxes(
......@@ -205,7 +207,8 @@ class _S3DISSegDataset(Custom3DSegDataset):
modality=None,
test_mode=False,
ignore_index=None,
scene_idxs=None):
scene_idxs=None,
**kwargs):
super().__init__(
data_root=data_root,
......@@ -216,7 +219,8 @@ class _S3DISSegDataset(Custom3DSegDataset):
modality=modality,
test_mode=test_mode,
ignore_index=ignore_index,
scene_idxs=scene_idxs)
scene_idxs=scene_idxs,
**kwargs)
def get_ann_info(self, index):
"""Get annotation info according to the given index.
......@@ -347,7 +351,8 @@ class S3DISSegDataset(_S3DISSegDataset):
modality=None,
test_mode=False,
ignore_index=None,
scene_idxs=None):
scene_idxs=None,
**kwargs):
# make sure that ann_files and scene_idxs have same length
ann_files = self._check_ann_files(ann_files)
......@@ -363,7 +368,8 @@ class S3DISSegDataset(_S3DISSegDataset):
modality=modality,
test_mode=test_mode,
ignore_index=ignore_index,
scene_idxs=scene_idxs[0])
scene_idxs=scene_idxs[0],
**kwargs)
datasets = [
_S3DISSegDataset(
......
......@@ -5,7 +5,7 @@ from os import path as osp
import numpy as np
from mmdet3d.core import show_result, show_seg_result
from mmdet3d.core import instance_seg_eval, show_result, show_seg_result
from mmdet3d.core.bbox import DepthInstance3DBoxes
from mmdet.datasets import DATASETS
from mmseg.datasets import DATASETS as SEG_DATASETS
......@@ -58,7 +58,8 @@ class ScanNetDataset(Custom3DDataset):
modality=dict(use_camera=False, use_depth=True),
box_type_3d='Depth',
filter_empty_gt=True,
test_mode=False):
test_mode=False,
**kwargs):
super().__init__(
data_root=data_root,
ann_file=ann_file,
......@@ -67,7 +68,8 @@ class ScanNetDataset(Custom3DDataset):
modality=modality,
box_type_3d=box_type_3d,
filter_empty_gt=filter_empty_gt,
test_mode=test_mode)
test_mode=test_mode,
**kwargs)
assert 'use_camera' in self.modality and \
'use_depth' in self.modality
assert self.modality['use_camera'] or self.modality['use_depth']
......@@ -143,10 +145,10 @@ class ScanNetDataset(Custom3DDataset):
if info['annos']['gt_num'] != 0:
gt_bboxes_3d = info['annos']['gt_boxes_upright_depth'].astype(
np.float32) # k, 6
gt_labels_3d = info['annos']['class'].astype(np.long)
gt_labels_3d = info['annos']['class'].astype(np.int64)
else:
gt_bboxes_3d = np.zeros((0, 6), dtype=np.float32)
gt_labels_3d = np.zeros((0, ), dtype=np.long)
gt_labels_3d = np.zeros((0, ), dtype=np.int64)
# to target box structure
gt_bboxes_3d = DepthInstance3DBoxes(
......@@ -322,7 +324,8 @@ class ScanNetSegDataset(Custom3DSegDataset):
modality=None,
test_mode=False,
ignore_index=None,
scene_idxs=None):
scene_idxs=None,
**kwargs):
super().__init__(
data_root=data_root,
......@@ -333,7 +336,8 @@ class ScanNetSegDataset(Custom3DSegDataset):
modality=modality,
test_mode=test_mode,
ignore_index=ignore_index,
scene_idxs=scene_idxs)
scene_idxs=scene_idxs,
**kwargs)
def get_ann_info(self, index):
"""Get annotation info according to the given index.
......@@ -460,3 +464,151 @@ class ScanNetSegDataset(Custom3DSegDataset):
outputs.append(dict(seg_mask=pred_label))
return outputs, tmp_dir
@DATASETS.register_module()
@SEG_DATASETS.register_module()
class ScanNetInstanceSegDataset(Custom3DSegDataset):
CLASSES = ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window',
'bookshelf', 'picture', 'counter', 'desk', 'curtain',
'refrigerator', 'showercurtrain', 'toilet', 'sink', 'bathtub',
'garbagebin')
VALID_CLASS_IDS = (3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34,
36, 39)
ALL_CLASS_IDS = tuple(range(41))
def get_ann_info(self, index):
"""Get annotation info according to the given index.
Args:
index (int): Index of the annotation data to get.
Returns:
dict: annotation information consists of the following keys:
- pts_semantic_mask_path (str): Path of semantic masks.
- pts_instance_mask_path (str): Path of instance masks.
"""
# Use index to get the annos, thus the evalhook could also use this api
info = self.data_infos[index]
pts_instance_mask_path = osp.join(self.data_root,
info['pts_instance_mask_path'])
pts_semantic_mask_path = osp.join(self.data_root,
info['pts_semantic_mask_path'])
anns_results = dict(
pts_instance_mask_path=pts_instance_mask_path,
pts_semantic_mask_path=pts_semantic_mask_path)
return anns_results
def get_classes_and_palette(self, classes=None, palette=None):
"""Get class names of current dataset. Palette is simply ignored for
instance segmentation.
Args:
classes (Sequence[str] | str | None): If classes is None, use
default CLASSES defined by builtin dataset. If classes is a
string, take it as a file name. The file contains the name of
classes where each line contains one class name. If classes is
a tuple or list, override the CLASSES defined by the dataset.
Defaults to None.
palette (Sequence[Sequence[int]]] | np.ndarray | None):
The palette of segmentation map. If None is given, random
palette will be generated. Defaults to None.
"""
if classes is not None:
return classes, None
return self.CLASSES, None
def _build_default_pipeline(self):
"""Build the default pipeline for this dataset."""
pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
use_color=True,
load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_mask_3d=True,
with_seg_3d=True),
dict(
type='PointSegClassMapping',
valid_cat_ids=self.VALID_CLASS_IDS,
max_cat_id=40),
dict(
type='DefaultFormatBundle3D',
with_label=False,
class_names=self.CLASSES),
dict(
type='Collect3D',
keys=['points', 'pts_semantic_mask', 'pts_instance_mask'])
]
return Compose(pipeline)
def evaluate(self,
results,
metric=None,
options=None,
logger=None,
show=False,
out_dir=None,
pipeline=None):
"""Evaluation in instance segmentation protocol.
Args:
results (list[dict]): List of results.
metric (str | list[str]): Metrics to be evaluated.
options (dict, optional): options for instance_seg_eval.
logger (logging.Logger | None | str): Logger used for printing
related information during evaluation. Defaults to None.
show (bool, optional): Whether to visualize.
Defaults to False.
out_dir (str, optional): Path to save the visualization results.
Defaults to None.
pipeline (list[dict], optional): raw data loading for showing.
Default: None.
Returns:
dict: Evaluation results.
"""
assert isinstance(
results, list), f'Expect results to be list, got {type(results)}.'
assert len(results) > 0, 'Expect length of results > 0.'
assert len(results) == len(self.data_infos)
assert isinstance(
results[0], dict
), f'Expect elements in results to be dict, got {type(results[0])}.'
load_pipeline = self._get_pipeline(pipeline)
pred_instance_masks = [result['instance_mask'] for result in results]
pred_instance_labels = [result['instance_label'] for result in results]
pred_instance_scores = [result['instance_score'] for result in results]
gt_semantic_masks, gt_instance_masks = zip(*[
self._extract_data(
index=i,
pipeline=load_pipeline,
key=['pts_semantic_mask', 'pts_instance_mask'],
load_annos=True) for i in range(len(self.data_infos))
])
ret_dict = instance_seg_eval(
gt_semantic_masks,
gt_instance_masks,
pred_instance_masks,
pred_instance_labels,
pred_instance_scores,
valid_class_ids=self.VALID_CLASS_IDS,
class_labels=self.CLASSES,
options=options,
logger=logger)
if show:
raise NotImplementedError('show is not implemented for now')
return ret_dict
......@@ -54,7 +54,8 @@ class SUNRGBDDataset(Custom3DDataset):
modality=dict(use_camera=True, use_lidar=True),
box_type_3d='Depth',
filter_empty_gt=True,
test_mode=False):
test_mode=False,
**kwargs):
super().__init__(
data_root=data_root,
ann_file=ann_file,
......@@ -63,7 +64,8 @@ class SUNRGBDDataset(Custom3DDataset):
modality=modality,
box_type_3d=box_type_3d,
filter_empty_gt=filter_empty_gt,
test_mode=test_mode)
test_mode=test_mode,
**kwargs)
assert 'use_camera' in self.modality and \
'use_lidar' in self.modality
assert self.modality['use_camera'] or self.modality['use_lidar']
......@@ -137,10 +139,10 @@ class SUNRGBDDataset(Custom3DDataset):
if info['annos']['gt_num'] != 0:
gt_bboxes_3d = info['annos']['gt_boxes_upright_depth'].astype(
np.float32) # k, 6
gt_labels_3d = info['annos']['class'].astype(np.long)
gt_labels_3d = info['annos']['class'].astype(np.int64)
else:
gt_bboxes_3d = np.zeros((0, 7), dtype=np.float32)
gt_labels_3d = np.zeros((0, ), dtype=np.long)
gt_labels_3d = np.zeros((0, ), dtype=np.int64)
# to target box structure
gt_bboxes_3d = DepthInstance3DBoxes(
......
......@@ -66,7 +66,8 @@ class WaymoDataset(KittiDataset):
filter_empty_gt=True,
test_mode=False,
load_interval=1,
pcd_limit_range=[-85, -85, -5, 85, 85, 5]):
pcd_limit_range=[-85, -85, -5, 85, 85, 5],
**kwargs):
super().__init__(
data_root=data_root,
ann_file=ann_file,
......@@ -78,7 +79,8 @@ class WaymoDataset(KittiDataset):
box_type_3d=box_type_3d,
filter_empty_gt=filter_empty_gt,
test_mode=test_mode,
pcd_limit_range=pcd_limit_range)
pcd_limit_range=pcd_limit_range,
**kwargs)
# to load a subset, just set the load_interval in the dataset config
self.data_infos = self.data_infos[::load_interval]
......
......@@ -3,6 +3,7 @@ import copy
import torch
from mmcv.cnn import ConvModule, build_conv_layer
from mmcv.ops import nms_bev as nms_gpu
from mmcv.runner import BaseModule, force_fp32
from torch import nn
......@@ -11,7 +12,6 @@ from mmdet3d.core import (circle_nms, draw_heatmap_gaussian, gaussian_radius,
from mmdet3d.models import builder
from mmdet3d.models.builder import HEADS, build_loss
from mmdet3d.models.utils import clip_sigmoid
from mmdet3d.ops.iou3d.iou3d_utils import nms_gpu
from mmdet.core import build_bbox_coder, multi_apply
......
......@@ -7,13 +7,14 @@ from mmcv import ConfigDict
from mmcv.cnn import ConvModule, xavier_init
from mmcv.cnn.bricks.transformer import (build_positional_encoding,
build_transformer_layer)
from mmcv.ops import PointsSampler as Points_Sampler
from mmcv.ops import gather_points
from mmcv.runner import BaseModule, force_fp32
from torch import nn as nn
from torch.nn import functional as F
from mmdet3d.core.post_processing import aligned_3d_nms
from mmdet3d.models.builder import build_loss
from mmdet3d.ops import Points_Sampler, gather_points
from mmdet.core import build_bbox_coder, multi_apply
from mmdet.models import HEADS
from .base_conv_bbox_head import BaseConvBboxHead
......
......@@ -3,10 +3,11 @@ from __future__ import division
import numpy as np
import torch
from mmcv.ops import nms_bev as nms_gpu
from mmcv.ops import nms_normal_bev as nms_normal_gpu
from mmcv.runner import force_fp32
from mmdet3d.core import limit_period, xywhr2xyxyr
from mmdet3d.ops.iou3d.iou3d_utils import nms_gpu, nms_normal_gpu
from mmdet.models import HEADS
from .anchor3d_head import Anchor3DHead
......
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.ops import nms_bev as nms_gpu
from mmcv.ops import nms_normal_bev as nms_normal_gpu
from mmcv.runner import BaseModule, force_fp32
from torch import nn as nn
from mmdet3d.core.bbox.structures import (DepthInstance3DBoxes,
LiDARInstance3DBoxes)
from mmdet3d.ops.iou3d.iou3d_utils import nms_gpu, nms_normal_gpu
from mmdet.core import build_bbox_coder, multi_apply
from mmdet.models import HEADS, build_loss
......
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from mmcv.ops import furthest_point_sample
from mmcv.runner import BaseModule, force_fp32
from torch.nn import functional as F
......@@ -8,7 +9,7 @@ from mmdet3d.core.post_processing import aligned_3d_nms
from mmdet3d.models.builder import build_loss
from mmdet3d.models.losses import chamfer_distance
from mmdet3d.models.model_utils import VoteModule
from mmdet3d.ops import build_sa_module, furthest_point_sample
from mmdet3d.ops import build_sa_module
from mmdet.core import build_bbox_coder, multi_apply
from mmdet.models import HEADS
from .base_conv_bbox_head import BaseConvBboxHead
......
......@@ -4,13 +4,13 @@ from os import path as osp
import mmcv
import torch
from mmcv.ops import Voxelization
from mmcv.parallel import DataContainer as DC
from mmcv.runner import force_fp32
from torch.nn import functional as F
from mmdet3d.core import (Box3DMode, Coord3DMode, bbox3d2result,
merge_aug_bboxes_3d, show_result)
from mmdet3d.ops import Voxelization
from mmdet.core import multi_apply
from mmdet.models import DETECTORS
from .. import builder
......
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.ops import Voxelization
from torch.nn import functional as F
from mmdet3d.ops import Voxelization
from mmdet.models import DETECTORS
from .. import builder
from .two_stage import TwoStage3DDetector
......
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.ops import Voxelization
from mmcv.runner import force_fp32
from torch.nn import functional as F
from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
from mmdet3d.ops import Voxelization
from mmdet.models import DETECTORS
from .. import builder
from .single_stage import SingleStage3DDetector
......
......@@ -50,14 +50,14 @@ class PointPillarsScatter(nn.Module):
dtype=voxel_features.dtype,
device=voxel_features.device)
indices = coors[:, 1] * self.nx + coors[:, 2]
indices = coors[:, 2] * self.nx + coors[:, 3]
indices = indices.long()
voxels = voxel_features.t()
# Now scatter the blob back to the canvas.
canvas[:, indices] = voxels
# Undo the column stacking to final 4-dim tensor
canvas = canvas.view(1, self.in_channels, self.ny, self.nx)
return [canvas]
return canvas
def forward_batch(self, voxel_features, coors, batch_size):
"""Scatter features of single sample.
......
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.ops import SparseConvTensor, SparseSequential
from mmcv.runner import auto_fp16
from torch import nn as nn
from mmdet3d.ops import SparseBasicBlock, make_sparse_convmodule
from mmdet3d.ops import spconv as spconv
from ..builder import MIDDLE_ENCODERS
......@@ -109,9 +109,8 @@ class SparseEncoder(nn.Module):
dict: Backbone features.
"""
coors = coors.int()
input_sp_tensor = spconv.SparseConvTensor(voxel_features, coors,
self.sparse_shape,
batch_size)
input_sp_tensor = SparseConvTensor(voxel_features, coors,
self.sparse_shape, batch_size)
x = self.conv_input(input_sp_tensor)
encode_features = []
......@@ -150,7 +149,7 @@ class SparseEncoder(nn.Module):
int: The number of encoder output channels.
"""
assert block_type in ['conv_module', 'basicblock']
self.encoder_layers = spconv.SparseSequential()
self.encoder_layers = SparseSequential()
for i, blocks in enumerate(self.encoder_channels):
blocks_list = []
......@@ -201,6 +200,6 @@ class SparseEncoder(nn.Module):
conv_type='SubMConv3d'))
in_channels = out_channels
stage_name = f'encoder_layer{i + 1}'
stage_layers = spconv.SparseSequential(*blocks_list)
stage_layers = SparseSequential(*blocks_list)
self.encoder_layers.add_module(stage_name, stage_layers)
return out_channels
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.ops import SparseConvTensor, SparseSequential
from mmcv.runner import BaseModule, auto_fp16
from mmdet3d.ops import SparseBasicBlock, make_sparse_convmodule
from mmdet3d.ops import spconv as spconv
from ..builder import MIDDLE_ENCODERS
......@@ -108,9 +108,8 @@ class SparseUNet(BaseModule):
dict[str, torch.Tensor]: Backbone features.
"""
coors = coors.int()
input_sp_tensor = spconv.SparseConvTensor(voxel_features, coors,
self.sparse_shape,
batch_size)
input_sp_tensor = SparseConvTensor(voxel_features, coors,
self.sparse_shape, batch_size)
x = self.conv_input(input_sp_tensor)
encode_features = []
......@@ -200,7 +199,7 @@ class SparseUNet(BaseModule):
Returns:
int: The number of encoder output channels.
"""
self.encoder_layers = spconv.SparseSequential()
self.encoder_layers = SparseSequential()
for i, blocks in enumerate(self.encoder_channels):
blocks_list = []
......@@ -231,7 +230,7 @@ class SparseUNet(BaseModule):
conv_type='SubMConv3d'))
in_channels = out_channels
stage_name = f'encoder_layer{i + 1}'
stage_layers = spconv.SparseSequential(*blocks_list)
stage_layers = SparseSequential(*blocks_list)
self.encoder_layers.add_module(stage_name, stage_layers)
return out_channels
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment