Unverified Commit d6d7447c authored by ChaimZhu's avatar ChaimZhu Committed by GitHub
Browse files

[Fix] fix waymo training and eval bug (#1733)

* fix waymo bug

* update ceph replace

* add file_client_args for kitti2waymo
parent e6fb90ca
......@@ -72,12 +72,7 @@ test_pipeline = [
# construct a pipeline for data and gt loading in show function
# please keep its loading function consistent with test_pipeline (e.g. client)
eval_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=6,
use_dim=5,
file_client_args=file_client_args),
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=6, use_dim=5),
dict(type='Pack3DDetInputs', keys=['points']),
]
......@@ -141,6 +136,5 @@ val_evaluator = dict(
type='WaymoMetric',
ann_file='./data/waymo/kitti_format/waymo_infos_val.pkl',
waymo_bin_file='./data/waymo/waymo_format/gt.bin',
data_root='./data/waymo/waymo_format',
file_client_args=file_client_args)
data_root='./data/waymo/waymo_format')
test_evaluator = val_evaluator
......@@ -70,12 +70,7 @@ test_pipeline = [
# construct a pipeline for data and gt loading in show function
# please keep its loading function consistent with test_pipeline (e.g. client)
eval_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=6,
use_dim=5,
file_client_args=file_client_args),
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=6, use_dim=5),
dict(type='Pack3DDetInputs', keys=['points']),
]
......@@ -139,6 +134,5 @@ val_evaluator = dict(
type='WaymoMetric',
ann_file='./data/waymo/kitti_format/waymo_infos_val.pkl',
waymo_bin_file='./data/waymo/waymo_format/gt.bin',
data_root='./data/waymo/waymo_format',
file_client_args=file_client_args)
data_root='./data/waymo/waymo_format')
test_evaluator = val_evaluator
......@@ -44,7 +44,8 @@ class KITTI2Waymo(object):
waymo_results_save_dir,
waymo_results_final_path,
prefix,
workers=64):
workers=64,
file_client_args=dict(backend='disk')):
self.kitti_result_files = kitti_result_files
self.waymo_tfrecords_dir = waymo_tfrecords_dir
......@@ -52,10 +53,11 @@ class KITTI2Waymo(object):
self.waymo_results_final_path = waymo_results_final_path
self.prefix = prefix
self.workers = int(workers)
self.file_client_args = file_client_args
self.name2idx = {}
for idx, result in enumerate(kitti_result_files):
if len(result['sample_idx']) > 0:
self.name2idx[str(result['sample_idx'][0])] = idx
if len(result['sample_id']) > 0:
self.name2idx[str(result['sample_id'][0])] = idx
# turn on eager execution for older tensorflow versions
if int(tf.__version__.split('.')[0]) < 2:
......@@ -78,8 +80,23 @@ class KITTI2Waymo(object):
def get_file_names(self):
"""Get file names of waymo raw data."""
self.waymo_tfrecord_pathnames = sorted(
glob(join(self.waymo_tfrecords_dir, '*.tfrecord')))
if 'path_mapping' in self.file_client_args:
for path in self.file_client_args['path_mapping'].keys():
if path in self.waymo_tfrecords_dir:
self.waymo_tfrecords_dir = \
self.waymo_tfrecords_dir.replace(
path, self.file_client_args['path_mapping'][path])
from petrel_client.client import Client
client = Client()
contents = client.list(self.waymo_tfrecords_dir)
self.waymo_tfrecord_pathnames = list()
for content in sorted(list(contents)):
if content.endswith('tfrecord'):
self.waymo_tfrecord_pathnames.append(
join(self.waymo_tfrecords_dir, content))
else:
self.waymo_tfrecord_pathnames = sorted(
glob(join(self.waymo_tfrecords_dir, '*.tfrecord')))
print(len(self.waymo_tfrecord_pathnames), 'tfrecords found.')
def create_folder(self):
......
......@@ -7,6 +7,7 @@ import mmcv
import numpy as np
import torch
from mmcv.utils import print_log
from mmengine import load
from mmengine.logging import MMLogger
from mmdet3d.models.layers import box3d_multiclass_nms
......@@ -102,7 +103,8 @@ class WaymoMetric(KittiMetric):
self.classes = self.dataset_meta['CLASSES']
# load annotations
self.data_infos = self.load_annotations(self.ann_file)['data_list']
self.data_infos = load(
self.ann_file, file_client_args=self.file_client_args)['data_list']
# different from kitti, waymo do not need to convert the ann file
if self.pklfile_prefix is None:
......@@ -223,7 +225,7 @@ class WaymoMetric(KittiMetric):
waymo_save_tmp_dir = tempfile.TemporaryDirectory()
waymo_results_save_dir = waymo_save_tmp_dir.name
waymo_results_final_path = f'{pklfile_prefix}.bin'
from ..core.evaluation.waymo_utils.prediction_kitti_to_waymo import \
from ..functional.waymo_utils.prediction_kitti_to_waymo import \
KITTI2Waymo
converter = KITTI2Waymo(
result_files['pred_instances_3d'],
......
......@@ -8,7 +8,7 @@ def replace_ceph_backend(cfg):
r'''file_client_args = dict(
backend='petrel',
path_mapping=dict({
'.data/DATA/': 's3://openmmlab/datasets/detection3d/CEPH/',
'./data/DATA/': 's3://openmmlab/datasets/detection3d/CEPH/',
'data/DATA/': 's3://openmmlab/datasets/detection3d/CEPH/'
}))
'''
......@@ -19,12 +19,12 @@ def replace_ceph_backend(cfg):
elif 'lyft' in cfg_pretty_text:
replace_strs = replace_strs.replace('DATA', 'lyft')
replace_strs = replace_strs.replace('CEPH', 'lyft')
elif 'kitti' in cfg_pretty_text:
replace_strs = replace_strs.replace('DATA', 'kitti')
replace_strs = replace_strs.replace('CEPH', 'kitti')
elif 'waymo' in cfg_pretty_text:
replace_strs = replace_strs.replace('DATA', 'waymo')
replace_strs = replace_strs.replace('CEPH', 'waymo')
elif 'kitti' in cfg_pretty_text:
replace_strs = replace_strs.replace('DATA', 'kitti')
replace_strs = replace_strs.replace('CEPH', 'kitti')
elif 'scannet' in cfg_pretty_text:
replace_strs = replace_strs.replace('DATA', 'scannet')
replace_strs = replace_strs.replace('CEPH', 'scannet_processed')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment