Unverified Commit d5655a41 authored by Martin Hahner's avatar Martin Hahner Committed by GitHub
Browse files

Add tqdm to WOD preprocessing (#372)

parent 53b2b93d
...@@ -8,6 +8,8 @@ import pickle ...@@ -8,6 +8,8 @@ import pickle
import copy import copy
import numpy as np import numpy as np
import torch import torch
import multiprocessing
from tqdm import tqdm
from pathlib import Path from pathlib import Path
from ...ops.roiaware_pool3d import roiaware_pool3d_utils from ...ops.roiaware_pool3d import roiaware_pool3d_utils
from ...utils import box_utils, common_utils from ...utils import box_utils, common_utils
...@@ -74,7 +76,7 @@ class WaymoDataset(DatasetTemplate): ...@@ -74,7 +76,7 @@ class WaymoDataset(DatasetTemplate):
return sequence_file return sequence_file
def get_infos(self, raw_data_path, save_path, num_workers=4, has_label=True, sampled_interval=1): def get_infos(self, raw_data_path, save_path, num_workers=multiprocessing.cpu_count(), has_label=True, sampled_interval=1):
import concurrent.futures as futures import concurrent.futures as futures
from functools import partial from functools import partial
from . import waymo_utils from . import waymo_utils
...@@ -92,8 +94,8 @@ class WaymoDataset(DatasetTemplate): ...@@ -92,8 +94,8 @@ class WaymoDataset(DatasetTemplate):
# process_single_sequence(sample_sequence_file_list[0]) # process_single_sequence(sample_sequence_file_list[0])
with futures.ThreadPoolExecutor(num_workers) as executor: with futures.ThreadPoolExecutor(num_workers) as executor:
sequence_infos = executor.map(process_single_sequence, sample_sequence_file_list) sequence_infos = list(tqdm(executor.map(process_single_sequence, sample_sequence_file_list),
sequence_infos = list(sequence_infos) total=len(sample_sequence_file_list)))
all_sequences_infos = [item for infos in sequence_infos for item in infos] all_sequences_infos = [item for infos in sequence_infos for item in infos]
return all_sequences_infos return all_sequences_infos
...@@ -305,7 +307,8 @@ class WaymoDataset(DatasetTemplate): ...@@ -305,7 +307,8 @@ class WaymoDataset(DatasetTemplate):
def create_waymo_infos(dataset_cfg, class_names, data_path, save_path, def create_waymo_infos(dataset_cfg, class_names, data_path, save_path,
raw_data_tag='raw_data', processed_data_tag='waymo_processed_data', workers=4): raw_data_tag='raw_data', processed_data_tag='waymo_processed_data',
workers=multiprocessing.cpu_count()):
dataset = WaymoDataset( dataset = WaymoDataset(
dataset_cfg=dataset_cfg, class_names=class_names, root_path=data_path, dataset_cfg=dataset_cfg, class_names=class_names, root_path=data_path,
training=False, logger=common_utils.create_logger() training=False, logger=common_utils.create_logger()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment