Commit 201f04b4 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

bugfixed to run PillarNet on WOD

parent 4e0962e9
......@@ -7,10 +7,11 @@ data/
venv/
*.idea/
*.so
*.yaml
*.sh
*.pth
*.pkl
*.zip
*.bin
output
version.py
\ No newline at end of file
version.py
from .mean_vfe import MeanVFE
from .pillar_vfe import PillarVFE
from .dynamic_mean_vfe import DynamicMeanVFE
from .dynamic_pillar_vfe import DynamicPillarVFE, DynamicPillarPFE
from .dynamic_pillar_vfe import DynamicPillarVFE, DynamicPillarVFESimple2D
from .image_vfe import ImageVFE
from .vfe_template import VFETemplate
......@@ -12,5 +12,5 @@ __all__ = {
'ImageVFE': ImageVFE,
'DynMeanVFE': DynamicMeanVFE,
'DynPillarVFE': DynamicPillarVFE,
'DynamicPillarPFE': DynamicPillarPFE
'DynamicPillarVFESimple2D': DynamicPillarVFESimple2D
}
......@@ -137,12 +137,12 @@ class DynamicPillarVFE(VFETemplate):
), dim=1)
voxel_coords = voxel_coords[:, [0, 3, 2, 1]]
batch_dict['pillar_features'] = features
batch_dict['voxel_features'] = batch_dict['pillar_features'] = features
batch_dict['voxel_coords'] = voxel_coords
return batch_dict
class DynamicPillarPFE(VFETemplate):
class DynamicPillarVFESimple2D(VFETemplate):
def __init__(self, model_cfg, num_point_features, voxel_size, grid_size, point_cloud_range, **kwargs):
super().__init__(model_cfg=model_cfg)
......
......@@ -11,6 +11,7 @@ from .centerpoint import CenterPoint
from .pv_rcnn_plusplus import PVRCNNPlusPlus
from .mppnet import MPPNet
from .mppnet_e2e import MPPNetE2E
from .pillarnet import PillarNet
__all__ = {
'Detector3DTemplate': Detector3DTemplate,
......@@ -23,6 +24,7 @@ __all__ = {
'CaDDN': CaDDN,
'VoxelRCNN': VoxelRCNN,
'CenterPoint': CenterPoint,
'PillarNet': PillarNet,
'PVRCNNPlusPlus': PVRCNNPlusPlus,
'MPPNet': MPPNet,
'MPPNetE2E': MPPNetE2E
......
......@@ -100,7 +100,7 @@ class Detector3DTemplate(nn.Module):
backbone_2d_module = backbones_2d.__all__[self.model_cfg.BACKBONE_2D.NAME](
model_cfg=self.model_cfg.BACKBONE_2D,
input_channels=model_info_dict['num_bev_features']
input_channels=model_info_dict.get('num_bev_features', None)
)
model_info_dict['module_list'].append(backbone_2d_module)
model_info_dict['num_bev_features'] = backbone_2d_module.num_bev_features
......
......@@ -8,7 +8,7 @@ MODEL:
NAME: PillarNet
VFE:
NAME: DynamicPillarPFE
NAME: DynamicPillarVFESimple2D
WITH_DISTANCE: False
USE_ABSLOTE_XYZ: True
USE_CLUSTER_XYZ: False
......
......@@ -64,7 +64,7 @@ MODEL:
NAME: PillarNet
VFE:
NAME: DynamicPillarPFE
NAME: DynamicPillarVFESimple2D
WITH_DISTANCE: False
USE_ABSLOTE_XYZ: True
USE_CLUSTER_XYZ: False
......
......@@ -7,7 +7,7 @@ MODEL:
NAME: PillarNet
VFE:
NAME: DynamicPillarPFE
NAME: DynamicPillarVFESimple2D
WITH_DISTANCE: False
USE_ABSLOTE_XYZ: True
USE_CLUSTER_XYZ: False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment