Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
2d73eafe
Unverified
Commit
2d73eafe
authored
Oct 23, 2021
by
pc
Committed by
GitHub
Oct 23, 2021
Browse files
add mmdet3d op (#1425)
Co-authored-by:
zhouzaida
<
zhouzaida@163.com
>
parent
75cae78c
Changes
43
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
11 additions
and
9 deletions
+11
-9
mmcv/ops/knn.py
mmcv/ops/knn.py
+4
-2
mmcv/ops/three_interpolate.py
mmcv/ops/three_interpolate.py
+4
-4
mmcv/ops/three_nn.py
mmcv/ops/three_nn.py
+3
-3
No files found.
mmcv/ops/knn.py
View file @
2d73eafe
...
@@ -61,9 +61,11 @@ class KNN(Function):
...
@@ -61,9 +61,11 @@ class KNN(Function):
idx
=
center_xyz
.
new_zeros
((
B
,
npoint
,
k
)).
int
()
idx
=
center_xyz
.
new_zeros
((
B
,
npoint
,
k
)).
int
()
dist2
=
center_xyz
.
new_zeros
((
B
,
npoint
,
k
)).
float
()
dist2
=
center_xyz
.
new_zeros
((
B
,
npoint
,
k
)).
float
()
ext_module
.
knn_forward
(
B
,
N
,
npoint
,
k
,
xyz
,
center_xyz
,
idx
,
dist2
)
ext_module
.
knn_forward
(
xyz
,
center_xyz
,
idx
,
dist2
,
b
=
B
,
n
=
N
,
m
=
npoint
,
nsample
=
k
)
# idx shape to [B, k, npoint]
# idx shape to [B, k, npoint]
idx
=
idx
.
transpose
(
2
,
1
).
contiguous
()
idx
=
idx
.
transpose
(
2
,
1
).
contiguous
()
if
torch
.
__version__
!=
'parrots'
:
ctx
.
mark_non_differentiable
(
idx
)
ctx
.
mark_non_differentiable
(
idx
)
return
idx
return
idx
...
...
mmcv/ops/three_interpolate.py
View file @
2d73eafe
...
@@ -39,8 +39,8 @@ class ThreeInterpolate(Function):
...
@@ -39,8 +39,8 @@ class ThreeInterpolate(Function):
ctx
.
three_interpolate_for_backward
=
(
indices
,
weight
,
m
)
ctx
.
three_interpolate_for_backward
=
(
indices
,
weight
,
m
)
output
=
torch
.
cuda
.
FloatTensor
(
B
,
c
,
n
)
output
=
torch
.
cuda
.
FloatTensor
(
B
,
c
,
n
)
ext_module
.
three_interpolate_forward
(
B
,
c
,
m
,
n
,
features
,
indices
,
ext_module
.
three_interpolate_forward
(
weight
,
output
)
features
,
indices
,
weight
,
output
,
b
=
B
,
c
=
c
,
m
=
m
,
n
=
n
)
return
output
return
output
@
staticmethod
@
staticmethod
...
@@ -60,8 +60,8 @@ class ThreeInterpolate(Function):
...
@@ -60,8 +60,8 @@ class ThreeInterpolate(Function):
grad_features
=
torch
.
cuda
.
FloatTensor
(
B
,
c
,
m
).
zero_
()
grad_features
=
torch
.
cuda
.
FloatTensor
(
B
,
c
,
m
).
zero_
()
grad_out_data
=
grad_out
.
data
.
contiguous
()
grad_out_data
=
grad_out
.
data
.
contiguous
()
ext_module
.
three_interpolate_backward
(
B
,
c
,
n
,
m
,
grad_out_data
,
idx
,
ext_module
.
three_interpolate_backward
(
weight
,
grad_features
.
data
)
grad_out_data
,
idx
,
weight
,
grad_features
.
data
,
b
=
B
,
c
=
c
,
n
=
n
,
m
=
m
)
return
grad_features
,
None
,
None
return
grad_features
,
None
,
None
...
...
mmcv/ops/three_nn.py
View file @
2d73eafe
...
@@ -37,8 +37,8 @@ class ThreeNN(Function):
...
@@ -37,8 +37,8 @@ class ThreeNN(Function):
dist2
=
torch
.
cuda
.
FloatTensor
(
B
,
N
,
3
)
dist2
=
torch
.
cuda
.
FloatTensor
(
B
,
N
,
3
)
idx
=
torch
.
cuda
.
IntTensor
(
B
,
N
,
3
)
idx
=
torch
.
cuda
.
IntTensor
(
B
,
N
,
3
)
ext_module
.
three_nn_forward
(
B
,
N
,
m
,
target
,
source
,
dist2
,
idx
)
ext_module
.
three_nn_forward
(
target
,
source
,
dist2
,
idx
,
b
=
B
,
n
=
N
,
m
=
m
)
if
torch
.
__version__
!=
'parrots'
:
ctx
.
mark_non_differentiable
(
idx
)
ctx
.
mark_non_differentiable
(
idx
)
return
torch
.
sqrt
(
dist2
),
idx
return
torch
.
sqrt
(
dist2
),
idx
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment