Unverified Commit bdeacecd authored by Yezhen Cong's avatar Yezhen Cong Committed by GitHub
Browse files

[Fix] Fix error when tensors are not in the same device (#317)

* fix bug caused by mmcv upgrade; delete pdb breakpoint

* fix typos

* use torch.cuda api

* update unittest for points_in_bbox_gpu and points_in_boxes_batch

* Added comments for explanation
parent 7b947e04
......@@ -21,6 +21,21 @@ def points_in_boxes_gpu(points, boxes):
box_idxs_of_pts = points.new_zeros((batch_size, num_points),
dtype=torch.int).fill_(-1)
# If manually put the tensor 'points' or 'boxes' on a device
# which is not the current device, some temporary variables
# will be created on the current device in the cuda op,
# and the output will be incorrect.
# Therefore, we force the current device to be the same
# as the device of the tensors if it was not.
# Please refer to https://github.com/open-mmlab/mmdetection3d/issues/305
# for the incorrect output before the fix.
points_device = points.get_device()
assert points_device == boxes.get_device(), \
'Points and boxes should be put on the same device'
if torch.cuda.current_device() != points_device:
torch.cuda.set_device(points_device)
roiaware_pool3d_ext.points_in_boxes_gpu(boxes.contiguous(),
points.contiguous(),
box_idxs_of_pts)
......@@ -75,6 +90,14 @@ def points_in_boxes_batch(points, boxes):
box_idxs_of_pts = points.new_zeros((batch_size, num_points, num_boxes),
dtype=torch.int).fill_(0)
# Same reason as line 25-32
points_device = points.get_device()
assert points_device == boxes.get_device(), \
'Points and boxes should be put on the same device'
if torch.cuda.current_device() != points_device:
torch.cuda.set_device(points_device)
roiaware_pool3d_ext.points_in_boxes_batch(boxes.contiguous(),
points.contiguous(),
box_idxs_of_pts)
......
......@@ -63,6 +63,14 @@ def test_points_in_boxes_gpu():
assert point_indices.shape == torch.Size([2, 8])
assert (point_indices == expected_point_indices).all()
if torch.cuda.device_count() > 1:
pts = pts.to('cuda:1')
boxes = boxes.to('cuda:1')
expected_point_indices = expected_point_indices.to('cuda:1')
point_indices = points_in_boxes_gpu(points=pts, boxes=boxes)
assert point_indices.shape == torch.Size([2, 8])
assert (point_indices == expected_point_indices).all()
def test_points_in_boxes_cpu():
boxes = torch.tensor(
......@@ -110,3 +118,11 @@ def test_points_in_boxes_batch():
dtype=torch.int32).cuda()
assert point_indices.shape == torch.Size([1, 15, 2])
assert (point_indices == expected_point_indices).all()
if torch.cuda.device_count() > 1:
pts = pts.to('cuda:1')
boxes = boxes.to('cuda:1')
expected_point_indices = expected_point_indices.to('cuda:1')
point_indices = points_in_boxes_batch(points=pts, boxes=boxes)
assert point_indices.shape == torch.Size([1, 15, 2])
assert (point_indices == expected_point_indices).all()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment