Unverified Commit c84285da authored by encore-zhou's avatar encore-zhou Committed by GitHub
Browse files

[Feautre]: Modify primitive head to support SUN-RGBD dataset (#136)

* modify primitive head

* modify primitive head

* rename upper and lower

* modify doc string
parent f05828b0
...@@ -354,10 +354,25 @@ class PrimitiveHead(nn.Module): ...@@ -354,10 +354,25 @@ class PrimitiveHead(nn.Module):
# Semantic information of primitive center # Semantic information of primitive center
point_sem = points.new_zeros([num_points, 3 + self.num_dims + 1]) point_sem = points.new_zeros([num_points, 3 + self.num_dims + 1])
# Generate pts_semantic_mask and pts_instance_mask when they are None
if pts_semantic_mask is None or pts_instance_mask is None:
points2box_mask = gt_bboxes_3d.points_in_boxes(points)
assignment = points2box_mask.argmax(1)
background_mask = points2box_mask.max(1)[0] == 0
if pts_semantic_mask is None:
pts_semantic_mask = gt_labels_3d[assignment]
pts_semantic_mask[background_mask] = self.num_classes
if pts_instance_mask is None:
pts_instance_mask = assignment
pts_instance_mask[background_mask] = gt_labels_3d.shape[0]
instance_flag = torch.nonzero( instance_flag = torch.nonzero(
pts_semantic_mask != self.num_classes).squeeze(1) pts_semantic_mask != self.num_classes).squeeze(1)
instance_labels = pts_instance_mask[instance_flag].unique() instance_labels = pts_instance_mask[instance_flag].unique()
with_yaw = gt_bboxes_3d.with_yaw
for i, i_instance in enumerate(instance_labels): for i, i_instance in enumerate(instance_labels):
indices = instance_flag[pts_instance_mask[instance_flag] == indices = instance_flag[pts_instance_mask[instance_flag] ==
i_instance] i_instance]
...@@ -366,8 +381,6 @@ class PrimitiveHead(nn.Module): ...@@ -366,8 +381,6 @@ class PrimitiveHead(nn.Module):
# Bbox Corners # Bbox Corners
cur_corners = gt_bboxes_3d.corners[i] cur_corners = gt_bboxes_3d.corners[i]
xmin, ymin, zmin = cur_corners.min(0)[0]
xmax, ymax, zmax = cur_corners.max(0)[0]
plane_lower_temp = points.new_tensor( plane_lower_temp = points.new_tensor(
[0, 0, 1, -cur_corners[7, -1]]) [0, 0, 1, -cur_corners[7, -1]])
...@@ -392,10 +405,10 @@ class PrimitiveHead(nn.Module): ...@@ -392,10 +405,10 @@ class PrimitiveHead(nn.Module):
point2plane_dist, selected = self.match_point2plane( point2plane_dist, selected = self.match_point2plane(
plane_lower, coords) plane_lower, coords)
# Get lower four lines # Get bottom four lines
if self.primitive_mode == 'line': if self.primitive_mode == 'line':
point2line_matching = self.match_point2line( point2line_matching = self.match_point2line(
coords[selected], xmin, xmax, ymin, ymax) coords[selected], cur_corners, with_yaw, mode='bottom')
point_mask, point_offset, point_sem = \ point_mask, point_offset, point_sem = \
self._assign_primitive_line_targets(point_mask, self._assign_primitive_line_targets(point_mask,
...@@ -406,7 +419,9 @@ class PrimitiveHead(nn.Module): ...@@ -406,7 +419,9 @@ class PrimitiveHead(nn.Module):
cur_cls_label, cur_cls_label,
point2line_matching, point2line_matching,
cur_corners, cur_corners,
[1, 1, 0, 0]) [1, 1, 0, 0],
with_yaw,
mode='bottom')
# Set the surface labels here # Set the surface labels here
if self.primitive_mode == 'z' and \ if self.primitive_mode == 'z' and \
...@@ -421,16 +436,18 @@ class PrimitiveHead(nn.Module): ...@@ -421,16 +436,18 @@ class PrimitiveHead(nn.Module):
coords[selected], coords[selected],
indices[selected], indices[selected],
cur_cls_label, cur_cls_label,
cur_corners) cur_corners,
with_yaw,
mode='bottom')
# Get the boundary points here # Get the boundary points here
point2plane_dist, selected = self.match_point2plane( point2plane_dist, selected = self.match_point2plane(
plane_upper, coords) plane_upper, coords)
# Get upper four lines # Get top four lines
if self.primitive_mode == 'line': if self.primitive_mode == 'line':
point2line_matching = self.match_point2line( point2line_matching = self.match_point2line(
coords[selected], xmin, xmax, ymin, ymax) coords[selected], cur_corners, with_yaw, mode='top')
point_mask, point_offset, point_sem = \ point_mask, point_offset, point_sem = \
self._assign_primitive_line_targets(point_mask, self._assign_primitive_line_targets(point_mask,
...@@ -441,7 +458,9 @@ class PrimitiveHead(nn.Module): ...@@ -441,7 +458,9 @@ class PrimitiveHead(nn.Module):
cur_cls_label, cur_cls_label,
point2line_matching, point2line_matching,
cur_corners, cur_corners,
[1, 1, 0, 0]) [1, 1, 0, 0],
with_yaw,
mode='top')
if self.primitive_mode == 'z' and \ if self.primitive_mode == 'z' and \
selected.sum() > self.train_cfg['num_point'] and \ selected.sum() > self.train_cfg['num_point'] and \
...@@ -455,7 +474,9 @@ class PrimitiveHead(nn.Module): ...@@ -455,7 +474,9 @@ class PrimitiveHead(nn.Module):
coords[selected], coords[selected],
indices[selected], indices[selected],
cur_cls_label, cur_cls_label,
cur_corners) cur_corners,
with_yaw,
mode='top')
# Get left two lines # Get left two lines
plane_left_temp = self._get_plane_fomulation( plane_left_temp = self._get_plane_fomulation(
...@@ -480,20 +501,16 @@ class PrimitiveHead(nn.Module): ...@@ -480,20 +501,16 @@ class PrimitiveHead(nn.Module):
point2plane_dist, selected = self.match_point2plane( point2plane_dist, selected = self.match_point2plane(
plane_left, coords) plane_left, coords)
# Get upper four lines # Get left four lines
if self.primitive_mode == 'line': if self.primitive_mode == 'line':
_, _, line_sel1, line_sel2 = self.match_point2line( point2line_matching = self.match_point2line(
coords[selected], xmin, xmax, ymin, ymax) coords[selected], cur_corners, with_yaw, mode='left')
point_mask, point_offset, point_sem = \ point_mask, point_offset, point_sem = \
self._assign_primitive_line_targets(point_mask, self._assign_primitive_line_targets(
point_offset, point_mask, point_offset, point_sem,
point_sem, coords[selected], indices[selected], cur_cls_label,
coords[selected], point2line_matching[2:], cur_corners, [2, 2],
indices[selected], with_yaw, mode='left')
cur_cls_label,
[line_sel1, line_sel2],
cur_corners,
[2, 2])
if self.primitive_mode == 'xy' and \ if self.primitive_mode == 'xy' and \
selected.sum() > self.train_cfg['num_point'] and \ selected.sum() > self.train_cfg['num_point'] and \
...@@ -501,32 +518,26 @@ class PrimitiveHead(nn.Module): ...@@ -501,32 +518,26 @@ class PrimitiveHead(nn.Module):
self.train_cfg['var_thresh']: self.train_cfg['var_thresh']:
point_mask, point_offset, point_sem = \ point_mask, point_offset, point_sem = \
self._assign_primitive_surface_targets(point_mask, self._assign_primitive_surface_targets(
point_offset, point_mask, point_offset, point_sem,
point_sem, coords[selected], indices[selected], cur_cls_label,
coords[selected], cur_corners, with_yaw, mode='left')
indices[selected],
cur_cls_label,
cur_corners)
# Get the boundary points here # Get the boundary points here
point2plane_dist, selected = self.match_point2plane( point2plane_dist, selected = self.match_point2plane(
plane_right, coords) plane_right, coords)
# Get right four lines
if self.primitive_mode == 'line': if self.primitive_mode == 'line':
_, _, line_sel1, line_sel2 = self.match_point2line( point2line_matching = self.match_point2line(
coords[selected], xmin, xmax, ymin, ymax) coords[selected], cur_corners, with_yaw, mode='right')
point_mask, point_offset, point_sem = \ point_mask, point_offset, point_sem = \
self._assign_primitive_line_targets(point_mask, self._assign_primitive_line_targets(
point_offset, point_mask, point_offset, point_sem,
point_sem, coords[selected], indices[selected], cur_cls_label,
coords[selected], point2line_matching[2:], cur_corners, [2, 2],
indices[selected], with_yaw, mode='right')
cur_cls_label,
[line_sel1, line_sel2],
cur_corners,
[2, 2])
if self.primitive_mode == 'xy' and \ if self.primitive_mode == 'xy' and \
selected.sum() > self.train_cfg['num_point'] and \ selected.sum() > self.train_cfg['num_point'] and \
...@@ -534,13 +545,10 @@ class PrimitiveHead(nn.Module): ...@@ -534,13 +545,10 @@ class PrimitiveHead(nn.Module):
self.train_cfg['var_thresh']: self.train_cfg['var_thresh']:
point_mask, point_offset, point_sem = \ point_mask, point_offset, point_sem = \
self._assign_primitive_surface_targets(point_mask, self._assign_primitive_surface_targets(
point_offset, point_mask, point_offset, point_sem,
point_sem, coords[selected], indices[selected], cur_cls_label,
coords[selected], cur_corners, with_yaw, mode='right')
indices[selected],
cur_cls_label,
cur_corners)
plane_front_temp = self._get_plane_fomulation( plane_front_temp = self._get_plane_fomulation(
cur_corners[0] - cur_corners[4], cur_corners[0] - cur_corners[4],
...@@ -570,13 +578,10 @@ class PrimitiveHead(nn.Module): ...@@ -570,13 +578,10 @@ class PrimitiveHead(nn.Module):
self.train_cfg['var_thresh']: self.train_cfg['var_thresh']:
point_mask, point_offset, point_sem = \ point_mask, point_offset, point_sem = \
self._assign_primitive_surface_targets(point_mask, self._assign_primitive_surface_targets(
point_offset, point_mask, point_offset, point_sem,
point_sem, coords[selected], indices[selected], cur_cls_label,
coords[selected], cur_corners, with_yaw, mode='front')
indices[selected],
cur_cls_label,
cur_corners)
# Get the boundary points here # Get the boundary points here
point2plane_dist, selected = self.match_point2plane( point2plane_dist, selected = self.match_point2plane(
...@@ -588,13 +593,10 @@ class PrimitiveHead(nn.Module): ...@@ -588,13 +593,10 @@ class PrimitiveHead(nn.Module):
self.train_cfg['var_thresh']: self.train_cfg['var_thresh']:
point_mask, point_offset, point_sem = \ point_mask, point_offset, point_sem = \
self._assign_primitive_surface_targets(point_mask, self._assign_primitive_surface_targets(
point_offset, point_mask, point_offset, point_sem,
point_sem, coords[selected], indices[selected], cur_cls_label,
coords[selected], cur_corners, with_yaw, mode='back')
indices[selected],
cur_cls_label,
cur_corners)
return (point_mask, point_sem, point_offset) return (point_mask, point_sem, point_offset)
...@@ -652,24 +654,65 @@ class PrimitiveHead(nn.Module): ...@@ -652,24 +654,65 @@ class PrimitiveHead(nn.Module):
return (points[:, 2] + return (points[:, 2] +
plane_equ[-1]).sum() / 4.0 < self.train_cfg['lower_thresh'] plane_equ[-1]).sum() / 4.0 < self.train_cfg['lower_thresh']
def match_point2line(self, points, xmin, xmax, ymin, ymax): def point2line_dist(self, points, pts_a, pts_b):
"""Calculate the distance from point to line.
Args:
points (torch.Tensor): Points of input.
pts_a (torch.Tensor): Point on the specific line.
pts_b (torch.Tensor): Point on the specific line.
Returns:
torch.Tensor: Distance between each point to line.
"""
line_a2b = pts_b - pts_a
line_a2pts = points - pts_a
length = (line_a2pts * line_a2b.view(1, 3)).sum(1) / \
line_a2b.norm()
dist = (line_a2pts.norm(dim=1)**2 - length**2).sqrt()
return dist
def match_point2line(self, points, corners, with_yaw, mode='bottom'):
"""Match points to corresponding line. """Match points to corresponding line.
Args: Args:
points (torch.Tensor): Points of input. points (torch.Tensor): Points of input.
xmin (float): Min of X-axis. corners (torch.Tensor): Eight corners of a bounding box.
xmax (float): Max of X-axis. with_yaw (Bool): Whether the boundind box is with rotation.
ymin (float): Min of Y-axis. mode (str, optional): Specify which line should be matched,
ymax (float): Max of Y-axis. available mode are ('bottom', 'top', 'left', 'right').
Defaults to 'bottom'.
Returns: Returns:
Tuple: Flag of matching correspondence. Tuple: Flag of matching correspondence.
""" """
sel1 = torch.abs(points[:, 0] - xmin) < self.train_cfg['line_thresh'] if with_yaw:
sel2 = torch.abs(points[:, 0] - xmax) < self.train_cfg['line_thresh'] corners_pair = {
sel3 = torch.abs(points[:, 1] - ymin) < self.train_cfg['line_thresh'] 'bottom': [[0, 3], [4, 7], [0, 4], [3, 7]],
sel4 = torch.abs(points[:, 1] - ymax) < self.train_cfg['line_thresh'] 'top': [[1, 2], [5, 6], [1, 5], [2, 6]],
return sel1, sel2, sel3, sel4 'left': [[0, 1], [3, 2], [0, 1], [3, 2]],
'right': [[4, 5], [7, 6], [4, 5], [7, 6]]
}
selected_list = []
for pair_index in corners_pair[mode]:
selected = self.point2line_dist(
points, corners[pair_index[0]], corners[pair_index[1]]) \
< self.train_cfg['line_thresh']
selected_list.append(selected)
else:
xmin, ymin, _ = corners.min(0)[0]
xmax, ymax, _ = corners.max(0)[0]
sel1 = torch.abs(points[:, 0] -
xmin) < self.train_cfg['line_thresh']
sel2 = torch.abs(points[:, 0] -
xmax) < self.train_cfg['line_thresh']
sel3 = torch.abs(points[:, 1] -
ymin) < self.train_cfg['line_thresh']
sel4 = torch.abs(points[:, 1] -
ymax) < self.train_cfg['line_thresh']
selected_list = [sel1, sel2, sel3, sel4]
return selected_list
def match_point2plane(self, plane, points): def match_point2plane(self, plane, points):
"""Match points to plane. """Match points to plane.
...@@ -757,10 +800,18 @@ class PrimitiveHead(nn.Module): ...@@ -757,10 +800,18 @@ class PrimitiveHead(nn.Module):
center = center + offset * selected.unsqueeze(-1) center = center + offset * selected.unsqueeze(-1)
return center, pred_indices return center, pred_indices
def _assign_primitive_line_targets(self, point_mask, point_offset, def _assign_primitive_line_targets(self,
point_sem, coords, indices, cls_label, point_mask,
point2line_matching, corners, point_offset,
center_axises): point_sem,
coords,
indices,
cls_label,
point2line_matching,
corners,
center_axises,
with_yaw,
mode='bottom'):
"""Generate targets of line primitive. """Generate targets of line primitive.
Args: Args:
...@@ -778,16 +829,35 @@ class PrimitiveHead(nn.Module): ...@@ -778,16 +829,35 @@ class PrimitiveHead(nn.Module):
corners (torch.Tensor): Corners of the ground truth bounding box. corners (torch.Tensor): Corners of the ground truth bounding box.
center_axises (list[int]): Indicate in which axis the line center center_axises (list[int]): Indicate in which axis the line center
should be refined. should be refined.
with_yaw (Bool): Whether the boundind box is with rotation.
mode (str, optional): Specify which line should be matched,
available mode are ('bottom', 'top', 'left', 'right').
Defaults to 'bottom'.
Returns: Returns:
Tuple: Targets of the line primitive. Tuple: Targets of the line primitive.
""" """
for line_select, center_axis in zip(point2line_matching, corners_pair = {
center_axises): 'bottom': [[0, 3], [4, 7], [0, 4], [3, 7]],
'top': [[1, 2], [5, 6], [1, 5], [2, 6]],
'left': [[0, 1], [3, 2]],
'right': [[4, 5], [7, 6]]
}
corners_pair = corners_pair[mode]
assert len(corners_pair) == len(point2line_matching) == len(
center_axises)
for line_select, center_axis, pair_index in zip(
point2line_matching, center_axises, corners_pair):
if line_select.sum() > self.train_cfg['num_point_line']: if line_select.sum() > self.train_cfg['num_point_line']:
point_mask[indices[line_select]] = 1.0 point_mask[indices[line_select]] = 1.0
line_center = coords[line_select].mean(dim=0)
line_center[center_axis] = corners[:, center_axis].mean() if with_yaw:
line_center = (corners[pair_index[0]] +
corners[pair_index[1]]) / 2
else:
line_center = coords[line_select].mean(dim=0)
line_center[center_axis] = corners[:, center_axis].mean()
point_offset[indices[line_select]] = \ point_offset[indices[line_select]] = \
line_center - coords[line_select] line_center - coords[line_select]
point_sem[indices[line_select]] = \ point_sem[indices[line_select]] = \
...@@ -795,9 +865,16 @@ class PrimitiveHead(nn.Module): ...@@ -795,9 +865,16 @@ class PrimitiveHead(nn.Module):
line_center[2], cls_label]) line_center[2], cls_label])
return point_mask, point_offset, point_sem return point_mask, point_offset, point_sem
def _assign_primitive_surface_targets(self, point_mask, point_offset, def _assign_primitive_surface_targets(self,
point_sem, coords, indices, point_mask,
cls_label, corners): point_offset,
point_sem,
coords,
indices,
cls_label,
corners,
with_yaw,
mode='bottom'):
"""Generate targets for primitive z and primitive xy. """Generate targets for primitive z and primitive xy.
Args: Args:
...@@ -811,29 +888,64 @@ class PrimitiveHead(nn.Module): ...@@ -811,29 +888,64 @@ class PrimitiveHead(nn.Module):
indices (torch.Tensor): Indices of the selected points. indices (torch.Tensor): Indices of the selected points.
cls_label (int): Class label of the ground truth bounding box. cls_label (int): Class label of the ground truth bounding box.
corners (torch.Tensor): Corners of the ground truth bounding box. corners (torch.Tensor): Corners of the ground truth bounding box.
with_yaw (Bool): Whether the boundind box is with rotation.
mode (str, optional): Specify which line should be matched,
available mode are ('bottom', 'top', 'left', 'right',
'front', 'back').
Defaults to 'bottom'.
Returns: Returns:
Tuple: Targets of the center primitive. Tuple: Targets of the center primitive.
""" """
point_mask[indices] = 1.0 point_mask[indices] = 1.0
corners_pair = {
'bottom': [0, 7],
'top': [1, 6],
'left': [0, 1],
'right': [4, 5],
'front': [0, 1],
'back': [3, 2]
}
pair_index = corners_pair[mode]
if self.primitive_mode == 'z': if self.primitive_mode == 'z':
center = point_mask.new_tensor([ if with_yaw:
corners[:, 0].mean(), corners[:, 1].mean(), coords[:, center = (corners[pair_index[0]] +
2].mean() corners[pair_index[1]]) / 2.0
]) center[2] = coords[:, 2].mean()
point_sem[indices] = point_sem.new_tensor([ point_sem[indices] = point_sem.new_tensor([
center[0], center[1], center[2], center[0], center[1],
corners[:, 0].max() - corners[:, 0].min(), center[2], (corners[4] - corners[0]).norm(),
corners[:, 1].max() - corners[:, 1].min(), cls_label (corners[3] - corners[0]).norm(), cls_label
]) ])
else:
center = point_mask.new_tensor([
corners[:, 0].mean(), corners[:, 1].mean(),
coords[:, 2].mean()
])
point_sem[indices] = point_sem.new_tensor([
center[0], center[1], center[2],
corners[:, 0].max() - corners[:, 0].min(),
corners[:, 1].max() - corners[:, 1].min(), cls_label
])
elif self.primitive_mode == 'xy': elif self.primitive_mode == 'xy':
center = point_mask.new_tensor([ if with_yaw:
coords[:, 0].mean(), coords[:, 1].mean(), corners[:, 2].mean() center = coords.mean(0)
]) center[2] = (corners[pair_index[0], 2] +
point_sem[indices] = point_sem.new_tensor([ corners[pair_index[1], 2]) / 2.0
center[0], center[1], center[2], point_sem[indices] = point_sem.new_tensor([
corners[:, 2].max() - corners[:, 2].min(), cls_label center[0], center[1], center[2],
]) corners[pair_index[1], 2] - corners[pair_index[0], 2],
cls_label
])
else:
center = point_mask.new_tensor([
coords[:, 0].mean(), coords[:, 1].mean(),
corners[:, 2].mean()
])
point_sem[indices] = point_sem.new_tensor([
center[0], center[1], center[2],
corners[:, 2].max() - corners[:, 2].min(), cls_label
])
point_offset[indices] = center - coords point_offset[indices] = center - coords
return point_mask, point_offset, point_sem return point_mask, point_offset, point_sem
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment