processing_kamland_new_mc.py 7.69 KB
Newer Older
maming's avatar
maming committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
#=====================================================================================
#    Author: Aobo Li
#    Contact: liaobo77@gmail.com
#    
#    Last Modified: Aug. 29, 2021
#    
#    * This code is used to convert MC simulated .root file into a 2D square grid
#    * Save each event and other variables as a CSR sparse matrix in .pickle format.
#    * Only applicable to the KLGSim simulation by the KamLAND-Zen group. To use this on your
#      own experiment, please modify this code to adapt to your own MC data structures.
#=====================================================================================
import argparse
import math
import os
import json
import pickle
from scipy import sparse
from scipy import constants as const
from sklearn.preprocessing import StandardScaler, Normalizer, MinMaxScaler
from random import *
import numpy as np
import time
from ROOT import TFile
from datetime import datetime
from tqdm import tqdm
import matplotlib.gridspec as gridspec
from clock import clock
from tools import *
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib import cm
from settings import COLS, FV_CUT_LOW, FV_CUT_HI, good_hit, only_17inch, use_charge, ELOW, EHI, PLOT_HITMAP
colormap_normal = cm.get_cmap("cool")
FSIZE = 20
plt.rcParams['font.size'] = FSIZE

# Transcribe hits to a 4D tensor
def transcribe_hits(input, outputdir, PMT_POSITION, elow, ehi):
    current_clock = clock(0)
    f1 = TFile(input)
    tree = f1.Get("nt")  # Read the ROOT tree
    start_evt = 0
    end_evt = tree.GetEntries()
    tz=[]
    if PLOT_HITMAP:
        end_evt = 10000
    input_name = os.path.basename(input).split('.')[0]


    event_map = []
    for evt_index in tqdm(range(start_evt,end_evt)):
        tree.GetEntry(evt_index)
        #FV/ROI cut
        try:
            energy = tree.EnergyA2   # These are the 
            position = tree.r
            hit = tree.NhitID
            if (energy < ELOW) or (energy > EHI) or (position > FV_CUT_HI) or (position < FV_CUT_LOW):
                continue
        except:
            print("error")
            continue

        '''
        Read out PMT hitlist, time and charge
        '''
        if good_hit:
            good_pmt_list = np.array(tree.pmtlist_good)
            good_pmt_time_list = np.array(tree.pmtt_good)
            good_pmt_charge_list = np.array(tree.pmtq_good)
        else:
            #gettting PMT information for a event
            good_pmt_list = np.array(tree.pmtlist)
            good_pmt_time_list = np.array(tree.pmtt)
            good_pmt_charge_list = np.array(tree.pmtq)

        '''
        Read out only 17 inch PMTs (that is, PMT index < 1325)
        '''
        event = np.zeros((current_clock.clock_size(),ROWS,COLS))

        if only_17inch:
            good_index = good_pmt_list<1325
            good_pmt_list = good_pmt_list[good_index]
            good_pmt_time_list = good_pmt_time_list[good_index]
            good_pmt_charge_list = good_pmt_charge_list[good_index]

        vertex = np.array([tree.x/100.0,tree.y/100.0,tree.z/100.0])
        # Calculate time of flight. In KLZ simulation the TOF is already subtracted, so just set it to 0
        tof_array = []
        for pmtid in good_pmt_list:
            tof_array.append(0)
        good_pmt_tof = np.array(tof_array)

        tzero = tree.T0
        total_charge = np.sum(good_pmt_charge_list)

        stacked_pmt_info = np.dstack((good_pmt_list, good_pmt_time_list, good_pmt_charge_list, good_pmt_tof))[0]

        timea = []

        for pmtinfo in stacked_pmt_info:
            if pmtinfo[-2] == 0.0:
                # Skip PMT with 0 charge
                continue
            col, row = xyz_to_row_col(pmtinfo[0], PMT_POSITION)
            t_center = pmtinfo[1] -    tzero
            tz.append(t_center)
            time = current_clock.tick(t_center)
            if use_charge:
                event[time][row][col] += pmtinfo[-2]
            else:
                event[time][row][col] += 1.0

        event_dic = {}
        event_dic['id'] = tree.EventNumber
        event_dic['run'] = tree.run
        event_dic['Nhit'] = np.count_nonzero(event)
        event_dic['energy'] = energy
        event_dic['vertex'] = tree.r
        event_dic['zpos'] = tree.z
        event_dic['event'] = event

        event_map.append(event_dic)
    if PLOT_HITMAP:
        '''
        This is the plot method for given dataset, it plots a few selected hit maps for
        demonstration purpose
        '''
        plt.figure(figsize=(15,15))
        spec = gridspec.GridSpec(ncols=4, nrows=2, height_ratios=[1,2])
        plt.subplot(spec[1,:])
        idx_pool = [5,11,14,18]
        plt.hist(tz,bins=np.arange(-20,40,1.5),density=True,color=colormap_normal(0.2))
        plt.axvline(x=-20,color="red",label="KamNet Window")
        plt.axvline(x=22,color="red")
        for idxc in idx_pool:
            begin,end = current_clock.get_range_from_tick(idxc)
            plt.axvspan(xmin=begin,xmax=end,color=colormap_normal(0.7),alpha=0.5)
        plt.ylim(0,0.08)
        plt.legend(frameon=False)
        plt.xlabel("Proper Hit Time [ns]",fontsize=25,labelpad=20)
        plt.ylabel("Normalized Amplitude",fontsize=25,labelpad=20)
        # plt.savefig("th.png",dpi=600)

    with open(os.path.join(outputdir, "eventfile_%s_%.2f_%.2f.pickle" % (input_name, elow, ehi)), 'wb') as handle:
        numev = 0
        print(len(event_map))
        for eventd in event_map:
            evnt = eventd['event']
            eventd['nhit'] = np.count_nonzero(evnt)
            numev += 1
            time_sequence = []
            subplot_index = 0
            for idx, maps in enumerate(evnt):
                if PLOT_HITMAP and (idx in idx_pool):
                        ax = plt.subplot(spec[0,subplot_index ])
                        begin,end = current_clock.get_range_from_tick(idx)
                        if begin == -9999:
                            plt.title("(Past, %.1f ns)"%(end),fontsize=FSIZE)
                        else:
                            plt.title("(%s ns, %.1f ns)"%(begin,end),fontsize=FSIZE)
                        subplot_index += 1
                        ax.axes.get_xaxis().set_visible(False)
                        ax.axes.get_yaxis().set_visible(False)
                        ax.imshow(maps,cmap=colormap_normal, norm=matplotlib.colors.LogNorm(vmin=0.3, vmax=10.0))
                        # plt.colorbar()
                        if subplot_index > 49:
                            break
                time_sequence.append(sparse.csr_matrix(maps)) # Save each event as a CSR sparse matrix
            if PLOT_HITMAP:
                plt.tight_layout()
                plt.savefig("event.png",dpi=600)
                plt.show()
                assert 1==0
            eventd['event'] = time_sequence
            pickle.dump(eventd, handle, protocol=pickle.HIGHEST_PROTOCOL) # dump event into .pickle file
        print("Number of Events: ", numev)
    return 0




def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input", default="/projectnb2/snoplus/KLZ_NEW2/machine_learning/CDM/CDM_deltaM16.5_XeLS_8.root")
    parser.add_argument("--outputdir", default="/projectnb/snoplus/sphere_data/c10_2MeV")
    parser.add_argument("--pmt_file_index", default="/project/snoplus/ml2/data/pmt_xyz.dat")
    parser.add_argument("--pmt_file_size", default="/projectnb/snoplus/machine_learning/prototype/pmt.txt")
    parser.add_argument("--process_index", type=int, default=-1)
    parser.add_argument("--elow", type=float, default=2.0)
    parser.add_argument("--ehi", type=float, default=3.0)
    args = parser.parse_args()

    position = PMT_setup(args.pmt_file_index)

    fmc = transcribe_hits(input=args.input, outputdir=args.outputdir, PMT_POSITION = position,elow=args.elow, ehi=args.ehi)





main()