Unverified Commit d5655a41 authored by Martin Hahner's avatar Martin Hahner Committed by GitHub
Browse files

Add tqdm to WOD preprocessing (#372)

parent 53b2b93d
......@@ -8,6 +8,8 @@ import pickle
import copy
import numpy as np
import torch
import multiprocessing
from tqdm import tqdm
from pathlib import Path
from ...ops.roiaware_pool3d import roiaware_pool3d_utils
from ...utils import box_utils, common_utils
......@@ -74,7 +76,7 @@ class WaymoDataset(DatasetTemplate):
return sequence_file
def get_infos(self, raw_data_path, save_path, num_workers=4, has_label=True, sampled_interval=1):
def get_infos(self, raw_data_path, save_path, num_workers=multiprocessing.cpu_count(), has_label=True, sampled_interval=1):
import concurrent.futures as futures
from functools import partial
from . import waymo_utils
......@@ -92,8 +94,8 @@ class WaymoDataset(DatasetTemplate):
# process_single_sequence(sample_sequence_file_list[0])
with futures.ThreadPoolExecutor(num_workers) as executor:
sequence_infos = executor.map(process_single_sequence, sample_sequence_file_list)
sequence_infos = list(sequence_infos)
sequence_infos = list(tqdm(executor.map(process_single_sequence, sample_sequence_file_list),
total=len(sample_sequence_file_list)))
all_sequences_infos = [item for infos in sequence_infos for item in infos]
return all_sequences_infos
......@@ -305,7 +307,8 @@ class WaymoDataset(DatasetTemplate):
def create_waymo_infos(dataset_cfg, class_names, data_path, save_path,
raw_data_tag='raw_data', processed_data_tag='waymo_processed_data', workers=4):
raw_data_tag='raw_data', processed_data_tag='waymo_processed_data',
workers=multiprocessing.cpu_count()):
dataset = WaymoDataset(
dataset_cfg=dataset_cfg, class_names=class_names, root_path=data_path,
training=False, logger=common_utils.create_logger()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment