{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import awkward as ak\n", "import uproot\n", "import vector\n", "vector.register_awkward()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import os\n", "import shutil\n", "import zipfile\n", "import tarfile\n", "import urllib\n", "import requests\n", "from tqdm import tqdm" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def _download(url, fname, chunk_size=1024):\n", " '''https://gist.github.com/yanqd0/c13ed29e29432e3cf3e7c38467f42f51'''\n", " resp = requests.get(url, stream=True)\n", " total = int(resp.headers.get('content-length', 0))\n", " with open(fname, 'wb') as file, tqdm(\n", " desc=fname,\n", " total=total,\n", " unit='iB',\n", " unit_scale=True,\n", " unit_divisor=1024,\n", " ) as bar:\n", " for data in resp.iter_content(chunk_size=chunk_size):\n", " size = file.write(data)\n", " bar.update(size)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Download the example file\n", "example_file = 'JetClass_example_100k.root'\n", "if not os.path.exists(example_file):\n", " _download('https://hqu.web.cern.ch/datasets/JetClass/example/JetClass_example_100k.root', example_file)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Exploring the file" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Load the content from the file\n", "tree = uproot.open(example_file)['tree']" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "name | typename | interpretation \n", "---------------------+--------------------------+-------------------------------\n", "part_px | std::vector | AsJagged(AsDtype('>f4'), he...\n", "part_py | std::vector | AsJagged(AsDtype('>f4'), he...\n", "part_pz | std::vector | AsJagged(AsDtype('>f4'), he...\n", "part_energy | std::vector | AsJagged(AsDtype('>f4'), he...\n", "part_deta | std::vector | AsJagged(AsDtype('>f4'), he...\n", "part_dphi | std::vector | AsJagged(AsDtype('>f4'), he...\n", "part_d0val | std::vector | AsJagged(AsDtype('>f4'), he...\n", "part_d0err | std::vector | AsJagged(AsDtype('>f4'), he...\n", "part_dzval | std::vector | AsJagged(AsDtype('>f4'), he...\n", "part_dzerr | std::vector | AsJagged(AsDtype('>f4'), he...\n", "part_charge | std::vector | AsJagged(AsDtype('>f4'), he...\n", "part_isChargedHadron | std::vector | AsJagged(AsDtype('>i4'), he...\n", "part_isNeutralHadron | std::vector | AsJagged(AsDtype('>i4'), he...\n", "part_isPhoton | std::vector | AsJagged(AsDtype('>i4'), he...\n", "part_isElectron | std::vector | AsJagged(AsDtype('>i4'), he...\n", "part_isMuon | std::vector | AsJagged(AsDtype('>i4'), he...\n", "label_QCD | float | AsDtype('>f4')\n", "label_Hbb | bool | AsDtype('bool')\n", "label_Hcc | bool | AsDtype('bool')\n", "label_Hgg | bool | AsDtype('bool')\n", "label_H4q | bool | AsDtype('bool')\n", "label_Hqql | bool | AsDtype('bool')\n", "label_Zqq | int32_t | AsDtype('>i4')\n", "label_Wqq | int32_t | AsDtype('>i4')\n", "label_Tbqq | int32_t | AsDtype('>i4')\n", "label_Tbl | int32_t | AsDtype('>i4')\n", "jet_pt | float | AsDtype('>f4')\n", "jet_eta | float | AsDtype('>f4')\n", "jet_phi | float | AsDtype('>f4')\n", "jet_energy | float | AsDtype('>f4')\n", "jet_nparticles | float | AsDtype('>f4')\n", "jet_sdmass | float | AsDtype('>f4')\n", "jet_tau1 | float | AsDtype('>f4')\n", "jet_tau2 | float | AsDtype('>f4')\n", "jet_tau3 | float | AsDtype('>f4')\n", "jet_tau4 | float | AsDtype('>f4')\n", "aux_genpart_eta | float | AsDtype('>f4')\n", "aux_genpart_phi | float | AsDtype('>f4')\n", "aux_genpart_pid | float | AsDtype('>f4')\n", "aux_genpart_pt | float | AsDtype('>f4')\n", "aux_truth_match | float | AsDtype('>f4')\n" ] } ], "source": [ "# Display the content of the \"tree\"\n", "tree.show()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Load all arrays in the tree\n", "# Each array is a column of the table\n", "table = tree.arrays()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Arrays of a scalar type (bool/int/float) can be converted to a numpy array directly, e.g.\n", "table['label_QCD'].to_numpy()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Arrays of a vector type are loaded as a JaggedArray that has varying elements per row\n", "table['part_px']\n", "\n", "# A JaggedArray can be (zero-) padded to become a regular numpy array (see later)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# Construct a Lorentz 4-vector from the (px, py, pz, energy) arrays\n", "p4 = vector.zip({'px': table['part_px'], 'py': table['part_py'], 'pz': table['part_pz'], 'energy': table['part_energy']})" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Get the transverse momentum (pt)\n", "p4.pt" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Get the pseudorapidity (eta)\n", "p4.eta" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Get the azimuth angle (phi)\n", "p4.phi" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "def _pad(a, maxlen, value=0, dtype='float32'):\n", " if isinstance(a, np.ndarray) and a.ndim >= 2 and a.shape[1] == maxlen:\n", " return a\n", " elif isinstance(a, ak.Array):\n", " if a.ndim == 1:\n", " a = ak.unflatten(a, 1)\n", " a = ak.fill_none(ak.pad_none(a, maxlen, clip=True), value)\n", " return ak.values_astype(a, dtype)\n", " else:\n", " x = (np.ones((len(a), maxlen)) * value).astype(dtype)\n", " for idx, s in enumerate(a):\n", " if not len(s):\n", " continue\n", " trunc = s[:maxlen].astype(dtype)\n", " x[idx, :len(trunc)] = trunc\n", " return x\n" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[140.19296 , 95.284584, 87.84807 , ..., 0. , 0. ,\n", " 0. ],\n", " [244.67009 , 62.332603, 45.159416, ..., 0. , 0. ,\n", " 0. ],\n", " [143.15791 , 91.48589 , 25.372644, ..., 0. , 0. ,\n", " 0. ],\n", " ...,\n", " [157.69547 , 101.245445, 79.816284, ..., 0. , 0. ,\n", " 0. ],\n", " [ 88.65814 , 80.69194 , 79.14036 , ..., 0. , 0. ,\n", " 0. ],\n", " [171.13641 , 121.71926 , 59.68036 , ..., 0. , 0. ,\n", " 0. ]], dtype=float32)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Apply zero-padding and convert to a numpy array\n", "_pad(p4.pt, maxlen=128).to_numpy()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Constructing features and labels" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you see previously with `tree.show()`, there are four groups of arrays with different prefixes:\n", " - `part_*`: JaggedArrays with features for each particle in a jet. These (and features constrcuted from them) are what we use for training in the Particle Transformer paper.\n", " - `label_*`: 1D numpy arrays one-hot truth labels for each jet. These are the target of the training.\n", " - *[Not used in the Particle Transformer paper]* `jet_*`: 1D numpy array with (high-level) features for each jet. These can also be used in the training, but since they are constructed from the particle-level features, it is not expected that they bring additional performance improvement.\n", " - *[Not used in the Particle Transformer paper]* `aux_*`: auxiliary truth information about the simulated particles for additional studies / interpretations. **SHOULD NOT be used in the training of any classifier.**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The code below illustrates how the input features and labels are constructed in the Particle Transformer paper.\n", "\n", "(See also the yaml configuration: https://github.com/jet-universe/particle_transformer/blob/main/data/JetClass/JetClass_full.yaml)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "def _clip(a, a_min, a_max):\n", " try:\n", " return np.clip(a, a_min, a_max)\n", " except ValueError:\n", " return ak.unflatten(np.clip(ak.flatten(a), a_min, a_max), ak.num(a))" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "def build_features_and_labels(tree, transform_features=True):\n", " \n", " # load arrays from the tree\n", " a = tree.arrays(filter_name=['part_*', 'jet_pt', 'jet_energy', 'label_*'])\n", "\n", " # compute new features\n", " a['part_mask'] = ak.ones_like(a['part_energy'])\n", " a['part_pt'] = np.hypot(a['part_px'], a['part_py'])\n", " a['part_pt_log'] = np.log(a['part_pt'])\n", " a['part_e_log'] = np.log(a['part_energy'])\n", " a['part_logptrel'] = np.log(a['part_pt']/a['jet_pt'])\n", " a['part_logerel'] = np.log(a['part_energy']/a['jet_energy'])\n", " a['part_deltaR'] = np.hypot(a['part_deta'], a['part_dphi'])\n", " a['part_d0'] = np.tanh(a['part_d0val'])\n", " a['part_dz'] = np.tanh(a['part_dzval'])\n", "\n", " # apply standardization\n", " if transform_features:\n", " a['part_pt_log'] = (a['part_pt_log'] - 1.7) * 0.7\n", " a['part_e_log'] = (a['part_e_log'] - 2.0) * 0.7\n", " a['part_logptrel'] = (a['part_logptrel'] - (-4.7)) * 0.7\n", " a['part_logerel'] = (a['part_logerel'] - (-4.7)) * 0.7\n", " a['part_deltaR'] = (a['part_deltaR'] - 0.2) * 4.0\n", " a['part_d0err'] = _clip(a['part_d0err'], 0, 1)\n", " a['part_dzerr'] = _clip(a['part_dzerr'], 0, 1)\n", "\n", " feature_list = {\n", " 'pf_points': ['part_deta', 'part_dphi'], # not used in ParT\n", " 'pf_features': [\n", " 'part_pt_log', \n", " 'part_e_log',\n", " 'part_logptrel',\n", " 'part_logerel',\n", " 'part_deltaR',\n", " 'part_charge',\n", " 'part_isChargedHadron',\n", " 'part_isNeutralHadron',\n", " 'part_isPhoton',\n", " 'part_isElectron',\n", " 'part_isMuon',\n", " 'part_d0',\n", " 'part_d0err',\n", " 'part_dz',\n", " 'part_dzerr',\n", " 'part_deta',\n", " 'part_dphi',\n", " ],\n", " 'pf_vectors': [\n", " 'part_px',\n", " 'part_py',\n", " 'part_pz',\n", " 'part_energy',\n", " ],\n", " 'pf_mask': ['part_mask']\n", " }\n", "\n", " out = {}\n", " for k, names in feature_list.items():\n", " out[k] = np.stack([_pad(a[n], maxlen=128).to_numpy() for n in names], axis=1)\n", "\n", " label_list = ['label_QCD', 'label_Hbb', 'label_Hcc', 'label_Hgg', 'label_H4q', 'label_Hqql', 'label_Zqq', 'label_Wqq', 'label_Tbqq', 'label_Tbl']\n", " out['label'] = np.stack([a[n].to_numpy().astype('int') for n in label_list], axis=1)\n", " \n", " return out" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/plain": [ "{'pf_points': array([[[-0.07242048, 0.07607916, 0.08601749, ..., 0. ,\n", " 0. , 0. ],\n", " [-0.08581114, 0.09253383, 0.06340456, ..., 0. ,\n", " 0. , 0. ]],\n", " \n", " [[ 0.01535046, -0.00294232, 0.03290105, ..., 0. ,\n", " 0. , 0. ],\n", " [-0.04896092, -0.04723394, -0.06385756, ..., 0. ,\n", " 0. , 0. ]],\n", " \n", " [[-0.13630104, -0.14928365, -0.17035806, ..., 0. ,\n", " 0. , 0. ],\n", " [-0.01668766, -0.02666983, -0.01680285, ..., 0. ,\n", " 0. , 0. ]],\n", " \n", " ...,\n", " \n", " [[ 0.07362503, 0.09081 , -0.15697095, ..., 0. ,\n", " 0. , 0. ],\n", " [ 0.01177871, 0.02063447, -0.01410705, ..., 0. ,\n", " 0. , 0. ]],\n", " \n", " [[ 0.03936064, 0.04921722, 0.03772318, ..., 0. ,\n", " 0. , 0. ],\n", " [ 0.04601908, 0.04276264, 0.04541695, ..., 0. ,\n", " 0. , 0. ]],\n", " \n", " [[-0.08016402, 0.19889337, 0.13402718, ..., 0. ,\n", " 0. , 0. ],\n", " [-0.00388956, 0.10787535, 0.15420079, ..., 0. ,\n", " 0. , 0. ]]], dtype=float32),\n", " 'pf_features': array([[[ 2.27011395e+00, 1.99980760e+00, 1.94292617e+00, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", " [ 2.08252525e+00, 1.84514868e+00, 1.79095721e+00, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", " [ 2.06154871e+00, 1.79124260e+00, 1.73436105e+00, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", " ...,\n", " [ 3.18000019e-02, 0.00000000e+00, 0.00000000e+00, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", " [-7.24204779e-02, 7.60791600e-02, 8.60174894e-02, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", " [-8.58111382e-02, 9.25338268e-02, 6.34045601e-02, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]],\n", " \n", " [[ 2.65993762e+00, 1.70273900e+00, 1.47713912e+00, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", " [ 2.98029804e+00, 2.01181531e+00, 1.80837512e+00, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", " [ 2.75114131e+00, 1.79394329e+00, 1.56834316e+00, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", " ...,\n", " [ 3.08999997e-02, 3.42999995e-02, 0.00000000e+00, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", " [ 1.53504610e-02, -2.94232368e-03, 3.29010487e-02, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", " [-4.89609241e-02, -4.72339392e-02, -6.38575554e-02, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]],\n", " \n", " [[ 2.28476381e+00, 1.97132933e+00, 1.07357013e+00, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", " [ 2.64549851e+00, 2.32392597e+00, 1.41300690e+00, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", " [ 2.40925336e+00, 2.09581900e+00, 1.19805980e+00, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", " ...,\n", " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", " [-1.36301041e-01, -1.49283648e-01, -1.70358062e-01, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", " [-1.66876614e-02, -2.66698301e-02, -1.68028474e-02, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]],\n", " \n", " ...,\n", " \n", " [[ 2.35246587e+00, 2.04228354e+00, 1.87580907e+00, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", " [ 2.28663683e+00, 1.98351204e+00, 1.72960627e+00, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", " [ 2.36275625e+00, 2.05257368e+00, 1.88609958e+00, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", " ...,\n", " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", " [ 7.36250281e-02, 9.08100009e-02, -1.56970948e-01, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", " [ 1.17787123e-02, 2.06344724e-02, -1.41070485e-02, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]],\n", " \n", " [[ 1.94935155e+00, 1.88344717e+00, 1.86985600e+00, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", " [ 1.86683977e+00, 1.80477130e+00, 1.78671205e+00, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", " [ 2.06553054e+00, 1.99962616e+00, 1.98603511e+00, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", " ...,\n", " [ 2.83000004e-02, 3.15999985e-02, 3.15999985e-02, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", " [ 3.93606424e-02, 4.92172241e-02, 3.77231836e-02, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", " [ 4.60190773e-02, 4.27626371e-02, 4.54169512e-02, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]],\n", " \n", " [[ 2.40972257e+00, 2.17120194e+00, 1.67230213e+00, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", " [ 2.42246366e+00, 2.33064985e+00, 1.79561615e+00, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", " [ 2.24142742e+00, 2.00290704e+00, 1.50400686e+00, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", " ...,\n", " [ 0.00000000e+00, 3.42999995e-02, 0.00000000e+00, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", " [-8.01640153e-02, 1.98893368e-01, 1.34027183e-01, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", " [-3.88956070e-03, 1.07875347e-01, 1.54200792e-01, ...,\n", " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]],\n", " dtype=float32),\n", " 'pf_vectors': array([[[-124.57671 , -91.08083 , -83.18519 , ..., 0. ,\n", " 0. , 0. ],\n", " [ 64.3017 , 27.989893, 28.240173, ..., 0. ,\n", " 0. , 0. ],\n", " [ -36.05099 , -39.437183, -37.305996, ..., 0. ,\n", " 0. , 0. ],\n", " [ 144.75407 , 103.123436, 95.441185, ..., 0. ,\n", " 0. , 0. ]],\n", " \n", " [[ 110.01701 , 27.931944, 20.90474 , ..., 0. ,\n", " 0. , 0. ],\n", " [ 218.53995 , 55.72396 , 40.02955 , ..., 0. ,\n", " 0. , 0. ],\n", " [-461.0496 , -115.04492 , -86.80113 , ..., 0. ,\n", " 0. , 0. ],\n", " [ 521.9484 , 130.84612 , 97.84585 , ..., 0. ,\n", " 0. , 0. ]],\n", " \n", " [[ 128.54211 , 81.7395 , 22.78092 , ..., 0. ,\n", " 0. , 0. ],\n", " [ -63.016777, -41.089207, -11.171425, ..., 0. ,\n", " 0. , 0. ],\n", " [ 290.1306 , 182.74101 , 49.498 , ..., 0. ,\n", " 0. , 0. ],\n", " [ 323.52737 , 204.36229 , 55.622147, ..., 0. ,\n", " 0. , 0. ]],\n", " \n", " ...,\n", " \n", " [[ 96.53474 , 61.266937, 50.477467, ..., 0. ,\n", " 0. , 0. ],\n", " [ 124.695244, 80.60399 , 61.8277 , ..., 0. ,\n", " 0. , 0. ],\n", " [-112.58496 , -74.43169 , -35.690136, ..., 0. ,\n", " 0. , 0. ],\n", " [ 193.76076 , 125.66112 , 87.4324 , ..., 0. ,\n", " 0. , 0. ]],\n", " \n", " [[ -32.93902 , -30.22315 , -29.447128, ..., 0. ,\n", " 0. , 0. ],\n", " [ -82.31213 , -74.818115, -73.4579 , ..., 0. ,\n", " 0. , 0. ],\n", " [ -58.771416, -54.44747 , -52.30668 , ..., 0. ,\n", " 0. , 0. ],\n", " [ 106.369 , 97.34339 , 94.864136, ..., 0. ,\n", " 0. , 0. ]],\n", " \n", " [[-121.577835, -76.37734 , -35.256603, ..., 0. ,\n", " 0. , 0. ],\n", " [-120.44294 , -94.773834, -48.15306 , ..., 0. ,\n", " 0. , 0. ],\n", " [-161.4199 , -166.60855 , -75.29501 , ..., 0. ,\n", " 0. , 0. ],\n", " [ 235.25317 , 206.3347 , 96.07854 , ..., 0. ,\n", " 0. , 0. ]]], dtype=float32),\n", " 'pf_mask': array([[[1., 1., 1., ..., 0., 0., 0.]],\n", " \n", " [[1., 1., 1., ..., 0., 0., 0.]],\n", " \n", " [[1., 1., 1., ..., 0., 0., 0.]],\n", " \n", " ...,\n", " \n", " [[1., 1., 1., ..., 0., 0., 0.]],\n", " \n", " [[1., 1., 1., ..., 0., 0., 0.]],\n", " \n", " [[1., 1., 1., ..., 0., 0., 0.]]], dtype=float32),\n", " 'label': array([[0, 1, 0, ..., 0, 0, 0],\n", " [0, 1, 0, ..., 0, 0, 0],\n", " [0, 1, 0, ..., 0, 0, 0],\n", " ...,\n", " [0, 0, 0, ..., 0, 0, 0],\n", " [0, 0, 0, ..., 0, 0, 0],\n", " [0, 0, 0, ..., 0, 0, 0]])}" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "build_features_and_labels(tree)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.7" } }, "nbformat": 4, "nbformat_minor": 4 }