Commit b5881ee2 authored by maming's avatar maming
Browse files

Initial commit

parents
import numpy as np
from numpy.fft import fft2, ifft2, fftshift, ifftshift
from .FFTBase import FFTBase
class T2FFT(FFTBase):
"""
The Fast Fourier Transform on the 2-Torus.
REMOVE?
The torus is parameterized by two cyclic variables (x, y).
The standard domain is (x, y) in [0, 1) x [0, 1), in which case the Fourier basis functions are:
exp( i 2 pi xi^T (x; y))
where xi is the spectral variable, xi in Z^2.
The Fourier transform is
\hat{f}[p, q] = 1/2pi int_0^2pi f(x, y) exp(-i 2 pi xi^T (x; y)) dx dy
but this class allows one to use arbitrarily scaled and shifted domains D = [l_x, u_x) x [l_y, u_y)
Let the width of the domain be given by
alpha_x = u_x - l_x
alpha_y = u_y - l_y
The basis functions on [l_x, u_x) x [l_y, u_y) are
exp( i 2 pi xi^T ((x - l_x) / alpha_x; (y - l_y) / alpha_y))
where xi is the spectral variable, xi in Z^2.
The normalized Haar measure is dx dy / (alpha_x * alpha_y) (in terms of Lebesque measure dx dy)
So the Fourier transform on this particular parameterization of the torus is:
\hat{f}_pq = 1/alpha int_lx^ux int_ly^uy f(x) e^{-2 pi i (p, q)^T ((x - lx) / alpha_x; (y - ly)/alpha_y)} dx dy
This is what the current class computes, given discrete samples in the domain D.
The samples are assumed to come from the following sampling grid:
(x_i, y_j), i = 0, ... N - 1; j = 0, ..., N - 1
x_i = lx + alpha_x * (i / N_x)
y_i = ly + alpha_y * (i / N_y)
this is the ouput of
x = np.linspace(lx, ux, N_x, endpoint=False)
x = np.linspace(ly, uy, N_y, endpoint=False)
X, Y = np.meshgrid(x, y)
"""
def __init__(self, lower_bound=(0., 0.), upper_bound=(1., 1.)):
self.lower_bound = np.array(lower_bound)
self.upper_bound = np.array(upper_bound)
@staticmethod
def analyze(f, axes=(0, 1)):
"""
Compute the Fourier Transform of the discretely sampled function f : T^2 -> C.
Let f : T^2 -> C be a band-limited function on the torus.
The samples f(theta_k, phi_l) correspond to points on a regular grid on the circle,
as returned by spaces.T1.linspace:
theta_k = phi_k = 2 pi k / N
for k = 0, ..., N - 1 and l = 0, ..., N - 1
This function computes
\hat{f}_n = (1/N) \sum_{k=0}^{N-1} f(theta_k) e^{-i n theta_k}
which, if f has band-limit less than N, is equal to:
\hat{f}_n = \int_0^{2pi} f(theta) e^{-i n theta} dtheta / 2pi,
= <f(theta), e^{i n theta}>
where dtheta / 2pi is the normalized Haar measure on T^1, and < , > denotes the inner product on Hilbert space,
with respect to which this transform is unitary.
The range of frequencies n is -floor(N/2) <= n <= ceil(N/2) - 1
:param f:
:param axis:
:return:
"""
# The numpy FFT returns coefficients in a different order than we want them,
# and using a different normalization.
f_hat = fft2(f, axes=axes)
f_hat = fftshift(f_hat, axes=axes)
size = np.prod([f.shape[ax] for ax in axes])
return f_hat / size
@staticmethod
def synthesize(f_hat, axes=(0, 1)):
"""
:param f_hat:
:param axis:
:return:
"""
size = np.prod([f_hat.shape[ax] for ax in axes])
f_hat = ifftshift(f_hat * size, axes=axes)
f = ifft2(f_hat, axes=axes)
return f
import numpy as np
from pynfft import nfft
from pynfft.solver import Solver
from .T2FFT import T2FFT
class FourierInterpolator(object):
def __init__(self, cartesian_grid_shape, nonequispaced_grid):
"""
The FourierInterpolator can interpolate data on an equispaced Cartesian grid to a non-equispaced grid.
The inpterpolation works by first computing the Fourier coefficients of the input grid, and then evaluating
the Fourier series defined by those coefficients at the non-equispaced grid.
This operation is exactly invertible, as long as the Fourier coefficients are recoverable
from the non-equispaced output samples.
:param cartesian_grid_shape: the shape (nx, ny) of the input grid.
Samples are assumed to be in [-.5, .5) x [-.5, .5)
:param nonequispaced_grid: the output grid points. Shape (M, 2)
"""
self.cartesian_grid_shape = cartesian_grid_shape
self.nonequispaced_grid_shape = nonequispaced_grid.shape[:-1]
self.nonequispaced_grid = nonequispaced_grid.reshape(-1, 2)
self.nfft = nfft.NFFT(N=cartesian_grid_shape, M=np.prod(nonequispaced_grid.shape[:-1]),
n=None, m=12, flags=None)
self.nfft.x = self.nonequispaced_grid
self.nfft.precompute()
self.solver = Solver(self.nfft)
@staticmethod
def init_cartesian_to_polar(nr, nt, nx, ny):
# On the computation of the polar FFT
# Markus Fenn, Stefan Kunis, Daniel Potts
r = np.linspace(0, 1. / np.sqrt(2), nr) # radius = sqrt((0 - 0.5)^2 + (0 - 0.5)^2) = sqrt(0.5) = 1/sqrt(2)
t = np.linspace(0, 2 * np.pi, nt, endpoint=False)
R, T = np.meshgrid(r, t, indexing='ij')
X = R * np.cos(T)
Y = R * np.sin(T)
C = np.c_[X[..., None], Y[..., None]]
return FourierInterpolator(cartesian_grid_shape=(nx, ny), nonequispaced_grid=C)
def forward(self, f):
"""
:param f:
:return:
"""
# Fourier transform x:
# Perform a regular FFT:
f_hat = T2FFT.analyze(f)
print(f_hat)
# Since this equispaced FFT assumes spatial samples in theta_k in [0, 1)
# [assuming basis functions exp(i 2 pi n theta), not exp(i n theta)],
# we shift by 0.5, i.e. multiply by exp(-i pi n) = (-1)^n
f_hat *= ((-1) ** np.arange(-np.floor(f.shape[0] / 2), np.ceil(f.shape[0] / 2)))[:, None]
f_hat *= ((-1) ** np.arange(-np.floor(f.shape[1] / 2), np.ceil(f.shape[1] / 2)))[None, :]
print(f_hat)
# Use NFFT to evaluate the function defined by these Fourier coefficients at the non-equispaced output grid
self.nfft.f_hat = f_hat
f_resampled = self.nfft.trafo().reshape(self.nonequispaced_grid_shape).copy()
print(f_resampled)
return f_resampled
def backward(self, f):
"""
:param f:
:return:
"""
self.solver.y = f
self.solver.before_loop()
for i in range(40):
self.solver.loop_one_step()
f_hat = self.solver.f_hat_iter
# Since this equispaced FFT assumes spatial samples in theta_k in [0, 1)
# [assuming basis functions exp(i 2 pi n theta), not exp(i n theta)],
# we shift by 0.5, i.e. multiply by exp(-i pi n) = (-1)^n
f_hat /= ((-1) ** np.arange(-np.floor(f_hat.shape[0] / 2), np.ceil(f_hat.shape[0] / 2)))[:, None]
f_hat /= ((-1) ** np.arange(-np.floor(f_hat.shape[1] / 2), np.ceil(f_hat.shape[1] / 2)))[None, :]
f = T2FFT.synthesize(f_hat)
return f
def test2():
nr = 100
nt = 100
nx = 20
ny = 20
F = FourierInterpolator.init_cartesian_to_polar(nr=nr, nt=nt, nx=nx, ny=ny)
X, Y = np.meshgrid(np.linspace(-0.5, 0.5, nx, endpoint=False), np.linspace(-0.5, 0.5, ny, endpoint=False),
indexing='ij')
f = np.exp(2*np.pi*1j*(X+0.5)) # + np.exp(2*np.pi * 1j * (3*(Y+0.5)))
C = np.c_[X[..., None], Y[..., None]].reshape(-1, 2)
#F = FourierInterpolator(grid_in_shape=X.shape, grid_out=C)
print('aa')
fp = F.forward(f)
fr = F.backward(fp)
return F, f, fp, fr
def test(sx=0, sy=0):
nx = 33
ny = 37
nt = 16
nr = 16
f = np.zeros((nx, ny), dtype='complex')
F = nfft.NFFT(N=(nx, ny), M=nx * ny)
X, Y = np.meshgrid(np.linspace(-0.5, 0.5, nx, endpoint=False), np.linspace(-0.5, 0.5, ny, endpoint=False),
indexing='ij')
f = np.exp(2*np.pi*1j*(X+0.5))
F.x = np.c_[X[..., None], Y[..., None]].reshape(-1, 2)
F.precompute()
f_hat = T2FFT.analyze(f)
tf_hat = f_hat.copy()
tf_hat *= np.exp((2. * np.pi * 1j * sx * np.arange(-np.floor(f.shape[0] / 2.), np.ceil(f.shape[0] / 2.))[:, None]) / f.shape[0])
tf_hat *= np.exp((2. * np.pi * 1j * sy * np.arange(-np.floor(f.shape[1] / 2.), np.ceil(f.shape[1] / 2.))[None, :]) / f.shape[1])
F.f_hat = f_hat.conj()
f_reconst1 = F.trafo().copy().conj()
F.f_hat = tf_hat.conj()
f_reconst2 = F.trafo().copy().conj()
return F, f, f_hat, tf_hat, f_reconst1, f_reconst2
\ No newline at end of file
[build-system]
requires = [
'setuptools',
'setuptools-scm',
'cython',
'numpy ; python_version>="3.0"',
'numpy<1.17 ; python_version<"3.0"',
]
build-backend = "setuptools.build_meta"
[project]
name = "lie_learn"
version = "0.0.2"
description = "A python package that knows how to do various tricky computations related to Lie groups and manifolds (mainly the sphere S2 and rotation group SO3)."
readme = "README.md"
license = {file = "LICENSE"}
requires-python = ">2.7,!=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*"
classifiers = [
"Programming Language :: Python :: 2.7",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]
dependencies = [
'requests',
'numpy ; python_version>="3.0"',
'scipy ; python_version>="3.0"',
'numpy<1.17 ; python_version<"3.0"',
'scipy<1.3 ; python_version<"3.0"',
# 'pynfft': # This installation is complicated. Do it yourself.
]
[project.urls]
'Source Code' = "https://github.com/AMLab-Amsterdam/lie_learn"
# pylint: disable=missing-docstring
import numpy as np
from Cython.Build import cythonize
from setuptools import setup, find_packages
setup(
ext_modules=cythonize('lie_learn/**/*.pyx', language_level=2),
include_dirs=[np.get_include()],
)
import numpy as np
import lie_learn.spaces.S3 as S3
from lie_learn.representations.SO3.wigner_d import wigner_D_function
def test_S3_quadint_equals_numint():
"""Test if SO(3) quadrature integration gives the same result as scipy numerical integration"""
b = 10
for l in range(2):
for m in range(-l, l + 1):
for n in range(-l, l + 1):
check_S3_quadint_equals_numint(l, m, n, b)
def check_S3_quadint_equals_numint(l=1, m=1, n=1, b=10):
# Create grids on the sphere
x = S3.meshgrid(b=b, grid_type='SOFT')
x = np.c_[x[0][..., None], x[1][..., None], x[2][..., None]]
# Compute quadrature weights
w = S3.quadrature_weights(b=b, grid_type='SOFT')
# Define a polynomial function, to be evaluated at one point or at an array of points
def f1(alpha, beta, gamma):
df = wigner_D_function(l=l, m=m, n=n, alpha=alpha, beta=beta, gamma=gamma)
return df * df.conj()
def f1a(xs):
d = np.zeros(x.shape[:-1])
for i in range(d.shape[0]):
for j in range(d.shape[1]):
for k in range(d.shape[2]):
d[i, j, k] = f1(xs[i, j, k, 0], xs[i, j, k, 1], xs[i, j, k, 2])
return d
# Obtain the "true" value of the integral of the function over the sphere, using scipy's numerical integration
# routines
i1 = S3.integrate(f1, normalize=True)
# Compute the integral using the quadrature formulae
i1_w = S3.integrate_quad(f1a(x), grid_type='SOFT', normalize=True, w=w)
# Check error
print(b, l, m, n, 'results:', i1_w, i1, 'diff:', np.abs(i1_w - i1))
assert np.isclose(np.abs(i1_w - i1), 0.0)
import lie_learn.spaces.S2 as S2
import numpy as np
def test_spherical_quadrature():
"""
Testing spherical quadrature rule versus numerical integration.
"""
b = 8 # 10
# Create grids on the sphere
x_gl = S2.meshgrid(b=b, grid_type='Gauss-Legendre')
x_cc = S2.meshgrid(b=b, grid_type='Clenshaw-Curtis')
x_soft = S2.meshgrid(b=b, grid_type='SOFT')
x_gl = np.c_[x_gl[0][..., None], x_gl[1][..., None]]
x_cc = np.c_[x_cc[0][..., None], x_cc[1][..., None]]
x_soft = np.c_[x_soft[0][..., None], x_soft[1][..., None]]
# Compute quadrature weights
w_gl = S2.quadrature_weights(b=b, grid_type='Gauss-Legendre')
w_cc = S2.quadrature_weights(b=b, grid_type='Clenshaw-Curtis')
w_soft = S2.quadrature_weights(b=b, grid_type='SOFT')
# Define a polynomial function, to be evaluated at one point or at an array of points
def f1a(xs):
xc = S2.change_coordinates(coords=xs, p_from='S', p_to='C')
return xc[..., 0] ** 2 * xc[..., 1] - 1.4 * xc[..., 2] * xc[..., 1] ** 3 + xc[..., 1] - xc[..., 2] ** 2 + 2.
def f1(theta, phi):
xs = np.array([theta, phi])
return f1a(xs)
# Obtain the "true" value of the integral of the function over the sphere, using scipy's numerical integration
# routines
i1 = S2.integrate(f1, normalize=False)
# Compute the integral using the quadrature formulae
# i1_gl_w = (w_gl * f1a(x_gl)).sum()
i1_gl_w = S2.integrate_quad(f1a(x_gl), grid_type='Gauss-Legendre', normalize=False, w=w_gl)
print(i1_gl_w, i1, 'diff:', np.abs(i1_gl_w - i1))
assert np.isclose(np.abs(i1_gl_w - i1), 0.0)
# i1_cc_w = (w_cc * f1a(x_cc)).sum()
i1_cc_w = S2.integrate_quad(f1a(x_cc), grid_type='Clenshaw-Curtis', normalize=False, w=w_cc)
print(i1_cc_w, i1, 'diff:', np.abs(i1_cc_w - i1))
assert np.isclose(np.abs(i1_cc_w - i1), 0.0)
i1_soft_w = (w_soft * f1a(x_soft)).sum()
print(i1_soft_w, i1, 'diff:', np.abs(i1_soft_w - i1))
print(i1_soft_w)
print(i1)
# assert np.isclose(np.abs(i1_cc_w - i1), 0.0) # TODO
import numpy as np
import lie_learn.spaces.S2 as S2
from lie_learn.representations.SO3.spherical_harmonics import sh
from lie_learn.spectral.S2FFT import setup_legendre_transform, setup_legendre_transform_indices, sphere_fft, S2_FT_Naive
def test_S2_FT_Naive():
L_max = 6
for grid_type in ('Gauss-Legendre', 'Clenshaw-Curtis'):
theta, phi = S2.meshgrid(b=L_max + 1, grid_type=grid_type)
for field in ('real', 'complex'):
for normalization in ('quantum', 'seismology'): # TODO Others should work but are not normalized
for condon_shortley in ('cs', 'nocs'):
fft = S2_FT_Naive(L_max, grid_type=grid_type,
field=field, normalization=normalization, condon_shortley=condon_shortley)
for l in range(L_max):
for m in range(-l, l + 1):
y_true = sh(
l, m, theta, phi,
field=field, normalization=normalization, condon_shortley=condon_shortley == 'cs')
y_hat = fft.analyze(y_true)
# The flat index for (l, m) is l^2 + l + m
# Before the harmonics of degree l, there are this many harmonics:
# sum_{i=0}^{l-1} 2i+1 = l^2
# There are 2l+1 harmonics of degree l, with order m=0 at the center,
# so the m-th harmonic of degree is at l + m within the block of degree l.
y_hat_true = np.zeros_like(y_hat)
y_hat_true[l**2 + l + m] = 1
y = fft.synthesize(y_hat_true)
diff = np.sum(np.abs(y_hat - y_hat_true))
print(grid_type, field, normalization, condon_shortley, l, m, diff)
assert np.isclose(diff, 0.)
diff = np.sum(np.abs(y - y_true))
print(grid_type, field, normalization, condon_shortley, l, m, diff)
assert np.isclose(diff, 0.)
def test_S2FFT():
L_max = 10
beta, alpha = S2.meshgrid(b=L_max + 1, grid_type='Driscoll-Healy')
lt = setup_legendre_transform(b=L_max + 1)
lti = setup_legendre_transform_indices(b=L_max + 1)
for l in range(L_max):
for m in range(-l, l + 1):
Y = sh(l, m, beta, alpha,
field='complex', normalization='seismology', condon_shortley=True)
y_hat = sphere_fft(Y, lt, lti)
# The flat index for (l, m) is l^2 + l + m
# Before the harmonics of degree l, there are this many harmonics: sum_{i=0}^{l-1} 2i+1 = l^2
# There are 2l+1 harmonics of degree l, with order m=0 at the center,
# so the m-th harmonic of degree is at l + m within the block of degree l.
y_hat_true = np.zeros_like(y_hat)
y_hat_true[l**2 + l + m] = 1
diff = np.sum(np.abs(y_hat - y_hat_true))
nz = 1. - np.isclose(y_hat, 0.)
diff_nz = np.sum(np.abs(nz - y_hat_true))
print(l, m, diff, diff_nz)
print(np.round(y_hat, 4))
print(y_hat_true)
# assert np.isclose(diff, 0.) # TODO make this work
print(nz)
assert np.isclose(diff_nz, 0.)
import lie_learn.spaces.S2 as S2
from lie_learn.spectral.S2FFT_NFFT import S2FFT_NFFT
from lie_learn.representations.SO3.spherical_harmonics import *
def test_S2FFT_NFFT():
"""
Testing S2FFT NFFT
"""
b = 8
convention = 'Gauss-Legendre'
#convention = 'Clenshaw-Curtis'
x = S2.meshgrid(b=b, grid_type=convention)
print(x[0].shape, x[1].shape)
x = np.c_[x[0][..., None], x[1][..., None]]#.reshape(-1, 2)
print(x.shape)
x = x.reshape(-1, 2)
w = S2.quadrature_weights(b=b, grid_type=convention).flatten()
F = S2FFT_NFFT(L_max=b, x=x, w=w)
for l in range(0, b):
for m in range(-l, l + 1):
#l = b; m = b
f = sh(l, m, x[..., 0], x[..., 1], field='real', normalization='quantum', condon_shortley=True)
#f2 = np.random.randn(*f.shape)
print(f)
f_hat = F.analyze(f)
print(np.round(f_hat, 3))
f_reconst = F.synthesize(f_hat)
#print np.round(f, 3)
print(np.round(f_reconst, 3))
#print np.round(f/f_reconst, 3)
print(np.abs(f-f_reconst).sum())
assert np.isclose(np.abs(f-f_reconst).sum(), 0.)
print(np.round(f_hat, 3))
assert np.isclose(f_hat[l ** 2 + l + m], 1.)
#assert False
\ No newline at end of file
import numpy as np
import lie_learn.spaces.S2 as S2
import lie_learn.spaces.S3 as S3
import lie_learn.groups.SO3 as SO3
from lie_learn.representations.SO3.spherical_harmonics import sh
from lie_learn.spectral.S2_conv import naive_S2_conv, spectral_S2_conv, naive_S2_conv_v2
def compare_naive_and_spectral_conv():
f1 = lambda t, p: sh(l=2, m=1, theta=t, phi=p, field='real', normalization='quantum', condon_shortley=True)
f2 = lambda t, p: sh(l=2, m=1, theta=t, phi=p, field='real', normalization='quantum', condon_shortley=True)
theta, phi = S2.meshgrid(b=4, grid_type='Gauss-Legendre')
f1_grid = f1(theta, phi)
f2_grid = f2(theta, phi)
alpha, beta, gamma = S3.meshgrid(b=4, grid_type='SOFT') # TODO check convention
f12_grid_spectral = spectral_S2_conv(f1_grid, f2_grid, s2_fft=None, so3_fft=None)
f12_grid = np.zeros_like(alpha)
for i in range(alpha.shape[0]):
for j in range(alpha.shape[1]):
for k in range(alpha.shape[2]):
f12_grid[i, j, k] = naive_S2_conv(f1, f2, alpha[i, j, k], beta[i, j, k], gamma[i, j, k])
print(i, j, k, f12_grid[i, j, k])
return f1_grid, f2_grid, f12_grid, f12_grid_spectral
def naive_conv(l1=1, m1=1, l2=1, m2=1, g_parameterization='EA313'):
f1 = lambda t, p: sh(l=l1, m=m1, theta=t, phi=p, field='real', normalization='quantum', condon_shortley=True)
f2 = lambda t, p: sh(l=l2, m=m2, theta=t, phi=p, field='real', normalization='quantum', condon_shortley=True)
theta, phi = S2.meshgrid(b=3, grid_type='Gauss-Legendre')
f1_grid = f1(theta, phi)
f2_grid = f2(theta, phi)
alpha, beta, gamma = S3.meshgrid(b=3, grid_type='SOFT') # TODO check convention
f12_grid = np.zeros_like(alpha)
for i in range(alpha.shape[0]):
for j in range(alpha.shape[1]):
for k in range(alpha.shape[2]):
f12_grid[i, j, k] = naive_S2_conv_v2(f1, f2, alpha[i, j, k], beta[i, j, k], gamma[i, j, k], g_parameterization)
print(i, j, k, f12_grid[i, j, k])
return f1_grid, f2_grid, f12_grid
import numpy as np
from lie_learn.spectral.SO3FFT_Naive import SO3_FFT_NaiveReal, SO3_FFT_SemiNaive_Complex, SO3_FT_Naive
from lie_learn.representations.SO3.pinchon_hoggan.pinchon_hoggan_dense import Jd, rot_mat
from lie_learn.representations.SO3.irrep_bases import change_of_basis_matrix
# TODO: test if the Fourier transform of a right SO(2)-invariant function is zero except for a column at n=0, and
# test if it is equal to the spherical harmonics transform of the corresponding function on the sphere
def test_SO3_FT_Naive():
"""
Check that the naive complex SO(3) FFT:
- Produces the right Wigner-D function when given a 1-hot input to the synthesis transform
- Produces a 1-hot vector when given a single Wigner-D function to the analysis transform
"""
L_max = 3
f_hat = [np.zeros((2 * ll + 1, 2 * ll + 1)) for ll in range(L_max + 1)]
# TODO: the SO3_FFT_SemiNaive_Complex no longer uses the D convention parameters because of new caching feature
field = 'complex'
order = 'centered'
for normalization in ('quantum', 'seismology'): # Note: the geodesy and nfft wigners are normalized differently
for condon_shortley in ('cs', 'nocs'):
fft = SO3_FT_Naive(L_max=L_max,
field=field, normalization=normalization,
order=order, condon_shortley=condon_shortley)
for l in range(L_max + 1):
for m in range(-l, l + 1):
for n in range(-l, l + 1):
f_hat[l][l + m, l + n] = 1. / (2 * l + 1)
f_hat_flat = np.hstack([fhl.flatten() for fhl in f_hat])
D = fft.synthesize_by_matmul(f_hat_flat)
D2 = make_D_sample_grid(b=L_max + 1, l=l, m=m, n=n,
field=field, normalization=normalization,
order=order, condon_shortley=condon_shortley)
diff = np.sum(np.abs(D - D2.flatten()))
print(l, m, n, 'Synthesize error:', diff)
assert np.isclose(diff, 0.0)
f_hat_2 = fft.analyze_by_matmul(D2)
# f_hat_flat = np.hstack([ff.flatten() for ff in f_hat])
f_hat_2_flat = np.hstack([ff.flatten() for ff in f_hat_2])
# f_hat_2_flat *= (2 * l + 1) # / (4 * np.pi) # apply magic constant TODO fix this
print(f_hat_2_flat)
print(f_hat_flat)
print(np.max(np.abs(f_hat_flat)), np.max(np.abs(f_hat_2_flat)))
diff = np.sum(np.abs(f_hat_flat - f_hat_2_flat))
print(l, m, n, 'Analyze error:', diff)
assert np.isclose(diff, 0.0)
f_hat[l][l + m, l + n] = 0.
def test_SO3_FFT_SemiNaiveComplex():
"""
Check that the naive complex SO(3) FFT:
- Produces the right Wigner-D function when given a 1-hot input to the synthesis transform
- Produces a 1-hot vector when given a single Wigner-D function to the analysis transform
"""
L_max = 3
f_hat = [np.zeros((2 * ll + 1, 2 * ll + 1)) for ll in range(L_max + 1)]
# TODO: the SO3_FFT_SemiNaive_Complex no longer uses the D convention parameters because of new caching feature
field = 'complex'
order = 'centered'
for normalization in ('quantum', 'seismology'): # Note: the geodesy and nfft wigners are normalized differently
for condon_shortley in ('cs', 'nocs'):
fft = SO3_FFT_SemiNaive_Complex(L_max=L_max, L2_normalized=False,
field=field, normalization=normalization,
order=order, condon_shortley=condon_shortley)
#fft = SO3_FFT_Naive(L_max=L_max,
# field=field, normalization=normalization,
# order=order, condon_shortley=condon_shortley)
for l in range(L_max + 1):
for m in range(-l, l + 1):
for n in range(-l, l + 1):
f_hat[l][l + m, l + n] = 1.
D = fft.synthesize(f_hat)
D2 = make_D_sample_grid(b=L_max + 1, l=l, m=m, n=n,
field=field, normalization=normalization,
order=order, condon_shortley=condon_shortley)
diff = np.sum(np.abs(D - D2))
print(l, m, n, diff)
assert np.isclose(diff, 0.0)
f_hat_2 = fft.analyze(D2)
f_hat_flat = np.hstack([ff.flatten() for ff in f_hat])
f_hat_2_flat = np.hstack([ff.flatten() for ff in f_hat_2])
f_hat_2_flat *= (2 * l + 1) / (4 * np.pi) # apply magic constant TODO fix this
diff = np.sum(np.abs(f_hat_flat - f_hat_2_flat))
print(l, m, n, diff)
assert np.isclose(diff, 0.0)
f_hat[l][l + m, l + n] = 0.
# TODO: test linearity of FFT
#TODO
def check_SO3_FFT_NaiveComplex_invertible():
L_max = 3
f_hat = [np.zeros((2 * ll + 1, 2 * ll + 1)) for ll in range(L_max + 1)]
fft = SO3_FFT_SemiNaive_Complex(L_max=L_max, L2_normalized=False)
for l in range(L_max + 1):
for m in range(-l, l + 1):
for n in range(-l, l + 1):
f_hat[l][l + m, l + n] = 1.
f = fft.synthesize(f_hat)
f_hat_2 = fft.analyze(f)
diff = np.sum([np.abs(f_hat[ll] - f_hat_2[ll]) for ll in range(L_max + 1)])
f_hat[l][l + m, l + n] = 0.
print(l, m, n, diff) # , D2 / D
assert np.isclose(diff, 0.0)
def test_SO3_FFT_NaiveReal():
"""
Testing if the real Naive SO(3) FFT synthesis works correctly for 1-hot input vectors
"""
L_max = 3
f_hat = [np.zeros((2 * ll + 1, 2 * ll + 1)) for ll in range(L_max + 1)]
fft = SO3_FFT_NaiveReal(L_max=L_max, L2_normalized=False)
for l in range(L_max + 1):
for m in range(-l, l + 1):
for n in range(-l, l + 1):
f_hat[l][l + m, l + n] = 1.
D = fft.synthesize(f_hat)
f_hat[l][l + m, l + n] = 0.
D2 = make_D_sample_grid(b=L_max + 1, l=l, m=m, n=n,
field='real', normalization='quantum', order='centered', condon_shortley='cs')
print(l, m, n, np.sum(np.abs(D - D2)))
assert np.isclose(np.sum(np.abs(D - D2)), 0.0)
def make_D_sample_grid(b=4, l=0, m=0, n=0,
field='complex', normalization='seismology', order='centered', condon_shortley='cs'):
from lie_learn.representations.SO3.wigner_d import wigner_D_function
D = lambda a, b, c: wigner_D_function(l, m, n, alpha, beta, gamma,
field=field, normalization=normalization,
order=order, condon_shortley=condon_shortley)
f = np.zeros((2 * b, 2 * b, 2 * b), dtype='complex')
for j1 in range(f.shape[0]):
alpha = 2 * np.pi * j1 / (2. * b)
for k in range(f.shape[1]):
beta = np.pi * (2 * k + 1) / (4. * b)
for j2 in range(f.shape[2]):
gamma = 2 * np.pi * j2 / (2. * b)
f[j1, k, j2] = D(alpha, beta, gamma)
return f
import lie_learn.spaces.S2 as S2
from lie_learn.spaces.S3 import change_coordinates
#=====================================================================================
# Author: Aobo Li
# Contact: liaobo77@gmail.com
#
# Last Modified: Aug. 29, 2021
#
# * This file contains Convolutional LSTM module enhanced with attention mechanism.
# * Prototype code from https://github.com/ndrplz/ConvLSTM_pytorch
# * Attention is added into the original code
# * returns the context image instead of standard LSTM output (output,(hidden,cell))
#=====================================================================================
import torch.nn as nn
import torch
import torchsnooper
from torch.nn.parameter import Parameter
class ConvLSTMCell(nn.Module):
def __init__(self, input_dim, hidden_dim, kernel_size, bias):
"""
Initialize ConvLSTM cell.
Parameters
----------
input_dim: int
Number of channels of input tensor.
hidden_dim: int
Number of channels of hidden state.
kernel_size: (int, int)
Size of the convolutional kernel.
bias: bool
Whether or not to add the bias.
"""
super(ConvLSTMCell, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.kernel_size = kernel_size
self.padding = kernel_size[0] // 2, kernel_size[1] // 2
self.bias = bias
print(self.input_dim + self.hidden_dim, 4 * self.hidden_dim)
self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
out_channels=4 * self.hidden_dim,
kernel_size=self.kernel_size,
padding=self.padding,
bias=self.bias)
def forward(self, input_tensor, cur_state):
h_cur, c_cur = cur_state
combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis
combined_conv = self.conv(combined)
cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
i = torch.sigmoid(cc_i)
f = torch.sigmoid(cc_f)
o = torch.sigmoid(cc_o)
g = torch.tanh(cc_g)
c_next = f * c_cur + i * g
h_next = o * torch.tanh(c_next)
return h_next, c_next
def init_hidden(self, batch_size, image_size):
height, width = image_size
return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))
class ConvLSTM(nn.Module):
"""
Parameters:
input_dim: Number of channels in input
hidden_dim: Number of hidden channels
kernel_size: Size of kernel in convolutions
num_layers: Number of LSTM layers stacked on each other
batch_first: Whether or not dimension 0 is the batch or not
bias: Bias or no bias in Convolution
return_all_layers: Return the list of computations for all layers
Note: Will do same padding.
Input:
A tensor of size B(batch size), T(time channel), C(hidden state channel), H(height), W(width) or T, B, C, H, W
Output:
A tuple of two lists of length num_layers (or length 1 if return_all_layers is False).
0 - layer_output_list is the list of lists of length T of each output
1 - last_state_list is the list of last states
each element of the list is a tuple (h, c) for hidden state and memory
Example:
>> x = torch.rand((32, 10, 64, 128, 128))
>> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False)
>> _, last_states = convlstm(x)
>> h = last_states[0][0] # 0 for layer index, 0 for h index
"""
def __init__(self, input_dim, hidden_dim, kernel_size, num_layers, time_channel,
batch_first=False, bias=True, return_all_layers=False, return_hidden_and_context = False, fill_value=0.1):
super(ConvLSTM, self).__init__()
self._check_kernel_size_consistency(kernel_size)
# Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
if not len(kernel_size) == len(hidden_dim) == num_layers:
raise ValueError('Inconsistent list length.')
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.kernel_size = kernel_size
self.num_layers = num_layers
self.batch_first = batch_first
self.bias = bias
self.return_all_layers = return_all_layers
self.return_hidden_and_context = return_hidden_and_context
# Initialize the attention weight of attention mechanism, and fill it with random values between (-fill_value, fill_value)
# The dimmension of attention weight is (T, C, H, W) (remember that H=W in our case)
self.attention_weight = Parameter(torch.empty(time_channel[0], hidden_dim[-1], time_channel[1],time_channel[1]).uniform_(-fill_value, fill_value))
cell_list = []
for i in range(0, self.num_layers):
cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]
cell_list.append(ConvLSTMCell(input_dim=cur_input_dim,
hidden_dim=self.hidden_dim[i],
kernel_size=self.kernel_size[i],
bias=self.bias))
self.cell_list = nn.ModuleList(cell_list)
self.input_array = 0
self.output_score = 0
#@torchsnooper.snoop()
def forward(self, input_tensor, hidden_state=None, att=False):
"""
Parameters
----------
input_tensor: todo
5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w)
hidden_state: todo
None. todo implement stateful
Returns
-------
last_state_list, layer_output
"""
if not self.batch_first:
# (t, b, c, h, w) -> (b, t, c, h, w)
input_tensor = input_tensor.permute(1, 0, 2, 3, 4)
b, _, _, h, w = input_tensor.size()
# Implement stateful ConvLSTM
if hidden_state is not None:
raise NotImplementedError()
else:
# Since the init is done in forward. Can send image size here
hidden_state = self._init_hidden(batch_size=b,
image_size=(h, w))
layer_output_list = []
last_state_list = []
seq_len = input_tensor.size(1)
cur_layer_input = input_tensor
for layer_idx in range(self.num_layers):
h, c = hidden_state[layer_idx]
output_inner = []
for t in range(seq_len):
h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :],
cur_state=[h, c])
output_inner.append(h)
layer_output = torch.stack(output_inner, dim=1)
cur_layer_input = layer_output
layer_output_list.append(layer_output)
last_state_list.append([h, c])
if not self.return_all_layers:
layer_output_list = layer_output_list[-1:]
last_state_list = last_state_list[-1:]
# Using the attention mechanism to produce context images
hs = layer_output_list[-1] # (B, T, C, H, W)
ht = last_state_list[-1][0] # (B, C, H, W)
w_attention = self.attention_weight.unsqueeze(0).expand(*hs.size()) # (T, C, H, W) -> (B, T, C, H, W)
ht_input = ht.unsqueeze(1).expand(*hs.size()) # (B, C, H, W) -> (B, T, C, H, W)
score_input = hs*w_attention*ht_input # (B, T, C, H, W)
score_input = score_input.permute(0,2,1,3,4) # (B, T, C, H, W) -> (B, C, T, H, W)
hs = hs.permute(0,2,1,3,4)
score = torch.softmax(score_input,dim=2) # (B, C, T, H, W) -> (B, C, H, W)
context = torch.sum(score * hs,dim=2) # (B, C, T, H, W) -> (B, C, H, W)
output = context
if self.return_hidden_and_context:
'''
Return context vector for training/validation
'''
output = torch.cat([context,ht],dim=1)
if att:
'''
Return the attention score for network interpretability study
'''
self.input_array, self.output_score = self.return_attention_score(input_tensor, score)
self.input_array = self.input_array.detach()
self.output_score = self.output_score.detach()
return self.input_array, self.output_score
return output
def _init_hidden(self, batch_size, image_size):
init_states = []
for i in range(self.num_layers):
init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))
return init_states
# @torchsnooper.snoop()
def return_attention_score(self,input_event,score):
'''
Read out the total attention score for given input events
'''
b, t, _, _, _ = input_event.size()
input_array = input_event.view(b,t,-1)
input_array = torch.sum(input_array,dim=-1)
dim0, dim1, dim2 = (score.size(0),score.size(1),score.size(2))
output_score = score.view(dim0,dim1, dim2,-1)
output_score = torch.sum(output_score,dim=-1)
return input_array, output_score
@staticmethod
def _check_kernel_size_consistency(kernel_size):
if not (isinstance(kernel_size, tuple) or
(isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
raise ValueError('`kernel_size` must be tuple or list of tuples')
@staticmethod
def _extend_for_multilayer(param, num_layers):
if not isinstance(param, list):
param = [param] * num_layers
return param
\ No newline at end of file
#=====================================================================================
# Author: Aobo Li
# Contact: liaobo77@gmail.com
#
# Last Modified: Aug. 29, 2021
#
# * KamNet is a deep learning model developed for KamLAND-Zen and
# other spherical liquid scintillator detectors.
# * It attempts to harness all of the inherent symmetries to produce a
# state-of-the-art algorithms for a spherical liquid scintillator detector.
#=====================================================================================
# pylint: disable=E1101,R,C
import numpy as np
import os
import argparse
import time
import math
import random
import torch.nn as nn
import torch.optim as optim
from torch.nn.parameter import Parameter
from torch.nn import init
from s2cnn import SO3Convolution
from s2cnn import S2Convolution
from s2cnn import so3_integrate
from s2cnn import so3_near_identity_grid, so3_equatorial_grid
from s2cnn import s2_near_identity_grid, s2_equatorial_grid
import torch.nn.functional as F
import torch
import torch.utils.data as data_utils
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
import pickle
import numpy as np
import copy
from torch.autograd import Variable
from torchsummary import summary
from scipy import sparse
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.preprocessing import StandardScaler
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import torchsnooper
import pytorch_warmup as warmup
from matplotlib import cm
colormap_normal = cm.get_cmap("cool")
from torch.cuda.amp import autocast
from tqdm import tqdm
from AttentionConvLSTM import ConvLSTM
from KamNetDataset import DetectorDataset, DetectorDataset_Nhit, DetectorDataset_NonUniform, DetectorDatasetRep
from settings import SEED, NUM_EPOCHS, BATCH_SIZE, FILE_UPPERLIM, KAMNET_PARAMS, LEARNING_RATE, EV_SUFFIX, DSIZE
from tool import get_roc, get_rej, roc_nhit, cd
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if SEED:
'''
Setting reproducability. If SEED=True, then training the neural network with
the same configuration will result in exactly the same output
'''
manualSeed = 7
np.random.seed(manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
# if you are using GPU
torch.cuda.manual_seed(manualSeed)
torch.cuda.manual_seed_all(manualSeed)
torch.backends.cudnn.enabled = False
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def load_data(batch_size):
'''
Load datasets from various pickle list
'''
parser = argparse.ArgumentParser()
parser.add_argument("--signallist", type = str, default = "/projectnb/snoplus/machine_learning/data/training_log/kamdata/XeLS-Solar.pickle%s.dat"%(EV_SUFFIX))
parser.add_argument("--bglist", type = str, default = "/projectnb/snoplus/machine_learning/data/training_log/kamdata/Bi214_210320.pickle%s.dat"%(EV_SUFFIX))
parser.add_argument("--datalist", type = str, default = "/projectnb/snoplus/machine_learning/data/training_log/kamdata/XeLS-Solar.pickle%s.dat"%(EV_SUFFIX))
parser.add_argument("--datablist", type = str, default = "/projectnb/snoplus/machine_learning/data/training_log/kamdata/XeLS-I132.pickle%s.dat"%(EV_SUFFIX))
parser.add_argument("--signal", type = str, default = "Te130")
parser.add_argument("--bg", type = str, default = "C10")
parser.add_argument("--outdir", type = str, default = "/projectnb/snoplus/sphere_data/Xe136_C10_torch_new/")
parser.add_argument("--time_index", type = int, default = 8)
parser.add_argument("--qe_index", type = int, default = 10)
parser.add_argument("--elow", type = float, default = 2.0)
parser.add_argument("--ehi", type = float, default = 3.0)
args = parser.parse_args()
save_prefix = "/project/snoplus/ml2/network/"
# This is used when we perform pressume map study
time_index = args.time_index
qe_index = args.qe_index
json_name = str(time_index) + '_' + str(qe_index)
# This is used when we train KamNet with MC
json_name = "event"
# Read out each pickle list as a list of address
signal_images_list = [str(filename.strip()) for filename in list(open(args.signallist, 'r')) if filename != '']
bkg_image_list = [str(filename.strip()) for filename in list(open(args.bglist, 'r')) if filename != '']
data_list = [str(filename.strip()) for filename in list(open(args.datalist, 'r')) if filename != '']
datab_list = [str(filename.strip()) for filename in list(open(args.datablist, 'r')) if filename != '']
signal_images_list = signal_images_list[:FILE_UPPERLIM]
bkg_image_list = bkg_image_list[:FILE_UPPERLIM]
data_list = data_list[:FILE_UPPERLIM]
datab_list = datab_list[:FILE_UPPERLIM]
# Add different types of backgrounds to the bkg_image_dict, used for verifying KamNet on other background events
bkg_image_dict = {
"Sb118":[str(filename.strip()) for filename in list(open("/projectnb/snoplus/machine_learning/data/training_log/kamdata/XeLS-Sb118.pickle%s.dat"%(EV_SUFFIX), 'r'))][:FILE_UPPERLIM],
"I122":[str(filename.strip()) for filename in list(open("/projectnb/snoplus/machine_learning/data/training_log/kamdata/XeLS-I122.pickle%s.dat"%(EV_SUFFIX), 'r'))][:FILE_UPPERLIM],
"I124":[str(filename.strip()) for filename in list(open("/projectnb/snoplus/machine_learning/data/training_log/kamdata/XeLS-I124.pickle%s.dat"%(EV_SUFFIX), 'r'))][:FILE_UPPERLIM],
"I130":[str(filename.strip()) for filename in list(open("/projectnb/snoplus/machine_learning/data/training_log/kamdata/XeLS-I130.pickle%s.dat"%(EV_SUFFIX), 'r'))][:FILE_UPPERLIM],
"I132":[str(filename.strip()) for filename in list(open("/projectnb/snoplus/machine_learning/data/training_log/kamdata/XeLS-I132.pickle%s.dat"%(EV_SUFFIX), 'r'))][:FILE_UPPERLIM],
"Bi214-MC":[str(filename.strip()) for filename in list(open("/projectnb/snoplus/machine_learning/data/training_log/kamdata/XeLS-Bi214m.pickle%s.dat"%(EV_SUFFIX), 'r'))][:FILE_UPPERLIM],
"Bi214-film":[str(filename.strip()) for filename in list(open("/projectnb/snoplus/machine_learning/data/training_log/kamdata/film-Bi214m.pickle%s.dat"%(EV_SUFFIX), 'r'))][:FILE_UPPERLIM],
# "C10p":[str(filename.strip()) for filename in list(open("/projectnb/snoplus/machine_learning/data/training_log/kamdata/XeLS-C10p.pickle%s.dat"%(EV_SUFFIX), 'r'))][:FILE_UPPERLIM],
# "C10OP":[str(filename.strip()) for filename in list(open("/projectnb/snoplus/machine_learning/data/training_log/kamdata/XeLS-C10OP.pickle%s.dat"%(EV_SUFFIX), 'r'))][:FILE_UPPERLIM],
}
# Read out detector events to verify KamNet's performance
event_list = [str(filename.strip()) for filename in list(open("/projectnb/snoplus/machine_learning/data/training_log/kamdata/DB_untagged.pickle%s.dat"%(EV_SUFFIX), 'r')) if filename != '']
dataset = DetectorDataset_Nhit(data_list[:FILE_UPPERLIM], datab_list[:FILE_UPPERLIM], str(json_name),dsize=DSIZE,bootstrap=False)
validation_split = .3
shuffle_dataset = True
random_seed= 42222
division = 2
dataset_size = int(len(dataset)/division)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset :
# Shuffle the dataset
np.random.seed(random_seed)
np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]
train_indices += list(division*dataset_size - 1-np.array(train_indices))
val_indices += list(division*dataset_size- 1-np.array(val_indices))
np.random.shuffle(train_indices)
np.random.shuffle(val_indices)
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)
rtq_dataset = DetectorDatasetRep(signal_images_list[:FILE_UPPERLIM], bkg_image_dict, str(json_name))
test_dataset = DetectorDataset_NonUniform(event_list[:FILE_UPPERLIM], bkg_image_list[:FILE_UPPERLIM], str(json_name))
# Convert dataset to data loader
train_loader = data_utils.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler, drop_last=True)
eval_loader = data_utils.DataLoader(dataset, batch_size=batch_size, sampler=valid_sampler, drop_last=True)
test_loader = data_utils.DataLoader(test_dataset, batch_size=batch_size, drop_last=True)
data_loader = data_utils.DataLoader(rtq_dataset, batch_size=batch_size, drop_last=False)
return train_loader,eval_loader, test_loader,data_loader, dataset.return_time_channel(), save_prefix, args.outdir
class KamNet(nn.Module):
def __init__(self, time_channel):
super(KamNet, self).__init__()
param_dict = KAMNET_PARAMS # Store the hyperparameters for KamNet
# Initialize the grid for spherical CNN
grid_dict = {'s2_eq': s2_equatorial_grid, 's2_ni': s2_near_identity_grid, "so3_eq":so3_equatorial_grid, 'so3_ni':so3_near_identity_grid}
s2_grid_type = param_dict["s2gridtype"]
so3_grid_type = param_dict["so3gridtype"]
grid_s2 = grid_dict[s2_grid_type]()
grid_so3 = grid_dict[so3_grid_type]()
self.ftype = param_dict["ftype"]
# Number of neurons in spherical CNN
s2_1 = param_dict["s2_1"]
so3_2 = param_dict["so3_2"]
so3_3 = param_dict["so3_3"]
so3_4 = param_dict["so3_4"]
# Number of neurons in fully connected NN
fc1 = int(param_dict["fc_max"])
fc2 = int(param_dict["fc_max"] * 0.8)
fc3 = int(param_dict["fc_max"] * 0.4)
fc4 = int(param_dict["fc_max"] * 0.2)
fc5 = int(param_dict["fc_max"] * 0.05)
do1r = param_dict["do"]
do2r = param_dict["do"]
do3r = param_dict["do"]
do4r = param_dict["do"]
do5r = param_dict["do"]
do1r = min(max(do1r,0.0),1.0)
do2r = min(max(do2r,0.0),1.0)
do3r = min(max(do3r,0.0),1.0)
do4r = min(max(do4r,0.0),1.0)
do5r = min(max(do5r,0.0),1.0)
# Number of neurons in AttentionConvLSTM
s1 = param_dict["s1"]
s2 = param_dict["s2"]
# Last output of spherical CNN
last_entry = so3_4
# Last output of fully connected NN
last_fc_entry = fc5
# The spherical CNN bandwidth
last_bw = int(param_dict["last_bw"])
bw = np.linspace(int(time_channel[1]/2), last_bw, 5).astype(int)
#. Spherical CNN part of KamNet
self.conv1 = S2Convolution(
nfeature_in=s2,
nfeature_out=s2_1,
b_in=bw[0],
b_out=bw[1],
grid=grid_s2)
self.conv2 = SO3Convolution(
nfeature_in=s2_1,
nfeature_out=so3_2,
b_in=bw[1],
b_out=bw[2],
grid=grid_so3)
self.conv3 = SO3Convolution(
nfeature_in=so3_2,
nfeature_out=so3_3,
b_in=bw[2],
b_out=bw[3],
grid=grid_so3)
self.conv4 = SO3Convolution(
nfeature_in=so3_3,
nfeature_out=so3_4,
b_in=bw[3],
b_out=bw[4],
grid=grid_so3)
#. AttentionConvLSTM part of KamNet
self.convlstm1=ConvLSTM(1, [s1,s2], [(param_dict["first_filter"],param_dict["first_filter"]),(param_dict["second_filter"],param_dict["second_filter"])],2, time_channel,batch_first=True,fill_value=0.1)
if self.ftype == "SO3I":
# This means integrating the last spherical CNN output using the Haar measure as provided in the paper
self.fc_layer = nn.Linear(so3_4, fc1)
else:
# This means flattening the last spherical CNN output into a 1D vector (batch_size,flattened_dimension)
self.fc_layer = nn.Linear(so3_4*(2*last_bw)**3, fc1)
# Fully connected part of KamNet
self.fc_layer_2 = nn.Linear(fc1, fc2)
self.fc_layer_3 = nn.Linear(fc2, fc3)
self.fc_layer_4 = nn.Linear(fc3, fc4)
self.fc_layer_5 = nn.Linear(fc4, fc5)
self.norm_layer_3d_1 = nn.BatchNorm3d(s2_1)
self.norm_layer_3d_2 = nn.BatchNorm3d(so3_2)
self.norm_layer_3d_3 = nn.BatchNorm3d(so3_3)
self.norm_layer_3d_4 = nn.BatchNorm3d(so3_4)
self.norm_1d_1 = nn.BatchNorm1d(fc1)
self.norm_1d_2 = nn.BatchNorm1d(fc2)
self.norm_1d_3 = nn.BatchNorm1d(fc3)
self.norm_1d_4 = nn.BatchNorm1d(fc4)
self.norm_1d_5 = nn.BatchNorm1d(fc5)
self.norm_1d_6 = nn.BatchNorm1d(1)
self.fc_layer_6 = nn.Linear(fc5, 1)
self.do1 = nn.Dropout(do1r)
self.do2 = nn.Dropout(do2r)
self.do3 = nn.Dropout(do3r)
self.do4 = nn.Dropout(do4r)
self.do5 = nn.Dropout(do5r)
self.sdo1 = nn.Dropout(param_dict["sdo"])
self.sdo2 = nn.Dropout(param_dict["sdo"])
self.sdo3 = nn.Dropout(param_dict["sdo"])
self.sdo4 = nn.Dropout(param_dict["sdo"])
def forward(self, x):
x = x.unsqueeze(2)
with autocast():
x = self.convlstm1(x)
x = self.conv1(x)
x = self.norm_layer_3d_1(x)
x = torch.relu(x)
x = self.sdo1(x)
x = self.conv2(x)
x = self.norm_layer_3d_2(x)
x = torch.relu(x)
x = self.sdo2(x)
x = self.conv3(x)
x = self.norm_layer_3d_3(x)
x = torch.relu(x)
x = self.sdo3(x)
x = self.conv4(x)
x = self.norm_layer_3d_4(x)
x = torch.relu(x)
x = self.sdo4(x)
if self.ftype == "SO3I":
x = so3_integrate(x)
else:
x = x.view(x.size(0),-1)
with autocast():
x = self.fc_layer(x)
x = self.norm_1d_1(x)
x = torch.relu(x)
x = self.do1(x)
x = self.fc_layer_2(x)
x = self.norm_1d_2(x)
x = torch.relu(x)
x = self.do2(x)
x = self.fc_layer_3(x)
x = self.norm_1d_3(x)
x = torch.relu(x)
x = self.do3(x)
x = self.fc_layer_4(x)
x = self.norm_1d_4(x)
x = torch.relu(x)
x = self.do4(x)
x = self.fc_layer_5(x)
x = self.norm_1d_5(x)
x = torch.relu(x)
x = self.do5(x)
x = self.fc_layer_6(x)
return x
def plot_result(test_loader, data_loader, classifier,suffix=EV_SUFFIX):
'''
Plot the training results of KamNet
'''
sigmoid_signal = []
sigmoid_bkg = []
energy_signal = []
energy_bkg = []
for images, labels, energies in test_loader:
classifier.eval()
with torch.no_grad():
energy_data = energies.cpu().data.numpy().flatten()
images = images.to(DEVICE)
labels = labels.view(-1,1)
labels = labels.to(DEVICE).float()
outputs = classifier(images).view(-1,1)
image_data = images.cpu().data.numpy().reshape(BATCH_SIZE,-1)
lb_data = labels.cpu().data.numpy().flatten()
outpt_data = outputs.cpu().data.numpy().flatten()
energy_data = energies.cpu().data.numpy().flatten()
signal = np.argwhere(lb_data == 1)
bkg = np.argwhere(lb_data == 0)
sigmoid_signal += list(outpt_data[signal].flatten())
sigmoid_bkg += list(outpt_data[bkg].flatten())
energy_signal += list(energy_data[signal].flatten())
energy_bkg += list(energy_data[bkg].flatten())
sigmoid_s = []
sigmoid_b_dict = {}
for images, labels in tqdm(data_loader):
classifier.eval()
with torch.no_grad():
images = images.to(DEVICE)
outputs = classifier(images)
outputs = outputs.view(-1,1)
lb_data = np.array(labels)
outpt_data = outputs.cpu().data.numpy().flatten()
signal = np.argwhere(lb_data == "Xe136")
sigmoid_s += list(outpt_data[signal].flatten())
bkg_name_list = np.unique(lb_data[lb_data != "Xe136"])
for bkg_name in bkg_name_list:
if bkg_name not in sigmoid_b_dict:
sigmoid_b_dict[bkg_name] = []
sigmoid_b_dict[bkg_name] += list(outpt_data[lb_data == bkg_name].flatten())
# Plot KamNet output spectrums with various backgrounds
# Calculate rejection % based on the signal acceptance threshold defined by thresh variable
thresh = 0.9
metric_list = []
# rg=np.linspace(0.0,1.0,100)
rg=np.linspace(min(sigmoid_s),max(sigmoid_s),100)
plt.hist(sigmoid_s, label ="Xe136", bins=rg, color="magenta", zorder=0,normed=True,alpha=0.3)
for bkgname in sigmoid_b_dict.keys():
fpr,tpr,thr,auc = get_roc(sigmoid_s, sigmoid_b_dict[bkgname])
effindex = np.abs(tpr-thresh).argmin()
effpurity = 1.-fpr[effindex]
plt.hist(sigmoid_b_dict[bkgname], label = "%s(%.3f)"%(bkgname,effpurity), bins=rg, histtype="step",normed=True, linewidth=1)
metric_list.append(auc)
# plt.ylim(0.0,8)
plt.ylabel("% of event/0.02 bins")
plt.xlabel('KamNet output')
plt.legend()
plt.savefig("llhist_%s.png"%(suffix),dpi=200)
plt.cla()
plt.clf()
plt.close()
# Plot the data energy spectrum and spectrum removed by KamNet
energy_signal = np.array(energy_signal)
sigmoid_signal = np.array(sigmoid_signal)
effindex = np.abs(tpr-0.9).argmin()
effthr = thr[effindex]
rg2 = np.arange(2.0,3.0,0.05)
plt.yscale("log")
plt.hist(energy_signal, bins=rg2,histtype = "step",label="All Data")
plt.hist(energy_signal[sigmoid_signal<=effthr], bins=rg2,histtype = "step",label="Cut Data")
plt.xlabel('Energy[MeV]')
plt.legend()
plt.savefig("cutevent_%s.png"%(suffix),dpi=200)
plt.cla()
plt.clf()
plt.close()
# Plot the KamNet spectrum of Bi214 MC and Bi214 data to check for data/MC agreement
# The factor 0.85/0.15 corresponds to the XeLS/film Bi214 fractions in the tagged Bi214 dataset
sigmoid_mcbi = np.concatenate([np.random.permutation(sigmoid_b_dict["Bi214-MC"])[:int(len(sigmoid_b_dict["Bi214-MC"])*0.85)],np.random.permutation(sigmoid_b_dict["Bi214-film"])[:int(len(sigmoid_b_dict["Bi214-film"])*0.15)]])
plt.hist(sigmoid_signal, label = r'DB_untagged', histtype='step',bins=rg, color=colormap_normal(0.9), density=True)
plt.hist(sigmoid_bkg, label = r'$^{214}Bi$ Data(%.3f)'%(get_rej(sigmoid_s, sigmoid_bkg)), histtype='step',bins=rg,color=colormap_normal(0.1),density=True)
plt.hist(sigmoid_mcbi, label = r'$^{214}Bi$ MC(%.3f)'%(get_rej(sigmoid_s, sigmoid_mcbi)), histtype='step',bins=rg,color=colormap_normal(0.4),density=True)
plt.xlabel('Sigmoid Ouptut')
plt.ylabel('Counts')
plt.legend(loc='upper center')
plt.savefig('test_log_%s.png'%(suffix))
plt.cla()
plt.clf()
plt.close()
def main():
'''
Training KamNet
'''
train_loader,eval_loader, test_loader, data_loader, time_channel, save_prefix, outdir = load_data(BATCH_SIZE)
classifier = KamNet(time_channel)
#=====================================================================================
'''
This part allows the loading of previously trained of KamNet using '.pt' model
'''
# pretrained_dict = torch.load('pretrain_data.pt')
# model_dict = classifier.state_dict()
# model_dict.update(pretrained_dict)
# classifier.load_state_dict(pretrained_dict)
#=====================================================================================
classifier.to(DEVICE)
print("#params", sum(x.numel() for x in classifier.parameters()))
'''
Define the loss function
'''
criterion = nn.BCEWithLogitsLoss()
criterion = criterion.to(DEVICE)
param_dict = KAMNET_PARAMS
#=====================================================================================
'''
Set up optimizer with varying learning rate:
Ramp Up : Gradually ramp up learning rate in the first 5 epochs, this allows the attention mechanism to learn proper attention score
Flat : Fix the learning rate at the nominal value
Ramp Down : Ramp down the learning rate to 10% of nominal value in the last 10th - 5th epochs
Flat : Fix the learning rate at 10% of the nominal value for the last 5 epochs
'''
step_length = len(train_loader)
total_step = int(NUM_EPOCHS * step_length)
ramp_up = np.linspace(1e-4, 1.0, 5*step_length)
ramp_down = list(np.linspace(1.0, 0.1, 5*step_length).flatten()) + [0.1]* 5*step_length
ramp_down_start = total_step - len(ramp_down)
lmbda = lambda epoch: ramp_up[epoch] if epoch<len(ramp_up) else ramp_down[epoch-ramp_down_start-1] if epoch > ramp_down_start else 1.0
optimizer = torch.optim.RMSprop(classifier.parameters(),lr=param_dict["lr"], momentum=param_dict["momentum"])
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lmbda)
#=====================================================================================
for epoch in tqdm(range(NUM_EPOCHS)):
print(scheduler.get_lr())
for i, (images, labels) in enumerate(train_loader):
classifier.train()
images = images.to(DEVICE)
labels = labels.view(-1,1)
labels = labels.to(DEVICE).float()
outputs = classifier(images)
loss = criterion(outputs,labels)
loss.backward() # optimizer the net
optimizer.step() # update parameters of net
optimizer.zero_grad() # reset gradient
scheduler.step()
print('\rEpoch [{0}/{1}], Iter [{2}/{3}] Loss: {4:.4f}'.format(
epoch+1, NUM_EPOCHS, i+1, len(train_loader),
loss.item(), end=""))
del images
torch.cuda.empty_cache()
plot_result(test_loader, data_loader, classifier)
torch.save(classifier.state_dict(), 'KamNet%s.pt'%(EV_SUFFIX)) # Save KamNet parameters in KamNet.pt file
main()
#=====================================================================================
# Author: Aobo Li
# Contact: liaobo77@gmail.com
#
# Last Modified: Aug. 29, 2021
#
# * The PyTorch dataset classes for KamNet
#=====================================================================================
import numpy as np
import torch.utils.data as data_utils
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from tool import label_data, create_table, create_table_zpos, get_roc, create_table_energy, look_table
from settings import FILE_UPPERLIM
from tqdm import tqdm
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
class DetectorDataset(Dataset):
def __init__(self, json_name):
"""
Base class for all KamNet datasets
"""
self.json_name = json_name
def __len__(self):
return self.size
def __getitem__(self, idx):
image = np.zeros(self.image_shape, dtype=np.float32)
for time_index, time in enumerate(self.trainX[idx]):
image[time_index] = time.todense()
return image, self.trainY[idx]
def return_time_channel(self):
'''
This method returns the time channel and one hit map dimension of input
E.g. If it returns (28,38), this means the input has 28 time channel, where
each channel contains a 38*38 hitmap
'''
return (self.__getitem__(0)[0].shape[0], self.image_shape[1])
def cap_resample(self,input,cap=5000):
'''
This method randomly resamples part of the dataset
'''
if input.shape[0] < cap:
return input
signal_samples = np.random.choice(np.arange(input.shape[0]), cap, replace=False)
return input[signal_samples]
def get_sparse_nhit(self, sparse_dict):
'''
This method get the nhit as a list of given event dict
It reads out the Nhit directly if Nhit is stored in the dict
Otherwise it calculate Nhit from the sparce matrices
'''
if "Nhit" in sparse_dict.keys():
return np.array(sparse_dict["Nhit"], dtype=int).flatten()
else:
sparsem = np.array(sparse_dict[self.json_name], dtype=object)
sparse_nhit = []
for i in tqdm(range(len(sparsem))):
sparse_nhit.append(np.sum([len(slice.nonzero()[0]) for slice in sparsem[i]]))
return np.array(sparse_nhit)
def match_nhit(self, signal_dict, background_dict, multiplier=1.0):
'''
Perform Nhit matching between input signal and output background
'''
signal_images = np.array(signal_dict[self.json_name], dtype=object)
background_images = np.array(background_dict[self.json_name], dtype=object)
nhit_range = np.arange(0,2000,1)
signal_nhit = np.array(self.get_sparse_nhit(signal_dict))
bkg_nhit = np.array(self.get_sparse_nhit(background_dict))
signal_list = []
bkg_list = []
for (nlow, nhi) in tqdm(zip(list(nhit_range[:-1]), list(nhit_range[1:])),0):
signal_index = np.where((signal_nhit >= nlow) & (signal_nhit <nhi))[0]
bkg_index = np.where((bkg_nhit >= nlow) & (bkg_nhit <nhi))[0]
if (len(signal_index) != 0) and (len(bkg_index) != 0):
sampled_amount = min(len(signal_index), len(bkg_index))
signal_list += list(np.random.choice(list(signal_index), sampled_amount, replace=False))
bkg_list += list(np.random.choice(list(bkg_index), min(len(bkg_index), int(sampled_amount*multiplier)), replace=False))
rg = np.arange(0,1000,1)
plt.hist(signal_nhit[signal_list],label="Signal",histtype="step",bins=rg)
plt.hist(bkg_nhit[bkg_list],label="Bkg",histtype="step",bins=rg)
plt.legend()
plt.savefig("Nhit.png")
plt.cla()
plt.clf()
plt.close()
return signal_images[signal_list], background_images[bkg_list]
def match_nhit_bootstrap(self, signal_dict, background_dict, multiplier=1.0):
'''
Perform Nhit matching between input signal and output background with bootstrap allowed
Bootstrap: sample with replacement in each dataset
'''
signal_images = np.array(signal_dict[self.json_name], dtype=object)
background_images = np.array(background_dict[self.json_name], dtype=object)
nhit_range = np.arange(0,38**2,1)
signal_nhit = np.array(self.get_sparse_nhit(signal_dict))
bkg_nhit = np.array(self.get_sparse_nhit(background_dict))
signal_samples = np.random.choice(signal_images.shape[0], signal_images.shape[0], replace=True)##NO RESAMPLE
signal_images = signal_images[signal_samples]
signal_nhit = signal_nhit[signal_samples]
bkg_samples = np.random.choice(background_images.shape[0], background_images.shape[0], replace=True)
background_images = background_images[bkg_samples]
bkg_nhit = bkg_nhit[bkg_samples]
signal_list = []
bkg_list = []
for (nlow, nhi) in tqdm(zip(list(nhit_range[:-1]), list(nhit_range[1:])),0):
signal_index = np.where((signal_nhit >= nlow) & (signal_nhit <nhi))[0]
bkg_index = np.where((bkg_nhit >= nlow) & (bkg_nhit <nhi))[0]
if (len(signal_index) != 0) and (len(bkg_index) != 0):
sampled_amount = min(len(signal_index), len(bkg_index))
signal_list += list(np.random.choice(list(signal_index), sampled_amount, replace=False))
bkg_list += list(np.random.choice(list(bkg_index), min(len(bkg_index), int(sampled_amount*multiplier)), replace=False))
return signal_images[signal_list], background_images[bkg_list]
def label_data(self, signal_images, background_images):
signal_labels = np.ones(len(signal_images), dtype=np.float32)
background_labels = np.zeros(len(background_images), dtype=np.float32)
size = len(signal_images) + len(background_images)
trainX = np.concatenate((signal_images, background_images), axis=0)
trainY = np.concatenate((signal_labels, background_labels), axis=0)
image_shape = (trainX.shape[-1], *trainX[0,0].shape)
return trainX, trainY, image_shape, size
class DetectorDataset_Nhit(DetectorDataset):
def __init__(self, signal_images_list, bkg_image_list, json_name, bootstrap=False, dsize=-1, elow=2.0,ehi=3.0):
super(DetectorDataset_Nhit, self).__init__(json_name)
"""
KamNet dataset with Nhit matching. Nhit matching removes Nhit dependency of signal/background events
Used for training the neural network
elow and ehi indicates the min/max energy of events we'd like to read out
"""
signal_dict = create_table_energy(signal_images_list, (json_name, 'Nhit','energy','zpos'), low=elow, high=ehi)
background_dict = create_table_energy(bkg_image_list, (json_name, 'Nhit','energy','zpos'), low=elow, high=ehi)
if bootstrap:
signal_images, background_images = self.match_nhit_bootstrap(signal_dict, background_dict)
else:
signal_images, background_images = self.match_nhit(signal_dict, background_dict)
if dsize != -1:
signal_images = self.cap_resample(signal_images,dsize)
background_images = self.cap_resample(background_images,dsize)
self.trainX, self.trainY, self.image_shape, self.size = self.label_data(signal_images, background_images)
class DetectorDataset_NonUniform(DetectorDataset):
def __init__(self, signal_images_list, bkg_image_list, json_name, elow=2.0,ehi=3.0):
super(DetectorDataset_NonUniform, self).__init__(json_name)
"""
KamNet dataset which do not require the signal/bkg dataset to follow the same size
"""
signal_dict = create_table_energy(signal_images_list, (json_name, 'Nhit','energy','zpos'), low=elow, high=ehi)
background_dict = create_table_energy(bkg_image_list, (json_name, 'Nhit','energy','zpos'), low=elow, high=ehi)
signal_images = np.array(signal_dict[json_name], dtype=object)
background_images = np.array(background_dict[json_name], dtype=object)
signal_images = self.cap_resample(signal_images)
background_images = self.cap_resample(background_images)
signal_labels = np.ones(len(signal_images), dtype=np.float32)
background_labels = np.zeros(len(background_images), dtype=np.float32)
print(signal_images.shape, background_images.shape)
self.trainX = np.concatenate((signal_images, background_images), axis=0)
print(self.trainX.shape)
self.size = self.trainX.shape[0]
self.trainY = np.concatenate((signal_labels, background_labels), axis=0)
self.image_shape = (self.trainX.shape[-1], *self.trainX[0,0].shape)
signal_ene = np.array(signal_dict["energy"]).flatten()
background_ene = np.array(background_dict["energy"]).flatten()
self.energy = np.concatenate((signal_ene, background_ene), axis=0)
def __getitem__(self, idx):
image = np.ndarray(self.image_shape, dtype=np.float32)
for time_index, time in enumerate(self.trainX[idx]):
image[time_index] = time.todense()
return image, self.trainY[idx], self.energy[idx]
class DetectorDatasetRep(DetectorDataset):
def __init__(self, signal_images_list, bkg_image_dict, json_name, dsize = -1, elow=2.0,ehi=3.0):
super(DetectorDatasetRep, self).__init__(json_name)
"""
KamNet dataset outputing multiple isotopes for validation purpose
"""
self.trainX = []
self.trainY = []
signal_dict = create_table_energy(signal_images_list[:FILE_UPPERLIM], (json_name,"Nhit"), low=elow, high=ehi)
sigim = np.array(signal_dict[json_name], dtype=object)
sigim = self.cap_resample(sigim, 2000)
self.trainX.append(sigim)
self.trainY += ["Xe136"] * len(sigim)
for bkgn,bkglist in bkg_image_dict.items():
bkgev = create_table_energy(bkglist[:FILE_UPPERLIM], (json_name, 'id'), low=elow, high=ehi)
sigim = np.array(bkgev[json_name], dtype=object)
if len(sigim) == 0:
continue
sigim = self.cap_resample(sigim, 2000)
print(bkgn)
self.trainX.append(sigim)
self.trainY += [bkgn] * len(sigim)
self.trainX = np.concatenate(self.trainX,axis=0)
self.trainY = np.array(self.trainY)
self.image_shape = (self.trainX.shape[-1], *self.trainX[0,0].shape)
self.size = len(self.trainY)
#=====================================================================================
# Author: Aobo Li
# Contact: liaobo77@gmail.com
#
# Last Modified: Aug. 29, 2021
#
# * constants and file addresses for KamNet
#=====================================================================================
SEED = False # Reproducibility of KamNet
NUM_EPOCHS = 30 # Number of training epochs
BATCH_SIZE = 32 # Batch size
FILE_UPPERLIM = 5 # Number of files in the pickle list, setting this to a small value allows faster training of model.
# Setting FILE_UPPERLIM to a very large number (a.k.a. 999999) will allow us to use the entire dataset
EV_SUFFIX = "_17and20good_nocharge" # Suffix of pickle list
DSIZE = 20000 # Number of signal/background events in training dataset. The final training size is DSIZE*2
KAMNET_PARAMS = {"momentum": 0.7806697572271865,# Hyperparameters of KamNet
"lr": 7.729560386535045e-05, "first_filter": 5,
"second_filter": 3,
"do": 0.08623589261579744,
"s2_1": 36, "so3_2": 63,
"so3_3": 74,
"so3_4": 124,
"fc_max": 1080,
"s1": 12, "s2": 22,
"sdo": 0.406085565931351,
"optimizer": "RMSprop",
"s2gridtype": "s2_ni",
"so3gridtype": "so3_eq",
"ftype": "SO3I",
"BATCH_SIZE": 32,
"last_bw": 2,
"do2": 0.0}
LEARNING_RATE =0.000018675460538381732 # Learning rate
\ No newline at end of file
import argparse
import os
import sys
import time
import json
import pickle
from scipy import sparse
import numpy as np
from datetime import datetime
from sklearn.metrics import roc_auc_score, roc_curve
import numpy as np
# from keras.callbacks import LearningRateScheduler
from tqdm import tqdm
DIM1 = 50
DIM2 = 25
DIM3 = 34
def get_roc(sig,bkg):
testY = np.array([1]*len(sig) + [0]*len(bkg))
predY = np.array(sig+bkg)
print(testY,predY)
auc = roc_auc_score(testY, predY)
fpr, tpr, thr = roc_curve(testY, predY)
return fpr,tpr,thr,auc
def label_data(signal_images, background_images):
labels = np.array([1] * len(signal_images) + [0] * len(background_images))
data = np.concatenate((signal_images, background_images))
return data, labels
class cd:
'''
Context manager for changing the current working directory
'''
def __init__(self, newPath):
self.newPath = newPath
def __enter__(self):
self.savedPath = os.getcwd()
os.chdir(self.newPath)
def __exit__(self, etype, value, traceback):
os.chdir(self.savedPath)
def shrink_image(input_image):
shrink_list = []
for index, image in enumerate(input_image,0):
if (np.count_nonzero(image.flatten()) == 0):
shrink_list.append(index)
output_image = np.delete(input_image, shrink_list ,0)
return output_image
def plot_loss(history, save_prefix=''):
# Loss Curves
plt.figure(figsize=[8,6])
plt.plot(history.history['loss'],'r',linewidth=3.0)
plt.plot(history.history['val_loss'],'b',linewidth=3.0)
plt.legend(['Training loss', 'Validation Loss'],fontsize=18)
plt.xlabel('Epochs ',fontsize=16)
plt.ylabel('Loss',fontsize=16)
plt.title('Loss Curves',fontsize=16)
plt.savefig(save_prefix + "acc.png")
def plot_accuracy(history, save_prefix=''):
# Accuracy Curves
plt.figure(figsize=[8,6])
plt.plot(history.history['acc'],'r',linewidth=3.0)
plt.plot(history.history['val_acc'],'b',linewidth=3.0)
plt.legend(['Training Accuracy', 'Validation Accuracy'],fontsize=18)
plt.xlabel('Epochs ',fontsize=16)
plt.ylabel('Accuracy',fontsize=16)
plt.title('Accuracy Curves',fontsize=16)
plt.savefig(save_prefix + "acc.png")
def plot_roc(my_network, testX, testY, save_prefix):
predY = my_network.predict_proba(testX)
print('\npredY.shape = ',predY.shape)
print(predY[0:10])
print(testY[0:10])
auc = roc_auc_score(testY, predY)
print('\nauc:', auc)
fpr, tpr, thr = roc_curve(testY, predY)
plt.plot(fpr, tpr, label = 'auc = ' + str(auc) )
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend(loc="lower right")
print('False positive rate:',fpr[1], '\nTrue positive rate:',tpr[1])
plt.savefig(save_prefix + "roc.png")
def load_data(npy_filename):
startTime = datetime.now()
with open(npy_filename) as json_data:
data = pd.read_json(json_data)
#print datetime.now() - startTime
return data.values.tolist()
def create_table(file_list, load_strings, dense=False):
event_dict = {el:[] for el in load_strings}
for file in tqdm(file_list):
# print(file)
try:
with open(file, 'rb') as f:
while True:
try:
event = pickle.load(f, encoding='latin1')
for load in load_strings:
event_dict[load].append(event[load])
except:
#except:
break
except:
'''
do nothing
'''
return event_dict
def create_table_zpos(file_list, load_strings, upper=True):
event_dict = {el:[] for el in load_strings}
for file in tqdm(file_list):
try:
with open(file, 'rb') as f:
while True:
try:
event = pickle.load(f, encoding='latin1')
if (upper and event["zpos"] < 0) or ((not upper) and event["zpos"] >= 0):
continue
for load in load_strings:
event_dict[load].append(event[load])
except:
#except:
break
except:
'''
do nothing
'''
return event_dict
def create_table_energy(file_list, load_strings, low=2.0,high=3.0):
event_dict = {el:[] for el in load_strings}
for file in tqdm(file_list):
try:
with open(file, 'rb') as f:
while True:
try:
event = pickle.load(f, encoding='latin1')
if (event["energy"] > high) or (event["energy"] < low):
continue
for load in load_strings:
event_dict[load].append(event[load])
except EOFerror:
#except:
break
except:
'''
do nothing
'''
return event_dict
def look_table(file_list, load_strings):
event_dict = {el:[] for el in load_strings}
for file in tqdm(file_list):
with open(file, 'rb') as f:
event = pickle.load(f, encoding='latin1')
print(event["event"])
assert 0
break
return event_dict
def step_decay_schedule(initial_lr=1e-3, decay_factor=0.75, step_size=10):
'''
Wrapper function to create a LearningRateScheduler with step decay schedule.
'''
def schedule(epoch):
return initial_lr * (decay_factor ** np.floor(epoch/step_size))
return LearningRateScheduler(schedule)
def get_rej(sig,bkg):
testY = np.array([1]*len(sig) + [0]*len(bkg))
predY = np.concatenate([np.array(sig).flatten(),np.array(bkg).flatten()],axis=0)
auc = roc_auc_score(testY, predY)
fpr, tpr, thr = roc_curve(testY, predY)
effindex = np.abs(tpr-0.9).argmin()
return 1 - fpr[effindex]
def get_roc(sig,bkg):
testY = np.array([1]*len(sig) + [0]*len(bkg))
predY = np.array(sig+bkg)
print(testY,predY)
auc = roc_auc_score(testY, predY)
fpr, tpr, thr = roc_curve(testY, predY)
return fpr,tpr,thr,auc
def roc_nhit(nhits, nhitb):
nhit_tot = nhits + nhitb
nhits = np.array(nhits)
nhitb = np.array(nhitb)
fpr = []
tpr = []
for nhit_cut in range(min(nhit_tot), max(nhit_tot)):
if np.average(nhits) > np.average(nhitb):
tpr.append(len(nhits[nhits>nhit_cut])/len(nhits))
fpr.append(len(nhitb[nhitb>nhit_cut])/len(nhitb))
else:
tpr.append(len(nhits[nhits<nhit_cut])/len(nhits))
fpr.append(len(nhitb[nhitb<nhit_cut])/len(nhitb))
return fpr, tpr
class cd:
'''
Context manager for changing the current working directory
'''
def __init__(self, newPath):
self.newPath = newPath
def __enter__(self):
self.savedPath = os.getcwd()
os.chdir(self.newPath)
def __exit__(self, etype, value, traceback):
os.chdir(self.savedPath)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment