import random import esim_py import torch import h5py import numpy as np import math import bisect from pathlib import Path project_dir = Path(__file__).resolve().parent.parent import sys sys.path.append(str(project_dir)) from others.event_utils.lib.representations.voxel_grid import events_to_voxel_torch from glob import glob # from PIL import Image import cv2 import os def package_images(image_root, h5_path): for ip in glob(os.path.join(image_root, "*.png")): image = cv2.imread(ip) image = np.array(image) image_name = ip.split(os.sep)[-1].split('.')[0].split("_")[-1] with h5py.File(h5_path, 'a') as h5f: h5f.create_dataset(f"images/{image_name}", data=image, compression="gzip") def vid2events(image_root, sensor_size_height, sensor_size_width): config = { 'refractory_period': 1e-4, 'CT_range': [0.05, 0.5], 'max_CT': 0.5, 'min_CT': 0.02, 'mu': 1, 'sigma': 0.1, 'H': sensor_size_height, 'W': sensor_size_width, 'log_eps': 1e-3, 'use_log': True, } Cp = random.uniform(config['CT_range'][0], config['CT_range'][1]) Cn = random.gauss(config['mu'], config['sigma']) * Cp Cp = min(max(Cp, config['min_CT']), config['max_CT']) Cn = min(max(Cn, config['min_CT']), config['max_CT']) esim = esim_py.EventSimulator(Cp, Cn, config['refractory_period'], config['log_eps'], config['use_log']) events = esim.generateFromFolder(f"{image_root}/images", f"{image_root}/timestamps.txt") # Generate events with shape [N, 4] return events def voxel_normalization(voxel): """ normalize the voxel same as https://arxiv.org/abs/1912.01584 Section 3.1 Params: voxel: torch.Tensor, shape is [num_bins, H, W] return: normalized voxel """ # check if voxel all element is 0 a,b,c = voxel.shape tmp = torch.zeros(a, b, c) if torch.equal(voxel, tmp): return voxel abs_voxel, _ = torch.sort(torch.abs(voxel).view(-1, 1).squeeze(1)) first_non_zero_idx = torch.nonzero(abs_voxel)[0].item() non_zero_voxel = abs_voxel[first_non_zero_idx:] norm_idx = math.floor(non_zero_voxel.shape[0] * 0.98) ones = torch.ones_like(voxel) normed_voxel = torch.where(torch.abs(voxel) < non_zero_voxel[norm_idx], voxel / non_zero_voxel[norm_idx], voxel) normed_voxel = torch.where(normed_voxel >= non_zero_voxel[norm_idx], ones, normed_voxel) normed_voxel = torch.where(normed_voxel <= -non_zero_voxel[norm_idx], -ones, normed_voxel) return normed_voxel def package_bidirectional_event_voxels(x, y, t, p, timestamp_list, backward, bins, sensor_size, h5_name, error_txt): """ params: x: ndarray, x-position of events y: ndarray, y-position of events t: ndarray, timestamp of events p: ndarray, polarity of events backward: bool, if forward or backward timestamp_list: list, to split events via timestamp bins: voxel num_bins returns: no return. """ # Step 1: convert data type assert x.shape == y.shape == t.shape == p.shape x = torch.from_numpy(x.astype(np.int16)) y = torch.from_numpy(y.astype(np.int16)) t = torch.from_numpy(t.astype(np.float32)) p = torch.from_numpy(p.astype(np.int16)) assert x.shape == y.shape == t.shape == p.shape # Step 2: select events between two frames according to timestamp temp = t.numpy().tolist() output = [ temp[ bisect.bisect_left(temp, timestamp_list[i]):bisect.bisect_left(temp, timestamp_list[i+1]) ] for i in range(len(timestamp_list) - 1) ] # Debug: Check if data error!!! assert len(output) == len(timestamp_list) - 1, f"len(output) is {len(output)}, but len(timestamp_list) is {len(timestamp_list)}" sum_output = [] sum = 0 for i in range(len(output)): if len(output[i]) == 0: raise ValueError(f"{h5_name} len(output[{i}] == 0)") elif len(output[i]) == 1: raise ValueError(f"{h5_name} len(output[{i}] == 1)") sum += len(output[i]) sum_output.append(sum) assert len(sum_output) == len(output) # Step 3: After checking data, continue. start_idx = 0 for voxel_idx in range(len(timestamp_list) - 1): if len(output[voxel_idx]) == 0 or len(output[voxel_idx]) == 1: print(f'{h5_name} len(output[{voxel_idx}])): ', len( output[voxel_idx])) with open(error_txt, 'a+') as f: f.write(h5_name + '\n') return end_idx = start_idx + len(output[voxel_idx]) if end_idx > len(t): with open(error_txt, 'a+') as f: f.write(f"{h5_name} voxel_idx: {voxel_idx}, start_idx {start_idx} end_idx {end_idx} exceed bound." + '\n') print(f"{h5_name} voxel_idx: {voxel_idx}, start_idx {start_idx} end_idx {end_idx} with exceed bound len(t) {len(t)}.") return xs = x[start_idx:end_idx] ys = y[start_idx:end_idx] ts = t[start_idx:end_idx] ps = p[start_idx:end_idx] # print(len(xs), len(ys), len(ts), len(ps)) if ts == torch.Size([]) or ts.shape == torch.Size([1]) or ts.shape == torch.Size([0]): with open(error_txt, 'a+') as f: f.write(f"{h5_name} len(output[{voxel_idx}]) backward {backward} start_idx {start_idx} end_idx {end_idx} is error! Please check the data." + '\n') print(f"{h5_name} len(output[{voxel_idx}]) backward {backward} start_idx {start_idx} end_idx {end_idx} is error! Please check the data.") return if backward: t_start = timestamp_list[voxel_idx] t_end = timestamp_list[voxel_idx + 1] xs = torch.flip(xs, dims=[0]) ys = torch.flip(ys, dims=[0]) ts = torch.flip(t_end - ts + t_start, dims=[0]) ps = torch.flip(-ps, dims=[0]) voxel = events_to_voxel_torch( xs, ys, ts, ps, bins, device=None, sensor_size=sensor_size) normed_voxel = voxel_normalization(voxel) np_voxel = normed_voxel.numpy() with h5py.File(h5_name, 'a') as events_file: if backward: events_file.create_dataset("voxels_b/{:06d}".format( voxel_idx), data=np_voxel, dtype=np.dtype(np.float32), compression="gzip") else: events_file.create_dataset("voxels_f/{:06d}".format( voxel_idx), data=np_voxel, dtype=np.dtype(np.float32), compression="gzip") start_idx = end_idx def events(args): # 1. 创建事件 events = vid2events(args.image_root, args.sensor_size_height, args.sensor_size_width) # 2. 时间 voxel grids timestamp_list = [] with open(f"{args.image_root}/timestamps.txt", "r") as f: for line in f.readlines(): timestamp_list.append(float(line.strip())) package_bidirectional_event_voxels( events[:, 0], events[:, 1], events[:, 2], events[:, 3], timestamp_list, args.backward, args.bins, (args.sensor_size_height, args.sensor_size_width), args.h5_path, args.error ) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--image_root", help="图像存储根目录") parser.add_argument("--backward", action="store_true") parser.add_argument("--sensor_size_height", type=int) parser.add_argument("--sensor_size_width", type=int) parser.add_argument("--bins", type=int, default=5) parser.add_argument("--h5_path", type=str) parser.add_argument("--error", type=str, help="错误信息存储路径") args = parser.parse_args() if not os.path.exists(args.h5_path): print("处理图像") package_images(f"{args.image_root}/images", args.h5_path) else: print("已处理") print("backward?", args.backward) events(args)