Commit e3f7f7b3 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #956 failed with stages
in 0 seconds
FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-centos7.6-dtk23.10-py38
ENV DEBIAN_FRONTEND=noninteractive
# RUN yum update && yum install -y git cmake wget build-essential
RUN source /opt/dtk-23.10/env.sh
# 安装pip相关依赖
COPY requirements.txt requirements.txt
RUN pip3 install -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com -r requirements.txt
coreforecast>=0.0.6
fsspec
gitpython
hyperopt
jupyterlab
matplotlib
numba
numpy>=1.21.6
optuna
pandas>=1.3.5
pyarrow
# pytorch>=2.0.0
# pytorch-cuda>=11.8
pytorch-lightning>=2.0.0
s3fs
nbdev
black
polars
ray[tune]>=2.2.0
utilsforecast>=0.0.24
datasetsforecast
docker run -it --shm-size=32G -v $PWD/neuralforecast:/home/neuralforecast -v /opt/hyhal:/opt/hyhal:ro --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name neuralforecast ffa1f63239fc bash
# python -m torch.utils.collect_env
name: neuralforecast
channels:
- pytorch
- conda-forge
dependencies:
- coreforecast>=0.0.6
- cpuonly
- fsspec
- gitpython
- hyperopt
- jupyterlab
- matplotlib
- numba
- numpy>=1.21.6
- optuna
- pandas>=1.3.5
- pyarrow
- pytorch>=2.0.0
- pytorch-lightning>=2.0.0
- pip
- s3fs
- snappy<1.2.0
- pip:
- nbdev
- black
- polars
- ray[tune]>=2.2.0
- utilsforecast>=0.0.25
name: neuralforecast
channels:
- pytorch
- nvidia
- conda-forge
dependencies:
- coreforecast>=0.0.6
- fsspec
- gitpython
- hyperopt
- jupyterlab
- matplotlib
- numba
- numpy>=1.21.6
- optuna
- pandas>=1.3.5
- pyarrow
- pytorch>=2.0.0
- pytorch-cuda>=11.8
- pytorch-lightning>=2.0.0
- pip
- s3fs
- pip:
- nbdev
- black
- polars
- "ray[tune]>=2.2.0"
- utilsforecast>=0.0.24
# Long Horizon Forecasting Experiments with NHITS
In these experiments we use `NHITS` on the [ETTh1, ETTh2, ETTm1, ETTm2](https://github.com/zhouhaoyi/ETDataset) benchmark datasets.
| Dataset | Horizon | NHITS-MSE | NHITS-MAE | TIDE-MSE | TIDE-MAE |
|----------|----------|------------|------------|------------|------------|
| ETTh1 | 96 | 0.378 | 0.393 | 0.375 | 0.398 |
| ETTh1 | 192 | 0.427 | 0.436 | 0.412 | 0.422 |
| ETTh1 | 336 | 0.458 | 0.484 | 0.435 | 0.433 |
| ETTh1 | 720 | 0.561 | 0.501 | 0.454 | 0.465 |
|----------|----------|------------|------------|------------|------------|
| ETTh2 | 96 | 0.274 | 0.345 | 0.270 | 0.336 |
| ETTh2 | 192 | 0.353 | 0.401 | 0.332 | 0.380 |
| ETTh2 | 336 | 0.382 | 0.425 | 0.360 | 0.407 |
| ETTh2 | 720 | 0.625 | 0.557 | 0.419 | 0.451 |
|----------|----------|------------|------------|------------|------------|
| ETTm1 | 96 | 0.302 | 0.35 | 0.306 | 0.349 |
| ETTm1 | 192 | 0.347 | 0.383 | 0.335 | 0.366 |
| ETTm1 | 336 | 0.369 | 0.402 | 0.364 | 0.384 |
| ETTm1 | 720 | 0.431 | 0.441 | 0.413 | 0.413 |
|----------|----------|------------|------------|------------|------------|
| ETTm2 | 96 | 0.176 | 0.255 | 0.161 | 0.251 |
| ETTm2 | 192 | 0.245 | 0.305 | 0.215 | 0.289 |
| ETTm2 | 336 | 0.295 | 0.346 | 0.267 | 0.326 |
| ETTm2 | 720 | 0.401 | 0.413 | 0.352 | 0.383 |
|----------|----------|------------|------------|------------|------------|
<br>
## Reproducibility
1. Create a conda environment `long_horizon` using the `environment.yml` file.
```shell
conda env create -f environment.yml
```
3. Activate the conda environment using
```shell
conda activate long_horizon
```
Alternatively simply installing neuralforecast and datasetsforecast with pip may suffice:
```
pip install git+https://github.com/Nixtla/datasetsforecast.git
pip install git+https://github.com/Nixtla/neuralforecast.git
```
4. Run the experiments for each dataset and each model using with
- `--horizon` parameter in `[96, 192, 336, 720]`
- `--dataset` parameter in `['ETTh1', 'ETTh2', 'ETTm1', 'ETTm2']`
<br>
```shell
python run_nhits.py --dataset 'ETTh1' --horizon 96 --num_samples 20
```
You can access the final forecasts from the `./data/{dataset}/{horizon}_forecasts.csv` file. Example: `./data/ETTh1/96_forecasts.csv`.
<br><br>
## References
-[Cristian Challu, Kin G. Olivares, Boris N. Oreshkin, Federico Garza, Max Mergenthaler-Canseco, Artur Dubrawski (2023). "NHITS: Neural Hierarchical Interpolation for Time Series Forecasting". Accepted at the Thirty-Seventh AAAI Conference on Artificial Intelligence.](https://arxiv.org/abs/2201.12886)
\ No newline at end of file
name: long_horizon
channels:
- conda-forge
dependencies:
- numpy<1.24
- pip
- pip:
- "git+https://github.com/Nixtla/datasetsforecast.git"
- "git+https://github.com/Nixtla/neuralforecast.git"
\ No newline at end of file
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import argparse
import pandas as pd
from ray import tune
from neuralforecast.auto import AutoNHITS
from neuralforecast.core import NeuralForecast
from neuralforecast.losses.pytorch import MAE, HuberLoss
from neuralforecast.losses.numpy import mae, mse
#from datasetsforecast.long_horizon import LongHorizon, LongHorizonInfo
from datasetsforecast.long_horizon2 import LongHorizon2, LongHorizon2Info
import logging
logging.getLogger("pytorch_lightning").setLevel(logging.WARNING)
if __name__ == '__main__':
# Parse execution parameters
verbose = True
parser = argparse.ArgumentParser()
parser.add_argument("-horizon", "--horizon", type=int)
parser.add_argument("-dataset", "--dataset", type=str)
parser.add_argument("-num_samples", "--num_samples", default=5, type=int)
args = parser.parse_args()
horizon = args.horizon
dataset = args.dataset
num_samples = args.num_samples
assert horizon in [96, 192, 336, 720]
# Load dataset
#Y_df, _, _ = LongHorizon.load(directory='./data/', group=dataset)
#Y_df['ds'] = pd.to_datetime(Y_df['ds'])
Y_df = LongHorizon2.load(directory='./data/', group=dataset)
freq = LongHorizon2Info[dataset].freq
n_time = len(Y_df.ds.unique())
#val_size = int(.2 * n_time)
#test_size = int(.2 * n_time)
val_size = LongHorizon2Info[dataset].val_size
test_size = LongHorizon2Info[dataset].test_size
# Adapt input_size to available data
input_size = tune.choice([7 * horizon])
if dataset=='ETTm1' and horizon==720:
input_size = tune.choice([2 * horizon])
nhits_config = {
#"learning_rate": tune.choice([1e-3]), # Initial Learning rate
"learning_rate": tune.loguniform(1e-5, 5e-3),
"max_steps": tune.choice([200, 1000]), # Number of SGD steps
"input_size": input_size, # input_size = multiplier * horizon
"batch_size": tune.choice([7]), # Number of series in windows
"windows_batch_size": tune.choice([256]), # Number of windows in batch
"n_pool_kernel_size": tune.choice([[2, 2, 2], [16, 8, 1]]), # MaxPool's Kernelsize
"n_freq_downsample": tune.choice([[(96*7)//2, 96//2, 1],
[(24*7)//2, 24//2, 1],
[1, 1, 1]]), # Interpolation expressivity ratios
"dropout_prob_theta": tune.choice([0.5]), # Dropout regularization
"activation": tune.choice(['ReLU']), # Type of non-linear activation
"n_blocks": tune.choice([[1, 1, 1]]), # Blocks per each 3 stacks
"mlp_units": tune.choice([[[512, 512], [512, 512], [512, 512]]]), # 2 512-Layers per block for each stack
"interpolation_mode": tune.choice(['linear']), # Type of multi-step interpolation
"val_check_steps": tune.choice([100]), # Compute validation every 100 epochs
"random_seed": tune.randint(1, 10),
}
models = [AutoNHITS(h=horizon,
loss=HuberLoss(delta=0.5),
valid_loss=MAE(),
config=nhits_config,
num_samples=num_samples,
refit_with_val=True)]
nf = NeuralForecast(models=models, freq=freq)
Y_hat_df = nf.cross_validation(df=Y_df, val_size=val_size,
test_size=test_size, n_windows=None)
y_true = Y_hat_df.y.values
y_hat = Y_hat_df['AutoNHITS'].values
n_series = len(Y_df.unique_id.unique())
y_true = y_true.reshape(n_series, -1, horizon)
y_hat = y_hat.reshape(n_series, -1, horizon)
print('\n'*4)
print('Parsed results')
print(f'NHITS {dataset} h={horizon}')
print('test_size', test_size)
print('y_true.shape (n_series, n_windows, n_time_out):\t', y_true.shape)
print('y_hat.shape (n_series, n_windows, n_time_out):\t', y_hat.shape)
print('MSE: ', mse(y_hat, y_true))
print('MAE: ', mae(y_hat, y_true))
# Save Outputs
if not os.path.exists(f'./data/{dataset}'):
os.makedirs(f'./data/{dataset}')
yhat_file = f'./data/{dataset}/{horizon}_forecasts.csv'
Y_hat_df.to_csv(yhat_file, index=False)
import pandas as pd
import numpy as np
from datasetsforecast.long_horizon import LongHorizon
from neuralforecast.core import NeuralForecast
def load_data(name):
if name == "ettm2":
Y_df, X_df, S_df = LongHorizon.load(directory='./ETT-small/', group='ETTm2')
Y_df = Y_df[Y_df['unique_id'] == 'OT']
Y_df['ds'] = pd.to_datetime(Y_df['ds'])
val_size = 11520
test_size = 11520
freq = '15T'
return Y_df, val_size, test_size, freq
# infer
Y_df, val_size, test_size, freq = load_data('ettm2')
nf = NeuralForecast.load(path='./checkpoints/test_run/')
Y_hat_df = nf.predict(Y_df).reset_index()#_predict(df: pd.DataFrame, static_cols, futr_exog_cols, models, freq, id_col, time_col, target_col)
print("Y_hat_df: ", Y_hat_df)
'''
futr_df = pd.read_csv('https://datasets-nixtla.s3.amazonaws.com/EPF_FR_BE_futr.csv')
futr_df['ds'] = pd.to_datetime(futr_df['ds'])
Y_hat_df = nf.predict(futr_df=futr_df)
Y_hat_df.head()
'''
'''
from neuralforecast.utils import AirPassengersDF
Y_df = AirPassengersDF # Defined in neuralforecast.utils
Y_df.head()
'''
from neuralforecast import NeuralForecast
from neuralforecast.models import iTransformer
from neuralforecast.utils import AirPassengersDF
horizon =8
nf = NeuralForecast(
# models = [iTransformer( h=12, input_size=24, n_series=1, max_steps=100)],
models = [iTransformer(h=horizon, input_size=2*horizon, n_series=4, max_steps=1000, early_stop_patience_steps=3)],
freq = 'M'
)
nf.fit(df=AirPassengersDF, val_size=20)
print(nf.predict())
# 模型编码
modelCode=613
# 模型名称
modelName=neuralforecast-itransformer_pytorch
# 模型描述
modelDescription=时序预测库neuralforecast中的iTransformer算法能高效利用长程时序特征。
# 应用场景
appScenario=推理,训练,金融,运维,电商,制造,能源,医疗
# 框架类型
frameType=pytorch
/.quarto/
/lightning_logs/
project:
type: website
format:
html:
theme: cosmo
fontsize: 1em
linestretch: 1.7
css: styles.css
toc: true
website:
twitter-card:
image: "https://farm6.staticflickr.com/5510/14338202952_93595258ff_z.jpg"
site: "@Nixtlainc"
open-graph:
image: "https://github.com/Nixtla/styles/blob/2abf51612584169874c90cd7c4d347e3917eaf73/images/Banner%20Github.png"
google-analytics: "G-NXJNCVR18L"
repo-actions: [issue]
favicon: favicon_png.png
navbar:
background: primary
search: true
collapse-below: lg
left:
- text: "Get Started"
href: examples/Getting_Started.ipynb
- text: "NixtlaVerse"
menu:
- text: "MLForecast 🤖"
href: https://github.com/nixtla/mlforecast
- text: "StatsForecast ⚡️"
href: https://github.com/nixtla/statsforecast
- text: "HierarchicalForecast 👑"
href: "https://github.com/nixtla/hierarchicalforecast"
- text: "Help"
menu:
- text: "Report an Issue"
icon: bug
href: https://github.com/nixtla/neuralforecast/issues/new/choose
- text: "Join our Slack"
icon: chat-right-text
href: https://join.slack.com/t/nixtlaworkspace/shared_invite/zt-135dssye9-fWTzMpv2WBthq8NK0Yvu6A
right:
- icon: github
href: "https://github.com/nixtla/neuralforecast"
- icon: twitter
href: https://twitter.com/nixtlainc
aria-label: Nixtla Twitter
sidebar:
style: floating
body-footer: |
Give us a ⭐ on [Github](https://github.com/nixtla/neuralforecast)
metadata-files: [nbdev.yml, sidebar.yml]
This diff is collapsed.
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "8e5c6594-e5e8-4966-8cb8-a3e6a9ed7d89",
"metadata": {},
"outputs": [],
"source": [
"#| default_exp common._base_model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fce0c950-2e03-4be1-95d4-a02409d8dba3",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1c7c2ba5-19ee-421e-9252-7224b03f5201",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"import inspect\n",
"import random\n",
"import warnings\n",
"from contextlib import contextmanager\n",
"from copy import deepcopy\n",
"from dataclasses import dataclass\n",
"\n",
"import fsspec\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import pytorch_lightning as pl\n",
"from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n",
"\n",
"from neuralforecast.tsdataset import (\n",
" TimeSeriesDataModule,\n",
" TimeSeriesDataset,\n",
" _DistributedTimeSeriesDataModule,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c6d4c4fd",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"@dataclass\n",
"class DistributedConfig:\n",
" partitions_path: str\n",
" num_nodes: int\n",
" devices: int"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5197e340-11f1-4c8c-96d1-ed396ac2b710",
"metadata": {},
"outputs": [],
"source": [
"#| exporti\n",
"@contextmanager\n",
"def _disable_torch_init():\n",
" \"\"\"Context manager used to disable pytorch's weight initialization.\n",
"\n",
" This is especially useful when loading saved models, since when initializing\n",
" a model the weights are also initialized following some method\n",
" (e.g. kaiming uniform), and that time is wasted since we'll override them with\n",
" the saved weights.\"\"\"\n",
" def noop(*args, **kwargs):\n",
" return\n",
" \n",
" kaiming_uniform = nn.init.kaiming_uniform_\n",
" kaiming_normal = nn.init.kaiming_normal_\n",
" xavier_uniform = nn.init.xavier_uniform_\n",
" xavier_normal = nn.init.xavier_normal_\n",
" \n",
" nn.init.kaiming_uniform_ = noop\n",
" nn.init.kaiming_normal_ = noop\n",
" nn.init.xavier_uniform_ = noop\n",
" nn.init.xavier_normal_ = noop\n",
" try:\n",
" yield\n",
" finally:\n",
" nn.init.kaiming_uniform_ = kaiming_uniform\n",
" nn.init.kaiming_normal_ = kaiming_normal\n",
" nn.init.xavier_uniform_ = xavier_uniform\n",
" nn.init.xavier_normal_ = xavier_normal"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "60c40a64-8381-46a2-8cbb-70ec70ed7914",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"class BaseModel(pl.LightningModule):\n",
" def __init__(\n",
" self,\n",
" random_seed,\n",
" loss,\n",
" valid_loss,\n",
" optimizer,\n",
" optimizer_kwargs,\n",
" futr_exog_list,\n",
" hist_exog_list,\n",
" stat_exog_list,\n",
" max_steps,\n",
" early_stop_patience_steps,\n",
" **trainer_kwargs,\n",
" ):\n",
" super().__init__()\n",
" with warnings.catch_warnings(record=False):\n",
" warnings.filterwarnings('ignore')\n",
" # the following line issues a warning about the loss attribute being saved\n",
" # but we do want to save it\n",
" self.save_hyperparameters() # Allows instantiation from a checkpoint from class\n",
" self.random_seed = random_seed\n",
" pl.seed_everything(self.random_seed, workers=True)\n",
"\n",
" # Loss\n",
" self.loss = loss\n",
" if valid_loss is None:\n",
" self.valid_loss = loss\n",
" else:\n",
" self.valid_loss = valid_loss\n",
" self.train_trajectories = []\n",
" self.valid_trajectories = []\n",
"\n",
" # Optimization\n",
" if optimizer is not None and not issubclass(optimizer, torch.optim.Optimizer):\n",
" raise TypeError(\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n",
" self.optimizer = optimizer\n",
" self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs else {}\n",
"\n",
" # Variables\n",
" self.futr_exog_list = list(futr_exog_list) if futr_exog_list is not None else []\n",
" self.hist_exog_list = list(hist_exog_list) if hist_exog_list is not None else []\n",
" self.stat_exog_list = list(stat_exog_list) if stat_exog_list is not None else []\n",
"\n",
" ## Trainer arguments ##\n",
" # Max steps, validation steps and check_val_every_n_epoch\n",
" trainer_kwargs = {**trainer_kwargs, 'max_steps': max_steps}\n",
"\n",
" if 'max_epochs' in trainer_kwargs.keys():\n",
" raise Exception('max_epochs is deprecated, use max_steps instead.')\n",
"\n",
" # Callbacks\n",
" if early_stop_patience_steps > 0:\n",
" if 'callbacks' not in trainer_kwargs:\n",
" trainer_kwargs['callbacks'] = []\n",
" trainer_kwargs['callbacks'].append(\n",
" EarlyStopping(\n",
" monitor='ptl/val_loss', patience=early_stop_patience_steps\n",
" )\n",
" )\n",
"\n",
" # Add GPU accelerator if available\n",
" if trainer_kwargs.get('accelerator', None) is None:\n",
" if torch.cuda.is_available():\n",
" trainer_kwargs['accelerator'] = \"gpu\"\n",
" if trainer_kwargs.get('devices', None) is None:\n",
" if torch.cuda.is_available():\n",
" trainer_kwargs['devices'] = -1\n",
"\n",
" # Avoid saturating local memory, disabled fit model checkpoints\n",
" if trainer_kwargs.get('enable_checkpointing', None) is None:\n",
" trainer_kwargs['enable_checkpointing'] = False\n",
"\n",
" self.trainer_kwargs = trainer_kwargs\n",
"\n",
" def __repr__(self):\n",
" return type(self).__name__ if self.alias is None else self.alias\n",
"\n",
" def _check_exog(self, dataset):\n",
" temporal_cols = set(dataset.temporal_cols.tolist())\n",
" static_cols = set(dataset.static_cols.tolist() if dataset.static_cols is not None else [])\n",
"\n",
" missing_hist = set(self.hist_exog_list) - temporal_cols\n",
" missing_futr = set(self.futr_exog_list) - temporal_cols\n",
" missing_stat = set(self.stat_exog_list) - static_cols\n",
" if missing_hist:\n",
" raise Exception(f'{missing_hist} historical exogenous variables not found in input dataset')\n",
" if missing_futr:\n",
" raise Exception(f'{missing_futr} future exogenous variables not found in input dataset')\n",
" if missing_stat:\n",
" raise Exception(f'{missing_stat} static exogenous variables not found in input dataset')\n",
"\n",
" def _restart_seed(self, random_seed):\n",
" if random_seed is None:\n",
" random_seed = self.random_seed\n",
" torch.manual_seed(random_seed)\n",
"\n",
" def _get_temporal_exogenous_cols(self, temporal_cols):\n",
" return list(\n",
" set(temporal_cols.tolist()) & set(self.hist_exog_list + self.futr_exog_list)\n",
" )\n",
"\n",
" def _fit(\n",
" self,\n",
" dataset,\n",
" batch_size,\n",
" valid_batch_size=1024,\n",
" val_size=0,\n",
" test_size=0,\n",
" random_seed=None,\n",
" shuffle_train=True,\n",
" distributed_config=None,\n",
" ):\n",
" self._check_exog(dataset)\n",
" self._restart_seed(random_seed)\n",
"\n",
" self.val_size = val_size\n",
" self.test_size = test_size\n",
" is_local = isinstance(dataset, TimeSeriesDataset)\n",
" if is_local:\n",
" datamodule_constructor = TimeSeriesDataModule\n",
" else:\n",
" datamodule_constructor = _DistributedTimeSeriesDataModule\n",
" datamodule = datamodule_constructor(\n",
" dataset=dataset, \n",
" batch_size=batch_size,\n",
" valid_batch_size=valid_batch_size,\n",
" num_workers=self.num_workers_loader,\n",
" drop_last=self.drop_last_loader,\n",
" shuffle_train=shuffle_train,\n",
" )\n",
"\n",
" if self.val_check_steps > self.max_steps:\n",
" warnings.warn(\n",
" 'val_check_steps is greater than max_steps, '\n",
" 'setting val_check_steps to max_steps.'\n",
" )\n",
" val_check_interval = min(self.val_check_steps, self.max_steps)\n",
" self.trainer_kwargs['val_check_interval'] = int(val_check_interval)\n",
" self.trainer_kwargs['check_val_every_n_epoch'] = None\n",
"\n",
" if is_local:\n",
" model = self\n",
" trainer = pl.Trainer(**model.trainer_kwargs)\n",
" trainer.fit(model, datamodule=datamodule)\n",
" model.metrics = trainer.callback_metrics\n",
" model.__dict__.pop('_trainer', None)\n",
" else:\n",
" assert distributed_config is not None\n",
" from pyspark.ml.torch.distributor import TorchDistributor\n",
"\n",
" def train_fn(\n",
" model_cls,\n",
" model_params,\n",
" datamodule,\n",
" trainer_kwargs,\n",
" num_tasks,\n",
" num_proc_per_task,\n",
" val_size,\n",
" test_size,\n",
" ):\n",
" import pytorch_lightning as pl\n",
"\n",
" # we instantiate here to avoid pickling large tensors (weights)\n",
" model = model_cls(**model_params)\n",
" model.val_size = val_size\n",
" model.test_size = test_size\n",
" for arg in ('devices', 'num_nodes'):\n",
" trainer_kwargs.pop(arg, None)\n",
" trainer = pl.Trainer(\n",
" strategy=\"ddp\",\n",
" use_distributed_sampler=False, # to ensure our dataloaders are used as-is\n",
" num_nodes=num_tasks,\n",
" devices=num_proc_per_task,\n",
" **trainer_kwargs,\n",
" )\n",
" trainer.fit(model=model, datamodule=datamodule)\n",
" model.metrics = trainer.callback_metrics\n",
" model.__dict__.pop('_trainer', None)\n",
" return model\n",
"\n",
" def is_gpu_accelerator(accelerator):\n",
" from pytorch_lightning.accelerators.cuda import CUDAAccelerator\n",
"\n",
" return (\n",
" accelerator == \"gpu\"\n",
" or isinstance(accelerator, CUDAAccelerator)\n",
" or (accelerator == \"auto\" and CUDAAccelerator.is_available())\n",
" )\n",
"\n",
" local_mode = distributed_config.num_nodes == 1\n",
" if local_mode:\n",
" num_tasks = 1\n",
" num_proc_per_task = distributed_config.devices\n",
" else:\n",
" num_tasks = distributed_config.devices * distributed_config.devices\n",
" num_proc_per_task = 1 # number of GPUs per task\n",
" num_proc = num_tasks * num_proc_per_task\n",
" use_gpu = is_gpu_accelerator(self.trainer_kwargs[\"accelerator\"])\n",
" model = TorchDistributor(\n",
" num_processes=num_proc,\n",
" local_mode=local_mode,\n",
" use_gpu=use_gpu,\n",
" ).run(\n",
" train_fn,\n",
" model_cls=type(self),\n",
" model_params=self.hparams,\n",
" datamodule=datamodule,\n",
" trainer_kwargs=self.trainer_kwargs,\n",
" num_tasks=num_tasks,\n",
" num_proc_per_task=num_proc_per_task,\n",
" val_size=val_size,\n",
" test_size=test_size,\n",
" )\n",
" return model\n",
"\n",
" def on_fit_start(self):\n",
" torch.manual_seed(self.random_seed)\n",
" np.random.seed(self.random_seed)\n",
" random.seed(self.random_seed)\n",
"\n",
" def configure_optimizers(self):\n",
" if self.optimizer:\n",
" optimizer_signature = inspect.signature(self.optimizer)\n",
" optimizer_kwargs = deepcopy(self.optimizer_kwargs)\n",
" if 'lr' in optimizer_signature.parameters:\n",
" if 'lr' in optimizer_kwargs:\n",
" warnings.warn(\"ignoring learning rate passed in optimizer_kwargs, using the model's learning rate\")\n",
" optimizer_kwargs['lr'] = self.learning_rate\n",
" optimizer = self.optimizer(params=self.parameters(), **optimizer_kwargs)\n",
" else:\n",
" optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
" scheduler = {\n",
" 'scheduler': torch.optim.lr_scheduler.StepLR(\n",
" optimizer=optimizer, step_size=self.lr_decay_steps, gamma=0.5\n",
" ),\n",
" 'frequency': 1,\n",
" 'interval': 'step',\n",
" }\n",
" return {'optimizer': optimizer, 'lr_scheduler': scheduler}\n",
"\n",
" def get_test_size(self):\n",
" return self.test_size\n",
"\n",
" def set_test_size(self, test_size):\n",
" self.test_size = test_size\n",
"\n",
" def on_validation_epoch_end(self):\n",
" if self.val_size == 0:\n",
" return\n",
" losses = torch.stack(self.validation_step_outputs)\n",
" avg_loss = losses.mean().item()\n",
" self.log(\n",
" \"ptl/val_loss\",\n",
" avg_loss,\n",
" batch_size=losses.size(0),\n",
" sync_dist=True,\n",
" )\n",
" self.valid_trajectories.append((self.global_step, avg_loss))\n",
" self.validation_step_outputs.clear() # free memory (compute `avg_loss` per epoch)\n",
"\n",
" def save(self, path):\n",
" with fsspec.open(path, 'wb') as f:\n",
" torch.save(\n",
" {'hyper_parameters': self.hparams, 'state_dict': self.state_dict()},\n",
" f,\n",
" )\n",
"\n",
" @classmethod\n",
" def load(cls, path, **kwargs):\n",
" with fsspec.open(path, 'rb') as f:\n",
" content = torch.load(f, **kwargs)\n",
" with _disable_torch_init():\n",
" model = cls(**content['hyper_parameters']) \n",
" model.load_state_dict(content['state_dict'], strict=True, assign=True)\n",
" return model"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "python3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment