Unverified Commit 5cd16f01 authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

time series forecasting model (#17965)



* initial files

* initial model via cli

* typos

* make a start on the model config

* ready with configuation

* remove tokenizer ref.

* init the transformer

* added initial model forward to return dec_output

* require gluonts

* update dep. ver table and add as extra

* fixed typo

* add type for prediction_length

* use num_time_features

* use config

* more config

* typos

* opps another typo

* freq can be none

* default via transformation is 1

* initial transformations

* fix imports

* added transform_start_field

* add helper to create pytorch dataloader

* added inital val and test data loader

* added initial distr head and loss

* training working

* remove TimeSeriesTransformerTokenizer
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/__init__.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/time_series_transformer/__init__.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* fixed copyright

* removed docs

* remove time series tokenizer

* fixed docs

* fix text

* fix second

* fix default

* fix order

* use config directly

* undo change

* fix comment

* fix year

* fix import

* add additional arguments for training vs. test

* initial greedy inference loop

* fix inference

* comment out token inputs to enc dec

* Use HF encoder/decoder

* fix inference

* Use Seq2SeqTSModelOutput output

* return Seq2SeqTSPredictionOutput

* added default arguments

* fix return_dict true

* scale is a tensor

* output static_features for inference

* clean up some unused bits

* fixed typo

* set return_dict if none

* call model once for both train/predict

* use cache if future_target is none

* initial generate func

* generate arguments

* future_time_feat is required

* return SampleTSPredictionOutput

* removed unneeded classes

* fix when params is none

* fix return dict

* fix num_attention_heads

* fix arguments

* remove unused shift_tokens_right

* add different dropout configs

* implement FeatureEmbedder, Scaler and weighted_average

* remove gluonts dependency

* fix class names

* avoid _variable names

* remove gluonts dependency

* fix imports

* remove gluonts from configuration

* fix docs

* fixed typo

* move utils to examples

* add example requirements

* config has no freq

* initial run_ts_no_trainer

* remove from ignore

* fix output_attentions and removed unsued getters/setters

* removed unsed tests

* add dec seq len

* add test_attention_outputs

* set has_text_modality=False

* add config attribute_map

* make style

* make fix-copies

* add encoder_outputs to TimeSeriesTransformerForPrediction forward

* Improve docs, add model to README

* added test_forward_signature

* More improvements

* Add more copied from

* Fix README

* Fix remaining quality issues

* updated encoder and decoder

* fix generate

* output_hidden_states and use_cache are optional

* past key_values returned too

* initialize weights of distribution_output module

* fixed more tests

* update test_forward_signature

* fix return_dict outputs

* Update src/transformers/models/time_series_transformer/configuration_time_series_transformer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/time_series_transformer/configuration_time_series_transformer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/time_series_transformer/configuration_time_series_transformer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/time_series_transformer/configuration_time_series_transformer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/time_series_transformer/modeling_time_series_transformer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/time_series_transformer/modeling_time_series_transformer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/time_series_transformer/modeling_time_series_transformer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* removed commented out tests

* added neg. bin and normal output

* Update src/transformers/models/time_series_transformer/configuration_time_series_transformer.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* move to one line

* Add docstrings

* Update src/transformers/models/time_series_transformer/configuration_time_series_transformer.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* add try except for assert and raise

* try and raise exception

* fix the documentation formatting

* fix assert call

* fix docstring formatting

* removed input_ids from DOCSTRING

* Update input docstring

* Improve variable names

* Update order of inputs

* Improve configuration

* Improve variable names

* Improve docs

* Remove key_length from tests

* Add extra docs

* initial unittests

* added test_inference_no_head test

* added test_inference_head

* add test_seq_to_seq_generation

* make style

* one line

* assert mean prediction

* removed comments

* Update src/transformers/models/time_series_transformer/modeling_time_series_transformer.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/time_series_transformer/modeling_time_series_transformer.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* fix order of args

* make past_observed_mask optional as well

* added Amazon license header

* updated utils with new fieldnames

* make style

* cleanup

* undo position of past_observed_mask

* fix import

* typo

* more typo

* rename example files

* remove example for now

* Update docs/source/en/_toctree.yml
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/time_series_transformer/configuration_time_series_transformer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/time_series_transformer/modeling_time_series_transformer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/time_series_transformer/modeling_time_series_transformer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update modeling_time_series_transformer.py

fix style

* fixed typo

* fix typo and grammer

* fix style
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>
Co-authored-by: default avatarNielsRogge <niels.rogge1@gmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent cfb777f2
......@@ -375,6 +375,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h
1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
1. **[TAPAS](https://huggingface.co/docs/transformers/model_doc/tapas)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos.
1. **[TAPEX](https://huggingface.co/docs/transformers/model_doc/tapex)** (from Microsoft Research) released with the paper [TAPEX: Table Pre-training via Learning a Neural SQL Executor](https://arxiv.org/abs/2107.07653) by Qian Liu, Bei Chen, Jiaqi Guo, Morteza Ziyadi, Zeqi Lin, Weizhu Chen, Jian-Guang Lou.
1. **[Time Series Transformer](https://huggingface.co/docs/transformers/main/model_doc/time_series_transformer)** (from HuggingFace).
1. **[Trajectory Transformer](https://huggingface.co/docs/transformers/model_doc/trajectory_transformers)** (from the University of California at Berkeley) released with the paper [Offline Reinforcement Learning as One Big Sequence Modeling Problem](https://arxiv.org/abs/2106.02039) by Michael Janner, Qiyang Li, Sergey Levine
1. **[Transformer-XL](https://huggingface.co/docs/transformers/model_doc/transfo-xl)** (from Google/CMU) released with the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (from Microsoft), released together with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei.
......
......@@ -325,6 +325,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
1. **[TAPAS](https://huggingface.co/docs/transformers/model_doc/tapas)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos.
1. **[TAPEX](https://huggingface.co/docs/transformers/model_doc/tapex)** (from Microsoft Research) released with the paper [TAPEX: Table Pre-training via Learning a Neural SQL Executor](https://arxiv.org/abs/2107.07653) by Qian Liu, Bei Chen, Jiaqi Guo, Morteza Ziyadi, Zeqi Lin, Weizhu Chen, Jian-Guang Lou.
1. **[Time Series Transformer](https://huggingface.co/docs/transformers/main/model_doc/time_series_transformer)** (from HuggingFace).
1. **[Trajectory Transformer](https://huggingface.co/docs/transformers/model_doc/trajectory_transformers)** (from the University of California at Berkeley) released with the paper [Offline Reinforcement Learning as One Big Sequence Modeling Problem](https://arxiv.org/abs/2106.02039) by Michael Janner, Qiyang Li, Sergey Levine
1. **[Transformer-XL](https://huggingface.co/docs/transformers/model_doc/transfo-xl)** (from Google/CMU) released with the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (from Microsoft), released together with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei.
......
......@@ -349,6 +349,7 @@ conda install -c huggingface transformers
1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (来自 Google AI) 伴随论文 [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) 由 Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu 发布。
1. **[TAPAS](https://huggingface.co/docs/transformers/model_doc/tapas)** (来自 Google AI) 伴随论文 [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) 由 Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos 发布。
1. **[TAPEX](https://huggingface.co/docs/transformers/model_doc/tapex)** (来自 Microsoft Research) 伴随论文 [TAPEX: Table Pre-training via Learning a Neural SQL Executor](https://arxiv.org/abs/2107.07653) 由 Qian Liu, Bei Chen, Jiaqi Guo, Morteza Ziyadi, Zeqi Lin, Weizhu Chen, Jian-Guang Lou 发布。
1. **[Time Series Transformer](https://huggingface.co/docs/transformers/main/model_doc/time_series_transformer)** (from HuggingFace).
1. **[Trajectory Transformer](https://huggingface.co/docs/transformers/model_doc/trajectory_transformers)** (from the University of California at Berkeley) released with the paper [Offline Reinforcement Learning as One Big Sequence Modeling Problem](https://arxiv.org/abs/2106.02039) by Michael Janner, Qiyang Li, Sergey Levine
1. **[Transformer-XL](https://huggingface.co/docs/transformers/model_doc/transfo-xl)** (来自 Google/CMU) 伴随论文 [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) 由 Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov 发布。
1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (来自 Microsoft) 伴随论文 [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) 由 Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei 发布。
......
......@@ -361,6 +361,7 @@ conda install -c huggingface transformers
1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (from Google AI) released with the paper [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
1. **[TAPAS](https://huggingface.co/docs/transformers/model_doc/tapas)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos.
1. **[TAPEX](https://huggingface.co/docs/transformers/model_doc/tapex)** (from Microsoft Research) released with the paper [TAPEX: Table Pre-training via Learning a Neural SQL Executor](https://arxiv.org/abs/2107.07653) by Qian Liu, Bei Chen, Jiaqi Guo, Morteza Ziyadi, Zeqi Lin, Weizhu Chen, Jian-Guang Lou.
1. **[Time Series Transformer](https://huggingface.co/docs/transformers/main/model_doc/time_series_transformer)** (from HuggingFace).
1. **[Trajectory Transformer](https://huggingface.co/docs/transformers/model_doc/trajectory_transformers)** (from the University of California at Berkeley) released with the paper [Offline Reinforcement Learning as One Big Sequence Modeling Problem](https://arxiv.org/abs/2106.02039) by Michael Janner, Qiyang Li, Sergey Levine
1. **[Transformer-XL](https://huggingface.co/docs/transformers/model_doc/transfo-xl)** (from Google/CMU) released with the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (from Microsoft) released with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei.
......
......@@ -498,6 +498,11 @@
- local: model_doc/trajectory_transformer
title: Trajectory Transformer
title: Reinforcement learning models
- isExpanded: false
sections:
- local: model_doc/time_series_transformer
title: Time Series Transformer
title: Time series models
title: Models
- sections:
- local: internal/modeling_utils
......
......@@ -165,6 +165,7 @@ The documentation is organized into five sections:
1. **[T5v1.1](model_doc/t5v1.1)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
1. **[TAPAS](model_doc/tapas)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos.
1. **[TAPEX](model_doc/tapex)** (from Microsoft Research) released with the paper [TAPEX: Table Pre-training via Learning a Neural SQL Executor](https://arxiv.org/abs/2107.07653) by Qian Liu, Bei Chen, Jiaqi Guo, Morteza Ziyadi, Zeqi Lin, Weizhu Chen, Jian-Guang Lou.
1. **[Time Series Transformer](model_doc/time_series_transformer)** (from HuggingFace).
1. **[Trajectory Transformer](model_doc/trajectory_transformers)** (from the University of California at Berkeley) released with the paper [Offline Reinforcement Learning as One Big Sequence Modeling Problem](https://arxiv.org/abs/2106.02039) by Michael Janner, Qiyang Li, Sergey Levine
1. **[Transformer-XL](model_doc/transfo-xl)** (from Google/CMU) released with the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
1. **[TrOCR](model_doc/trocr)** (from Microsoft), released together with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei.
......@@ -310,6 +311,7 @@ Flax), PyTorch, and/or TensorFlow.
| Swin Transformer V2 | ❌ | ❌ | ✅ | ❌ | ❌ |
| T5 | ✅ | ✅ | ✅ | ✅ | ✅ |
| TAPAS | ✅ | ❌ | ✅ | ✅ | ❌ |
| Time Series Transformer | ❌ | ❌ | ✅ | ❌ | ❌ |
| Trajectory Transformer | ❌ | ❌ | ✅ | ❌ | ❌ |
| Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ |
| TrOCR | ❌ | ❌ | ✅ | ❌ | ❌ |
......
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Time Series Transformer
<Tip>
This is a recently introduced model so the API hasn't been tested extensively. There may be some bugs or slight
breaking changes to fix it in the future. If you see something strange, file a [Github Issue](https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title).
</Tip>
## Overview
The Time Series Transformer model is a vanilla encoder-decoder Transformer for time series forecasting.
Tips:
- Similar to other models in the library, [`TimeSeriesTransformerModel`] is the raw Transformer without any head on top, and [`TimeSeriesTransformerForPrediction`]
adds a distribution head on top of the former, which can be used for time-series forecasting. Note that this is a so-called probabilistic forecasting model, not a
point forecasting model. This means that the model learns a distribution, from which one can sample. The model doesn't directly output values.
- [`TimeSeriesTransformerForPrediction`] consists of 2 blocks: an encoder, which takes a `context_length` of time series values as input (called `past_values`),
and a decoder, which predicts a `prediction_length` of time series values into the future (called `future_values`). During training, one needs to provide
pairs of (`past_values` and `future_values`) to the model.
- In addition to the raw (`past_values` and `future_values`), one typically provides additional features to the model. These can be the following:
- `past_time_features`: temporal features which the model will add to `past_values`. These serve as "positional encodings" for the Transformer encoder.
Examples are "day of the month", "month of the year", etc. as scalar values (and then stacked together as a vector).
e.g. if a given time-series value was obtained on the 11th of August, then one could have [11, 8] as time feature vector (11 being "day of the month", 8 being "month of the year").
- `future_time_features`: temporal features which the model will add to `future_values`. These serve as "positional encodings" for the Transformer decoder.
Examples are "day of the month", "month of the year", etc. as scalar values (and then stacked together as a vector).
e.g. if a given time-series value was obtained on the 11th of August, then one could have [11, 8] as time feature vector (11 being "day of the month", 8 being "month of the year").
- `static_categorical_features`: categorical features which are static over time (i.e., have the same value for all `past_values` and `future_values`).
An example here is the store ID or region ID that identifies a given time-series.
Note that these features need to be known for ALL data points (also those in the future).
- `static_real_features`: real-valued features which are static over time (i.e., have the same value for all `past_values` and `future_values`).
An example here is the image representation of the product for which you have the time-series values (like the [ResNet](resnet) embedding of a "shoe" picture,
if your time-series is about the sales of shoes).
Note that these features need to be known for ALL data points (also those in the future).
- The model is trained using "teacher-forcing", similar to how a Transformer is trained for machine translation. This means that, during training, one shifts the
`future_values` one position to the right as input to the decoder, prepended by the last value of `past_values`. At each time step, the model needs to predict the
next target. So the set-up of training is similar to a GPT model for language, except that there's no notion of `decoder_start_token_id` (we just use the last value
of the context as initial input for the decoder).
- At inference time, we give the final value of the `past_values` as input to the decoder. Next, we can sample from the model to make a prediction at the next time step,
which is then fed to the decoder in order to make the next prediction (also called autoregressive generation).
This model was contributed by [kashif](<https://huggingface.co/kashif).
## TimeSeriesTransformerConfig
[[autodoc]] TimeSeriesTransformerConfig
## TimeSeriesTransformerModel
[[autodoc]] TimeSeriesTransformerModel
- forward
## TimeSeriesTransformerForPrediction
[[autodoc]] TimeSeriesTransformerForPrediction
- forward
\ No newline at end of file
......@@ -281,6 +281,7 @@ extras["vision"] = deps_list("Pillow")
extras["timm"] = deps_list("timm")
extras["codecarbon"] = deps_list("codecarbon")
extras["sentencepiece"] = deps_list("sentencepiece", "protobuf")
extras["testing"] = (
deps_list(
......
......@@ -336,6 +336,10 @@ _import_structure = {
"models.t5": ["T5_PRETRAINED_CONFIG_ARCHIVE_MAP", "T5Config"],
"models.tapas": ["TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP", "TapasConfig", "TapasTokenizer"],
"models.tapex": ["TapexTokenizer"],
"models.time_series_transformer": [
"TIME_SERIES_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"TimeSeriesTransformerConfig",
],
"models.trajectory_transformer": [
"TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"TrajectoryTransformerConfig",
......@@ -815,6 +819,15 @@ else:
_import_structure["modeling_utils"] = ["PreTrainedModel"]
# PyTorch models structure
_import_structure["models.time_series_transformer"].extend(
[
"TIME_SERIES_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"TimeSeriesTransformerForPrediction",
"TimeSeriesTransformerModel",
"TimeSeriesTransformerPreTrainedModel",
]
)
_import_structure["models.albert"].extend(
[
"ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
......@@ -3286,6 +3299,10 @@ if TYPE_CHECKING:
from .models.t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
from .models.tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig, TapasTokenizer
from .models.tapex import TapexTokenizer
from .models.time_series_transformer import (
TIME_SERIES_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
TimeSeriesTransformerConfig,
)
from .models.trajectory_transformer import (
TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
TrajectoryTransformerConfig,
......@@ -4577,6 +4594,12 @@ if TYPE_CHECKING:
T5PreTrainedModel,
load_tf_weights_in_t5,
)
from .models.time_series_transformer import (
TIME_SERIES_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TimeSeriesTransformerForPrediction,
TimeSeriesTransformerModel,
TimeSeriesTransformerPreTrainedModel,
)
from .models.trajectory_transformer import (
TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TrajectoryTransformerModel,
......
......@@ -139,6 +139,7 @@ from . import (
t5,
tapas,
tapex,
time_series_transformer,
trajectory_transformer,
transfo_xl,
trocr,
......
......@@ -134,6 +134,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("swinv2", "Swinv2Config"),
("t5", "T5Config"),
("tapas", "TapasConfig"),
("time_series_transformer", "TimeSeriesTransformerConfig"),
("trajectory_transformer", "TrajectoryTransformerConfig"),
("transfo-xl", "TransfoXLConfig"),
("trocr", "TrOCRConfig"),
......@@ -262,6 +263,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
("swinv2", "SWINV2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("t5", "T5_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("tapas", "TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("time_series_transformer", "TIME_SERIES_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("transfo-xl", "TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("unispeech", "UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("unispeech-sat", "UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
......@@ -412,6 +414,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("t5v1.1", "T5v1.1"),
("tapas", "TAPAS"),
("tapex", "TAPEX"),
("time_series_transformer", "Time Series Transformer"),
("trajectory_transformer", "Trajectory Transformer"),
("transfo-xl", "Transformer-XL"),
("trocr", "TrOCR"),
......
......@@ -130,6 +130,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("swinv2", "Swinv2Model"),
("t5", "T5Model"),
("tapas", "TapasModel"),
("time_series_transformer", "TimeSeriesTransformerModel"),
("trajectory_transformer", "TrajectoryTransformerModel"),
("transfo-xl", "TransfoXLModel"),
("unispeech", "UniSpeechModel"),
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
# rely on isort to merge the imports
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {
"configuration_time_series_transformer": [
"TIME_SERIES_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"TimeSeriesTransformerConfig",
],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_time_series_transformer"] = [
"TIME_SERIES_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"TimeSeriesTransformerForPrediction",
"TimeSeriesTransformerModel",
"TimeSeriesTransformerPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_time_series_transformer import (
TIME_SERIES_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
TimeSeriesTransformerConfig,
)
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_time_series_transformer import (
TIME_SERIES_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TimeSeriesTransformerForPrediction,
TimeSeriesTransformerModel,
TimeSeriesTransformerPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Time Series Transformer model configuration"""
from typing import List, Optional
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
TIME_SERIES_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"huggingface/time-series-transformer-tourism-monthly": (
"https://huggingface.co/huggingface/time-series-transformer-tourism-monthly/resolve/main/config.json"
),
# See all TimeSeriesTransformer models at https://huggingface.co/models?filter=time_series_transformer
}
class TimeSeriesTransformerConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`TimeSeriesTransformerModel`]. It is used to
instantiate a Time Series Transformer model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of the Time Series
Transformer
[huggingface/time-series-transformer-tourism-monthly](https://huggingface.co/huggingface/time-series-transformer-tourism-monthly)
architecture.
Configuration objects inherit from [`PretrainedConfig`] can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
prediction_length (`int`):
The prediction length for the decoder. In other words, the prediction horizon of the model.
context_length (`int`, *optional*, defaults to `prediction_length`):
The context length for the encoder. If `None`, the context length will be the same as the
`prediction_length`.
distribution_output (`string`, *optional*, defaults to `"student_t"`):
The distribution emission head for the model. Could be either "student_t", "normal" or "negative_binomial".
loss (`string`, *optional*, defaults to `"nll"`):
The loss function for the model corresponding to the `distribution_output` head. For parametric
distributions it is the negative log likelihood (nll) - which currently is the only supported one.
input_size (`int`, *optional*, defaults to 1):
The size of the target variable which by default is 1 for univariate targets. Would be > 1 in case of
multivarate targets.
scaling (`bool`, *optional* defaults to `True`):
Whether to scale the input targets.
lags_sequence (`list[int]`, *optional*, defaults to `[1, 2, 3, 4, 5, 6, 7]`):
The lags of the input time series as covariates often dictated by the frequency. Default is `[1, 2, 3, 4,
5, 6, 7]`.
num_time_features (`int`, *optional*, defaults to 0):
The number of time features in the input time series.
num_dynamic_real_features (`int`, *optional*, defaults to 0):
The number of dynamic real valued features.
num_static_categorical_features (`int`, *optional*, defaults to 0):
The number of static categorical features.
num_static_real_features (`int`, *optional*, defaults to 0):
The number of static real valued features.
cardinality (`list[int]`, *optional*):
The cardinality (number of different values) for each of the static categorical features. Should be a list
of integers, having the same length as `num_static_categorical_features`. Cannot be `None` if
`num_static_categorical_features` is > 0.
embedding_dimension (`list[int]`, *optional*):
The dimension of the embedding for each of the static categorical features. Should be a list of integers,
having the same length as `num_static_categorical_features`. Cannot be `None` if
`num_static_categorical_features` is > 0.
encoder_layers (`int`, *optional*, defaults to 2):
Number of encoder layers.
decoder_layers (`int`, *optional*, defaults to 2):
Number of decoder layers.
encoder_attention_heads (`int`, *optional*, defaults to 2):
Number of attention heads for each attention layer in the Transformer encoder.
decoder_attention_heads (`int`, *optional*, defaults to 2):
Number of attention heads for each attention layer in the Transformer decoder.
encoder_ffn_dim (`int`, *optional*, defaults to 32):
Dimension of the "intermediate" (often named feed-forward) layer in encoder.
decoder_ffn_dim (`int`, *optional*, defaults to 32):
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and decoder. If string, `"gelu"` and
`"relu"` are supported.
dropout (`float`, *optional*, defaults to 0.1):
The dropout probability for all fully connected layers in the encoder, and decoder.
encoder_layerdrop (`float`, *optional*, defaults to 0.1):
The dropout probability for the attention and fully connected layers for each encoder layer.
decoder_layerdrop (`float`, *optional*, defaults to 0.1):
The dropout probability for the attention and fully connected layers for each decoder layer.
attention_dropout (`float`, *optional*, defaults to 0.1):
The dropout probability for the attention probabilities.
activation_dropout (`float`, *optional*, defaults to 0.1):
The dropout probability used between the two layers of the feed-forward networks.
num_parallel_samples (`int`, *optional*, defaults to 100):
The number of samples to generate in parallel for each time step of inference.
init_std (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated normal weight initialization distribution.
use_cache (`bool`, *optional*, defaults to `True`):
Whether to use the past key/values attentions (if applicable to the model) to speed up decoding.
Example:
```python
>>> from transformers import TimeSeriesTransformerConfig, TimeSeriesTransformerModel
>>> # Initializing a default Time Series Transformer configuration
>>> configuration = TimeSeriesTransformerConfig()
>>> # Randomly initializing a model from the configuration
>>> model = TimeSeriesTransformerModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "time_series_transformer"
attribute_map = {
"hidden_size": "d_model",
"num_attention_heads": "encoder_attention_heads",
"num_hidden_layers": "encoder_layers",
}
def __init__(
self,
input_size: int = 1,
prediction_length: Optional[int] = None,
context_length: Optional[int] = None,
distribution_output: str = "student_t",
loss: str = "nll",
lags_sequence: List[int] = [1, 2, 3, 4, 5, 6, 7],
scaling: bool = True,
num_dynamic_real_features: int = 0,
num_static_categorical_features: int = 0,
num_static_real_features: int = 0,
num_time_features: int = 0,
cardinality: Optional[List[int]] = None,
embedding_dimension: Optional[List[int]] = None,
encoder_ffn_dim: int = 32,
decoder_ffn_dim: int = 32,
encoder_attention_heads: int = 2,
decoder_attention_heads: int = 2,
encoder_layers: int = 2,
decoder_layers: int = 2,
is_encoder_decoder: bool = True,
activation_function: str = "gelu",
dropout: float = 0.1,
encoder_layerdrop: float = 0.1,
decoder_layerdrop: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
num_parallel_samples: int = 100,
init_std: float = 0.02,
use_cache=True,
**kwargs
):
# time series specific configuration
self.prediction_length = prediction_length
self.context_length = context_length or prediction_length
self.distribution_output = distribution_output
self.loss = loss
self.input_size = input_size
self.num_time_features = num_time_features
self.lags_sequence = lags_sequence
self.scaling = scaling
self.num_dynamic_real_features = num_dynamic_real_features
self.num_static_real_features = num_static_real_features
self.num_static_categorical_features = num_static_categorical_features
if cardinality and num_static_categorical_features > 0:
if len(cardinality) != num_static_categorical_features:
raise ValueError(
"The cardinality should be a list of the same length as `num_static_categorical_features`"
)
self.cardinality = cardinality
else:
self.cardinality = [1]
if embedding_dimension and num_static_categorical_features > 0:
if len(embedding_dimension) != num_static_categorical_features:
raise ValueError(
"The embedding dimension should be a list of the same length as `num_static_categorical_features`"
)
self.embedding_dimension = embedding_dimension
else:
self.embedding_dimension = [min(50, (cat + 1) // 2) for cat in self.cardinality]
self.num_parallel_samples = num_parallel_samples
# Transformer architecture configuration
self.d_model = input_size * len(lags_sequence) + self._number_of_features
self.encoder_attention_heads = encoder_attention_heads
self.decoder_attention_heads = decoder_attention_heads
self.encoder_ffn_dim = encoder_ffn_dim
self.decoder_ffn_dim = decoder_ffn_dim
self.encoder_layers = encoder_layers
self.decoder_layers = decoder_layers
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
self.encoder_layerdrop = encoder_layerdrop
self.decoder_layerdrop = decoder_layerdrop
self.activation_function = activation_function
self.init_std = init_std
self.output_attentions = False
self.output_hidden_states = False
self.use_cache = use_cache
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
@property
def _number_of_features(self) -> int:
return (
sum(self.embedding_dimension)
+ self.num_dynamic_real_features
+ self.num_time_features
+ max(1, self.num_static_real_features) # there is at least one dummy static real feature
+ 1 # the log(scale)
)
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Time Series Transformer model."""
import random
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from torch.distributions import (
AffineTransform,
Distribution,
NegativeBinomial,
Normal,
StudentT,
TransformedDistribution,
)
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_time_series_transformer import TimeSeriesTransformerConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "TimeSeriesTransformerConfig"
TIME_SERIES_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"huggingface/time-series-transformer-tourism-monthly",
# See all TimeSeriesTransformer models at https://huggingface.co/models?filter=time_series_transformer
]
class AffineTransformed(TransformedDistribution):
def __init__(self, base_distribution: Distribution, loc=None, scale=None):
self.scale = 1.0 if scale is None else scale
self.loc = 0.0 if loc is None else loc
super().__init__(base_distribution, [AffineTransform(self.loc, self.scale)])
@property
def mean(self):
"""
Returns the mean of the distribution.
"""
return self.base_dist.mean * self.scale + self.loc
@property
def variance(self):
"""
Returns the variance of the distribution.
"""
return self.base_dist.variance * self.scale**2
@property
def stddev(self):
"""
Returns the standard deviation of the distribution.
"""
return self.variance.sqrt()
class ParameterProjection(nn.Module):
def __init__(
self,
in_features: int,
args_dim: Dict[str, int],
domain_map: Callable[..., Tuple[torch.Tensor]],
**kwargs,
) -> None:
super().__init__(**kwargs)
self.args_dim = args_dim
self.proj = nn.ModuleList([nn.Linear(in_features, dim) for dim in args_dim.values()])
self.domain_map = domain_map
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
params_unbounded = [proj(x) for proj in self.proj]
return self.domain_map(*params_unbounded)
class LambdaLayer(nn.Module):
def __init__(self, function):
super().__init__()
self.function = function
def forward(self, x, *args):
return self.function(x, *args)
class DistributionOutput:
distr_cls: type
in_features: int
args_dim: Dict[str, int]
def __init__(self) -> None:
pass
def _base_distribution(self, distr_args):
return self.distr_cls(*distr_args)
def distribution(
self,
distr_args,
loc: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
) -> Distribution:
distr = self._base_distribution(distr_args)
if loc is None and scale is None:
return distr
else:
return AffineTransformed(distr, loc=loc, scale=scale)
@property
def event_shape(self) -> Tuple:
r"""
Shape of each individual event contemplated by the distributions that this object constructs.
"""
raise NotImplementedError()
@property
def event_dim(self) -> int:
r"""
Number of event dimensions, i.e., length of the `event_shape` tuple, of the distributions that this object
constructs.
"""
return len(self.event_shape)
@property
def value_in_support(self) -> float:
r"""
A float that will have a valid numeric value when computing the log-loss of the corresponding distribution. By
default 0.0. This value will be used when padding data series.
"""
return 0.0
def get_parameter_projection(self, in_features: int) -> nn.Module:
r"""
Return the parameter projection layer that maps the input to the appropriate parameters of the distribution.
"""
return ParameterProjection(
in_features=in_features,
args_dim=self.args_dim,
domain_map=LambdaLayer(self.domain_map),
)
def domain_map(self, *args: torch.Tensor):
r"""
Converts arguments to the right shape and domain. The domain depends on the type of distribution, while the
correct shape is obtained by reshaping the trailing axis in such a way that the returned tensors define a
distribution of the right event_shape.
"""
raise NotImplementedError()
@classmethod
def squareplus(cls, x: torch.Tensor) -> torch.Tensor:
r"""
Helper to map inputs to the positive orthant by applying the square-plus operation. Reference:
https://twitter.com/jon_barron/status/1387167648669048833
"""
return (x + torch.sqrt(torch.square(x) + 4.0)) / 2.0
class StudentTOutput(DistributionOutput):
args_dim: Dict[str, int] = {"df": 1, "loc": 1, "scale": 1}
distr_cls: type = StudentT
@classmethod
def domain_map(cls, df: torch.Tensor, loc: torch.Tensor, scale: torch.Tensor):
scale = cls.squareplus(scale)
df = 2.0 + cls.squareplus(df)
return df.squeeze(-1), loc.squeeze(-1), scale.squeeze(-1)
@property
def event_shape(self) -> Tuple:
return ()
class NormalOutput(DistributionOutput):
args_dim: Dict[str, int] = {"loc": 1, "scale": 1}
distr_cls: type = Normal
@classmethod
def domain_map(cls, loc: torch.Tensor, scale: torch.Tensor):
scale = cls.squareplus(scale)
return loc.squeeze(-1), scale.squeeze(-1)
@property
def event_shape(self) -> Tuple:
return ()
class NegativeBinomialOutput(DistributionOutput):
args_dim: Dict[str, int] = {"total_count": 1, "logits": 1}
distr_cls: type = NegativeBinomial
@classmethod
def domain_map(cls, total_count: torch.Tensor, logits: torch.Tensor):
total_count = cls.squareplus(total_count)
return total_count.squeeze(-1), logits.squeeze(-1)
def _base_distribution(self, distr_args) -> Distribution:
total_count, logits = distr_args
return self.distr_cls(total_count=total_count, logits=logits)
# Overwrites the parent class method. We cannot scale using the affine
# transformation since negative binomial should return integers. Instead
# we scale the parameters.
def distribution(
self,
distr_args,
loc: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
) -> Distribution:
total_count, logits = distr_args
if scale is not None:
logits += scale.log()
return NegativeBinomial(total_count=total_count, logits=logits)
@property
def event_shape(self) -> Tuple:
return ()
class FeatureEmbedder(nn.Module):
def __init__(self, cardinalities: List[int], embedding_dims: List[int]) -> None:
super().__init__()
self.num_features = len(cardinalities)
self.embedders = nn.ModuleList([nn.Embedding(c, d) for c, d in zip(cardinalities, embedding_dims)])
def forward(self, features: torch.Tensor) -> torch.Tensor:
if self.num_features > 1:
# we slice the last dimension, giving an array of length
# self.num_features with shape (N,T) or (N)
cat_feature_slices = torch.chunk(features, self.num_features, dim=-1)
else:
cat_feature_slices = [features]
return torch.cat(
[
embed(cat_feature_slice.squeeze(-1))
for embed, cat_feature_slice in zip(self.embedders, cat_feature_slices)
],
dim=-1,
)
class MeanScaler(nn.Module):
"""
Computes a scaling factor as the weighted average absolute value along dimension `dim`, and scales the data
accordingly.
Args:
dim (`int`):
Dimension along which to compute the scale.
keepdim (`bool`, *optional*, defaults to `False`):
Controls whether to retain dimension `dim` (of length 1) in the scale tensor, or suppress it.
minimum_scale (`float`, *optional*, defaults to 1e-10):
Default scale that is used for elements that are constantly zero along dimension `dim`.
"""
def __init__(self, dim: int, keepdim: bool = False, minimum_scale: float = 1e-10):
super().__init__()
if not dim > 0:
raise ValueError("Cannot compute scale along dim = 0 (batch dimension), please provide dim > 0")
self.dim = dim
self.keepdim = keepdim
self.register_buffer("minimum_scale", torch.tensor(minimum_scale))
def forward(self, data: torch.Tensor, weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# these will have shape (N, C)
total_weight = weights.sum(dim=self.dim)
weighted_sum = (data.abs() * weights).sum(dim=self.dim)
# first compute a global scale per-dimension
total_observed = total_weight.sum(dim=0)
denominator = torch.max(total_observed, torch.ones_like(total_observed))
default_scale = weighted_sum.sum(dim=0) / denominator
# then compute a per-item, per-dimension scale
denominator = torch.max(total_weight, torch.ones_like(total_weight))
scale = weighted_sum / denominator
# use per-batch scale when no element is observed
# or when the sequence contains only zeros
scale = (
torch.max(
self.minimum_scale,
torch.where(
weighted_sum > torch.zeros_like(weighted_sum),
scale,
default_scale * torch.ones_like(total_weight),
),
)
.detach()
.unsqueeze(dim=self.dim)
)
return data / scale, scale if self.keepdim else scale.squeeze(dim=self.dim)
class NOPScaler(nn.Module):
"""
Assigns a scaling factor equal to 1 along dimension `dim`, and therefore applies no scaling to the input data.
Args:
dim (`int`):
Dimension along which to compute the scale.
keepdim (`bool`, *optional*, defaults to `False`):
Controls whether to retain dimension `dim` (of length 1) in the scale tensor, or suppress it.
"""
def __init__(self, dim: int, keepdim: bool = False):
super().__init__()
self.dim = dim
self.keepdim = keepdim
def forward(self, data: torch.Tensor, observed_indicator: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
scale = torch.ones_like(data).mean(dim=self.dim, keepdim=self.keepdim)
return data, scale
def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor] = None, dim=None) -> torch.Tensor:
"""
Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero,
meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`.
Args:
input_tensor (`torch.FloatTensor`):
Input tensor, of which the average must be computed.
weights (`torch.FloatTensor`, *optional*):
Weights tensor, of the same shape as `input_tensor`.
dim (`int`, *optional*):
The dim along which to average `input_tensor`.
Returns:
`torch.FloatTensor`: The tensor with values averaged along the specified `dim`.
"""
if weights is not None:
weighted_tensor = torch.where(weights != 0, input_tensor * weights, torch.zeros_like(input_tensor))
sum_weights = torch.clamp(weights.sum(dim=dim) if dim else weights.sum(), min=1.0)
return (weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum()) / sum_weights
else:
return input_tensor.mean(dim=dim)
class NegativeLogLikelihood:
"""
Computes the negative log likelihood loss.
Args:
beta (`float`):
Float in range (0, 1). The beta parameter from the paper: "On the Pitfalls of Heteroscedastic Uncertainty
Estimation with Probabilistic Neural Networks" by [Seitzer et al.
2022](https://openreview.net/forum?id=aPOpXlnV1T).
"""
beta: float = 0.0
def __call__(self, input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor:
nll = -input.log_prob(target)
if self.beta > 0.0:
variance = input.variance
nll = nll * (variance.detach() ** self.beta)
return nll
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
@dataclass
class Seq2SeqTimeSeriesModelOutput(ModelOutput):
"""
Base class for model encoder's outputs that also contains pre-computed hidden states that can speed up sequential
decoding.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the decoder of the model.
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
hidden_size)` is output.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs.
decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs.
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
scale: (`torch.FloatTensor` of shape `(batch_size,)`, *optional*):
Scaling values of each time series' context window which is used to give the model inputs of the same
magnitude and then used to rescale to the original scale.
static_features: (`torch.FloatTensor` of shape `(batch_size, feature size)`, *optional*):
Static features of each time series' in a batch which are copied to the covariates at inference time.
"""
last_hidden_state: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
scale: Optional[torch.FloatTensor] = None
static_features: Optional[torch.FloatTensor] = None
@dataclass
class Seq2SeqTimeSeriesPredictionOutput(ModelOutput):
"""
Base class for model's predictions outputs that also contain the loss as well parameters of the chosen
distribution.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when a `future_values` is provided):
Distributional loss.
params (`torch.FloatTensor` of shape `(batch_size, num_samples, num_params)`):
Parameters of the chosen distribution.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
scale: (`torch.FloatTensor` of shape `(batch_size,)`, *optional*):
Scaling values of each time series' context window which is used to give the model inputs of the same
magnitude and then used to rescale to the original scale.
static_features: (`torch.FloatTensor` of shape `(batch_size, feature size)`, *optional*):
Static features of each time series' in a batch which are copied to the covariates at inference time.
"""
loss: Optional[torch.FloatTensor] = None
params: Optional[Tuple[torch.FloatTensor]] = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
scale: Optional[torch.FloatTensor] = None
static_features: Optional[torch.FloatTensor] = None
@dataclass
class SampleTimeSeriesPredictionOutput(ModelOutput):
sequences: torch.FloatTensor = None
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->TimeSeriesTransformer
class TimeSeriesTransformerAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
# get key, value proj
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if output_attentions:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped, past_key_value
# Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->TimeSeriesTransformer
class TimeSeriesTransformerEncoderLayer(nn.Module):
def __init__(self, config: TimeSeriesTransformerConfig):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = TimeSeriesTransformerAttention(
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: torch.FloatTensor,
layer_head_mask: torch.FloatTensor,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states, attn_weights, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
if hidden_states.dtype == torch.float16 and (
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
):
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->TimeSeriesTransformer
class TimeSeriesTransformerDecoderLayer(nn.Module):
def __init__(self, config: TimeSeriesTransformerConfig):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = TimeSeriesTransformerAttention(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = TimeSeriesTransformerAttention(
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
encoder_hidden_states (`torch.FloatTensor`):
cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`.
cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
size `(decoder_attention_heads,)`.
past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
# Self Attention
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
# add present self-attn cache to positions 1,2 of present_key_value tuple
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
past_key_value=self_attn_past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
# Cross-Attention Block
cross_attn_present_key_value = None
cross_attn_weights = None
if encoder_hidden_states is not None:
residual = hidden_states
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
# add cross-attn to positions 3,4 of present_key_value tuple
present_key_value = present_key_value + cross_attn_present_key_value
# Fully Connected
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
if use_cache:
outputs += (present_key_value,)
return outputs
class TimeSeriesTransformerPreTrainedModel(PreTrainedModel):
config_class = TimeSeriesTransformerConfig
base_model_prefix = "model"
main_input_name = "past_values"
supports_gradient_checkpointing = True
def _init_weights(self, module):
std = self.config.init_std
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (TimeSeriesTransformerDecoder, TimeSeriesTransformerEncoder)):
module.gradient_checkpointing = value
TIME_SERIES_TRANSFORMER_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`TimeSeriesTransformerConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
TIME_SERIES_TRANSFORMER_INPUTS_DOCSTRING = r"""
Args:
past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Past values of the time series, that serve as context in order to predict the future. These values may
contain lags, i.e. additional values from the past which are added in order to serve as "extra context".
The `past_values` is what the Transformer encoder gets as input (with optional additional features, such as
`static_categorical_features`, `static_real_features`, `past_time_features`).
The sequence length here is equal to `context_length` + `max(config.lags_sequence)`.
Missing values need to be replaced with zeros.
past_time_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_features)`, *optional*):
Optional time features, which the model internally will add to `past_values`. These could be things like
"month of year", "day of the month", etc. encoded as vectors (for instance as Fourier features). These
could also be so-called "age" features, which basically help the model know "at which point in life" a
time-series is. Age features have small values for distant past time steps and increase monotonically the
more we approach the current time step.
These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where
the position encodings are learned from scratch internally as parameters of the model, the Time Series
Transformer requires to provide additional time features.
The Time Series Transformer only learns additional embeddings for `static_categorical_features`.
past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected in
`[0, 1]`:
- 1 for values that are **observed**,
- 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
static_categorical_features (`torch.LongTensor` of shape `(batch_size, number of static categorical features)`, *optional*):
Optional static categorical features for which the model will learn an embedding, which it will add to the
values of the time series.
Static categorical features are features which have the same value for all time steps (static over time).
A typical example of a static categorical feature is a time series ID.
static_real_features (`torch.FloatTensor` of shape `(batch_size, number of static real features)`, *optional*):
Optional static real features which the model will add to the values of the time series.
Static real features are features which have the same value for all time steps (static over time).
A typical example of a static real feature is promotion information.
future_values (`torch.FloatTensor` of shape `(batch_size, prediction_length)`):
Future values of the time series, that serve as labels for the model. The `future_values` is what the
Transformer needs to learn to output, given the `past_values`.
See the demo notebook and code snippets for details.
Missing values need to be replaced with zeros.
future_time_features (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_features)`, *optional*):
Optional time features, which the model internally will add to `future_values`. These could be things like
"month of year", "day of the month", etc. encoded as vectors (for instance as Fourier features). These
could also be so-called "age" features, which basically help the model know "at which point in life" a
time-series is. Age features have small values for distant past time steps and increase monotonically the
more we approach the current time step.
These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where
the position encodings are learned from scratch internally as parameters of the model, the Time Series
Transformer requires to provide additional features.
The Time Series Transformer only learns additional embeddings for `static_categorical_features`.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on certain token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Mask to avoid performing attention on certain token indices. By default, a causal mask will be used, to
make sure the model can only look at previous inputs in order to predict the future.
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
Tuple consists of `last_hidden_state`, `hidden_states` (*optional*) and `attentions` (*optional*)
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` (*optional*) is a sequence of
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
class TimeSeriesTransformerEncoder(TimeSeriesTransformerPreTrainedModel):
"""
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
[`TimeSeriesTransformerEncoderLayer`].
Args:
config: TimeSeriesTransformerConfig
"""
def __init__(self, config: TimeSeriesTransformerConfig):
super().__init__(config)
self.dropout = config.dropout
self.layerdrop = config.encoder_layerdrop
embed_dim = config.d_model
self.layers = nn.ModuleList([TimeSeriesTransformerEncoderLayer(config) for _ in range(config.encoder_layers)])
self.layernorm_embedding = nn.LayerNorm(embed_dim)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
r"""
Args:
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
hidden_states = inputs_embeds
hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
if head_mask.size()[0] != (len(self.layers)):
raise ValueError(
f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
f" {head_mask.size()[0]}."
)
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)
class TimeSeriesTransformerDecoder(TimeSeriesTransformerPreTrainedModel):
"""
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a
[`TimeSeriesTransformerDecoderLayer`]
Args:
config: TimeSeriesTransformerConfig
"""
def __init__(self, config: TimeSeriesTransformerConfig):
super().__init__(config)
self.dropout = config.dropout
self.layerdrop = config.decoder_layerdrop
self.layers = nn.ModuleList([TimeSeriesTransformerDecoderLayer(config) for _ in range(config.decoder_layers)])
self.layernorm_embedding = nn.LayerNorm(config.d_model)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(inputs_embeds.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward(
self,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
r"""
Args:
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
of the decoder.
encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
cross-attention on hidden heads. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
input_shape = inputs_embeds.size()[:-1]
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
hidden_states = inputs_embeds
hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
if attn_mask is not None:
if attn_mask.size()[0] != (len(self.layers)):
raise ValueError(
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
f" {head_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop):
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)
@add_start_docstrings(
"The bare Time Series Transformer Model outputting raw hidden-states without any specific head on top.",
TIME_SERIES_TRANSFORMER_START_DOCSTRING,
)
class TimeSeriesTransformerModel(TimeSeriesTransformerPreTrainedModel):
def __init__(self, config: TimeSeriesTransformerConfig):
super().__init__(config)
if config.scaling:
self.scaler = MeanScaler(dim=1, keepdim=True)
else:
self.scaler = NOPScaler(dim=1, keepdim=True)
self.embedder = FeatureEmbedder(
cardinalities=config.cardinality,
embedding_dims=config.embedding_dimension,
)
# transformer encoder-decoder and mask initializer
self.encoder = TimeSeriesTransformerEncoder(config)
self.decoder = TimeSeriesTransformerDecoder(config)
# Initialize weights and apply final processing
self.post_init()
@property
def _past_length(self) -> int:
return self.config.context_length + max(self.config.lags_sequence)
def get_lagged_subsequences(
self, sequence: torch.Tensor, subsequences_length: int, shift: int = 0
) -> torch.Tensor:
"""
Returns lagged subsequences of a given sequence. Returns a tensor of shape (N, S, C, I),
where S = subsequences_length and I = len(indices), containing lagged subsequences. Specifically, lagged[i,
j, :, k] = sequence[i, -indices[k]-S+j, :].
Args:
sequence: Tensor
The sequence from which lagged subsequences should be extracted. Shape: (N, T, C).
subsequences_length : int
Length of the subsequences to be extracted.
shift: int
Shift the lags by this amount back.
"""
sequence_length = sequence.shape[1]
indices = [lag - shift for lag in self.config.lags_sequence]
try:
assert max(indices) + subsequences_length <= sequence_length, (
f"lags cannot go further than history length, found lag {max(indices)} "
f"while history length is only {sequence_length}"
)
except AssertionError as e:
e.args += (max(indices), sequence_length)
raise
lagged_values = []
for lag_index in indices:
begin_index = -lag_index - subsequences_length
end_index = -lag_index if lag_index > 0 else None
lagged_values.append(sequence[:, begin_index:end_index, ...])
return torch.stack(lagged_values, dim=-1)
def create_network_inputs(
self,
past_values: torch.Tensor,
past_time_features: torch.Tensor,
static_categorical_features: torch.Tensor,
static_real_features: torch.Tensor,
past_observed_mask: Optional[torch.Tensor] = None,
future_values: Optional[torch.Tensor] = None,
future_time_features: Optional[torch.Tensor] = None,
):
# time feature
time_feat = (
torch.cat(
(
past_time_features[:, self._past_length - self.config.context_length :, ...],
future_time_features,
),
dim=1,
)
if future_values is not None
else past_time_features[:, self._past_length - self.config.context_length :, ...]
)
# target
if past_observed_mask is None:
past_observed_mask = torch.ones_like(past_values)
context = past_values[:, -self.config.context_length :]
observed_context = past_observed_mask[:, -self.config.context_length :]
_, scale = self.scaler(context, observed_context)
inputs = (
torch.cat((past_values, future_values), dim=1) / scale
if future_values is not None
else past_values / scale
)
inputs_length = (
self._past_length + self.config.prediction_length if future_values is not None else self._past_length
)
try:
assert inputs.shape[1] == inputs_length, (
f"input length {inputs.shape[1]} and dynamic feature lengths {inputs_length} does not match",
)
except AssertionError as e:
e.args += (inputs.shape[1], inputs_length)
raise
subsequences_length = (
self.config.context_length + self.config.prediction_length
if future_values is not None
else self.config.context_length
)
# embeddings
embedded_cat = self.embedder(static_categorical_features)
static_feat = torch.cat(
(embedded_cat, static_real_features, scale.log()),
dim=1,
)
expanded_static_feat = static_feat.unsqueeze(1).expand(-1, time_feat.shape[1], -1)
features = torch.cat((expanded_static_feat, time_feat), dim=-1)
# sequence = torch.cat((prior_input, inputs), dim=1)
lagged_sequence = self.get_lagged_subsequences(sequence=inputs, subsequences_length=subsequences_length)
lags_shape = lagged_sequence.shape
reshaped_lagged_sequence = lagged_sequence.reshape(lags_shape[0], lags_shape[1], -1)
transformer_inputs = torch.cat((reshaped_lagged_sequence, features), dim=-1)
return transformer_inputs, scale, static_feat
def enc_dec_outputs(self, transformer_inputs):
enc_input = transformer_inputs[:, : self.config.context_length, ...]
dec_input = transformer_inputs[:, self.config.context_length :, ...]
encoder_outputs = self.encoder(inputs_embeds=enc_input)
decoder_outputs = self.decoder(
inputs_embeds=dec_input, encoder_hidden_states=encoder_outputs.last_hidden_state
)
return encoder_outputs, decoder_outputs
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
@add_start_docstrings_to_model_forward(TIME_SERIES_TRANSFORMER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqTimeSeriesModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
past_values: torch.Tensor,
past_time_features: torch.Tensor,
past_observed_mask: torch.Tensor,
static_categorical_features: torch.Tensor,
static_real_features: torch.Tensor,
future_values: Optional[torch.Tensor] = None,
future_time_features: Optional[torch.Tensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
use_cache: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
r"""
Returns:
Examples:
```python
>>> from transformers import TimeSeriesTransformerModel
>>> import torch
>>> model = TimeSeriesTransformerModel.from_pretrained("huggingface/tst-base")
>>> inputs = dict()
>>> batch_size = 2
>>> cardinality = 5
>>> num_time_features = 10
>>> content_length = 8
>>> prediction_length = 2
>>> lags_sequence = [2, 3]
>>> past_length = context_length + max(lags_sequence)
>>> # encoder inputs
>>> inputs["static_categorical_features"] = ids_tensor([batch_size, 1], cardinality)
>>> inputs["static_real_features"] = torch.randn([batch_size, 1])
>>> inputs["past_time_features"] = torch.randn([batch_size, past_length, num_time_features])
>>> inputs["past_values"] = torch.randn([batch_size, past_length])
>>> inputs["past_observed_mask"] = torch.ones([batch_size, past_length])
>>> # decoder inputs
>>> inputs["future_time_features"] = torch.randn([batch_size, prediction_length, num_time_features])
>>> inputs["future_values"] = torch.randn([batch_size, prediction_length])
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_inputs, scale, static_feat = self.create_network_inputs(
past_values=past_values,
past_time_features=past_time_features,
past_observed_mask=past_observed_mask,
static_categorical_features=static_categorical_features,
static_real_features=static_real_features,
future_values=future_values,
future_time_features=future_time_features,
)
if encoder_outputs is None:
enc_input = transformer_inputs[:, : self.config.context_length, ...]
encoder_outputs = self.encoder(
inputs_embeds=enc_input,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
dec_input = transformer_inputs[:, self.config.context_length :, ...]
decoder_outputs = self.decoder(
inputs_embeds=dec_input,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs[0],
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if not return_dict:
return decoder_outputs + encoder_outputs + (scale, static_feat)
return Seq2SeqTimeSeriesModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
scale=scale,
static_features=static_feat,
)
@add_start_docstrings(
"The Time Series Transformer Model with a distribution head on top for time-series forecasting.",
TIME_SERIES_TRANSFORMER_START_DOCSTRING,
)
class TimeSeriesTransformerForPrediction(TimeSeriesTransformerPreTrainedModel):
def __init__(self, config: TimeSeriesTransformerConfig):
super().__init__(config)
self.model = TimeSeriesTransformerModel(config)
if config.distribution_output == "student_t":
self.distribution_output = StudentTOutput()
elif config.distribution_output == "normal":
self.distribution_output = NormalOutput()
elif config.distribution_output == "negative_binomial":
self.distribution_output = NegativeBinomialOutput()
else:
raise ValueError(f"Unknown distribution output {config.distribution_output}")
self.parameter_projection = self.distribution_output.get_parameter_projection(self.model.config.d_model)
self.target_shape = self.distribution_output.event_shape
if config.loss == "nll":
self.loss = NegativeLogLikelihood()
else:
raise ValueError(f"Unknown loss function {config.loss}")
# Initialize weights of distribution_output and apply final processing
self.post_init()
def output_params(self, dec_output):
return self.parameter_projection(dec_output)
def get_encoder(self):
return self.model.get_encoder()
def get_decoder(self):
return self.model.get_decoder()
@torch.jit.ignore
def output_distribution(self, params, scale=None, trailing_n=None) -> torch.distributions.Distribution:
sliced_params = params
if trailing_n is not None:
sliced_params = [p[:, -trailing_n:] for p in params]
return self.distribution_output.distribution(sliced_params, scale=scale)
@add_start_docstrings_to_model_forward(TIME_SERIES_TRANSFORMER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqTimeSeriesModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
past_values: torch.Tensor,
past_time_features: torch.Tensor,
past_observed_mask: torch.Tensor,
static_categorical_features: torch.Tensor,
static_real_features: torch.Tensor,
future_values: Optional[torch.Tensor] = None,
future_time_features: Optional[torch.Tensor] = None,
future_observed_mask: Optional[torch.Tensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
use_cache: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
r"""
Returns:
future_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
Boolean mask to indicate which `future_values` were observed and which were missing. Mask values selected
in `[0, 1]`:
- 1 for values that are **observed**,
- 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
This mask is used to filter out missing values for the final loss calculation.
Examples:
```python
>>> from transformers import TimeSeriesTransformerForPrediction
>>> import torch
>>> model = TimeSeriesTransformerForPrediction.from_pretrained("huggingface/tst-base")
>>> inputs = dict()
>>> batch_size = 2
>>> cardinality = 5
>>> num_time_features = 10
>>> content_length = 8
>>> prediction_length = 2
>>> lags_sequence = [2, 3]
>>> past_length = context_length + max(lags_sequence)
>>> # encoder inputs
>>> inputs["static_categorical_features"] = ids_tensor([batch_size, 1], cardinality)
>>> inputs["static_real_features"] = torch.randn([batch_size, 1])
>>> inputs["past_time_features"] = torch.randn([batch_size, past_length, num_time_features])
>>> inputs["past_values"] = torch.randn([batch_size, past_length])
>>> inputs["past_observed_mask"] = torch.ones([batch_size, past_length])
>>> # decoder inputs
>>> inputs["future_time_features"] = torch.randn([batch_size, prediction_length, num_time_features])
>>> inputs["future_values"] = torch.randn([batch_size, prediction_length])
>>> outputs = model(**inputs)
>>> loss = outputs.loss
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if future_values is not None:
use_cache = False
outputs = self.model(
past_values=past_values,
past_time_features=past_time_features,
past_observed_mask=past_observed_mask,
static_categorical_features=static_categorical_features,
static_real_features=static_real_features,
future_values=future_values,
future_time_features=future_time_features,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
use_cache=use_cache,
return_dict=return_dict,
)
prediction_loss = None
params = None
if future_values is not None:
params = self.output_params(outputs[0]) # outputs.last_hidden_state
distribution = self.output_distribution(params, outputs[-2]) # outputs.scale
loss = self.loss(distribution, future_values)
if future_observed_mask is None:
future_observed_mask = torch.ones_like(future_values)
if len(self.target_shape) == 0:
loss_weights = future_observed_mask
else:
loss_weights = future_observed_mask.min(dim=-1, keepdim=False)
prediction_loss = weighted_average(loss, weights=loss_weights)
if not return_dict:
outputs = ((params,) + outputs[1:]) if params is not None else outputs[1:]
return ((prediction_loss,) + outputs) if prediction_loss is not None else outputs
return Seq2SeqTimeSeriesPredictionOutput(
loss=prediction_loss,
params=params,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
scale=outputs.scale,
static_features=outputs.static_features,
)
@torch.no_grad()
def generate(
self,
static_categorical_features: torch.Tensor,
static_real_features: torch.Tensor,
past_time_features: torch.Tensor,
past_values: torch.Tensor,
past_observed_mask: torch.Tensor,
future_time_features: Optional[torch.Tensor],
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> torch.Tensor:
outputs = self(
static_categorical_features=static_categorical_features,
static_real_features=static_real_features,
past_time_features=past_time_features,
past_values=past_values,
past_observed_mask=past_observed_mask,
future_time_features=future_time_features,
future_values=None,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
use_cache=True,
)
decoder = self.model.get_decoder()
enc_last_hidden = outputs.encoder_last_hidden_state
scale = outputs.scale
static_feat = outputs.static_features
num_parallel_samples = self.config.num_parallel_samples
repeated_scale = scale.repeat_interleave(repeats=num_parallel_samples, dim=0)
repeated_past_values = past_values.repeat_interleave(repeats=num_parallel_samples, dim=0) / repeated_scale
expanded_static_feat = static_feat.unsqueeze(1).expand(-1, future_time_features.shape[1], -1)
features = torch.cat((expanded_static_feat, future_time_features), dim=-1)
repeated_features = features.repeat_interleave(repeats=num_parallel_samples, dim=0)
repeated_enc_last_hidden = enc_last_hidden.repeat_interleave(repeats=num_parallel_samples, dim=0)
future_samples = []
# greedy decoding
for k in range(self.config.prediction_length):
lagged_sequence = self.model.get_lagged_subsequences(
sequence=repeated_past_values,
subsequences_length=1 + k,
shift=1,
)
lags_shape = lagged_sequence.shape
reshaped_lagged_sequence = lagged_sequence.reshape(lags_shape[0], lags_shape[1], -1)
decoder_input = torch.cat((reshaped_lagged_sequence, repeated_features[:, : k + 1]), dim=-1)
dec_output = decoder(inputs_embeds=decoder_input, encoder_hidden_states=repeated_enc_last_hidden)
dec_last_hidden = dec_output.last_hidden_state
params = self.parameter_projection(dec_last_hidden[:, -1:])
distr = self.output_distribution(params, scale=repeated_scale)
next_sample = distr.sample()
repeated_past_values = torch.cat((repeated_past_values, next_sample / repeated_scale), dim=1)
future_samples.append(next_sample)
concat_future_samples = torch.cat(future_samples, dim=1)
return SampleTimeSeriesPredictionOutput(
sequences=concat_future_samples.reshape(
(-1, num_parallel_samples, self.config.prediction_length) + self.target_shape,
)
)
......@@ -4811,6 +4811,30 @@ def load_tf_weights_in_t5(*args, **kwargs):
requires_backends(load_tf_weights_in_t5, ["torch"])
TIME_SERIES_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TimeSeriesTransformerForPrediction(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class TimeSeriesTransformerModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class TimeSeriesTransformerPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch TimeSeriesTransformer model. """
import inspect
import tempfile
import unittest
from huggingface_hub import hf_hub_download
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
TOLERANCE = 1e-4
if is_torch_available():
import torch
from transformers import (
TimeSeriesTransformerConfig,
TimeSeriesTransformerForPrediction,
TimeSeriesTransformerModel,
)
from transformers.models.time_series_transformer.modeling_time_series_transformer import (
TimeSeriesTransformerDecoder,
TimeSeriesTransformerEncoder,
)
@require_torch
class TimeSeriesTransformerModelTester:
def __init__(
self,
parent,
batch_size=13,
prediction_length=7,
context_length=14,
cardinality=19,
embedding_dimension=5,
num_time_features=4,
is_training=True,
hidden_size=16,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=4,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
lags_sequence=[1, 2, 3, 4, 5],
):
self.parent = parent
self.batch_size = batch_size
self.prediction_length = prediction_length
self.context_length = context_length
self.cardinality = cardinality
self.num_time_features = num_time_features
self.lags_sequence = lags_sequence
self.embedding_dimension = embedding_dimension
self.is_training = is_training
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.encoder_seq_length = context_length
self.decoder_seq_length = prediction_length
def get_config(self):
return TimeSeriesTransformerConfig(
encoder_layers=self.num_hidden_layers,
decoder_layers=self.num_hidden_layers,
encoder_attention_heads=self.num_attention_heads,
decoder_attention_heads=self.num_attention_heads,
encoder_ffn_dim=self.intermediate_size,
decoder_ffn_dim=self.intermediate_size,
dropout=self.hidden_dropout_prob,
attention_dropout=self.attention_probs_dropout_prob,
prediction_length=self.prediction_length,
context_length=self.context_length,
lags_sequence=self.lags_sequence,
num_time_features=self.num_time_features,
num_static_categorical_features=1,
cardinality=[self.cardinality],
embedding_dimension=[self.embedding_dimension],
)
def prepare_time_series_transformer_inputs_dict(self, config):
_past_length = config.context_length + max(config.lags_sequence)
static_categorical_features = ids_tensor([self.batch_size, 1], config.cardinality[0])
static_real_features = floats_tensor([self.batch_size, 1])
past_time_features = floats_tensor([self.batch_size, _past_length, config.num_time_features])
past_values = floats_tensor([self.batch_size, _past_length])
past_observed_mask = floats_tensor([self.batch_size, _past_length])
# decoder inputs
future_time_features = floats_tensor([self.batch_size, config.prediction_length, config.num_time_features])
future_values = floats_tensor([self.batch_size, config.prediction_length])
inputs_dict = {
"past_values": past_values,
"static_categorical_features": static_categorical_features,
"static_real_features": static_real_features,
"past_time_features": past_time_features,
"past_observed_mask": past_observed_mask,
"future_time_features": future_time_features,
"future_values": future_values,
}
return inputs_dict
def prepare_config_and_inputs(self):
config = self.get_config()
inputs_dict = self.prepare_time_series_transformer_inputs_dict(config)
return config, inputs_dict
def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs()
return config, inputs_dict
def check_encoder_decoder_model_standalone(self, config, inputs_dict):
model = TimeSeriesTransformerModel(config=config).to(torch_device).eval()
outputs = model(**inputs_dict)
encoder_last_hidden_state = outputs.encoder_last_hidden_state
last_hidden_state = outputs.last_hidden_state
with tempfile.TemporaryDirectory() as tmpdirname:
encoder = model.get_encoder()
encoder.save_pretrained(tmpdirname)
encoder = TimeSeriesTransformerEncoder.from_pretrained(tmpdirname).to(torch_device)
transformer_inputs, _, _ = model.create_network_inputs(**inputs_dict)
enc_input = transformer_inputs[:, : config.context_length, ...]
dec_input = transformer_inputs[:, config.context_length :, ...]
encoder_last_hidden_state_2 = encoder(inputs_embeds=enc_input)[0]
self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3)
with tempfile.TemporaryDirectory() as tmpdirname:
decoder = model.get_decoder()
decoder.save_pretrained(tmpdirname)
decoder = TimeSeriesTransformerDecoder.from_pretrained(tmpdirname).to(torch_device)
last_hidden_state_2 = decoder(
inputs_embeds=dec_input,
encoder_hidden_states=encoder_last_hidden_state,
)[0]
self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 1e-3)
@require_torch
class TimeSeriesTransformerModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(TimeSeriesTransformerModel, TimeSeriesTransformerForPrediction) if is_torch_available() else ()
)
all_generative_model_classes = (TimeSeriesTransformerForPrediction,) if is_torch_available() else ()
is_encoder_decoder = True
test_pruning = False
test_head_masking = False
test_missing_keys = False
test_torchscript = False
test_inputs_embeds = False
test_model_common_attributes = False
def setUp(self):
self.model_tester = TimeSeriesTransformerModelTester(self)
self.config_tester = ConfigTester(self, config_class=TimeSeriesTransformerConfig, has_text_modality=False)
def test_config(self):
self.config_tester.run_common_tests()
def test_save_load_strict(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
for model_class in self.all_model_classes:
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], [])
def test_encoder_decoder_model_standalone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs)
# Ignore since we have no tokens embeddings
def test_resize_tokens_embeddings(self):
pass
# # Input is 'static_categorical_features' not 'input_ids'
def test_model_main_input_name(self):
model_signature = inspect.signature(getattr(TimeSeriesTransformerModel, "forward"))
# The main input is the name of the argument after `self`
observed_main_input_name = list(model_signature.parameters.keys())[1]
self.assertEqual(TimeSeriesTransformerModel.main_input_name, observed_main_input_name)
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = [
"past_values",
"past_time_features",
"past_observed_mask",
"static_categorical_features",
"static_real_features",
"future_values",
"future_time_features",
]
expected_arg_names.extend(
[
"future_observed_mask",
"decoder_attention_mask",
"head_mask",
"decoder_head_mask",
"cross_attn_head_mask",
"encoder_outputs",
"past_key_values",
"output_hidden_states",
"output_attentions",
"use_cache",
"return_dict",
]
if "future_observed_mask" in arg_names
else [
"decoder_attention_mask",
"head_mask",
"decoder_head_mask",
"cross_attn_head_mask",
"encoder_outputs",
"past_key_values",
"output_hidden_states",
"output_attentions",
"use_cache",
"return_dict",
]
)
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
seq_len = getattr(self.model_tester, "seq_length", None)
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
config.return_dict = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_seq_length],
)
out_len = len(outputs)
correct_outlen = 6
if "last_hidden_state" in outputs:
correct_outlen += 1
if "past_key_values" in outputs:
correct_outlen += 1 # past_key_values have been returned
if "loss" in outputs:
correct_outlen += 1
if "params" in outputs:
correct_outlen += 1
self.assertEqual(out_len, correct_outlen)
# decoder attentions
decoder_attentions = outputs.decoder_attentions
self.assertIsInstance(decoder_attentions, (list, tuple))
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(decoder_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_seq_length],
)
# cross attentions
cross_attentions = outputs.cross_attentions
self.assertIsInstance(cross_attentions, (list, tuple))
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(cross_attentions[0].shape[-3:]),
[
self.model_tester.num_attention_heads,
decoder_seq_length,
encoder_seq_length,
],
)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
self.assertEqual(out_len + 2, len(outputs))
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_seq_length],
)
def prepare_batch(filename="train-batch.pt"):
file = hf_hub_download(repo_id="kashif/tourism-monthly-batch", filename=filename, repo_type="dataset")
batch = torch.load(file, map_location=torch_device)
return batch
@require_torch
@slow
class TimeSeriesTransformerModelIntegrationTests(unittest.TestCase):
def test_inference_no_head(self):
model = TimeSeriesTransformerModel.from_pretrained("huggingface/time-series-transformer-tourism-monthly").to(
torch_device
)
batch = prepare_batch()
with torch.no_grad():
output = model(
past_values=batch["past_values"],
past_time_features=batch["past_time_features"],
past_observed_mask=batch["past_observed_mask"],
static_categorical_features=batch["static_categorical_features"],
static_real_features=batch["static_real_features"],
future_values=batch["future_values"],
future_time_features=batch["future_time_features"],
)[0]
expected_shape = torch.Size((64, model.config.prediction_length, model.config.d_model))
self.assertEqual(output.shape, expected_shape)
expected_slice = torch.tensor(
[[-0.3125, -1.2884, -1.1118], [-0.5801, -1.4907, -0.7782], [0.0849, -1.6557, -0.9755]], device=torch_device
)
self.assertTrue(torch.allclose(output[0, :3, :3], expected_slice, atol=TOLERANCE))
def test_inference_head(self):
model = TimeSeriesTransformerForPrediction.from_pretrained(
"huggingface/time-series-transformer-tourism-monthly"
).to(torch_device)
batch = prepare_batch("val-batch.pt")
with torch.no_grad():
output = model(
past_values=batch["past_values"],
past_time_features=batch["past_time_features"],
past_observed_mask=batch["past_observed_mask"],
static_categorical_features=batch["static_categorical_features"],
static_real_features=batch["static_real_features"],
future_time_features=batch["future_time_features"],
)[1]
expected_shape = torch.Size((64, model.config.prediction_length, model.config.d_model))
self.assertEqual(output.shape, expected_shape)
expected_slice = torch.tensor(
[[0.9127, -0.2056, -0.5259], [1.0572, 1.4104, -0.1964], [0.1358, 2.0348, 0.5739]], device=torch_device
)
self.assertTrue(torch.allclose(output[0, :3, :3], expected_slice, atol=TOLERANCE))
def test_seq_to_seq_generation(self):
model = TimeSeriesTransformerForPrediction.from_pretrained(
"huggingface/time-series-transformer-tourism-monthly"
).to(torch_device)
batch = prepare_batch("val-batch.pt")
with torch.no_grad():
outputs = model.generate(
static_categorical_features=batch["static_categorical_features"],
static_real_features=batch["static_real_features"],
past_time_features=batch["past_time_features"],
past_values=batch["past_values"],
future_time_features=batch["future_time_features"],
past_observed_mask=batch["past_observed_mask"],
)
expected_shape = torch.Size((64, model.config.num_parallel_samples, model.config.prediction_length))
self.assertEqual(outputs.sequences.shape, expected_shape)
expected_slice = torch.tensor([2289.5203, 2778.3054, 4648.1313], device=torch_device)
mean_prediction = outputs.sequences.mean(dim=1)
self.assertTrue(torch.allclose(mean_prediction[0, -3:], expected_slice, rtol=1e-1))
......@@ -46,6 +46,8 @@ PRIVATE_MODELS = [
# Being in this list is an exception and should **not** be the rule.
IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
# models to ignore for not tested
"TimeSeriesTransformerEncoder", # Building part of bigger (tested) model.
"TimeSeriesTransformerDecoder", # Building part of bigger (tested) model.
"DeformableDetrEncoder", # Building part of bigger (tested) model.
"DeformableDetrDecoder", # Building part of bigger (tested) model.
"OPTDecoder", # Building part of bigger (tested) model.
......@@ -132,6 +134,7 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
# should **not** be the rule.
IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
# models to ignore for model xxx mapping
"TimeSeriesTransformerForPrediction",
"PegasusXEncoder",
"PegasusXDecoder",
"PegasusXDecoderWrapper",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment