Commit 4a0e9929 authored by ChaimZhu's avatar ChaimZhu Committed by Zaida Zhou
Browse files

[Fix] Fix the potential NaN bug in calc_square_dist() (#2356)

parent 2d10616b
...@@ -23,16 +23,11 @@ def calc_square_dist(point_feat_a: Tensor, ...@@ -23,16 +23,11 @@ def calc_square_dist(point_feat_a: Tensor,
torch.Tensor: (B, N, M) Square distance between each point pair. torch.Tensor: (B, N, M) Square distance between each point pair.
""" """
num_channel = point_feat_a.shape[-1] num_channel = point_feat_a.shape[-1]
# [bs, n, 1] dist = torch.cdist(point_feat_a, point_feat_b)
a_square = torch.sum(point_feat_a.unsqueeze(dim=2).pow(2), dim=-1)
# [bs, 1, m]
b_square = torch.sum(point_feat_b.unsqueeze(dim=1).pow(2), dim=-1)
corr_matrix = torch.matmul(point_feat_a, point_feat_b.transpose(1, 2))
dist = a_square + b_square - 2 * corr_matrix
if norm: if norm:
dist = torch.sqrt(dist) / num_channel dist = dist / num_channel
else:
dist = torch.square(dist)
return dist return dist
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment