Unverified Commit a1b974a5 authored by xizaoqu's avatar xizaoqu Committed by GitHub
Browse files

[Feature] Add Cylinder3D head (#2291)

* add cylinder decode head

* update

* update

* add lovasz loss

* update

* update

* update

* update

* update

* update

* update

* update

* update

* cylinder3d_head

* update

* update

* update

* update
parent ae3c8f80
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .cylinder3d_head import Cylinder3DHead
from .dgcnn_head import DGCNNHead from .dgcnn_head import DGCNNHead
from .paconv_head import PAConvHead from .paconv_head import PAConvHead
from .pointnet2_head import PointNet2Head from .pointnet2_head import PointNet2Head
__all__ = ['PointNet2Head', 'DGCNNHead', 'PAConvHead'] __all__ = ['PointNet2Head', 'DGCNNHead', 'PAConvHead', 'Cylinder3DHead']
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.ops import SparseConvTensor, SparseModule, SubMConv3d
from mmdet3d.registry import MODELS
from mmdet3d.structures.det3d_data_sample import SampleList
from mmdet3d.utils import OptMultiConfig
from mmdet3d.utils.typing_utils import ConfigType
from .decode_head import Base3DDecodeHead
@MODELS.register_module()
class Cylinder3DHead(Base3DDecodeHead):
"""Cylinder3D decoder head.
Decoder head used in `Cylinder3D <https://arxiv.org/abs/2011.10033>`_.
Refer to the
`official code <https://https://github.com/xinge008/Cylinder3D>`_.
Args:
channels (int): Channels after modules, before conv_seg.
num_classes (int): Number of classes.
dropout_ratio (float): Ratio of dropout layer. Defaults to 0.
conv_cfg (dict or :obj:`ConfigDict`): Config of conv layers.
Defaults to dict(type='Conv1d').
norm_cfg (dict or :obj:`ConfigDict`): Config of norm layers.
Defaults to dict(type='BN1d').
act_cfg (dict or :obj:`ConfigDict`): Config of activation layers.
Defaults to dict(type='ReLU').
loss_ce (dict or :obj:`ConfigDict`): Config of CrossEntropy loss.
Defaults to dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=False,
class_weight=None,
loss_weight=1.0).
loss_lovasz (dict or :obj:`ConfigDict`): Config of Lovasz loss.
Defaults to dict(type='LovaszLoss', loss_weight=1.0).
conv_seg_kernel_size (int): The kernel size used in conv_seg.
Defaults to 3.
ignore_index (int): The label index to be ignored. When using masked
BCE loss, ignore_index should be set to None. Defaults to 0.
init_cfg (dict or :obj:`ConfigDict` or list[dict or :obj:`ConfigDict`],
optional): Initialization config dict. Defaults to None.
"""
def __init__(self,
channels: int,
num_classes: int,
dropout_ratio: float = 0,
conv_cfg: ConfigType = dict(type='Conv1d'),
norm_cfg: ConfigType = dict(type='BN1d'),
act_cfg: ConfigType = dict(type='ReLU'),
loss_ce: ConfigType = dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=False,
class_weight=None,
loss_weight=1.0),
loss_lovasz: ConfigType = dict(
type='LovaszLoss', loss_weight=1.0),
conv_seg_kernel_size: int = 3,
ignore_index: int = 0,
init_cfg: OptMultiConfig = None) -> None:
super(Cylinder3DHead, self).__init__(
channels=channels,
num_classes=num_classes,
dropout_ratio=dropout_ratio,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
conv_seg_kernel_size=conv_seg_kernel_size,
init_cfg=init_cfg)
self.loss_lovasz = MODELS.build(loss_lovasz)
self.loss_ce = MODELS.build(loss_ce)
self.ignore_index = ignore_index
def build_conv_seg(self, channels: int, num_classes: int,
kernel_size: int) -> SparseModule:
return SubMConv3d(
channels,
num_classes,
indice_key='logit',
kernel_size=kernel_size,
stride=1,
padding=1,
bias=True)
def forward(self, sparse_voxels: SparseConvTensor) -> SparseConvTensor:
"""Forward function."""
sparse_logits = self.cls_seg(sparse_voxels)
return sparse_logits
def loss_by_feat(self, seg_logit: SparseConvTensor,
batch_data_samples: SampleList) -> dict:
"""Compute semantic segmentation loss.
Args:
seg_logit (SparseConvTensor): Predicted per-voxel
segmentation logits of shape [num_voxels, num_classes]
stored in SparseConvTensor.
batch_data_samples (List[:obj:`Det3DDataSample`]): The seg
data samples. It usually includes information such
as `metainfo` and `gt_pts_seg`.
Returns:
Dict[str, Tensor]: A dictionary of loss components.
"""
gt_semantic_segs = [
data_sample.gt_pts_seg.voxel_semantic_mask
for data_sample in batch_data_samples
]
seg_label = torch.cat(gt_semantic_segs)
seg_logit_feat = seg_logit.features
loss = dict()
loss['loss_ce'] = self.loss_ce(
seg_logit_feat, seg_label, ignore_index=self.ignore_index)
seg_logit_feat = seg_logit_feat.permute(1, 0)[None, :, :,
None] # pseudo BCHW
loss['loss_lovasz'] = self.loss_lovasz(
seg_logit_feat, seg_label, ignore_index=self.ignore_index)
return loss
def predict(
self,
inputs: SparseConvTensor,
batch_inputs_dict: dict,
batch_data_samples: SampleList,
) -> torch.Tensor:
"""Forward function for testing.
Args:
inputs (SparseConvTensor): Feature from backbone.
batch_inputs_dict (dict): Input sample dict which includes 'points'
and 'voxels' keys.
- points (List[Tensor]): Point cloud of each sample.
- voxels (dict): Dict of voxelized voxels and the corresponding
coordinates.
batch_data_samples (List[:obj:`Det3DDataSample`]): The det3d data
samples. It usually includes information such as `metainfo` and
`gt_pts_seg`. We use `point2voxel_map` in this function.
Returns:
List[torch.Tensor]: List of point-wise segmentation logits.
"""
seg_logits = self.forward(inputs).features
seg_pred_list = []
coors = batch_inputs_dict['voxels']['voxel_coors']
for batch_idx in range(len(batch_data_samples)):
seg_logits_sample = seg_logits[coors[:, 0] == batch_idx]
point2voxel_map = batch_data_samples[
batch_idx].gt_pts_seg.point2voxel_map.long()
point_seg_predicts = seg_logits_sample[point2voxel_map]
seg_pred_list.append(point_seg_predicts)
return seg_pred_list
...@@ -51,6 +51,8 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta): ...@@ -51,6 +51,8 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta):
loss_decode (dict or :obj:`ConfigDict`): Config of decode loss. loss_decode (dict or :obj:`ConfigDict`): Config of decode loss.
Defaults to dict(type='mmdet.CrossEntropyLoss', use_sigmoid=False, Defaults to dict(type='mmdet.CrossEntropyLoss', use_sigmoid=False,
class_weight=None, loss_weight=1.0). class_weight=None, loss_weight=1.0).
conv_seg_kernel_size (int): The kernel size used in conv_seg.
Defaults to 1.
ignore_index (int): The label index to be ignored. When using masked ignore_index (int): The label index to be ignored. When using masked
BCE loss, ignore_index should be set to None. Defaults to 255. BCE loss, ignore_index should be set to None. Defaults to 255.
init_cfg (dict or :obj:`ConfigDict` or list[dict or :obj:`ConfigDict`], init_cfg (dict or :obj:`ConfigDict` or list[dict or :obj:`ConfigDict`],
...@@ -69,6 +71,7 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta): ...@@ -69,6 +71,7 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta):
use_sigmoid=False, use_sigmoid=False,
class_weight=None, class_weight=None,
loss_weight=1.0), loss_weight=1.0),
conv_seg_kernel_size: int = 1,
ignore_index: int = 255, ignore_index: int = 255,
init_cfg: OptMultiConfig = None) -> None: init_cfg: OptMultiConfig = None) -> None:
super(Base3DDecodeHead, self).__init__(init_cfg=init_cfg) super(Base3DDecodeHead, self).__init__(init_cfg=init_cfg)
...@@ -81,7 +84,10 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta): ...@@ -81,7 +84,10 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta):
self.loss_decode = MODELS.build(loss_decode) self.loss_decode = MODELS.build(loss_decode)
self.ignore_index = ignore_index self.ignore_index = ignore_index
self.conv_seg = nn.Conv1d(channels, num_classes, kernel_size=1) self.conv_seg = self.build_conv_seg(
channels=channels,
num_classes=num_classes,
kernel_size=conv_seg_kernel_size)
if dropout_ratio > 0: if dropout_ratio > 0:
self.dropout = nn.Dropout(dropout_ratio) self.dropout = nn.Dropout(dropout_ratio)
else: else:
...@@ -97,6 +103,11 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta): ...@@ -97,6 +103,11 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta):
"""Placeholder of forward function.""" """Placeholder of forward function."""
pass pass
def build_conv_seg(self, channels: int, num_classes: int,
kernel_size: int) -> nn.Module:
"""Build Convolutional Segmentation Layers."""
return nn.Conv1d(channels, num_classes, kernel_size=kernel_size)
def cls_seg(self, feat: Tensor) -> Tensor: def cls_seg(self, feat: Tensor) -> Tensor:
"""Classify each points.""" """Classify each points."""
if self.dropout is not None: if self.dropout is not None:
......
...@@ -3,6 +3,7 @@ from mmdet.models.losses import FocalLoss, SmoothL1Loss, binary_cross_entropy ...@@ -3,6 +3,7 @@ from mmdet.models.losses import FocalLoss, SmoothL1Loss, binary_cross_entropy
from .axis_aligned_iou_loss import AxisAlignedIoULoss, axis_aligned_iou_loss from .axis_aligned_iou_loss import AxisAlignedIoULoss, axis_aligned_iou_loss
from .chamfer_distance import ChamferDistance, chamfer_distance from .chamfer_distance import ChamferDistance, chamfer_distance
from .lovasz_loss import LovaszLoss
from .multibin_loss import MultiBinLoss from .multibin_loss import MultiBinLoss
from .paconv_regularization_loss import PAConvRegularizationLoss from .paconv_regularization_loss import PAConvRegularizationLoss
from .rotated_iou_loss import RotatedIoU3DLoss, rotated_iou_3d_loss from .rotated_iou_loss import RotatedIoU3DLoss, rotated_iou_3d_loss
...@@ -12,5 +13,5 @@ __all__ = [ ...@@ -12,5 +13,5 @@ __all__ = [
'FocalLoss', 'SmoothL1Loss', 'binary_cross_entropy', 'ChamferDistance', 'FocalLoss', 'SmoothL1Loss', 'binary_cross_entropy', 'ChamferDistance',
'chamfer_distance', 'axis_aligned_iou_loss', 'AxisAlignedIoULoss', 'chamfer_distance', 'axis_aligned_iou_loss', 'AxisAlignedIoULoss',
'PAConvRegularizationLoss', 'UncertainL1Loss', 'UncertainSmoothL1Loss', 'PAConvRegularizationLoss', 'UncertainL1Loss', 'UncertainSmoothL1Loss',
'MultiBinLoss', 'RotatedIoU3DLoss', 'rotated_iou_3d_loss' 'MultiBinLoss', 'RotatedIoU3DLoss', 'rotated_iou_3d_loss', 'LovaszLoss'
] ]
# Copyright (c) OpenMMLab. All rights reserved.
"""Directly borrowed from mmsegmentation.
Modified from https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytor
ch/lovasz_losses.py Lovasz-Softmax and Jaccard hinge loss in PyTorch Maxim
Berman 2018 ESAT-PSI KU Leuven (MIT License)
"""
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.models import weight_reduce_loss
from mmengine.utils import is_list_of
from mmdet3d.registry import MODELS
def lovasz_grad(gt_sorted: torch.Tensor) -> torch.Tensor:
"""Computes gradient of the Lovasz extension w.r.t sorted errors.
See Alg. 1 in paper.
`The Lovasz-Softmax loss. <https://arxiv.org/abs/1705.08790>`_.
Args:
gt_sorted (torch.Tensor): Sorted ground truth.
Return:
torch.Tensor: Gradient of the Lovasz extension.
"""
p = len(gt_sorted)
gts = gt_sorted.sum()
intersection = gts - gt_sorted.float().cumsum(0)
union = gts + (1 - gt_sorted).float().cumsum(0)
jaccard = 1. - intersection / union
if p > 1: # cover 1-pixel case
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
return jaccard
def flatten_binary_logits(
logits: torch.Tensor,
labels: torch.Tensor,
ignore_index: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Flatten predictions and labels in the batch (binary case). Remove
tensors whose labels equal to 'ignore_index'.
Args:
probs (torch.Tensor): Predictions to be modified.
labels (torch.Tensor): Labels to be modified.
ignore_index (int, optional): The label index to be ignored.
Defaults to None.
Return:
tuple(torch.Tensor, torch.Tensor): Modified predictions and labels.
"""
logits = logits.view(-1)
labels = labels.view(-1)
if ignore_index is None:
return logits, labels
valid = (labels != ignore_index)
vlogits = logits[valid]
vlabels = labels[valid]
return vlogits, vlabels
def flatten_probs(
probs: torch.Tensor,
labels: torch.Tensor,
ignore_index: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Flatten predictions and labels in the batch. Remove tensors whose labels
equal to 'ignore_index'.
Args:
probs (torch.Tensor): Predictions to be modified.
labels (torch.Tensor): Labels to be modified.
ignore_index (int, optional): The label index to be ignored.
Defaults to None.
Return:
tuple(torch.Tensor, torch.Tensor): Modified predictions and labels.
"""
if probs.dim() != 2: # for input with P*C
if probs.dim() == 3:
# assumes output of a sigmoid layer
B, H, W = probs.size()
probs = probs.view(B, 1, H, W)
B, C, H, W = probs.size()
probs = probs.permute(0, 2, 3, 1).contiguous().view(-1,
C) # B*H*W, C=P,C
labels = labels.view(-1)
if ignore_index is None:
return probs, labels
valid = (labels != ignore_index)
vprobs = probs[valid.nonzero().squeeze()]
vlabels = labels[valid]
return vprobs, vlabels
def lovasz_hinge_flat(logits: torch.Tensor,
labels: torch.Tensor) -> torch.Tensor:
"""Binary Lovasz hinge loss.
Args:
logits (torch.Tensor): Logits at each prediction
(between -infty and +infty) with shape [P].
labels (torch.Tensor): Binary ground truth labels (0 or 1)
with shape [P].
Returns:
torch.Tensor: The calculated loss.
"""
if len(labels) == 0:
# only void pixels, the gradients should be 0
return logits.sum() * 0.
signs = 2. * labels.float() - 1.
errors = (1. - logits * signs)
errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
perm = perm.data
gt_sorted = labels[perm]
grad = lovasz_grad(gt_sorted)
loss = torch.dot(F.relu(errors_sorted), grad)
return loss
def lovasz_hinge(logits: torch.Tensor,
labels: torch.Tensor,
classes: Optional[Union[str, List[int]]] = None,
per_sample: bool = False,
class_weight: Optional[List[float]] = None,
reduction: str = 'mean',
avg_factor: Optional[int] = None,
ignore_index: int = 255) -> torch.Tensor:
"""Binary Lovasz hinge loss.
Args:
logits (torch.Tensor): Logits at each pixel
(between -infty and +infty) with shape [B, H, W].
labels (torch.Tensor): Binary ground truth masks (0 or 1)
with shape [B, H, W].
classes (Union[str, list[int]], optional): Placeholder, to be
consistent with other loss. Defaults to None.
per_sample (bool): If per_sample is True, compute the loss per
sample instead of per batch. Defaults to False.
class_weight (list[float], optional): Placeholder, to be consistent
with other loss. Defaults to None.
reduction (str): The method used to reduce the loss. Options
are "none", "mean" and "sum". This parameter only works when
per_sample is True. Defaults to 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. This parameter only works when per_sample is True.
Defaults to None.
ignore_index (Union[int, None]): The label index to be ignored.
Defaults to 255.
Returns:
torch.Tensor: The calculated loss.
"""
if per_sample:
loss = [
lovasz_hinge_flat(*flatten_binary_logits(
logit.unsqueeze(0), label.unsqueeze(0), ignore_index))
for logit, label in zip(logits, labels)
]
loss = weight_reduce_loss(
torch.stack(loss), None, reduction, avg_factor)
else:
loss = lovasz_hinge_flat(
*flatten_binary_logits(logits, labels, ignore_index))
return loss
def lovasz_softmax_flat(
probs: torch.Tensor,
labels: torch.Tensor,
classes: Union[str, List[int]] = 'present',
class_weight: Optional[List[float]] = None) -> torch.Tensor:
"""Multi-class Lovasz-Softmax loss.
Args:
probs (torch.Tensor): Class probabilities at each prediction
(between 0 and 1) with shape [P, C]
labels (torch.Tensor): Ground truth labels (between 0 and C - 1)
with shape [P].
classes (Union[str, list[int]]): Classes chosen to calculate loss.
'all' for all classes, 'present' for classes present in labels, or
a list of classes to average. Defaults to 'present'.
class_weight (list[float], optional): The weight for each class.
Defaults to None.
Returns:
torch.Tensor: The calculated loss.
"""
if probs.numel() == 0:
# only void pixels, the gradients should be 0
return probs * 0.
C = probs.size(1)
losses = []
class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
for c in class_to_sum:
fg = (labels == c).float() # foreground for class c
if (classes == 'present' and fg.sum() == 0):
continue
if C == 1:
if len(classes) > 1:
raise ValueError('Sigmoid output possible only with 1 class')
class_pred = probs[:, 0]
else:
class_pred = probs[:, c]
errors = (fg - class_pred).abs()
errors_sorted, perm = torch.sort(errors, 0, descending=True)
perm = perm.data
fg_sorted = fg[perm]
loss = torch.dot(errors_sorted, lovasz_grad(fg_sorted))
if class_weight is not None:
loss *= class_weight[c]
losses.append(loss)
return torch.stack(losses).mean()
def lovasz_softmax(probs: torch.Tensor,
labels: torch.Tensor,
classes: Union[str, List[int]] = 'present',
per_sample: bool = False,
class_weight: List[float] = None,
reduction: str = 'mean',
avg_factor: Optional[int] = None,
ignore_index: int = 255) -> torch.Tensor:
"""Multi-class Lovasz-Softmax loss.
Args:
probs (torch.Tensor): Class probabilities at each
prediction (between 0 and 1) with shape [B, C, H, W].
labels (torch.Tensor): Ground truth labels (between 0 and
C - 1) with shape [B, H, W].
classes (Union[str, list[int]]): Classes chosen to calculate loss.
'all' for all classes, 'present' for classes present in labels, or
a list of classes to average. Defaults to 'present'.
per_sample (bool): If per_sample is True, compute the loss per
sample instead of per batch. Defaults to False.
class_weight (list[float], optional): The weight for each class.
Defaults to None.
reduction (str): The method used to reduce the loss. Options
are "none", "mean" and "sum". This parameter only works when
per_sample is True. Defaults to 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. This parameter only works when per_sample is True.
Defaults to None.
ignore_index (Union[int, None]): The label index to be ignored.
Defaults to 255.
Returns:
torch.Tensor: The calculated loss.
"""
if per_sample:
loss = [
lovasz_softmax_flat(
*flatten_probs(
prob.unsqueeze(0), label.unsqueeze(0), ignore_index),
classes=classes,
class_weight=class_weight)
for prob, label in zip(probs, labels)
]
loss = weight_reduce_loss(
torch.stack(loss), None, reduction, avg_factor)
else:
loss = lovasz_softmax_flat(
*flatten_probs(probs, labels, ignore_index),
classes=classes,
class_weight=class_weight)
return loss
@MODELS.register_module()
class LovaszLoss(nn.Module):
"""LovaszLoss.
This loss is proposed in `The Lovasz-Softmax loss: A tractable surrogate
for the optimization of the intersection-over-union measure in neural
networks <https://arxiv.org/abs/1705.08790>`_.
Args:
loss_type (str): Binary or multi-class loss.
Defaults to 'multi_class'. Options are "binary" and "multi_class".
classes (Union[str, list[int]]): Classes chosen to calculate loss.
'all' for all classes, 'present' for classes present in labels, or
a list of classes to average. Defaults to 'present'.
per_sample (bool): If per_sample is True, compute the loss per
sample instead of per batch. Defaults to False.
reduction (str): The method used to reduce the loss. Options
are "none", "mean" and "sum". This parameter only works when
per_sample is True. Defaults to 'mean'.
class_weight ([list[float], optional): Weight of each class.
Defaults to None.
loss_weight (float): Weight of the loss. Defaults to 1.0.
"""
def __init__(self,
loss_type: str = 'multi_class',
classes: Union[str, List[int]] = 'present',
per_sample: bool = False,
reduction: str = 'mean',
class_weight: Optional[List[float]] = None,
loss_weight: float = 1.0):
super().__init__()
assert loss_type in ('binary', 'multi_class'), "loss_type should be \
'binary' or 'multi_class'."
if loss_type == 'binary':
self.cls_criterion = lovasz_hinge
else:
self.cls_criterion = lovasz_softmax
assert classes in ('all', 'present') or is_list_of(classes, int)
if not per_sample:
assert reduction == 'none', "reduction should be 'none' when \
per_sample is False."
self.classes = classes
self.per_sample = per_sample
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = class_weight
def forward(self,
cls_score: torch.Tensor,
label: torch.Tensor,
avg_factor: int = None,
reduction_override: str = None,
**kwargs) -> torch.Tensor:
"""Forward function."""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.class_weight is not None:
class_weight = cls_score.new_tensor(self.class_weight)
else:
class_weight = None
# if multi-class loss, transform logits to probs
if self.cls_criterion == lovasz_softmax:
cls_score = F.softmax(cls_score, dim=1)
loss_cls = self.loss_weight * self.cls_criterion(
cls_score,
label,
self.classes,
self.per_sample,
class_weight=class_weight,
reduction=reduction,
avg_factor=avg_factor,
**kwargs)
return loss_cls
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import pytest
import torch
from mmcv.ops import SparseConvTensor
from mmdet3d.models.decode_heads import Cylinder3DHead
from mmdet3d.structures import Det3DDataSample, PointData
class TestCylinder3DHead(TestCase):
def test_cylinder3d_head_loss(self):
"""Tests Cylinder3D head loss."""
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
cylinder3d_head = Cylinder3DHead(
channels=128,
num_classes=20,
loss_ce=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=False,
class_weight=None,
loss_weight=1.0),
loss_lovasz=dict(
type='LovaszLoss', loss_weight=1.0, reduction='none'),
).cuda()
voxel_feats = torch.rand(50, 128).cuda()
coorx = torch.randint(0, 480, (50, 1)).int().cuda()
coory = torch.randint(0, 360, (50, 1)).int().cuda()
coorz = torch.randint(0, 32, (50, 1)).int().cuda()
coorbatch0 = torch.zeros(50, 1).int().cuda()
coors = torch.cat([coorbatch0, coorx, coory, coorz], dim=1)
grid_size = [480, 360, 32]
batch_size = 1
sparse_voxels = SparseConvTensor(voxel_feats, coors, grid_size,
batch_size)
# Test forward
seg_logits = cylinder3d_head.forward(sparse_voxels)
self.assertEqual(seg_logits.features.shape, torch.Size([50, 20]))
# When truth is non-empty then losses
# should be nonzero for random inputs
voxel_semantic_mask = torch.randint(0, 20, (50, )).long().cuda()
gt_pts_seg = PointData(voxel_semantic_mask=voxel_semantic_mask)
datasample = Det3DDataSample()
datasample.gt_pts_seg = gt_pts_seg
losses = cylinder3d_head.loss_by_feat(seg_logits, [datasample])
loss_ce = losses['loss_ce'].item()
loss_lovasz = losses['loss_lovasz'].item()
self.assertGreater(loss_ce, 0, 'ce loss should be positive')
self.assertGreater(loss_lovasz, 0, 'lovasz loss should be positive')
batch_inputs_dict = dict(voxels=dict(voxel_coors=coors))
datasample.gt_pts_seg.point2voxel_map = torch.randint(
0, 50, (100, )).int().cuda()
point_logits = cylinder3d_head.predict(sparse_voxels,
batch_inputs_dict, [datasample])
assert point_logits[0].shape == torch.Size([100, 20])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment