Commit 14e2adc5 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

update check_sequence_name_with_all_version of WOD to support train/val/test prefix

parent 738ba6f9
...@@ -69,11 +69,18 @@ class WaymoDataset(DatasetTemplate): ...@@ -69,11 +69,18 @@ class WaymoDataset(DatasetTemplate):
@staticmethod @staticmethod
def check_sequence_name_with_all_version(sequence_file): def check_sequence_name_with_all_version(sequence_file):
if '_with_camera_labels' not in str(sequence_file) and not sequence_file.exists(): if not sequence_file.exists():
sequence_file = Path(str(sequence_file)[:-9] + '_with_camera_labels.tfrecord') found_sequence_file = sequence_file
if '_with_camera_labels' in str(sequence_file) and not sequence_file.exists(): for pre_text in ['training', 'validation', 'testing']:
sequence_file = Path(str(sequence_file).replace('_with_camera_labels', '')) if not sequence_file.exists():
temp_sequence_file = Path(str(sequence_file).replace('segment', pre_text + '_segment'))
if temp_sequence_file.exists():
found_sequence_file = temp_sequence_file
break
if not found_sequence_file.exists():
found_sequence_file = Path(str(sequence_file).replace('_with_camera_labels', ''))
if found_sequence_file.exists():
sequence_file = found_sequence_file
return sequence_file return sequence_file
def get_infos(self, raw_data_path, save_path, num_workers=multiprocessing.cpu_count(), has_label=True, sampled_interval=1): def get_infos(self, raw_data_path, save_path, num_workers=multiprocessing.cpu_count(), has_label=True, sampled_interval=1):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment