{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "8e5c6594-e5e8-4966-8cb8-a3e6a9ed7d89", "metadata": {}, "outputs": [], "source": [ "#| default_exp common._base_model" ] }, { "cell_type": "code", "execution_count": null, "id": "fce0c950-2e03-4be1-95d4-a02409d8dba3", "metadata": {}, "outputs": [], "source": [ "#| hide\n", "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": null, "id": "1c7c2ba5-19ee-421e-9252-7224b03f5201", "metadata": {}, "outputs": [], "source": [ "#| export\n", "import inspect\n", "import random\n", "import warnings\n", "from contextlib import contextmanager\n", "from copy import deepcopy\n", "from dataclasses import dataclass\n", "\n", "import fsspec\n", "import numpy as np\n", "import torch\n", "import torch.nn as nn\n", "import pytorch_lightning as pl\n", "from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n", "\n", "from neuralforecast.tsdataset import (\n", " TimeSeriesDataModule,\n", " TimeSeriesDataset,\n", " _DistributedTimeSeriesDataModule,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "c6d4c4fd", "metadata": {}, "outputs": [], "source": [ "#| export\n", "@dataclass\n", "class DistributedConfig:\n", " partitions_path: str\n", " num_nodes: int\n", " devices: int" ] }, { "cell_type": "code", "execution_count": null, "id": "5197e340-11f1-4c8c-96d1-ed396ac2b710", "metadata": {}, "outputs": [], "source": [ "#| exporti\n", "@contextmanager\n", "def _disable_torch_init():\n", " \"\"\"Context manager used to disable pytorch's weight initialization.\n", "\n", " This is especially useful when loading saved models, since when initializing\n", " a model the weights are also initialized following some method\n", " (e.g. kaiming uniform), and that time is wasted since we'll override them with\n", " the saved weights.\"\"\"\n", " def noop(*args, **kwargs):\n", " return\n", " \n", " kaiming_uniform = nn.init.kaiming_uniform_\n", " kaiming_normal = nn.init.kaiming_normal_\n", " xavier_uniform = nn.init.xavier_uniform_\n", " xavier_normal = nn.init.xavier_normal_\n", " \n", " nn.init.kaiming_uniform_ = noop\n", " nn.init.kaiming_normal_ = noop\n", " nn.init.xavier_uniform_ = noop\n", " nn.init.xavier_normal_ = noop\n", " try:\n", " yield\n", " finally:\n", " nn.init.kaiming_uniform_ = kaiming_uniform\n", " nn.init.kaiming_normal_ = kaiming_normal\n", " nn.init.xavier_uniform_ = xavier_uniform\n", " nn.init.xavier_normal_ = xavier_normal" ] }, { "cell_type": "code", "execution_count": null, "id": "60c40a64-8381-46a2-8cbb-70ec70ed7914", "metadata": {}, "outputs": [], "source": [ "#| export\n", "class BaseModel(pl.LightningModule):\n", " def __init__(\n", " self,\n", " random_seed,\n", " loss,\n", " valid_loss,\n", " optimizer,\n", " optimizer_kwargs,\n", " futr_exog_list,\n", " hist_exog_list,\n", " stat_exog_list,\n", " max_steps,\n", " early_stop_patience_steps,\n", " **trainer_kwargs,\n", " ):\n", " super().__init__()\n", " with warnings.catch_warnings(record=False):\n", " warnings.filterwarnings('ignore')\n", " # the following line issues a warning about the loss attribute being saved\n", " # but we do want to save it\n", " self.save_hyperparameters() # Allows instantiation from a checkpoint from class\n", " self.random_seed = random_seed\n", " pl.seed_everything(self.random_seed, workers=True)\n", "\n", " # Loss\n", " self.loss = loss\n", " if valid_loss is None:\n", " self.valid_loss = loss\n", " else:\n", " self.valid_loss = valid_loss\n", " self.train_trajectories = []\n", " self.valid_trajectories = []\n", "\n", " # Optimization\n", " if optimizer is not None and not issubclass(optimizer, torch.optim.Optimizer):\n", " raise TypeError(\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n", " self.optimizer = optimizer\n", " self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs else {}\n", "\n", " # Variables\n", " self.futr_exog_list = list(futr_exog_list) if futr_exog_list is not None else []\n", " self.hist_exog_list = list(hist_exog_list) if hist_exog_list is not None else []\n", " self.stat_exog_list = list(stat_exog_list) if stat_exog_list is not None else []\n", "\n", " ## Trainer arguments ##\n", " # Max steps, validation steps and check_val_every_n_epoch\n", " trainer_kwargs = {**trainer_kwargs, 'max_steps': max_steps}\n", "\n", " if 'max_epochs' in trainer_kwargs.keys():\n", " raise Exception('max_epochs is deprecated, use max_steps instead.')\n", "\n", " # Callbacks\n", " if early_stop_patience_steps > 0:\n", " if 'callbacks' not in trainer_kwargs:\n", " trainer_kwargs['callbacks'] = []\n", " trainer_kwargs['callbacks'].append(\n", " EarlyStopping(\n", " monitor='ptl/val_loss', patience=early_stop_patience_steps\n", " )\n", " )\n", "\n", " # Add GPU accelerator if available\n", " if trainer_kwargs.get('accelerator', None) is None:\n", " if torch.cuda.is_available():\n", " trainer_kwargs['accelerator'] = \"gpu\"\n", " if trainer_kwargs.get('devices', None) is None:\n", " if torch.cuda.is_available():\n", " trainer_kwargs['devices'] = -1\n", "\n", " # Avoid saturating local memory, disabled fit model checkpoints\n", " if trainer_kwargs.get('enable_checkpointing', None) is None:\n", " trainer_kwargs['enable_checkpointing'] = False\n", "\n", " self.trainer_kwargs = trainer_kwargs\n", "\n", " def __repr__(self):\n", " return type(self).__name__ if self.alias is None else self.alias\n", "\n", " def _check_exog(self, dataset):\n", " temporal_cols = set(dataset.temporal_cols.tolist())\n", " static_cols = set(dataset.static_cols.tolist() if dataset.static_cols is not None else [])\n", "\n", " missing_hist = set(self.hist_exog_list) - temporal_cols\n", " missing_futr = set(self.futr_exog_list) - temporal_cols\n", " missing_stat = set(self.stat_exog_list) - static_cols\n", " if missing_hist:\n", " raise Exception(f'{missing_hist} historical exogenous variables not found in input dataset')\n", " if missing_futr:\n", " raise Exception(f'{missing_futr} future exogenous variables not found in input dataset')\n", " if missing_stat:\n", " raise Exception(f'{missing_stat} static exogenous variables not found in input dataset')\n", "\n", " def _restart_seed(self, random_seed):\n", " if random_seed is None:\n", " random_seed = self.random_seed\n", " torch.manual_seed(random_seed)\n", "\n", " def _get_temporal_exogenous_cols(self, temporal_cols):\n", " return list(\n", " set(temporal_cols.tolist()) & set(self.hist_exog_list + self.futr_exog_list)\n", " )\n", "\n", " def _fit(\n", " self,\n", " dataset,\n", " batch_size,\n", " valid_batch_size=1024,\n", " val_size=0,\n", " test_size=0,\n", " random_seed=None,\n", " shuffle_train=True,\n", " distributed_config=None,\n", " ):\n", " self._check_exog(dataset)\n", " self._restart_seed(random_seed)\n", "\n", " self.val_size = val_size\n", " self.test_size = test_size\n", " is_local = isinstance(dataset, TimeSeriesDataset)\n", " if is_local:\n", " datamodule_constructor = TimeSeriesDataModule\n", " else:\n", " datamodule_constructor = _DistributedTimeSeriesDataModule\n", " datamodule = datamodule_constructor(\n", " dataset=dataset, \n", " batch_size=batch_size,\n", " valid_batch_size=valid_batch_size,\n", " num_workers=self.num_workers_loader,\n", " drop_last=self.drop_last_loader,\n", " shuffle_train=shuffle_train,\n", " )\n", "\n", " if self.val_check_steps > self.max_steps:\n", " warnings.warn(\n", " 'val_check_steps is greater than max_steps, '\n", " 'setting val_check_steps to max_steps.'\n", " )\n", " val_check_interval = min(self.val_check_steps, self.max_steps)\n", " self.trainer_kwargs['val_check_interval'] = int(val_check_interval)\n", " self.trainer_kwargs['check_val_every_n_epoch'] = None\n", "\n", " if is_local:\n", " model = self\n", " trainer = pl.Trainer(**model.trainer_kwargs)\n", " trainer.fit(model, datamodule=datamodule)\n", " model.metrics = trainer.callback_metrics\n", " model.__dict__.pop('_trainer', None)\n", " else:\n", " assert distributed_config is not None\n", " from pyspark.ml.torch.distributor import TorchDistributor\n", "\n", " def train_fn(\n", " model_cls,\n", " model_params,\n", " datamodule,\n", " trainer_kwargs,\n", " num_tasks,\n", " num_proc_per_task,\n", " val_size,\n", " test_size,\n", " ):\n", " import pytorch_lightning as pl\n", "\n", " # we instantiate here to avoid pickling large tensors (weights)\n", " model = model_cls(**model_params)\n", " model.val_size = val_size\n", " model.test_size = test_size\n", " for arg in ('devices', 'num_nodes'):\n", " trainer_kwargs.pop(arg, None)\n", " trainer = pl.Trainer(\n", " strategy=\"ddp\",\n", " use_distributed_sampler=False, # to ensure our dataloaders are used as-is\n", " num_nodes=num_tasks,\n", " devices=num_proc_per_task,\n", " **trainer_kwargs,\n", " )\n", " trainer.fit(model=model, datamodule=datamodule)\n", " model.metrics = trainer.callback_metrics\n", " model.__dict__.pop('_trainer', None)\n", " return model\n", "\n", " def is_gpu_accelerator(accelerator):\n", " from pytorch_lightning.accelerators.cuda import CUDAAccelerator\n", "\n", " return (\n", " accelerator == \"gpu\"\n", " or isinstance(accelerator, CUDAAccelerator)\n", " or (accelerator == \"auto\" and CUDAAccelerator.is_available())\n", " )\n", "\n", " local_mode = distributed_config.num_nodes == 1\n", " if local_mode:\n", " num_tasks = 1\n", " num_proc_per_task = distributed_config.devices\n", " else:\n", " num_tasks = distributed_config.devices * distributed_config.devices\n", " num_proc_per_task = 1 # number of GPUs per task\n", " num_proc = num_tasks * num_proc_per_task\n", " use_gpu = is_gpu_accelerator(self.trainer_kwargs[\"accelerator\"])\n", " model = TorchDistributor(\n", " num_processes=num_proc,\n", " local_mode=local_mode,\n", " use_gpu=use_gpu,\n", " ).run(\n", " train_fn,\n", " model_cls=type(self),\n", " model_params=self.hparams,\n", " datamodule=datamodule,\n", " trainer_kwargs=self.trainer_kwargs,\n", " num_tasks=num_tasks,\n", " num_proc_per_task=num_proc_per_task,\n", " val_size=val_size,\n", " test_size=test_size,\n", " )\n", " return model\n", "\n", " def on_fit_start(self):\n", " torch.manual_seed(self.random_seed)\n", " np.random.seed(self.random_seed)\n", " random.seed(self.random_seed)\n", "\n", " def configure_optimizers(self):\n", " if self.optimizer:\n", " optimizer_signature = inspect.signature(self.optimizer)\n", " optimizer_kwargs = deepcopy(self.optimizer_kwargs)\n", " if 'lr' in optimizer_signature.parameters:\n", " if 'lr' in optimizer_kwargs:\n", " warnings.warn(\"ignoring learning rate passed in optimizer_kwargs, using the model's learning rate\")\n", " optimizer_kwargs['lr'] = self.learning_rate\n", " optimizer = self.optimizer(params=self.parameters(), **optimizer_kwargs)\n", " else:\n", " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", " scheduler = {\n", " 'scheduler': torch.optim.lr_scheduler.StepLR(\n", " optimizer=optimizer, step_size=self.lr_decay_steps, gamma=0.5\n", " ),\n", " 'frequency': 1,\n", " 'interval': 'step',\n", " }\n", " return {'optimizer': optimizer, 'lr_scheduler': scheduler}\n", "\n", " def get_test_size(self):\n", " return self.test_size\n", "\n", " def set_test_size(self, test_size):\n", " self.test_size = test_size\n", "\n", " def on_validation_epoch_end(self):\n", " if self.val_size == 0:\n", " return\n", " losses = torch.stack(self.validation_step_outputs)\n", " avg_loss = losses.mean().item()\n", " self.log(\n", " \"ptl/val_loss\",\n", " avg_loss,\n", " batch_size=losses.size(0),\n", " sync_dist=True,\n", " )\n", " self.valid_trajectories.append((self.global_step, avg_loss))\n", " self.validation_step_outputs.clear() # free memory (compute `avg_loss` per epoch)\n", "\n", " def save(self, path):\n", " with fsspec.open(path, 'wb') as f:\n", " torch.save(\n", " {'hyper_parameters': self.hparams, 'state_dict': self.state_dict()},\n", " f,\n", " )\n", "\n", " @classmethod\n", " def load(cls, path, **kwargs):\n", " with fsspec.open(path, 'rb') as f:\n", " content = torch.load(f, **kwargs)\n", " with _disable_torch_init():\n", " model = cls(**content['hyper_parameters']) \n", " model.load_state_dict(content['state_dict'], strict=True, assign=True)\n", " return model" ] } ], "metadata": { "kernelspec": { "display_name": "python3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 5 }