Unverified Commit bc97d76a authored by xiliu8006's avatar xiliu8006 Committed by GitHub
Browse files

[Enhancement]: Add assertions of data dimensions in points_in_boxes.py (#357)

* Avoid 4 dims points

* add assert prinf

* modify print info
parent 1feb8917
......@@ -15,8 +15,15 @@ def points_in_boxes_gpu(points, boxes):
Returns:
box_idxs_of_pts (torch.Tensor): (B, M), default background = -1
"""
assert boxes.shape[0] == points.shape[0]
assert boxes.shape[2] == 7
assert boxes.shape[0] == points.shape[0], \
f'Points and boxes should have the same batch size, ' \
f'got {boxes.shape[0]} and {boxes.shape[0]}'
assert boxes.shape[2] == 7, \
f'boxes dimension should be 7, ' \
f'got unexpected shape {boxes.shape[2]}'
assert points.shape[2] == 3, \
f'points dimension should be 3, ' \
f'got unexpected shape {points.shape[2]}'
batch_size, num_points, _ = points.shape
box_idxs_of_pts = points.new_zeros((batch_size, num_points),
......@@ -59,8 +66,12 @@ def points_in_boxes_cpu(points, boxes):
point_indices (torch.Tensor): (N, npoints)
"""
# TODO: Refactor this function as a CPU version of points_in_boxes_gpu
assert boxes.shape[1] == 7
assert points.shape[1] == 3
assert boxes.shape[1] == 7, \
f'boxes dimension should be 7, ' \
f'got unexpected shape {boxes.shape[2]}'
assert points.shape[1] == 3, \
f'points dimension should be 3, ' \
f'got unexpected shape {points.shape[2]}'
point_indices = points.new_zeros((boxes.shape[0], points.shape[0]),
dtype=torch.int)
......@@ -83,8 +94,15 @@ def points_in_boxes_batch(points, boxes):
Returns:
box_idxs_of_pts (torch.Tensor): (B, M, T), default background = 0
"""
assert boxes.shape[0] == points.shape[0]
assert boxes.shape[2] == 7
assert boxes.shape[0] == points.shape[0], \
f'Points and boxes should have the same batch size, ' \
f'got {boxes.shape[0]} and {boxes.shape[0]}'
assert boxes.shape[2] == 7, \
f'boxes dimension should be 7, ' \
f'got unexpected shape {boxes.shape[2]}'
assert points.shape[2] == 3, \
f'points dimension should be 3, ' \
f'got unexpected shape {points.shape[2]}'
batch_size, num_points, _ = points.shape
num_boxes = boxes.shape[1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment