Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenPCDet
Commits
69326cef
Commit
69326cef
authored
Jul 13, 2020
by
Shaoshuai Shi
Browse files
speed up PFNLayer
parent
f70902b8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
4 deletions
+8
-4
pcdet/models/backbones_3d/vfe/pillar_vfe.py
pcdet/models/backbones_3d/vfe/pillar_vfe.py
+8
-4
No files found.
pcdet/models/backbones_3d/vfe/pillar_vfe.py
View file @
69326cef
...
@@ -3,6 +3,7 @@ import torch.nn as nn
...
@@ -3,6 +3,7 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
.vfe_template
import
VFETemplate
from
.vfe_template
import
VFETemplate
class
PFNLayer
(
nn
.
Module
):
class
PFNLayer
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
in_channels
,
in_channels
,
...
@@ -28,12 +29,14 @@ class PFNLayer(nn.Module):
...
@@ -28,12 +29,14 @@ class PFNLayer(nn.Module):
if
inputs
.
shape
[
0
]
>
self
.
part
:
if
inputs
.
shape
[
0
]
>
self
.
part
:
# nn.Linear performs randomly when batch size is too large
# nn.Linear performs randomly when batch size is too large
num_parts
=
inputs
.
shape
[
0
]
//
self
.
part
num_parts
=
inputs
.
shape
[
0
]
//
self
.
part
part_linear_out
=
[
self
.
linear
(
inputs
[
num_part
*
self
.
part
:(
num_part
+
1
)
*
self
.
part
])
for
num_part
in
range
(
num_parts
+
1
)]
part_linear_out
=
[
self
.
linear
(
inputs
[
num_part
*
self
.
part
:(
num_part
+
1
)
*
self
.
part
])
for
num_part
in
range
(
num_parts
+
1
)]
x
=
torch
.
cat
(
part_linear_out
,
dim
=
0
)
x
=
torch
.
cat
(
part_linear_out
,
dim
=
0
)
else
:
else
:
x
=
self
.
linear
(
inputs
)
x
=
self
.
linear
(
inputs
)
total_points
,
voxel_points
,
channels
=
x
.
shape
torch
.
backends
.
cudnn
.
enabled
=
False
x
=
self
.
norm
(
x
.
view
(
-
1
,
channels
)).
view
(
total_points
,
voxel_points
,
channels
)
if
self
.
use_norm
else
x
x
=
self
.
norm
(
x
.
permute
(
0
,
2
,
1
)).
permute
(
0
,
2
,
1
)
if
self
.
use_norm
else
x
torch
.
backends
.
cudnn
.
enabled
=
True
x
=
F
.
relu
(
x
)
x
=
F
.
relu
(
x
)
x_max
=
torch
.
max
(
x
,
dim
=
1
,
keepdim
=
True
)[
0
]
x_max
=
torch
.
max
(
x
,
dim
=
1
,
keepdim
=
True
)[
0
]
...
@@ -44,6 +47,7 @@ class PFNLayer(nn.Module):
...
@@ -44,6 +47,7 @@ class PFNLayer(nn.Module):
x_concatenated
=
torch
.
cat
([
x
,
x_repeat
],
dim
=
2
)
x_concatenated
=
torch
.
cat
([
x
,
x_repeat
],
dim
=
2
)
return
x_concatenated
return
x_concatenated
class
PillarVFE
(
VFETemplate
):
class
PillarVFE
(
VFETemplate
):
def
__init__
(
self
,
model_cfg
,
num_point_features
,
voxel_size
,
point_cloud_range
):
def
__init__
(
self
,
model_cfg
,
num_point_features
,
voxel_size
,
point_cloud_range
):
super
().
__init__
(
model_cfg
=
model_cfg
)
super
().
__init__
(
model_cfg
=
model_cfg
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment