Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
006b7ccf
"docs/git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "7a1d53172482b325db5bbd5a03228796975a0024"
Commit
006b7ccf
authored
May 18, 2020
by
wuyuefeng
Committed by
zhangwenwei
May 18, 2020
Browse files
Feat pointnet ops
parent
28e511cd
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
612 additions
and
0 deletions
+612
-0
mmdet3d/ops/interpolate/src/three_interpolate_cuda.cu
mmdet3d/ops/interpolate/src/three_interpolate_cuda.cu
+93
-0
mmdet3d/ops/interpolate/src/three_nn_cuda.cu
mmdet3d/ops/interpolate/src/three_nn_cuda.cu
+74
-0
mmdet3d/ops/interpolate/three_interpolate.py
mmdet3d/ops/interpolate/three_interpolate.py
+64
-0
mmdet3d/ops/interpolate/three_nn.py
mmdet3d/ops/interpolate/three_nn.py
+42
-0
setup.py
setup.py
+27
-0
tests/test_pointnet_ops.py
tests/test_pointnet_ops.py
+312
-0
No files found.
mmdet3d/ops/interpolate/src/three_interpolate_cuda.cu
0 → 100644
View file @
006b7ccf
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
__global__
void
three_interpolate_kernel
(
int
b
,
int
c
,
int
m
,
int
n
,
const
float
*
__restrict__
points
,
const
int
*
__restrict__
idx
,
const
float
*
__restrict__
weight
,
float
*
__restrict__
out
)
{
// points: (B, C, M)
// idx: (B, N, 3)
// weight: (B, N, 3)
// output:
// out: (B, C, N)
int
bs_idx
=
blockIdx
.
z
;
int
c_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
n
)
return
;
weight
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
points
+=
bs_idx
*
c
*
m
+
c_idx
*
m
;
idx
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
out
+=
bs_idx
*
c
*
n
+
c_idx
*
n
;
out
[
pt_idx
]
=
weight
[
0
]
*
points
[
idx
[
0
]]
+
weight
[
1
]
*
points
[
idx
[
1
]]
+
weight
[
2
]
*
points
[
idx
[
2
]];
}
void
three_interpolate_kernel_launcher
(
int
b
,
int
c
,
int
m
,
int
n
,
const
float
*
points
,
const
int
*
idx
,
const
float
*
weight
,
float
*
out
,
cudaStream_t
stream
)
{
// points: (B, C, M)
// idx: (B, N, 3)
// weight: (B, N, 3)
// output:
// out: (B, C, N)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
n
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
three_interpolate_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
m
,
n
,
points
,
idx
,
weight
,
out
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
__global__
void
three_interpolate_grad_kernel
(
int
b
,
int
c
,
int
n
,
int
m
,
const
float
*
__restrict__
grad_out
,
const
int
*
__restrict__
idx
,
const
float
*
__restrict__
weight
,
float
*
__restrict__
grad_points
)
{
// grad_out: (B, C, N)
// weight: (B, N, 3)
// output:
// grad_points: (B, C, M)
int
bs_idx
=
blockIdx
.
z
;
int
c_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
n
)
return
;
grad_out
+=
bs_idx
*
c
*
n
+
c_idx
*
n
+
pt_idx
;
weight
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
grad_points
+=
bs_idx
*
c
*
m
+
c_idx
*
m
;
idx
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
atomicAdd
(
grad_points
+
idx
[
0
],
grad_out
[
0
]
*
weight
[
0
]);
atomicAdd
(
grad_points
+
idx
[
1
],
grad_out
[
0
]
*
weight
[
1
]);
atomicAdd
(
grad_points
+
idx
[
2
],
grad_out
[
0
]
*
weight
[
2
]);
}
void
three_interpolate_grad_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
m
,
const
float
*
grad_out
,
const
int
*
idx
,
const
float
*
weight
,
float
*
grad_points
,
cudaStream_t
stream
)
{
// grad_out: (B, C, N)
// weight: (B, N, 3)
// output:
// grad_points: (B, C, M)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
n
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
three_interpolate_grad_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
m
,
grad_out
,
idx
,
weight
,
grad_points
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
mmdet3d/ops/interpolate/src/three_nn_cuda.cu
0 → 100644
View file @
006b7ccf
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
__global__
void
three_nn_kernel
(
int
b
,
int
n
,
int
m
,
const
float
*
__restrict__
unknown
,
const
float
*
__restrict__
known
,
float
*
__restrict__
dist2
,
int
*
__restrict__
idx
)
{
// unknown: (B, N, 3)
// known: (B, M, 3)
// output:
// dist2: (B, N, 3)
// idx: (B, N, 3)
int
bs_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
b
||
pt_idx
>=
n
)
return
;
unknown
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
known
+=
bs_idx
*
m
*
3
;
dist2
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
idx
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
float
ux
=
unknown
[
0
];
float
uy
=
unknown
[
1
];
float
uz
=
unknown
[
2
];
double
best1
=
1e40
,
best2
=
1e40
,
best3
=
1e40
;
int
besti1
=
0
,
besti2
=
0
,
besti3
=
0
;
for
(
int
k
=
0
;
k
<
m
;
++
k
)
{
float
x
=
known
[
k
*
3
+
0
];
float
y
=
known
[
k
*
3
+
1
];
float
z
=
known
[
k
*
3
+
2
];
float
d
=
(
ux
-
x
)
*
(
ux
-
x
)
+
(
uy
-
y
)
*
(
uy
-
y
)
+
(
uz
-
z
)
*
(
uz
-
z
);
if
(
d
<
best1
)
{
best3
=
best2
;
besti3
=
besti2
;
best2
=
best1
;
besti2
=
besti1
;
best1
=
d
;
besti1
=
k
;
}
else
if
(
d
<
best2
)
{
best3
=
best2
;
besti3
=
besti2
;
best2
=
d
;
besti2
=
k
;
}
else
if
(
d
<
best3
)
{
best3
=
d
;
besti3
=
k
;
}
}
dist2
[
0
]
=
best1
;
dist2
[
1
]
=
best2
;
dist2
[
2
]
=
best3
;
idx
[
0
]
=
besti1
;
idx
[
1
]
=
besti2
;
idx
[
2
]
=
besti3
;
}
void
three_nn_kernel_launcher
(
int
b
,
int
n
,
int
m
,
const
float
*
unknown
,
const
float
*
known
,
float
*
dist2
,
int
*
idx
,
cudaStream_t
stream
)
{
// unknown: (B, N, 3)
// known: (B, M, 3)
// output:
// dist2: (B, N, 3)
// idx: (B, N, 3)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
n
,
THREADS_PER_BLOCK
),
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
three_nn_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
unknown
,
known
,
dist2
,
idx
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
mmdet3d/ops/interpolate/three_interpolate.py
0 → 100644
View file @
006b7ccf
from
typing
import
Tuple
import
torch
from
torch.autograd
import
Function
from
.
import
interpolate_ext
class
ThreeInterpolate
(
Function
):
@
staticmethod
def
forward
(
ctx
,
features
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Performs weighted linear interpolation on 3 features
Args:
features (Tensor): (B, C, M) Features descriptors to be
interpolated from
indices (Tensor): (B, n, 3) index three nearest neighbors
of the target features in features
weight (Tensor): (B, n, 3) weights of interpolation
Returns:
Tensor: (B, C, N) tensor of the interpolated features
"""
assert
features
.
is_contiguous
()
assert
indices
.
is_contiguous
()
assert
weight
.
is_contiguous
()
B
,
c
,
m
=
features
.
size
()
n
=
indices
.
size
(
1
)
ctx
.
three_interpolate_for_backward
=
(
indices
,
weight
,
m
)
output
=
torch
.
cuda
.
FloatTensor
(
B
,
c
,
n
)
interpolate_ext
.
three_interpolate_wrapper
(
B
,
c
,
m
,
n
,
features
,
indices
,
weight
,
output
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_out
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Backward of three interpolate
Args:
grad_out (Tensor): (B, C, N) tensor with gradients of outputs
Returns:
Tensor: (B, C, M) tensor with gradients of features
"""
idx
,
weight
,
m
=
ctx
.
three_interpolate_for_backward
B
,
c
,
n
=
grad_out
.
size
()
grad_features
=
torch
.
cuda
.
FloatTensor
(
B
,
c
,
m
).
zero_
()
grad_out_data
=
grad_out
.
data
.
contiguous
()
interpolate_ext
.
three_interpolate_grad_wrapper
(
B
,
c
,
n
,
m
,
grad_out_data
,
idx
,
weight
,
grad_features
.
data
)
return
grad_features
,
None
,
None
three_interpolate
=
ThreeInterpolate
.
apply
mmdet3d/ops/interpolate/three_nn.py
0 → 100644
View file @
006b7ccf
from
typing
import
Tuple
import
torch
from
torch.autograd
import
Function
from
.
import
interpolate_ext
class
ThreeNN
(
Function
):
@
staticmethod
def
forward
(
ctx
,
target
:
torch
.
Tensor
,
source
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Find the top-3 nearest neighbors of the target set from the source set.
Args:
target (Tensor): shape (B, N, 3), points set that needs to
find the nearest neighbors.
source (Tensor): shape (B, M, 3), points set that is used
to find the nearest neighbors of points in target set.
Returns:
Tensor: shape (B, N, 3), L2 distance of each point in target
set to their corresponding nearest neighbors.
"""
assert
target
.
is_contiguous
()
assert
source
.
is_contiguous
()
B
,
N
,
_
=
target
.
size
()
m
=
source
.
size
(
1
)
dist2
=
torch
.
cuda
.
FloatTensor
(
B
,
N
,
3
)
idx
=
torch
.
cuda
.
IntTensor
(
B
,
N
,
3
)
interpolate_ext
.
three_nn_wrapper
(
B
,
N
,
m
,
target
,
source
,
dist2
,
idx
)
return
torch
.
sqrt
(
dist2
),
idx
@
staticmethod
def
backward
(
ctx
,
a
=
None
,
b
=
None
):
return
None
,
None
three_nn
=
ThreeNN
.
apply
setup.py
View file @
006b7ccf
...
@@ -269,6 +269,33 @@ if __name__ == '__main__':
...
@@ -269,6 +269,33 @@ if __name__ == '__main__':
'src/roiaware_pool3d_kernel.cu'
,
'src/roiaware_pool3d_kernel.cu'
,
'src/points_in_boxes_cuda.cu'
,
'src/points_in_boxes_cuda.cu'
,
]),
]),
make_cuda_ext
(
name
=
'ball_query_ext'
,
module
=
'mmdet3d.ops.ball_query'
,
sources
=
[
'src/ball_query.cpp'
],
sources_cuda
=
[
'src/ball_query_cuda.cu'
]),
make_cuda_ext
(
name
=
'group_points_ext'
,
module
=
'mmdet3d.ops.group_points'
,
sources
=
[
'src/group_points.cpp'
],
sources_cuda
=
[
'src/group_points_cuda.cu'
]),
make_cuda_ext
(
name
=
'interpolate_ext'
,
module
=
'mmdet3d.ops.interpolate'
,
sources
=
[
'src/interpolate.cpp'
],
sources_cuda
=
[
'src/three_interpolate_cuda.cu'
,
'src/three_nn_cuda.cu'
]),
make_cuda_ext
(
name
=
'furthest_point_sample_ext'
,
module
=
'mmdet3d.ops.furthest_point_sample'
,
sources
=
[
'src/furthest_point_sample.cpp'
],
sources_cuda
=
[
'src/furthest_point_sample_cuda.cu'
]),
make_cuda_ext
(
name
=
'gather_points_ext'
,
module
=
'mmdet3d.ops.gather_points'
,
sources
=
[
'src/gather_points.cpp'
],
sources_cuda
=
[
'src/gather_points_cuda.cu'
])
],
],
cmdclass
=
{
'build_ext'
:
BuildExtension
},
cmdclass
=
{
'build_ext'
:
BuildExtension
},
zip_safe
=
False
)
zip_safe
=
False
)
tests/test_pointnet_ops.py
0 → 100644
View file @
006b7ccf
import
pytest
import
torch
from
mmdet3d.ops
import
(
ball_query
,
furthest_point_sample
,
gather_points
,
grouping_operation
,
three_interpolate
,
three_nn
)
def
test_fps
():
if
not
torch
.
cuda
.
is_available
():
pytest
.
skip
()
xyz
=
torch
.
tensor
([[[
-
0.2748
,
1.0020
,
-
1.1674
],
[
0.1015
,
1.3952
,
-
1.2681
],
[
-
0.8070
,
2.4137
,
-
0.5845
],
[
-
1.0001
,
2.1982
,
-
0.5859
],
[
0.3841
,
1.8983
,
-
0.7431
]],
[[
-
1.0696
,
3.0758
,
-
0.1899
],
[
-
0.2559
,
3.5521
,
-
0.1402
],
[
0.8164
,
4.0081
,
-
0.1839
],
[
-
1.1000
,
3.0213
,
-
0.8205
],
[
-
0.0518
,
3.7251
,
-
0.3950
]]]).
cuda
()
idx
=
furthest_point_sample
(
xyz
,
3
)
expected_idx
=
torch
.
tensor
([[
0
,
2
,
4
],
[
0
,
2
,
1
]]).
cuda
()
assert
torch
.
all
(
idx
==
expected_idx
)
def
test_ball_query
():
if
not
torch
.
cuda
.
is_available
():
pytest
.
skip
()
new_xyz
=
torch
.
tensor
([[[
-
0.0740
,
1.3147
,
-
1.3625
],
[
-
2.2769
,
2.7817
,
-
0.2334
],
[
-
0.4003
,
2.4666
,
-
0.5116
],
[
-
0.0740
,
1.3147
,
-
1.3625
],
[
-
0.0740
,
1.3147
,
-
1.3625
]],
[[
-
2.0289
,
2.4952
,
-
0.1708
],
[
-
2.0668
,
6.0278
,
-
0.4875
],
[
0.4066
,
1.4211
,
-
0.2947
],
[
-
2.0289
,
2.4952
,
-
0.1708
],
[
-
2.0289
,
2.4952
,
-
0.1708
]]]).
cuda
()
xyz
=
torch
.
tensor
([[[
-
0.0740
,
1.3147
,
-
1.3625
],
[
0.5555
,
1.0399
,
-
1.3634
],
[
-
0.4003
,
2.4666
,
-
0.5116
],
[
-
0.5251
,
2.4379
,
-
0.8466
],
[
-
0.9691
,
1.1418
,
-
1.3733
],
[
-
0.2232
,
0.9561
,
-
1.3626
],
[
-
2.2769
,
2.7817
,
-
0.2334
],
[
-
0.2822
,
1.3192
,
-
1.3645
],
[
0.1533
,
1.5024
,
-
1.0432
],
[
0.4917
,
1.1529
,
-
1.3496
]],
[[
-
2.0289
,
2.4952
,
-
0.1708
],
[
-
0.7188
,
0.9956
,
-
0.5096
],
[
-
2.0668
,
6.0278
,
-
0.4875
],
[
-
1.9304
,
3.3092
,
0.6610
],
[
0.0949
,
1.4332
,
0.3140
],
[
-
1.2879
,
2.0008
,
-
0.7791
],
[
-
0.7252
,
0.9611
,
-
0.6371
],
[
0.4066
,
1.4211
,
-
0.2947
],
[
0.3220
,
1.4447
,
0.3548
],
[
-
0.9744
,
2.3856
,
-
1.2000
]]]).
cuda
()
idx
=
ball_query
(
0.2
,
5
,
xyz
,
new_xyz
)
expected_idx
=
torch
.
tensor
([[[
0
,
0
,
0
,
0
,
0
],
[
6
,
6
,
6
,
6
,
6
],
[
2
,
2
,
2
,
2
,
2
],
[
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
]],
[[
0
,
0
,
0
,
0
,
0
],
[
2
,
2
,
2
,
2
,
2
],
[
7
,
7
,
7
,
7
,
7
],
[
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
]]]).
cuda
()
assert
torch
.
all
(
idx
==
expected_idx
)
def
test_grouping_points
():
if
not
torch
.
cuda
.
is_available
():
pytest
.
skip
()
idx
=
torch
.
tensor
([[[
0
,
0
,
0
],
[
3
,
3
,
3
],
[
8
,
8
,
8
],
[
0
,
0
,
0
],
[
0
,
0
,
0
],
[
0
,
0
,
0
]],
[[
0
,
0
,
0
],
[
6
,
6
,
6
],
[
9
,
9
,
9
],
[
0
,
0
,
0
],
[
0
,
0
,
0
],
[
0
,
0
,
0
]]]).
int
().
cuda
()
festures
=
torch
.
tensor
([[[
0.5798
,
-
0.7981
,
-
0.9280
,
-
1.3311
,
1.3687
,
0.9277
,
-
0.4164
,
-
1.8274
,
0.9268
,
0.8414
],
[
5.4247
,
1.5113
,
2.3944
,
1.4740
,
5.0300
,
5.1030
,
1.9360
,
2.1939
,
2.1581
,
3.4666
],
[
-
1.6266
,
-
1.0281
,
-
1.0393
,
-
1.6931
,
-
1.3982
,
-
0.5732
,
-
1.0830
,
-
1.7561
,
-
1.6786
,
-
1.6967
]],
[[
-
0.0380
,
-
0.1880
,
-
1.5724
,
0.6905
,
-
0.3190
,
0.7798
,
-
0.3693
,
-
0.9457
,
-
0.2942
,
-
1.8527
],
[
1.1773
,
1.5009
,
2.6399
,
5.9242
,
1.0962
,
2.7346
,
6.0865
,
1.5555
,
4.3303
,
2.8229
],
[
-
0.6646
,
-
0.6870
,
-
0.1125
,
-
0.2224
,
-
0.3445
,
-
1.4049
,
0.4990
,
-
0.7037
,
-
0.9924
,
0.0386
]]]).
cuda
()
output
=
grouping_operation
(
festures
,
idx
)
expected_output
=
torch
.
tensor
([[[[
0.5798
,
0.5798
,
0.5798
],
[
-
1.3311
,
-
1.3311
,
-
1.3311
],
[
0.9268
,
0.9268
,
0.9268
],
[
0.5798
,
0.5798
,
0.5798
],
[
0.5798
,
0.5798
,
0.5798
],
[
0.5798
,
0.5798
,
0.5798
]],
[[
5.4247
,
5.4247
,
5.4247
],
[
1.4740
,
1.4740
,
1.4740
],
[
2.1581
,
2.1581
,
2.1581
],
[
5.4247
,
5.4247
,
5.4247
],
[
5.4247
,
5.4247
,
5.4247
],
[
5.4247
,
5.4247
,
5.4247
]],
[[
-
1.6266
,
-
1.6266
,
-
1.6266
],
[
-
1.6931
,
-
1.6931
,
-
1.6931
],
[
-
1.6786
,
-
1.6786
,
-
1.6786
],
[
-
1.6266
,
-
1.6266
,
-
1.6266
],
[
-
1.6266
,
-
1.6266
,
-
1.6266
],
[
-
1.6266
,
-
1.6266
,
-
1.6266
]]],
[[[
-
0.0380
,
-
0.0380
,
-
0.0380
],
[
-
0.3693
,
-
0.3693
,
-
0.3693
],
[
-
1.8527
,
-
1.8527
,
-
1.8527
],
[
-
0.0380
,
-
0.0380
,
-
0.0380
],
[
-
0.0380
,
-
0.0380
,
-
0.0380
],
[
-
0.0380
,
-
0.0380
,
-
0.0380
]],
[[
1.1773
,
1.1773
,
1.1773
],
[
6.0865
,
6.0865
,
6.0865
],
[
2.8229
,
2.8229
,
2.8229
],
[
1.1773
,
1.1773
,
1.1773
],
[
1.1773
,
1.1773
,
1.1773
],
[
1.1773
,
1.1773
,
1.1773
]],
[[
-
0.6646
,
-
0.6646
,
-
0.6646
],
[
0.4990
,
0.4990
,
0.4990
],
[
0.0386
,
0.0386
,
0.0386
],
[
-
0.6646
,
-
0.6646
,
-
0.6646
],
[
-
0.6646
,
-
0.6646
,
-
0.6646
],
[
-
0.6646
,
-
0.6646
,
-
0.6646
]]]]).
cuda
()
assert
torch
.
allclose
(
output
,
expected_output
)
def
test_gather_points
():
if
not
torch
.
cuda
.
is_available
():
pytest
.
skip
()
features
=
torch
.
tensor
([[[
-
1.6095
,
-
0.1029
,
-
0.8876
,
-
1.2447
,
-
2.4031
,
0.3708
,
-
1.1586
,
-
1.4967
,
-
0.4800
,
0.2252
],
[
1.9138
,
3.4979
,
1.6854
,
1.5631
,
3.6776
,
3.1154
,
2.1705
,
2.5221
,
2.0411
,
3.1446
],
[
-
1.4173
,
0.3073
,
-
1.4339
,
-
1.4340
,
-
1.2770
,
-
0.2867
,
-
1.4162
,
-
1.4044
,
-
1.4245
,
-
1.4074
]],
[[
0.2160
,
0.0842
,
0.3661
,
-
0.2749
,
-
0.4909
,
-
0.6066
,
-
0.8773
,
-
0.0745
,
-
0.9496
,
0.1434
],
[
1.3644
,
1.8087
,
1.6855
,
1.9563
,
1.2746
,
1.9662
,
0.9566
,
1.8778
,
1.1437
,
1.3639
],
[
-
0.7172
,
0.1692
,
0.2241
,
0.0721
,
-
0.7540
,
0.0462
,
-
0.6227
,
0.3223
,
-
0.6944
,
-
0.5294
]]]).
cuda
()
idx
=
torch
.
tensor
([[
0
,
1
,
4
,
0
,
0
,
0
],
[
0
,
5
,
6
,
0
,
0
,
0
]]).
int
().
cuda
()
output
=
gather_points
(
features
,
idx
)
expected_output
=
torch
.
tensor
(
[[[
-
1.6095
,
-
0.1029
,
-
2.4031
,
-
1.6095
,
-
1.6095
,
-
1.6095
],
[
1.9138
,
3.4979
,
3.6776
,
1.9138
,
1.9138
,
1.9138
],
[
-
1.4173
,
0.3073
,
-
1.2770
,
-
1.4173
,
-
1.4173
,
-
1.4173
]],
[[
0.2160
,
-
0.6066
,
-
0.8773
,
0.2160
,
0.2160
,
0.2160
],
[
1.3644
,
1.9662
,
0.9566
,
1.3644
,
1.3644
,
1.3644
],
[
-
0.7172
,
0.0462
,
-
0.6227
,
-
0.7172
,
-
0.7172
,
-
0.7172
]]]).
cuda
()
assert
torch
.
allclose
(
output
,
expected_output
)
def
test_three_interpolate
():
if
not
torch
.
cuda
.
is_available
():
pytest
.
skip
()
features
=
torch
.
tensor
([[[
2.4350
,
4.7516
,
4.4995
,
2.4350
,
2.4350
,
2.4350
],
[
3.1236
,
2.6278
,
3.0447
,
3.1236
,
3.1236
,
3.1236
],
[
2.6732
,
2.8677
,
2.6436
,
2.6732
,
2.6732
,
2.6732
],
[
0.0124
,
7.0150
,
7.0199
,
0.0124
,
0.0124
,
0.0124
],
[
0.3207
,
0.0000
,
0.3411
,
0.3207
,
0.3207
,
0.3207
]],
[[
0.0000
,
0.9544
,
2.4532
,
0.0000
,
0.0000
,
0.0000
],
[
0.5346
,
1.9176
,
1.4715
,
0.5346
,
0.5346
,
0.5346
],
[
0.0000
,
0.2744
,
2.0842
,
0.0000
,
0.0000
,
0.0000
],
[
0.3414
,
1.5063
,
1.6209
,
0.3414
,
0.3414
,
0.3414
],
[
0.5814
,
0.0103
,
0.0000
,
0.5814
,
0.5814
,
0.5814
]]]).
cuda
()
idx
=
torch
.
tensor
([[[
0
,
1
,
2
],
[
2
,
3
,
4
],
[
2
,
3
,
4
],
[
0
,
1
,
2
],
[
0
,
1
,
2
],
[
0
,
1
,
3
]],
[[
0
,
2
,
3
],
[
1
,
3
,
4
],
[
2
,
1
,
4
],
[
0
,
2
,
4
],
[
0
,
2
,
4
],
[
0
,
1
,
2
]]]).
int
().
cuda
()
weight
=
torch
.
tensor
([[[
3.3333e-01
,
3.3333e-01
,
3.3333e-01
],
[
1.0000e+00
,
5.8155e-08
,
2.2373e-08
],
[
1.0000e+00
,
1.7737e-08
,
1.7356e-08
],
[
3.3333e-01
,
3.3333e-01
,
3.3333e-01
],
[
3.3333e-01
,
3.3333e-01
,
3.3333e-01
],
[
3.3333e-01
,
3.3333e-01
,
3.3333e-01
]],
[[
3.3333e-01
,
3.3333e-01
,
3.3333e-01
],
[
1.0000e+00
,
1.3651e-08
,
7.7312e-09
],
[
1.0000e+00
,
1.7148e-08
,
1.4070e-08
],
[
3.3333e-01
,
3.3333e-01
,
3.3333e-01
],
[
3.3333e-01
,
3.3333e-01
,
3.3333e-01
],
[
3.3333e-01
,
3.3333e-01
,
3.3333e-01
]]]).
cuda
()
output
=
three_interpolate
(
features
,
idx
,
weight
)
expected_output
=
torch
.
tensor
([[[
3.8953e+00
,
4.4995e+00
,
4.4995e+00
,
3.8953e+00
,
3.8953e+00
,
3.2072e+00
],
[
2.9320e+00
,
3.0447e+00
,
3.0447e+00
,
2.9320e+00
,
2.9320e+00
,
2.9583e+00
],
[
2.7281e+00
,
2.6436e+00
,
2.6436e+00
,
2.7281e+00
,
2.7281e+00
,
2.7380e+00
],
[
4.6824e+00
,
7.0199e+00
,
7.0199e+00
,
4.6824e+00
,
4.6824e+00
,
2.3466e+00
],
[
2.2060e-01
,
3.4110e-01
,
3.4110e-01
,
2.2060e-01
,
2.2060e-01
,
2.1380e-01
]],
[[
8.1773e-01
,
9.5440e-01
,
2.4532e+00
,
8.1773e-01
,
8.1773e-01
,
1.1359e+00
],
[
8.4689e-01
,
1.9176e+00
,
1.4715e+00
,
8.4689e-01
,
8.4689e-01
,
1.3079e+00
],
[
6.9473e-01
,
2.7440e-01
,
2.0842e+00
,
6.9473e-01
,
6.9473e-01
,
7.8619e-01
],
[
7.6789e-01
,
1.5063e+00
,
1.6209e+00
,
7.6789e-01
,
7.6789e-01
,
1.1562e+00
],
[
3.8760e-01
,
1.0300e-02
,
8.3569e-09
,
3.8760e-01
,
3.8760e-01
,
1.9723e-01
]]]).
cuda
()
assert
torch
.
allclose
(
output
,
expected_output
,
1e-4
)
def
test_three_nn
():
known
=
torch
.
tensor
([[[
-
1.8373
,
3.5605
,
-
0.7867
],
[
0.7615
,
2.9420
,
0.2314
],
[
-
0.6503
,
3.6637
,
-
1.0622
],
[
-
1.8373
,
3.5605
,
-
0.7867
],
[
-
1.8373
,
3.5605
,
-
0.7867
]],
[[
-
1.3399
,
1.9991
,
-
0.3698
],
[
-
0.0799
,
0.9698
,
-
0.8457
],
[
0.0858
,
2.4721
,
-
0.1928
],
[
-
1.3399
,
1.9991
,
-
0.3698
],
[
-
1.3399
,
1.9991
,
-
0.3698
]]]).
cuda
()
unknown
=
torch
.
tensor
([[[
-
1.8373
,
3.5605
,
-
0.7867
],
[
0.7615
,
2.9420
,
0.2314
],
[
-
0.6503
,
3.6637
,
-
1.0622
],
[
-
1.5237
,
2.3976
,
-
0.8097
],
[
-
0.0722
,
3.4017
,
-
0.2880
],
[
0.5198
,
3.0661
,
-
0.4605
],
[
-
2.0185
,
3.5019
,
-
0.3236
],
[
0.5098
,
3.1020
,
0.5799
],
[
-
1.6137
,
3.8443
,
-
0.5269
],
[
0.7341
,
2.9626
,
-
0.3189
]],
[[
-
1.3399
,
1.9991
,
-
0.3698
],
[
-
0.0799
,
0.9698
,
-
0.8457
],
[
0.0858
,
2.4721
,
-
0.1928
],
[
-
0.9022
,
1.6560
,
-
1.3090
],
[
0.1156
,
1.6901
,
-
0.4366
],
[
-
0.6477
,
2.3576
,
-
0.1563
],
[
-
0.8482
,
1.1466
,
-
1.2704
],
[
-
0.8753
,
2.0845
,
-
0.3460
],
[
-
0.5621
,
1.4233
,
-
1.2858
],
[
-
0.5883
,
1.3114
,
-
1.2899
]]]).
cuda
()
dist
,
idx
=
three_nn
(
unknown
,
known
)
expected_dist
=
torch
.
tensor
([[[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
2.0463
,
2.8588
],
[
0.0000
,
1.2229
,
1.2229
],
[
1.2047
,
1.2047
,
1.2047
],
[
1.0011
,
1.0845
,
1.8411
],
[
0.7433
,
1.4451
,
2.4304
],
[
0.5007
,
0.5007
,
0.5007
],
[
0.4587
,
2.0875
,
2.7544
],
[
0.4450
,
0.4450
,
0.4450
],
[
0.5514
,
1.7206
,
2.6811
]],
[[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
1.6464
,
1.6952
],
[
0.0000
,
1.5125
,
1.5125
],
[
1.0915
,
1.0915
,
1.0915
],
[
0.8197
,
0.8511
,
1.4894
],
[
0.7433
,
0.8082
,
0.8082
],
[
0.8955
,
1.3340
,
1.3340
],
[
0.4730
,
0.4730
,
0.4730
],
[
0.7949
,
1.3325
,
1.3325
],
[
0.7566
,
1.3727
,
1.3727
]]]).
cuda
()
expected_idx
=
torch
.
tensor
([[[
0
,
3
,
4
],
[
1
,
2
,
0
],
[
2
,
0
,
3
],
[
0
,
3
,
4
],
[
2
,
1
,
0
],
[
1
,
2
,
0
],
[
0
,
3
,
4
],
[
1
,
2
,
0
],
[
0
,
3
,
4
],
[
1
,
2
,
0
]],
[[
0
,
3
,
4
],
[
1
,
2
,
0
],
[
2
,
0
,
3
],
[
0
,
3
,
4
],
[
2
,
1
,
0
],
[
2
,
0
,
3
],
[
1
,
0
,
3
],
[
0
,
3
,
4
],
[
1
,
0
,
3
],
[
1
,
0
,
3
]]]).
cuda
()
assert
torch
.
allclose
(
dist
,
expected_dist
,
1e-4
)
assert
torch
.
all
(
idx
==
expected_idx
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment