{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "524620c1", "metadata": {}, "outputs": [], "source": [ "#| default_exp tsdataset" ] }, { "cell_type": "code", "execution_count": null, "id": "15392f6f", "metadata": {}, "outputs": [], "source": [ "#| hide\n", "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "markdown", "id": "12fa25a4", "metadata": {}, "source": [ "# PyTorch Dataset/Loader\n", "> Torch Dataset for Time Series\n" ] }, { "cell_type": "code", "execution_count": null, "id": "2508f7a9-1433-4ad8-8f2f-0078c6ed6c3c", "metadata": {}, "outputs": [], "source": [ "#| hide\n", "from fastcore.test import test_eq\n", "from nbdev.showdoc import show_doc\n", "from neuralforecast.utils import generate_series" ] }, { "cell_type": "code", "execution_count": null, "id": "44065066-e72a-431f-938f-1528adef9fe8", "metadata": {}, "outputs": [], "source": [ "#| export\n", "import warnings\n", "from collections.abc import Mapping\n", "from typing import List, Optional, Union\n", "\n", "import numpy as np\n", "import pandas as pd\n", "import pytorch_lightning as pl\n", "import torch\n", "import utilsforecast.processing as ufp\n", "from torch.utils.data import Dataset, DataLoader\n", "from utilsforecast.compat import DataFrame, pl_Series" ] }, { "cell_type": "code", "execution_count": null, "id": "323a7a6e-38c3-496d-8f1e-cad05f643d41", "metadata": {}, "outputs": [], "source": [ "#| export\n", "class TimeSeriesLoader(DataLoader):\n", " \"\"\"TimeSeriesLoader DataLoader.\n", " [Source code](https://github.com/Nixtla/neuralforecast1/blob/main/neuralforecast/tsdataset.py).\n", "\n", " Small change to PyTorch's Data loader. \n", " Combines a dataset and a sampler, and provides an iterable over the given dataset.\n", "\n", " The class `~torch.utils.data.DataLoader` supports both map-style and\n", " iterable-style datasets with single- or multi-process loading, customizing\n", " loading order and optional automatic batching (collation) and memory pinning. \n", " \n", " **Parameters:**
\n", " `batch_size`: (int, optional): how many samples per batch to load (default: 1).
\n", " `shuffle`: (bool, optional): set to `True` to have the data reshuffled at every epoch (default: `False`).
\n", " `sampler`: (Sampler or Iterable, optional): defines the strategy to draw samples from the dataset.
\n", " Can be any `Iterable` with `__len__` implemented. If specified, `shuffle` must not be specified.
\n", " \"\"\"\n", " def __init__(self, dataset, **kwargs):\n", " if 'collate_fn' in kwargs:\n", " kwargs.pop('collate_fn')\n", " kwargs_ = {**kwargs, **dict(collate_fn=self._collate_fn)}\n", " DataLoader.__init__(self, dataset=dataset, **kwargs_)\n", " \n", " def _collate_fn(self, batch):\n", " elem = batch[0]\n", " elem_type = type(elem)\n", "\n", " if isinstance(elem, torch.Tensor):\n", " out = None\n", " if torch.utils.data.get_worker_info() is not None:\n", " # If we're in a background process, concatenate directly into a\n", " # shared memory tensor to avoid an extra copy\n", " numel = sum(x.numel() for x in batch)\n", " storage = elem.storage()._new_shared(numel, device=elem.device)\n", " out = elem.new(storage).resize_(len(batch), *list(elem.size()))\n", " return torch.stack(batch, 0, out=out)\n", "\n", " elif isinstance(elem, Mapping):\n", " if elem['static'] is None:\n", " return dict(temporal=self.collate_fn([d['temporal'] for d in batch]),\n", " temporal_cols = elem['temporal_cols'],\n", " y_idx=elem['y_idx'])\n", " \n", " return dict(static=self.collate_fn([d['static'] for d in batch]),\n", " static_cols = elem['static_cols'],\n", " temporal=self.collate_fn([d['temporal'] for d in batch]),\n", " temporal_cols = elem['temporal_cols'],\n", " y_idx=elem['y_idx'])\n", "\n", " raise TypeError(f'Unknown {elem_type}')" ] }, { "cell_type": "code", "execution_count": null, "id": "93e94050-0290-43ad-9a73-c4626bba9541", "metadata": {}, "outputs": [], "source": [ "show_doc(TimeSeriesLoader)" ] }, { "cell_type": "code", "execution_count": null, "id": "05687429-c139-44c0-adb9-097c616908cc", "metadata": {}, "outputs": [], "source": [ "#| export\n", "class TimeSeriesDataset(Dataset):\n", "\n", " def __init__(self,\n", " temporal,\n", " temporal_cols,\n", " indptr,\n", " max_size: int,\n", " min_size: int,\n", " y_idx: int,\n", " static=None,\n", " static_cols=None,\n", " sorted=False,\n", " ):\n", " super().__init__()\n", " self.temporal = self._as_torch_copy(temporal)\n", " self.temporal_cols = pd.Index(list(temporal_cols))\n", "\n", " if static is not None:\n", " self.static = self._as_torch_copy(static)\n", " self.static_cols = static_cols\n", " else:\n", " self.static = static\n", " self.static_cols = static_cols\n", "\n", " self.indptr = indptr\n", " self.n_groups = self.indptr.size - 1\n", " self.max_size = max_size\n", " self.min_size = min_size\n", " self.y_idx = y_idx\n", "\n", " # Upadated flag. To protect consistency, dataset can only be updated once\n", " self.updated = False\n", " self.sorted = sorted\n", "\n", " def __getitem__(self, idx):\n", " if isinstance(idx, int):\n", " # Parse temporal data and pad its left\n", " temporal = torch.zeros(size=(len(self.temporal_cols), self.max_size),\n", " dtype=torch.float32)\n", " ts = self.temporal[self.indptr[idx] : self.indptr[idx + 1], :]\n", " temporal[:len(self.temporal_cols), -len(ts):] = ts.permute(1, 0)\n", "\n", " # Add static data if available\n", " static = None if self.static is None else self.static[idx,:]\n", "\n", " item = dict(temporal=temporal, temporal_cols=self.temporal_cols,\n", " static=static, static_cols=self.static_cols,\n", " y_idx=self.y_idx)\n", "\n", " return item\n", " raise ValueError(f'idx must be int, got {type(idx)}')\n", "\n", " def __len__(self):\n", " return self.n_groups\n", "\n", " def __repr__(self):\n", " return f'TimeSeriesDataset(n_data={self.temporal.shape[0]:,}, n_groups={self.n_groups:,})'\n", "\n", " def __eq__(self, other):\n", " if not hasattr(other, 'data') or not hasattr(other, 'indptr'):\n", " return False\n", " return np.allclose(self.data, other.data) and np.array_equal(self.indptr, other.indptr)\n", "\n", " def _as_torch_copy(\n", " self,\n", " x: Union[np.ndarray, torch.Tensor],\n", " dtype: torch.dtype = torch.float32,\n", " ) -> torch.Tensor:\n", " if isinstance(x, np.ndarray):\n", " x = torch.from_numpy(x)\n", " return x.to(dtype, copy=False).clone()\n", "\n", " def align(self, df: DataFrame, id_col: str, time_col: str, target_col: str) -> 'TimeSeriesDataset':\n", " # Protect consistency\n", " df = ufp.copy_if_pandas(df, deep=False)\n", "\n", " # Add Nones to missing columns (without available_mask)\n", " temporal_cols = self.temporal_cols.copy()\n", " for col in temporal_cols:\n", " if col not in df.columns:\n", " df = ufp.assign_columns(df, col, np.nan)\n", " if col == 'available_mask':\n", " df = ufp.assign_columns(df, col, 1.0)\n", " \n", " # Sort columns to match self.temporal_cols (without available_mask)\n", " df = df[ [id_col, time_col] + temporal_cols.tolist() ]\n", "\n", " # Process future_df\n", " dataset, *_ = TimeSeriesDataset.from_df(\n", " df=df,\n", " sort_df=self.sorted,\n", " id_col=id_col,\n", " time_col=time_col,\n", " target_col=target_col,\n", " )\n", " return dataset\n", "\n", " def append(self, futr_dataset: 'TimeSeriesDataset') -> 'TimeSeriesDataset':\n", " \"\"\"Add future observations to the dataset. Returns a copy\"\"\"\n", " if self.indptr.size != futr_dataset.indptr.size:\n", " raise ValueError('Cannot append `futr_dataset` with different number of groups.')\n", " # Define and fill new temporal with updated information\n", " len_temporal, col_temporal = self.temporal.shape\n", " len_futr = futr_dataset.temporal.shape[0]\n", " new_temporal = torch.empty(size=(len_temporal + len_futr, col_temporal))\n", " new_sizes = np.diff(self.indptr) + np.diff(futr_dataset.indptr)\n", " new_indptr = np.append(0, new_sizes.cumsum()).astype(np.int32)\n", " new_max_size = np.max(new_sizes)\n", "\n", " for i in range(self.n_groups):\n", " curr_slice = slice(self.indptr[i], self.indptr[i + 1])\n", " curr_size = curr_slice.stop - curr_slice.start\n", " futr_slice = slice(futr_dataset.indptr[i], futr_dataset.indptr[i + 1])\n", " new_temporal[new_indptr[i] : new_indptr[i] + curr_size] = self.temporal[curr_slice]\n", " new_temporal[new_indptr[i] + curr_size : new_indptr[i + 1]] = futr_dataset.temporal[futr_slice]\n", " \n", " # Define new dataset\n", " updated_dataset = TimeSeriesDataset(temporal=new_temporal,\n", " temporal_cols=self.temporal_cols.copy(),\n", " indptr=new_indptr,\n", " max_size=new_max_size,\n", " min_size=self.min_size,\n", " static=self.static,\n", " y_idx=self.y_idx,\n", " static_cols=self.static_cols,\n", " sorted=self.sorted)\n", "\n", " return updated_dataset\n", "\n", " @staticmethod\n", " def update_dataset(dataset, futr_df, id_col='unique_id', time_col='ds', target_col='y'):\n", " futr_dataset = dataset.align(\n", " futr_df, id_col=id_col, time_col=time_col, target_col=target_col\n", " )\n", " return dataset.append(futr_dataset)\n", " \n", " @staticmethod\n", " def trim_dataset(dataset, left_trim: int = 0, right_trim: int = 0):\n", " \"\"\"\n", " Trim temporal information from a dataset.\n", " Returns temporal indexes [t+left:t-right] for all series.\n", " \"\"\"\n", " if dataset.min_size <= left_trim + right_trim:\n", " raise Exception(f'left_trim + right_trim ({left_trim} + {right_trim}) \\\n", " must be lower than the shorter time series ({dataset.min_size})')\n", "\n", " # Define and fill new temporal with trimmed information \n", " len_temporal, col_temporal = dataset.temporal.shape\n", " total_trim = (left_trim + right_trim) * dataset.n_groups\n", " new_temporal = torch.zeros(size=(len_temporal-total_trim, col_temporal))\n", " new_indptr = [0]\n", "\n", " acum = 0\n", " for i in range(dataset.n_groups):\n", " series_length = dataset.indptr[i + 1] - dataset.indptr[i]\n", " new_length = series_length - left_trim - right_trim\n", " new_temporal[acum:(acum+new_length), :] = dataset.temporal[dataset.indptr[i]+left_trim : \\\n", " dataset.indptr[i + 1]-right_trim, :]\n", " acum += new_length\n", " new_indptr.append(acum)\n", "\n", " new_max_size = dataset.max_size-left_trim-right_trim\n", " new_min_size = dataset.min_size-left_trim-right_trim\n", " \n", " # Define new dataset\n", " updated_dataset = TimeSeriesDataset(temporal=new_temporal,\n", " temporal_cols= dataset.temporal_cols.copy(),\n", " indptr=np.array(new_indptr, dtype=np.int32),\n", " max_size=new_max_size,\n", " min_size=new_min_size,\n", " y_idx=dataset.y_idx,\n", " static=dataset.static,\n", " static_cols=dataset.static_cols,\n", " sorted=dataset.sorted)\n", "\n", " return updated_dataset\n", "\n", " @staticmethod\n", " def from_df(df, static_df=None, sort_df=False, id_col='unique_id', time_col='ds', target_col='y'):\n", " # TODO: protect on equality of static_df + df indexes\n", " if isinstance(df, pd.DataFrame) and df.index.name == id_col:\n", " warnings.warn(\n", " \"Passing the id as index is deprecated, please provide it as a column instead.\",\n", " FutureWarning,\n", " )\n", " df = df.reset_index(id_col)\n", " # Define indexes if not given\n", " if static_df is not None:\n", " if isinstance(static_df, pd.DataFrame) and static_df.index.name == id_col:\n", " warnings.warn(\n", " \"Passing the id as index is deprecated, please provide it as a column instead.\",\n", " FutureWarning,\n", " )\n", " if sort_df:\n", " static_df = ufp.sort(static_df, by=id_col)\n", "\n", " ids, times, data, indptr, sort_idxs = ufp.process_df(df, id_col, time_col, target_col)\n", " # processor sets y as the first column\n", " temporal_cols = pd.Index(\n", " [target_col] + [c for c in df.columns if c not in (id_col, time_col, target_col)]\n", " )\n", " temporal = data.astype(np.float32, copy=False)\n", " indices = ids\n", " if isinstance(df, pd.DataFrame):\n", " dates = pd.Index(times, name=time_col)\n", " else:\n", " dates = pl_Series(time_col, times)\n", " sizes = np.diff(indptr)\n", " max_size = max(sizes)\n", " min_size = min(sizes)\n", "\n", " # Add Available mask efficiently (without adding column to df)\n", " if 'available_mask' not in df.columns:\n", " available_mask = np.ones((len(temporal),1), dtype=np.float32)\n", " temporal = np.append(temporal, available_mask, axis=1)\n", " temporal_cols = temporal_cols.append(pd.Index(['available_mask']))\n", "\n", " # Static features\n", " if static_df is not None:\n", " static_cols = [col for col in static_df.columns if col != id_col]\n", " static = ufp.to_numpy(static_df[static_cols])\n", " static_cols = pd.Index(static_cols)\n", " else:\n", " static = None\n", " static_cols = None\n", "\n", " dataset = TimeSeriesDataset(\n", " temporal=temporal,\n", " temporal_cols=temporal_cols,\n", " static=static,\n", " static_cols=static_cols,\n", " indptr=indptr,\n", " max_size=max_size,\n", " min_size=min_size,\n", " sorted=sort_df,\n", " y_idx=0,\n", " )\n", " ds = df[time_col].to_numpy()\n", " if sort_idxs is not None:\n", " ds = ds[sort_idxs]\n", " return dataset, indices, dates, ds" ] }, { "cell_type": "code", "execution_count": null, "id": "61a818bf-28d2-4561-8036-475f6fe78d0a", "metadata": {}, "outputs": [], "source": [ "show_doc(TimeSeriesDataset)" ] }, { "cell_type": "code", "execution_count": null, "id": "52c07552-b6fa-4d10-8792-71743dcdfd1d", "metadata": {}, "outputs": [], "source": [ "#| hide\n", "\n", "# Testing sort_df=True functionality\n", "temporal_df = generate_series(n_series=1000, \n", " n_temporal_features=0, equal_ends=False)\n", "sorted_temporal_df = temporal_df.sort_values(['unique_id', 'ds'])\n", "unsorted_temporal_df = sorted_temporal_df.sample(frac=1.0)\n", "dataset, indices, dates, ds = TimeSeriesDataset.from_df(df=unsorted_temporal_df,\n", " sort_df=True)\n", "\n", "np.testing.assert_allclose(dataset.temporal[:,:-1], \n", " sorted_temporal_df.drop(columns=['unique_id', 'ds']).values)\n", "test_eq(indices, pd.Series(sorted_temporal_df['unique_id'].unique()))\n", "test_eq(dates, temporal_df.groupby('unique_id')['ds'].max().values)" ] }, { "cell_type": "code", "execution_count": null, "id": "24e51cf3", "metadata": {}, "outputs": [], "source": [ "#| export\n", "class _FilesDataset:\n", " def __init__(\n", " self,\n", " files: List[str],\n", " temporal_cols: List[str],\n", " static_cols: Optional[List[str]],\n", " id_col: str,\n", " time_col: str,\n", " target_col: str,\n", " min_size: int,\n", " ):\n", " self.files = files\n", " self.temporal_cols = pd.Index(temporal_cols)\n", " self.static_cols = pd.Index(static_cols) if static_cols is not None else None\n", " self.id_col = id_col\n", " self.time_col = time_col\n", " self.target_col = target_col\n", " self.min_size = min_size" ] }, { "cell_type": "code", "execution_count": null, "id": "c4dae43c-4d11-4bbc-a431-ac33b004859a", "metadata": {}, "outputs": [], "source": [ "#| export\n", "class TimeSeriesDataModule(pl.LightningDataModule):\n", " \n", " def __init__(\n", " self, \n", " dataset: TimeSeriesDataset,\n", " batch_size=32, \n", " valid_batch_size=1024,\n", " num_workers=0,\n", " drop_last=False,\n", " shuffle_train=True,\n", " ):\n", " super().__init__()\n", " self.dataset = dataset\n", " self.batch_size = batch_size\n", " self.valid_batch_size = valid_batch_size\n", " self.num_workers = num_workers\n", " self.drop_last = drop_last\n", " self.shuffle_train = shuffle_train\n", " \n", " def train_dataloader(self):\n", " loader = TimeSeriesLoader(\n", " self.dataset,\n", " batch_size=self.batch_size, \n", " num_workers=self.num_workers,\n", " shuffle=self.shuffle_train,\n", " drop_last=self.drop_last\n", " )\n", " return loader\n", " \n", " def val_dataloader(self):\n", " loader = TimeSeriesLoader(\n", " self.dataset, \n", " batch_size=self.valid_batch_size, \n", " num_workers=self.num_workers,\n", " shuffle=False,\n", " drop_last=self.drop_last\n", " )\n", " return loader\n", " \n", " def predict_dataloader(self):\n", " loader = TimeSeriesLoader(\n", " self.dataset,\n", " batch_size=self.valid_batch_size, \n", " num_workers=self.num_workers,\n", " shuffle=False\n", " )\n", " return loader" ] }, { "cell_type": "code", "execution_count": null, "id": "8535a15f-b5cf-4ca1-bfa2-e53a9e8c3bc0", "metadata": {}, "outputs": [], "source": [ "show_doc(TimeSeriesDataModule)" ] }, { "cell_type": "code", "execution_count": null, "id": "b534d29d-eecc-43ba-8468-c23305fa24a2", "metadata": {}, "outputs": [], "source": [ "#| hide\n", "\n", "batch_size = 128\n", "data = TimeSeriesDataModule(dataset=dataset, \n", " batch_size=batch_size, drop_last=True)\n", "for batch in data.train_dataloader():\n", " test_eq(batch['temporal'].shape, (batch_size, 2, 500))\n", " test_eq(batch['temporal_cols'], ['y', 'available_mask'])" ] }, { "cell_type": "code", "execution_count": null, "id": "4481272a-ea3a-4b63-8f14-9445d8f41338", "metadata": {}, "outputs": [], "source": [ "#| hide\n", "\n", "batch_size = 128\n", "n_static_features = 2\n", "n_temporal_features = 4\n", "temporal_df, static_df = generate_series(n_series=1000,\n", " n_static_features=n_static_features,\n", " n_temporal_features=n_temporal_features, \n", " equal_ends=False)\n", "\n", "dataset, indices, dates, ds = TimeSeriesDataset.from_df(df=temporal_df,\n", " static_df=static_df,\n", " sort_df=True)\n", "data = TimeSeriesDataModule(dataset=dataset,\n", " batch_size=batch_size, drop_last=True)\n", "\n", "for batch in data.train_dataloader():\n", " test_eq(batch['temporal'].shape, (batch_size, n_temporal_features + 2, 500))\n", " test_eq(batch['temporal_cols'],\n", " ['y'] + [f'temporal_{i}' for i in range(n_temporal_features)] + ['available_mask'])\n", " \n", " test_eq(batch['static'].shape, (batch_size, n_static_features))\n", " test_eq(batch['static_cols'], [f'static_{i}' for i in range(n_static_features)])" ] }, { "cell_type": "code", "execution_count": null, "id": "252b59f6", "metadata": {}, "outputs": [], "source": [ "#| hide\n", "\n", "# Testing sort_df=True functionality\n", "temporal_df = generate_series(n_series=2,\n", " n_temporal_features=2, equal_ends=True)\n", "temporal_df = temporal_df.groupby('unique_id').tail(10)\n", "temporal_df = temporal_df.reset_index()\n", "temporal_full_df = temporal_df.sort_values(['unique_id', 'ds']).reset_index(drop=True)\n", "temporal_full_df.loc[temporal_full_df.ds > '2001-05-11', ['y', 'temporal_0']] = None\n", "\n", "split1_df = temporal_full_df.loc[temporal_full_df.ds <= '2001-05-11']\n", "split2_df = temporal_full_df.loc[temporal_full_df.ds > '2001-05-11']" ] }, { "cell_type": "code", "execution_count": null, "id": "6eab7367", "metadata": {}, "outputs": [], "source": [ "#| hide\n", "\n", "# Testing available mask\n", "temporal_df_w_mask = temporal_df.copy()\n", "temporal_df_w_mask['available_mask'] = 1\n", "\n", "# Mask with all 1's\n", "dataset, indices, dates, ds = TimeSeriesDataset.from_df(df=temporal_df_w_mask,\n", " sort_df=True)\n", "mask_average = dataset.temporal[:, -1].mean()\n", "np.testing.assert_almost_equal(mask_average, 1.0000)\n", "\n", "# Add 0's to available mask\n", "temporal_df_w_mask.loc[temporal_df_w_mask.ds > '2001-05-11', 'available_mask'] = 0\n", "dataset, indices, dates, ds = TimeSeriesDataset.from_df(df=temporal_df_w_mask,\n", " sort_df=True)\n", "mask_average = dataset.temporal[:, -1].mean()\n", "np.testing.assert_almost_equal(mask_average, 0.7000)\n", "\n", "# Available mask not in last column\n", "temporal_df_w_mask = temporal_df_w_mask[['unique_id','ds','y','available_mask', 'temporal_0','temporal_1']]\n", "dataset, indices, dates, ds = TimeSeriesDataset.from_df(df=temporal_df_w_mask,\n", " sort_df=True)\n", "mask_average = dataset.temporal[:, 1].mean()\n", "np.testing.assert_almost_equal(mask_average, 0.7000)" ] }, { "cell_type": "code", "execution_count": null, "id": "a0d23f1a", "metadata": {}, "outputs": [], "source": [ "# To test correct future_df wrangling of the `update_df` method\n", "# We are checking that we are able to recover the AirPassengers dataset\n", "# using the dataframe or splitting it into parts and initializing." ] }, { "cell_type": "code", "execution_count": null, "id": "39f999c2", "metadata": {}, "outputs": [], "source": [ "#| hide\n", "\n", "# FULL DATASET\n", "dataset_full, indices_full, dates_full, ds_full = TimeSeriesDataset.from_df(df=temporal_full_df,\n", " sort_df=False)" ] }, { "cell_type": "code", "execution_count": null, "id": "30f927e2", "metadata": {}, "outputs": [], "source": [ "#| hide\n", "\n", "# SPLIT_1 DATASET\n", "dataset_1, indices_1, dates_1, ds_1 = TimeSeriesDataset.from_df(df=split1_df,\n", " sort_df=False)\n", "dataset_1 = dataset_1.update_dataset(dataset_1, split2_df)" ] }, { "cell_type": "code", "execution_count": null, "id": "468a6879", "metadata": {}, "outputs": [], "source": [ "#| hide\n", "\n", "np.testing.assert_almost_equal(dataset_full.temporal.numpy(), dataset_1.temporal.numpy())\n", "test_eq(dataset_full.max_size, dataset_1.max_size)\n", "test_eq(dataset_full.indptr, dataset_1.indptr)" ] }, { "cell_type": "code", "execution_count": null, "id": "556f852c", "metadata": {}, "outputs": [], "source": [ "#| hide\n", "\n", "# Testing trim_dataset functionality\n", "n_static_features = 0\n", "n_temporal_features = 2\n", "temporal_df = generate_series(n_series=100,\n", " min_length=50,\n", " max_length=100,\n", " n_static_features=n_static_features,\n", " n_temporal_features=n_temporal_features, \n", " equal_ends=False)\n", "dataset, indices, dates, ds = TimeSeriesDataset.from_df(df=temporal_df,\n", " static_df=static_df,\n", " sort_df=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "db7b1a51", "metadata": {}, "outputs": [], "source": [ "#| hide\n", "left_trim = 10\n", "right_trim = 20\n", "dataset_trimmed = dataset.trim_dataset(dataset, left_trim=left_trim, right_trim=right_trim)\n", "\n", "np.testing.assert_almost_equal(dataset.temporal[dataset.indptr[50]+left_trim:dataset.indptr[51]-right_trim].numpy(),\n", " dataset_trimmed.temporal[dataset_trimmed.indptr[50]:dataset_trimmed.indptr[51]].numpy())" ] }, { "cell_type": "code", "execution_count": null, "id": "624a3fbb-cb78-4440-a645-54699fd82660", "metadata": {}, "outputs": [], "source": [ "#| hide\n", "#| polars\n", "import polars" ] }, { "cell_type": "code", "execution_count": null, "id": "a1bdd479-b4c7-4a40-93eb-2b7c9b969a80", "metadata": {}, "outputs": [], "source": [ "#| hide\n", "#| polars\n", "temporal_df2 = temporal_df.copy()\n", "for col in ('unique_id', 'temporal_0', 'temporal_1'):\n", " temporal_df2[col] = temporal_df2[col].cat.codes\n", "temporal_pl = polars.from_pandas(temporal_df2).sample(fraction=1.0)\n", "static_pl = polars.from_pandas(static_df.assign(unique_id=lambda df: df['unique_id'].astype('int64')))\n", "dataset_pl, indices_pl, dates_pl, ds_pl = TimeSeriesDataset.from_df(df=temporal_pl, static_df=static_df, sort_df=True)\n", "for attr in ('static_cols', 'temporal_cols', 'min_size', 'max_size', 'n_groups'):\n", " test_eq(getattr(dataset, attr), getattr(dataset_pl, attr))\n", "torch.testing.assert_allclose(dataset.temporal, dataset_pl.temporal)\n", "torch.testing.assert_allclose(dataset.static, dataset_pl.static)\n", "pd.testing.assert_series_equal(indices.astype('int64'), indices_pl.to_pandas().astype('int64'))\n", "pd.testing.assert_index_equal(dates, pd.Index(dates_pl, name='ds'))\n", "np.testing.assert_array_equal(ds, ds_pl)\n", "np.testing.assert_array_equal(dataset.indptr, dataset_pl.indptr)" ] }, { "cell_type": "code", "execution_count": null, "id": "959ea63c", "metadata": {}, "outputs": [], "source": [ "#| export\n", "class _DistributedTimeSeriesDataModule(TimeSeriesDataModule):\n", " def __init__(\n", " self,\n", " dataset: _FilesDataset,\n", " batch_size=32,\n", " valid_batch_size=1024,\n", " num_workers=0,\n", " drop_last=False,\n", " shuffle_train=True,\n", " ):\n", " super(TimeSeriesDataModule, self).__init__()\n", " self.files_ds = dataset\n", " self.batch_size = batch_size\n", " self.valid_batch_size = valid_batch_size\n", " self.num_workers = num_workers\n", " self.drop_last = drop_last\n", " self.shuffle_train = shuffle_train\n", "\n", " def setup(self, stage):\n", " import torch.distributed as dist\n", "\n", " df = pd.read_parquet(self.files_ds.files[dist.get_rank()])\n", " if self.files_ds.static_cols is not None:\n", " static_df = (\n", " df[[self.files_ds.id_col] + self.files_ds.static_cols.tolist()]\n", " .groupby(self.files_ds.id_col, observed=True)\n", " .head(1)\n", " )\n", " df = df.drop(columns=self.files_ds.static_cols)\n", " else:\n", " static_df = None\n", " self.dataset, *_ = TimeSeriesDataset.from_df(\n", " df=df,\n", " static_df=static_df,\n", " sort_df=True,\n", " id_col=self.files_ds.id_col,\n", " time_col=self.files_ds.time_col,\n", " target_col=self.files_ds.target_col,\n", " )" ] } ], "metadata": { "kernelspec": { "display_name": "python3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 5 }