Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
cbf194fa
Unverified
Commit
cbf194fa
authored
Jul 21, 2021
by
xiliu8006
Committed by
GitHub
Jul 21, 2021
Browse files
[Enhance] GroupFree3d inherits BaseModule From MMCV (#704)
parent
26c18075
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
13 deletions
+13
-13
mmdet3d/models/dense_heads/groupfree3d_head.py
mmdet3d/models/dense_heads/groupfree3d_head.py
+13
-13
No files found.
mmdet3d/models/dense_heads/groupfree3d_head.py
View file @
cbf194fa
...
...
@@ -2,10 +2,10 @@ import copy
import
numpy
as
np
import
torch
from
mmcv
import
ConfigDict
from
mmcv.cnn
import
ConvModule
from
mmcv.cnn
import
ConvModule
,
xavier_init
from
mmcv.cnn.bricks.transformer
import
(
build_positional_encoding
,
build_transformer_layer
)
from
mmcv.runner
import
force_fp32
from
mmcv.runner
import
BaseModule
,
force_fp32
from
torch
import
nn
as
nn
from
torch.nn
import
functional
as
F
...
...
@@ -19,7 +19,7 @@ from .base_conv_bbox_head import BaseConvBboxHead
EPS
=
1e-6
class
PointsObjClsModule
(
nn
.
Module
):
class
PointsObjClsModule
(
Base
Module
):
"""object candidate point prediction from seed point features.
Args:
...
...
@@ -39,8 +39,9 @@ class PointsObjClsModule(nn.Module):
num_convs
=
3
,
conv_cfg
=
dict
(
type
=
'Conv1d'
),
norm_cfg
=
dict
(
type
=
'BN1d'
),
act_cfg
=
dict
(
type
=
'ReLU'
)):
super
().
__init__
()
act_cfg
=
dict
(
type
=
'ReLU'
),
init_cfg
=
None
):
super
().
__init__
(
init_cfg
=
init_cfg
)
conv_channels
=
[
in_channel
for
_
in
range
(
num_convs
-
1
)]
conv_channels
.
append
(
1
)
...
...
@@ -104,7 +105,7 @@ class GeneralSamplingModule(nn.Module):
@
HEADS
.
register_module
()
class
GroupFree3DHead
(
nn
.
Module
):
class
GroupFree3DHead
(
Base
Module
):
r
"""Bbox head of `Group-Free 3D <https://arxiv.org/abs/2104.00678>`_.
Args:
...
...
@@ -162,8 +163,9 @@ class GroupFree3DHead(nn.Module):
size_class_loss
=
None
,
size_res_loss
=
None
,
size_reg_loss
=
None
,
semantic_loss
=
None
):
super
(
GroupFree3DHead
,
self
).
__init__
()
semantic_loss
=
None
,
init_cfg
=
None
):
super
(
GroupFree3DHead
,
self
).
__init__
(
init_cfg
=
init_cfg
)
self
.
num_classes
=
num_classes
self
.
train_cfg
=
train_cfg
self
.
test_cfg
=
test_cfg
...
...
@@ -251,15 +253,13 @@ class GroupFree3DHead(nn.Module):
# initialize transformer
for
m
in
self
.
decoder_layers
.
parameters
():
if
m
.
dim
()
>
1
:
nn
.
init
.
xavier_uniform_
(
m
)
xavier_init
(
m
,
distribution
=
'uniform'
)
for
m
in
self
.
decoder_self_posembeds
.
parameters
():
if
m
.
dim
()
>
1
:
nn
.
init
.
xavier_uniform_
(
m
)
xavier_init
(
m
,
distribution
=
'uniform'
)
for
m
in
self
.
decoder_cross_posembeds
.
parameters
():
if
m
.
dim
()
>
1
:
nn
.
init
.
xavier_
uniform
_
(
m
)
xavier_init
(
m
,
distribution
=
'
uniform
'
)
def
_get_cls_out_channels
(
self
):
"""Return the channel number of classification outputs."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment