Commit c89a92f9 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

update back to original version of torch.div and torch.meshgrid to also...

update back to original version of torch.div and torch.meshgrid to also support lower version pytorch
parent e496d248
...@@ -139,11 +139,11 @@ def _topk(scores, K=40): ...@@ -139,11 +139,11 @@ def _topk(scores, K=40):
topk_scores, topk_inds = torch.topk(scores.flatten(2, 3), K) topk_scores, topk_inds = torch.topk(scores.flatten(2, 3), K)
topk_inds = topk_inds % (height * width) topk_inds = topk_inds % (height * width)
topk_ys = torch.div(topk_inds, width, rounding_mode='floor').float() topk_ys = (topk_inds // width).float()
topk_xs = (topk_inds % width).int().float() topk_xs = (topk_inds % width).int().float()
topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K) topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K)
topk_classes = torch.div(topk_ind, K, rounding_mode='floor').int() topk_classes = (topk_ind // K).int()
topk_inds = _gather_feat(topk_inds.view(batch, -1, 1), topk_ind).view(batch, K) topk_inds = _gather_feat(topk_inds.view(batch, -1, 1), topk_ind).view(batch, K)
topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K) topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K)
topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K) topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K)
......
...@@ -215,7 +215,7 @@ class VectorPoolLocalInterpolateModule(nn.Module): ...@@ -215,7 +215,7 @@ class VectorPoolLocalInterpolateModule(nn.Module):
self.max_neighbour_distance, self.nsample, self.neighbor_type, self.max_neighbour_distance, self.nsample, self.neighbor_type,
self.num_avg_length_of_neighbor_idxs, self.num_total_grids, self.neighbor_distance_multiplier self.num_avg_length_of_neighbor_idxs, self.num_total_grids, self.neighbor_distance_multiplier
) )
self.num_avg_length_of_neighbor_idxs = max(self.num_avg_length_of_neighbor_idxs, num_avg_length_of_neighbor_idxs) self.num_avg_length_of_neighbor_idxs = max(self.num_avg_length_of_neighbor_idxs, num_avg_length_of_neighbor_idxs.item())
dist_recip = 1.0 / (dist + 1e-8) dist_recip = 1.0 / (dist + 1e-8)
norm = torch.sum(dist_recip, dim=-1, keepdim=True) norm = torch.sum(dist_recip, dim=-1, keepdim=True)
...@@ -349,7 +349,7 @@ class VectorPoolAggregationModule(nn.Module): ...@@ -349,7 +349,7 @@ class VectorPoolAggregationModule(nn.Module):
x_grids = torch.arange(-R + R / num_voxels[0], R - R / num_voxels[0] + 1e-5, 2 * R / num_voxels[0], device=device) x_grids = torch.arange(-R + R / num_voxels[0], R - R / num_voxels[0] + 1e-5, 2 * R / num_voxels[0], device=device)
y_grids = torch.arange(-R + R / num_voxels[1], R - R / num_voxels[1] + 1e-5, 2 * R / num_voxels[1], device=device) y_grids = torch.arange(-R + R / num_voxels[1], R - R / num_voxels[1] + 1e-5, 2 * R / num_voxels[1], device=device)
z_grids = torch.arange(-R + R / num_voxels[2], R - R / num_voxels[2] + 1e-5, 2 * R / num_voxels[2], device=device) z_grids = torch.arange(-R + R / num_voxels[2], R - R / num_voxels[2] + 1e-5, 2 * R / num_voxels[2], device=device)
x_offset, y_offset, z_offset = torch.meshgrid(x_grids, y_grids, z_grids, indexing='ij') # shape: [num_x, num_y, num_z] x_offset, y_offset, z_offset = torch.meshgrid(x_grids, y_grids, z_grids) # shape: [num_x, num_y, num_z]
xyz_offset = torch.cat(( xyz_offset = torch.cat((
x_offset.contiguous().view(-1, 1), x_offset.contiguous().view(-1, 1),
y_offset.contiguous().view(-1, 1), y_offset.contiguous().view(-1, 1),
......
...@@ -348,7 +348,7 @@ class ThreeNNForVectorPoolByTwoStep(Function): ...@@ -348,7 +348,7 @@ class ThreeNNForVectorPoolByTwoStep(Function):
stack_neighbor_idxs, start_len, num_new_xyz, num_total_grids stack_neighbor_idxs, start_len, num_new_xyz, num_total_grids
) )
return torch.sqrt(new_xyz_grid_dist2), new_xyz_grid_idxs, avg_length_of_neighbor_idxs return torch.sqrt(new_xyz_grid_dist2), new_xyz_grid_idxs, torch.tensor(avg_length_of_neighbor_idxs)
three_nn_for_vector_pool_by_two_step = ThreeNNForVectorPoolByTwoStep.apply three_nn_for_vector_pool_by_two_step = ThreeNNForVectorPoolByTwoStep.apply
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment