Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenPCDet
Commits
e6f6151c
Commit
e6f6151c
authored
Jul 23, 2020
by
Shaoshuai Shi
Browse files
add RETURN_ENCODED_TENSOR config for Spconv UNet
parent
b6fbc50d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
15 deletions
+19
-15
pcdet/models/backbones_3d/spconv_unet.py
pcdet/models/backbones_3d/spconv_unet.py
+18
-15
tools/cfgs/kitti_models/PartA2_free.yaml
tools/cfgs/kitti_models/PartA2_free.yaml
+1
-0
No files found.
pcdet/models/backbones_3d/spconv_unet.py
View file @
e6f6151c
...
@@ -91,8 +91,8 @@ class UNetV2(nn.Module):
...
@@ -91,8 +91,8 @@ class UNetV2(nn.Module):
block
(
64
,
64
,
3
,
norm_fn
=
norm_fn
,
padding
=
1
,
indice_key
=
'subm4'
),
block
(
64
,
64
,
3
,
norm_fn
=
norm_fn
,
padding
=
1
,
indice_key
=
'subm4'
),
)
)
last_pad
=
0
if
self
.
model_cfg
.
get
(
'RETURN_ENCODED_TENSOR'
,
True
):
last_pad
=
self
.
model_cfg
.
get
(
'last_pad'
,
last_pad
)
last_pad
=
self
.
model_cfg
.
get
(
'last_pad'
,
0
)
self
.
conv_out
=
spconv
.
SparseSequential
(
self
.
conv_out
=
spconv
.
SparseSequential
(
# [200, 150, 5] -> [200, 150, 2]
# [200, 150, 5] -> [200, 150, 2]
...
@@ -101,6 +101,8 @@ class UNetV2(nn.Module):
...
@@ -101,6 +101,8 @@ class UNetV2(nn.Module):
norm_fn
(
128
),
norm_fn
(
128
),
nn
.
ReLU
(),
nn
.
ReLU
(),
)
)
else
:
self
.
conv_out
=
None
# decoder
# decoder
# [400, 352, 11] <- [200, 176, 5]
# [400, 352, 11] <- [200, 176, 5]
...
@@ -181,9 +183,12 @@ class UNetV2(nn.Module):
...
@@ -181,9 +183,12 @@ class UNetV2(nn.Module):
x_conv3
=
self
.
conv3
(
x_conv2
)
x_conv3
=
self
.
conv3
(
x_conv2
)
x_conv4
=
self
.
conv4
(
x_conv3
)
x_conv4
=
self
.
conv4
(
x_conv3
)
if
self
.
conv_out
is
not
None
:
# for detection head
# for detection head
# [200, 176, 5] -> [200, 176, 2]
# [200, 176, 5] -> [200, 176, 2]
out
=
self
.
conv_out
(
x_conv4
)
out
=
self
.
conv_out
(
x_conv4
)
batch_dict
[
'encoded_spconv_tensor'
]
=
out
batch_dict
[
'encoded_spconv_tensor_stride'
]
=
8
# for segmentation head
# for segmentation head
# [400, 352, 11] <- [200, 176, 5]
# [400, 352, 11] <- [200, 176, 5]
...
@@ -201,6 +206,4 @@ class UNetV2(nn.Module):
...
@@ -201,6 +206,4 @@ class UNetV2(nn.Module):
point_cloud_range
=
self
.
point_cloud_range
point_cloud_range
=
self
.
point_cloud_range
)
)
batch_dict
[
'point_coords'
]
=
torch
.
cat
((
x_up1
.
indices
[:,
0
:
1
].
float
(),
point_coords
),
dim
=
1
)
batch_dict
[
'point_coords'
]
=
torch
.
cat
((
x_up1
.
indices
[:,
0
:
1
].
float
(),
point_coords
),
dim
=
1
)
batch_dict
[
'encoded_spconv_tensor'
]
=
out
batch_dict
[
'encoded_spconv_tensor_stride'
]
=
8
return
batch_dict
return
batch_dict
tools/cfgs/kitti_models/PartA2_free.yaml
View file @
e6f6151c
...
@@ -12,6 +12,7 @@ MODEL:
...
@@ -12,6 +12,7 @@ MODEL:
BACKBONE_3D
:
BACKBONE_3D
:
NAME
:
UNetV2
NAME
:
UNetV2
RETURN_ENCODED_TENSOR
:
False
POINT_HEAD
:
POINT_HEAD
:
NAME
:
PointHeadBox
NAME
:
PointHeadBox
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment