{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "524620c1",
"metadata": {},
"outputs": [],
"source": [
"#| default_exp tsdataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "15392f6f",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "markdown",
"id": "12fa25a4",
"metadata": {},
"source": [
"# PyTorch Dataset/Loader\n",
"> Torch Dataset for Time Series\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2508f7a9-1433-4ad8-8f2f-0078c6ed6c3c",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"from fastcore.test import test_eq\n",
"from nbdev.showdoc import show_doc\n",
"from neuralforecast.utils import generate_series"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "44065066-e72a-431f-938f-1528adef9fe8",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"import warnings\n",
"from collections.abc import Mapping\n",
"from typing import List, Optional, Union\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import pytorch_lightning as pl\n",
"import torch\n",
"import utilsforecast.processing as ufp\n",
"from torch.utils.data import Dataset, DataLoader\n",
"from utilsforecast.compat import DataFrame, pl_Series"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "323a7a6e-38c3-496d-8f1e-cad05f643d41",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"class TimeSeriesLoader(DataLoader):\n",
" \"\"\"TimeSeriesLoader DataLoader.\n",
" [Source code](https://github.com/Nixtla/neuralforecast1/blob/main/neuralforecast/tsdataset.py).\n",
"\n",
" Small change to PyTorch's Data loader. \n",
" Combines a dataset and a sampler, and provides an iterable over the given dataset.\n",
"\n",
" The class `~torch.utils.data.DataLoader` supports both map-style and\n",
" iterable-style datasets with single- or multi-process loading, customizing\n",
" loading order and optional automatic batching (collation) and memory pinning. \n",
" \n",
" **Parameters:**
\n",
" `batch_size`: (int, optional): how many samples per batch to load (default: 1).
\n",
" `shuffle`: (bool, optional): set to `True` to have the data reshuffled at every epoch (default: `False`).
\n",
" `sampler`: (Sampler or Iterable, optional): defines the strategy to draw samples from the dataset.
\n",
" Can be any `Iterable` with `__len__` implemented. If specified, `shuffle` must not be specified.
\n",
" \"\"\"\n",
" def __init__(self, dataset, **kwargs):\n",
" if 'collate_fn' in kwargs:\n",
" kwargs.pop('collate_fn')\n",
" kwargs_ = {**kwargs, **dict(collate_fn=self._collate_fn)}\n",
" DataLoader.__init__(self, dataset=dataset, **kwargs_)\n",
" \n",
" def _collate_fn(self, batch):\n",
" elem = batch[0]\n",
" elem_type = type(elem)\n",
"\n",
" if isinstance(elem, torch.Tensor):\n",
" out = None\n",
" if torch.utils.data.get_worker_info() is not None:\n",
" # If we're in a background process, concatenate directly into a\n",
" # shared memory tensor to avoid an extra copy\n",
" numel = sum(x.numel() for x in batch)\n",
" storage = elem.storage()._new_shared(numel, device=elem.device)\n",
" out = elem.new(storage).resize_(len(batch), *list(elem.size()))\n",
" return torch.stack(batch, 0, out=out)\n",
"\n",
" elif isinstance(elem, Mapping):\n",
" if elem['static'] is None:\n",
" return dict(temporal=self.collate_fn([d['temporal'] for d in batch]),\n",
" temporal_cols = elem['temporal_cols'],\n",
" y_idx=elem['y_idx'])\n",
" \n",
" return dict(static=self.collate_fn([d['static'] for d in batch]),\n",
" static_cols = elem['static_cols'],\n",
" temporal=self.collate_fn([d['temporal'] for d in batch]),\n",
" temporal_cols = elem['temporal_cols'],\n",
" y_idx=elem['y_idx'])\n",
"\n",
" raise TypeError(f'Unknown {elem_type}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "93e94050-0290-43ad-9a73-c4626bba9541",
"metadata": {},
"outputs": [],
"source": [
"show_doc(TimeSeriesLoader)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "05687429-c139-44c0-adb9-097c616908cc",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"class TimeSeriesDataset(Dataset):\n",
"\n",
" def __init__(self,\n",
" temporal,\n",
" temporal_cols,\n",
" indptr,\n",
" max_size: int,\n",
" min_size: int,\n",
" y_idx: int,\n",
" static=None,\n",
" static_cols=None,\n",
" sorted=False,\n",
" ):\n",
" super().__init__()\n",
" self.temporal = self._as_torch_copy(temporal)\n",
" self.temporal_cols = pd.Index(list(temporal_cols))\n",
"\n",
" if static is not None:\n",
" self.static = self._as_torch_copy(static)\n",
" self.static_cols = static_cols\n",
" else:\n",
" self.static = static\n",
" self.static_cols = static_cols\n",
"\n",
" self.indptr = indptr\n",
" self.n_groups = self.indptr.size - 1\n",
" self.max_size = max_size\n",
" self.min_size = min_size\n",
" self.y_idx = y_idx\n",
"\n",
" # Upadated flag. To protect consistency, dataset can only be updated once\n",
" self.updated = False\n",
" self.sorted = sorted\n",
"\n",
" def __getitem__(self, idx):\n",
" if isinstance(idx, int):\n",
" # Parse temporal data and pad its left\n",
" temporal = torch.zeros(size=(len(self.temporal_cols), self.max_size),\n",
" dtype=torch.float32)\n",
" ts = self.temporal[self.indptr[idx] : self.indptr[idx + 1], :]\n",
" temporal[:len(self.temporal_cols), -len(ts):] = ts.permute(1, 0)\n",
"\n",
" # Add static data if available\n",
" static = None if self.static is None else self.static[idx,:]\n",
"\n",
" item = dict(temporal=temporal, temporal_cols=self.temporal_cols,\n",
" static=static, static_cols=self.static_cols,\n",
" y_idx=self.y_idx)\n",
"\n",
" return item\n",
" raise ValueError(f'idx must be int, got {type(idx)}')\n",
"\n",
" def __len__(self):\n",
" return self.n_groups\n",
"\n",
" def __repr__(self):\n",
" return f'TimeSeriesDataset(n_data={self.temporal.shape[0]:,}, n_groups={self.n_groups:,})'\n",
"\n",
" def __eq__(self, other):\n",
" if not hasattr(other, 'data') or not hasattr(other, 'indptr'):\n",
" return False\n",
" return np.allclose(self.data, other.data) and np.array_equal(self.indptr, other.indptr)\n",
"\n",
" def _as_torch_copy(\n",
" self,\n",
" x: Union[np.ndarray, torch.Tensor],\n",
" dtype: torch.dtype = torch.float32,\n",
" ) -> torch.Tensor:\n",
" if isinstance(x, np.ndarray):\n",
" x = torch.from_numpy(x)\n",
" return x.to(dtype, copy=False).clone()\n",
"\n",
" def align(self, df: DataFrame, id_col: str, time_col: str, target_col: str) -> 'TimeSeriesDataset':\n",
" # Protect consistency\n",
" df = ufp.copy_if_pandas(df, deep=False)\n",
"\n",
" # Add Nones to missing columns (without available_mask)\n",
" temporal_cols = self.temporal_cols.copy()\n",
" for col in temporal_cols:\n",
" if col not in df.columns:\n",
" df = ufp.assign_columns(df, col, np.nan)\n",
" if col == 'available_mask':\n",
" df = ufp.assign_columns(df, col, 1.0)\n",
" \n",
" # Sort columns to match self.temporal_cols (without available_mask)\n",
" df = df[ [id_col, time_col] + temporal_cols.tolist() ]\n",
"\n",
" # Process future_df\n",
" dataset, *_ = TimeSeriesDataset.from_df(\n",
" df=df,\n",
" sort_df=self.sorted,\n",
" id_col=id_col,\n",
" time_col=time_col,\n",
" target_col=target_col,\n",
" )\n",
" return dataset\n",
"\n",
" def append(self, futr_dataset: 'TimeSeriesDataset') -> 'TimeSeriesDataset':\n",
" \"\"\"Add future observations to the dataset. Returns a copy\"\"\"\n",
" if self.indptr.size != futr_dataset.indptr.size:\n",
" raise ValueError('Cannot append `futr_dataset` with different number of groups.')\n",
" # Define and fill new temporal with updated information\n",
" len_temporal, col_temporal = self.temporal.shape\n",
" len_futr = futr_dataset.temporal.shape[0]\n",
" new_temporal = torch.empty(size=(len_temporal + len_futr, col_temporal))\n",
" new_sizes = np.diff(self.indptr) + np.diff(futr_dataset.indptr)\n",
" new_indptr = np.append(0, new_sizes.cumsum()).astype(np.int32)\n",
" new_max_size = np.max(new_sizes)\n",
"\n",
" for i in range(self.n_groups):\n",
" curr_slice = slice(self.indptr[i], self.indptr[i + 1])\n",
" curr_size = curr_slice.stop - curr_slice.start\n",
" futr_slice = slice(futr_dataset.indptr[i], futr_dataset.indptr[i + 1])\n",
" new_temporal[new_indptr[i] : new_indptr[i] + curr_size] = self.temporal[curr_slice]\n",
" new_temporal[new_indptr[i] + curr_size : new_indptr[i + 1]] = futr_dataset.temporal[futr_slice]\n",
" \n",
" # Define new dataset\n",
" updated_dataset = TimeSeriesDataset(temporal=new_temporal,\n",
" temporal_cols=self.temporal_cols.copy(),\n",
" indptr=new_indptr,\n",
" max_size=new_max_size,\n",
" min_size=self.min_size,\n",
" static=self.static,\n",
" y_idx=self.y_idx,\n",
" static_cols=self.static_cols,\n",
" sorted=self.sorted)\n",
"\n",
" return updated_dataset\n",
"\n",
" @staticmethod\n",
" def update_dataset(dataset, futr_df, id_col='unique_id', time_col='ds', target_col='y'):\n",
" futr_dataset = dataset.align(\n",
" futr_df, id_col=id_col, time_col=time_col, target_col=target_col\n",
" )\n",
" return dataset.append(futr_dataset)\n",
" \n",
" @staticmethod\n",
" def trim_dataset(dataset, left_trim: int = 0, right_trim: int = 0):\n",
" \"\"\"\n",
" Trim temporal information from a dataset.\n",
" Returns temporal indexes [t+left:t-right] for all series.\n",
" \"\"\"\n",
" if dataset.min_size <= left_trim + right_trim:\n",
" raise Exception(f'left_trim + right_trim ({left_trim} + {right_trim}) \\\n",
" must be lower than the shorter time series ({dataset.min_size})')\n",
"\n",
" # Define and fill new temporal with trimmed information \n",
" len_temporal, col_temporal = dataset.temporal.shape\n",
" total_trim = (left_trim + right_trim) * dataset.n_groups\n",
" new_temporal = torch.zeros(size=(len_temporal-total_trim, col_temporal))\n",
" new_indptr = [0]\n",
"\n",
" acum = 0\n",
" for i in range(dataset.n_groups):\n",
" series_length = dataset.indptr[i + 1] - dataset.indptr[i]\n",
" new_length = series_length - left_trim - right_trim\n",
" new_temporal[acum:(acum+new_length), :] = dataset.temporal[dataset.indptr[i]+left_trim : \\\n",
" dataset.indptr[i + 1]-right_trim, :]\n",
" acum += new_length\n",
" new_indptr.append(acum)\n",
"\n",
" new_max_size = dataset.max_size-left_trim-right_trim\n",
" new_min_size = dataset.min_size-left_trim-right_trim\n",
" \n",
" # Define new dataset\n",
" updated_dataset = TimeSeriesDataset(temporal=new_temporal,\n",
" temporal_cols= dataset.temporal_cols.copy(),\n",
" indptr=np.array(new_indptr, dtype=np.int32),\n",
" max_size=new_max_size,\n",
" min_size=new_min_size,\n",
" y_idx=dataset.y_idx,\n",
" static=dataset.static,\n",
" static_cols=dataset.static_cols,\n",
" sorted=dataset.sorted)\n",
"\n",
" return updated_dataset\n",
"\n",
" @staticmethod\n",
" def from_df(df, static_df=None, sort_df=False, id_col='unique_id', time_col='ds', target_col='y'):\n",
" # TODO: protect on equality of static_df + df indexes\n",
" if isinstance(df, pd.DataFrame) and df.index.name == id_col:\n",
" warnings.warn(\n",
" \"Passing the id as index is deprecated, please provide it as a column instead.\",\n",
" FutureWarning,\n",
" )\n",
" df = df.reset_index(id_col)\n",
" # Define indexes if not given\n",
" if static_df is not None:\n",
" if isinstance(static_df, pd.DataFrame) and static_df.index.name == id_col:\n",
" warnings.warn(\n",
" \"Passing the id as index is deprecated, please provide it as a column instead.\",\n",
" FutureWarning,\n",
" )\n",
" if sort_df:\n",
" static_df = ufp.sort(static_df, by=id_col)\n",
"\n",
" ids, times, data, indptr, sort_idxs = ufp.process_df(df, id_col, time_col, target_col)\n",
" # processor sets y as the first column\n",
" temporal_cols = pd.Index(\n",
" [target_col] + [c for c in df.columns if c not in (id_col, time_col, target_col)]\n",
" )\n",
" temporal = data.astype(np.float32, copy=False)\n",
" indices = ids\n",
" if isinstance(df, pd.DataFrame):\n",
" dates = pd.Index(times, name=time_col)\n",
" else:\n",
" dates = pl_Series(time_col, times)\n",
" sizes = np.diff(indptr)\n",
" max_size = max(sizes)\n",
" min_size = min(sizes)\n",
"\n",
" # Add Available mask efficiently (without adding column to df)\n",
" if 'available_mask' not in df.columns:\n",
" available_mask = np.ones((len(temporal),1), dtype=np.float32)\n",
" temporal = np.append(temporal, available_mask, axis=1)\n",
" temporal_cols = temporal_cols.append(pd.Index(['available_mask']))\n",
"\n",
" # Static features\n",
" if static_df is not None:\n",
" static_cols = [col for col in static_df.columns if col != id_col]\n",
" static = ufp.to_numpy(static_df[static_cols])\n",
" static_cols = pd.Index(static_cols)\n",
" else:\n",
" static = None\n",
" static_cols = None\n",
"\n",
" dataset = TimeSeriesDataset(\n",
" temporal=temporal,\n",
" temporal_cols=temporal_cols,\n",
" static=static,\n",
" static_cols=static_cols,\n",
" indptr=indptr,\n",
" max_size=max_size,\n",
" min_size=min_size,\n",
" sorted=sort_df,\n",
" y_idx=0,\n",
" )\n",
" ds = df[time_col].to_numpy()\n",
" if sort_idxs is not None:\n",
" ds = ds[sort_idxs]\n",
" return dataset, indices, dates, ds"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "61a818bf-28d2-4561-8036-475f6fe78d0a",
"metadata": {},
"outputs": [],
"source": [
"show_doc(TimeSeriesDataset)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "52c07552-b6fa-4d10-8792-71743dcdfd1d",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"\n",
"# Testing sort_df=True functionality\n",
"temporal_df = generate_series(n_series=1000, \n",
" n_temporal_features=0, equal_ends=False)\n",
"sorted_temporal_df = temporal_df.sort_values(['unique_id', 'ds'])\n",
"unsorted_temporal_df = sorted_temporal_df.sample(frac=1.0)\n",
"dataset, indices, dates, ds = TimeSeriesDataset.from_df(df=unsorted_temporal_df,\n",
" sort_df=True)\n",
"\n",
"np.testing.assert_allclose(dataset.temporal[:,:-1], \n",
" sorted_temporal_df.drop(columns=['unique_id', 'ds']).values)\n",
"test_eq(indices, pd.Series(sorted_temporal_df['unique_id'].unique()))\n",
"test_eq(dates, temporal_df.groupby('unique_id')['ds'].max().values)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "24e51cf3",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"class _FilesDataset:\n",
" def __init__(\n",
" self,\n",
" files: List[str],\n",
" temporal_cols: List[str],\n",
" static_cols: Optional[List[str]],\n",
" id_col: str,\n",
" time_col: str,\n",
" target_col: str,\n",
" min_size: int,\n",
" ):\n",
" self.files = files\n",
" self.temporal_cols = pd.Index(temporal_cols)\n",
" self.static_cols = pd.Index(static_cols) if static_cols is not None else None\n",
" self.id_col = id_col\n",
" self.time_col = time_col\n",
" self.target_col = target_col\n",
" self.min_size = min_size"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c4dae43c-4d11-4bbc-a431-ac33b004859a",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"class TimeSeriesDataModule(pl.LightningDataModule):\n",
" \n",
" def __init__(\n",
" self, \n",
" dataset: TimeSeriesDataset,\n",
" batch_size=32, \n",
" valid_batch_size=1024,\n",
" num_workers=0,\n",
" drop_last=False,\n",
" shuffle_train=True,\n",
" ):\n",
" super().__init__()\n",
" self.dataset = dataset\n",
" self.batch_size = batch_size\n",
" self.valid_batch_size = valid_batch_size\n",
" self.num_workers = num_workers\n",
" self.drop_last = drop_last\n",
" self.shuffle_train = shuffle_train\n",
" \n",
" def train_dataloader(self):\n",
" loader = TimeSeriesLoader(\n",
" self.dataset,\n",
" batch_size=self.batch_size, \n",
" num_workers=self.num_workers,\n",
" shuffle=self.shuffle_train,\n",
" drop_last=self.drop_last\n",
" )\n",
" return loader\n",
" \n",
" def val_dataloader(self):\n",
" loader = TimeSeriesLoader(\n",
" self.dataset, \n",
" batch_size=self.valid_batch_size, \n",
" num_workers=self.num_workers,\n",
" shuffle=False,\n",
" drop_last=self.drop_last\n",
" )\n",
" return loader\n",
" \n",
" def predict_dataloader(self):\n",
" loader = TimeSeriesLoader(\n",
" self.dataset,\n",
" batch_size=self.valid_batch_size, \n",
" num_workers=self.num_workers,\n",
" shuffle=False\n",
" )\n",
" return loader"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8535a15f-b5cf-4ca1-bfa2-e53a9e8c3bc0",
"metadata": {},
"outputs": [],
"source": [
"show_doc(TimeSeriesDataModule)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b534d29d-eecc-43ba-8468-c23305fa24a2",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"\n",
"batch_size = 128\n",
"data = TimeSeriesDataModule(dataset=dataset, \n",
" batch_size=batch_size, drop_last=True)\n",
"for batch in data.train_dataloader():\n",
" test_eq(batch['temporal'].shape, (batch_size, 2, 500))\n",
" test_eq(batch['temporal_cols'], ['y', 'available_mask'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4481272a-ea3a-4b63-8f14-9445d8f41338",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"\n",
"batch_size = 128\n",
"n_static_features = 2\n",
"n_temporal_features = 4\n",
"temporal_df, static_df = generate_series(n_series=1000,\n",
" n_static_features=n_static_features,\n",
" n_temporal_features=n_temporal_features, \n",
" equal_ends=False)\n",
"\n",
"dataset, indices, dates, ds = TimeSeriesDataset.from_df(df=temporal_df,\n",
" static_df=static_df,\n",
" sort_df=True)\n",
"data = TimeSeriesDataModule(dataset=dataset,\n",
" batch_size=batch_size, drop_last=True)\n",
"\n",
"for batch in data.train_dataloader():\n",
" test_eq(batch['temporal'].shape, (batch_size, n_temporal_features + 2, 500))\n",
" test_eq(batch['temporal_cols'],\n",
" ['y'] + [f'temporal_{i}' for i in range(n_temporal_features)] + ['available_mask'])\n",
" \n",
" test_eq(batch['static'].shape, (batch_size, n_static_features))\n",
" test_eq(batch['static_cols'], [f'static_{i}' for i in range(n_static_features)])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "252b59f6",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"\n",
"# Testing sort_df=True functionality\n",
"temporal_df = generate_series(n_series=2,\n",
" n_temporal_features=2, equal_ends=True)\n",
"temporal_df = temporal_df.groupby('unique_id').tail(10)\n",
"temporal_df = temporal_df.reset_index()\n",
"temporal_full_df = temporal_df.sort_values(['unique_id', 'ds']).reset_index(drop=True)\n",
"temporal_full_df.loc[temporal_full_df.ds > '2001-05-11', ['y', 'temporal_0']] = None\n",
"\n",
"split1_df = temporal_full_df.loc[temporal_full_df.ds <= '2001-05-11']\n",
"split2_df = temporal_full_df.loc[temporal_full_df.ds > '2001-05-11']"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6eab7367",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"\n",
"# Testing available mask\n",
"temporal_df_w_mask = temporal_df.copy()\n",
"temporal_df_w_mask['available_mask'] = 1\n",
"\n",
"# Mask with all 1's\n",
"dataset, indices, dates, ds = TimeSeriesDataset.from_df(df=temporal_df_w_mask,\n",
" sort_df=True)\n",
"mask_average = dataset.temporal[:, -1].mean()\n",
"np.testing.assert_almost_equal(mask_average, 1.0000)\n",
"\n",
"# Add 0's to available mask\n",
"temporal_df_w_mask.loc[temporal_df_w_mask.ds > '2001-05-11', 'available_mask'] = 0\n",
"dataset, indices, dates, ds = TimeSeriesDataset.from_df(df=temporal_df_w_mask,\n",
" sort_df=True)\n",
"mask_average = dataset.temporal[:, -1].mean()\n",
"np.testing.assert_almost_equal(mask_average, 0.7000)\n",
"\n",
"# Available mask not in last column\n",
"temporal_df_w_mask = temporal_df_w_mask[['unique_id','ds','y','available_mask', 'temporal_0','temporal_1']]\n",
"dataset, indices, dates, ds = TimeSeriesDataset.from_df(df=temporal_df_w_mask,\n",
" sort_df=True)\n",
"mask_average = dataset.temporal[:, 1].mean()\n",
"np.testing.assert_almost_equal(mask_average, 0.7000)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0d23f1a",
"metadata": {},
"outputs": [],
"source": [
"# To test correct future_df wrangling of the `update_df` method\n",
"# We are checking that we are able to recover the AirPassengers dataset\n",
"# using the dataframe or splitting it into parts and initializing."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "39f999c2",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"\n",
"# FULL DATASET\n",
"dataset_full, indices_full, dates_full, ds_full = TimeSeriesDataset.from_df(df=temporal_full_df,\n",
" sort_df=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "30f927e2",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"\n",
"# SPLIT_1 DATASET\n",
"dataset_1, indices_1, dates_1, ds_1 = TimeSeriesDataset.from_df(df=split1_df,\n",
" sort_df=False)\n",
"dataset_1 = dataset_1.update_dataset(dataset_1, split2_df)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "468a6879",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"\n",
"np.testing.assert_almost_equal(dataset_full.temporal.numpy(), dataset_1.temporal.numpy())\n",
"test_eq(dataset_full.max_size, dataset_1.max_size)\n",
"test_eq(dataset_full.indptr, dataset_1.indptr)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "556f852c",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"\n",
"# Testing trim_dataset functionality\n",
"n_static_features = 0\n",
"n_temporal_features = 2\n",
"temporal_df = generate_series(n_series=100,\n",
" min_length=50,\n",
" max_length=100,\n",
" n_static_features=n_static_features,\n",
" n_temporal_features=n_temporal_features, \n",
" equal_ends=False)\n",
"dataset, indices, dates, ds = TimeSeriesDataset.from_df(df=temporal_df,\n",
" static_df=static_df,\n",
" sort_df=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "db7b1a51",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"left_trim = 10\n",
"right_trim = 20\n",
"dataset_trimmed = dataset.trim_dataset(dataset, left_trim=left_trim, right_trim=right_trim)\n",
"\n",
"np.testing.assert_almost_equal(dataset.temporal[dataset.indptr[50]+left_trim:dataset.indptr[51]-right_trim].numpy(),\n",
" dataset_trimmed.temporal[dataset_trimmed.indptr[50]:dataset_trimmed.indptr[51]].numpy())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "624a3fbb-cb78-4440-a645-54699fd82660",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"#| polars\n",
"import polars"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a1bdd479-b4c7-4a40-93eb-2b7c9b969a80",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"#| polars\n",
"temporal_df2 = temporal_df.copy()\n",
"for col in ('unique_id', 'temporal_0', 'temporal_1'):\n",
" temporal_df2[col] = temporal_df2[col].cat.codes\n",
"temporal_pl = polars.from_pandas(temporal_df2).sample(fraction=1.0)\n",
"static_pl = polars.from_pandas(static_df.assign(unique_id=lambda df: df['unique_id'].astype('int64')))\n",
"dataset_pl, indices_pl, dates_pl, ds_pl = TimeSeriesDataset.from_df(df=temporal_pl, static_df=static_df, sort_df=True)\n",
"for attr in ('static_cols', 'temporal_cols', 'min_size', 'max_size', 'n_groups'):\n",
" test_eq(getattr(dataset, attr), getattr(dataset_pl, attr))\n",
"torch.testing.assert_allclose(dataset.temporal, dataset_pl.temporal)\n",
"torch.testing.assert_allclose(dataset.static, dataset_pl.static)\n",
"pd.testing.assert_series_equal(indices.astype('int64'), indices_pl.to_pandas().astype('int64'))\n",
"pd.testing.assert_index_equal(dates, pd.Index(dates_pl, name='ds'))\n",
"np.testing.assert_array_equal(ds, ds_pl)\n",
"np.testing.assert_array_equal(dataset.indptr, dataset_pl.indptr)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "959ea63c",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"class _DistributedTimeSeriesDataModule(TimeSeriesDataModule):\n",
" def __init__(\n",
" self,\n",
" dataset: _FilesDataset,\n",
" batch_size=32,\n",
" valid_batch_size=1024,\n",
" num_workers=0,\n",
" drop_last=False,\n",
" shuffle_train=True,\n",
" ):\n",
" super(TimeSeriesDataModule, self).__init__()\n",
" self.files_ds = dataset\n",
" self.batch_size = batch_size\n",
" self.valid_batch_size = valid_batch_size\n",
" self.num_workers = num_workers\n",
" self.drop_last = drop_last\n",
" self.shuffle_train = shuffle_train\n",
"\n",
" def setup(self, stage):\n",
" import torch.distributed as dist\n",
"\n",
" df = pd.read_parquet(self.files_ds.files[dist.get_rank()])\n",
" if self.files_ds.static_cols is not None:\n",
" static_df = (\n",
" df[[self.files_ds.id_col] + self.files_ds.static_cols.tolist()]\n",
" .groupby(self.files_ds.id_col, observed=True)\n",
" .head(1)\n",
" )\n",
" df = df.drop(columns=self.files_ds.static_cols)\n",
" else:\n",
" static_df = None\n",
" self.dataset, *_ = TimeSeriesDataset.from_df(\n",
" df=df,\n",
" static_df=static_df,\n",
" sort_df=True,\n",
" id_col=self.files_ds.id_col,\n",
" time_col=self.files_ds.time_col,\n",
" target_col=self.files_ds.target_col,\n",
" )"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "python3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}