Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
4a0e9929
Commit
4a0e9929
authored
Oct 26, 2022
by
ChaimZhu
Committed by
Zaida Zhou
Nov 23, 2022
Browse files
[Fix] Fix the potential NaN bug in calc_square_dist() (#2356)
parent
2d10616b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
9 deletions
+4
-9
mmcv/ops/points_sampler.py
mmcv/ops/points_sampler.py
+4
-9
No files found.
mmcv/ops/points_sampler.py
View file @
4a0e9929
...
@@ -23,16 +23,11 @@ def calc_square_dist(point_feat_a: Tensor,
...
@@ -23,16 +23,11 @@ def calc_square_dist(point_feat_a: Tensor,
torch.Tensor: (B, N, M) Square distance between each point pair.
torch.Tensor: (B, N, M) Square distance between each point pair.
"""
"""
num_channel
=
point_feat_a
.
shape
[
-
1
]
num_channel
=
point_feat_a
.
shape
[
-
1
]
# [bs, n, 1]
dist
=
torch
.
cdist
(
point_feat_a
,
point_feat_b
)
a_square
=
torch
.
sum
(
point_feat_a
.
unsqueeze
(
dim
=
2
).
pow
(
2
),
dim
=-
1
)
# [bs, 1, m]
b_square
=
torch
.
sum
(
point_feat_b
.
unsqueeze
(
dim
=
1
).
pow
(
2
),
dim
=-
1
)
corr_matrix
=
torch
.
matmul
(
point_feat_a
,
point_feat_b
.
transpose
(
1
,
2
))
dist
=
a_square
+
b_square
-
2
*
corr_matrix
if
norm
:
if
norm
:
dist
=
torch
.
sqrt
(
dist
)
/
num_channel
dist
=
dist
/
num_channel
else
:
dist
=
torch
.
square
(
dist
)
return
dist
return
dist
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment