Commit 524a1b6e authored by mashun's avatar mashun
Browse files

particle

parents
Pipeline #1943 failed with stages
in 0 seconds
import torch
import torch.nn as nn
class ResNetUnit(nn.Module):
r"""Parameters
----------
in_channels : int
Number of channels in the input vectors.
out_channels : int
Number of channels in the output vectors.
strides: tuple
Strides of the two convolutional layers, in the form of (stride0, stride1)
"""
def __init__(self, in_channels, out_channels, strides=(1, 1), **kwargs):
super(ResNetUnit, self).__init__(**kwargs)
self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=strides[0], padding=1)
self.bn1 = nn.BatchNorm1d(out_channels)
self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, stride=strides[1], padding=1)
self.bn2 = nn.BatchNorm1d(out_channels)
self.relu = nn.ReLU()
self.dim_match = True
if not in_channels == out_channels or not strides == (1, 1): # dimensions not match
self.dim_match = False
self.conv_sc = nn.Conv1d(in_channels, out_channels, kernel_size=1,
stride=strides[0] * strides[1], bias=False)
def forward(self, x):
identity = x
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
# print('resnet unit', identity.shape, x.shape, self.dim_match)
if self.dim_match:
return identity + x
else:
return self.conv_sc(identity) + x
class ResNet(nn.Module):
r"""Parameters
----------
features_dims : int
Input feature dimensions.
num_classes : int
Number of output classes.
conv_params : list
List of the convolution layer parameters.
The first element is a tuple of size 1, defining the transformed feature size for the initial feature convolution layer.
The following are tuples of feature size for multiple stages of the ResNet units. Each number defines an individual ResNet unit.
fc_params: list
List of fully connected layer parameters after all EdgeConv blocks, each element in the format of
(n_feat, drop_rate)
"""
def __init__(self, features_dims, num_classes,
conv_params=[(32,), (64, 64), (64, 64), (128, 128)],
fc_params=[(512, 0.2)],
for_inference=False,
**kwargs):
super(ResNet, self).__init__(**kwargs)
self.conv_params = conv_params
self.num_stages = len(conv_params) - 1
self.fts_conv = nn.Sequential(
nn.BatchNorm1d(features_dims),
nn.Conv1d(
in_channels=features_dims, out_channels=conv_params[0][0],
kernel_size=3, stride=1, padding=1),
nn.BatchNorm1d(conv_params[0][0]),
nn.ReLU())
# define ResNet units for each stage. Each unit is composed of a sequence of ResNetUnit block
self.resnet_units = nn.ModuleDict()
for i in range(self.num_stages):
# stack units[i] layers in this stage
unit_layers = []
for j in range(len(conv_params[i + 1])):
in_channels, out_channels = (conv_params[i][-1], conv_params[i + 1][0]) if j == 0 \
else (conv_params[i + 1][j - 1], conv_params[i + 1][j])
strides = (2, 1) if (j == 0 and i > 0) else (1, 1)
unit_layers.append(ResNetUnit(in_channels, out_channels, strides))
self.resnet_units.add_module('resnet_unit_%d' % i, nn.Sequential(*unit_layers))
# define fully connected layers
fcs = []
for idx, layer_param in enumerate(fc_params):
channels, drop_rate = layer_param
in_chn = conv_params[-1][-1] if idx == 0 else fc_params[idx - 1][0]
fcs.append(nn.Sequential(nn.Linear(in_chn, channels), nn.ReLU(), nn.Dropout(drop_rate)))
fcs.append(nn.Linear(fc_params[-1][0], num_classes))
if for_inference:
fcs.append(nn.Softmax(dim=1))
self.fc = nn.Sequential(*fcs)
def forward(self, points, features, lorentz_vectors, mask):
# x: the feature vector, (N, C, P)
if mask is not None:
features = features * mask
x = self.fts_conv(features)
for i in range(self.num_stages):
x = self.resnet_units['resnet_unit_%d' % i](x) # (N, C', P'), P'<P due to kernal_size>1 or stride>1
# global average pooling
x = x.mean(dim=-1) # (N, C')
# fully connected
x = self.fc(x) # (N, out_chn)
return x
def get_model(data_config, **kwargs):
conv_params = [(32,), (64, 64), (64, 64), (128, 128)]
fc_params = [(512, 0.2)]
pf_features_dims = len(data_config.input_dicts['pf_features'])
num_classes = len(data_config.label_value)
model = ResNet(pf_features_dims, num_classes,
conv_params=conv_params,
fc_params=fc_params)
model_info = {
'input_names': list(data_config.input_names),
'input_shapes': {k: ((1,) + s[1:]) for k, s in data_config.input_shapes.items()},
'output_names': ['softmax'],
'dynamic_axes': {**{k: {0: 'N', 2: 'n_' + k.split('_')[0]} for k in data_config.input_names}, **{'softmax': {0: 'N'}}},
}
return model, model_info
def get_loss(data_config, **kwargs):
return torch.nn.CrossEntropyLoss()
import torch
import torch.nn as nn
class ParticleFlowNetwork(nn.Module):
r"""Parameters
----------
input_dims : int
Input feature dimensions.
num_classes : int
Number of output classes.
layer_params : list
List of the feature size for each layer.
"""
def __init__(self, input_dims, num_classes,
Phi_sizes=(100, 100, 128),
F_sizes=(100, 100, 100),
use_bn=True,
for_inference=False,
**kwargs):
super(ParticleFlowNetwork, self).__init__(**kwargs)
# input bn
self.input_bn = nn.BatchNorm1d(input_dims) if use_bn else nn.Identity()
# per-particle functions
phi_layers = []
for i in range(len(Phi_sizes)):
phi_layers.append(nn.Sequential(
nn.Conv1d(input_dims if i == 0 else Phi_sizes[i - 1], Phi_sizes[i], kernel_size=1),
nn.BatchNorm1d(Phi_sizes[i]) if use_bn else nn.Identity(),
nn.ReLU())
)
self.phi = nn.Sequential(*phi_layers)
# global functions
f_layers = []
for i in range(len(F_sizes)):
f_layers.append(nn.Sequential(
nn.Linear(Phi_sizes[-1] if i == 0 else F_sizes[i - 1], F_sizes[i]),
nn.ReLU())
)
f_layers.append(nn.Linear(F_sizes[-1], num_classes))
if for_inference:
f_layers.append(nn.Softmax(dim=1))
self.fc = nn.Sequential(*f_layers)
def forward(self, points, features, lorentz_vectors, mask):
# x: the feature vector initally read from the data structure, in dimension (N, C, P)
x = self.input_bn(features)
x = self.phi(x)
if mask is not None:
x = x * mask.bool().float()
x = x.sum(-1)
return self.fc(x)
def get_model(data_config, **kwargs):
Phi_sizes = (128, 128, 128)
F_sizes = (128, 128, 128)
input_dims = len(data_config.input_dicts['pf_features'])
num_classes = len(data_config.label_value)
model = ParticleFlowNetwork(input_dims, num_classes, Phi_sizes=Phi_sizes,
F_sizes=F_sizes, use_bn=kwargs.get('use_bn', False))
model_info = {
'input_names': list(data_config.input_names),
'input_shapes': {k: ((1,) + s[1:]) for k, s in data_config.input_shapes.items()},
'output_names': ['softmax'],
'dynamic_axes': {**{k: {0: 'N', 2: 'n_' + k.split('_')[0]} for k in data_config.input_names}, **{'softmax': {0: 'N'}}},
}
return model, model_info
def get_loss(data_config, **kwargs):
return torch.nn.CrossEntropyLoss()
import torch
from weaver.nn.model.ParticleNet import ParticleNet
'''
Link to the full model implementation:
https://github.com/hqucms/weaver-core/blob/main/weaver/nn/model/ParticleNet.py
'''
class ParticleNetWrapper(torch.nn.Module):
def __init__(self, **kwargs) -> None:
super().__init__()
self.mod = ParticleNet(**kwargs)
def forward(self, points, features, lorentz_vectors, mask):
return self.mod(points, features, mask)
def get_model(data_config, **kwargs):
conv_params = [
(16, (64, 64, 64)),
(16, (128, 128, 128)),
(16, (256, 256, 256)),
]
fc_params = [(256, 0.1)]
pf_features_dims = len(data_config.input_dicts['pf_features'])
num_classes = len(data_config.label_value)
model = ParticleNetWrapper(
input_dims=pf_features_dims,
num_classes=num_classes,
conv_params=kwargs.get('conv_params', conv_params),
fc_params=kwargs.get('fc_params', fc_params),
use_fusion=kwargs.get('use_fusion', False),
use_fts_bn=kwargs.get('use_fts_bn', True),
use_counts=kwargs.get('use_counts', True),
for_inference=kwargs.get('for_inference', False)
)
model_info = {
'input_names': list(data_config.input_names),
'input_shapes': {k: ((1,) + s[1:]) for k, s in data_config.input_shapes.items()},
'output_names': ['softmax'],
'dynamic_axes': {**{k: {0: 'N', 2: 'n_' + k.split('_')[0]} for k in data_config.input_names}, **{'softmax': {0: 'N'}}},
}
return model, model_info
def get_loss(data_config, **kwargs):
return torch.nn.CrossEntropyLoss()
import torch
import torch.nn as nn
from weaver.nn.model.ParticleNet import ParticleNet
'''
Link to the full model implementation:
https://github.com/hqucms/weaver-core/blob/main/weaver/nn/model/ParticleNet.py
'''
class ParticleNetWrapper(nn.Module):
def __init__(self, **kwargs) -> None:
super().__init__()
in_dim = kwargs['fc_params'][-1][0]
num_classes = kwargs['num_classes']
self.for_inference = kwargs['for_inference']
# finetune the last FC layer
self.fc_out = nn.Linear(in_dim, num_classes)
kwargs['for_inference'] = False
self.mod = ParticleNet(**kwargs)
self.mod.fc = self.mod.fc[:-1]
def forward(self, points, features, lorentz_vectors, mask):
x_cls = self.mod(points, features, mask)
output = self.fc_out(x_cls)
if self.for_inference:
output = torch.softmax(output, dim=1)
return output
def get_model(data_config, **kwargs):
conv_params = [
(16, (64, 64, 64)),
(16, (128, 128, 128)),
(16, (256, 256, 256)),
]
fc_params = [(256, 0.1)]
pf_features_dims = len(data_config.input_dicts['pf_features'])
num_classes = len(data_config.label_value)
model = ParticleNetWrapper(
input_dims=pf_features_dims,
num_classes=num_classes,
conv_params=kwargs.get('conv_params', conv_params),
fc_params=kwargs.get('fc_params', fc_params),
use_fusion=kwargs.get('use_fusion', False),
use_fts_bn=kwargs.get('use_fts_bn', True),
use_counts=kwargs.get('use_counts', True),
for_inference=kwargs.get('for_inference', False)
)
model_info = {
'input_names': list(data_config.input_names),
'input_shapes': {k: ((1,) + s[1:]) for k, s in data_config.input_shapes.items()},
'output_names': ['softmax'],
'dynamic_axes': {**{k: {0: 'N', 2: 'n_' + k.split('_')[0]} for k in data_config.input_names}, **{'softmax': {0: 'N'}}},
}
return model, model_info
def get_loss(data_config, **kwargs):
return torch.nn.CrossEntropyLoss()
import torch
from weaver.nn.model.ParticleTransformer import ParticleTransformer
from weaver.utils.logger import _logger
'''
Link to the full model implementation:
https://github.com/hqucms/weaver-core/blob/main/weaver/nn/model/ParticleTransformer.py
'''
class ParticleTransformerWrapper(torch.nn.Module):
def __init__(self, **kwargs) -> None:
super().__init__()
self.mod = ParticleTransformer(**kwargs)
@torch.jit.ignore
def no_weight_decay(self):
return {'mod.cls_token', }
def forward(self, points, features, lorentz_vectors, mask):
return self.mod(features, v=lorentz_vectors, mask=mask)
def get_model(data_config, **kwargs):
cfg = dict(
input_dim=len(data_config.input_dicts['pf_features']),
num_classes=len(data_config.label_value),
# network configurations
pair_input_dim=4,
use_pre_activation_pair=False,
embed_dims=[128, 512, 128],
pair_embed_dims=[64, 64, 64],
num_heads=8,
num_layers=8,
num_cls_layers=2,
block_params=None,
cls_block_params={'dropout': 0, 'attn_dropout': 0, 'activation_dropout': 0},
fc_params=[],
activation='gelu',
# misc
trim=True,
for_inference=False,
)
cfg.update(**kwargs)
_logger.info('Model config: %s' % str(cfg))
model = ParticleTransformerWrapper(**cfg)
model_info = {
'input_names': list(data_config.input_names),
'input_shapes': {k: ((1,) + s[1:]) for k, s in data_config.input_shapes.items()},
'output_names': ['softmax'],
'dynamic_axes': {**{k: {0: 'N', 2: 'n_' + k.split('_')[0]} for k in data_config.input_names}, **{'softmax': {0: 'N'}}},
}
return model, model_info
def get_loss(data_config, **kwargs):
return torch.nn.CrossEntropyLoss()
import torch
import torch.nn as nn
from weaver.nn.model.ParticleTransformer import ParticleTransformer
from weaver.utils.logger import _logger
'''
Link to the full model implementation:
https://github.com/hqucms/weaver-core/blob/main/weaver/nn/model/ParticleTransformer.py
'''
class ParticleTransformerWrapper(nn.Module):
def __init__(self, **kwargs) -> None:
super().__init__()
in_dim = kwargs['embed_dims'][-1]
fc_params = kwargs.pop('fc_params')
num_classes = kwargs.pop('num_classes')
self.for_inference = kwargs['for_inference']
fcs = []
for out_dim, drop_rate in fc_params:
fcs.append(nn.Sequential(nn.Linear(in_dim, out_dim), nn.ReLU(), nn.Dropout(drop_rate)))
in_dim = out_dim
fcs.append(nn.Linear(in_dim, num_classes))
self.fc = nn.Sequential(*fcs)
kwargs['num_classes'] = None
kwargs['fc_params'] = None
self.mod = ParticleTransformer(**kwargs)
@torch.jit.ignore
def no_weight_decay(self):
return {'mod.cls_token', }
def forward(self, points, features, lorentz_vectors, mask):
x_cls = self.mod(features, v=lorentz_vectors, mask=mask)
output = self.fc(x_cls)
if self.for_inference:
output = torch.softmax(output, dim=1)
return output
def get_model(data_config, **kwargs):
cfg = dict(
input_dim=len(data_config.input_dicts['pf_features']),
num_classes=len(data_config.label_value),
# network configurations
pair_input_dim=4,
use_pre_activation_pair=False,
embed_dims=[128, 512, 128],
pair_embed_dims=[64, 64, 64],
num_heads=8,
num_layers=8,
num_cls_layers=2,
block_params=None,
cls_block_params={'dropout': 0, 'attn_dropout': 0, 'activation_dropout': 0},
fc_params=[],
activation='gelu',
# misc
trim=True,
for_inference=False,
)
cfg.update(**kwargs)
_logger.info('Model config: %s' % str(cfg))
model = ParticleTransformerWrapper(**cfg)
model_info = {
'input_names': list(data_config.input_names),
'input_shapes': {k: ((1,) + s[1:]) for k, s in data_config.input_shapes.items()},
'output_names': ['softmax'],
'dynamic_axes': {**{k: {0: 'N', 2: 'n_' + k.split('_')[0]} for k in data_config.input_names}, **{'softmax': {0: 'N'}}},
}
return model, model_info
def get_loss(data_config, **kwargs):
return torch.nn.CrossEntropyLoss()
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import awkward as ak\n",
"import uproot\n",
"import vector\n",
"vector.register_awkward()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import shutil\n",
"import zipfile\n",
"import tarfile\n",
"import urllib\n",
"import requests\n",
"from tqdm import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def _download(url, fname, chunk_size=1024):\n",
" '''https://gist.github.com/yanqd0/c13ed29e29432e3cf3e7c38467f42f51'''\n",
" resp = requests.get(url, stream=True)\n",
" total = int(resp.headers.get('content-length', 0))\n",
" with open(fname, 'wb') as file, tqdm(\n",
" desc=fname,\n",
" total=total,\n",
" unit='iB',\n",
" unit_scale=True,\n",
" unit_divisor=1024,\n",
" ) as bar:\n",
" for data in resp.iter_content(chunk_size=chunk_size):\n",
" size = file.write(data)\n",
" bar.update(size)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Download the example file\n",
"example_file = 'JetClass_example_100k.root'\n",
"if not os.path.exists(example_file):\n",
" _download('https://hqu.web.cern.ch/datasets/JetClass/example/JetClass_example_100k.root', example_file)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Exploring the file"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# Load the content from the file\n",
"tree = uproot.open(example_file)['tree']"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"name | typename | interpretation \n",
"---------------------+--------------------------+-------------------------------\n",
"part_px | std::vector<float> | AsJagged(AsDtype('>f4'), he...\n",
"part_py | std::vector<float> | AsJagged(AsDtype('>f4'), he...\n",
"part_pz | std::vector<float> | AsJagged(AsDtype('>f4'), he...\n",
"part_energy | std::vector<float> | AsJagged(AsDtype('>f4'), he...\n",
"part_deta | std::vector<float> | AsJagged(AsDtype('>f4'), he...\n",
"part_dphi | std::vector<float> | AsJagged(AsDtype('>f4'), he...\n",
"part_d0val | std::vector<float> | AsJagged(AsDtype('>f4'), he...\n",
"part_d0err | std::vector<float> | AsJagged(AsDtype('>f4'), he...\n",
"part_dzval | std::vector<float> | AsJagged(AsDtype('>f4'), he...\n",
"part_dzerr | std::vector<float> | AsJagged(AsDtype('>f4'), he...\n",
"part_charge | std::vector<float> | AsJagged(AsDtype('>f4'), he...\n",
"part_isChargedHadron | std::vector<int32_t> | AsJagged(AsDtype('>i4'), he...\n",
"part_isNeutralHadron | std::vector<int32_t> | AsJagged(AsDtype('>i4'), he...\n",
"part_isPhoton | std::vector<int32_t> | AsJagged(AsDtype('>i4'), he...\n",
"part_isElectron | std::vector<int32_t> | AsJagged(AsDtype('>i4'), he...\n",
"part_isMuon | std::vector<int32_t> | AsJagged(AsDtype('>i4'), he...\n",
"label_QCD | float | AsDtype('>f4')\n",
"label_Hbb | bool | AsDtype('bool')\n",
"label_Hcc | bool | AsDtype('bool')\n",
"label_Hgg | bool | AsDtype('bool')\n",
"label_H4q | bool | AsDtype('bool')\n",
"label_Hqql | bool | AsDtype('bool')\n",
"label_Zqq | int32_t | AsDtype('>i4')\n",
"label_Wqq | int32_t | AsDtype('>i4')\n",
"label_Tbqq | int32_t | AsDtype('>i4')\n",
"label_Tbl | int32_t | AsDtype('>i4')\n",
"jet_pt | float | AsDtype('>f4')\n",
"jet_eta | float | AsDtype('>f4')\n",
"jet_phi | float | AsDtype('>f4')\n",
"jet_energy | float | AsDtype('>f4')\n",
"jet_nparticles | float | AsDtype('>f4')\n",
"jet_sdmass | float | AsDtype('>f4')\n",
"jet_tau1 | float | AsDtype('>f4')\n",
"jet_tau2 | float | AsDtype('>f4')\n",
"jet_tau3 | float | AsDtype('>f4')\n",
"jet_tau4 | float | AsDtype('>f4')\n",
"aux_genpart_eta | float | AsDtype('>f4')\n",
"aux_genpart_phi | float | AsDtype('>f4')\n",
"aux_genpart_pid | float | AsDtype('>f4')\n",
"aux_genpart_pt | float | AsDtype('>f4')\n",
"aux_truth_match | float | AsDtype('>f4')\n"
]
}
],
"source": [
"# Display the content of the \"tree\"\n",
"tree.show()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# Load all arrays in the tree\n",
"# Each array is a column of the table\n",
"table = tree.arrays()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Arrays of a scalar type (bool/int/float) can be converted to a numpy array directly, e.g.\n",
"table['label_QCD'].to_numpy()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Array [[-125, -91.1, ... -0.735, -0.694]] type='100000 * var * float32'>"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Arrays of a vector type are loaded as a JaggedArray that has varying elements per row\n",
"table['part_px']\n",
"\n",
"# A JaggedArray can be (zero-) padded to become a regular numpy array (see later)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# Construct a Lorentz 4-vector from the (px, py, pz, energy) arrays\n",
"p4 = vector.zip({'px': table['part_px'], 'py': table['part_py'], 'pz': table['part_pz'], 'energy': table['part_energy']})"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Array [[140, 95.3, 87.8, ... 1.3, 0.919]] type='100000 * var * float32'>"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Get the transverse momentum (pt)\n",
"p4.pt"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Array [[-0.254, -0.403, ... -0.857, -0.935]] type='100000 * var * float32'>"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Get the pseudorapidity (eta)\n",
"p4.eta"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Array [[2.67, 2.84, 2.81, ... -2.17, -2.43]] type='100000 * var * float32'>"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Get the azimuth angle (phi)\n",
"p4.phi"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"def _pad(a, maxlen, value=0, dtype='float32'):\n",
" if isinstance(a, np.ndarray) and a.ndim >= 2 and a.shape[1] == maxlen:\n",
" return a\n",
" elif isinstance(a, ak.Array):\n",
" if a.ndim == 1:\n",
" a = ak.unflatten(a, 1)\n",
" a = ak.fill_none(ak.pad_none(a, maxlen, clip=True), value)\n",
" return ak.values_astype(a, dtype)\n",
" else:\n",
" x = (np.ones((len(a), maxlen)) * value).astype(dtype)\n",
" for idx, s in enumerate(a):\n",
" if not len(s):\n",
" continue\n",
" trunc = s[:maxlen].astype(dtype)\n",
" x[idx, :len(trunc)] = trunc\n",
" return x\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[140.19296 , 95.284584, 87.84807 , ..., 0. , 0. ,\n",
" 0. ],\n",
" [244.67009 , 62.332603, 45.159416, ..., 0. , 0. ,\n",
" 0. ],\n",
" [143.15791 , 91.48589 , 25.372644, ..., 0. , 0. ,\n",
" 0. ],\n",
" ...,\n",
" [157.69547 , 101.245445, 79.816284, ..., 0. , 0. ,\n",
" 0. ],\n",
" [ 88.65814 , 80.69194 , 79.14036 , ..., 0. , 0. ,\n",
" 0. ],\n",
" [171.13641 , 121.71926 , 59.68036 , ..., 0. , 0. ,\n",
" 0. ]], dtype=float32)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Apply zero-padding and convert to a numpy array\n",
"_pad(p4.pt, maxlen=128).to_numpy()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Constructing features and labels"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As you see previously with `tree.show()`, there are four groups of arrays with different prefixes:\n",
" - `part_*`: JaggedArrays with features for each particle in a jet. These (and features constrcuted from them) are what we use for training in the Particle Transformer paper.\n",
" - `label_*`: 1D numpy arrays one-hot truth labels for each jet. These are the target of the training.\n",
" - *[Not used in the Particle Transformer paper]* `jet_*`: 1D numpy array with (high-level) features for each jet. These can also be used in the training, but since they are constructed from the particle-level features, it is not expected that they bring additional performance improvement.\n",
" - *[Not used in the Particle Transformer paper]* `aux_*`: auxiliary truth information about the simulated particles for additional studies / interpretations. **SHOULD NOT be used in the training of any classifier.**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The code below illustrates how the input features and labels are constructed in the Particle Transformer paper.\n",
"\n",
"(See also the yaml configuration: https://github.com/jet-universe/particle_transformer/blob/main/data/JetClass/JetClass_full.yaml)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"def _clip(a, a_min, a_max):\n",
" try:\n",
" return np.clip(a, a_min, a_max)\n",
" except ValueError:\n",
" return ak.unflatten(np.clip(ak.flatten(a), a_min, a_max), ak.num(a))"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"def build_features_and_labels(tree, transform_features=True):\n",
" \n",
" # load arrays from the tree\n",
" a = tree.arrays(filter_name=['part_*', 'jet_pt', 'jet_energy', 'label_*'])\n",
"\n",
" # compute new features\n",
" a['part_mask'] = ak.ones_like(a['part_energy'])\n",
" a['part_pt'] = np.hypot(a['part_px'], a['part_py'])\n",
" a['part_pt_log'] = np.log(a['part_pt'])\n",
" a['part_e_log'] = np.log(a['part_energy'])\n",
" a['part_logptrel'] = np.log(a['part_pt']/a['jet_pt'])\n",
" a['part_logerel'] = np.log(a['part_energy']/a['jet_energy'])\n",
" a['part_deltaR'] = np.hypot(a['part_deta'], a['part_dphi'])\n",
" a['part_d0'] = np.tanh(a['part_d0val'])\n",
" a['part_dz'] = np.tanh(a['part_dzval'])\n",
"\n",
" # apply standardization\n",
" if transform_features:\n",
" a['part_pt_log'] = (a['part_pt_log'] - 1.7) * 0.7\n",
" a['part_e_log'] = (a['part_e_log'] - 2.0) * 0.7\n",
" a['part_logptrel'] = (a['part_logptrel'] - (-4.7)) * 0.7\n",
" a['part_logerel'] = (a['part_logerel'] - (-4.7)) * 0.7\n",
" a['part_deltaR'] = (a['part_deltaR'] - 0.2) * 4.0\n",
" a['part_d0err'] = _clip(a['part_d0err'], 0, 1)\n",
" a['part_dzerr'] = _clip(a['part_dzerr'], 0, 1)\n",
"\n",
" feature_list = {\n",
" 'pf_points': ['part_deta', 'part_dphi'], # not used in ParT\n",
" 'pf_features': [\n",
" 'part_pt_log', \n",
" 'part_e_log',\n",
" 'part_logptrel',\n",
" 'part_logerel',\n",
" 'part_deltaR',\n",
" 'part_charge',\n",
" 'part_isChargedHadron',\n",
" 'part_isNeutralHadron',\n",
" 'part_isPhoton',\n",
" 'part_isElectron',\n",
" 'part_isMuon',\n",
" 'part_d0',\n",
" 'part_d0err',\n",
" 'part_dz',\n",
" 'part_dzerr',\n",
" 'part_deta',\n",
" 'part_dphi',\n",
" ],\n",
" 'pf_vectors': [\n",
" 'part_px',\n",
" 'part_py',\n",
" 'part_pz',\n",
" 'part_energy',\n",
" ],\n",
" 'pf_mask': ['part_mask']\n",
" }\n",
"\n",
" out = {}\n",
" for k, names in feature_list.items():\n",
" out[k] = np.stack([_pad(a[n], maxlen=128).to_numpy() for n in names], axis=1)\n",
"\n",
" label_list = ['label_QCD', 'label_Hbb', 'label_Hcc', 'label_Hgg', 'label_H4q', 'label_Hqql', 'label_Zqq', 'label_Wqq', 'label_Tbqq', 'label_Tbl']\n",
" out['label'] = np.stack([a[n].to_numpy().astype('int') for n in label_list], axis=1)\n",
" \n",
" return out"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"text/plain": [
"{'pf_points': array([[[-0.07242048, 0.07607916, 0.08601749, ..., 0. ,\n",
" 0. , 0. ],\n",
" [-0.08581114, 0.09253383, 0.06340456, ..., 0. ,\n",
" 0. , 0. ]],\n",
" \n",
" [[ 0.01535046, -0.00294232, 0.03290105, ..., 0. ,\n",
" 0. , 0. ],\n",
" [-0.04896092, -0.04723394, -0.06385756, ..., 0. ,\n",
" 0. , 0. ]],\n",
" \n",
" [[-0.13630104, -0.14928365, -0.17035806, ..., 0. ,\n",
" 0. , 0. ],\n",
" [-0.01668766, -0.02666983, -0.01680285, ..., 0. ,\n",
" 0. , 0. ]],\n",
" \n",
" ...,\n",
" \n",
" [[ 0.07362503, 0.09081 , -0.15697095, ..., 0. ,\n",
" 0. , 0. ],\n",
" [ 0.01177871, 0.02063447, -0.01410705, ..., 0. ,\n",
" 0. , 0. ]],\n",
" \n",
" [[ 0.03936064, 0.04921722, 0.03772318, ..., 0. ,\n",
" 0. , 0. ],\n",
" [ 0.04601908, 0.04276264, 0.04541695, ..., 0. ,\n",
" 0. , 0. ]],\n",
" \n",
" [[-0.08016402, 0.19889337, 0.13402718, ..., 0. ,\n",
" 0. , 0. ],\n",
" [-0.00388956, 0.10787535, 0.15420079, ..., 0. ,\n",
" 0. , 0. ]]], dtype=float32),\n",
" 'pf_features': array([[[ 2.27011395e+00, 1.99980760e+00, 1.94292617e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [ 2.08252525e+00, 1.84514868e+00, 1.79095721e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [ 2.06154871e+00, 1.79124260e+00, 1.73436105e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" ...,\n",
" [ 3.18000019e-02, 0.00000000e+00, 0.00000000e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [-7.24204779e-02, 7.60791600e-02, 8.60174894e-02, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [-8.58111382e-02, 9.25338268e-02, 6.34045601e-02, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]],\n",
" \n",
" [[ 2.65993762e+00, 1.70273900e+00, 1.47713912e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [ 2.98029804e+00, 2.01181531e+00, 1.80837512e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [ 2.75114131e+00, 1.79394329e+00, 1.56834316e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" ...,\n",
" [ 3.08999997e-02, 3.42999995e-02, 0.00000000e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [ 1.53504610e-02, -2.94232368e-03, 3.29010487e-02, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [-4.89609241e-02, -4.72339392e-02, -6.38575554e-02, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]],\n",
" \n",
" [[ 2.28476381e+00, 1.97132933e+00, 1.07357013e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [ 2.64549851e+00, 2.32392597e+00, 1.41300690e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [ 2.40925336e+00, 2.09581900e+00, 1.19805980e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" ...,\n",
" [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [-1.36301041e-01, -1.49283648e-01, -1.70358062e-01, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [-1.66876614e-02, -2.66698301e-02, -1.68028474e-02, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]],\n",
" \n",
" ...,\n",
" \n",
" [[ 2.35246587e+00, 2.04228354e+00, 1.87580907e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [ 2.28663683e+00, 1.98351204e+00, 1.72960627e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [ 2.36275625e+00, 2.05257368e+00, 1.88609958e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" ...,\n",
" [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [ 7.36250281e-02, 9.08100009e-02, -1.56970948e-01, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [ 1.17787123e-02, 2.06344724e-02, -1.41070485e-02, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]],\n",
" \n",
" [[ 1.94935155e+00, 1.88344717e+00, 1.86985600e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [ 1.86683977e+00, 1.80477130e+00, 1.78671205e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [ 2.06553054e+00, 1.99962616e+00, 1.98603511e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" ...,\n",
" [ 2.83000004e-02, 3.15999985e-02, 3.15999985e-02, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [ 3.93606424e-02, 4.92172241e-02, 3.77231836e-02, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [ 4.60190773e-02, 4.27626371e-02, 4.54169512e-02, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]],\n",
" \n",
" [[ 2.40972257e+00, 2.17120194e+00, 1.67230213e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [ 2.42246366e+00, 2.33064985e+00, 1.79561615e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [ 2.24142742e+00, 2.00290704e+00, 1.50400686e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" ...,\n",
" [ 0.00000000e+00, 3.42999995e-02, 0.00000000e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [-8.01640153e-02, 1.98893368e-01, 1.34027183e-01, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [-3.88956070e-03, 1.07875347e-01, 1.54200792e-01, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]],\n",
" dtype=float32),\n",
" 'pf_vectors': array([[[-124.57671 , -91.08083 , -83.18519 , ..., 0. ,\n",
" 0. , 0. ],\n",
" [ 64.3017 , 27.989893, 28.240173, ..., 0. ,\n",
" 0. , 0. ],\n",
" [ -36.05099 , -39.437183, -37.305996, ..., 0. ,\n",
" 0. , 0. ],\n",
" [ 144.75407 , 103.123436, 95.441185, ..., 0. ,\n",
" 0. , 0. ]],\n",
" \n",
" [[ 110.01701 , 27.931944, 20.90474 , ..., 0. ,\n",
" 0. , 0. ],\n",
" [ 218.53995 , 55.72396 , 40.02955 , ..., 0. ,\n",
" 0. , 0. ],\n",
" [-461.0496 , -115.04492 , -86.80113 , ..., 0. ,\n",
" 0. , 0. ],\n",
" [ 521.9484 , 130.84612 , 97.84585 , ..., 0. ,\n",
" 0. , 0. ]],\n",
" \n",
" [[ 128.54211 , 81.7395 , 22.78092 , ..., 0. ,\n",
" 0. , 0. ],\n",
" [ -63.016777, -41.089207, -11.171425, ..., 0. ,\n",
" 0. , 0. ],\n",
" [ 290.1306 , 182.74101 , 49.498 , ..., 0. ,\n",
" 0. , 0. ],\n",
" [ 323.52737 , 204.36229 , 55.622147, ..., 0. ,\n",
" 0. , 0. ]],\n",
" \n",
" ...,\n",
" \n",
" [[ 96.53474 , 61.266937, 50.477467, ..., 0. ,\n",
" 0. , 0. ],\n",
" [ 124.695244, 80.60399 , 61.8277 , ..., 0. ,\n",
" 0. , 0. ],\n",
" [-112.58496 , -74.43169 , -35.690136, ..., 0. ,\n",
" 0. , 0. ],\n",
" [ 193.76076 , 125.66112 , 87.4324 , ..., 0. ,\n",
" 0. , 0. ]],\n",
" \n",
" [[ -32.93902 , -30.22315 , -29.447128, ..., 0. ,\n",
" 0. , 0. ],\n",
" [ -82.31213 , -74.818115, -73.4579 , ..., 0. ,\n",
" 0. , 0. ],\n",
" [ -58.771416, -54.44747 , -52.30668 , ..., 0. ,\n",
" 0. , 0. ],\n",
" [ 106.369 , 97.34339 , 94.864136, ..., 0. ,\n",
" 0. , 0. ]],\n",
" \n",
" [[-121.577835, -76.37734 , -35.256603, ..., 0. ,\n",
" 0. , 0. ],\n",
" [-120.44294 , -94.773834, -48.15306 , ..., 0. ,\n",
" 0. , 0. ],\n",
" [-161.4199 , -166.60855 , -75.29501 , ..., 0. ,\n",
" 0. , 0. ],\n",
" [ 235.25317 , 206.3347 , 96.07854 , ..., 0. ,\n",
" 0. , 0. ]]], dtype=float32),\n",
" 'pf_mask': array([[[1., 1., 1., ..., 0., 0., 0.]],\n",
" \n",
" [[1., 1., 1., ..., 0., 0., 0.]],\n",
" \n",
" [[1., 1., 1., ..., 0., 0., 0.]],\n",
" \n",
" ...,\n",
" \n",
" [[1., 1., 1., ..., 0., 0., 0.]],\n",
" \n",
" [[1., 1., 1., ..., 0., 0., 0.]],\n",
" \n",
" [[1., 1., 1., ..., 0., 0., 0.]]], dtype=float32),\n",
" 'label': array([[0, 1, 0, ..., 0, 0, 0],\n",
" [0, 1, 0, ..., 0, 0, 0],\n",
" [0, 1, 0, ..., 0, 0, 0],\n",
" ...,\n",
" [0, 0, 0, ..., 0, 0, 0],\n",
" [0, 0, 0, ..., 0, 0, 0],\n",
" [0, 0, 0, ..., 0, 0, 0]])}"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"build_features_and_labels(tree)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
#!/bin/bash
set -x
source env.sh
echo "args: $@"
# set the dataset dir via `DATADIR_QuarkGluon`
DATADIR=${DATADIR_QuarkGluon}
[[ -z $DATADIR ]] && DATADIR='./datasets/QuarkGluon'
# set a comment via `COMMENT`
suffix=${COMMENT}
model=$1
extraopts=""
modelopts="networks/example_ParticleTransformer.py" # ParT
# /public/home/mashun/gaonengsuo/particle_transformer/training/QuarkGluon/ParT/20241121-093112_example_ParticleTransformer_ranger_lr0.001_batch512/net_best_epoch_state.pt
weaver --predict \
--data-test "${DATADIR}/test_file_*.parquet" \
--data-config data/QuarkGluon/qg_${FEATURE_TYPE}.yaml \
--network-config $modelopts \
--model-prefix training/QuarkGluon/ParT/20241121-093112_example_ParticleTransformer_ranger_lr0.001_batch512/net_best_epoch_state.pt\
--batch-size 512 \
--gpus 0 \
--predict-output pred.root \
#!/bin/bash
set -x
source env.sh
echo "args: $@"
# set the dataset dir via `DATADIR_JetClass`
DATADIR=${DATADIR_JetClass}
[[ -z $DATADIR ]] && DATADIR='./datasets/JetClass'
# set a comment via `COMMENT`
suffix=${COMMENT}
# set the number of gpus for DDP training via `DDP_NGPUS`
NGPUS=${DDP_NGPUS}
[[ -z $NGPUS ]] && NGPUS=1
if ((NGPUS > 1)); then
CMD="torchrun --standalone --nnodes=1 --nproc_per_node=$NGPUS -- $(which weaver) --backend nccl"
else
CMD="weaver"
fi
epochs=1
samples_per_epoch=$((10000 * 1024 / $NGPUS))
samples_per_epoch_val=$((10000 * 128))
dataopts="--num-workers 2 --fetch-step 0.01"
# PN, PFN, PCNN, ParT
model=$1
if [[ "$model" == "ParT" ]]; then
modelopts="networks/example_ParticleTransformer.py --use-amp"
batchopts="--batch-size 512 --start-lr 1e-3"
elif [[ "$model" == "PN" ]]; then
modelopts="networks/example_ParticleNet.py"
batchopts="--batch-size 512 --start-lr 1e-2"
elif [[ "$model" == "PFN" ]]; then
modelopts="networks/example_PFN.py"
batchopts="--batch-size 4096 --start-lr 2e-2"
elif [[ "$model" == "PCNN" ]]; then
modelopts="networks/example_PCNN.py"
batchopts="--batch-size 4096 --start-lr 2e-2"
else
echo "Invalid model $model!"
exit 1
fi
# "kin", "kinpid", "full"
FEATURE_TYPE=$2
[[ -z ${FEATURE_TYPE} ]] && FEATURE_TYPE="full"
if ! [[ "${FEATURE_TYPE}" =~ ^(full|kin|kinpid)$ ]]; then
echo "Invalid feature type ${FEATURE_TYPE}!"
exit 1
fi
# currently only Pythia
SAMPLE_TYPE=Pythia
$CMD \
--data-train \
"HToBB:${DATADIR}/${SAMPLE_TYPE}/train_100M/HToBB_*.root" \
"HToCC:${DATADIR}/${SAMPLE_TYPE}/train_100M/HToCC_*.root" \
"HToGG:${DATADIR}/${SAMPLE_TYPE}/train_100M/HToGG_*.root" \
"HToWW2Q1L:${DATADIR}/${SAMPLE_TYPE}/train_100M/HToWW2Q1L_*.root" \
"HToWW4Q:${DATADIR}/${SAMPLE_TYPE}/train_100M/HToWW4Q_*.root" \
"TTBar:${DATADIR}/${SAMPLE_TYPE}/train_100M/TTBar_*.root" \
"TTBarLep:${DATADIR}/${SAMPLE_TYPE}/train_100M/TTBarLep_*.root" \
"WToQQ:${DATADIR}/${SAMPLE_TYPE}/train_100M/WToQQ_*.root" \
"ZToQQ:${DATADIR}/${SAMPLE_TYPE}/train_100M/ZToQQ_*.root" \
"ZJetsToNuNu:${DATADIR}/${SAMPLE_TYPE}/train_100M/ZJetsToNuNu_*.root" \
--data-val "${DATADIR}/${SAMPLE_TYPE}/val_5M/*.root" \
--data-test \
"HToBB:${DATADIR}/${SAMPLE_TYPE}/test_20M/HToBB_*.root" \
"HToCC:${DATADIR}/${SAMPLE_TYPE}/test_20M/HToCC_*.root" \
"HToGG:${DATADIR}/${SAMPLE_TYPE}/test_20M/HToGG_*.root" \
"HToWW2Q1L:${DATADIR}/${SAMPLE_TYPE}/test_20M/HToWW2Q1L_*.root" \
"HToWW4Q:${DATADIR}/${SAMPLE_TYPE}/test_20M/HToWW4Q_*.root" \
"TTBar:${DATADIR}/${SAMPLE_TYPE}/test_20M/TTBar_*.root" \
"TTBarLep:${DATADIR}/${SAMPLE_TYPE}/test_20M/TTBarLep_*.root" \
"WToQQ:${DATADIR}/${SAMPLE_TYPE}/test_20M/WToQQ_*.root" \
"ZToQQ:${DATADIR}/${SAMPLE_TYPE}/test_20M/ZToQQ_*.root" \
"ZJetsToNuNu:${DATADIR}/${SAMPLE_TYPE}/test_20M/ZJetsToNuNu_*.root" \
--data-config data/JetClass/JetClass_${FEATURE_TYPE}.yaml --network-config $modelopts \
--model-prefix training/JetClass/${SAMPLE_TYPE}/${FEATURE_TYPE}/${model}/{auto}${suffix}/net \
$dataopts $batchopts \
--samples-per-epoch ${samples_per_epoch} --samples-per-epoch-val ${samples_per_epoch_val} --num-epochs $epochs --gpus 0 \
--optimizer ranger \
--log logs/JetClass_${SAMPLE_TYPE}_${FEATURE_TYPE}_${model}_{auto}${suffix}.log \
--predict-output pred.root \
--tensorboard JetClass_${SAMPLE_TYPE}_${FEATURE_TYPE}_${model}${suffix} \
"${@:3}"
#!/bin/bash
set -x
source env.sh
echo "args: $@"
# set the dataset dir via `DATADIR_QuarkGluon`
DATADIR=${DATADIR_QuarkGluon}
[[ -z $DATADIR ]] && DATADIR='./datasets/QuarkGluon'
# set a comment via `COMMENT`
suffix=${COMMENT}
# PN, PFN, PCNN, ParT
model=$1
extraopts=""
if [[ "$model" == "ParT" ]]; then
modelopts="networks/example_ParticleTransformer.py --use-amp --optimizer-option weight_decay 0.01"
lr="1e-3"
elif [[ "$model" == "ParT-FineTune" ]]; then
modelopts="networks/example_ParticleTransformer_finetune.py --use-amp --optimizer-option weight_decay 0.01"
lr="1e-4"
extraopts="--optimizer-option lr_mult (\"fc.*\",50) --lr-scheduler none"
elif [[ "$model" == "PN" ]]; then
modelopts="networks/example_ParticleNet.py"
lr="1e-2"
elif [[ "$model" == "PN-FineTune" ]]; then
modelopts="networks/example_ParticleNet_finetune.py"
lr="1e-3"
extraopts="--optimizer-option lr_mult (\"fc_out.*\",50) --lr-scheduler none"
elif [[ "$model" == "PFN" ]]; then
modelopts="networks/example_PFN.py"
lr="2e-2"
extraopts="--batch-size 4096"
elif [[ "$model" == "PCNN" ]]; then
modelopts="networks/example_PCNN.py"
lr="2e-2"
extraopts="--batch-size 4096"
else
echo "Invalid model $model!"
exit 1
fi
# "kin", "kinpid", "kinpidplus"
FEATURE_TYPE=$2
[[ -z ${FEATURE_TYPE} ]] && FEATURE_TYPE="kinpid"
if [[ "${FEATURE_TYPE}" == "kin" ]]; then
pretrain_type="kin"
elif [[ "${FEATURE_TYPE}" =~ ^(kinpid|kinpidplus)$ ]]; then
pretrain_type="kinpid"
else
echo "Invalid feature type ${FEATURE_TYPE}!"
exit 1
fi
if [[ "$model" == "ParT-FineTune" ]]; then
modelopts+=" --load-model-weights models/ParT_${pretrain_type}.pt"
fi
if [[ "$model" == "PN-FineTune" ]]; then
modelopts+=" --load-model-weights models/ParticleNet_${pretrain_type}.pt"
fi
weaver \
--data-train "${DATADIR}/train_file_*.parquet" \
--data-test "${DATADIR}/test_file_*.parquet" \
--data-config data/QuarkGluon/qg_${FEATURE_TYPE}.yaml --network-config $modelopts \
--model-prefix training/QuarkGluon/${model}/{auto}${suffix}/net \
--num-workers 1 --fetch-step 1 --in-memory --train-val-split 0.8889 \
--batch-size 512 --samples-per-epoch 1600000 --samples-per-epoch-val 200000 --num-epochs 1 --gpus 0 \
--start-lr $lr --optimizer ranger --log logs/QuarkGluon_${model}_{auto}${suffix}.log --predict-output pred.root \
--tensorboard QuarkGluon_${FEATURE_TYPE}_${model}${suffix} \
${extraopts} "${@:3}"
#!/bin/bash
set -x
source env.sh
echo "args: $@"
# set the dataset dir via `DATADIR_TopLandscape`
DATADIR=${DATADIR_TopLandscape}
[[ -z $DATADIR ]] && DATADIR='./datasets/TopLandscape'
# set a comment via `COMMENT`
suffix=${COMMENT}
# PN, PFN, PCNN, ParT
model=$1
extraopts=""
if [[ "$model" == "ParT" ]]; then
modelopts="networks/example_ParticleTransformer.py --use-amp --optimizer-option weight_decay 0.01"
lr="1e-3"
elif [[ "$model" == "ParT-FineTune" ]]; then
modelopts="networks/example_ParticleTransformer_finetune.py --use-amp --optimizer-option weight_decay 0.01"
lr="1e-4"
extraopts="--optimizer-option lr_mult (\"fc.*\",50) --lr-scheduler none --load-model-weights models/ParT_kin.pt"
elif [[ "$model" == "PN" ]]; then
modelopts="networks/example_ParticleNet.py"
lr="1e-2"
elif [[ "$model" == "PN-FineTune" ]]; then
modelopts="networks/example_ParticleNet_finetune.py"
lr="1e-3"
extraopts="--optimizer-option lr_mult (\"fc_out.*\",50) --lr-scheduler none --load-model-weights models/ParticleNet_kin.pt"
elif [[ "$model" == "PFN" ]]; then
modelopts="networks/example_PFN.py"
lr="2e-2"
extraopts="--batch-size 4096"
elif [[ "$model" == "PCNN" ]]; then
modelopts="networks/example_PCNN.py"
lr="2e-2"
extraopts="--batch-size 4096"
else
echo "Invalid model $model!"
exit 1
fi
# "kin"
FEATURE_TYPE=$2
[[ -z ${FEATURE_TYPE} ]] && FEATURE_TYPE="kin"
if [[ "${FEATURE_TYPE}" != "kin" ]]; then
echo "Invalid feature type ${FEATURE_TYPE}!"
exit 1
fi
weaver \
--data-train "${DATADIR}/train_file.parquet" \
--data-val "${DATADIR}/val_file.parquet" \
--data-test "${DATADIR}/test_file.parquet" \
--data-config data/TopLandscape/top_${FEATURE_TYPE}.yaml --network-config $modelopts \
--model-prefix training/TopLandscape/${model}/{auto}${suffix}/net \
--num-workers 1 --fetch-step 1 --in-memory \
--batch-size 512 --samples-per-epoch $((2400 * 512)) --samples-per-epoch-val $((800 * 512)) --num-epochs 20 --gpus 0 \
--start-lr $lr --optimizer ranger --log logs/TopLandscape_${model}_{auto}${suffix}.log --predict-output pred.root \
--tensorboard TopLandscape_${FEATURE_TYPE}_${model}${suffix} \
${extraopts} "${@:3}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment