Commit efb99f11 authored by suily's avatar suily
Browse files

Initial commit

parents
Pipeline #1482 canceled with stages
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Forked from https://github.com/Nixtla/nixtla/blob/main/experiments/amazon-chronos/src/utils.py."""
from functools import partial
from itertools import repeat
import multiprocessing
import os
from pathlib import Path
from typing import List
from gluonts.dataset import Dataset
from gluonts.dataset.repository.datasets import (
dataset_names as gluonts_datasets,
get_dataset,
)
from gluonts.time_feature.seasonality import get_seasonality
import numpy as np
import pandas as pd
from utilsforecast.evaluation import evaluate
from utilsforecast.losses import mae, mase, smape
def parallel_transform(inp):
ts, last_n = inp[0], inp[1]
return ExperimentHandler._transform_gluonts_instance_to_df(ts, last_n=last_n)
def quantile_loss(
df: pd.DataFrame,
models: list,
q: float = 0.5,
id_col: str = "unique_id",
target_col: str = "y",
) -> pd.DataFrame:
delta_y = df[models].sub(df[target_col], axis=0)
res = (
np.maximum(q * delta_y, (q - 1) * delta_y)
.groupby(df[id_col], observed=True)
.mean()
)
res.index.name = id_col
res = res.reset_index()
return res
class ExperimentHandler:
def __init__(
self,
dataset: str,
quantiles: List[float] = list(np.arange(1, 10) / 10.0),
results_dir: str = "./results",
models_dir: str = "./models",
):
if dataset not in gluonts_datasets:
raise Exception(
f"dataset {dataset} not found in gluonts "
f"available datasets: {', '.join(gluonts_datasets)}"
)
self.dataset = dataset
self.quantiles = quantiles
self.level = self._transform_quantiles_to_levels(quantiles)
self.results_dir = results_dir
self.models_dir = models_dir
# defining datasets
self._maybe_download_m3_or_m5_file(self.dataset)
gluonts_dataset = get_dataset(self.dataset)
self.horizon = gluonts_dataset.metadata.prediction_length
if self.horizon is None:
raise Exception(
f"horizon not found for dataset {self.dataset} "
"experiment cannot be run"
)
self.freq = gluonts_dataset.metadata.freq
# get_seasonality() returns 1 for freq='D', override this to 7. This significantly improves the accuracy of
# statistical models on datasets like m5/nn5_daily. The models like AutoARIMA/AutoETS can still set
# seasonality=1 internally on datasets like weather by choosing non-seasonal models during model selection.
if self.freq == "D":
self.seasonality = 7
else:
self.seasonality = get_seasonality(self.freq)
self.gluonts_train_dataset = gluonts_dataset.train
self.gluonts_test_dataset = gluonts_dataset.test
self._create_dir_if_not_exists(self.results_dir)
try:
multiprocessing.set_start_method("spawn")
except RuntimeError:
print("Multiprocessing context has already been set.")
@staticmethod
def _maybe_download_m3_or_m5_file(dataset: str):
if dataset[:2] == "m3":
m3_file = Path.home() / ".gluonts" / "datasets" / "M3C.xls"
if not m3_file.exists():
from datasetsforecast.m3 import M3
from datasetsforecast.utils import download_file
download_file(m3_file.parent, M3.source_url)
elif dataset == "m5":
m5_raw_dir = Path.home() / ".gluonts" / "m5"
if not m5_raw_dir.exists():
import zipfile
from datasetsforecast.m5 import M5
from datasetsforecast.utils import download_file
download_file(m5_raw_dir, M5.source_url)
with zipfile.ZipFile(m5_raw_dir / "m5.zip", "r") as zip_ref:
zip_ref.extractall(m5_raw_dir)
@staticmethod
def _transform_quantiles_to_levels(quantiles: List[float]) -> List[int]:
level = [
int(100 - 200 * q) for q in quantiles if q < 0.5
] # in this case mean=mediain
level = sorted(list(set(level)))
return level
@staticmethod
def _create_dir_if_not_exists(directory: str):
Path(directory).mkdir(parents=True, exist_ok=True)
@staticmethod
def _transform_gluonts_instance_to_df(
ts: dict,
last_n: int | None = None,
) -> pd.DataFrame:
start_period = ts["start"]
start_ds, freq = start_period.to_timestamp(), start_period.freq
target = ts["target"]
ds = pd.date_range(start=start_ds, freq=freq, periods=len(target))
if last_n is not None:
target = target[-last_n:]
ds = ds[-last_n:]
ts_df = pd.DataFrame({"unique_id": ts["item_id"], "ds": ds, "y": target})
return ts_df
@staticmethod
def _transform_gluonts_dataset_to_df(
gluonts_dataset: Dataset,
last_n: int | None = None,
) -> pd.DataFrame:
with multiprocessing.Pool(os.cpu_count()) as pool: # Create a process pool
results = pool.map(
parallel_transform, zip(gluonts_dataset, repeat(last_n))
)
df = pd.concat(results)
df = df.reset_index(drop=True)
return df
@property
def train_df(self) -> pd.DataFrame:
train_df = self._transform_gluonts_dataset_to_df(self.gluonts_train_dataset)
return train_df
@property
def test_df(self) -> pd.DataFrame:
test_df = self._transform_gluonts_dataset_to_df(
self.gluonts_test_dataset,
last_n=self.horizon,
)
# Make sure that only the first backtest window is used for evaluation on `traffic` / `exchange_rate` datasets
return test_df.groupby("unique_id", sort=False).head(self.horizon)
def save_dataframe(self, df: pd.DataFrame, file_name: str):
df.to_csv(f"{self.results_dir}/{file_name}", index=False)
def save_results(
self, fcst_df: pd.DataFrame, total_time: float, model_name: str
):
self.save_dataframe(
fcst_df,
f"{model_name}-{self.dataset}-fcst.csv",
)
time_df = pd.DataFrame({"time": [total_time], "model": model_name})
self.save_dataframe(
time_df,
f"{model_name}-{self.dataset}-time.csv",
)
def fcst_from_level_to_quantiles(
self,
fcst_df: pd.DataFrame,
model_name: str,
) -> pd.DataFrame:
fcst_df = fcst_df.copy()
cols = ["unique_id", "ds", model_name]
for q in self.quantiles:
if q == 0.5:
col = f"{model_name}"
else:
lv = int(100 - 200 * q)
hi_or_lo = "lo" if lv > 0 else "hi"
lv = abs(lv)
col = f"{model_name}-{hi_or_lo}-{lv}"
q_col = f"{model_name}-q-{q}"
fcst_df[q_col] = fcst_df[col].values
cols.append(q_col)
return fcst_df[cols]
def evaluate_models(self, models: List[str]) -> pd.DataFrame:
fcsts_df = []
times_df = []
for model in models:
fcst_method_df = pd.read_csv(
f"{self.results_dir}/{model}-{self.dataset}-fcst.csv"
).set_index(["unique_id", "ds"])
fcsts_df.append(fcst_method_df)
time_method_df = pd.read_csv(
f"{self.results_dir}/{model}-{self.dataset}-time.csv"
)
times_df.append(time_method_df)
fcsts_df = pd.concat(fcsts_df, axis=1).reset_index()
fcsts_df["ds"] = pd.to_datetime(fcsts_df["ds"])
times_df = pd.concat(times_df)
return self.evaluate_from_predictions(
models=models, fcsts_df=fcsts_df, times_df=times_df
)
def evaluate_from_predictions(
self, models: List[str], fcsts_df: pd.DataFrame, times_df: pd.DataFrame
) -> pd.DataFrame:
test_df = self.test_df
train_df = self.train_df
test_df = test_df.merge(fcsts_df, how="left")
assert test_df.isna().sum().sum() == 0, "merge contains nas"
# point evaluation
point_fcsts_cols = ["unique_id", "ds", "y"] + models
test_df["unique_id"] = test_df["unique_id"].astype(str)
train_df["unique_id"] = train_df["unique_id"].astype(str)
mase_seas = partial(mase, seasonality=self.seasonality)
eval_df = evaluate(
test_df[point_fcsts_cols],
train_df=train_df,
metrics=[smape, mase_seas, mae],
)
# probabilistic evaluation
eval_prob_df = []
for q in self.quantiles:
prob_cols = [f"{model}-q-{q}" for model in models]
eval_q_df = quantile_loss(test_df, models=prob_cols, q=q)
eval_q_df[prob_cols] = eval_q_df[prob_cols] * self.horizon
eval_q_df = eval_q_df.rename(columns=dict(zip(prob_cols, models)))
eval_q_df["metric"] = f"quantile-loss-{q}"
eval_prob_df.append(eval_q_df)
eval_prob_df = pd.concat(eval_prob_df)
eval_prob_df = eval_prob_df.groupby("metric").sum().reset_index()
total_y = test_df["y"].sum()
eval_prob_df[models] = eval_prob_df[models] / total_y
eval_prob_df["metric"] = "scaled_crps"
eval_df = pd.concat([eval_df, eval_prob_df]).reset_index(drop=True)
eval_df = eval_df.groupby("metric").mean(numeric_only=True).reset_index()
eval_df = eval_df.melt(
id_vars="metric", value_name="value", var_name="model"
)
times_df.insert(0, "metric", "time")
times_df = times_df.rename(columns={"time": "value"})
eval_df = pd.concat([eval_df, times_df])
eval_df.insert(0, "dataset", self.dataset)
eval_df = eval_df.sort_values(["dataset", "metric", "model"])
eval_df = eval_df.reset_index(drop=True)
return eval_df
if __name__ == "__main__":
multiprocessing.set_start_method("spawn")
# Extended Benchmarks
We benchmark on the original test set for ETT datasets as per long horizon benchmark papers (see [here](https://openreview.net/forum?id=pCbC3aQB5W) for example.) In the original benchmark, rolling validation task on all test windows (with a stride of 1) is considered. While we can easily run our method on this task, the baselines can take a very long time to run. Therefore we present results on a modified task with stride between windows set to Horizon length i.e all disjoint horizons in the test period is considered.
All experiments were performed on a [g2-standard-32](https://cloud.google.com/compute/docs/gpus). We compare TimesFM with [Amazon-Chronos](https://github.com/amazon-science/chronos-forecasting).
## Running TimesFM on the benchmark
Install the environment and the package as detailed in the main README and then follow the steps from the base directory.
```
conda activate tfm_env
TF_CPP_MIN_LOG_LEVEL=2 XLA_PYTHON_CLIENT_PREALLOCATE=false python3 -m experiments.long_horizon_benchmarks.run_eval \
--model_path=<model_path> --backend="gpu" \
--pred_len=96 --context_len=512 --dataset=etth1
```
In the above, `<model_path>` should point to the checkpoint directory that can be downloaded from HuggingFace.
For running chronos on the same benchmark you can run the command,
```
TF_CPP_MIN_LOG_LEVEL=2 XLA_PYTHON_CLIENT_PREALLOCATE=false python3 -m experiments.long_horizon_benchmarks.run_eval \
--model_path=amazon/chronos-t5-mini --backend="gpu" \
--pred_len=96 --context_len=512 --dataset=etth1
```
You can change the model size from "mini" to "large" as required. The datasets we benchmark on are etth1, etth2, ettm1 and ettm2.
## Benchmark Results
![Benchmark Results Table](./tfm_long_horizon.png)
We compare the performance on horizon lengths of 96, 192 and 336, while context length is held fixed at 512.
We can see that TimesFM performs the best in terms of both wape and smape. More importantly it is much faster than the other methods, in particular it is more than 1000x faster than Chronos (Large).
\ No newline at end of file
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TF dataloaders for general timeseries datasets.
The expected input format is csv file with a datetime index.
"""
from absl import logging
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
import tensorflow as tf
import sys
import os
from . import time_features #TODO:路径
tf.config.experimental.set_visible_devices([], "GPU") #TODO:OOM错误
# import time_features
class TimeSeriesdata(object):
"""Data loader class."""
def __init__(
self,
data_path,
datetime_col,
num_cov_cols,
cat_cov_cols,
ts_cols,
train_range,
val_range,
test_range,
hist_len,
pred_len,
batch_size,
freq='H',
normalize=True,
epoch_len=None,
holiday=False,
permute=True,
):
"""Initialize objects.
Args:
data_path: path to csv file
datetime_col: column name for datetime col
num_cov_cols: list of numerical global covariates
cat_cov_cols: list of categorical global covariates
ts_cols: columns corresponding to ts
train_range: tuple of train ranges
val_range: tuple of validation ranges
test_range: tuple of test ranges
hist_len: historical context
pred_len: prediction length
batch_size: batch size (number of ts in a batch)
freq: freq of original data
normalize: std. normalize data or not
epoch_len: num iters in an epoch
holiday: use holiday features or not
permute: permute ts in train batches or not
Returns:
None
"""
self.data_df = pd.read_csv(open(data_path, 'r'))
if not num_cov_cols:
self.data_df['ncol'] = np.zeros(self.data_df.shape[0])
num_cov_cols = ['ncol']
if not cat_cov_cols:
self.data_df['ccol'] = np.zeros(self.data_df.shape[0])
cat_cov_cols = ['ccol']
self.data_df.fillna(0, inplace=True)
self.data_df.set_index(
pd.DatetimeIndex(self.data_df[datetime_col]), inplace=True
)
self.num_cov_cols = num_cov_cols
self.cat_cov_cols = cat_cov_cols
self.ts_cols = ts_cols
self.train_range = train_range
self.val_range = val_range
self.test_range = test_range
data_df_idx = self.data_df.index
date_index = data_df_idx.union(
pd.date_range(
data_df_idx[-1] + pd.Timedelta(1, freq=freq),
periods=pred_len + 1,
freq=freq,
)
)
self.time_df = time_features.TimeCovariates(
date_index, holiday=holiday
).get_covariates()
self.hist_len = hist_len
self.pred_len = pred_len
self.batch_size = batch_size
self.freq = freq
self.normalize = normalize
self.data_mat = self.data_df[self.ts_cols].to_numpy().transpose()
self.data_mat = self.data_mat[:, 0 : self.test_range[1]]
self.time_mat = self.time_df.to_numpy().transpose()
self.num_feat_mat = self.data_df[num_cov_cols].to_numpy().transpose()
self.cat_feat_mat, self.cat_sizes = self._get_cat_cols(cat_cov_cols)
self.normalize = normalize
if normalize:
self._normalize_data()
logging.info(
'Data Shapes: %s, %s, %s, %s',
self.data_mat.shape,
self.time_mat.shape,
self.num_feat_mat.shape,
self.cat_feat_mat.shape,
)
self.epoch_len = epoch_len
self.permute = permute
def _get_cat_cols(self, cat_cov_cols):
"""Get categorical columns."""
cat_vars = []
cat_sizes = []
for col in cat_cov_cols:
dct = {x: i for i, x in enumerate(self.data_df[col].unique())}
cat_sizes.append(len(dct))
mapped = self.data_df[col].map(lambda x: dct[x]).to_numpy().transpose() # pylint: disable=cell-var-from-loop
cat_vars.append(mapped)
return np.vstack(cat_vars), cat_sizes
def _normalize_data(self):
self.scaler = StandardScaler()
train_mat = self.data_mat[:, self.train_range[0] : self.train_range[1]]
self.scaler = self.scaler.fit(train_mat.transpose())
self.data_mat = self.scaler.transform(self.data_mat.transpose()).transpose()
def train_gen(self):
"""Generator for training data."""
num_ts = len(self.ts_cols)
perm = np.arange(
self.train_range[0] + self.hist_len,
self.train_range[1] - self.pred_len,
)
perm = np.random.permutation(perm)
hist_len = self.hist_len
logging.info('Hist len: %s', hist_len)
if not self.epoch_len:
epoch_len = len(perm)
else:
epoch_len = self.epoch_len
for idx in perm[0:epoch_len]:
for _ in range(num_ts // self.batch_size + 1):
if self.permute:
tsidx = np.random.choice(num_ts, size=self.batch_size, replace=False)
else:
tsidx = np.arange(num_ts)
dtimes = np.arange(idx - hist_len, idx + self.pred_len)
(
bts_train,
bts_pred,
bfeats_train,
bfeats_pred,
bcf_train,
bcf_pred,
) = self._get_features_and_ts(dtimes, tsidx, hist_len)
all_data = [
bts_train,
bfeats_train,
bcf_train,
bts_pred,
bfeats_pred,
bcf_pred,
tsidx,
]
yield tuple(all_data)
def test_val_gen(self, mode='val', shift=1):
"""Generator for validation/test data."""
if mode == 'val':
start = self.val_range[0]
end = self.val_range[1] - self.pred_len + 1
elif mode == 'test':
start = self.test_range[0]
end = self.test_range[1] - self.pred_len + 1
else:
raise NotImplementedError('Eval mode not implemented')
num_ts = len(self.ts_cols)
hist_len = self.hist_len
logging.info('Hist len: %s', hist_len)
perm = np.arange(start, end)
if self.epoch_len:
epoch_len = self.epoch_len
else:
epoch_len = len(perm)
for i in range(0, epoch_len, shift):
idx = perm[i]
for batch_idx in range(0, num_ts, self.batch_size):
tsidx = np.arange(batch_idx, min(batch_idx + self.batch_size, num_ts))
dtimes = np.arange(idx - hist_len, idx + self.pred_len)
(
bts_train,
bts_pred,
bfeats_train,
bfeats_pred,
bcf_train,
bcf_pred,
) = self._get_features_and_ts(dtimes, tsidx, hist_len)
all_data = [
bts_train,
bfeats_train,
bcf_train,
bts_pred,
bfeats_pred,
bcf_pred,
tsidx,
]
yield tuple(all_data)
def _get_features_and_ts(self, dtimes, tsidx, hist_len=None):
"""Get features and ts in specified windows."""
if hist_len is None:
hist_len = self.hist_len
data_times = dtimes[dtimes < self.data_mat.shape[1]]
bdata = self.data_mat[:, data_times]
bts = bdata[tsidx, :]
bnf = self.num_feat_mat[:, data_times]
bcf = self.cat_feat_mat[:, data_times]
btf = self.time_mat[:, dtimes]
if bnf.shape[1] < btf.shape[1]:
rem_len = btf.shape[1] - bnf.shape[1]
rem_rep = np.repeat(bnf[:, [-1]], repeats=rem_len)
rem_rep_cat = np.repeat(bcf[:, [-1]], repeats=rem_len)
bnf = np.hstack([bnf, rem_rep.reshape(bnf.shape[0], -1)])
bcf = np.hstack([bcf, rem_rep_cat.reshape(bcf.shape[0], -1)])
bfeats = np.vstack([btf, bnf])
bts_train = bts[:, 0:hist_len]
bts_pred = bts[:, hist_len:]
bfeats_train = bfeats[:, 0:hist_len]
bfeats_pred = bfeats[:, hist_len:]
bcf_train = bcf[:, 0:hist_len]
bcf_pred = bcf[:, hist_len:]
return bts_train, bts_pred, bfeats_train, bfeats_pred, bcf_train, bcf_pred
def tf_dataset(self, mode='train', shift=1):
"""Tensorflow Dataset."""
if mode == 'train':
gen_fn = self.train_gen
else:
gen_fn = lambda: self.test_val_gen(mode, shift)
output_types = tuple(
[tf.float32] * 2 + [tf.int32] + [tf.float32] * 2 + [tf.int32] * 2
)
dataset = tf.data.Dataset.from_generator(gen_fn, output_types)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Eval pipeline."""
import json
import os
import sys
import time
from absl import flags
import chronos
import numpy as np
import pandas as pd
from paxml import checkpoints
sys.path.append(os.getcwd())
from src import timesfm # TODO: 报错
import torch
import tqdm
from . import data_loader # TODO: 报错
# import data_loader
from jax.lib import xla_bridge
FLAGS = flags.FLAGS
_BATCH_SIZE = flags.DEFINE_integer(
"batch_size", 64, "Batch size for the randomly sampled batch"
)
_DATASET = flags.DEFINE_string("dataset", "etth1", "The name of the dataset.")
_MODEL_PATH = flags.DEFINE_string( # TODO:模型位置
"model_path", "model/checkpoints", "Path to model"
)
_DATETIME_COL = flags.DEFINE_string(
"datetime_col", "date", "Column having datetime."
)
_NUM_COV_COLS = flags.DEFINE_list(
"num_cov_cols", None, "Column having numerical features."
)
_CAT_COV_COLS = flags.DEFINE_list(
"cat_cov_cols", None, "Column having categorical features."
)
_TS_COLS = flags.DEFINE_list("ts_cols", None, "Columns of time-series features")
_NORMALIZE = flags.DEFINE_bool(
"normalize", True, "normalize data for eval or not"
)
_CONTEXT_LEN = flags.DEFINE_integer(
"context_len", 512, "Length of the context window"
)
_PRED_LEN = flags.DEFINE_integer("pred_len", 96, "prediction length.")
_BACKEND = flags.DEFINE_string("backend", "gpu", "backend to use")
_RESULTS_DIR = flags.DEFINE_string(
"results_dir", "./results/long_horizon", "results directory"
)
DATA_DICT = {
"ettm2": {
"boundaries": [34560, 46080, 57600],
"data_path": "./datasets/ETT-small/ETTm2.csv",
"freq": "15min",
},
"ettm1": {
"boundaries": [34560, 46080, 57600],
"data_path": "./datasets/ETT-small/ETTm1.csv",
"freq": "15min",
},
"etth2": {
"boundaries": [8640, 11520, 14400],
"data_path": "./datasets/ETT-small/ETTh2.csv",
"freq": "H",
},
"etth1": {
"boundaries": [8640, 11520, 14400],
"data_path": "./datasets/ETT-small/ETTh1.csv",
"freq": "H",
},
"elec": {
"boundaries": [18413, 21044, 26304],
"data_path": "./datasets/electricity/electricity.csv",
"freq": "H",
},
"traffic": {
"boundaries": [12280, 14036, 17544],
"data_path": "./datasets/traffic/traffic.csv",
"freq": "H",
},
"weather": {
"boundaries": [36887, 42157, 52696],
"data_path": "./datasets/weather/weather.csv",
"freq": "10min",
},
}
QUANTILES = list(np.arange(1, 10) / 10.0)
EPS = 1e-7
def get_forecasts(model_path, model, past, freq, pred_len):
"""Get forecasts."""
if model_path.startswith("amazon"):
out = model.predict(
torch.tensor(past),
prediction_length=pred_len,
limit_prediction_length=False,
)
out = out.numpy()
out = np.median(out, axis=1)
else:
lfreq = [freq] * past.shape[0]
_, out = model.forecast(list(past), lfreq)
out = out[:, :, 5]
return out
def _mse(y_pred, y_true):
"""mse loss."""
return np.square(y_pred - y_true)
def _mae(y_pred, y_true):
"""mae loss."""
return np.abs(y_pred - y_true)
def _smape(y_pred, y_true):
"""_smape loss."""
abs_diff = np.abs(y_pred - y_true)
abs_val = (np.abs(y_true) + np.abs(y_pred)) / 2
abs_val = np.where(abs_val > EPS, abs_val, 1.0)
abs_diff = np.where(abs_val > EPS, abs_diff, 0.0)
return abs_diff / abs_val
def eval():
"""Eval pipeline."""
dataset = _DATASET.value
data_path = DATA_DICT[dataset]["data_path"]
freq = DATA_DICT[dataset]["freq"]
int_freq = timesfm.freq_map(freq)
boundaries = DATA_DICT[dataset]["boundaries"]
data_df = pd.read_csv(open(data_path, "r"))
if _TS_COLS.value is not None:
ts_cols = DATA_DICT[dataset]["ts_cols"]
num_cov_cols = DATA_DICT[dataset]["num_cov_cols"]
cat_cov_cols = DATA_DICT[dataset]["cat_cov_cols"]
else:
ts_cols = [col for col in data_df.columns if col != _DATETIME_COL.value]
num_cov_cols = None
cat_cov_cols = None
batch_size = min(_BATCH_SIZE.value, len(ts_cols))
dtl = data_loader.TimeSeriesdata(
data_path=data_path,
datetime_col=_DATETIME_COL.value,
num_cov_cols=num_cov_cols,
cat_cov_cols=cat_cov_cols,
ts_cols=np.array(ts_cols),
train_range=[0, boundaries[0]],
val_range=[boundaries[0], boundaries[1]],
test_range=[boundaries[1], boundaries[2]],
hist_len=_CONTEXT_LEN.value,
pred_len=_PRED_LEN.value,
batch_size=batch_size,
freq=freq,
normalize=_NORMALIZE.value,
epoch_len=None,
holiday=False,
permute=False,
)
eval_itr = dtl.tf_dataset(
mode="test", shift=_PRED_LEN.value
).as_numpy_iterator()
model_path = _MODEL_PATH.value
if model_path.startswith("amazon"):
model = chronos.ChronosPipeline.from_pretrained(
model_path,
device_map="auto",
torch_dtype=torch.bfloat16,
)
else:
model = timesfm.TimesFm(
context_len=_CONTEXT_LEN.value,
horizon_len=_PRED_LEN.value,
input_patch_len=32,
output_patch_len=128,
num_layers=20,
model_dims=1280,
backend=_BACKEND.value,
per_core_batch_size=batch_size,
quantiles=QUANTILES,
)
model.load_from_checkpoint(
model_path,
checkpoint_type=checkpoints.CheckpointType.FLAX,
)
smape_run_losses = []
mse_run_losses = []
mae_run_losses = []
num_elements = 0
abs_sum = 0
start_time = time.time()
for batch in tqdm.tqdm(eval_itr):
past = batch[0]
actuals = batch[3]
forecasts = get_forecasts(
model_path, model, past, int_freq, _PRED_LEN.value
)
forecasts = forecasts[:, 0 : actuals.shape[1]]
mae_run_losses.append(_mae(forecasts, actuals).sum())
mse_run_losses.append(_mse(forecasts, actuals).sum())
smape_run_losses.append(_smape(forecasts, actuals).sum())
num_elements += actuals.shape[0] * actuals.shape[1]
abs_sum += np.abs(actuals).sum()
mse_val = np.sum(mse_run_losses) / num_elements
result_dict = {
"mse": mse_val,
"smape": np.sum(smape_run_losses) / num_elements,
"mae": np.sum(mae_run_losses) / num_elements,
"wape": np.sum(mae_run_losses) / abs_sum,
"nrmse": np.sqrt(mse_val) / (abs_sum / num_elements),
"num_elements": num_elements,
"abs_sum": abs_sum,
"total_time": time.time() - start_time,
"model_path": model_path,
"dataset": dataset,
"freq": freq,
"pred_len": _PRED_LEN.value,
"context_len": _CONTEXT_LEN.value,
}
run_id = np.random.randint(10000)
save_path = os.path.join(_RESULTS_DIR.value, str(run_id))
print(f"Saving results to {save_path}", flush=True)
os.makedirs(save_path, exist_ok=True)
with open(os.path.join(save_path, "results.json"), "w") as f:
json.dump(result_dict, f)
print(result_dict, flush=True)
if __name__ == "__main__":
# # debug1-测试torch-gpu\jax-gpu\TensorFlow-gpu
jax_test=xla_bridge.get_backend().platform
print(jax_test)
if not (jax_test=='gpu'):
exit()
FLAGS = flags.FLAGS
FLAGS(sys.argv)
eval()
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Directory to extract time covariates.
Extract time covariates from datetime.
"""
import numpy as np
import pandas as pd
from pandas.tseries.holiday import EasterMonday
from pandas.tseries.holiday import GoodFriday
from pandas.tseries.holiday import Holiday
from pandas.tseries.holiday import SU
from pandas.tseries.holiday import TH
from pandas.tseries.holiday import USColumbusDay
from pandas.tseries.holiday import USLaborDay
from pandas.tseries.holiday import USMartinLutherKingJr
from pandas.tseries.holiday import USMemorialDay
from pandas.tseries.holiday import USPresidentsDay
from pandas.tseries.holiday import USThanksgivingDay
from pandas.tseries.offsets import DateOffset
from pandas.tseries.offsets import Day
from pandas.tseries.offsets import Easter
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
# This is 183 to cover half a year (in both directions), also for leap years
# + 17 as Eastern can be between March, 22 - April, 25
MAX_WINDOW = 183 + 17
def _distance_to_holiday(holiday):
"""Return distance to given holiday."""
def _distance_to_day(index):
holiday_date = holiday.dates(
index - pd.Timedelta(days=MAX_WINDOW),
index + pd.Timedelta(days=MAX_WINDOW),
)
assert (
len(holiday_date) != 0 # pylint: disable=g-explicit-length-test
), f"No closest holiday for the date index {index} found."
# It sometimes returns two dates if it is exactly half a year after the
# holiday. In this case, the smaller distance (182 days) is returned.
return (index - holiday_date[0]).days
return _distance_to_day
EasterSunday = Holiday(
"Easter Sunday", month=1, day=1, offset=[Easter(), Day(0)]
)
NewYearsDay = Holiday("New Years Day", month=1, day=1)
SuperBowl = Holiday(
"Superbowl", month=2, day=1, offset=DateOffset(weekday=SU(1))
)
MothersDay = Holiday(
"Mothers Day", month=5, day=1, offset=DateOffset(weekday=SU(2))
)
IndependenceDay = Holiday("Independence Day", month=7, day=4)
ChristmasEve = Holiday("Christmas", month=12, day=24)
ChristmasDay = Holiday("Christmas", month=12, day=25)
NewYearsEve = Holiday("New Years Eve", month=12, day=31)
BlackFriday = Holiday(
"Black Friday",
month=11,
day=1,
offset=[pd.DateOffset(weekday=TH(4)), Day(1)],
)
CyberMonday = Holiday(
"Cyber Monday",
month=11,
day=1,
offset=[pd.DateOffset(weekday=TH(4)), Day(4)],
)
HOLIDAYS = [
EasterMonday,
GoodFriday,
USColumbusDay,
USLaborDay,
USMartinLutherKingJr,
USMemorialDay,
USPresidentsDay,
USThanksgivingDay,
EasterSunday,
NewYearsDay,
SuperBowl,
MothersDay,
IndependenceDay,
ChristmasEve,
ChristmasDay,
NewYearsEve,
BlackFriday,
CyberMonday,
]
class TimeCovariates(object):
"""Extract all time covariates except for holidays."""
def __init__(
self,
datetimes,
normalized=True,
holiday=False,
):
"""Init function.
Args:
datetimes: pandas DatetimeIndex (lowest granularity supported is min)
normalized: whether to normalize features or not
holiday: fetch holiday features or not
Returns:
None
"""
self.normalized = normalized
self.dti = datetimes
self.holiday = holiday
def _minute_of_hour(self):
minutes = np.array(self.dti.minute, dtype=np.float32)
if self.normalized:
minutes = minutes / 59.0 - 0.5
return minutes
def _hour_of_day(self):
hours = np.array(self.dti.hour, dtype=np.float32)
if self.normalized:
hours = hours / 23.0 - 0.5
return hours
def _day_of_week(self):
day_week = np.array(self.dti.dayofweek, dtype=np.float32)
if self.normalized:
day_week = day_week / 6.0 - 0.5
return day_week
def _day_of_month(self):
day_month = np.array(self.dti.day, dtype=np.float32)
if self.normalized:
day_month = day_month / 30.0 - 0.5
return day_month
def _day_of_year(self):
day_year = np.array(self.dti.dayofyear, dtype=np.float32)
if self.normalized:
day_year = day_year / 364.0 - 0.5
return day_year
def _month_of_year(self):
month_year = np.array(self.dti.month, dtype=np.float32)
if self.normalized:
month_year = month_year / 11.0 - 0.5
return month_year
def _week_of_year(self):
week_year = np.array(self.dti.strftime("%U").astype(int), dtype=np.float32)
if self.normalized:
week_year = week_year / 51.0 - 0.5
return week_year
def _get_holidays(self):
dti_series = self.dti.to_series()
hol_variates = np.vstack([
dti_series.apply(_distance_to_holiday(h)).values for h in tqdm(HOLIDAYS)
])
# hol_variates is (num_holiday, num_time_steps), the normalization should be
# performed in the num_time_steps dimension.
return StandardScaler().fit_transform(hol_variates.T).T
def get_covariates(self):
"""Get all time covariates."""
moh = self._minute_of_hour().reshape(1, -1)
hod = self._hour_of_day().reshape(1, -1)
dom = self._day_of_month().reshape(1, -1)
dow = self._day_of_week().reshape(1, -1)
doy = self._day_of_year().reshape(1, -1)
moy = self._month_of_year().reshape(1, -1)
woy = self._week_of_year().reshape(1, -1)
all_covs = [
moh,
hod,
dom,
dow,
doy,
moy,
woy,
]
columns = ["moh", "hod", "dom", "dow", "doy", "moy", "woy"]
if self.holiday:
hol_covs = self._get_holidays()
all_covs.append(hol_covs)
columns += [f"hol_{i}" for i in range(len(HOLIDAYS))]
return pd.DataFrame(
data=np.vstack(all_covs).transpose(),
columns=columns,
index=self.dti,
)
# This project can be installed with `python3 -m pip install -e .` from the main directory.
[project]
name = "timesfm"
description = "Open weights time-series foundation model from Google Research."
version = "0.0.1"
dependencies = [
"einshape>=1.0.0",
"paxml>=1.4.0",
"praxis>=1.4.0",
"jax>=0.4.26",
"numpy>=1.26.4",
"pandas>=2.1.4",
]
authors = [
{name = "Rajat Sen", email = "senrajat@google.com"},
{name = "Yichen Zhou", email = "yichenzhou@google.com"},
{name = "Abhimanyu Das", email = "abhidas@google.com"},
{name = "Petros Mol", email = "pmol@google.com"},
]
[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"
# python=3.10
# tensorflow==2.13.1+git429d21b.abi1.dtk2404:https://cancon.hpccube.com:65024/directlink/4/tensorflow/DAS1.0/tensorflow-2.13.1+das1.0+git429d21b.abi1.dtk2404-cp310-cp310-manylinux2014_x86_64.whl
# torch==2.1.0+git00661e0.abi0.dtk2404:https://cancon.hpccube.com:65024/directlink/4/pytorch/DAS1.0/torch-2.1.0+das1.0+git00661e0.abi0.dtk2404-cp310-cp310-manylinux2014_x86_64.whl
# jaxlib==0.4.23+git97306ab.abi1.dtk2404:https://cancon.hpccube.com:65024/directlink/4/jax/DAS1.0/jaxlib-0.4.23+das1.0+git97306ab.abi1.dtk2404-cp310-cp310-manylinux2014_x86_64.whl
huggingface_hub[cli]==0.23.3
utilsforecast==0.1.11
praxis==1.2.0
paxml==1.2.0
einshape==1.0
# 适配ai核心包
orbax-checkpoint==0.4.1
flax==0.8.5
protobuf==3.20.3
pandas==2.0.0
scipy==1.11.1
lingvo==0.13.1
seqio-nightly==0.0.14.dev20230301
\ No newline at end of file
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pax ML model for patched time-series decoder.
The file implements Residual MLPs, Patched Decoder layers and PAX ML models.
"""
import dataclasses
from typing import Optional, Tuple
import einshape as es
from jax import lax
import jax.numpy as jnp
from praxis import base_layer
from praxis import layers
from praxis import pax_fiddle
from praxis import py_utils
from praxis import pytypes
from praxis.layers import activations
from praxis.layers import embedding_softmax
from praxis.layers import linears
from praxis.layers import normalizations
from praxis.layers import stochastics
from praxis.layers import transformers
# PAX shortcuts
NestedMap = py_utils.NestedMap
JTensor = pytypes.JTensor
LayerTpl = pax_fiddle.Config[base_layer.BaseLayer]
template_field = base_layer.template_field
PAD_VAL = 1123581321.0
DEFAULT_QUANTILES = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
# NestedMap keys
_INPUT_TS = "input_ts"
_INPUT_PADDING = "input_padding"
_OUTPUT_TS = "output_ts"
_FREQ = "freq"
_OUTPUT_TOKENS = "output_tokens"
_STATS = "stats"
# Small numerical value.
_TOLERANCE = 1e-7
def _shift_padded_seq(mask: JTensor, seq: JTensor) -> JTensor:
"""Shifts rows of seq based on the first 0 in each row of the mask."""
num = seq.shape[1]
# Find the index of the first 0 in each row of the mask
first_zero_idx = jnp.argmin(mask, axis=1)
# Create a range array for indexing
idx_range = jnp.arange(num)
def shift_row(carry, x):
seq_row, shift = x
shifted_idx = (idx_range - shift) % num
shifted_row = seq_row[shifted_idx]
return carry, shifted_row
# Use lax.scan to shift each row of seq based on the corresponding
# first_zero_idx.
_, shifted_seq = lax.scan(shift_row, None, (seq, first_zero_idx))
return shifted_seq
class ResidualBlock(base_layer.BaseLayer):
"""Simple feedforward block with residual connection.
Attributes:
input_dims: input dimension.
hidden_dims: hidden dimension.
output_dims: output dimension.
dropout_prob: dropout probability.
layer_norm: whether to use layer norm or not.
dropout_tpl: config for dropout.
ln_tpl: config for layer norm.
act_tpl: config for activation in hidden layer.
"""
input_dims: int = 0
hidden_dims: int = 0
output_dims: int = 0
dropout_prob: float = 0.0
layer_norm: bool = False
dropout_tpl: LayerTpl = template_field(stochastics.Dropout)
ln_tpl: LayerTpl = template_field(normalizations.LayerNorm)
act_tpl: LayerTpl = template_field(activations.Swish)
def setup(self):
lnorm_tpl = self.ln_tpl.clone()
lnorm_tpl.dim = self.output_dims
self.create_child("ln_layer", lnorm_tpl)
dropout_tpl = self.dropout_tpl.clone()
dropout_tpl.keep_prob = 1.0 - self.dropout_prob
self.create_child("dropout", dropout_tpl)
self.create_child(
"hidden_layer",
pax_fiddle.Config(
linears.FeedForward,
input_dims=self.input_dims,
output_dims=self.hidden_dims,
activation_tpl=self.act_tpl.clone(),
),
)
self.create_child(
"output_layer",
pax_fiddle.Config(
linears.FeedForward,
input_dims=self.hidden_dims,
output_dims=self.output_dims,
activation_tpl=pax_fiddle.Config(activations.Identity),
),
)
self.create_child(
"residual_layer",
pax_fiddle.Config(
linears.FeedForward,
input_dims=self.input_dims,
output_dims=self.output_dims,
activation_tpl=pax_fiddle.Config(activations.Identity),
),
)
def __call__(self, inputs: JTensor) -> JTensor:
hidden = self.hidden_layer(inputs)
output = self.output_layer(hidden)
output = self.dropout(output)
residual = self.residual_layer(inputs)
if self.layer_norm:
return self.ln_layer(output + residual)
else:
return output + residual
def _masked_mean_std(
inputs: JTensor, padding: JTensor
) -> Tuple[JTensor, JTensor]:
"""Calculates mean and standard deviation of arr across axis 1.
It should exclude values where pad is 1.
Args:
inputs: A JAX array of shape [b, n, p].
padding: A JAX array of shape [b, n, p] with values 0 or 1.
Returns:
A tuple containing the mean and standard deviation of arr. We return the
statistics of the first patch with more than three non-padded values.
"""
# Selecting the first pad with more than 3 unpadded values.
pad_sum = jnp.sum(1 - padding, axis=2)
def _get_patch_index(arr: JTensor):
indices = jnp.argmax(arr >= 3, axis=1)
row_sum = (arr >= 3).sum(axis=1)
return jnp.where(row_sum == 0, arr.shape[1] - 1, indices)
patch_indices = _get_patch_index(pad_sum)
bidxs = jnp.arange(inputs.shape[0])
arr = inputs[bidxs, patch_indices, :]
pad = padding[bidxs, patch_indices, :]
# Create a mask where P is 0
mask = 1 - pad
# Calculate the number of valid elements
num_valid_elements = jnp.sum(mask, axis=1)
num_valid_elements = jnp.where(num_valid_elements == 0, 1, num_valid_elements)
# Calculate the masked sum and squared sum of M
masked_sum = jnp.sum(arr * mask, axis=1)
masked_squared_sum = jnp.sum((arr * mask) ** 2, axis=1)
# Calculate the masked mean and standard deviation
masked_mean = masked_sum / num_valid_elements
masked_var = masked_squared_sum / num_valid_elements - masked_mean**2
masked_var = jnp.where(masked_var < 0.0, 0.0, masked_var)
masked_std = jnp.sqrt(masked_var)
return masked_mean, masked_std
def _create_quantiles() -> list[float]:
"""Returns the quantiles for forecasting."""
return DEFAULT_QUANTILES
class PatchedTimeSeriesDecoder(base_layer.BaseLayer):
"""Patch decoder layer for time-series foundation model.
Attributes:
patch_len: length of input patches.
horizon_len: length of output patches. Referred to as `output_patch_len`
during inference.
model_dims: model dimension of stacked transformer layer.
hidden_dims: hidden dimensions in fully connected layers.
quantiles: list of quantiles for non prob model.
residual_block_tpl: config for residual block.
stacked_transformer_params_tpl: config for stacked transformer.
use_freq: whether to use frequency encoding.
In all of what followed, except specified otherwise, B is batch size, T is
sequence length of time-series. N is the number of input patches that can be
obtained from T. P is the input patch length and H is the horizon length. Q is
number of output logits. D is model dimension.
"""
patch_len: int = 0
horizon_len: int = 0
model_dims: int = 0
hidden_dims: int = 0
quantiles: list[float] = dataclasses.field(default_factory=_create_quantiles)
residual_block_tpl: LayerTpl = template_field(ResidualBlock)
stacked_transformer_params_tpl: LayerTpl = template_field(
transformers.StackedTransformer
)
use_freq: bool = True
def setup(self) -> None:
"""Construct the model."""
num_outputs = len(self.quantiles) + 1
stl = self.stacked_transformer_params_tpl.clone()
stl.model_dims = self.model_dims
stl.hidden_dims = self.hidden_dims
stl.mask_self_attention = True
self.create_child("stacked_transformer_layer", stl)
input_resl = self.residual_block_tpl.clone()
ff_in_dims = 2 * self.patch_len
input_resl.input_dims = ff_in_dims
input_resl.hidden_dims = self.hidden_dims
input_resl.output_dims = self.model_dims
self.create_child(
"input_ff_layer",
input_resl,
)
horizon_resl = self.residual_block_tpl.clone()
horizon_resl.input_dims = self.model_dims
horizon_resl.hidden_dims = self.hidden_dims
horizon_resl.output_dims = self.horizon_len * num_outputs
self.create_child(
"horizon_ff_layer",
horizon_resl,
)
self.create_child(
"position_emb",
pax_fiddle.Config(
layers.PositionalEmbedding, embedding_dims=self.model_dims
),
)
if self.use_freq:
self.create_child(
"freq_emb",
pax_fiddle.Config(
embedding_softmax.Embedding,
num_classes=3,
input_dims=self.model_dims,
),
)
def transform_decode_state(
self, transform_fn: base_layer.DecodeStateTransformFn
) -> None:
"""Transforms all decode state variables based on transform_fn."""
self.stacked_transformer_layer.transform_decode_state(transform_fn)
def _forward_transform(
self, inputs: JTensor, patched_pads: JTensor
) -> Tuple[JTensor, Tuple[JTensor, JTensor]]:
"""Input is of shape [B, N, P]."""
mu, sigma = _masked_mean_std(inputs, patched_pads)
sigma = jnp.where(sigma < _TOLERANCE, 1.0, sigma)
# Normalize each patch.
outputs = (inputs - mu[:, None, None]) / sigma[:, None, None]
outputs = jnp.where(
jnp.abs(inputs - PAD_VAL) < _TOLERANCE, PAD_VAL, outputs
)
return outputs, (mu, sigma)
def _reverse_transform(
self, outputs: JTensor, stats: Tuple[JTensor, JTensor]
) -> JTensor:
"""Output is of shape [B, N, P, Q]."""
mu, sigma = stats
return outputs * sigma[:, None, None, None] + mu[:, None, None, None]
def _preprocess_input(
self,
input_ts: JTensor,
input_padding: JTensor,
pos_emb: Optional[JTensor] = None,
) -> Tuple[JTensor, JTensor, Optional[Tuple[JTensor, JTensor]], JTensor]:
"""Preprocess input for stacked transformer."""
# Reshape into patches.
patched_inputs = es.jax_einshape("b(np)->bnp", input_ts, p=self.patch_len)
input_padding = jnp.where(
jnp.abs(input_ts - PAD_VAL) < _TOLERANCE, 1, input_padding
)
patched_pads = es.jax_einshape(
"b(np)->bnp", input_padding, p=self.patch_len
)
patched_inputs, stats = self._forward_transform(
patched_inputs, patched_pads
)
# B x N x D
patched_inputs = patched_inputs * (1.0 - patched_pads)
concat_inputs = jnp.concatenate([patched_inputs, patched_pads], axis=-1)
model_input = self.input_ff_layer(concat_inputs)
# A patch should not be padded even if there is at least one zero.
patched_padding = jnp.min(patched_pads, axis=-1)
if pos_emb is None:
position_emb = self.position_emb(seq_length=model_input.shape[1])
else:
position_emb = pos_emb
if self.do_eval:
if position_emb.shape[0] != model_input.shape[0]:
position_emb = jnp.repeat(position_emb, model_input.shape[0], axis=0)
position_emb = _shift_padded_seq(patched_padding, position_emb)
model_input += position_emb
return model_input, patched_padding, stats, patched_inputs
def _postprocess_output(
self,
model_output: JTensor,
num_outputs: int,
stats: Tuple[JTensor, JTensor],
) -> JTensor:
"""Postprocess output of stacked transformer."""
# B x N x (H.Q)
output_ts = self.horizon_ff_layer(model_output)
output_ts = es.jax_einshape(
"bn(hq)->bnhq", output_ts, q=num_outputs, h=self.horizon_len
)
return self._reverse_transform(output_ts, stats)
def __call__(self, inputs: NestedMap) -> NestedMap:
"""PatchTST call.
Args:
inputs: A NestedMap containing (1) input_ts: input sequence of shape [B,
T] where T must be multiple of patch_length; (2) input_padding: that
contains padding map.
Returns:
A nested map with two keys:
(1) 'output_tokens' of shape [B, N, D].
(2) 'output_ts' of shape [B, N, H, Q]
(3) 'stats' a Tuple of statistics for renormalization.
"""
input_ts, input_padding = inputs[_INPUT_TS], inputs[_INPUT_PADDING]
num_outputs = len(self.quantiles) + 1
model_input, patched_padding, stats, _ = self._preprocess_input(
input_ts=input_ts,
input_padding=input_padding,
)
if self.use_freq:
freq = inputs[_FREQ].astype(jnp.int32)
f_emb = self.freq_emb(freq) # B x 1 x D
f_emb = jnp.repeat(f_emb, model_input.shape[1], axis=1)
model_input += f_emb
model_output = self.stacked_transformer_layer(model_input, patched_padding)
output_ts = self._postprocess_output(model_output, num_outputs, stats)
return NestedMap(
{_OUTPUT_TOKENS: model_output, _OUTPUT_TS: output_ts, _STATS: stats}
)
def decode(
self,
inputs: NestedMap,
horizon_len: int,
output_patch_len: Optional[int] = None,
max_len: int = 512,
) -> tuple[JTensor, JTensor]:
"""Auto-regressive decoding without caching.
Args:
inputs: input time-series and paddings. Time-series shape B x C, padding
shape shape B x (C + H) where H is the prediction length.
horizon_len: prediction length.
output_patch_len: output length to be fetched from one step of
auto-regressive decoding.
max_len: maximum training context length.
Returns:
Tuple of two forecasting results:
- Point (mean) output predictions as a tensor with shape B x H.
- Full predictions (mean and quantiles) as a tensor with shape
B x H x (1 + # quantiles).
"""
final_out = inputs[_INPUT_TS]
inp_time_len = final_out.shape[1]
paddings = inputs[_INPUT_PADDING]
if self.use_freq:
freq = inputs[_FREQ].astype(jnp.int32)
else:
freq = jnp.zeros([final_out.shape[0], 1], dtype=jnp.int32)
full_outputs = []
if paddings.shape[1] != final_out.shape[1] + horizon_len:
raise ValueError(
"Length of paddings must match length of input + horizon_len:"
f" {paddings.shape[1]} != {final_out.shape[1]} + {horizon_len}"
)
if output_patch_len is None:
output_patch_len = self.horizon_len
num_decode_patches = (
horizon_len + output_patch_len - 1
) // output_patch_len
for _ in range(num_decode_patches):
current_padding = paddings[:, 0 : final_out.shape[1]]
input_ts = final_out[:, -max_len:]
input_padding = current_padding[:, -max_len:]
model_input = NestedMap(
input_ts=input_ts,
input_padding=input_padding,
freq=freq,
)
fprop_outputs = self(model_input)[_OUTPUT_TS]
# (full batch, last patch, output_patch_len, index of mean forecast = 0)
new_ts = fprop_outputs[:, -1, :output_patch_len, 0]
# (full batch, last patch, output_patch_len, all output indices)
full_outputs.append(fprop_outputs[:, -1, :output_patch_len, :])
final_out = jnp.concatenate([final_out, new_ts], axis=-1)
return (
final_out[:, inp_time_len : inp_time_len + horizon_len],
jnp.concatenate(full_outputs, axis=1)[:, 0:horizon_len, :],
)
This diff is collapsed.
export HIP_VISIBLE_DEVICES=5
export USE_MIOPEN_BATCHNORM=1
export TF_CPP_MIN_LOG_LEVEL=2
export XLA_PYTHON_CLIENT_PREALLOCATE=false
python3 -m experiments.extended_benchmarks.run_timesfm \
--model_path="model/checkpoints" \
--backend="gpu"
# python -c "print('finish run_timesfm!!!!')"
# for dataset in etth1 ettm1
# do
# for pred_len in 96 192 336
# do
# python3 -m experiments.long_horizon_benchmarks.run_eval \
# --model_path="model/checkpoints" \
# --backend="gpu" \
# --pred_len=$pred_len \
# --context_len=512 \
# --dataset=$dataset
# done
# done
# python -c "print('finish run_eval!!!!')"
\ No newline at end of file
export CUDA_VISIBLE_DEVICES=7
export USE_MIOPEN_BATCHNORM=1
export TF_CPP_MIN_LOG_LEVEL=2
export XLA_PYTHON_CLIENT_PREALLOCATE=false
python3 -m experiments.extended_benchmarks.run_timesfm \
--model_path="model/checkpoints" \
--backend="gpu"
python -c "print('finish run_timesfm!!!!')"
for dataset in etth1 ettm1
do
for pred_len in 96 192 336
do
python3 -m experiments.long_horizon_benchmarks.run_eval \
--model_path="model/checkpoints" \
--backend="gpu" \
--pred_len=$pred_len \
--context_len=512 \
--dataset=$dataset
done
done
python -c "print('finish run_eval!!!!')"
# python3 -m experiments.long_horizon_benchmarks.run_eval \
# --model_path="model/checkpoints" \
# --backend="gpu" \
# --pred_len=96 \
# --context_len=512 \
# --dataset=etth1
# python3 -m experiments.long_horizon_benchmarks.run_eval \
# --model_path="model/checkpoints" \
# --backend="gpu" \
# --pred_len=96 \
# --context_len=512 \
# --dataset=ettm1
# python3 -m experiments.long_horizon_benchmarks.run_eval \
# --model_path="model/checkpoints" \
# --backend="gpu" \
# --pred_len=192 \
# --context_len=512 \
# --dataset=etth1
# python3 -m experiments.long_horizon_benchmarks.run_eval \
# --model_path="model/checkpoints" \
# --backend="gpu" \
# --pred_len=192 \
# --context_len=512 \
# --dataset=ettm1
# python3 -m experiments.long_horizon_benchmarks.run_eval \
# --model_path="model/checkpoints" \
# --backend="gpu" \
# --pred_len=336 \
# --context_len=512 \
# --dataset=etth1
# python3 -m experiments.long_horizon_benchmarks.run_eval \
# --model_path="model/checkpoints" \
# --backend="gpu" \
# --pred_len=336 \
# --context_len=512 \
# --dataset=ettm1
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment