Unverified Commit d6d7447c authored by ChaimZhu's avatar ChaimZhu Committed by GitHub
Browse files

[Fix] fix waymo training and eval bug (#1733)

* fix waymo bug

* update ceph replace

* add file_client_args for kitti2waymo
parent e6fb90ca
...@@ -72,12 +72,7 @@ test_pipeline = [ ...@@ -72,12 +72,7 @@ test_pipeline = [
# construct a pipeline for data and gt loading in show function # construct a pipeline for data and gt loading in show function
# please keep its loading function consistent with test_pipeline (e.g. client) # please keep its loading function consistent with test_pipeline (e.g. client)
eval_pipeline = [ eval_pipeline = [
dict( dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=6, use_dim=5),
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=6,
use_dim=5,
file_client_args=file_client_args),
dict(type='Pack3DDetInputs', keys=['points']), dict(type='Pack3DDetInputs', keys=['points']),
] ]
...@@ -141,6 +136,5 @@ val_evaluator = dict( ...@@ -141,6 +136,5 @@ val_evaluator = dict(
type='WaymoMetric', type='WaymoMetric',
ann_file='./data/waymo/kitti_format/waymo_infos_val.pkl', ann_file='./data/waymo/kitti_format/waymo_infos_val.pkl',
waymo_bin_file='./data/waymo/waymo_format/gt.bin', waymo_bin_file='./data/waymo/waymo_format/gt.bin',
data_root='./data/waymo/waymo_format', data_root='./data/waymo/waymo_format')
file_client_args=file_client_args)
test_evaluator = val_evaluator test_evaluator = val_evaluator
...@@ -70,12 +70,7 @@ test_pipeline = [ ...@@ -70,12 +70,7 @@ test_pipeline = [
# construct a pipeline for data and gt loading in show function # construct a pipeline for data and gt loading in show function
# please keep its loading function consistent with test_pipeline (e.g. client) # please keep its loading function consistent with test_pipeline (e.g. client)
eval_pipeline = [ eval_pipeline = [
dict( dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=6, use_dim=5),
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=6,
use_dim=5,
file_client_args=file_client_args),
dict(type='Pack3DDetInputs', keys=['points']), dict(type='Pack3DDetInputs', keys=['points']),
] ]
...@@ -139,6 +134,5 @@ val_evaluator = dict( ...@@ -139,6 +134,5 @@ val_evaluator = dict(
type='WaymoMetric', type='WaymoMetric',
ann_file='./data/waymo/kitti_format/waymo_infos_val.pkl', ann_file='./data/waymo/kitti_format/waymo_infos_val.pkl',
waymo_bin_file='./data/waymo/waymo_format/gt.bin', waymo_bin_file='./data/waymo/waymo_format/gt.bin',
data_root='./data/waymo/waymo_format', data_root='./data/waymo/waymo_format')
file_client_args=file_client_args)
test_evaluator = val_evaluator test_evaluator = val_evaluator
...@@ -44,7 +44,8 @@ class KITTI2Waymo(object): ...@@ -44,7 +44,8 @@ class KITTI2Waymo(object):
waymo_results_save_dir, waymo_results_save_dir,
waymo_results_final_path, waymo_results_final_path,
prefix, prefix,
workers=64): workers=64,
file_client_args=dict(backend='disk')):
self.kitti_result_files = kitti_result_files self.kitti_result_files = kitti_result_files
self.waymo_tfrecords_dir = waymo_tfrecords_dir self.waymo_tfrecords_dir = waymo_tfrecords_dir
...@@ -52,10 +53,11 @@ class KITTI2Waymo(object): ...@@ -52,10 +53,11 @@ class KITTI2Waymo(object):
self.waymo_results_final_path = waymo_results_final_path self.waymo_results_final_path = waymo_results_final_path
self.prefix = prefix self.prefix = prefix
self.workers = int(workers) self.workers = int(workers)
self.file_client_args = file_client_args
self.name2idx = {} self.name2idx = {}
for idx, result in enumerate(kitti_result_files): for idx, result in enumerate(kitti_result_files):
if len(result['sample_idx']) > 0: if len(result['sample_id']) > 0:
self.name2idx[str(result['sample_idx'][0])] = idx self.name2idx[str(result['sample_id'][0])] = idx
# turn on eager execution for older tensorflow versions # turn on eager execution for older tensorflow versions
if int(tf.__version__.split('.')[0]) < 2: if int(tf.__version__.split('.')[0]) < 2:
...@@ -78,8 +80,23 @@ class KITTI2Waymo(object): ...@@ -78,8 +80,23 @@ class KITTI2Waymo(object):
def get_file_names(self): def get_file_names(self):
"""Get file names of waymo raw data.""" """Get file names of waymo raw data."""
self.waymo_tfrecord_pathnames = sorted( if 'path_mapping' in self.file_client_args:
glob(join(self.waymo_tfrecords_dir, '*.tfrecord'))) for path in self.file_client_args['path_mapping'].keys():
if path in self.waymo_tfrecords_dir:
self.waymo_tfrecords_dir = \
self.waymo_tfrecords_dir.replace(
path, self.file_client_args['path_mapping'][path])
from petrel_client.client import Client
client = Client()
contents = client.list(self.waymo_tfrecords_dir)
self.waymo_tfrecord_pathnames = list()
for content in sorted(list(contents)):
if content.endswith('tfrecord'):
self.waymo_tfrecord_pathnames.append(
join(self.waymo_tfrecords_dir, content))
else:
self.waymo_tfrecord_pathnames = sorted(
glob(join(self.waymo_tfrecords_dir, '*.tfrecord')))
print(len(self.waymo_tfrecord_pathnames), 'tfrecords found.') print(len(self.waymo_tfrecord_pathnames), 'tfrecords found.')
def create_folder(self): def create_folder(self):
......
...@@ -7,6 +7,7 @@ import mmcv ...@@ -7,6 +7,7 @@ import mmcv
import numpy as np import numpy as np
import torch import torch
from mmcv.utils import print_log from mmcv.utils import print_log
from mmengine import load
from mmengine.logging import MMLogger from mmengine.logging import MMLogger
from mmdet3d.models.layers import box3d_multiclass_nms from mmdet3d.models.layers import box3d_multiclass_nms
...@@ -102,7 +103,8 @@ class WaymoMetric(KittiMetric): ...@@ -102,7 +103,8 @@ class WaymoMetric(KittiMetric):
self.classes = self.dataset_meta['CLASSES'] self.classes = self.dataset_meta['CLASSES']
# load annotations # load annotations
self.data_infos = self.load_annotations(self.ann_file)['data_list'] self.data_infos = load(
self.ann_file, file_client_args=self.file_client_args)['data_list']
# different from kitti, waymo do not need to convert the ann file # different from kitti, waymo do not need to convert the ann file
if self.pklfile_prefix is None: if self.pklfile_prefix is None:
...@@ -223,7 +225,7 @@ class WaymoMetric(KittiMetric): ...@@ -223,7 +225,7 @@ class WaymoMetric(KittiMetric):
waymo_save_tmp_dir = tempfile.TemporaryDirectory() waymo_save_tmp_dir = tempfile.TemporaryDirectory()
waymo_results_save_dir = waymo_save_tmp_dir.name waymo_results_save_dir = waymo_save_tmp_dir.name
waymo_results_final_path = f'{pklfile_prefix}.bin' waymo_results_final_path = f'{pklfile_prefix}.bin'
from ..core.evaluation.waymo_utils.prediction_kitti_to_waymo import \ from ..functional.waymo_utils.prediction_kitti_to_waymo import \
KITTI2Waymo KITTI2Waymo
converter = KITTI2Waymo( converter = KITTI2Waymo(
result_files['pred_instances_3d'], result_files['pred_instances_3d'],
......
...@@ -8,7 +8,7 @@ def replace_ceph_backend(cfg): ...@@ -8,7 +8,7 @@ def replace_ceph_backend(cfg):
r'''file_client_args = dict( r'''file_client_args = dict(
backend='petrel', backend='petrel',
path_mapping=dict({ path_mapping=dict({
'.data/DATA/': 's3://openmmlab/datasets/detection3d/CEPH/', './data/DATA/': 's3://openmmlab/datasets/detection3d/CEPH/',
'data/DATA/': 's3://openmmlab/datasets/detection3d/CEPH/' 'data/DATA/': 's3://openmmlab/datasets/detection3d/CEPH/'
})) }))
''' '''
...@@ -19,12 +19,12 @@ def replace_ceph_backend(cfg): ...@@ -19,12 +19,12 @@ def replace_ceph_backend(cfg):
elif 'lyft' in cfg_pretty_text: elif 'lyft' in cfg_pretty_text:
replace_strs = replace_strs.replace('DATA', 'lyft') replace_strs = replace_strs.replace('DATA', 'lyft')
replace_strs = replace_strs.replace('CEPH', 'lyft') replace_strs = replace_strs.replace('CEPH', 'lyft')
elif 'kitti' in cfg_pretty_text:
replace_strs = replace_strs.replace('DATA', 'kitti')
replace_strs = replace_strs.replace('CEPH', 'kitti')
elif 'waymo' in cfg_pretty_text: elif 'waymo' in cfg_pretty_text:
replace_strs = replace_strs.replace('DATA', 'waymo') replace_strs = replace_strs.replace('DATA', 'waymo')
replace_strs = replace_strs.replace('CEPH', 'waymo') replace_strs = replace_strs.replace('CEPH', 'waymo')
elif 'kitti' in cfg_pretty_text:
replace_strs = replace_strs.replace('DATA', 'kitti')
replace_strs = replace_strs.replace('CEPH', 'kitti')
elif 'scannet' in cfg_pretty_text: elif 'scannet' in cfg_pretty_text:
replace_strs = replace_strs.replace('DATA', 'scannet') replace_strs = replace_strs.replace('DATA', 'scannet')
replace_strs = replace_strs.replace('CEPH', 'scannet_processed') replace_strs = replace_strs.replace('CEPH', 'scannet_processed')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment