Unverified Commit ee6cc047 authored by Sun Jiahao's avatar Sun Jiahao Committed by GitHub
Browse files

[Feature] Add MinkUNet segmentor (#2294)

* add cylindrical voxelization & voxel feature encoder

* add cylindrical voxelization & voxel feature encoder

* add voxel-wise label & voxelization UT

* fix vfe

* fix vfe UT

* rename voxel encoder & add more test case

* fix type hint

* temporarily refactoring mmcv's voxelize and dynamic in mmdet3d for data_preprocesser

* _forward

* del checkpoints

* add if tp

* add predict

* fix vfe init bug & fix UT

* add grid_size & move voxelization code

* fix import bug

* keep radian to follow origin

* add doc string

* fix type hint

* add minkunet voxelization and loss function

* fix data

* init train

* fix sparsetensor typehint

* rename dir

* fix data config

* fix data config

* fix batch_size & replace dynamic_scatter

* fix conflicts 2

* fix conflicts on s_70

* Alignment of the original implementation

* rename config

* add worker_init_fn_hook

* remove test_config & worker hook

* add UT

* fix polarmix UT

* add seed for cr0p5

* format

* rename SemanticKittiDataset

* add platte & fix visual bug

* add platte & fix data info bug

* fix ut

* fix semantic_kitti ut

* fix docstring

* fix config name

* rename layer

* fix doc string

* fix review

* remove filter data

* fix coors typo

* fix ut

* pred in segmentor

* fix get voxel seg

* resolve comments
parent be2029d1
model = dict(
type='MinkUNet',
data_preprocessor=dict(
type='Det3DDataPreprocessor',
voxel=True,
voxel_type='minkunet',
voxel_layer=dict(
max_num_points=-1,
point_cloud_range=[-100, -100, -20, 100, 100, 20],
voxel_size=[0.05, 0.05, 0.05],
max_voxels=(-1, -1)),
),
backbone=dict(
type='MinkUNetBackbone',
in_channels=4,
base_channels=32,
encoder_channels=[32, 64, 128, 256],
decoder_channels=[256, 128, 96, 96],
num_stages=4,
init_cfg=None),
decode_head=dict(
type='MinkUNetHead',
channels=96,
num_classes=19,
dropout_ratio=0,
loss_decode=dict(type='mmdet.CrossEntropyLoss', avg_non_ignore=True),
ignore_index=19),
train_cfg=dict(),
test_cfg=dict())
_base_ = ['./minkunet_w32_8xb2-15e_semantickitti.py']
model = dict(
backbone=dict(
base_channels=16,
encoder_channels=[16, 32, 64, 128],
decoder_channels=[128, 64, 48, 48]),
decode_head=dict(channels=48))
# NOTE: Due to TorchSparse backend, the model performance is relatively
# dependent on random seeds, and if random seeds are not specified the
# model performance will be different (± 1.5 mIoU).
randomness = dict(seed=1588147245)
_base_ = ['./minkunet_w32_8xb2-15e_semantickitti.py']
model = dict(
backbone=dict(
base_channels=20,
encoder_channels=[20, 40, 81, 163],
decoder_channels=[163, 81, 61, 61]),
decode_head=dict(channels=61))
_base_ = [
'../_base_/datasets/semantickitti.py', '../_base_/models/minkunet.py',
'../_base_/default_runtime.py'
]
train_pipeline = [
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_seg_3d=True,
seg_3d_dtype='np.int32',
seg_offset=2**16,
dataset_type='semantickitti'),
dict(type='PointSegClassMapping'),
dict(
type='GlobalRotScaleTrans',
rot_range=[0., 6.28318531],
scale_ratio_range=[0.95, 1.05],
translation_std=[0, 0, 0],
),
dict(type='Pack3DDetInputs', keys=['points', 'pts_semantic_mask'])
]
train_dataloader = dict(
sampler=dict(seed=0), dataset=dict(dataset=dict(pipeline=train_pipeline)))
lr = 0.24
optim_wrapper = dict(
type='AmpOptimWrapper',
loss_scale='dynamic',
optimizer=dict(
type='SGD', lr=lr, weight_decay=0.0001, momentum=0.9, nesterov=True))
param_scheduler = [
dict(
type='LinearLR', start_factor=0.008, by_epoch=False, begin=0, end=125),
dict(
type='CosineAnnealingLR',
begin=0,
T_max=15,
by_epoch=True,
eta_min=1e-5,
convert_to_iter_based=True)
]
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=15, val_interval=1)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=1))
randomness = dict(seed=0, deterministic=False, diff_rank_seed=True)
env_cfg = dict(cudnn_benchmark=True)
......@@ -5,6 +5,7 @@ from .cylinder3d import Asymm3DSpconv
from .dgcnn import DGCNNBackbone
from .dla import DLANet
from .mink_resnet import MinkResNet
from .minkunet_backbone import MinkUNetBackbone
from .multi_backbone import MultiBackbone
from .nostem_regnet import NoStemRegNet
from .pointnet2_sa_msg import PointNet2SAMSG
......@@ -14,5 +15,6 @@ from .second import SECOND
__all__ = [
'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'NoStemRegNet',
'SECOND', 'DGCNNBackbone', 'PointNet2SASSG', 'PointNet2SAMSG',
'MultiBackbone', 'DLANet', 'MinkResNet', 'Asymm3DSpconv'
'MultiBackbone', 'DLANet', 'MinkResNet', 'Asymm3DSpconv',
'MinkUNetBackbone'
]
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
from mmengine.model import BaseModule
from mmengine.registry import MODELS
from torch import Tensor, nn
from mmdet3d.models.layers import (TorchSparseConvModule,
TorchSparseResidualBlock)
from mmdet3d.models.layers.torchsparse import IS_TORCHSPARSE_AVAILABLE
from mmdet3d.utils import OptMultiConfig
if IS_TORCHSPARSE_AVAILABLE:
import torchsparse
from torchsparse.tensor import SparseTensor
else:
SparseTensor = None
@MODELS.register_module()
class MinkUNetBackbone(BaseModule):
r"""MinkUNet backbone with TorchSparse backend.
Refer to `implementation code <https://github.com/mit-han-lab/spvnas>`_.
Args:
in_channels (int): Number of input voxel feature channels.
Defaults to 4.
base_channels (int): The input channels for first encoder layer.
Defaults to 32.
encoder_channels (List[int]): Convolutional channels of each encode
layer. Defaults to [32, 64, 128, 256].
decoder_channels (List[int]): Convolutional channels of each decode
layer. Defaults to [256, 128, 96, 96].
num_stages (int): Number of stages in encoder and decoder.
Defaults to 4.
init_cfg (dict or :obj:`ConfigDict` or List[dict or :obj:`ConfigDict`]
, optional): Initialization config dict.
"""
def __init__(self,
in_channels: int = 4,
base_channels: int = 32,
encoder_channels: List[int] = [32, 64, 128, 256],
decoder_channels: List[int] = [256, 128, 96, 96],
num_stages: int = 4,
init_cfg: OptMultiConfig = None) -> None:
super().__init__(init_cfg)
assert num_stages == len(encoder_channels) == len(decoder_channels)
self.num_stages = num_stages
self.conv_input = nn.Sequential(
TorchSparseConvModule(in_channels, base_channels, kernel_size=3),
TorchSparseConvModule(base_channels, base_channels, kernel_size=3))
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
encoder_channels.insert(0, base_channels)
decoder_channels.insert(0, encoder_channels[-1])
for i in range(num_stages):
self.encoder.append(
nn.Sequential(
TorchSparseConvModule(
encoder_channels[i],
encoder_channels[i],
kernel_size=2,
stride=2),
TorchSparseResidualBlock(
encoder_channels[i],
encoder_channels[i + 1],
kernel_size=3),
TorchSparseResidualBlock(
encoder_channels[i + 1],
encoder_channels[i + 1],
kernel_size=3)))
self.decoder.append(
nn.ModuleList([
TorchSparseConvModule(
decoder_channels[i],
decoder_channels[i + 1],
kernel_size=2,
stride=2,
transposed=True),
nn.Sequential(
TorchSparseResidualBlock(
decoder_channels[i + 1] + encoder_channels[-2 - i],
decoder_channels[i + 1],
kernel_size=3),
TorchSparseResidualBlock(
decoder_channels[i + 1],
decoder_channels[i + 1],
kernel_size=3))
]))
def forward(self, voxel_features: Tensor, coors: Tensor) -> SparseTensor:
"""Forward function.
Args:
voxel_features (Tensor): Voxel features in shape (N, C).
coors (Tensor): Coordinates in shape (N, 4),
the columns in the order of (x_idx, y_idx, z_idx, batch_idx).
Returns:
SparseTensor: Backbone features.
"""
x = torchsparse.SparseTensor(voxel_features, coors)
x = self.conv_input(x)
laterals = [x]
for encoder_layer in self.encoder:
x = encoder_layer(x)
laterals.append(x)
laterals = laterals[:-1][::-1]
decoder_outs = []
for i, decoder_layer in enumerate(self.decoder):
x = decoder_layer[0](x)
x = torchsparse.cat((x, laterals[i]))
x = decoder_layer[1](x)
decoder_outs.append(x)
return decoder_outs[-1]
......@@ -415,6 +415,33 @@ class Det3DDataPreprocessor(DetDataPreprocessor):
coors.append(res_coors)
voxels = torch.cat(voxels, dim=0)
coors = torch.cat(coors, dim=0)
elif self.voxel_type == 'minkunet':
voxels, coors = [], []
voxel_size = points[0].new_tensor(self.voxel_layer.voxel_size)
for i, (res, data_sample) in enumerate(zip(points, data_samples)):
res_coors = torch.round(res[:, :3] / voxel_size).int()
res_coors -= res_coors.min(0)[0]
res_coors_numpy = res_coors.cpu().numpy()
inds, voxel2point_map = self.sparse_quantize(
res_coors_numpy, return_index=True, return_inverse=True)
voxel2point_map = torch.from_numpy(voxel2point_map).cuda()
if self.training:
if len(inds) > 80000:
inds = np.random.choice(inds, 80000, replace=False)
inds = torch.from_numpy(inds).cuda()
data_sample.gt_pts_seg.voxel_semantic_mask \
= data_sample.gt_pts_seg.pts_semantic_mask[inds]
res_voxel_coors = res_coors[inds]
res_voxels = res[inds]
res_voxel_coors = F.pad(
res_voxel_coors, (0, 1), mode='constant', value=i)
data_sample.voxel2point_map = voxel2point_map.long()
voxels.append(res_voxels)
coors.append(res_voxel_coors)
voxels = torch.cat(voxels, dim=0)
coors = torch.cat(coors, dim=0)
else:
raise ValueError(f'Invalid voxelization type {self.voxel_type}')
......@@ -445,3 +472,53 @@ class Det3DDataPreprocessor(DetDataPreprocessor):
_, _, point2voxel_map = dynamic_scatter_3d(pseudo_tensor,
res_coors, 'mean', True)
data_sample.gt_pts_seg.point2voxel_map = point2voxel_map
def ravel_hash(self, x: np.ndarray) -> np.ndarray:
"""Get voxel coordinates hash for np.unique().
Args:
x (np.ndarray): The voxel coordinates of points, Nx3.
Returns:
np.ndarray: Voxels coordinates hash.
"""
assert x.ndim == 2, x.shape
x = x - np.min(x, axis=0)
x = x.astype(np.uint64, copy=False)
xmax = np.max(x, axis=0).astype(np.uint64) + 1
h = np.zeros(x.shape[0], dtype=np.uint64)
for k in range(x.shape[1] - 1):
h += x[:, k]
h *= xmax[k + 1]
h += x[:, -1]
return h
def sparse_quantize(self,
coords: np.ndarray,
return_index: bool = False,
return_inverse: bool = False) -> List[np.ndarray]:
"""Sparse Quantization for voxel coordinates used in Minkunet.
Args:
coords (np.ndarray): The voxel coordinates of points, Nx3.
return_index (bool): Whether to return the indices of the
unique coords, shape (M,).
return_inverse (bool): Whether to return the indices of the
original coords shape (N,).
Returns:
List[np.ndarray] or None: Return index and inverse map if
return_index and return_inverse is True.
"""
_, indices, inverse_indices = np.unique(
self.ravel_hash(coords), return_index=True, return_inverse=True)
coords = coords[indices]
outputs = []
if return_index:
outputs += [indices]
if return_inverse:
outputs += [inverse_indices]
return outputs
# Copyright (c) OpenMMLab. All rights reserved.
from .cylinder3d_head import Cylinder3DHead
from .dgcnn_head import DGCNNHead
from .minkunet_head import MinkUNetHead
from .paconv_head import PAConvHead
from .pointnet2_head import PointNet2Head
__all__ = ['PointNet2Head', 'DGCNNHead', 'PAConvHead', 'Cylinder3DHead']
__all__ = [
'PointNet2Head', 'DGCNNHead', 'PAConvHead', 'Cylinder3DHead',
'MinkUNetHead'
]
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import torch
from torch import Tensor
from torch import nn as nn
from mmdet3d.models.layers.torchsparse import IS_TORCHSPARSE_AVAILABLE
from mmdet3d.registry import MODELS
from mmdet3d.structures.det3d_data_sample import SampleList
from .decode_head import Base3DDecodeHead
if IS_TORCHSPARSE_AVAILABLE:
from torchsparse import SparseTensor
else:
SparseTensor = None
@MODELS.register_module()
class MinkUNetHead(Base3DDecodeHead):
r"""MinkUNet decoder head with TorchSparse backend.
Refer to `implementation code <https://github.com/mit-han-lab/spvnas>`_.
Args:
channels (int): The input channel of conv_seg.
num_classes (int): Number of classes.
"""
def __init__(self, channels: int, num_classes: int, **kwargs) -> None:
super().__init__(channels, num_classes, **kwargs)
def build_conv_seg(self, channels: int, num_classes: int,
kernel_size: int) -> nn.Module:
"""Build Convolutional Segmentation Layers."""
return nn.Linear(channels, num_classes)
def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor:
"""Concat voxel-wise Groud Truth."""
gt_semantic_segs = [
data_sample.gt_pts_seg.voxel_semantic_mask
for data_sample in batch_data_samples
]
return torch.cat(gt_semantic_segs)
def predict(self, inputs: SparseTensor,
batch_data_samples: SampleList) -> List[Tensor]:
"""Forward function for testing.
Args:
inputs (SparseTensor): Features from backone.
batch_data_samples (List[:obj:`Det3DDataSample`]): The seg
data samples.
Returns:
List[Tensor]: The segmentation prediction mask of each batch.
"""
seg_logits = self.forward(inputs)
batch_idx = inputs.C[:, -1]
seg_logit_list = []
for i, data_sample in enumerate(batch_data_samples):
seg_logit = seg_logits[batch_idx == i]
seg_logit = seg_logit[data_sample.voxel2point_map]
seg_logit_list.append(seg_logit)
return seg_logit_list
def forward(self, x: SparseTensor) -> Tensor:
"""Forward function.
Args:
x (SparseTensor): Features from backbone.
Returns:
Tensor: Segmentation map of shape [N, C].
Note that output contains all points from each batch.
"""
output = self.cls_seg(x.F)
return output
......@@ -2,5 +2,6 @@
from .base import Base3DSegmentor
from .cylinder3d import Cylinder3D
from .encoder_decoder import EncoderDecoder3D
from .minkunet import MinkUNet
__all__ = ['Base3DSegmentor', 'EncoderDecoder3D', 'Cylinder3D']
__all__ = ['Base3DSegmentor', 'EncoderDecoder3D', 'Cylinder3D', 'MinkUNet']
# Copyright (c) OpenMMLab. All rights reserved.
from torch import Tensor
from mmdet3d.models.layers.torchsparse import IS_TORCHSPARSE_AVAILABLE
from mmdet3d.registry import MODELS
from mmdet3d.structures.det3d_data_sample import OptSampleList, SampleList
from .encoder_decoder import EncoderDecoder3D
if IS_TORCHSPARSE_AVAILABLE:
from torchsparse import SparseTensor
else:
SparseTensor = None
@MODELS.register_module()
class MinkUNet(EncoderDecoder3D):
r"""MinkUNet is the implementation of `4D Spatio-Temporal ConvNets.
<https://arxiv.org/abs/1904.08755>`_ with TorchSparse backend.
Refer to `implementation code <https://github.com/mit-han-lab/spvnas>`_.
Args:
kwargs (dict): Arguments are the same as those in
:class:`EncoderDecoder3D`.
"""
def __init__(self, **kwargs) -> None:
if not IS_TORCHSPARSE_AVAILABLE:
raise ImportError(
'Please follow `get_started.md` to install Torchsparse.`')
super().__init__(**kwargs)
def loss(self, inputs: dict, data_samples: SampleList):
"""Calculate losses from a batch of inputs and data samples.
Args:
batch_inputs_dict (dict): Input sample dict which
includes 'points' and 'voxels' keys.
- points (List[Tensor]): Point cloud of each sample.
- voxels (dict): Voxel feature and coords after voxelization.
batch_data_samples (List[:obj:`Det3DDataSample`]): The seg data
samples. It usually includes information such as `metainfo` and
`gt_pts_seg`.
Returns:
Dict[str, Tensor]: A dictionary of loss components.
"""
x = self.extract_feat(inputs)
losses = self.decode_head.loss(x, data_samples, self.train_cfg)
return losses
def predict(self, inputs: dict, data_samples: SampleList) -> SampleList:
"""Simple test with single scene.
Args:
batch_inputs_dict (dict): Input sample dict which
includes 'points' and 'voxels' keys.
- points (List[Tensor]): Point cloud of each sample.
- voxels (dict): Voxel feature and coords after voxelization.
batch_data_samples (List[:obj:`Det3DDataSample`]): The seg data
samples. It usually includes information such as `metainfo` and
`gt_pts_seg`.
Returns:
List[:obj:`Det3DDataSample`]: Segmentation results of the input
points. Each Det3DDataSample usually contains:
- ``pred_pts_seg`` (PixelData): Prediction of 3D semantic
segmentation.
"""
x = self.extract_feat(inputs)
seg_logits = self.decode_head.predict(x, data_samples)
seg_preds = [seg_logit.argmax(dim=1) for seg_logit in seg_logits]
return self.postprocess_result(seg_preds, data_samples)
def _forward(self,
batch_inputs_dict: dict,
batch_data_samples: OptSampleList = None) -> Tensor:
"""Network forward process.
Args:
batch_inputs_dict (dict): Input sample dict which
includes 'points' and 'voxels' keys.
- points (List[Tensor]): Point cloud of each sample.
- voxels (dict): Voxel feature and coords after voxelization.
batch_data_samples (List[:obj:`Det3DDataSample`]): The seg data
samples. It usually includes information such as `metainfo` and
`gt_pts_seg`. Defaults to None.
Returns:
Tensor: Forward output of model without any post-processes.
"""
x = self.extract_feat(batch_inputs_dict)
return self.decode_head.forward(x)
def extract_feat(self, batch_inputs_dict: dict) -> SparseTensor:
"""Extract features from voxels.
Args:
batch_inputs_dict (dict): Input sample dict which
includes 'points' and 'voxels' keys.
- points (List[Tensor]): Point cloud of each sample.
- voxels (dict): Voxel feature and coords after voxelization.
Returns:
SparseTensor: voxels with features.
"""
voxel_dict = batch_inputs_dict['voxels']
x = self.backbone(voxel_dict['voxels'], voxel_dict['coors'])
if self.with_neck:
x = self.neck(x)
return x
......@@ -84,6 +84,7 @@ def create_detector_inputs(seed=0,
gt_bboxes_dim=7,
with_pts_semantic_mask=False,
with_pts_instance_mask=False,
with_eval_ann_info=False,
bboxes_3d_type='lidar'):
setup_seed(seed)
assert bboxes_3d_type in ('lidar', 'depth', 'cam')
......@@ -145,5 +146,9 @@ def create_detector_inputs(seed=0,
if with_pts_semantic_mask:
pts_semantic_mask = torch.randint(0, num_classes, [num_points])
data_sample.gt_pts_seg['pts_semantic_mask'] = pts_semantic_mask
if with_eval_ann_info:
data_sample.eval_ann_info = dict()
else:
data_sample.eval_ann_info = None
return dict(inputs=inputs_dict, data_samples=[data_sample])
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
import torch.nn.functional as F
from mmdet3d.registry import MODELS
def test_minkunet_backbone():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
try:
import torchsparse # noqa: F401
except ImportError:
pytest.skip('test requires Torchsparse installation')
coordinates, features = [], []
for i in range(2):
c = torch.randint(0, 10, (100, 3)).int()
c = F.pad(c, (0, 1), mode='constant', value=i)
coordinates.append(c)
f = torch.rand(100, 4)
features.append(f)
features = torch.cat(features, dim=0).cuda()
coordinates = torch.cat(coordinates, dim=0).cuda()
cfg = dict(type='MinkUNetBackbone')
self = MODELS.build(cfg).cuda()
self.init_weights()
y = self(features, coordinates)
assert y.F.shape == torch.Size([200, 96])
assert y.C.shape == torch.Size([200, 4])
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import pytest
import torch
import torch.nn.functional as F
from mmdet3d.models.decode_heads import MinkUNetHead
from mmdet3d.structures import Det3DDataSample, PointData
class TestMinkUNetHead(TestCase):
def test_minkunet_head_loss(self):
"""Tests PAConv head loss."""
try:
import torchsparse
except ImportError:
pytest.skip('test requires Torchsparse installation')
if torch.cuda.is_available():
minkunet_head = MinkUNetHead(channels=4, num_classes=19)
minkunet_head.cuda()
coordinates, features = [], []
for i in range(2):
c = torch.randint(0, 10, (100, 3)).int()
c = F.pad(c, (0, 1), mode='constant', value=i)
coordinates.append(c)
f = torch.rand(100, 4)
features.append(f)
features = torch.cat(features, dim=0).cuda()
coordinates = torch.cat(coordinates, dim=0).cuda()
x = torchsparse.SparseTensor(feats=features, coords=coordinates)
# Test forward
seg_logits = minkunet_head.forward(x)
self.assertEqual(seg_logits.shape, torch.Size([200, 19]))
# When truth is non-empty then losses
# should be nonzero for random inputs
voxel_semantic_mask = torch.randint(0, 19, (100, )).long().cuda()
gt_pts_seg = PointData(voxel_semantic_mask=voxel_semantic_mask)
datasample = Det3DDataSample()
datasample.gt_pts_seg = gt_pts_seg
gt_losses = minkunet_head.loss(x, [datasample, datasample], {})
gt_sem_seg_loss = gt_losses['loss_sem_seg'].item()
self.assertGreater(gt_sem_seg_loss, 0,
'semantic seg loss should be positive')
# Copyright (c) OpenMMLab. All rights reserved.
import unittest
import pytest
import torch
from mmengine import DefaultScope
from mmdet3d.registry import MODELS
from mmdet3d.testing import (create_detector_inputs, get_detector_cfg,
setup_seed)
class TestMinkUNet(unittest.TestCase):
def test_minkunet(self):
try:
import torchsparse # noqa
except ImportError:
pytest.skip('test requires Torchsparse installation')
import mmdet3d.models
assert hasattr(mmdet3d.models, 'MinkUNet')
DefaultScope.get_instance('test_minkunet', scope_name='mmdet3d')
setup_seed(0)
model_cfg = get_detector_cfg('_base_/models/minkunet.py')
model = MODELS.build(model_cfg)
num_gt_instance = 3
packed_inputs = create_detector_inputs(
num_gt_instance=num_gt_instance,
num_classes=19,
with_pts_semantic_mask=True)
if torch.cuda.is_available():
model = model.cuda()
# test simple_test
with torch.no_grad():
data = model.data_preprocessor(packed_inputs, True)
torch.cuda.empty_cache()
results = model.forward(**data, mode='predict')
self.assertEqual(len(results), 1)
self.assertIn('pts_semantic_mask', results[0].pred_pts_seg)
losses = model.forward(**data, mode='loss')
self.assertGreater(losses['loss_sem_seg'], 0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment