Commit 46d558ed authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

bugfixed: cumsum floordiv in vectorpool

parent 55d12ff5
......@@ -215,7 +215,7 @@ class VectorPoolLocalInterpolateModule(nn.Module):
self.max_neighbour_distance, self.nsample, self.neighbor_type,
self.num_avg_length_of_neighbor_idxs, self.num_total_grids, self.neighbor_distance_multiplier
)
self.num_avg_length_of_neighbor_idxs = max(self.num_avg_length_of_neighbor_idxs, num_avg_length_of_neighbor_idxs.item())
self.num_avg_length_of_neighbor_idxs = max(self.num_avg_length_of_neighbor_idxs, num_avg_length_of_neighbor_idxs)
dist_recip = 1.0 / (dist + 1e-8)
norm = torch.sum(dist_recip, dim=-1, keepdim=True)
......
......@@ -337,7 +337,7 @@ class ThreeNNForVectorPoolByTwoStep(Function):
avg_length_of_neighbor_idxs, max_neighbour_distance * neighbor_distance_multiplier,
nsample, neighbor_type
)
avg_length_of_neighbor_idxs = cumsum[0] // num_new_xyz + int(cumsum[0] % num_new_xyz > 0)
avg_length_of_neighbor_idxs = cumsum[0].item() // num_new_xyz + int(cumsum[0].item() % num_new_xyz > 0)
if cumsum[0] <= num_max_sum_points:
break
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment