{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import awkward" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import logging\n", "logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s: %(message)s')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def stack_arrays(a, keys, axis=-1):\n", " flat_arr = np.stack([a[k].flatten() for k in keys], axis=axis)\n", " return awkward.JaggedArray.fromcounts(a[keys[0]].counts, flat_arr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def pad_array(a, maxlen, value=0., dtype='float32'):\n", " x = (np.ones((len(a), maxlen)) * value).astype(dtype)\n", " for idx, s in enumerate(a):\n", " if not len(s):\n", " continue\n", " trunc = s[:maxlen].astype(dtype)\n", " x[idx, :len(trunc)] = trunc\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Dataset(object):\n", "\n", " def __init__(self, filepath, feature_dict = {}, label='label', pad_len=100, data_format='channel_first'):\n", " self.filepath = filepath\n", " self.feature_dict = feature_dict\n", " if len(feature_dict)==0:\n", " feature_dict['points'] = ['part_etarel', 'part_phirel']\n", " feature_dict['features'] = ['part_pt_log', 'part_e_log', 'part_etarel', 'part_phirel']\n", " feature_dict['mask'] = ['part_pt_log']\n", " self.label = label\n", " self.pad_len = pad_len\n", " assert data_format in ('channel_first', 'channel_last')\n", " self.stack_axis = 1 if data_format=='channel_first' else -1\n", " self._values = {}\n", " self._label = None\n", " self._load()\n", "\n", " def _load(self):\n", " logging.info('Start loading file %s' % self.filepath)\n", " counts = None\n", " with awkward.load(self.filepath) as a:\n", " self._label = a[self.label]\n", " for k in self.feature_dict:\n", " cols = self.feature_dict[k]\n", " if not isinstance(cols, (list, tuple)):\n", " cols = [cols]\n", " arrs = []\n", " for col in cols:\n", " if counts is None:\n", " counts = a[col].counts\n", " else:\n", " assert np.array_equal(counts, a[col].counts)\n", " arrs.append(pad_array(a[col], self.pad_len))\n", " self._values[k] = np.stack(arrs, axis=self.stack_axis)\n", " logging.info('Finished loading file %s' % self.filepath)\n", "\n", "\n", " def __len__(self):\n", " return len(self._label)\n", "\n", " def __getitem__(self, key):\n", " if key==self.label:\n", " return self._label\n", " else:\n", " return self._values[key]\n", " \n", " @property\n", " def X(self):\n", " return self._values\n", " \n", " @property\n", " def y(self):\n", " return self._label\n", "\n", " def shuffle(self, seed=None):\n", " if seed is not None:\n", " np.random.seed(seed)\n", " shuffle_indices = np.arange(self.__len__())\n", " np.random.shuffle(shuffle_indices)\n", " for k in self._values:\n", " self._values[k] = self._values[k][shuffle_indices]\n", " self._label = self._label[shuffle_indices]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_dataset = Dataset('converted/train_file_0.awkd', data_format='channel_last')\n", "val_dataset = Dataset('converted/val_file_0.awkd', data_format='channel_last')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "from tensorflow import keras\n", "from tf_keras_model import get_particle_net, get_particle_net_lite" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "model_type = 'particle_net_lite' # choose between 'particle_net' and 'particle_net_lite'\n", "num_classes = train_dataset.y.shape[1]\n", "input_shapes = {k:train_dataset[k].shape[1:] for k in train_dataset.X}\n", "if 'lite' in model_type:\n", " model = get_particle_net_lite(num_classes, input_shapes)\n", "else:\n", " model = get_particle_net(num_classes, input_shapes)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Training parameters\n", "batch_size = 1024 if 'lite' in model_type else 384\n", "epochs = 30" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def lr_schedule(epoch):\n", " lr = 1e-3\n", " if epoch > 10:\n", " lr *= 0.1\n", " elif epoch > 20:\n", " lr *= 0.01\n", " logging.info('Learning rate: %f'%lr)\n", " return lr" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "model.compile(loss='categorical_crossentropy',\n", " optimizer=keras.optimizers.Adam(learning_rate=lr_schedule(0)),\n", " metrics=['accuracy'])\n", "model.summary()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Prepare model model saving directory.\n", "import os\n", "save_dir = 'model_checkpoints'\n", "model_name = '%s_model.{epoch:03d}.h5' % model_type\n", "if not os.path.isdir(save_dir):\n", " os.makedirs(save_dir)\n", "filepath = os.path.join(save_dir, model_name)\n", "\n", "# Prepare callbacks for model saving and for learning rate adjustment.\n", "checkpoint = keras.callbacks.ModelCheckpoint(filepath=filepath,\n", " monitor='val_acc',\n", " verbose=1,\n", " save_best_only=True)\n", "\n", "lr_scheduler = keras.callbacks.LearningRateScheduler(lr_schedule)\n", "progress_bar = keras.callbacks.ProgbarLogger()\n", "callbacks = [checkpoint, lr_scheduler, progress_bar]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "train_dataset.shuffle()\n", "model.fit(train_dataset.X, train_dataset.y,\n", " batch_size=batch_size,\n", "# epochs=epochs,\n", " epochs=1, # --- train only for 1 epoch here for demonstration ---\n", " validation_data=(val_dataset.X, val_dataset.y),\n", " shuffle=True,\n", " callbacks=callbacks)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }