Unverified Commit 2ac5b932 authored by Gift Sinthong's avatar Gift Sinthong Committed by GitHub
Browse files

[time series] Add PatchTST (#25927)



* Initial commit of PatchTST model classes
Co-authored-by: default avatarPhanwadee Sinthong <phsinthong@gmail.com>
Co-authored-by: default avatarNam Nguyen <namctin@gmail.com>
Co-authored-by: default avatarVijay Ekambaram <vijaykr.e@gmail.com>
Co-authored-by: default avatarNgoc Diep Do <55230119+diepi@users.noreply.github.com>
Co-authored-by: default avatarWesley Gifford <79663411+wgifford@users.noreply.github.com>

* Add PatchTSTForPretraining

* update to include classification
Co-authored-by: default avatarPhanwadee Sinthong <phsinthong@gmail.com>
Co-authored-by: default avatarNam Nguyen <namctin@gmail.com>
Co-authored-by: default avatarVijay Ekambaram <vijaykr.e@gmail.com>
Co-authored-by: default avatarNgoc Diep Do <55230119+diepi@users.noreply.github.com>
Co-authored-by: default avatarWesley Gifford <79663411+wgifford@users.noreply.github.com>

* clean up auto files

* Add PatchTSTForPrediction

* Fix relative import

* Replace original PatchTSTEncoder with ChannelAttentionPatchTSTEncoder

* temporary adding absolute path + add PatchTSTForForecasting class

* Update base PatchTSTModel + Unittest

* Update ForecastHead to use the config class

* edit cv_random_masking, add mask to model output

* Update configuration_patchtst.py

* add masked_loss to the pretraining

* add PatchEmbeddings

* Update configuration_patchtst.py

* edit loss which considers mask in the pretraining

* remove patch_last option

* Add commits from internal repo

* Update ForecastHead

* Add model weight initilization + unittest

* Update PatchTST unittest to use local import

* PatchTST integration tests for pretraining and prediction

* Added PatchTSTForRegression + update unittest to include label generation

* Revert unrelated model test file

* Combine similar output classes

* update PredictionHead

* Update configuration_patchtst.py

* Add Revin

* small edit to PatchTSTModelOutputWithNoAttention

* Update modeling_patchtst.py

* Updating integration test for forecasting

* Fix unittest after class structure changed

* docstring updates

* change input_size to num_input_channels

* more formatting

* Remove some unused params

* Add a comment for pretrained models

* add channel_attention option

add channel_attention option and remove unused positional encoders.

* Update PatchTST models to use HF's MultiHeadAttention module

* Update paper + github urls

* Fix hidden_state return value

* Update integration test to use PatchTSTForForecasting

* Adding dataclass decorator for model output classes

* Run fixup script

* Rename model repos for integration test

* edit argument explanation

* change individual option to shared_projection

* style

* Rename integration test + import cleanup

* Fix outpu_hidden_states return value

* removed unused mode

* added std, mean and nops scaler

* add initial distributional loss for predition

* fix typo in docs

* add generate function

* formatting

* add num_parallel_samples

* Fix a typo

* copy weighted_average function, edit PredictionHead

* edit PredictionHead

* add distribution head to forecasting

* formatting

* Add generate function for forecasting

* Add generate function to prediction task

* formatting

* use argsort

* add past_observed_mask ordering

* fix arguments

* docs

* add back test_model_outputs_equivalence test

* formatting

* cleanup

* formatting

* use ACT2CLS

* formatting

* fix add_start_docstrings decorator

* add distribution head and generate function to regression task

add distribution head and generate function to regression task. Also made add PatchTSTForForecastingOutput,  PatchTSTForRegressionOutput.

* add distribution head and generate function to regression task

add distribution head and generate function to regression task. Also made add PatchTSTForForecastingOutput,  PatchTSTForRegressionOutput.

* fix typos

* add forecast_masking

* fixed tests

* use set_seed

* fix doc test

* formatting

* Update docs/source/en/model_doc/patchtst.md
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* better var names

* rename PatchTSTTranspose

* fix argument names and docs string

* remove compute_num_patches and unused class

* remove assert

* renamed to PatchTSTMasking

* use num_labels for classification

* use num_labels

* use default num_labels from super class

* move model_type after docstring

* renamed PatchTSTForMaskPretraining

* bs -> batch_size

* more review fixes

* use hidden_state

* rename encoder layer and block class

* remove commented seed_number

* edit docstring

* Add docstring

* formatting

* use past_observed_mask

* doc suggestion

* make fix-copies

* use Args:

* add docstring

* add docstring

* change some variable names and add PatchTST before some class names

* formatting

* fix argument types

* fix tests

* change x variable to patch_input

* format

* formatting

* fix-copies

* Update tests/models/patchtst/test_modeling_patchtst.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* move loss to forward

* Update src/transformers/models/patchtst/modeling_patchtst.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/models/patchtst/modeling_patchtst.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/models/patchtst/modeling_patchtst.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/models/patchtst/modeling_patchtst.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/models/patchtst/modeling_patchtst.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* formatting

* fix a bug when pre_norm is set to True

* output_hidden_states is set to False as default

* set pre_norm=True as default

* format docstring

* format

* output_hidden_states is None by default

* add missing docs

* better var names

* docstring: remove default to False in output_hidden_states

* change labels name to target_values in regression task

* format

* fix tests

* change to forecast_mask_ratios and random_mask_ratio

* change mask names

* change future_values to target_values param in the prediction class

* remove nn.Sequential and make PatchTSTBatchNorm class

* black

* fix argument name for prediction

* add output_attentions option

* add output_attentions to PatchTSTEncoder

* formatting

* Add attention output option to all classes

* Remove PatchTSTEncoderBlock

* create PatchTSTEmbedding class

* use config in PatchTSTPatchify

* Use config in PatchTSTMasking class

* add channel_attn_weights

* Add PatchTSTScaler class

* add output_attentions arg to test function

* format

* Update doc with image patchtst.md

* fix-copies

* rename Forecast <-> Prediction

* change name of a few parameters to match with PatchTSMixer.

* Remove *ForForecasting class to match with other time series models.

* make style

* Remove PatchTSTForForecasting in the test

* remove PatchTSTForForecastingOutput class

* change test_forecast_head to test_prediction_head

* style

* fix docs

* fix tests

* change num_labels to num_targets

* Remove PatchTSTTranspose

* remove arguments in PatchTSTMeanScaler

* remove arguments in PatchTSTStdScaler

* add config as an argument to all the scaler classes

* reformat

* Add norm_eps for batchnorm and layernorm

* reformat.

* reformat

* edit docstring

* update docstring

* change variable name pooling to pooling_type

* fix output_hidden_states as tuple

* fix bug when calling PatchTSTBatchNorm

* change stride to patch_stride

* create PatchTSTPositionalEncoding class and restructure the PatchTSTEncoder

* formatting

* initialize scalers with configs

* edit output_hidden_states

* style

* fix forecast_mask_patches doc string

---------
Co-authored-by: default avatarGift Sinthong <gift.sinthong@ibm.com>
Co-authored-by: default avatarNam Nguyen <namctin@gmail.com>
Co-authored-by: default avatarVijay Ekambaram <vijaykr.e@gmail.com>
Co-authored-by: default avatarNgoc Diep Do <55230119+diepi@users.noreply.github.com>
Co-authored-by: default avatarWesley Gifford <79663411+wgifford@users.noreply.github.com>
Co-authored-by: default avatarWesley M. Gifford <wmgifford@us.ibm.com>
Co-authored-by: default avatarnnguyen <nnguyen@us.ibm.com>
Co-authored-by: default avatarNgoc Diep Do <diiepy@gmail.com>
Co-authored-by: default avatarKashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 8017a590
...@@ -83,67 +83,66 @@ class TimeSeriesFeatureEmbedder(nn.Module): ...@@ -83,67 +83,66 @@ class TimeSeriesFeatureEmbedder(nn.Module):
class TimeSeriesStdScaler(nn.Module): class TimeSeriesStdScaler(nn.Module):
""" """
Standardize features by calculating the mean and scaling along some given dimension `dim`, and then normalizes it Standardize features by calculating the mean and scaling along the first dimension, and then normalizes it by
by subtracting from the mean and dividing by the standard deviation. subtracting from the mean and dividing by the standard deviation.
Args:
dim (`int`):
Dimension along which to calculate the mean and standard deviation.
keepdim (`bool`, *optional*, defaults to `False`):
Controls whether to retain dimension `dim` (of length 1) in the scale tensor, or suppress it.
minimum_scale (`float`, *optional*, defaults to 1e-5):
Default scale that is used for elements that are constantly zero along dimension `dim`.
""" """
def __init__(self, dim: int, keepdim: bool = False, minimum_scale: float = 1e-5): def __init__(self, config: TimeSeriesTransformerConfig):
super().__init__() super().__init__()
if not dim > 0: self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
raise ValueError("Cannot compute scale along dim = 0 (batch dimension), please provide dim > 0") self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
self.dim = dim self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-10
self.keepdim = keepdim
self.minimum_scale = minimum_scale
@torch.no_grad() def forward(
def forward(self, data: torch.Tensor, weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: self, data: torch.Tensor, observed_indicator: torch.Tensor
denominator = weights.sum(self.dim, keepdim=self.keepdim) ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Parameters:
data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
input for Batch norm calculation
observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
Calculating the scale on the observed indicator.
Returns:
tuple of `torch.Tensor` of shapes
(`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
`(batch_size, 1, num_input_channels)`)
"""
denominator = observed_indicator.sum(self.dim, keepdim=self.keepdim)
denominator = denominator.clamp_min(1.0) denominator = denominator.clamp_min(1.0)
loc = (data * weights).sum(self.dim, keepdim=self.keepdim) / denominator loc = (data * observed_indicator).sum(self.dim, keepdim=self.keepdim) / denominator
variance = (((data - loc) * weights) ** 2).sum(self.dim, keepdim=self.keepdim) / denominator variance = (((data - loc) * observed_indicator) ** 2).sum(self.dim, keepdim=self.keepdim) / denominator
scale = torch.sqrt(variance + self.minimum_scale) scale = torch.sqrt(variance + self.minimum_scale)
return (data - loc) / scale, loc, scale return (data - loc) / scale, loc, scale
class TimeSeriesMeanScaler(nn.Module): class TimeSeriesMeanScaler(nn.Module):
""" """
Computes a scaling factor as the weighted average absolute value along dimension `dim`, and scales the data Computes a scaling factor as the weighted average absolute value along the first dimension, and scales the data
accordingly. accordingly.
Args:
dim (`int`):
Dimension along which to compute the scale.
keepdim (`bool`, *optional*, defaults to `False`):
Controls whether to retain dimension `dim` (of length 1) in the scale tensor, or suppress it.
default_scale (`float`, *optional*, defaults to `None`):
Default scale that is used for elements that are constantly zero. If `None`, we use the scale of the batch.
minimum_scale (`float`, *optional*, defaults to 1e-10):
Default minimum possible scale that is used for any item.
""" """
def __init__( def __init__(self, config: TimeSeriesTransformerConfig):
self, dim: int = -1, keepdim: bool = True, default_scale: Optional[float] = None, minimum_scale: float = 1e-10
):
super().__init__() super().__init__()
self.dim = dim self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
self.keepdim = keepdim self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
self.minimum_scale = minimum_scale self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-10
self.default_scale = default_scale self.default_scale = config.default_scale if hasattr(config, "default_scale") else None
@torch.no_grad()
def forward( def forward(
self, data: torch.Tensor, observed_indicator: torch.Tensor self, data: torch.Tensor, observed_indicator: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# shape: (N, [C], T=1) """
Parameters:
data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
input for Batch norm calculation
observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
Calculating the scale on the observed indicator.
Returns:
tuple of `torch.Tensor` of shapes
(`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
`(batch_size, 1, num_input_channels)`)
"""
ts_sum = (data * observed_indicator).abs().sum(self.dim, keepdim=True) ts_sum = (data * observed_indicator).abs().sum(self.dim, keepdim=True)
num_observed = observed_indicator.sum(self.dim, keepdim=True) num_observed = observed_indicator.sum(self.dim, keepdim=True)
...@@ -173,23 +172,26 @@ class TimeSeriesMeanScaler(nn.Module): ...@@ -173,23 +172,26 @@ class TimeSeriesMeanScaler(nn.Module):
class TimeSeriesNOPScaler(nn.Module): class TimeSeriesNOPScaler(nn.Module):
""" """
Assigns a scaling factor equal to 1 along dimension `dim`, and therefore applies no scaling to the input data. Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data.
Args:
dim (`int`):
Dimension along which to compute the scale.
keepdim (`bool`, *optional*, defaults to `False`):
Controls whether to retain dimension `dim` (of length 1) in the scale tensor, or suppress it.
""" """
def __init__(self, dim: int, keepdim: bool = False): def __init__(self, config: TimeSeriesTransformerConfig):
super().__init__() super().__init__()
self.dim = dim self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
self.keepdim = keepdim self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
def forward( def forward(
self, data: torch.Tensor, observed_indicator: torch.Tensor self, data: torch.Tensor, observed_indicator: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Parameters:
data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
input for Batch norm calculation
Returns:
tuple of `torch.Tensor` of shapes
(`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
`(batch_size, 1, num_input_channels)`)
"""
scale = torch.ones_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim) scale = torch.ones_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim)
loc = torch.zeros_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim) loc = torch.zeros_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim)
return data, loc, scale return data, loc, scale
...@@ -1180,11 +1182,11 @@ class TimeSeriesTransformerModel(TimeSeriesTransformerPreTrainedModel): ...@@ -1180,11 +1182,11 @@ class TimeSeriesTransformerModel(TimeSeriesTransformerPreTrainedModel):
super().__init__(config) super().__init__(config)
if config.scaling == "mean" or config.scaling is True: if config.scaling == "mean" or config.scaling is True:
self.scaler = TimeSeriesMeanScaler(dim=1, keepdim=True) self.scaler = TimeSeriesMeanScaler(config)
elif config.scaling == "std": elif config.scaling == "std":
self.scaler = TimeSeriesStdScaler(dim=1, keepdim=True) self.scaler = TimeSeriesStdScaler(config)
else: else:
self.scaler = TimeSeriesNOPScaler(dim=1, keepdim=True) self.scaler = TimeSeriesNOPScaler(config)
if config.num_static_categorical_features > 0: if config.num_static_categorical_features > 0:
self.embedder = TimeSeriesFeatureEmbedder( self.embedder = TimeSeriesFeatureEmbedder(
......
...@@ -627,6 +627,12 @@ MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING = None ...@@ -627,6 +627,12 @@ MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING = None
MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING = None MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING = None
MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING = None
MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING = None
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
...@@ -6019,6 +6025,51 @@ class OwlViTVisionModel(metaclass=DummyObject): ...@@ -6019,6 +6025,51 @@ class OwlViTVisionModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
PATCHTST_PRETRAINED_MODEL_ARCHIVE_LIST = None
class PatchTSTForClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PatchTSTForPrediction(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PatchTSTForPretraining(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PatchTSTForRegression(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PatchTSTModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PatchTSTPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PegasusForCausalLM(metaclass=DummyObject): class PegasusForCausalLM(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch PatchTST model. """
import inspect
import random
import tempfile
import unittest
from huggingface_hub import hf_hub_download
from transformers import is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import is_flaky, require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
TOLERANCE = 1e-4
if is_torch_available():
import torch
from transformers import (
MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING,
MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING,
PatchTSTConfig,
PatchTSTForClassification,
PatchTSTForPrediction,
PatchTSTForPretraining,
PatchTSTForRegression,
PatchTSTModel,
)
@require_torch
class PatchTSTModelTester:
def __init__(
self,
parent,
batch_size=13,
prediction_length=7,
context_length=14,
patch_length=5,
patch_stride=5,
num_input_channels=1,
num_time_features=1,
is_training=True,
hidden_size=16,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=4,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
lags_sequence=[1, 2, 3, 4, 5],
distil=False,
seed_number=42,
num_targets=2,
num_output_channels=2,
):
self.parent = parent
self.batch_size = batch_size
self.prediction_length = prediction_length
self.context_length = context_length
self.patch_length = patch_length
self.patch_stride = patch_stride
self.num_input_channels = num_input_channels
self.num_time_features = num_time_features
self.lags_sequence = lags_sequence
self.is_training = is_training
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.seed_number = seed_number
self.num_targets = num_targets
self.num_output_channels = num_output_channels
self.distil = distil
self.num_patches = (max(self.context_length, self.patch_length) - self.patch_length) // self.patch_stride + 1
def get_config(self):
return PatchTSTConfig(
prediction_length=self.prediction_length,
patch_length=self.patch_length,
patch_stride=self.patch_stride,
num_input_channels=self.num_input_channels,
d_model=self.hidden_size,
encoder_layers=self.num_hidden_layers,
encoder_attention_heads=self.num_attention_heads,
encoder_ffn_dim=self.intermediate_size,
dropout=self.hidden_dropout_prob,
attention_dropout=self.attention_probs_dropout_prob,
context_length=self.context_length,
activation_function=self.hidden_act,
seed_number=self.seed_number,
num_targets=self.num_targets,
num_output_channels=self.num_output_channels,
)
def prepare_patchtst_inputs_dict(self, config):
_past_length = config.context_length
# bs, num_input_channels, num_patch, patch_len
# [bs x seq_len x num_input_channels]
past_values = floats_tensor([self.batch_size, _past_length, self.num_input_channels])
future_values = floats_tensor([self.batch_size, config.prediction_length, self.num_input_channels])
inputs_dict = {
"past_values": past_values,
"future_values": future_values,
}
return inputs_dict
def prepare_config_and_inputs(self):
config = self.get_config()
inputs_dict = self.prepare_patchtst_inputs_dict(config)
return config, inputs_dict
def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs()
return config, inputs_dict
@require_torch
class PatchTSTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(
PatchTSTModel,
PatchTSTForPrediction,
PatchTSTForPretraining,
PatchTSTForClassification,
PatchTSTForRegression,
)
if is_torch_available()
else ()
)
all_generative_model_classes = (
(PatchTSTForPrediction, PatchTSTForRegression, PatchTSTForPretraining) if is_torch_available() else ()
)
pipeline_model_mapping = {"feature-extraction": PatchTSTModel} if is_torch_available() else {}
test_pruning = False
test_head_masking = False
test_missing_keys = False
test_torchscript = False
test_inputs_embeds = False
test_model_common_attributes = False
test_resize_embeddings = True
test_resize_position_embeddings = False
test_mismatched_shapes = True
test_model_parallel = False
has_attentions = False
def setUp(self):
self.model_tester = PatchTSTModelTester(self)
self.config_tester = ConfigTester(
self,
config_class=PatchTSTConfig,
has_text_modality=False,
prediction_length=self.model_tester.prediction_length,
)
def test_config(self):
self.config_tester.run_common_tests()
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
# if PatchTSTForPretraining
if model_class == PatchTSTForPretraining:
inputs_dict.pop("future_values")
# else if classification model:
elif model_class in get_values(MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING):
rng = random.Random(self.model_tester.seed_number)
labels = ids_tensor([self.model_tester.batch_size], self.model_tester.num_targets, rng=rng)
inputs_dict["target_values"] = labels
inputs_dict.pop("future_values")
elif model_class in get_values(MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING):
rng = random.Random(self.model_tester.seed_number)
target_values = floats_tensor(
[self.model_tester.batch_size, self.model_tester.num_output_channels], rng=rng
)
inputs_dict["target_values"] = target_values
inputs_dict.pop("future_values")
return inputs_dict
def test_save_load_strict(self):
config, _ = self.model_tester.prepare_config_and_inputs()
for model_class in self.all_model_classes:
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], [])
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.hidden_states
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers
)
self.assertEqual(len(hidden_states), expected_num_layers)
num_patch = self.model_tester.num_patches
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[num_patch, self.model_tester.hidden_size],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
print("model_class: ", model_class)
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
@unittest.skip(reason="we have no tokens embeddings")
def test_resize_tokens_embeddings(self):
pass
def test_model_main_input_name(self):
model_signature = inspect.signature(getattr(PatchTSTModel, "forward"))
# The main input is the name of the argument after `self`
observed_main_input_name = list(model_signature.parameters.keys())[1]
self.assertEqual(PatchTSTModel.main_input_name, observed_main_input_name)
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = [
"past_values",
"past_observed_mask",
"future_values",
]
if model_class == PatchTSTForPretraining:
expected_arg_names.remove("future_values")
elif model_class in get_values(MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING) or model_class in get_values(
MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING
):
expected_arg_names.remove("future_values")
expected_arg_names.remove("past_observed_mask")
expected_arg_names.append("target_values") if model_class in get_values(
MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING
) else expected_arg_names.append("target_values")
expected_arg_names.append("past_observed_mask")
expected_arg_names.extend(
[
"output_hidden_states",
"output_attentions",
"return_dict",
]
)
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
@is_flaky()
def test_retain_grad_hidden_states_attentions(self):
super().test_retain_grad_hidden_states_attentions()
# Note: Publishing of this dataset is under internal review. The dataset is not yet downloadable.
def prepare_batch(repo_id="ibm/etth1-forecast-test", file="train-batch.pt"):
file = hf_hub_download(repo_id=repo_id, filename=file, repo_type="dataset")
batch = torch.load(file, map_location=torch_device)
return batch
# Note: Publishing of pretrained weights is under internal review. Pretrained model is not yet downloadable.
@require_torch
@slow
class PatchTSTModelIntegrationTests(unittest.TestCase):
# Publishing of pretrained weights are under internal review. Pretrained model is not yet downloadable.
def test_pretrain_head(self):
model = PatchTSTForPretraining.from_pretrained("ibm/patchtst-etth1-pretrain").to(torch_device)
batch = prepare_batch()
torch.manual_seed(0)
with torch.no_grad():
output = model(past_values=batch["past_values"].to(torch_device)).prediction_output
num_patch = (
max(model.config.context_length, model.config.patch_length) - model.config.patch_length
) // model.config.patch_stride + 1
expected_shape = torch.Size([64, model.config.num_input_channels, num_patch, model.config.patch_length])
self.assertEqual(output.shape, expected_shape)
expected_slice = torch.tensor(
[[[-0.5409]], [[0.3093]], [[-0.3759]], [[0.5068]], [[-0.8387]], [[0.0937]], [[0.2809]]],
device=torch_device,
)
self.assertTrue(torch.allclose(output[0, :7, :1, :1], expected_slice, atol=TOLERANCE))
# Publishing of pretrained weights are under internal review. Pretrained model is not yet downloadable.
def test_prediction_head(self):
model = PatchTSTForPrediction.from_pretrained("ibm/patchtst-etth1-forecast").to(torch_device)
batch = prepare_batch(file="test-batch.pt")
torch.manual_seed(0)
with torch.no_grad():
output = model(
past_values=batch["past_values"].to(torch_device),
future_values=batch["future_values"].to(torch_device),
).prediction_outputs
expected_shape = torch.Size([64, model.config.prediction_length, model.config.num_input_channels])
self.assertEqual(output.shape, expected_shape)
expected_slice = torch.tensor(
[[0.3228, 0.4320, 0.4591, 0.4066, -0.3461, 0.3094, -0.8426]],
device=torch_device,
)
self.assertTrue(torch.allclose(output[0, :1, :7], expected_slice, atol=TOLERANCE))
...@@ -185,6 +185,8 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ ...@@ -185,6 +185,8 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"TimeSeriesTransformerForPrediction", "TimeSeriesTransformerForPrediction",
"InformerForPrediction", "InformerForPrediction",
"AutoformerForPrediction", "AutoformerForPrediction",
"PatchTSTForPretraining",
"PatchTSTForPrediction",
"JukeboxVQVAE", "JukeboxVQVAE",
"JukeboxPrior", "JukeboxPrior",
"SamModel", "SamModel",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment