Commit b5881ee2 authored by maming's avatar maming
Browse files

Initial commit

parents
MIT License
Copyright (c) 2021 aobol
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
# KamNet_pytorch
## 项目简介
KamNet 是一种面向球形液体闪烁体探测器的前沿神经网络模型。它利用 S² 和 SO(3) 群上的等变卷积操作,使模型对三维旋转具有不变性,适用于几何深度学习和非欧数据建模研究。
---
## 环境部署
### 1. 拉取镜像
```bash
docker pull harbor.sourcefind.cn:5443/dcu/admin/base/pytorch:2.5.1-ubuntu22.04-dtk25.04.2-py3.10
```
### 2. 创建容器
```bash
docker run -it \
--network=host \
--hostname=localhost \
--name=kamnet \
-v /opt/hyhal:/opt/hyhal:ro \
-v $PWD:/workspace \
--ipc=host \
--device=/dev/kfd \
--device=/dev/mkfd \
--device=/dev/dri \
--shm-size=512G \
--privileged \
--group-add video \
--cap-add=SYS_PTRACE \
-u root \
--security-opt seccomp=unconfined \
harbor.sourcefind.cn:5443/dcu/admin/base/pytorch:2.5.1-ubuntu22.04-dtk25.04.2-py3.10 \
/bin/bash
```
---
## 测试步骤
### 1. 拉取代码
```bash
git remote add origin http://developer.sourcefind.cn/codes/bw-bestperf/kamnet_pytorch.git
cd kamnet_pytorch/
```
### 2. 安装依赖
```bash
pip install joblib cython rtree shapely pyembree "trimesh[easy]" -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
cd KamNet/lie_learn && pip install e . #可能耗时较长
cd ../ && cd s2cnn && python setup.py install
```
### 3、获取优化包
```
curl -f -C - -o rocblas-install-1117-bug110157.tar.gz https://wuzh01.hpccube.com:65015/efile/s/d/bWFtaW5nMTAx/75c4336061c1be03
#设置环境变量
CURRENT_DIR=$(pwd)
export LD_LIBRARY_PATH=$CURRENT_DIR/rocblas-install/lib:$LD_LIBRARY_PATH
#验证环境变量设置
echo $LD_LIBRARY_PATH | tr ':' '\n' | head -3
```
## 训练脚本(8卡)
```bash
cd KamNet_pytorch/s2cnn/examples/mnist
#生成球面MNIST数据集
python3 gendata.py
# 进入缺失文件的目录
cd /usr/local/lib/python3.10/dist-packages/lie_learn/representations/SO3/pinchon_hoggan/
# 从官方仓库手动下载缺失的矩阵文件
wget https://raw.githubusercontent.com/AMLab-Amsterdam/lie_learn/master/lie_learn/representations/SO3/pinchon_hoggan/J_dense_0-150.npy
#numa绑定,查看机器numa亲和度
export HIP_VISIBLE_DEVICES=0
numactl --cpunodebind=3 python run_new.py --network original
```
---
## 贡献指南
欢迎对KamNet项目进行贡献!请遵循以下步骤:
1. Fork 本仓库,并新建分支进行功能开发或问题修复。
2. 提交规范的 commit 信息,描述清晰。
3. 提交 Pull Request,简述修改内容及目的。
4. 遵守项目代码规范和测试标准。
5. 参与代码评审,积极沟通改进方案。
---
## 许可证
本项目遵循 MIT 许可证,详见 [LICENSE](./LICENSE) 文件。
---
感谢您的关注与支持!如有问题,欢迎提交 Issue 或联系维护团队。
'''
Author: Aobo Li
History:
June 10, 2022 - First Version
Purpose:
This code defines the clock of liquid scintillator detector data.
Simulation file with [pmt_position, hittime, hitcharge] info are
stored as the spatiotemporal [t, theta, phi] grid, the clock
controls which t index should each hit be stored at.
'''
import argparse
import math
import os
import json
import pickle
from scipy import sparse
from scipy import constants as const
from sklearn.preprocessing import StandardScaler, Normalizer, MinMaxScaler
from random import *
import numpy as np
import time
from ROOT import TFile
from datetime import datetime
from tqdm import tqdm
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib import cm
colormap_normal = cm.get_cmap("OrRd")
class clock:
def __init__(self, initial_time):
clock.initiated=True
self.tick_interval = 1.5
self.final_time = 22
self.initiated=False
self.clock_array = np.arange(-20,self.final_time, self.tick_interval)
self.clock_array = self.clock_array + initial_time
def tick(self, time):
if (time <= self.clock_array[0]):
return 0
return self.clock_array[self.clock_array <= time].argmax()
def clock_size(self):
return (len(self.clock_array))
def get_range_from_tick(self, tick):
if tick == 0:
return -9999, self.clock_array[0]
return self.clock_array[tick-1], self.clock_array[tick]
#=====================================================================================
# Author: Aobo Li
# Contact: liaobo77@gmail.com
#
# Last Modified: Sep. 5, 2021
#
# * Batch processing script for KamNet
# * Converting the
#=====================================================================================
#!/usr/bin/python
import json
import time
import datetime
import sys
import argparse
import os
import re
import string
import signal
import subprocess
import shutil
import numpy as np
from settings import OUT_DIR, INPUT_DIR, MACRO_DIR, TIME, PROCESSOR, ISOTOPE
def main(argv):
refresh = False
if refresh:
'''
If refresh is true, this will delete all files in OUT_DIR, then start re-generating files.
'''
print("Delete all files in 5 second!")
time.sleep(5)
# Setting the output directory if it does not exist.
if not os.path.exists(OUT_DIR):
os.mkdir(OUT_DIR)
else:
if refresh:
shutil.rmtree(OUT_DIR)
os.mkdir(OUT_DIR)
if not os.path.exists(MACRO_DIR):
os.mkdir(MACRO_DIR)
else:
if refresh:
shutil.rmtree(MACRO_DIR)
os.mkdir(MACRO_DIR)
inputfiles = []
# Reading out isotopes from the input directory, and add their addresses into a list
for SIG in ISOTOPE:
inputfiles += [(ifile) for ifile in os.listdir(INPUT_DIR) if SIG in ifile and ".root" in ifile]
for rootfile in inputfiles:
# Loading the template to run on the Boston University SCC cluster batch queue
# In order to run on your cluster batch system, please modify this .sh template and its inputs
macrotemplate = string.Template(open('process_kamland.sh', 'r').read())
with cd(MACRO_DIR):
outputstring = str(OUT_DIR)
timestring = str(TIME)
inputstring = str(INPUT_DIR + rootfile)
macrostring = macrotemplate.substitute(TIME=timestring, INPUT=inputstring, OUTPUT=outputstring, PROCESSOR=PROCESSOR, PROCESSING_UNIT=-1)
macrofilename = 'shell_%s.sh' % (str(rootfile))
macro = open(macrofilename,'w')
macro.write(macrostring)
macro.close()
# print(os.path.join(MACRO_DIR, macrofilename))
try:
# os.system("source " + os.path.join(MACRO_DIR, macrofilename))
command = ['qsub', macrofilename]
process = subprocess.call(command)
except Exception as error:
return 0
return 1
class cd:
'''
Context manager for changing the current working directory
'''
def __init__(self, newPath):
self.newPath = newPath
def __enter__(self):
self.savedPath = os.getcwd()
os.chdir(self.newPath)
def __exit__(self, etype, value, traceback):
os.chdir(self.savedPath)
if __name__=="__main__":
print(sys.exit(main(sys.argv[1:])))
#=====================================================================================
# Author: Aobo Li
# Contact: liaobo77@gmail.com
#
# Last Modified: Aug. 29, 2021
#
# * This code generates the .dat list of all .picke files.
# * After running processing_kamland_new_mc.py or processing_sparse_time.py
# run this code to generate the pickle list. The pickle list is the input to
# KamNet.
#=====================================================================================
#!/usr/bin/python
import json
import time
import datetime
import sys
import argparse
import os
import re
import string
import signal
import subprocess
from settings import OUT_DIR, OUT_PICKLE_DIR, TAIL, ROWS, COLS
from tools import cd, append_file
def main():
# Setting the output Directory if it does not exist.
if not os.path.exists(OUT_PICKLE_DIR):
os.mkdir(OUT_PICKLE_DIR )
'''
Training combo is a python dict containing types of isotopes to generate pickle list
Each entry of the python dict takes the form of:
map[sig] = [bkg1, bkg2, bkg3,...]
Note that "sig" and every "bkg" string has to be part of the .pickle filename
'''
training_combo = {}
training_combo['Solar'] = ['Bi214m']
# Reads out all .picke file addresses
inputfiles = [(ifile) for ifile in os.listdir(OUT_DIR) if ".pickle" in ifile]
inputfiles.sort()
filename_array = {}
# Categorize .picke file addresses into corresponding types of isotopes (sig or bkg)
for npyfile in inputfiles:
for sig, bkg in training_combo.iteritems():
if sig in npyfile:
filename_array = append_file(sig, str(OUT_DIR + npyfile), filename_array)
else:
for single_bkg in bkg:
if single_bkg in npyfile:
filename_array = append_file(single_bkg, str(OUT_DIR + npyfile), filename_array)
# Generate the .dat pickle list
for key in filename_array.keys():
with cd(OUT_PICKLE_DIR):
writefile = open(str(key + TAIL +'.dat'),"w")
for filename in filename_array[key]:
if os.stat(filename).st_size == 0:
# Skip file with 0 size
continue
writefile.write(filename + '\n')
writefile.close()
if __name__=="__main__":
main()
\ No newline at end of file
#!/bin/bash -l
#$$ -P snoplus
#$$ -l h_rt=${TIME}
#$$ -j y
#$$ -V
source /projectnb/snoplus/Mo_work_place/RAT/rat-pac/env.sh
python ${PROCESSOR} --input ${INPUT} --outputdir ${OUTPUT} --process_index ${PROCESSING_UNIT}
#=====================================================================================
# Author: Aobo Li
# Contact: liaobo77@gmail.com
#
# Last Modified: Aug. 29, 2021
#
# * This code is used to convert MC simulated .root file into a 2D square grid
# * Save each event and other variables as a CSR sparse matrix in .pickle format.
# * Only applicable to the KLGSim simulation by the KamLAND-Zen group. To use this on your
# own experiment, please modify this code to adapt to your own MC data structures.
#=====================================================================================
import argparse
import math
import os
import json
import pickle
from scipy import sparse
from scipy import constants as const
from sklearn.preprocessing import StandardScaler, Normalizer, MinMaxScaler
from random import *
import numpy as np
import time
from ROOT import TFile
from datetime import datetime
from tqdm import tqdm
import matplotlib.gridspec as gridspec
from clock import clock
from tools import *
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib import cm
from settings import COLS, FV_CUT_LOW, FV_CUT_HI, good_hit, only_17inch, use_charge, ELOW, EHI, PLOT_HITMAP
colormap_normal = cm.get_cmap("cool")
FSIZE = 20
plt.rcParams['font.size'] = FSIZE
# Transcribe hits to a 4D tensor
def transcribe_hits(input, outputdir, PMT_POSITION, elow, ehi):
current_clock = clock(0)
f1 = TFile(input)
tree = f1.Get("nt") # Read the ROOT tree
start_evt = 0
end_evt = tree.GetEntries()
tz=[]
if PLOT_HITMAP:
end_evt = 10000
input_name = os.path.basename(input).split('.')[0]
event_map = []
for evt_index in tqdm(range(start_evt,end_evt)):
tree.GetEntry(evt_index)
#FV/ROI cut
try:
energy = tree.EnergyA2 # These are the
position = tree.r
hit = tree.NhitID
if (energy < ELOW) or (energy > EHI) or (position > FV_CUT_HI) or (position < FV_CUT_LOW):
continue
except:
print("error")
continue
'''
Read out PMT hitlist, time and charge
'''
if good_hit:
good_pmt_list = np.array(tree.pmtlist_good)
good_pmt_time_list = np.array(tree.pmtt_good)
good_pmt_charge_list = np.array(tree.pmtq_good)
else:
#gettting PMT information for a event
good_pmt_list = np.array(tree.pmtlist)
good_pmt_time_list = np.array(tree.pmtt)
good_pmt_charge_list = np.array(tree.pmtq)
'''
Read out only 17 inch PMTs (that is, PMT index < 1325)
'''
event = np.zeros((current_clock.clock_size(),ROWS,COLS))
if only_17inch:
good_index = good_pmt_list<1325
good_pmt_list = good_pmt_list[good_index]
good_pmt_time_list = good_pmt_time_list[good_index]
good_pmt_charge_list = good_pmt_charge_list[good_index]
vertex = np.array([tree.x/100.0,tree.y/100.0,tree.z/100.0])
# Calculate time of flight. In KLZ simulation the TOF is already subtracted, so just set it to 0
tof_array = []
for pmtid in good_pmt_list:
tof_array.append(0)
good_pmt_tof = np.array(tof_array)
tzero = tree.T0
total_charge = np.sum(good_pmt_charge_list)
stacked_pmt_info = np.dstack((good_pmt_list, good_pmt_time_list, good_pmt_charge_list, good_pmt_tof))[0]
timea = []
for pmtinfo in stacked_pmt_info:
if pmtinfo[-2] == 0.0:
# Skip PMT with 0 charge
continue
col, row = xyz_to_row_col(pmtinfo[0], PMT_POSITION)
t_center = pmtinfo[1] - tzero
tz.append(t_center)
time = current_clock.tick(t_center)
if use_charge:
event[time][row][col] += pmtinfo[-2]
else:
event[time][row][col] += 1.0
event_dic = {}
event_dic['id'] = tree.EventNumber
event_dic['run'] = tree.run
event_dic['Nhit'] = np.count_nonzero(event)
event_dic['energy'] = energy
event_dic['vertex'] = tree.r
event_dic['zpos'] = tree.z
event_dic['event'] = event
event_map.append(event_dic)
if PLOT_HITMAP:
'''
This is the plot method for given dataset, it plots a few selected hit maps for
demonstration purpose
'''
plt.figure(figsize=(15,15))
spec = gridspec.GridSpec(ncols=4, nrows=2, height_ratios=[1,2])
plt.subplot(spec[1,:])
idx_pool = [5,11,14,18]
plt.hist(tz,bins=np.arange(-20,40,1.5),density=True,color=colormap_normal(0.2))
plt.axvline(x=-20,color="red",label="KamNet Window")
plt.axvline(x=22,color="red")
for idxc in idx_pool:
begin,end = current_clock.get_range_from_tick(idxc)
plt.axvspan(xmin=begin,xmax=end,color=colormap_normal(0.7),alpha=0.5)
plt.ylim(0,0.08)
plt.legend(frameon=False)
plt.xlabel("Proper Hit Time [ns]",fontsize=25,labelpad=20)
plt.ylabel("Normalized Amplitude",fontsize=25,labelpad=20)
# plt.savefig("th.png",dpi=600)
with open(os.path.join(outputdir, "eventfile_%s_%.2f_%.2f.pickle" % (input_name, elow, ehi)), 'wb') as handle:
numev = 0
print(len(event_map))
for eventd in event_map:
evnt = eventd['event']
eventd['nhit'] = np.count_nonzero(evnt)
numev += 1
time_sequence = []
subplot_index = 0
for idx, maps in enumerate(evnt):
if PLOT_HITMAP and (idx in idx_pool):
ax = plt.subplot(spec[0,subplot_index ])
begin,end = current_clock.get_range_from_tick(idx)
if begin == -9999:
plt.title("(Past, %.1f ns)"%(end),fontsize=FSIZE)
else:
plt.title("(%s ns, %.1f ns)"%(begin,end),fontsize=FSIZE)
subplot_index += 1
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
ax.imshow(maps,cmap=colormap_normal, norm=matplotlib.colors.LogNorm(vmin=0.3, vmax=10.0))
# plt.colorbar()
if subplot_index > 49:
break
time_sequence.append(sparse.csr_matrix(maps)) # Save each event as a CSR sparse matrix
if PLOT_HITMAP:
plt.tight_layout()
plt.savefig("event.png",dpi=600)
plt.show()
assert 1==0
eventd['event'] = time_sequence
pickle.dump(eventd, handle, protocol=pickle.HIGHEST_PROTOCOL) # dump event into .pickle file
print("Number of Events: ", numev)
return 0
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input", default="/projectnb2/snoplus/KLZ_NEW2/machine_learning/CDM/CDM_deltaM16.5_XeLS_8.root")
parser.add_argument("--outputdir", default="/projectnb/snoplus/sphere_data/c10_2MeV")
parser.add_argument("--pmt_file_index", default="/project/snoplus/ml2/data/pmt_xyz.dat")
parser.add_argument("--pmt_file_size", default="/projectnb/snoplus/machine_learning/prototype/pmt.txt")
parser.add_argument("--process_index", type=int, default=-1)
parser.add_argument("--elow", type=float, default=2.0)
parser.add_argument("--ehi", type=float, default=3.0)
args = parser.parse_args()
position = PMT_setup(args.pmt_file_index)
fmc = transcribe_hits(input=args.input, outputdir=args.outputdir, PMT_POSITION = position,elow=args.elow, ehi=args.ehi)
main()
#=====================================================================================
# Author: Aobo Li
# Contact: liaobo77@gmail.com
#
# Last Modified: Sep. 5, 2021
#
# * setting files for data batch processing scripts
#=====================================================================================
import os
TAIL = "_trial1" # The use-defined suffix of pickle list, used to distinguish different input pickle lists
OUT_DIR = os.path.join("/projectnb2/snoplus/KLZ_NEWFINAL/machine_learning/", "KamLAND" + TAIL + "/") # Location to store the .pickle files
OUT_PICKLE_DIR = "/projectnb/snoplus/KLZ_NEWFINAL/machine_learning_plist/" # Location of the .dat pickle list
INPUT_DIR = "/projectnb/snoplus/KLZ_NEWFINAL/new_ml_data/data-root-KamNET/" # Location of the input .root files
MACRO_DIR = "/projectnb/snoplus/machine_learning/data/shell" # A place to store all the generated shell scripts
TIME = "0:5:00" # Processing time of each shell script
PROCESSOR = "/project/snoplus/ml2/preprocessing/processing_kamland_new.py" # The processor we'd like to use
ISOTOPE = ["Solar", "C10p"] # Name of the isotopes we'd like to process, this name must be part of the file name of the .root file
# Define the size of each hit maps
COLS = 38
ROWS = COLS
# Define the fiducial volume cutting threshold for event selection
FV_CUT_LOW = 0.0
FV_CUT_HI = 167.0
# Define the energy range for event selection
ELOW = 2.0
EHI = 3.0
good_hit = True # If true, use only good PMT hits, otherwise use all PMT hits
only_17inch = False # If true, use only 17-inch PMTs, otherwise use both 17 and 20 inch pmts
use_charge = False # If true, register the corresponding charge of each PMT to hit map for each hit, otherwise register 1.0 for each hit
PLOT_HITMAP=False # Plot flag. If true, plot hit maps of a input event. Note that if this flag is set to True, then the processing script won't process any file.
import math
import numpy as np
import json
import re
from settings import *
def PMT_setup(pmt_file_with_index):
'''
Reads in PMT positional information according to PMT location file.
This file is internal to KamLAND collaboration thus cannot be provided here,
it should follow the format of:
PMTID PMT_X[cm] PMT_Y[cm] PMT_Z[cm]
Each line represents a PMT and each field is separated by blank space
'''
PMT_POSITION = {}
for pmt in np.loadtxt(pmt_file_with_index):
current_pmt_pos = pmt[1:] / 100.0
PMT_POSITION[int(pmt[0])] = current_pmt_pos
return PMT_POSITION
def xyz_to_phi_theta(x, y, z):
phi = math.atan2(y, x)
r = (x**2 + y**2 + z**2)**.5
theta = math.acos(z / r)
return phi, theta
#change directory
class cd:
'''
Context manager for changing the current working directory
'''
def __init__(self, newPath):
self.newPath = newPath
def __enter__(self):
self.savedPath = os.getcwd()
os.chdir(self.newPath)
def __exit__(self, etype, value, traceback):
os.chdir(self.savedPath)
def tof(pmt_position, vertex_position):
# ceff = 16.95 #cm/ns From KatLTVertex.cc source code in Kat
return np.linalg.norm(pmt_position-vertex_position,2) * 100.0/16.95
# Convert the phi theta information to row and column index in 2D grid
def phi_theta_to_row_col(phi, theta, rows=ROWS, cols=COLS):
# phi is in [-pi, pi], theta is in [0, pi]
row = min(rows/2 + (math.floor((rows/2)*phi/math.pi)), rows-1)
row = max(row, 0)
col = min(math.floor(cols*theta/math.pi), cols-1);
col = max(col, 0)
return int(row), int(col)
# Calculating the angle between two input vectors
def calculate_angle(vec1, vec2):
x1,y1,z1 = vec1
x2,y2,z2 = vec2
inner_product = x1*x2 + y1*y2 + z1*z2
len1 = (x1**2 + y1**2 + z1**2)**0.5
len2 = (x2**2 + y2**2 + z2**2)**0.5
return math.acos(float(inner_product)/float(len1*len2))
# Converting Cartesian position to 2D Grid
def xyz_to_row_col(pmt_index, PMT_POSITION,rows=ROWS, cols=COLS):
x, y, z = tuple(PMT_POSITION[pmt_index])
return phi_theta_to_row_col(*xyz_to_phi_theta(x, y, z), rows=rows, cols=cols)
# Set up the clocl to start ticking on the first incoming photon of a given events.
def set_clock(tree, evt):
tree.GetEntry(evt)
time_array = []
for i in range(tree.N_phot):
time_array.append(tree.PE_time[i])
return clock(np.array(time_array).min())
# Save input file as a .json file.
def savefile(saved_file, appendix, filename, pathname):
if not os.path.exists(pathname):
os.mkdir(pathname)
with cd(pathname):
with open(filename, 'w') as datafile:
json.dump(saved_file, datafile)
def calculate_tzero(t, tof, charge):
return np.sum((t - tof) * charge)/ np.sum(charge)
def append_file(key, val, filename_array):
if key not in filename_array.keys():
filename_array[key] = [val]
else:
filename_array[key].append(val)
return filename_array
def get_name(file):
'''
Retrieve run number from file name
'''
zdab_regex = re.compile(r"^eventfile_sph_out_(.+)_[a-z].*_1k_(\d+)\.\d+\.\d+.pickle$")
matches = zdab_regex.match(file)
if matches:
return str(matches.group(1)), str(matches.group(2))
else:
return 0
\ No newline at end of file
FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-py3.10-dtk24.04.3-ubuntu20.04
name: "Build and Publish"
on:
release:
types: [published]
workflow_dispatch:
jobs:
build_wheels:
name: Build wheels on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
environment: release
permissions:
id-token: write
strategy:
matrix:
os: [ubuntu-latest, windows-latest, macos-13, macos-14]
steps:
- uses: actions/checkout@v4
# Used to host cibuildwheel
- uses: actions/setup-python@v5
- name: Install cibuildwheel
run: python -m pip install cibuildwheel
- name: Build wheels
run: python -m cibuildwheel --output-dir dist
# to supply options, put them in 'env', like:
env:
CIBW_BUILD: "cp39-* cp310-* cp311-* cp312-*"
CIBW_SKIP: "*_i686 *-musllinux_*"
- uses: actions/upload-artifact@v4
with:
name: lie_learn-${{ matrix.os }}-${{ strategy.job-index }}
path: ./dist/*.whl
- name: Publish package
if: matrix.os == 'ubuntu-latest'
uses: pypa/gh-action-pypi-publish@release/v1
# Current project only:
#######################
# ignore cython-generated files
*.c
# Compiled source #
###################
*.com
*.class
*.dll
*.exe
*.o
*.so
*.pyc
# LaTeX #
#########
*.aux
*.glo
*.idx
*.log
*.toc
*.ist
*.acn
*.acr
*.alg
*.bbl
*.blg
*.dvi
*.glg
*.gls
*.ilg
*.ind
*.lof
*.lot
*.maf
*.mtc
*.mtc1
*.out
*.synctex.gz
# Packages #
############
# it's better to unpack these files and commit the raw source
# git has its own built in compression methods
*.7z
*.dmg
*.gz
*.iso
*.jar
*.rar
*.tar
*.zip
# Logs and databases #
######################
*.log
*.sql
*.sqlite
# OS generated files #
######################
.DS_Store
.DS_Store?
._*
.Spotlight-V100
.Trashes
Icon?
ehthumbs.db
Thumbs.db
# Emacs #
#########
*~
\#*\#
/.emacs.desktop
/.emacs.desktop.lock
.elc
auto-save-list
tramp
.\#*
# Others #
##########
# Python-pickled data files
*.pkl
*.npy
*.imc
*.mat
*.idea
.eggs
*.egg-info
dist
build
*.html
Copyright (c) 2017, Taco Cohen
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
\ No newline at end of file
global-exclude tests/*
global-exclude test_*.py
[Install docker](https://docs.docker.com/get-docker/).
Get the [manylinux docker environment](https://github.com/pypa/manylinux).
At the time of this writing, `manylinux1` is compatible; however, I used `manylinx2014`.
There is a tool which will label the manylinux binary with the oldest compatible standard.
```bash
docker pull quay.io/pypa/manylinux2014_x86_64
```
Run an interactive bash shell in the manylinux docker environment.
```bash
docker run -it quay.io/pypa/manylinux2014_x86_64 /bin/bash
```
Inside the interactive bash shell for the docker environment, download lie_learn and change to the source directory.
```bash
git clone https://github.com/AMLab-Amsterdam/lie_learn.git
cd lie_learn
```
Create wheels. You have to determine which versions of python are appropriate.
```bash
/opt/python/cp35-cp35m/bin/python setup.py bdist_wheel
/opt/python/cp36-cp36m/bin/python setup.py bdist_wheel
/opt/python/cp37-cp37m/bin/python setup.py bdist_wheel
/opt/python/cp38-cp38/bin/python setup.py bdist_wheel
```
Use auditwheel to check for success and modify the binaries to be labeled with the oldest compatible standard (lowest
priority).
```bash
auditwheel repair ./dist/lie_learn-0.0.1.post1-cp35-cp35m-linux_x86_64.whl -w ./manylinux
auditwheel repair ./dist/lie_learn-0.0.1.post1-cp36-cp36m-linux_x86_64.whl -w ./manylinux
auditwheel repair ./dist/lie_learn-0.0.1.post1-cp37-cp37m-linux_x86_64.whl -w ./manylinux
auditwheel repair ./dist/lie_learn-0.0.1.post1-cp38-cp38-linux_x86_64.whl -w ./manylinux
```
Open a new terminal window (host environment) and get the running docker `CONTAINER ID`.
```bash
docker ps
```
yields
```
CONTAINER ID IMAGE COMMAND CREATED STATUS PORTS NAMES
8e2b2c3baa8e quay.io/pypa/manylinux2014_x86_64 "/bin/bash" 30 minutes ago Up 30 minutes charming_shannon
```
In my case, the `CONTAINER ID` is `8e2b2c3baa8e`.
In the new terminal window, copy the manylinux wheels from the running container to a folder you'll remember.
```bash
mkdir ~/manylinux
docker cp 8e2b2c3baa8e:/lie_learn/manylinux/lie_learn-0.0.1.post1-cp35-cp35m-manylinux1_x86_64.whl ~/manylinux/
docker cp 8e2b2c3baa8e:/lie_learn/manylinux/lie_learn-0.0.1.post1-cp36-cp36m-manylinux1_x86_64.whl ~/manylinux/
docker cp 8e2b2c3baa8e:/lie_learn/manylinux/lie_learn-0.0.1.post1-cp37-cp37m-manylinux1_x86_64.whl ~/manylinux/
docker cp 8e2b2c3baa8e:/lie_learn/manylinux/lie_learn-0.0.1.post1-cp38-cp38-manylinux1_x86_64.whl ~/manylinux/
```
First do a test by uploading to test pypi.
```bash
twine upload --repository-url https://test.pypi.org/legacy/ ~/manylinux/*
```
Try downloading and testing `lie_learn` from there before proceeding.
This is easier said than done. You will need to download all of the dependencies manually then download from test
pypi without any dependencies using
`pip install --no-cache-dir --index-url https://test.pypi.org/simple/ --no-deps lie_learn`.
Once you know it's working, upload the wheels to pypi with twine.
```bash
twine upload ~/manylinux/*
```
For a bit more info, another useful resource is https://opensource.com/article/19/2/manylinux-python-wheels.
\ No newline at end of file
lie_learn is a python package that knows how to do various tricky computations related to Lie groups and manifolds (mainly the sphere S2 and rotation group SO3). This package was written to support various machine learning projects, such as Harmonic Exponential Families [2], (continuous) Group Equivariant Networks [3], Steerable CNNs [4] and Spherical CNNs [5].
# What this code can do
- Reparamterize rotations, e.g. matrix to Euler angles to quaternions, etc. (see groups & spaces modules)
- Compute the Wigner-d and Wigner-D matrices (the irreducible representations of SO(3)), and spherical harmonics, using the method developed by Pinchon & Hoggan [1] (see pinchon_hoggan_dense.py). This is a very fast and stable method, but requires a fairly large "J matrix", which we have precomputed up to order 278 using a Maple script. The code will automatically download it from Google Drive during installation.
Note: There are many normalization and phase conventions for both the real and complex versions of the D-matrices and spherical harmonics, and the code can convert between a lot of them (irrep_bases.pyx).
- Compute generalized / non-commutative FFTs for the sphere S2, rotation group SO3, and special Euclidean group SE2 (see spectral module).
- Fit Harmonic Exponential Families on the sphere (probability module; not sure code is still working)
# Installation
lie_learn can be installed from pypi using:
```
pip install lie_learn
```
Although cython is not a necessary dependency, if you have cython installed, cython will write new versions of the `*.c
` files before compiling them into `*.so` during installation. To use lie_learn, you will need a c compiler which is
available to python setuptools.
# Feedback
For questions and comments, feel free to contact Taco Cohen (http://ta.co.nl).
# References
[1] Pinchon, D., & Hoggan, P. E. (2007). Rotation matrices for real spherical harmonics: general rotations of atomic orbitals in space-fixed axes. Journal of Physics A: Mathematical and Theoretical, 40(7), 1597–1610.
[2] Cohen, T. S., & Welling, M. (2015). Harmonic Exponential Families on Manifolds. In Proceedings of the 32nd International Conference on Machine Learning (ICML) (pp. 1757–1765).
[3] Cohen, T. S., & Welling, M. (2016). Group equivariant convolutional networks. In Proceedings of The 33rd International Conference on Machine Learning (ICML) (Vol. 48, pp. 2990–2999).
[4] Cohen, T. S., & Welling, M. (2017). Steerable CNNs. In ICLR.
[5] T.S. Cohen, M. Geiger, J. Koehler, M. Welling (2017). Convolutional Networks for Spherical Signals. In ICML Workshop on Principled Approaches to Deep Learning.
import numpy as np
from numpy.lib.index_tricks import as_strided
def generalized_broadcast(arrays):
"""
Broadcast X and Y, while ignoring the last axis of X and Y.
If X.shape = xs + (i,)
and Y.shape = ys + (j,)
then the output arrays have shapes
Xb.shape = zs + (i,)
Yb.shape = zs + (j,)
where zs is the shape of the broadcasting of xs and ys shaped arrays.
:param arrays: a list of numpy arrays to be broadcasted while ignoring the last axis.
:return: a list of arrays whose shapes have been broadcast
"""
arrays1 = np.broadcast_arrays(*[A[..., 0] for A in arrays])
shapes_b = [A1.shape + (A.shape[-1],) for A1, A in zip(arrays1, arrays)]
strides_b = [A1.strides + (A.strides[-1],) for A1, A in zip(arrays1, arrays)]
arrays_b = [as_strided(A, shape=shape_Ab, strides=strides_Ab)
for A, shape_Ab, strides_Ab in zip(arrays, shapes_b, strides_b)]
return arrays_b
def make_gufunc(f, core_dims_in, core_dims_out):
"""
Automatically turn a function f into a generalized universal function (gufunc).
:param f:
:param core_dims_in:
:param core_dims_out:
:return:
"""
return
def gufunc(args):
args = generalized_broadcast(args)
data_shape = args[0].shape[:-len(core_dims_in[0])]
args = [A.reshape(-1, A.shape[-1]) for A in args]
#if X_out is None:
# X_out = np.empty_like(X)
#X_out = X_out.reshape(-1, X.shape[-1])
out = f(args)
return out.reshape()
return gufunc
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment