Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
e970a318
"doc/git@developer.sourcefind.cn:OpenDAS/ktransformers.git" did not exist on "9e42f33c29b6192965232b402b30d0db565f706d"
Unverified
Commit
e970a318
authored
Sep 07, 2021
by
zhanggefan
Committed by
GitHub
Sep 07, 2021
Browse files
fix #768 (#915)
parent
075f4442
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
50 additions
and
9 deletions
+50
-9
mmdet3d/ops/voxel/src/scatter_points_cuda.cu
mmdet3d/ops/voxel/src/scatter_points_cuda.cu
+9
-4
tests/test_models/test_voxel_encoder/test_dynamic_scatter.py
tests/test_models/test_voxel_encoder/test_dynamic_scatter.py
+41
-5
No files found.
mmdet3d/ops/voxel/src/scatter_points_cuda.cu
View file @
e970a318
...
@@ -204,10 +204,15 @@ std::vector<at::Tensor> dynamic_point_to_voxel_forward_gpu(
...
@@ -204,10 +204,15 @@ std::vector<at::Tensor> dynamic_point_to_voxel_forward_gpu(
std
::
tie
(
out_coors
,
coors_map
,
reduce_count
)
=
std
::
tie
(
out_coors
,
coors_map
,
reduce_count
)
=
at
::
unique_dim
(
coors_clean
,
0
,
true
,
true
,
true
);
at
::
unique_dim
(
coors_clean
,
0
,
true
,
true
,
true
);
// the first element of out_coors is always (-1,-1,-1) and should be removed
if
(
out_coors
.
index
({
0
,
0
}).
lt
(
0
).
item
<
bool
>
())
{
out_coors
=
out_coors
.
slice
(
0
,
1
);
// the first element of out_coors (-1,-1,-1) and should be removed
reduce_count
=
reduce_count
.
slice
(
0
,
1
).
to
(
torch
::
kInt32
);
out_coors
=
out_coors
.
slice
(
0
,
1
);
coors_map
=
coors_map
.
to
(
torch
::
kInt32
)
-
1
;
reduce_count
=
reduce_count
.
slice
(
0
,
1
);
coors_map
=
coors_map
-
1
;
}
coors_map
=
coors_map
.
to
(
torch
::
kInt32
);
reduce_count
=
reduce_count
.
to
(
torch
::
kInt32
);
auto
reduced_feats
=
auto
reduced_feats
=
at
::
empty
({
out_coors
.
size
(
0
),
num_feats
},
feats
.
options
());
at
::
empty
({
out_coors
.
size
(
0
),
num_feats
},
feats
.
options
());
...
...
tests/test_models/test_voxel_encoder/test_dynamic_scatter.py
View file @
e970a318
...
@@ -10,11 +10,6 @@ def test_dynamic_scatter():
...
@@ -10,11 +10,6 @@ def test_dynamic_scatter():
if
not
torch
.
cuda
.
is_available
():
if
not
torch
.
cuda
.
is_available
():
pytest
.
skip
(
'test requires GPU and torch+cuda'
)
pytest
.
skip
(
'test requires GPU and torch+cuda'
)
feats
=
torch
.
rand
(
size
=
(
200000
,
3
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
*
100
-
50
coors
=
torch
.
randint
(
low
=-
1
,
high
=
20
,
size
=
(
200000
,
3
),
dtype
=
torch
.
int32
,
device
=
'cuda'
)
dsmean
=
DynamicScatter
([
0.32
,
0.32
,
6
],
dsmean
=
DynamicScatter
([
0.32
,
0.32
,
6
],
[
-
74.88
,
-
74.88
,
-
2
,
74.88
,
74.88
,
4
],
True
)
[
-
74.88
,
-
74.88
,
-
2
,
74.88
,
74.88
,
4
],
True
)
dsmax
=
DynamicScatter
([
0.32
,
0.32
,
6
],
dsmax
=
DynamicScatter
([
0.32
,
0.32
,
6
],
...
@@ -54,6 +49,47 @@ def test_dynamic_scatter():
...
@@ -54,6 +49,47 @@ def test_dynamic_scatter():
assert
(
empty_o_feats
.
grad
==
0
).
all
()
assert
(
empty_o_feats
.
grad
==
0
).
all
()
# test non-empty input
# test non-empty input
feats
=
torch
.
rand
(
size
=
(
200000
,
3
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
*
100
-
50
coors
=
torch
.
randint
(
low
=-
1
,
high
=
20
,
size
=
(
200000
,
3
),
dtype
=
torch
.
int32
,
device
=
'cuda'
)
ref_voxel_coors
=
coors
.
unique
(
dim
=
0
,
sorted
=
True
)
ref_voxel_coors
=
ref_voxel_coors
[
ref_voxel_coors
.
min
(
dim
=-
1
).
values
>=
0
]
ref_voxel_feats_mean
=
[]
ref_voxel_feats_max
=
[]
for
ref_voxel_coor
in
ref_voxel_coors
:
voxel_mask
=
(
coors
==
ref_voxel_coor
).
all
(
dim
=-
1
)
ref_voxel_feats_mean
.
append
(
feats
[
voxel_mask
].
mean
(
dim
=
0
))
ref_voxel_feats_max
.
append
(
feats
[
voxel_mask
].
max
(
dim
=
0
).
values
)
ref_voxel_feats_mean
=
torch
.
stack
(
ref_voxel_feats_mean
)
ref_voxel_feats_max
=
torch
.
stack
(
ref_voxel_feats_max
)
feats_out_mean
,
coors_out_mean
=
dsmean
(
feats
,
coors
)
seq_mean
=
(
coors_out_mean
[:,
0
]
*
400
+
coors_out_mean
[:,
1
]
*
20
+
coors_out_mean
[:,
2
]).
argsort
()
feats_out_mean
=
feats_out_mean
[
seq_mean
]
coors_out_mean
=
coors_out_mean
[
seq_mean
]
feats_out_max
,
coors_out_max
=
dsmax
(
feats
,
coors
)
seq_max
=
(
coors_out_max
[:,
0
]
*
400
+
coors_out_max
[:,
1
]
*
20
+
coors_out_max
[:,
2
]).
argsort
()
feats_out_max
=
feats_out_max
[
seq_max
]
coors_cout_max
=
coors_out_max
[
seq_max
]
assert
(
coors_out_mean
==
ref_voxel_coors
).
all
()
assert
torch
.
allclose
(
feats_out_mean
,
ref_voxel_feats_mean
,
atol
=
1e-2
,
rtol
=
1e-5
)
assert
(
coors_cout_max
==
ref_voxel_coors
).
all
()
assert
torch
.
allclose
(
feats_out_max
,
ref_voxel_feats_max
,
atol
=
1e-2
,
rtol
=
1e-5
)
# test non-empty input without any point out of bound
feats
=
torch
.
rand
(
size
=
(
200000
,
3
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
*
100
-
50
coors
=
torch
.
randint
(
low
=
0
,
high
=
20
,
size
=
(
200000
,
3
),
dtype
=
torch
.
int32
,
device
=
'cuda'
)
ref_voxel_coors
=
coors
.
unique
(
dim
=
0
,
sorted
=
True
)
ref_voxel_coors
=
coors
.
unique
(
dim
=
0
,
sorted
=
True
)
ref_voxel_coors
=
ref_voxel_coors
[
ref_voxel_coors
.
min
(
dim
=-
1
).
values
>=
0
]
ref_voxel_coors
=
ref_voxel_coors
[
ref_voxel_coors
.
min
(
dim
=-
1
).
values
>=
0
]
ref_voxel_feats_mean
=
[]
ref_voxel_feats_mean
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment