#===================================================================================== # Author: Aobo Li # Contact: liaobo77@gmail.com # # Last Modified: Aug. 29, 2021 # # * The PyTorch dataset classes for KamNet #===================================================================================== import numpy as np import torch.utils.data as data_utils from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler from tool import label_data, create_table, create_table_zpos, get_roc, create_table_energy, look_table from settings import FILE_UPPERLIM from tqdm import tqdm import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt class DetectorDataset(Dataset): def __init__(self, json_name): """ Base class for all KamNet datasets """ self.json_name = json_name def __len__(self): return self.size def __getitem__(self, idx): image = np.zeros(self.image_shape, dtype=np.float32) for time_index, time in enumerate(self.trainX[idx]): image[time_index] = time.todense() return image, self.trainY[idx] def return_time_channel(self): ''' This method returns the time channel and one hit map dimension of input E.g. If it returns (28,38), this means the input has 28 time channel, where each channel contains a 38*38 hitmap ''' return (self.__getitem__(0)[0].shape[0], self.image_shape[1]) def cap_resample(self,input,cap=5000): ''' This method randomly resamples part of the dataset ''' if input.shape[0] < cap: return input signal_samples = np.random.choice(np.arange(input.shape[0]), cap, replace=False) return input[signal_samples] def get_sparse_nhit(self, sparse_dict): ''' This method get the nhit as a list of given event dict It reads out the Nhit directly if Nhit is stored in the dict Otherwise it calculate Nhit from the sparce matrices ''' if "Nhit" in sparse_dict.keys(): return np.array(sparse_dict["Nhit"], dtype=int).flatten() else: sparsem = np.array(sparse_dict[self.json_name], dtype=object) sparse_nhit = [] for i in tqdm(range(len(sparsem))): sparse_nhit.append(np.sum([len(slice.nonzero()[0]) for slice in sparsem[i]])) return np.array(sparse_nhit) def match_nhit(self, signal_dict, background_dict, multiplier=1.0): ''' Perform Nhit matching between input signal and output background ''' signal_images = np.array(signal_dict[self.json_name], dtype=object) background_images = np.array(background_dict[self.json_name], dtype=object) nhit_range = np.arange(0,2000,1) signal_nhit = np.array(self.get_sparse_nhit(signal_dict)) bkg_nhit = np.array(self.get_sparse_nhit(background_dict)) signal_list = [] bkg_list = [] for (nlow, nhi) in tqdm(zip(list(nhit_range[:-1]), list(nhit_range[1:])),0): signal_index = np.where((signal_nhit >= nlow) & (signal_nhit = nlow) & (bkg_nhit = nlow) & (signal_nhit = nlow) & (bkg_nhit