Unverified Commit b242d0f2 authored by Arindam Jati's avatar Arindam Jati Committed by GitHub
Browse files

[Time series] Add PatchTSMixer (#26247)



* patchtsmixer initial commit

* x,y->context_values,target_values, unittest addded

* cleanup code

* minor

* return hidden states

* model tests, partial integration tests

* ettm notebook temporary

* minor

* config mask bug fix, tests updated

* final ETT notebooks

* add selfattn

* init

* added docstrings

* PatchTSMixerForPretraining -> PatchTSMixerForMaskPretraining

* functionality tests added

* add start and input docstrings

* docstring edits

* testcase edits

* minor changes

* docstring error fixed

* ran make fixup

* finalize integration tests and docs

* minor

* cleaned gitignore

* added dataclass decorator, ran black formatter

* ran ruff

* formatting

* add slow decorator

* renamed in_Channel to input_size and default to 1

* shorten dataclass names

* use smaller model for testing

* moved the 3 heads to the modeling file

* use scalers instead of revin

* support forecast_channel_indices

* fix regression scaling

* undo reg. scaling

* removed unneeded classes

* forgot missing

* add more layers

* add copied positional_encoding

* use patchmask from patchtst

* removed dependency on layers directory

* formatting

* set seed

* removed unused imports

* fixed forward signature test

* adding distributional head for PatchTSMixerForecasting

* add generate to forecast

* testcases for generate

* add generate and distributional head for regression

* raise Exception for negative values for neg binominal distribution

* formatting changes

* remove copied from patchtst and add TODO for test passing

* make copies

* doc edits

* minor changes

* format issues

* minor changes

* minor changes

* format docstring

* change some class names to PatchTSMixer + class name

Transpose to PatchTSMixerTranspose
GatedAttention to PatchTSMixerGatedAttention

* change NormLayer to PatchTSMixerNormLayer

* change MLP to PatchTSMixerMLP

* change PatchMixer to PatchMixerBlock, FeatureMixer to FeatureMixerBlock

* change ChannelFeatureMixer to ChannelFeatureMixerBlock

* change PatchMasking to PatchTSMixerMasking

* change Patchify to PatchTSMixerPatchify

* list to `list`

* fix docstrings

* formatting

* change bs to batch_size, edit forecast_masking

* edit random_masking

* change variable name and update docstring in PatchTSMixerMasking

* change variable name and update docstring in InjectScalerStatistics4D

* update forward call in PatchTSMixerTranspose

* change variable name and update docstring in PatchTSMixerNormLayer

* change variable name and update docstring in PatchTSMixerMLP

* change variable name and update docstring in ChannelFeatureMixerBlock

* formatting

* formatting issues

* docstring issue

* fixed observed_mask type in docstrings

* use FloatTensor type

* formatting

* fix rescaling issue in forecasting, fixed integration tests

* add docstring from decorator

* fix docstring

* Update README.md
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/patchtsmixer/configuration_patchtsmixer.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/patchtsmixer/modeling_patchtsmixer.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/patchtsmixer/configuration_patchtsmixer.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/patchtsmixer/modeling_patchtsmixer.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* PatchTSMixerChannelFeatureMixerBlock

* formatting

* ForPretraining

* use num_labels instead of n_classes

* remove commented out code

* docstring fixed

* nn.functional used instead of one letter F

* x_tmp renamed

* one letter variable x removed from forward calls

* one letter variable y removed

* remove commented code

* rename patch_size, in_channels, PatchTSMixerBackbone

* add config to heads

* add config to heads tests

* code reafactoring to use config instead of passing individual params

* Cdocstring fixes part 1

* docstring fixes part 2

* removed logger.debug

* context_values -> past_values

* formatting changes

* pe -> positional_encoding

* removed unused target variable

* self.mode logic fixed

* formatting change

* edit docstring and var name

* change n_targets to num_targets

* rename input_size to num_input_channels

* add head names with prefix PatchTSMixer

* edit docstring in PatchTSMixerForRegression

* fix var name change in testcases

* add PatchTSMixerAttention

* return dict for all exposed classes, test cases added

* format

* move loss function to forward call

* make style

* adding return dict/tuple

* make repo-consistency

* remove flatten mode

* code refactoring

* rename data

* remove PatchTSMixer and keep only PatchTSMixerEncoder

* docstring fixes

* removed unused code

* format

* format

* remove contiguous and formatting changes

* remove model description from config

* replace asserts with ValueError

* remove nn.Sequential from PatchTSMixerNormLayer

* replace if-else with map

* remove all nn.Sequential

* format

* formatting

* fix gradient_checkpointing error after merge, and formatting

* make fix-copies

* remove comments

* reshape

* doesnt support gradient checkpointing

* corect Patchify

* masking updates

* batchnorm copy from

* format checks

* scaler edits

* remove comments

* format changes

* remove self.config

* correct class PatchTSMixerMLP(nn.Module):

* makr fix

* doc updates

* fix-copies

* scaler class correction

* doc edits

* scaler edits

* update readme with links

* injectstatistics add

* fix-copies

* add norm_eps option to LayerNorm

* format changes

* fix copies

* correct make copies

* use parametrize

* fix doc string

* add docs to toctree

* make style

* doc segmenting

* docstring edit

* change forecast to prediction

* edit doc

* doc edits

* remove PatchTSMixerTranspose

* add PatchTSMixerPositionalEncoding and init position_enc

* remove positional_encoding

* edit forecast_masking, remove forecast_mask_ratios

* fix broken code

* var rename target_values -> future_values

* num_features -> d_model

* fix broken code after master merge

* repo consistency

* use postional embedding

* prediction_logits -> prediction_outputs, make fix-copies

* uncommented @slow

* minor changes

* loss first in tuple

* tuple and dict same ordering

* style edits

* minor changes

* dict/tuple consistent enablement

* Update src/transformers/models/patchtsmixer/modeling_patchtsmixer.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update tests/models/patchtsmixer/test_modeling_patchtsmixer.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/patchtsmixer/modeling_patchtsmixer.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* fix formatting

* formatting

* usage tip

* test on cpu only

* add sample usage

* change PatchTSMixerForClassification to PatchTSMixerForTimeSeriesClassification

* push changes

* fix copies

* std scaling set to default True case

* minor changes

* stylechanges

---------
Co-authored-by: default avatarArindam Jati <arindam.jati@ibm.com>
Co-authored-by: default avatarvijaye12 <vijaye12@in.ibm.com>
Co-authored-by: default avatarKashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: default avatarnnguyen <nnguyen@us.ibm.com>
Co-authored-by: default avatarvijaye12 <vijaykr.e@gmail.com>
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>
Co-authored-by: default avatarNam Nguyen <namctin@gmail.com>
Co-authored-by: default avatarWesley Gifford <79663411+wgifford@users.noreply.github.com>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent e5c12c03
...@@ -440,6 +440,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h ...@@ -440,6 +440,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h
1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al. 1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al.
1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby. 1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby.
1. **[OWLv2](https://huggingface.co/docs/transformers/model_doc/owlv2)** (from Google AI) released with the paper [Scaling Open-Vocabulary Object Detection](https://arxiv.org/abs/2306.09683) by Matthias Minderer, Alexey Gritsenko, Neil Houlsby. 1. **[OWLv2](https://huggingface.co/docs/transformers/model_doc/owlv2)** (from Google AI) released with the paper [Scaling Open-Vocabulary Object Detection](https://arxiv.org/abs/2306.09683) by Matthias Minderer, Alexey Gritsenko, Neil Houlsby.
1. **[PatchTSMixer](https://huggingface.co/docs/transformers/main/model_doc/patchtsmixer)** (from IBM Research) released with the paper [TSMixer: Lightweight MLP-Mixer Model for Multivariate Time Series Forecasting](https://arxiv.org/pdf/2306.09364.pdf) by Vijay Ekambaram, Arindam Jati, Nam Nguyen, Phanwadee Sinthong, Jayant Kalagnanam.
1. **[PatchTST](https://huggingface.co/docs/transformers/main/model_doc/patchtst)** (from IBM) released with the paper [A Time Series is Worth 64 Words: Long-term Forecasting with Transformers](https://arxiv.org/abs/2211.14730) by Yuqi Nie, Nam H. Nguyen, Phanwadee Sinthong, Jayant Kalagnanam. 1. **[PatchTST](https://huggingface.co/docs/transformers/main/model_doc/patchtst)** (from IBM) released with the paper [A Time Series is Worth 64 Words: Long-term Forecasting with Transformers](https://arxiv.org/abs/2211.14730) by Yuqi Nie, Nam H. Nguyen, Phanwadee Sinthong, Jayant Kalagnanam.
1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. 1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu.
1. **[PEGASUS-X](https://huggingface.co/docs/transformers/model_doc/pegasus_x)** (from Google) released with the paper [Investigating Efficiently Extending Transformers for Long Input Summarization](https://arxiv.org/abs/2208.04347) by Jason Phang, Yao Zhao, and Peter J. Liu. 1. **[PEGASUS-X](https://huggingface.co/docs/transformers/model_doc/pegasus_x)** (from Google) released with the paper [Investigating Efficiently Extending Transformers for Long Input Summarization](https://arxiv.org/abs/2208.04347) by Jason Phang, Yao Zhao, and Peter J. Liu.
......
...@@ -415,6 +415,7 @@ Número actual de puntos de control: ![](https://img.shields.io/endpoint?url=htt ...@@ -415,6 +415,7 @@ Número actual de puntos de control: ![](https://img.shields.io/endpoint?url=htt
1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al. 1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al.
1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby. 1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby.
1. **[OWLv2](https://huggingface.co/docs/transformers/model_doc/owlv2)** (from Google AI) released with the paper [Scaling Open-Vocabulary Object Detection](https://arxiv.org/abs/2306.09683) by Matthias Minderer, Alexey Gritsenko, Neil Houlsby. 1. **[OWLv2](https://huggingface.co/docs/transformers/model_doc/owlv2)** (from Google AI) released with the paper [Scaling Open-Vocabulary Object Detection](https://arxiv.org/abs/2306.09683) by Matthias Minderer, Alexey Gritsenko, Neil Houlsby.
1. **[PatchTSMixer](https://huggingface.co/docs/transformers/main/model_doc/patchtsmixer)** (from IBM Research) released with the paper [TSMixer: Lightweight MLP-Mixer Model for Multivariate Time Series Forecasting](https://arxiv.org/pdf/2306.09364.pdf) by Vijay Ekambaram, Arindam Jati, Nam Nguyen, Phanwadee Sinthong, Jayant Kalagnanam.
1. **[PatchTST](https://huggingface.co/docs/transformers/main/model_doc/patchtst)** (from IBM) released with the paper [A Time Series is Worth 64 Words: Long-term Forecasting with Transformers](https://arxiv.org/pdf/2211.14730.pdf) by Yuqi Nie, Nam H. Nguyen, Phanwadee Sinthong, Jayant Kalagnanam. 1. **[PatchTST](https://huggingface.co/docs/transformers/main/model_doc/patchtst)** (from IBM) released with the paper [A Time Series is Worth 64 Words: Long-term Forecasting with Transformers](https://arxiv.org/pdf/2211.14730.pdf) by Yuqi Nie, Nam H. Nguyen, Phanwadee Sinthong, Jayant Kalagnanam.
1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. 1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu.
1. **[PEGASUS-X](https://huggingface.co/docs/transformers/model_doc/pegasus_x)** (from Google) released with the paper [Investigating Efficiently Extending Transformers for Long Input Summarization](https://arxiv.org/abs/2208.04347) by Jason Phang, Yao Zhao, and Peter J. Liu. 1. **[PEGASUS-X](https://huggingface.co/docs/transformers/model_doc/pegasus_x)** (from Google) released with the paper [Investigating Efficiently Extending Transformers for Long Input Summarization](https://arxiv.org/abs/2208.04347) by Jason Phang, Yao Zhao, and Peter J. Liu.
......
...@@ -389,6 +389,7 @@ conda install -c huggingface transformers ...@@ -389,6 +389,7 @@ conda install -c huggingface transformers
1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al. 1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al.
1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (Google AI से) साथ में कागज [विज़न ट्रांसफॉर्मर्स के साथ सिंपल ओपन-वोकैबुलरी ऑब्जेक्ट डिटेक्शन](https:/ /arxiv.org/abs/2205.06230) मैथियास मिंडरर, एलेक्सी ग्रिट्सेंको, ऑस्टिन स्टोन, मैक्सिम न्यूमैन, डिर्क वीसेनबोर्न, एलेक्सी डोसोवित्स्की, अरविंद महेंद्रन, अनुराग अर्नब, मुस्तफा देहघानी, ज़ुओरन शेन, जिओ वांग, ज़ियाओहुआ झाई, थॉमस किफ़, और नील हॉल्सबी द्वारा पोस्ट किया गया। 1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (Google AI से) साथ में कागज [विज़न ट्रांसफॉर्मर्स के साथ सिंपल ओपन-वोकैबुलरी ऑब्जेक्ट डिटेक्शन](https:/ /arxiv.org/abs/2205.06230) मैथियास मिंडरर, एलेक्सी ग्रिट्सेंको, ऑस्टिन स्टोन, मैक्सिम न्यूमैन, डिर्क वीसेनबोर्न, एलेक्सी डोसोवित्स्की, अरविंद महेंद्रन, अनुराग अर्नब, मुस्तफा देहघानी, ज़ुओरन शेन, जिओ वांग, ज़ियाओहुआ झाई, थॉमस किफ़, और नील हॉल्सबी द्वारा पोस्ट किया गया।
1. **[OWLv2](https://huggingface.co/docs/transformers/model_doc/owlv2)** (Google AI से) Matthias Minderer, Alexey Gritsenko, Neil Houlsby. द्वाराअनुसंधान पत्र [Scaling Open-Vocabulary Object Detection](https://arxiv.org/abs/2306.09683) के साथ जारी किया गया 1. **[OWLv2](https://huggingface.co/docs/transformers/model_doc/owlv2)** (Google AI से) Matthias Minderer, Alexey Gritsenko, Neil Houlsby. द्वाराअनुसंधान पत्र [Scaling Open-Vocabulary Object Detection](https://arxiv.org/abs/2306.09683) के साथ जारी किया गया
1. **[PatchTSMixer](https://huggingface.co/docs/transformers/main/model_doc/patchtsmixer)** ( IBM Research से) Vijay Ekambaram, Arindam Jati, Nam Nguyen, Phanwadee Sinthong, Jayant Kalagnanam. द्वाराअनुसंधान पत्र [TSMixer: Lightweight MLP-Mixer Model for Multivariate Time Series Forecasting](https://arxiv.org/pdf/2306.09364.pdf) के साथ जारी किया गया
1. **[PatchTST](https://huggingface.co/docs/transformers/main/model_doc/patchtst)** (IBM से) Yuqi Nie, Nam H. Nguyen, Phanwadee Sinthong, Jayant Kalagnanam. द्वाराअनुसंधान पत्र [A Time Series is Worth 64 Words: Long-term Forecasting with Transformers](https://arxiv.org/pdf/2211.14730.pdf) के साथ जारी किया गया 1. **[PatchTST](https://huggingface.co/docs/transformers/main/model_doc/patchtst)** (IBM से) Yuqi Nie, Nam H. Nguyen, Phanwadee Sinthong, Jayant Kalagnanam. द्वाराअनुसंधान पत्र [A Time Series is Worth 64 Words: Long-term Forecasting with Transformers](https://arxiv.org/pdf/2211.14730.pdf) के साथ जारी किया गया
1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. 1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu.
1. **[PEGASUS-X](https://huggingface.co/docs/transformers/model_doc/pegasus_x)** (Google की ओर से) साथ में दिया गया पेपर [लंबे इनपुट सारांश के लिए ट्रांसफ़ॉर्मरों को बेहतर तरीके से एक्सटेंड करना](https://arxiv .org/abs/2208.04347) जेसन फांग, याओ झाओ, पीटर जे लियू द्वारा। 1. **[PEGASUS-X](https://huggingface.co/docs/transformers/model_doc/pegasus_x)** (Google की ओर से) साथ में दिया गया पेपर [लंबे इनपुट सारांश के लिए ट्रांसफ़ॉर्मरों को बेहतर तरीके से एक्सटेंड करना](https://arxiv .org/abs/2208.04347) जेसन फांग, याओ झाओ, पीटर जे लियू द्वारा।
......
...@@ -449,6 +449,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ ...@@ -449,6 +449,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ
1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (Meta AI から) Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al から公開された研究論文: [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) 1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (Meta AI から) Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al から公開された研究論文: [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068)
1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (Google AI から) Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby から公開された研究論文: [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) 1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (Google AI から) Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby から公開された研究論文: [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230)
1. **[OWLv2](https://huggingface.co/docs/transformers/model_doc/owlv2)** (Google AI から) Matthias Minderer, Alexey Gritsenko, Neil Houlsby. から公開された研究論文 [Scaling Open-Vocabulary Object Detection](https://arxiv.org/abs/2306.09683) 1. **[OWLv2](https://huggingface.co/docs/transformers/model_doc/owlv2)** (Google AI から) Matthias Minderer, Alexey Gritsenko, Neil Houlsby. から公開された研究論文 [Scaling Open-Vocabulary Object Detection](https://arxiv.org/abs/2306.09683)
1. **[PatchTSMixer](https://huggingface.co/docs/transformers/main/model_doc/patchtsmixer)** ( IBM Research から) Vijay Ekambaram, Arindam Jati, Nam Nguyen, Phanwadee Sinthong, Jayant Kalagnanam. から公開された研究論文 [TSMixer: Lightweight MLP-Mixer Model for Multivariate Time Series Forecasting](https://arxiv.org/pdf/2306.09364.pdf)
1. **[PatchTST](https://huggingface.co/docs/transformers/main/model_doc/patchtst)** (IBM から) Yuqi Nie, Nam H. Nguyen, Phanwadee Sinthong, Jayant Kalagnanam. から公開された研究論文 [A Time Series is Worth 64 Words: Long-term Forecasting with Transformers](https://arxiv.org/pdf/2211.14730.pdf) 1. **[PatchTST](https://huggingface.co/docs/transformers/main/model_doc/patchtst)** (IBM から) Yuqi Nie, Nam H. Nguyen, Phanwadee Sinthong, Jayant Kalagnanam. から公開された研究論文 [A Time Series is Worth 64 Words: Long-term Forecasting with Transformers](https://arxiv.org/pdf/2211.14730.pdf)
1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (Google から) Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu から公開された研究論文: [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) 1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (Google から) Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu から公開された研究論文: [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777)
1. **[PEGASUS-X](https://huggingface.co/docs/transformers/model_doc/pegasus_x)** (Google から) Jason Phang, Yao Zhao, and Peter J. Liu から公開された研究論文: [Investigating Efficiently Extending Transformers for Long Input Summarization](https://arxiv.org/abs/2208.04347) 1. **[PEGASUS-X](https://huggingface.co/docs/transformers/model_doc/pegasus_x)** (Google から) Jason Phang, Yao Zhao, and Peter J. Liu から公開された研究論文: [Investigating Efficiently Extending Transformers for Long Input Summarization](https://arxiv.org/abs/2208.04347)
......
...@@ -364,6 +364,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 ...@@ -364,6 +364,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (Meta AI 에서) Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al 의 [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) 논문과 함께 발표했습니다. 1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (Meta AI 에서) Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al 의 [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) 논문과 함께 발표했습니다.
1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (Google AI 에서) Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby 의 [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) 논문과 함께 발표했습니다. 1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (Google AI 에서) Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby 의 [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) 논문과 함께 발표했습니다.
1. **[OWLv2](https://huggingface.co/docs/transformers/model_doc/owlv2)** (Google AI 에서 제공)은 Matthias Minderer, Alexey Gritsenko, Neil Houlsby.의 [Scaling Open-Vocabulary Object Detection](https://arxiv.org/abs/2306.09683)논문과 함께 발표했습니다. 1. **[OWLv2](https://huggingface.co/docs/transformers/model_doc/owlv2)** (Google AI 에서 제공)은 Matthias Minderer, Alexey Gritsenko, Neil Houlsby.의 [Scaling Open-Vocabulary Object Detection](https://arxiv.org/abs/2306.09683)논문과 함께 발표했습니다.
1. **[PatchTSMixer](https://huggingface.co/docs/transformers/main/model_doc/patchtsmixer)** ( IBM Research 에서 제공)은 Vijay Ekambaram, Arindam Jati, Nam Nguyen, Phanwadee Sinthong, Jayant Kalagnanam.의 [TSMixer: Lightweight MLP-Mixer Model for Multivariate Time Series Forecasting](https://arxiv.org/pdf/2306.09364.pdf)논문과 함께 발표했습니다.
1. **[PatchTST](https://huggingface.co/docs/transformers/main/model_doc/patchtst)** (IBM 에서 제공)은 Yuqi Nie, Nam H. Nguyen, Phanwadee Sinthong, Jayant Kalagnanam.의 [A Time Series is Worth 64 Words: Long-term Forecasting with Transformers](https://arxiv.org/pdf/2211.14730.pdf)논문과 함께 발표했습니다. 1. **[PatchTST](https://huggingface.co/docs/transformers/main/model_doc/patchtst)** (IBM 에서 제공)은 Yuqi Nie, Nam H. Nguyen, Phanwadee Sinthong, Jayant Kalagnanam.의 [A Time Series is Worth 64 Words: Long-term Forecasting with Transformers](https://arxiv.org/pdf/2211.14730.pdf)논문과 함께 발표했습니다.
1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (Google 에서) Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu 의 [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) 논문과 함께 발표했습니다. 1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (Google 에서) Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu 의 [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) 논문과 함께 발표했습니다.
1. **[PEGASUS-X](https://huggingface.co/docs/transformers/model_doc/pegasus_x)** (Google 에서) Jason Phang, Yao Zhao, Peter J. Liu 의 [Investigating Efficiently Extending Transformers for Long Input Summarization](https://arxiv.org/abs/2208.04347) 논문과 함께 발표했습니다. 1. **[PEGASUS-X](https://huggingface.co/docs/transformers/model_doc/pegasus_x)** (Google 에서) Jason Phang, Yao Zhao, Peter J. Liu 의 [Investigating Efficiently Extending Transformers for Long Input Summarization](https://arxiv.org/abs/2208.04347) 논문과 함께 발표했습니다.
......
...@@ -388,6 +388,7 @@ conda install -c huggingface transformers ...@@ -388,6 +388,7 @@ conda install -c huggingface transformers
1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (来自 Meta AI) 伴随论文 [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) 由 Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al 发布。 1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (来自 Meta AI) 伴随论文 [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) 由 Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al 发布。
1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (来自 Google AI) 伴随论文 [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) 由 Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby 发布。 1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (来自 Google AI) 伴随论文 [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) 由 Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby 发布。
1. **[OWLv2](https://huggingface.co/docs/transformers/model_doc/owlv2)** (来自 Google AI) 伴随论文 [Scaling Open-Vocabulary Object Detection](https://arxiv.org/abs/2306.09683) 由 Matthias Minderer, Alexey Gritsenko, Neil Houlsby 发布。 1. **[OWLv2](https://huggingface.co/docs/transformers/model_doc/owlv2)** (来自 Google AI) 伴随论文 [Scaling Open-Vocabulary Object Detection](https://arxiv.org/abs/2306.09683) 由 Matthias Minderer, Alexey Gritsenko, Neil Houlsby 发布。
1. **[PatchTSMixer](https://huggingface.co/docs/transformers/main/model_doc/patchtsmixer)** (来自 IBM Research) 伴随论文 [TSMixer: Lightweight MLP-Mixer Model for Multivariate Time Series Forecasting](https://arxiv.org/pdf/2306.09364.pdf) 由 Vijay Ekambaram, Arindam Jati, Nam Nguyen, Phanwadee Sinthong, Jayant Kalagnanam 发布。
1. **[PatchTST](https://huggingface.co/docs/transformers/main/model_doc/patchtst)** (来自 IBM) 伴随论文 [A Time Series is Worth 64 Words: Long-term Forecasting with Transformers](https://arxiv.org/pdf/2211.14730.pdf) 由 Yuqi Nie, Nam H. Nguyen, Phanwadee Sinthong, Jayant Kalagnanam 发布。 1. **[PatchTST](https://huggingface.co/docs/transformers/main/model_doc/patchtst)** (来自 IBM) 伴随论文 [A Time Series is Worth 64 Words: Long-term Forecasting with Transformers](https://arxiv.org/pdf/2211.14730.pdf) 由 Yuqi Nie, Nam H. Nguyen, Phanwadee Sinthong, Jayant Kalagnanam 发布。
1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (来自 Google) 伴随论文 [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) 由 Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu 发布。 1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (来自 Google) 伴随论文 [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) 由 Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu 发布。
1. **[PEGASUS-X](https://huggingface.co/docs/transformers/model_doc/pegasus_x)** (来自 Google) 伴随论文 [Investigating Efficiently Extending Transformers for Long Input Summarization](https://arxiv.org/abs/2208.04347) 由 Jason Phang, Yao Zhao, Peter J. Liu 发布。 1. **[PEGASUS-X](https://huggingface.co/docs/transformers/model_doc/pegasus_x)** (来自 Google) 伴随论文 [Investigating Efficiently Extending Transformers for Long Input Summarization](https://arxiv.org/abs/2208.04347) 由 Jason Phang, Yao Zhao, Peter J. Liu 发布。
......
...@@ -400,6 +400,7 @@ conda install -c huggingface transformers ...@@ -400,6 +400,7 @@ conda install -c huggingface transformers
1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al. 1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al.
1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby. 1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby.
1. **[OWLv2](https://huggingface.co/docs/transformers/model_doc/owlv2)** (from Google AI) released with the paper [Scaling Open-Vocabulary Object Detection](https://arxiv.org/abs/2306.09683) by Matthias Minderer, Alexey Gritsenko, Neil Houlsby. 1. **[OWLv2](https://huggingface.co/docs/transformers/model_doc/owlv2)** (from Google AI) released with the paper [Scaling Open-Vocabulary Object Detection](https://arxiv.org/abs/2306.09683) by Matthias Minderer, Alexey Gritsenko, Neil Houlsby.
1. **[PatchTSMixer](https://huggingface.co/docs/transformers/main/model_doc/patchtsmixer)** (from IBM Research) released with the paper [TSMixer: Lightweight MLP-Mixer Model for Multivariate Time Series Forecasting](https://arxiv.org/pdf/2306.09364.pdf) by Vijay Ekambaram, Arindam Jati, Nam Nguyen, Phanwadee Sinthong, Jayant Kalagnanam.
1. **[PatchTST](https://huggingface.co/docs/transformers/main/model_doc/patchtst)** (from IBM) released with the paper [A Time Series is Worth 64 Words: Long-term Forecasting with Transformers](https://arxiv.org/pdf/2211.14730.pdf) by Yuqi Nie, Nam H. Nguyen, Phanwadee Sinthong, Jayant Kalagnanam. 1. **[PatchTST](https://huggingface.co/docs/transformers/main/model_doc/patchtst)** (from IBM) released with the paper [A Time Series is Worth 64 Words: Long-term Forecasting with Transformers](https://arxiv.org/pdf/2211.14730.pdf) by Yuqi Nie, Nam H. Nguyen, Phanwadee Sinthong, Jayant Kalagnanam.
1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. 1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu.
1. **[PEGASUS-X](https://huggingface.co/docs/transformers/model_doc/pegasus_x)** (from Google) released with the paper [Investigating Efficiently Extending Transformers for Long Input Summarization](https://arxiv.org/abs/2208.04347) by Jason Phang, Yao Zhao, Peter J. Liu. 1. **[PEGASUS-X](https://huggingface.co/docs/transformers/model_doc/pegasus_x)** (from Google) released with the paper [Investigating Efficiently Extending Transformers for Long Input Summarization](https://arxiv.org/abs/2208.04347) by Jason Phang, Yao Zhao, Peter J. Liu.
......
...@@ -757,6 +757,8 @@ ...@@ -757,6 +757,8 @@
title: Autoformer title: Autoformer
- local: model_doc/informer - local: model_doc/informer
title: Informer title: Informer
- local: model_doc/patchtsmixer
title: PatchTSMixer
- local: model_doc/patchtst - local: model_doc/patchtst
title: PatchTST title: PatchTST
- local: model_doc/time_series_transformer - local: model_doc/time_series_transformer
......
...@@ -214,6 +214,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -214,6 +214,7 @@ Flax), PyTorch, and/or TensorFlow.
| [OPT](model_doc/opt) | ✅ | ✅ | ✅ | | [OPT](model_doc/opt) | ✅ | ✅ | ✅ |
| [OWL-ViT](model_doc/owlvit) | ✅ | ❌ | ❌ | | [OWL-ViT](model_doc/owlvit) | ✅ | ❌ | ❌ |
| [OWLv2](model_doc/owlv2) | ✅ | ❌ | ❌ | | [OWLv2](model_doc/owlv2) | ✅ | ❌ | ❌ |
| [PatchTSMixer](model_doc/patchtsmixer) | ✅ | ❌ | ❌ |
| [PatchTST](model_doc/patchtst) | ✅ | ❌ | ❌ | | [PatchTST](model_doc/patchtst) | ✅ | ❌ | ❌ |
| [Pegasus](model_doc/pegasus) | ✅ | ✅ | ✅ | | [Pegasus](model_doc/pegasus) | ✅ | ✅ | ✅ |
| [PEGASUS-X](model_doc/pegasus_x) | ✅ | ❌ | ❌ | | [PEGASUS-X](model_doc/pegasus_x) | ✅ | ❌ | ❌ |
......
<!--Copyright 2023 IBM and HuggingFace Inc. team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# PatchTSMixer
## Overview
The PatchTSMixer model was proposed in [TSMixer: Lightweight MLP-Mixer Model for Multivariate Time Series Forecasting](https://arxiv.org/pdf/2306.09364.pdf) by Vijay Ekambaram, Arindam Jati, Nam Nguyen, Phanwadee Sinthong and Jayant Kalagnanam.
PatchTSMixer is a lightweight time-series modeling approach based on the MLP-Mixer architecture. In this HuggingFace implementation, we provide PatchTSMixer's capabilities to effortlessly facilitate lightweight mixing across patches, channels, and hidden features for effective multivariate time-series modeling. It also supports various attention mechanisms starting from simple gated attention to more complex self-attention blocks that can be customized accordingly. The model can be pretrained and subsequently used for various downstream tasks such as forecasting, classification and regression.
The abstract from the paper is the following:
*TSMixer is a lightweight neural architecture exclusively composed of multi-layer perceptron (MLP) modules designed for multivariate forecasting and representation learning on patched time series. Our model draws inspiration from the success of MLP-Mixer models in computer vision. We demonstrate the challenges involved in adapting Vision MLP-Mixer for time series and introduce empirically validated components to enhance accuracy. This includes a novel design paradigm of attaching online reconciliation heads to the MLP-Mixer backbone, for explicitly modeling the time-series properties such as hierarchy and channel-correlations. We also propose a Hybrid channel modeling approach to effectively handle noisy channel interactions and generalization across diverse datasets, a common challenge in existing patch channel-mixing methods. Additionally, a simple gated attention mechanism is introduced in the backbone to prioritize important features. By incorporating these lightweight components, we significantly enhance the learning capability of simple MLP structures, outperforming complex Transformer models with minimal computing usage. Moreover, TSMixer's modular design enables compatibility with both supervised and masked self-supervised learning methods, making it a promising building block for time-series Foundation Models. TSMixer outperforms state-of-the-art MLP and Transformer models in forecasting by a considerable margin of 8-60%. It also outperforms the latest strong benchmarks of Patch-Transformer models (by 1-2%) with a significant reduction in memory and runtime (2-3X).*
This model was contributed by [ajati](https://huggingface.co/ajati), [vijaye12](https://huggingface.co/vijaye12),
[gsinthong](https://huggingface.co/gsinthong), [namctin](https://huggingface.co/namctin),
[wmgifford](https://huggingface.co/wmgifford), [kashif](https://huggingface.co/kashif).
## Sample usage
```python
from transformers import PatchTSMixerConfig, PatchTSMixerForPrediction
from transformers import Trainer, TrainingArguments,
config = PatchTSMixerConfig(context_length = 512, prediction_length = 96)
model = PatchTSMixerForPrediction(config)
trainer = Trainer(model=model, args=training_args,
train_dataset=train_dataset,
eval_dataset=valid_dataset)
trainer.train()
results = trainer.evaluate(test_dataset)
```
## Usage tips
The model can also be used for time series classification and time series regression. See the respective [`PatchTSMixerForTimeSeriesClassification`] and [`PatchTSMixerForRegression`] classes.
## PatchTSMixerConfig
[[autodoc]] PatchTSMixerConfig
## PatchTSMixerModel
[[autodoc]] PatchTSMixerModel
- forward
## PatchTSMixerForPrediction
[[autodoc]] PatchTSMixerForPrediction
- forward
## PatchTSMixerForTimeSeriesClassification
[[autodoc]] PatchTSMixerForTimeSeriesClassification
- forward
## PatchTSMixerForPretraining
[[autodoc]] PatchTSMixerForPretraining
- forward
## PatchTSMixerForRegression
[[autodoc]] PatchTSMixerForRegression
- forward
\ No newline at end of file
...@@ -186,16 +186,28 @@ _import_structure = { ...@@ -186,16 +186,28 @@ _import_structure = {
"WordpieceTokenizer", "WordpieceTokenizer",
], ],
"models.bert_generation": ["BertGenerationConfig"], "models.bert_generation": ["BertGenerationConfig"],
"models.bert_japanese": ["BertJapaneseTokenizer", "CharacterTokenizer", "MecabTokenizer"], "models.bert_japanese": [
"BertJapaneseTokenizer",
"CharacterTokenizer",
"MecabTokenizer",
],
"models.bertweet": ["BertweetTokenizer"], "models.bertweet": ["BertweetTokenizer"],
"models.big_bird": ["BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP", "BigBirdConfig"], "models.big_bird": ["BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP", "BigBirdConfig"],
"models.bigbird_pegasus": [ "models.bigbird_pegasus": [
"BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP", "BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP",
"BigBirdPegasusConfig", "BigBirdPegasusConfig",
], ],
"models.biogpt": ["BIOGPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BioGptConfig", "BioGptTokenizer"], "models.biogpt": [
"BIOGPT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"BioGptConfig",
"BioGptTokenizer",
],
"models.bit": ["BIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BitConfig"], "models.bit": ["BIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BitConfig"],
"models.blenderbot": ["BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BlenderbotConfig", "BlenderbotTokenizer"], "models.blenderbot": [
"BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"BlenderbotConfig",
"BlenderbotTokenizer",
],
"models.blenderbot_small": [ "models.blenderbot_small": [
"BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP", "BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP",
"BlenderbotSmallConfig", "BlenderbotSmallConfig",
...@@ -223,10 +235,18 @@ _import_structure = { ...@@ -223,10 +235,18 @@ _import_structure = {
"BridgeTowerTextConfig", "BridgeTowerTextConfig",
"BridgeTowerVisionConfig", "BridgeTowerVisionConfig",
], ],
"models.bros": ["BROS_PRETRAINED_CONFIG_ARCHIVE_MAP", "BrosConfig", "BrosProcessor"], "models.bros": [
"BROS_PRETRAINED_CONFIG_ARCHIVE_MAP",
"BrosConfig",
"BrosProcessor",
],
"models.byt5": ["ByT5Tokenizer"], "models.byt5": ["ByT5Tokenizer"],
"models.camembert": ["CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CamembertConfig"], "models.camembert": ["CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CamembertConfig"],
"models.canine": ["CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP", "CanineConfig", "CanineTokenizer"], "models.canine": [
"CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP",
"CanineConfig",
"CanineTokenizer",
],
"models.chinese_clip": [ "models.chinese_clip": [
"CHINESE_CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP", "CHINESE_CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP",
"ChineseCLIPConfig", "ChineseCLIPConfig",
...@@ -266,14 +286,36 @@ _import_structure = { ...@@ -266,14 +286,36 @@ _import_structure = {
"ClvpTokenizer", "ClvpTokenizer",
], ],
"models.code_llama": [], "models.code_llama": [],
"models.codegen": ["CODEGEN_PRETRAINED_CONFIG_ARCHIVE_MAP", "CodeGenConfig", "CodeGenTokenizer"], "models.codegen": [
"models.conditional_detr": ["CONDITIONAL_DETR_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConditionalDetrConfig"], "CODEGEN_PRETRAINED_CONFIG_ARCHIVE_MAP",
"models.convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig", "ConvBertTokenizer"], "CodeGenConfig",
"CodeGenTokenizer",
],
"models.conditional_detr": [
"CONDITIONAL_DETR_PRETRAINED_CONFIG_ARCHIVE_MAP",
"ConditionalDetrConfig",
],
"models.convbert": [
"CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"ConvBertConfig",
"ConvBertTokenizer",
],
"models.convnext": ["CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvNextConfig"], "models.convnext": ["CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvNextConfig"],
"models.convnextv2": ["CONVNEXTV2_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvNextV2Config"], "models.convnextv2": [
"CONVNEXTV2_PRETRAINED_CONFIG_ARCHIVE_MAP",
"ConvNextV2Config",
],
"models.cpm": [], "models.cpm": [],
"models.cpmant": ["CPMANT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CpmAntConfig", "CpmAntTokenizer"], "models.cpmant": [
"models.ctrl": ["CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CTRLConfig", "CTRLTokenizer"], "CPMANT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"CpmAntConfig",
"CpmAntTokenizer",
],
"models.ctrl": [
"CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP",
"CTRLConfig",
"CTRLTokenizer",
],
"models.cvt": ["CVT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CvtConfig"], "models.cvt": ["CVT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CvtConfig"],
"models.data2vec": [ "models.data2vec": [
"DATA2VEC_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DATA2VEC_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP",
...@@ -282,10 +324,23 @@ _import_structure = { ...@@ -282,10 +324,23 @@ _import_structure = {
"Data2VecTextConfig", "Data2VecTextConfig",
"Data2VecVisionConfig", "Data2VecVisionConfig",
], ],
"models.deberta": ["DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaConfig", "DebertaTokenizer"], "models.deberta": [
"models.deberta_v2": ["DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaV2Config"], "DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP",
"models.decision_transformer": ["DECISION_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "DecisionTransformerConfig"], "DebertaConfig",
"models.deformable_detr": ["DEFORMABLE_DETR_PRETRAINED_CONFIG_ARCHIVE_MAP", "DeformableDetrConfig"], "DebertaTokenizer",
],
"models.deberta_v2": [
"DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP",
"DebertaV2Config",
],
"models.decision_transformer": [
"DECISION_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"DecisionTransformerConfig",
],
"models.deformable_detr": [
"DEFORMABLE_DETR_PRETRAINED_CONFIG_ARCHIVE_MAP",
"DeformableDetrConfig",
],
"models.deit": ["DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DeiTConfig"], "models.deit": ["DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DeiTConfig"],
"models.deprecated": [], "models.deprecated": [],
"models.deprecated.bort": [], "models.deprecated.bort": [],
...@@ -296,7 +351,10 @@ _import_structure = { ...@@ -296,7 +351,10 @@ _import_structure = {
"MCTCTProcessor", "MCTCTProcessor",
], ],
"models.deprecated.mmbt": ["MMBTConfig"], "models.deprecated.mmbt": ["MMBTConfig"],
"models.deprecated.open_llama": ["OPEN_LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP", "OpenLlamaConfig"], "models.deprecated.open_llama": [
"OPEN_LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP",
"OpenLlamaConfig",
],
"models.deprecated.retribert": [ "models.deprecated.retribert": [
"RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"RetriBertConfig", "RetriBertConfig",
...@@ -319,9 +377,17 @@ _import_structure = { ...@@ -319,9 +377,17 @@ _import_structure = {
"models.dialogpt": [], "models.dialogpt": [],
"models.dinat": ["DINAT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DinatConfig"], "models.dinat": ["DINAT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DinatConfig"],
"models.dinov2": ["DINOV2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Dinov2Config"], "models.dinov2": ["DINOV2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Dinov2Config"],
"models.distilbert": ["DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DistilBertConfig", "DistilBertTokenizer"], "models.distilbert": [
"DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"DistilBertConfig",
"DistilBertTokenizer",
],
"models.dit": [], "models.dit": [],
"models.donut": ["DONUT_SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP", "DonutProcessor", "DonutSwinConfig"], "models.donut": [
"DONUT_SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP",
"DonutProcessor",
"DonutSwinConfig",
],
"models.dpr": [ "models.dpr": [
"DPR_PRETRAINED_CONFIG_ARCHIVE_MAP", "DPR_PRETRAINED_CONFIG_ARCHIVE_MAP",
"DPRConfig", "DPRConfig",
...@@ -331,9 +397,19 @@ _import_structure = { ...@@ -331,9 +397,19 @@ _import_structure = {
"DPRReaderTokenizer", "DPRReaderTokenizer",
], ],
"models.dpt": ["DPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DPTConfig"], "models.dpt": ["DPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DPTConfig"],
"models.efficientformer": ["EFFICIENTFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "EfficientFormerConfig"], "models.efficientformer": [
"models.efficientnet": ["EFFICIENTNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "EfficientNetConfig"], "EFFICIENTFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"models.electra": ["ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP", "ElectraConfig", "ElectraTokenizer"], "EfficientFormerConfig",
],
"models.efficientnet": [
"EFFICIENTNET_PRETRAINED_CONFIG_ARCHIVE_MAP",
"EfficientNetConfig",
],
"models.electra": [
"ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP",
"ElectraConfig",
"ElectraTokenizer",
],
"models.encodec": [ "models.encodec": [
"ENCODEC_PRETRAINED_CONFIG_ARCHIVE_MAP", "ENCODEC_PRETRAINED_CONFIG_ARCHIVE_MAP",
"EncodecConfig", "EncodecConfig",
...@@ -347,7 +423,11 @@ _import_structure = { ...@@ -347,7 +423,11 @@ _import_structure = {
"models.ernie_m": ["ERNIE_M_PRETRAINED_CONFIG_ARCHIVE_MAP", "ErnieMConfig"], "models.ernie_m": ["ERNIE_M_PRETRAINED_CONFIG_ARCHIVE_MAP", "ErnieMConfig"],
"models.esm": ["ESM_PRETRAINED_CONFIG_ARCHIVE_MAP", "EsmConfig", "EsmTokenizer"], "models.esm": ["ESM_PRETRAINED_CONFIG_ARCHIVE_MAP", "EsmConfig", "EsmTokenizer"],
"models.falcon": ["FALCON_PRETRAINED_CONFIG_ARCHIVE_MAP", "FalconConfig"], "models.falcon": ["FALCON_PRETRAINED_CONFIG_ARCHIVE_MAP", "FalconConfig"],
"models.flaubert": ["FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "FlaubertConfig", "FlaubertTokenizer"], "models.flaubert": [
"FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"FlaubertConfig",
"FlaubertTokenizer",
],
"models.flava": [ "models.flava": [
"FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP", "FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP",
"FlavaConfig", "FlavaConfig",
...@@ -358,16 +438,39 @@ _import_structure = { ...@@ -358,16 +438,39 @@ _import_structure = {
], ],
"models.fnet": ["FNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "FNetConfig"], "models.fnet": ["FNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "FNetConfig"],
"models.focalnet": ["FOCALNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "FocalNetConfig"], "models.focalnet": ["FOCALNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "FocalNetConfig"],
"models.fsmt": ["FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP", "FSMTConfig", "FSMTTokenizer"], "models.fsmt": [
"models.funnel": ["FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP", "FunnelConfig", "FunnelTokenizer"], "FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"FSMTConfig",
"FSMTTokenizer",
],
"models.funnel": [
"FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP",
"FunnelConfig",
"FunnelTokenizer",
],
"models.fuyu": ["FUYU_PRETRAINED_CONFIG_ARCHIVE_MAP", "FuyuConfig"], "models.fuyu": ["FUYU_PRETRAINED_CONFIG_ARCHIVE_MAP", "FuyuConfig"],
"models.git": ["GIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "GitConfig", "GitProcessor", "GitVisionConfig"], "models.git": [
"GIT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"GitConfig",
"GitProcessor",
"GitVisionConfig",
],
"models.glpn": ["GLPN_PRETRAINED_CONFIG_ARCHIVE_MAP", "GLPNConfig"], "models.glpn": ["GLPN_PRETRAINED_CONFIG_ARCHIVE_MAP", "GLPNConfig"],
"models.gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2Tokenizer"], "models.gpt2": [
"models.gpt_bigcode": ["GPT_BIGCODE_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTBigCodeConfig"], "GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP",
"GPT2Config",
"GPT2Tokenizer",
],
"models.gpt_bigcode": [
"GPT_BIGCODE_PRETRAINED_CONFIG_ARCHIVE_MAP",
"GPTBigCodeConfig",
],
"models.gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig"], "models.gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig"],
"models.gpt_neox": ["GPT_NEOX_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoXConfig"], "models.gpt_neox": ["GPT_NEOX_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoXConfig"],
"models.gpt_neox_japanese": ["GPT_NEOX_JAPANESE_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoXJapaneseConfig"], "models.gpt_neox_japanese": [
"GPT_NEOX_JAPANESE_PRETRAINED_CONFIG_ARCHIVE_MAP",
"GPTNeoXJapaneseConfig",
],
"models.gpt_sw3": [], "models.gpt_sw3": [],
"models.gptj": ["GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTJConfig"], "models.gptj": ["GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTJConfig"],
"models.gptsan_japanese": [ "models.gptsan_japanese": [
...@@ -375,7 +478,10 @@ _import_structure = { ...@@ -375,7 +478,10 @@ _import_structure = {
"GPTSanJapaneseConfig", "GPTSanJapaneseConfig",
"GPTSanJapaneseTokenizer", "GPTSanJapaneseTokenizer",
], ],
"models.graphormer": ["GRAPHORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "GraphormerConfig"], "models.graphormer": [
"GRAPHORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"GraphormerConfig",
],
"models.groupvit": [ "models.groupvit": [
"GROUPVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "GROUPVIT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"GroupViTConfig", "GroupViTConfig",
...@@ -410,7 +516,11 @@ _import_structure = { ...@@ -410,7 +516,11 @@ _import_structure = {
"Kosmos2Config", "Kosmos2Config",
"Kosmos2Processor", "Kosmos2Processor",
], ],
"models.layoutlm": ["LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "LayoutLMConfig", "LayoutLMTokenizer"], "models.layoutlm": [
"LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP",
"LayoutLMConfig",
"LayoutLMTokenizer",
],
"models.layoutlmv2": [ "models.layoutlmv2": [
"LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP", "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP",
"LayoutLMv2Config", "LayoutLMv2Config",
...@@ -432,10 +542,22 @@ _import_structure = { ...@@ -432,10 +542,22 @@ _import_structure = {
"models.levit": ["LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LevitConfig"], "models.levit": ["LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LevitConfig"],
"models.lilt": ["LILT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LiltConfig"], "models.lilt": ["LILT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LiltConfig"],
"models.llama": ["LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP", "LlamaConfig"], "models.llama": ["LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP", "LlamaConfig"],
"models.longformer": ["LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongformerConfig", "LongformerTokenizer"], "models.longformer": [
"LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"LongformerConfig",
"LongformerTokenizer",
],
"models.longt5": ["LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongT5Config"], "models.longt5": ["LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongT5Config"],
"models.luke": ["LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP", "LukeConfig", "LukeTokenizer"], "models.luke": [
"models.lxmert": ["LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LxmertConfig", "LxmertTokenizer"], "LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP",
"LukeConfig",
"LukeTokenizer",
],
"models.lxmert": [
"LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"LxmertConfig",
"LxmertTokenizer",
],
"models.m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"], "models.m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"],
"models.marian": ["MarianConfig"], "models.marian": ["MarianConfig"],
"models.markuplm": [ "models.markuplm": [
...@@ -449,21 +571,50 @@ _import_structure = { ...@@ -449,21 +571,50 @@ _import_structure = {
"MASK2FORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "MASK2FORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"Mask2FormerConfig", "Mask2FormerConfig",
], ],
"models.maskformer": ["MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "MaskFormerConfig", "MaskFormerSwinConfig"], "models.maskformer": [
"MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"MaskFormerConfig",
"MaskFormerSwinConfig",
],
"models.mbart": ["MBartConfig"], "models.mbart": ["MBartConfig"],
"models.mbart50": [], "models.mbart50": [],
"models.mega": ["MEGA_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegaConfig"], "models.mega": ["MEGA_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegaConfig"],
"models.megatron_bert": ["MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegatronBertConfig"], "models.megatron_bert": [
"MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"MegatronBertConfig",
],
"models.megatron_gpt2": [], "models.megatron_gpt2": [],
"models.mgp_str": ["MGP_STR_PRETRAINED_CONFIG_ARCHIVE_MAP", "MgpstrConfig", "MgpstrProcessor", "MgpstrTokenizer"], "models.mgp_str": [
"MGP_STR_PRETRAINED_CONFIG_ARCHIVE_MAP",
"MgpstrConfig",
"MgpstrProcessor",
"MgpstrTokenizer",
],
"models.mistral": ["MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP", "MistralConfig"], "models.mistral": ["MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP", "MistralConfig"],
"models.mluke": [], "models.mluke": [],
"models.mobilebert": ["MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileBertConfig", "MobileBertTokenizer"], "models.mobilebert": [
"models.mobilenet_v1": ["MOBILENET_V1_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileNetV1Config"], "MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"models.mobilenet_v2": ["MOBILENET_V2_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileNetV2Config"], "MobileBertConfig",
"MobileBertTokenizer",
],
"models.mobilenet_v1": [
"MOBILENET_V1_PRETRAINED_CONFIG_ARCHIVE_MAP",
"MobileNetV1Config",
],
"models.mobilenet_v2": [
"MOBILENET_V2_PRETRAINED_CONFIG_ARCHIVE_MAP",
"MobileNetV2Config",
],
"models.mobilevit": ["MOBILEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileViTConfig"], "models.mobilevit": ["MOBILEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileViTConfig"],
"models.mobilevitv2": ["MOBILEVITV2_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileViTV2Config"], "models.mobilevitv2": [
"models.mpnet": ["MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "MPNetConfig", "MPNetTokenizer"], "MOBILEVITV2_PRETRAINED_CONFIG_ARCHIVE_MAP",
"MobileViTV2Config",
],
"models.mpnet": [
"MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP",
"MPNetConfig",
"MPNetTokenizer",
],
"models.mpt": ["MPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MptConfig"], "models.mpt": ["MPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MptConfig"],
"models.mra": ["MRA_PRETRAINED_CONFIG_ARCHIVE_MAP", "MraConfig"], "models.mra": ["MRA_PRETRAINED_CONFIG_ARCHIVE_MAP", "MraConfig"],
"models.mt5": ["MT5Config"], "models.mt5": ["MT5Config"],
...@@ -482,8 +633,16 @@ _import_structure = { ...@@ -482,8 +633,16 @@ _import_structure = {
"NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"NystromformerConfig", "NystromformerConfig",
], ],
"models.oneformer": ["ONEFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "OneFormerConfig", "OneFormerProcessor"], "models.oneformer": [
"models.openai": ["OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "OpenAIGPTConfig", "OpenAIGPTTokenizer"], "ONEFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"OneFormerConfig",
"OneFormerProcessor",
],
"models.openai": [
"OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"OpenAIGPTConfig",
"OpenAIGPTTokenizer",
],
"models.opt": ["OPTConfig"], "models.opt": ["OPTConfig"],
"models.owlv2": [ "models.owlv2": [
"OWLV2_PRETRAINED_CONFIG_ARCHIVE_MAP", "OWLV2_PRETRAINED_CONFIG_ARCHIVE_MAP",
...@@ -499,10 +658,22 @@ _import_structure = { ...@@ -499,10 +658,22 @@ _import_structure = {
"OwlViTTextConfig", "OwlViTTextConfig",
"OwlViTVisionConfig", "OwlViTVisionConfig",
], ],
"models.patchtsmixer": [
"PATCHTSMIXER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"PatchTSMixerConfig",
],
"models.patchtst": ["PATCHTST_PRETRAINED_CONFIG_ARCHIVE_MAP", "PatchTSTConfig"], "models.patchtst": ["PATCHTST_PRETRAINED_CONFIG_ARCHIVE_MAP", "PatchTSTConfig"],
"models.pegasus": ["PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP", "PegasusConfig", "PegasusTokenizer"], "models.pegasus": [
"PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP",
"PegasusConfig",
"PegasusTokenizer",
],
"models.pegasus_x": ["PEGASUS_X_PRETRAINED_CONFIG_ARCHIVE_MAP", "PegasusXConfig"], "models.pegasus_x": ["PEGASUS_X_PRETRAINED_CONFIG_ARCHIVE_MAP", "PegasusXConfig"],
"models.perceiver": ["PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PerceiverConfig", "PerceiverTokenizer"], "models.perceiver": [
"PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"PerceiverConfig",
"PerceiverTokenizer",
],
"models.persimmon": ["PERSIMMON_PRETRAINED_CONFIG_ARCHIVE_MAP", "PersimmonConfig"], "models.persimmon": ["PERSIMMON_PRETRAINED_CONFIG_ARCHIVE_MAP", "PersimmonConfig"],
"models.phi": ["PHI_PRETRAINED_CONFIG_ARCHIVE_MAP", "PhiConfig"], "models.phi": ["PHI_PRETRAINED_CONFIG_ARCHIVE_MAP", "PhiConfig"],
"models.phobert": ["PhobertTokenizer"], "models.phobert": ["PhobertTokenizer"],
...@@ -514,24 +685,50 @@ _import_structure = { ...@@ -514,24 +685,50 @@ _import_structure = {
"Pix2StructVisionConfig", "Pix2StructVisionConfig",
], ],
"models.plbart": ["PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP", "PLBartConfig"], "models.plbart": ["PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP", "PLBartConfig"],
"models.poolformer": ["POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PoolFormerConfig"], "models.poolformer": [
"POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"PoolFormerConfig",
],
"models.pop2piano": [ "models.pop2piano": [
"POP2PIANO_PRETRAINED_CONFIG_ARCHIVE_MAP", "POP2PIANO_PRETRAINED_CONFIG_ARCHIVE_MAP",
"Pop2PianoConfig", "Pop2PianoConfig",
], ],
"models.prophetnet": ["PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "ProphetNetConfig", "ProphetNetTokenizer"], "models.prophetnet": [
"PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP",
"ProphetNetConfig",
"ProphetNetTokenizer",
],
"models.pvt": ["PVT_PRETRAINED_CONFIG_ARCHIVE_MAP", "PvtConfig"], "models.pvt": ["PVT_PRETRAINED_CONFIG_ARCHIVE_MAP", "PvtConfig"],
"models.qdqbert": ["QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "QDQBertConfig"], "models.qdqbert": ["QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "QDQBertConfig"],
"models.rag": ["RagConfig", "RagRetriever", "RagTokenizer"], "models.rag": ["RagConfig", "RagRetriever", "RagTokenizer"],
"models.realm": ["REALM_PRETRAINED_CONFIG_ARCHIVE_MAP", "RealmConfig", "RealmTokenizer"], "models.realm": [
"REALM_PRETRAINED_CONFIG_ARCHIVE_MAP",
"RealmConfig",
"RealmTokenizer",
],
"models.reformer": ["REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "ReformerConfig"], "models.reformer": ["REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "ReformerConfig"],
"models.regnet": ["REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "RegNetConfig"], "models.regnet": ["REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "RegNetConfig"],
"models.rembert": ["REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RemBertConfig"], "models.rembert": ["REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RemBertConfig"],
"models.resnet": ["RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "ResNetConfig"], "models.resnet": ["RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "ResNetConfig"],
"models.roberta": ["ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "RobertaConfig", "RobertaTokenizer"], "models.roberta": [
"models.roberta_prelayernorm": ["ROBERTA_PRELAYERNORM_PRETRAINED_CONFIG_ARCHIVE_MAP", "RobertaPreLayerNormConfig"], "ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP",
"models.roc_bert": ["ROC_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RoCBertConfig", "RoCBertTokenizer"], "RobertaConfig",
"models.roformer": ["ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "RoFormerConfig", "RoFormerTokenizer"], "RobertaTokenizer",
],
"models.roberta_prelayernorm": [
"ROBERTA_PRELAYERNORM_PRETRAINED_CONFIG_ARCHIVE_MAP",
"RobertaPreLayerNormConfig",
],
"models.roc_bert": [
"ROC_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"RoCBertConfig",
"RoCBertTokenizer",
],
"models.roformer": [
"ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"RoFormerConfig",
"RoFormerTokenizer",
],
"models.rwkv": ["RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP", "RwkvConfig"], "models.rwkv": ["RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP", "RwkvConfig"],
"models.sam": [ "models.sam": [
"SAM_PRETRAINED_CONFIG_ARCHIVE_MAP", "SAM_PRETRAINED_CONFIG_ARCHIVE_MAP",
...@@ -575,21 +772,45 @@ _import_structure = { ...@@ -575,21 +772,45 @@ _import_structure = {
"SpeechT5HifiGanConfig", "SpeechT5HifiGanConfig",
"SpeechT5Processor", "SpeechT5Processor",
], ],
"models.splinter": ["SPLINTER_PRETRAINED_CONFIG_ARCHIVE_MAP", "SplinterConfig", "SplinterTokenizer"], "models.splinter": [
"models.squeezebert": ["SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "SqueezeBertConfig", "SqueezeBertTokenizer"], "SPLINTER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"models.swiftformer": ["SWIFTFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "SwiftFormerConfig"], "SplinterConfig",
"SplinterTokenizer",
],
"models.squeezebert": [
"SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"SqueezeBertConfig",
"SqueezeBertTokenizer",
],
"models.swiftformer": [
"SWIFTFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"SwiftFormerConfig",
],
"models.swin": ["SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP", "SwinConfig"], "models.swin": ["SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP", "SwinConfig"],
"models.swin2sr": ["SWIN2SR_PRETRAINED_CONFIG_ARCHIVE_MAP", "Swin2SRConfig"], "models.swin2sr": ["SWIN2SR_PRETRAINED_CONFIG_ARCHIVE_MAP", "Swin2SRConfig"],
"models.swinv2": ["SWINV2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Swinv2Config"], "models.swinv2": ["SWINV2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Swinv2Config"],
"models.switch_transformers": ["SWITCH_TRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP", "SwitchTransformersConfig"], "models.switch_transformers": [
"SWITCH_TRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP",
"SwitchTransformersConfig",
],
"models.t5": ["T5_PRETRAINED_CONFIG_ARCHIVE_MAP", "T5Config"], "models.t5": ["T5_PRETRAINED_CONFIG_ARCHIVE_MAP", "T5Config"],
"models.table_transformer": ["TABLE_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "TableTransformerConfig"], "models.table_transformer": [
"models.tapas": ["TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP", "TapasConfig", "TapasTokenizer"], "TABLE_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"TableTransformerConfig",
],
"models.tapas": [
"TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP",
"TapasConfig",
"TapasTokenizer",
],
"models.time_series_transformer": [ "models.time_series_transformer": [
"TIME_SERIES_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "TIME_SERIES_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"TimeSeriesTransformerConfig", "TimeSeriesTransformerConfig",
], ],
"models.timesformer": ["TIMESFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "TimesformerConfig"], "models.timesformer": [
"TIMESFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"TimesformerConfig",
],
"models.timm_backbone": ["TimmBackboneConfig"], "models.timm_backbone": ["TimmBackboneConfig"],
"models.trocr": [ "models.trocr": [
"TROCR_PRETRAINED_CONFIG_ARCHIVE_MAP", "TROCR_PRETRAINED_CONFIG_ARCHIVE_MAP",
...@@ -631,10 +852,19 @@ _import_structure = { ...@@ -631,10 +852,19 @@ _import_structure = {
"ViltProcessor", "ViltProcessor",
], ],
"models.vision_encoder_decoder": ["VisionEncoderDecoderConfig"], "models.vision_encoder_decoder": ["VisionEncoderDecoderConfig"],
"models.vision_text_dual_encoder": ["VisionTextDualEncoderConfig", "VisionTextDualEncoderProcessor"], "models.vision_text_dual_encoder": [
"models.visual_bert": ["VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "VisualBertConfig"], "VisionTextDualEncoderConfig",
"VisionTextDualEncoderProcessor",
],
"models.visual_bert": [
"VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"VisualBertConfig",
],
"models.vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"], "models.vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"],
"models.vit_hybrid": ["VIT_HYBRID_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTHybridConfig"], "models.vit_hybrid": [
"VIT_HYBRID_PRETRAINED_CONFIG_ARCHIVE_MAP",
"ViTHybridConfig",
],
"models.vit_mae": ["VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTMAEConfig"], "models.vit_mae": ["VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTMAEConfig"],
"models.vit_msn": ["VIT_MSN_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTMSNConfig"], "models.vit_msn": ["VIT_MSN_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTMSNConfig"],
"models.vitdet": ["VITDET_PRETRAINED_CONFIG_ARCHIVE_MAP", "VitDetConfig"], "models.vitdet": ["VITDET_PRETRAINED_CONFIG_ARCHIVE_MAP", "VitDetConfig"],
...@@ -682,9 +912,18 @@ _import_structure = { ...@@ -682,9 +912,18 @@ _import_structure = {
], ],
"models.xglm": ["XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XGLMConfig"], "models.xglm": ["XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XGLMConfig"],
"models.xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMTokenizer"], "models.xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMTokenizer"],
"models.xlm_prophetnet": ["XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMProphetNetConfig"], "models.xlm_prophetnet": [
"models.xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"], "XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP",
"models.xlm_roberta_xl": ["XLM_ROBERTA_XL_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaXLConfig"], "XLMProphetNetConfig",
],
"models.xlm_roberta": [
"XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP",
"XLMRobertaConfig",
],
"models.xlm_roberta_xl": [
"XLM_ROBERTA_XL_PRETRAINED_CONFIG_ARCHIVE_MAP",
"XLMRobertaXLConfig",
],
"models.xlnet": ["XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLNetConfig"], "models.xlnet": ["XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLNetConfig"],
"models.xmod": ["XMOD_PRETRAINED_CONFIG_ARCHIVE_MAP", "XmodConfig"], "models.xmod": ["XMOD_PRETRAINED_CONFIG_ARCHIVE_MAP", "XmodConfig"],
"models.yolos": ["YOLOS_PRETRAINED_CONFIG_ARCHIVE_MAP", "YolosConfig"], "models.yolos": ["YOLOS_PRETRAINED_CONFIG_ARCHIVE_MAP", "YolosConfig"],
...@@ -760,7 +999,13 @@ _import_structure = { ...@@ -760,7 +999,13 @@ _import_structure = {
"TrainerControl", "TrainerControl",
"TrainerState", "TrainerState",
], ],
"trainer_utils": ["EvalPrediction", "IntervalStrategy", "SchedulerType", "enable_full_determinism", "set_seed"], "trainer_utils": [
"EvalPrediction",
"IntervalStrategy",
"SchedulerType",
"enable_full_determinism",
"set_seed",
],
"training_args": ["TrainingArguments"], "training_args": ["TrainingArguments"],
"training_args_seq2seq": ["Seq2SeqTrainingArguments"], "training_args_seq2seq": ["Seq2SeqTrainingArguments"],
"training_args_tf": ["TFTrainingArguments"], "training_args_tf": ["TFTrainingArguments"],
...@@ -885,7 +1130,11 @@ else: ...@@ -885,7 +1130,11 @@ else:
_import_structure["models.deprecated.retribert"].append("RetriBertTokenizerFast") _import_structure["models.deprecated.retribert"].append("RetriBertTokenizerFast")
_import_structure["models.distilbert"].append("DistilBertTokenizerFast") _import_structure["models.distilbert"].append("DistilBertTokenizerFast")
_import_structure["models.dpr"].extend( _import_structure["models.dpr"].extend(
["DPRContextEncoderTokenizerFast", "DPRQuestionEncoderTokenizerFast", "DPRReaderTokenizerFast"] [
"DPRContextEncoderTokenizerFast",
"DPRQuestionEncoderTokenizerFast",
"DPRReaderTokenizerFast",
]
) )
_import_structure["models.electra"].append("ElectraTokenizerFast") _import_structure["models.electra"].append("ElectraTokenizerFast")
_import_structure["models.fnet"].append("FNetTokenizerFast") _import_structure["models.fnet"].append("FNetTokenizerFast")
...@@ -939,7 +1188,10 @@ except OptionalDependencyNotAvailable: ...@@ -939,7 +1188,10 @@ except OptionalDependencyNotAvailable:
name for name in dir(dummy_sentencepiece_and_tokenizers_objects) if not name.startswith("_") name for name in dir(dummy_sentencepiece_and_tokenizers_objects) if not name.startswith("_")
] ]
else: else:
_import_structure["convert_slow_tokenizer"] = ["SLOW_TO_FAST_CONVERTERS", "convert_slow_tokenizer"] _import_structure["convert_slow_tokenizer"] = [
"SLOW_TO_FAST_CONVERTERS",
"convert_slow_tokenizer",
]
# Tensorflow-text-specific objects # Tensorflow-text-specific objects
try: try:
...@@ -1657,10 +1909,19 @@ else: ...@@ -1657,10 +1909,19 @@ else:
) )
_import_structure["models.deprecated.mmbt"].extend(["MMBTForClassification", "MMBTModel", "ModalEmbeddings"]) _import_structure["models.deprecated.mmbt"].extend(["MMBTForClassification", "MMBTModel", "ModalEmbeddings"])
_import_structure["models.deprecated.open_llama"].extend( _import_structure["models.deprecated.open_llama"].extend(
["OpenLlamaForCausalLM", "OpenLlamaForSequenceClassification", "OpenLlamaModel", "OpenLlamaPreTrainedModel"] [
"OpenLlamaForCausalLM",
"OpenLlamaForSequenceClassification",
"OpenLlamaModel",
"OpenLlamaPreTrainedModel",
]
) )
_import_structure["models.deprecated.retribert"].extend( _import_structure["models.deprecated.retribert"].extend(
["RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST", "RetriBertModel", "RetriBertPreTrainedModel"] [
"RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"RetriBertModel",
"RetriBertPreTrainedModel",
]
) )
_import_structure["models.deprecated.trajectory_transformer"].extend( _import_structure["models.deprecated.trajectory_transformer"].extend(
[ [
...@@ -2165,7 +2426,12 @@ else: ...@@ -2165,7 +2426,12 @@ else:
] ]
) )
_import_structure["models.llama"].extend( _import_structure["models.llama"].extend(
["LlamaForCausalLM", "LlamaForSequenceClassification", "LlamaModel", "LlamaPreTrainedModel"] [
"LlamaForCausalLM",
"LlamaForSequenceClassification",
"LlamaModel",
"LlamaPreTrainedModel",
]
) )
_import_structure["models.longformer"].extend( _import_structure["models.longformer"].extend(
[ [
...@@ -2298,7 +2564,12 @@ else: ...@@ -2298,7 +2564,12 @@ else:
] ]
) )
_import_structure["models.mistral"].extend( _import_structure["models.mistral"].extend(
["MistralForCausalLM", "MistralForSequenceClassification", "MistralModel", "MistralPreTrainedModel"] [
"MistralForCausalLM",
"MistralForSequenceClassification",
"MistralModel",
"MistralPreTrainedModel",
]
) )
_import_structure["models.mobilebert"].extend( _import_structure["models.mobilebert"].extend(
[ [
...@@ -2515,6 +2786,17 @@ else: ...@@ -2515,6 +2786,17 @@ else:
"OwlViTVisionModel", "OwlViTVisionModel",
] ]
) )
_import_structure["models.patchtsmixer"].extend(
[
"PATCHTSMIXER_PRETRAINED_MODEL_ARCHIVE_LIST",
"PatchTSMixerForPrediction",
"PatchTSMixerForPretraining",
"PatchTSMixerForRegression",
"PatchTSMixerForTimeSeriesClassification",
"PatchTSMixerModel",
"PatchTSMixerPreTrainedModel",
]
)
_import_structure["models.patchtst"].extend( _import_structure["models.patchtst"].extend(
[ [
"PATCHTST_PRETRAINED_MODEL_ARCHIVE_LIST", "PATCHTST_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -2527,7 +2809,12 @@ else: ...@@ -2527,7 +2809,12 @@ else:
] ]
) )
_import_structure["models.pegasus"].extend( _import_structure["models.pegasus"].extend(
["PegasusForCausalLM", "PegasusForConditionalGeneration", "PegasusModel", "PegasusPreTrainedModel"] [
"PegasusForCausalLM",
"PegasusForConditionalGeneration",
"PegasusModel",
"PegasusPreTrainedModel",
]
) )
_import_structure["models.pegasus_x"].extend( _import_structure["models.pegasus_x"].extend(
[ [
...@@ -2553,7 +2840,12 @@ else: ...@@ -2553,7 +2840,12 @@ else:
] ]
) )
_import_structure["models.persimmon"].extend( _import_structure["models.persimmon"].extend(
["PersimmonForCausalLM", "PersimmonForSequenceClassification", "PersimmonModel", "PersimmonPreTrainedModel"] [
"PersimmonForCausalLM",
"PersimmonForSequenceClassification",
"PersimmonModel",
"PersimmonPreTrainedModel",
]
) )
_import_structure["models.phi"].extend( _import_structure["models.phi"].extend(
[ [
...@@ -2635,7 +2927,12 @@ else: ...@@ -2635,7 +2927,12 @@ else:
] ]
) )
_import_structure["models.rag"].extend( _import_structure["models.rag"].extend(
["RagModel", "RagPreTrainedModel", "RagSequenceForGeneration", "RagTokenForGeneration"] [
"RagModel",
"RagPreTrainedModel",
"RagSequenceForGeneration",
"RagTokenForGeneration",
]
) )
_import_structure["models.realm"].extend( _import_structure["models.realm"].extend(
[ [
...@@ -2961,7 +3258,11 @@ else: ...@@ -2961,7 +3258,11 @@ else:
) )
_import_structure["models.timm_backbone"].extend(["TimmBackbone"]) _import_structure["models.timm_backbone"].extend(["TimmBackbone"])
_import_structure["models.trocr"].extend( _import_structure["models.trocr"].extend(
["TROCR_PRETRAINED_MODEL_ARCHIVE_LIST", "TrOCRForCausalLM", "TrOCRPreTrainedModel"] [
"TROCR_PRETRAINED_MODEL_ARCHIVE_LIST",
"TrOCRForCausalLM",
"TrOCRPreTrainedModel",
]
) )
_import_structure["models.tvlt"].extend( _import_structure["models.tvlt"].extend(
[ [
...@@ -3298,7 +3599,11 @@ else: ...@@ -3298,7 +3599,11 @@ else:
"get_polynomial_decay_schedule_with_warmup", "get_polynomial_decay_schedule_with_warmup",
"get_scheduler", "get_scheduler",
] ]
_import_structure["pytorch_utils"] = ["Conv1D", "apply_chunking_to_forward", "prune_layer"] _import_structure["pytorch_utils"] = [
"Conv1D",
"apply_chunking_to_forward",
"prune_layer",
]
_import_structure["sagemaker"] = [] _import_structure["sagemaker"] = []
_import_structure["time_series_utils"] = [] _import_structure["time_series_utils"] = []
_import_structure["trainer"] = ["Trainer"] _import_structure["trainer"] = ["Trainer"]
...@@ -3411,7 +3716,12 @@ else: ...@@ -3411,7 +3716,12 @@ else:
] ]
) )
_import_structure["models.bart"].extend( _import_structure["models.bart"].extend(
["TFBartForConditionalGeneration", "TFBartForSequenceClassification", "TFBartModel", "TFBartPretrainedModel"] [
"TFBartForConditionalGeneration",
"TFBartForSequenceClassification",
"TFBartModel",
"TFBartPretrainedModel",
]
) )
_import_structure["models.bert"].extend( _import_structure["models.bert"].extend(
[ [
...@@ -3431,10 +3741,18 @@ else: ...@@ -3431,10 +3741,18 @@ else:
] ]
) )
_import_structure["models.blenderbot"].extend( _import_structure["models.blenderbot"].extend(
["TFBlenderbotForConditionalGeneration", "TFBlenderbotModel", "TFBlenderbotPreTrainedModel"] [
"TFBlenderbotForConditionalGeneration",
"TFBlenderbotModel",
"TFBlenderbotPreTrainedModel",
]
) )
_import_structure["models.blenderbot_small"].extend( _import_structure["models.blenderbot_small"].extend(
["TFBlenderbotSmallForConditionalGeneration", "TFBlenderbotSmallModel", "TFBlenderbotSmallPreTrainedModel"] [
"TFBlenderbotSmallForConditionalGeneration",
"TFBlenderbotSmallModel",
"TFBlenderbotSmallPreTrainedModel",
]
) )
_import_structure["models.blip"].extend( _import_structure["models.blip"].extend(
[ [
...@@ -3795,7 +4113,11 @@ else: ...@@ -3795,7 +4113,11 @@ else:
] ]
) )
_import_structure["models.pegasus"].extend( _import_structure["models.pegasus"].extend(
["TFPegasusForConditionalGeneration", "TFPegasusModel", "TFPegasusPreTrainedModel"] [
"TFPegasusForConditionalGeneration",
"TFPegasusModel",
"TFPegasusPreTrainedModel",
]
) )
_import_structure["models.rag"].extend( _import_structure["models.rag"].extend(
[ [
...@@ -4010,7 +4332,12 @@ else: ...@@ -4010,7 +4332,12 @@ else:
"TFXLNetPreTrainedModel", "TFXLNetPreTrainedModel",
] ]
) )
_import_structure["optimization_tf"] = ["AdamWeightDecay", "GradientAccumulator", "WarmUp", "create_optimizer"] _import_structure["optimization_tf"] = [
"AdamWeightDecay",
"GradientAccumulator",
"WarmUp",
"create_optimizer",
]
_import_structure["tf_utils"] = [] _import_structure["tf_utils"] = []
_import_structure["trainer_tf"] = ["TFTrainer"] _import_structure["trainer_tf"] = ["TFTrainer"]
...@@ -4025,7 +4352,9 @@ try: ...@@ -4025,7 +4352,9 @@ try:
): ):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from .utils import dummy_essentia_and_librosa_and_pretty_midi_and_scipy_and_torch_objects from .utils import (
dummy_essentia_and_librosa_and_pretty_midi_and_scipy_and_torch_objects,
)
_import_structure["utils.dummy_essentia_and_librosa_and_pretty_midi_and_scipy_and_torch_objects"] = [ _import_structure["utils.dummy_essentia_and_librosa_and_pretty_midi_and_scipy_and_torch_objects"] = [
name name
...@@ -4164,7 +4493,11 @@ else: ...@@ -4164,7 +4493,11 @@ else:
] ]
) )
_import_structure["models.blenderbot"].extend( _import_structure["models.blenderbot"].extend(
["FlaxBlenderbotForConditionalGeneration", "FlaxBlenderbotModel", "FlaxBlenderbotPreTrainedModel"] [
"FlaxBlenderbotForConditionalGeneration",
"FlaxBlenderbotModel",
"FlaxBlenderbotPreTrainedModel",
]
) )
_import_structure["models.blenderbot_small"].extend( _import_structure["models.blenderbot_small"].extend(
[ [
...@@ -4222,7 +4555,11 @@ else: ...@@ -4222,7 +4555,11 @@ else:
) )
_import_structure["models.gptj"].extend(["FlaxGPTJForCausalLM", "FlaxGPTJModel", "FlaxGPTJPreTrainedModel"]) _import_structure["models.gptj"].extend(["FlaxGPTJForCausalLM", "FlaxGPTJModel", "FlaxGPTJPreTrainedModel"])
_import_structure["models.longt5"].extend( _import_structure["models.longt5"].extend(
["FlaxLongT5ForConditionalGeneration", "FlaxLongT5Model", "FlaxLongT5PreTrainedModel"] [
"FlaxLongT5ForConditionalGeneration",
"FlaxLongT5Model",
"FlaxLongT5PreTrainedModel",
]
) )
_import_structure["models.marian"].extend( _import_structure["models.marian"].extend(
[ [
...@@ -4256,10 +4593,18 @@ else: ...@@ -4256,10 +4593,18 @@ else:
] ]
) )
_import_structure["models.regnet"].extend( _import_structure["models.regnet"].extend(
["FlaxRegNetForImageClassification", "FlaxRegNetModel", "FlaxRegNetPreTrainedModel"] [
"FlaxRegNetForImageClassification",
"FlaxRegNetModel",
"FlaxRegNetPreTrainedModel",
]
) )
_import_structure["models.resnet"].extend( _import_structure["models.resnet"].extend(
["FlaxResNetForImageClassification", "FlaxResNetModel", "FlaxResNetPreTrainedModel"] [
"FlaxResNetForImageClassification",
"FlaxResNetModel",
"FlaxResNetPreTrainedModel",
]
) )
_import_structure["models.roberta"].extend( _import_structure["models.roberta"].extend(
[ [
...@@ -4298,13 +4643,23 @@ else: ...@@ -4298,13 +4643,23 @@ else:
) )
_import_structure["models.speech_encoder_decoder"].append("FlaxSpeechEncoderDecoderModel") _import_structure["models.speech_encoder_decoder"].append("FlaxSpeechEncoderDecoderModel")
_import_structure["models.t5"].extend( _import_structure["models.t5"].extend(
["FlaxT5EncoderModel", "FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"] [
"FlaxT5EncoderModel",
"FlaxT5ForConditionalGeneration",
"FlaxT5Model",
"FlaxT5PreTrainedModel",
]
) )
_import_structure["models.vision_encoder_decoder"].append("FlaxVisionEncoderDecoderModel") _import_structure["models.vision_encoder_decoder"].append("FlaxVisionEncoderDecoderModel")
_import_structure["models.vision_text_dual_encoder"].extend(["FlaxVisionTextDualEncoderModel"]) _import_structure["models.vision_text_dual_encoder"].extend(["FlaxVisionTextDualEncoderModel"])
_import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"]) _import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"])
_import_structure["models.wav2vec2"].extend( _import_structure["models.wav2vec2"].extend(
["FlaxWav2Vec2ForCTC", "FlaxWav2Vec2ForPreTraining", "FlaxWav2Vec2Model", "FlaxWav2Vec2PreTrainedModel"] [
"FlaxWav2Vec2ForCTC",
"FlaxWav2Vec2ForPreTraining",
"FlaxWav2Vec2Model",
"FlaxWav2Vec2PreTrainedModel",
]
) )
_import_structure["models.whisper"].extend( _import_structure["models.whisper"].extend(
[ [
...@@ -4465,13 +4820,28 @@ if TYPE_CHECKING: ...@@ -4465,13 +4820,28 @@ if TYPE_CHECKING:
WordpieceTokenizer, WordpieceTokenizer,
) )
from .models.bert_generation import BertGenerationConfig from .models.bert_generation import BertGenerationConfig
from .models.bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer from .models.bert_japanese import (
BertJapaneseTokenizer,
CharacterTokenizer,
MecabTokenizer,
)
from .models.bertweet import BertweetTokenizer from .models.bertweet import BertweetTokenizer
from .models.big_bird import BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP, BigBirdConfig from .models.big_bird import BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP, BigBirdConfig
from .models.bigbird_pegasus import BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP, BigBirdPegasusConfig from .models.bigbird_pegasus import (
from .models.biogpt import BIOGPT_PRETRAINED_CONFIG_ARCHIVE_MAP, BioGptConfig, BioGptTokenizer BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP,
BigBirdPegasusConfig,
)
from .models.biogpt import (
BIOGPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
BioGptConfig,
BioGptTokenizer,
)
from .models.bit import BIT_PRETRAINED_CONFIG_ARCHIVE_MAP, BitConfig from .models.bit import BIT_PRETRAINED_CONFIG_ARCHIVE_MAP, BitConfig
from .models.blenderbot import BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotConfig, BlenderbotTokenizer from .models.blenderbot import (
BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP,
BlenderbotConfig,
BlenderbotTokenizer,
)
from .models.blenderbot_small import ( from .models.blenderbot_small import (
BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP, BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
BlenderbotSmallConfig, BlenderbotSmallConfig,
...@@ -4499,10 +4869,21 @@ if TYPE_CHECKING: ...@@ -4499,10 +4869,21 @@ if TYPE_CHECKING:
BridgeTowerTextConfig, BridgeTowerTextConfig,
BridgeTowerVisionConfig, BridgeTowerVisionConfig,
) )
from .models.bros import BROS_PRETRAINED_CONFIG_ARCHIVE_MAP, BrosConfig, BrosProcessor from .models.bros import (
BROS_PRETRAINED_CONFIG_ARCHIVE_MAP,
BrosConfig,
BrosProcessor,
)
from .models.byt5 import ByT5Tokenizer from .models.byt5 import ByT5Tokenizer
from .models.camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig from .models.camembert import (
from .models.canine import CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP, CanineConfig, CanineTokenizer CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
CamembertConfig,
)
from .models.canine import (
CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP,
CanineConfig,
CanineTokenizer,
)
from .models.chinese_clip import ( from .models.chinese_clip import (
CHINESE_CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP, CHINESE_CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP,
ChineseCLIPConfig, ChineseCLIPConfig,
...@@ -4541,13 +4922,35 @@ if TYPE_CHECKING: ...@@ -4541,13 +4922,35 @@ if TYPE_CHECKING:
ClvpProcessor, ClvpProcessor,
ClvpTokenizer, ClvpTokenizer,
) )
from .models.codegen import CODEGEN_PRETRAINED_CONFIG_ARCHIVE_MAP, CodeGenConfig, CodeGenTokenizer from .models.codegen import (
from .models.conditional_detr import CONDITIONAL_DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, ConditionalDetrConfig CODEGEN_PRETRAINED_CONFIG_ARCHIVE_MAP,
from .models.convbert import CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvBertConfig, ConvBertTokenizer CodeGenConfig,
CodeGenTokenizer,
)
from .models.conditional_detr import (
CONDITIONAL_DETR_PRETRAINED_CONFIG_ARCHIVE_MAP,
ConditionalDetrConfig,
)
from .models.convbert import (
CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
ConvBertConfig,
ConvBertTokenizer,
)
from .models.convnext import CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvNextConfig from .models.convnext import CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvNextConfig
from .models.convnextv2 import CONVNEXTV2_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvNextV2Config from .models.convnextv2 import (
from .models.cpmant import CPMANT_PRETRAINED_CONFIG_ARCHIVE_MAP, CpmAntConfig, CpmAntTokenizer CONVNEXTV2_PRETRAINED_CONFIG_ARCHIVE_MAP,
from .models.ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig, CTRLTokenizer ConvNextV2Config,
)
from .models.cpmant import (
CPMANT_PRETRAINED_CONFIG_ARCHIVE_MAP,
CpmAntConfig,
CpmAntTokenizer,
)
from .models.ctrl import (
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
CTRLConfig,
CTRLTokenizer,
)
from .models.cvt import CVT_PRETRAINED_CONFIG_ARCHIVE_MAP, CvtConfig from .models.cvt import CVT_PRETRAINED_CONFIG_ARCHIVE_MAP, CvtConfig
from .models.data2vec import ( from .models.data2vec import (
DATA2VEC_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, DATA2VEC_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP,
...@@ -4556,13 +4959,23 @@ if TYPE_CHECKING: ...@@ -4556,13 +4959,23 @@ if TYPE_CHECKING:
Data2VecTextConfig, Data2VecTextConfig,
Data2VecVisionConfig, Data2VecVisionConfig,
) )
from .models.deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig, DebertaTokenizer from .models.deberta import (
from .models.deberta_v2 import DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaV2Config DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
DebertaConfig,
DebertaTokenizer,
)
from .models.deberta_v2 import (
DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP,
DebertaV2Config,
)
from .models.decision_transformer import ( from .models.decision_transformer import (
DECISION_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, DECISION_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
DecisionTransformerConfig, DecisionTransformerConfig,
) )
from .models.deformable_detr import DEFORMABLE_DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, DeformableDetrConfig from .models.deformable_detr import (
DEFORMABLE_DETR_PRETRAINED_CONFIG_ARCHIVE_MAP,
DeformableDetrConfig,
)
from .models.deit import DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, DeiTConfig from .models.deit import DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, DeiTConfig
from .models.deprecated.mctct import ( from .models.deprecated.mctct import (
MCTCT_PRETRAINED_CONFIG_ARCHIVE_MAP, MCTCT_PRETRAINED_CONFIG_ARCHIVE_MAP,
...@@ -4571,7 +4984,10 @@ if TYPE_CHECKING: ...@@ -4571,7 +4984,10 @@ if TYPE_CHECKING:
MCTCTProcessor, MCTCTProcessor,
) )
from .models.deprecated.mmbt import MMBTConfig from .models.deprecated.mmbt import MMBTConfig
from .models.deprecated.open_llama import OPEN_LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenLlamaConfig from .models.deprecated.open_llama import (
OPEN_LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP,
OpenLlamaConfig,
)
from .models.deprecated.retribert import ( from .models.deprecated.retribert import (
RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
RetriBertConfig, RetriBertConfig,
...@@ -4593,8 +5009,16 @@ if TYPE_CHECKING: ...@@ -4593,8 +5009,16 @@ if TYPE_CHECKING:
from .models.detr import DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, DetrConfig from .models.detr import DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, DetrConfig
from .models.dinat import DINAT_PRETRAINED_CONFIG_ARCHIVE_MAP, DinatConfig from .models.dinat import DINAT_PRETRAINED_CONFIG_ARCHIVE_MAP, DinatConfig
from .models.dinov2 import DINOV2_PRETRAINED_CONFIG_ARCHIVE_MAP, Dinov2Config from .models.dinov2 import DINOV2_PRETRAINED_CONFIG_ARCHIVE_MAP, Dinov2Config
from .models.distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig, DistilBertTokenizer from .models.distilbert import (
from .models.donut import DONUT_SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP, DonutProcessor, DonutSwinConfig DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
DistilBertConfig,
DistilBertTokenizer,
)
from .models.donut import (
DONUT_SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP,
DonutProcessor,
DonutSwinConfig,
)
from .models.dpr import ( from .models.dpr import (
DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, DPR_PRETRAINED_CONFIG_ARCHIVE_MAP,
DPRConfig, DPRConfig,
...@@ -4604,9 +5028,19 @@ if TYPE_CHECKING: ...@@ -4604,9 +5028,19 @@ if TYPE_CHECKING:
DPRReaderTokenizer, DPRReaderTokenizer,
) )
from .models.dpt import DPT_PRETRAINED_CONFIG_ARCHIVE_MAP, DPTConfig from .models.dpt import DPT_PRETRAINED_CONFIG_ARCHIVE_MAP, DPTConfig
from .models.efficientformer import EFFICIENTFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, EfficientFormerConfig from .models.efficientformer import (
from .models.efficientnet import EFFICIENTNET_PRETRAINED_CONFIG_ARCHIVE_MAP, EfficientNetConfig EFFICIENTFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
from .models.electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig, ElectraTokenizer EfficientFormerConfig,
)
from .models.efficientnet import (
EFFICIENTNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
EfficientNetConfig,
)
from .models.electra import (
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
ElectraConfig,
ElectraTokenizer,
)
from .models.encodec import ( from .models.encodec import (
ENCODEC_PRETRAINED_CONFIG_ARCHIVE_MAP, ENCODEC_PRETRAINED_CONFIG_ARCHIVE_MAP,
EncodecConfig, EncodecConfig,
...@@ -4617,7 +5051,11 @@ if TYPE_CHECKING: ...@@ -4617,7 +5051,11 @@ if TYPE_CHECKING:
from .models.ernie_m import ERNIE_M_PRETRAINED_CONFIG_ARCHIVE_MAP, ErnieMConfig from .models.ernie_m import ERNIE_M_PRETRAINED_CONFIG_ARCHIVE_MAP, ErnieMConfig
from .models.esm import ESM_PRETRAINED_CONFIG_ARCHIVE_MAP, EsmConfig, EsmTokenizer from .models.esm import ESM_PRETRAINED_CONFIG_ARCHIVE_MAP, EsmConfig, EsmTokenizer
from .models.falcon import FALCON_PRETRAINED_CONFIG_ARCHIVE_MAP, FalconConfig from .models.falcon import FALCON_PRETRAINED_CONFIG_ARCHIVE_MAP, FalconConfig
from .models.flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig, FlaubertTokenizer from .models.flaubert import (
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
FlaubertConfig,
FlaubertTokenizer,
)
from .models.flava import ( from .models.flava import (
FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP, FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP,
FlavaConfig, FlavaConfig,
...@@ -4628,23 +5066,49 @@ if TYPE_CHECKING: ...@@ -4628,23 +5066,49 @@ if TYPE_CHECKING:
) )
from .models.fnet import FNET_PRETRAINED_CONFIG_ARCHIVE_MAP, FNetConfig from .models.fnet import FNET_PRETRAINED_CONFIG_ARCHIVE_MAP, FNetConfig
from .models.focalnet import FOCALNET_PRETRAINED_CONFIG_ARCHIVE_MAP, FocalNetConfig from .models.focalnet import FOCALNET_PRETRAINED_CONFIG_ARCHIVE_MAP, FocalNetConfig
from .models.fsmt import FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, FSMTConfig, FSMTTokenizer from .models.fsmt import (
from .models.funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig, FunnelTokenizer FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP,
FSMTConfig,
FSMTTokenizer,
)
from .models.funnel import (
FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP,
FunnelConfig,
FunnelTokenizer,
)
from .models.fuyu import FUYU_PRETRAINED_CONFIG_ARCHIVE_MAP, FuyuConfig from .models.fuyu import FUYU_PRETRAINED_CONFIG_ARCHIVE_MAP, FuyuConfig
from .models.git import GIT_PRETRAINED_CONFIG_ARCHIVE_MAP, GitConfig, GitProcessor, GitVisionConfig from .models.git import (
GIT_PRETRAINED_CONFIG_ARCHIVE_MAP,
GitConfig,
GitProcessor,
GitVisionConfig,
)
from .models.glpn import GLPN_PRETRAINED_CONFIG_ARCHIVE_MAP, GLPNConfig from .models.glpn import GLPN_PRETRAINED_CONFIG_ARCHIVE_MAP, GLPNConfig
from .models.gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2Tokenizer from .models.gpt2 import (
from .models.gpt_bigcode import GPT_BIGCODE_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTBigCodeConfig GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
GPT2Config,
GPT2Tokenizer,
)
from .models.gpt_bigcode import (
GPT_BIGCODE_PRETRAINED_CONFIG_ARCHIVE_MAP,
GPTBigCodeConfig,
)
from .models.gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig from .models.gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig
from .models.gpt_neox import GPT_NEOX_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoXConfig from .models.gpt_neox import GPT_NEOX_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoXConfig
from .models.gpt_neox_japanese import GPT_NEOX_JAPANESE_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoXJapaneseConfig from .models.gpt_neox_japanese import (
GPT_NEOX_JAPANESE_PRETRAINED_CONFIG_ARCHIVE_MAP,
GPTNeoXJapaneseConfig,
)
from .models.gptj import GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTJConfig from .models.gptj import GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTJConfig
from .models.gptsan_japanese import ( from .models.gptsan_japanese import (
GPTSAN_JAPANESE_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTSAN_JAPANESE_PRETRAINED_CONFIG_ARCHIVE_MAP,
GPTSanJapaneseConfig, GPTSanJapaneseConfig,
GPTSanJapaneseTokenizer, GPTSanJapaneseTokenizer,
) )
from .models.graphormer import GRAPHORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, GraphormerConfig from .models.graphormer import (
GRAPHORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
GraphormerConfig,
)
from .models.groupvit import ( from .models.groupvit import (
GROUPVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, GROUPVIT_PRETRAINED_CONFIG_ARCHIVE_MAP,
GroupViTConfig, GroupViTConfig,
...@@ -4679,7 +5143,11 @@ if TYPE_CHECKING: ...@@ -4679,7 +5143,11 @@ if TYPE_CHECKING:
Kosmos2Config, Kosmos2Config,
Kosmos2Processor, Kosmos2Processor,
) )
from .models.layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig, LayoutLMTokenizer from .models.layoutlm import (
LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
LayoutLMConfig,
LayoutLMTokenizer,
)
from .models.layoutlmv2 import ( from .models.layoutlmv2 import (
LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP, LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP,
LayoutLMv2Config, LayoutLMv2Config,
...@@ -4701,10 +5169,22 @@ if TYPE_CHECKING: ...@@ -4701,10 +5169,22 @@ if TYPE_CHECKING:
from .models.levit import LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, LevitConfig from .models.levit import LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, LevitConfig
from .models.lilt import LILT_PRETRAINED_CONFIG_ARCHIVE_MAP, LiltConfig from .models.lilt import LILT_PRETRAINED_CONFIG_ARCHIVE_MAP, LiltConfig
from .models.llama import LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP, LlamaConfig from .models.llama import LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP, LlamaConfig
from .models.longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig, LongformerTokenizer from .models.longformer import (
LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
LongformerConfig,
LongformerTokenizer,
)
from .models.longt5 import LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP, LongT5Config from .models.longt5 import LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP, LongT5Config
from .models.luke import LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP, LukeConfig, LukeTokenizer from .models.luke import (
from .models.lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig, LxmertTokenizer LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP,
LukeConfig,
LukeTokenizer,
)
from .models.lxmert import (
LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
LxmertConfig,
LxmertTokenizer,
)
from .models.m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config from .models.m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config
from .models.marian import MarianConfig from .models.marian import MarianConfig
from .models.markuplm import ( from .models.markuplm import (
...@@ -4714,19 +5194,54 @@ if TYPE_CHECKING: ...@@ -4714,19 +5194,54 @@ if TYPE_CHECKING:
MarkupLMProcessor, MarkupLMProcessor,
MarkupLMTokenizer, MarkupLMTokenizer,
) )
from .models.mask2former import MASK2FORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, Mask2FormerConfig from .models.mask2former import (
from .models.maskformer import MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, MaskFormerConfig, MaskFormerSwinConfig MASK2FORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
Mask2FormerConfig,
)
from .models.maskformer import (
MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
MaskFormerConfig,
MaskFormerSwinConfig,
)
from .models.mbart import MBartConfig from .models.mbart import MBartConfig
from .models.mega import MEGA_PRETRAINED_CONFIG_ARCHIVE_MAP, MegaConfig from .models.mega import MEGA_PRETRAINED_CONFIG_ARCHIVE_MAP, MegaConfig
from .models.megatron_bert import MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MegatronBertConfig from .models.megatron_bert import (
from .models.mgp_str import MGP_STR_PRETRAINED_CONFIG_ARCHIVE_MAP, MgpstrConfig, MgpstrProcessor, MgpstrTokenizer MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
MegatronBertConfig,
)
from .models.mgp_str import (
MGP_STR_PRETRAINED_CONFIG_ARCHIVE_MAP,
MgpstrConfig,
MgpstrProcessor,
MgpstrTokenizer,
)
from .models.mistral import MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP, MistralConfig from .models.mistral import MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP, MistralConfig
from .models.mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig, MobileBertTokenizer from .models.mobilebert import (
from .models.mobilenet_v1 import MOBILENET_V1_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileNetV1Config MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
from .models.mobilenet_v2 import MOBILENET_V2_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileNetV2Config MobileBertConfig,
from .models.mobilevit import MOBILEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileViTConfig MobileBertTokenizer,
from .models.mobilevitv2 import MOBILEVITV2_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileViTV2Config )
from .models.mpnet import MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP, MPNetConfig, MPNetTokenizer from .models.mobilenet_v1 import (
MOBILENET_V1_PRETRAINED_CONFIG_ARCHIVE_MAP,
MobileNetV1Config,
)
from .models.mobilenet_v2 import (
MOBILENET_V2_PRETRAINED_CONFIG_ARCHIVE_MAP,
MobileNetV2Config,
)
from .models.mobilevit import (
MOBILEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP,
MobileViTConfig,
)
from .models.mobilevitv2 import (
MOBILEVITV2_PRETRAINED_CONFIG_ARCHIVE_MAP,
MobileViTV2Config,
)
from .models.mpnet import (
MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
MPNetConfig,
MPNetTokenizer,
)
from .models.mpt import MPT_PRETRAINED_CONFIG_ARCHIVE_MAP, MptConfig from .models.mpt import MPT_PRETRAINED_CONFIG_ARCHIVE_MAP, MptConfig
from .models.mra import MRA_PRETRAINED_CONFIG_ARCHIVE_MAP, MraConfig from .models.mra import MRA_PRETRAINED_CONFIG_ARCHIVE_MAP, MraConfig
from .models.mt5 import MT5Config from .models.mt5 import MT5Config
...@@ -4740,9 +5255,20 @@ if TYPE_CHECKING: ...@@ -4740,9 +5255,20 @@ if TYPE_CHECKING:
from .models.nezha import NEZHA_PRETRAINED_CONFIG_ARCHIVE_MAP, NezhaConfig from .models.nezha import NEZHA_PRETRAINED_CONFIG_ARCHIVE_MAP, NezhaConfig
from .models.nllb_moe import NLLB_MOE_PRETRAINED_CONFIG_ARCHIVE_MAP, NllbMoeConfig from .models.nllb_moe import NLLB_MOE_PRETRAINED_CONFIG_ARCHIVE_MAP, NllbMoeConfig
from .models.nougat import NougatProcessor from .models.nougat import NougatProcessor
from .models.nystromformer import NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, NystromformerConfig from .models.nystromformer import (
from .models.oneformer import ONEFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, OneFormerConfig, OneFormerProcessor NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
from .models.openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig, OpenAIGPTTokenizer NystromformerConfig,
)
from .models.oneformer import (
ONEFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
OneFormerConfig,
OneFormerProcessor,
)
from .models.openai import (
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
OpenAIGPTConfig,
OpenAIGPTTokenizer,
)
from .models.opt import OPTConfig from .models.opt import OPTConfig
from .models.owlv2 import ( from .models.owlv2 import (
OWLV2_PRETRAINED_CONFIG_ARCHIVE_MAP, OWLV2_PRETRAINED_CONFIG_ARCHIVE_MAP,
...@@ -4758,11 +5284,29 @@ if TYPE_CHECKING: ...@@ -4758,11 +5284,29 @@ if TYPE_CHECKING:
OwlViTTextConfig, OwlViTTextConfig,
OwlViTVisionConfig, OwlViTVisionConfig,
) )
from .models.patchtsmixer import (
PATCHTSMIXER_PRETRAINED_CONFIG_ARCHIVE_MAP,
PatchTSMixerConfig,
)
from .models.patchtst import PATCHTST_PRETRAINED_CONFIG_ARCHIVE_MAP, PatchTSTConfig from .models.patchtst import PATCHTST_PRETRAINED_CONFIG_ARCHIVE_MAP, PatchTSTConfig
from .models.pegasus import PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP, PegasusConfig, PegasusTokenizer from .models.pegasus import (
from .models.pegasus_x import PEGASUS_X_PRETRAINED_CONFIG_ARCHIVE_MAP, PegasusXConfig PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP,
from .models.perceiver import PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP, PerceiverConfig, PerceiverTokenizer PegasusConfig,
from .models.persimmon import PERSIMMON_PRETRAINED_CONFIG_ARCHIVE_MAP, PersimmonConfig PegasusTokenizer,
)
from .models.pegasus_x import (
PEGASUS_X_PRETRAINED_CONFIG_ARCHIVE_MAP,
PegasusXConfig,
)
from .models.perceiver import (
PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP,
PerceiverConfig,
PerceiverTokenizer,
)
from .models.persimmon import (
PERSIMMON_PRETRAINED_CONFIG_ARCHIVE_MAP,
PersimmonConfig,
)
from .models.phi import PHI_PRETRAINED_CONFIG_ARCHIVE_MAP, PhiConfig from .models.phi import PHI_PRETRAINED_CONFIG_ARCHIVE_MAP, PhiConfig
from .models.phobert import PhobertTokenizer from .models.phobert import PhobertTokenizer
from .models.pix2struct import ( from .models.pix2struct import (
...@@ -4773,27 +5317,50 @@ if TYPE_CHECKING: ...@@ -4773,27 +5317,50 @@ if TYPE_CHECKING:
Pix2StructVisionConfig, Pix2StructVisionConfig,
) )
from .models.plbart import PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP, PLBartConfig from .models.plbart import PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP, PLBartConfig
from .models.poolformer import POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, PoolFormerConfig from .models.poolformer import (
POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
PoolFormerConfig,
)
from .models.pop2piano import ( from .models.pop2piano import (
POP2PIANO_PRETRAINED_CONFIG_ARCHIVE_MAP, POP2PIANO_PRETRAINED_CONFIG_ARCHIVE_MAP,
Pop2PianoConfig, Pop2PianoConfig,
) )
from .models.prophetnet import PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ProphetNetConfig, ProphetNetTokenizer from .models.prophetnet import (
PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
ProphetNetConfig,
ProphetNetTokenizer,
)
from .models.pvt import PVT_PRETRAINED_CONFIG_ARCHIVE_MAP, PvtConfig from .models.pvt import PVT_PRETRAINED_CONFIG_ARCHIVE_MAP, PvtConfig
from .models.qdqbert import QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, QDQBertConfig from .models.qdqbert import QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, QDQBertConfig
from .models.rag import RagConfig, RagRetriever, RagTokenizer from .models.rag import RagConfig, RagRetriever, RagTokenizer
from .models.realm import REALM_PRETRAINED_CONFIG_ARCHIVE_MAP, RealmConfig, RealmTokenizer from .models.realm import (
REALM_PRETRAINED_CONFIG_ARCHIVE_MAP,
RealmConfig,
RealmTokenizer,
)
from .models.reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig from .models.reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
from .models.regnet import REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP, RegNetConfig from .models.regnet import REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP, RegNetConfig
from .models.rembert import REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RemBertConfig from .models.rembert import REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RemBertConfig
from .models.resnet import RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ResNetConfig from .models.resnet import RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ResNetConfig
from .models.roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, RobertaTokenizer from .models.roberta import (
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
RobertaConfig,
RobertaTokenizer,
)
from .models.roberta_prelayernorm import ( from .models.roberta_prelayernorm import (
ROBERTA_PRELAYERNORM_PRETRAINED_CONFIG_ARCHIVE_MAP, ROBERTA_PRELAYERNORM_PRETRAINED_CONFIG_ARCHIVE_MAP,
RobertaPreLayerNormConfig, RobertaPreLayerNormConfig,
) )
from .models.roc_bert import ROC_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RoCBertConfig, RoCBertTokenizer from .models.roc_bert import (
from .models.roformer import ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, RoFormerConfig, RoFormerTokenizer ROC_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
RoCBertConfig,
RoCBertTokenizer,
)
from .models.roformer import (
ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
RoFormerConfig,
RoFormerTokenizer,
)
from .models.rwkv import RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP, RwkvConfig from .models.rwkv import RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP, RwkvConfig
from .models.sam import ( from .models.sam import (
SAM_PRETRAINED_CONFIG_ARCHIVE_MAP, SAM_PRETRAINED_CONFIG_ARCHIVE_MAP,
...@@ -4813,7 +5380,10 @@ if TYPE_CHECKING: ...@@ -4813,7 +5380,10 @@ if TYPE_CHECKING:
SEAMLESS_M4T_V2_PRETRAINED_CONFIG_ARCHIVE_MAP, SEAMLESS_M4T_V2_PRETRAINED_CONFIG_ARCHIVE_MAP,
SeamlessM4Tv2Config, SeamlessM4Tv2Config,
) )
from .models.segformer import SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, SegformerConfig from .models.segformer import (
SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
SegformerConfig,
)
from .models.sew import SEW_PRETRAINED_CONFIG_ARCHIVE_MAP, SEWConfig from .models.sew import SEW_PRETRAINED_CONFIG_ARCHIVE_MAP, SEWConfig
from .models.sew_d import SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP, SEWDConfig from .models.sew_d import SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP, SEWDConfig
from .models.speech_encoder_decoder import SpeechEncoderDecoderConfig from .models.speech_encoder_decoder import SpeechEncoderDecoderConfig
...@@ -4837,32 +5407,71 @@ if TYPE_CHECKING: ...@@ -4837,32 +5407,71 @@ if TYPE_CHECKING:
SpeechT5HifiGanConfig, SpeechT5HifiGanConfig,
SpeechT5Processor, SpeechT5Processor,
) )
from .models.splinter import SPLINTER_PRETRAINED_CONFIG_ARCHIVE_MAP, SplinterConfig, SplinterTokenizer from .models.splinter import (
from .models.squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig, SqueezeBertTokenizer SPLINTER_PRETRAINED_CONFIG_ARCHIVE_MAP,
from .models.swiftformer import SWIFTFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, SwiftFormerConfig SplinterConfig,
SplinterTokenizer,
)
from .models.squeezebert import (
SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
SqueezeBertConfig,
SqueezeBertTokenizer,
)
from .models.swiftformer import (
SWIFTFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
SwiftFormerConfig,
)
from .models.swin import SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP, SwinConfig from .models.swin import SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP, SwinConfig
from .models.swin2sr import SWIN2SR_PRETRAINED_CONFIG_ARCHIVE_MAP, Swin2SRConfig from .models.swin2sr import SWIN2SR_PRETRAINED_CONFIG_ARCHIVE_MAP, Swin2SRConfig
from .models.swinv2 import SWINV2_PRETRAINED_CONFIG_ARCHIVE_MAP, Swinv2Config from .models.swinv2 import SWINV2_PRETRAINED_CONFIG_ARCHIVE_MAP, Swinv2Config
from .models.switch_transformers import SWITCH_TRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP, SwitchTransformersConfig from .models.switch_transformers import (
SWITCH_TRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP,
SwitchTransformersConfig,
)
from .models.t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config from .models.t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
from .models.table_transformer import TABLE_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, TableTransformerConfig from .models.table_transformer import (
from .models.tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig, TapasTokenizer TABLE_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
TableTransformerConfig,
)
from .models.tapas import (
TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP,
TapasConfig,
TapasTokenizer,
)
from .models.time_series_transformer import ( from .models.time_series_transformer import (
TIME_SERIES_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, TIME_SERIES_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
TimeSeriesTransformerConfig, TimeSeriesTransformerConfig,
) )
from .models.timesformer import TIMESFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, TimesformerConfig from .models.timesformer import (
TIMESFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
TimesformerConfig,
)
from .models.timm_backbone import TimmBackboneConfig from .models.timm_backbone import TimmBackboneConfig
from .models.trocr import TROCR_PRETRAINED_CONFIG_ARCHIVE_MAP, TrOCRConfig, TrOCRProcessor from .models.trocr import (
from .models.tvlt import TVLT_PRETRAINED_CONFIG_ARCHIVE_MAP, TvltConfig, TvltFeatureExtractor, TvltProcessor TROCR_PRETRAINED_CONFIG_ARCHIVE_MAP,
TrOCRConfig,
TrOCRProcessor,
)
from .models.tvlt import (
TVLT_PRETRAINED_CONFIG_ARCHIVE_MAP,
TvltConfig,
TvltFeatureExtractor,
TvltProcessor,
)
from .models.tvp import ( from .models.tvp import (
TVP_PRETRAINED_CONFIG_ARCHIVE_MAP, TVP_PRETRAINED_CONFIG_ARCHIVE_MAP,
TvpConfig, TvpConfig,
TvpProcessor, TvpProcessor,
) )
from .models.umt5 import UMT5Config from .models.umt5 import UMT5Config
from .models.unispeech import UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP, UniSpeechConfig from .models.unispeech import (
from .models.unispeech_sat import UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP, UniSpeechSatConfig UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP,
UniSpeechConfig,
)
from .models.unispeech_sat import (
UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP,
UniSpeechSatConfig,
)
from .models.univnet import ( from .models.univnet import (
UNIVNET_PRETRAINED_CONFIG_ARCHIVE_MAP, UNIVNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
UnivNetConfig, UnivNetConfig,
...@@ -4878,10 +5487,19 @@ if TYPE_CHECKING: ...@@ -4878,10 +5487,19 @@ if TYPE_CHECKING:
ViltProcessor, ViltProcessor,
) )
from .models.vision_encoder_decoder import VisionEncoderDecoderConfig from .models.vision_encoder_decoder import VisionEncoderDecoderConfig
from .models.vision_text_dual_encoder import VisionTextDualEncoderConfig, VisionTextDualEncoderProcessor from .models.vision_text_dual_encoder import (
from .models.visual_bert import VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, VisualBertConfig VisionTextDualEncoderConfig,
VisionTextDualEncoderProcessor,
)
from .models.visual_bert import (
VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
VisualBertConfig,
)
from .models.vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig from .models.vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
from .models.vit_hybrid import VIT_HYBRID_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTHybridConfig from .models.vit_hybrid import (
VIT_HYBRID_PRETRAINED_CONFIG_ARCHIVE_MAP,
ViTHybridConfig,
)
from .models.vit_mae import VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTMAEConfig from .models.vit_mae import VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTMAEConfig
from .models.vit_msn import VIT_MSN_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTMSNConfig from .models.vit_msn import VIT_MSN_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTMSNConfig
from .models.vitdet import VITDET_PRETRAINED_CONFIG_ARCHIVE_MAP, VitDetConfig from .models.vitdet import VITDET_PRETRAINED_CONFIG_ARCHIVE_MAP, VitDetConfig
...@@ -4900,7 +5518,10 @@ if TYPE_CHECKING: ...@@ -4900,7 +5518,10 @@ if TYPE_CHECKING:
Wav2Vec2Processor, Wav2Vec2Processor,
Wav2Vec2Tokenizer, Wav2Vec2Tokenizer,
) )
from .models.wav2vec2_conformer import WAV2VEC2_CONFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2ConformerConfig from .models.wav2vec2_conformer import (
WAV2VEC2_CONFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
Wav2Vec2ConformerConfig,
)
from .models.wav2vec2_phoneme import Wav2Vec2PhonemeCTCTokenizer from .models.wav2vec2_phoneme import Wav2Vec2PhonemeCTCTokenizer
from .models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM from .models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
from .models.wavlm import WAVLM_PRETRAINED_CONFIG_ARCHIVE_MAP, WavLMConfig from .models.wavlm import WAVLM_PRETRAINED_CONFIG_ARCHIVE_MAP, WavLMConfig
...@@ -4920,9 +5541,18 @@ if TYPE_CHECKING: ...@@ -4920,9 +5541,18 @@ if TYPE_CHECKING:
) )
from .models.xglm import XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XGLMConfig from .models.xglm import XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XGLMConfig
from .models.xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig, XLMTokenizer from .models.xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig, XLMTokenizer
from .models.xlm_prophetnet import XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMProphetNetConfig from .models.xlm_prophetnet import (
from .models.xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
from .models.xlm_roberta_xl import XLM_ROBERTA_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaXLConfig XLMProphetNetConfig,
)
from .models.xlm_roberta import (
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLMRobertaConfig,
)
from .models.xlm_roberta_xl import (
XLM_ROBERTA_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLMRobertaXLConfig,
)
from .models.xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig from .models.xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig
from .models.xmod import XMOD_PRETRAINED_CONFIG_ARCHIVE_MAP, XmodConfig from .models.xmod import XMOD_PRETRAINED_CONFIG_ARCHIVE_MAP, XmodConfig
from .models.yolos import YOLOS_PRETRAINED_CONFIG_ARCHIVE_MAP, YolosConfig from .models.yolos import YOLOS_PRETRAINED_CONFIG_ARCHIVE_MAP, YolosConfig
...@@ -5004,7 +5634,13 @@ if TYPE_CHECKING: ...@@ -5004,7 +5634,13 @@ if TYPE_CHECKING:
TrainerControl, TrainerControl,
TrainerState, TrainerState,
) )
from .trainer_utils import EvalPrediction, IntervalStrategy, SchedulerType, enable_full_determinism, set_seed from .trainer_utils import (
EvalPrediction,
IntervalStrategy,
SchedulerType,
enable_full_determinism,
set_seed,
)
from .training_args import TrainingArguments from .training_args import TrainingArguments
from .training_args_seq2seq import Seq2SeqTrainingArguments from .training_args_seq2seq import Seq2SeqTrainingArguments
from .training_args_tf import TFTrainingArguments from .training_args_tf import TFTrainingArguments
...@@ -5120,7 +5756,11 @@ if TYPE_CHECKING: ...@@ -5120,7 +5756,11 @@ if TYPE_CHECKING:
from .models.deberta_v2 import DebertaV2TokenizerFast from .models.deberta_v2 import DebertaV2TokenizerFast
from .models.deprecated.retribert import RetriBertTokenizerFast from .models.deprecated.retribert import RetriBertTokenizerFast
from .models.distilbert import DistilBertTokenizerFast from .models.distilbert import DistilBertTokenizerFast
from .models.dpr import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast, DPRReaderTokenizerFast from .models.dpr import (
DPRContextEncoderTokenizerFast,
DPRQuestionEncoderTokenizerFast,
DPRReaderTokenizerFast,
)
from .models.electra import ElectraTokenizerFast from .models.electra import ElectraTokenizerFast
from .models.fnet import FNetTokenizerFast from .models.fnet import FNetTokenizerFast
from .models.funnel import FunnelTokenizerFast from .models.funnel import FunnelTokenizerFast
...@@ -5168,7 +5808,10 @@ if TYPE_CHECKING: ...@@ -5168,7 +5808,10 @@ if TYPE_CHECKING:
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from .utils.dummies_sentencepiece_and_tokenizers_objects import * from .utils.dummies_sentencepiece_and_tokenizers_objects import *
else: else:
from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS, convert_slow_tokenizer from .convert_slow_tokenizer import (
SLOW_TO_FAST_CONVERTERS,
convert_slow_tokenizer,
)
try: try:
if not is_tensorflow_text_available(): if not is_tensorflow_text_available():
...@@ -5198,11 +5841,20 @@ if TYPE_CHECKING: ...@@ -5198,11 +5841,20 @@ if TYPE_CHECKING:
from .models.bit import BitImageProcessor from .models.bit import BitImageProcessor
from .models.blip import BlipImageProcessor from .models.blip import BlipImageProcessor
from .models.bridgetower import BridgeTowerImageProcessor from .models.bridgetower import BridgeTowerImageProcessor
from .models.chinese_clip import ChineseCLIPFeatureExtractor, ChineseCLIPImageProcessor from .models.chinese_clip import (
ChineseCLIPFeatureExtractor,
ChineseCLIPImageProcessor,
)
from .models.clip import CLIPFeatureExtractor, CLIPImageProcessor from .models.clip import CLIPFeatureExtractor, CLIPImageProcessor
from .models.conditional_detr import ConditionalDetrFeatureExtractor, ConditionalDetrImageProcessor from .models.conditional_detr import (
ConditionalDetrFeatureExtractor,
ConditionalDetrImageProcessor,
)
from .models.convnext import ConvNextFeatureExtractor, ConvNextImageProcessor from .models.convnext import ConvNextFeatureExtractor, ConvNextImageProcessor
from .models.deformable_detr import DeformableDetrFeatureExtractor, DeformableDetrImageProcessor from .models.deformable_detr import (
DeformableDetrFeatureExtractor,
DeformableDetrImageProcessor,
)
from .models.deit import DeiTFeatureExtractor, DeiTImageProcessor from .models.deit import DeiTFeatureExtractor, DeiTImageProcessor
from .models.deta import DetaImageProcessor from .models.deta import DetaImageProcessor
from .models.detr import DetrFeatureExtractor, DetrImageProcessor from .models.detr import DetrFeatureExtractor, DetrImageProcessor
...@@ -5210,18 +5862,37 @@ if TYPE_CHECKING: ...@@ -5210,18 +5862,37 @@ if TYPE_CHECKING:
from .models.dpt import DPTFeatureExtractor, DPTImageProcessor from .models.dpt import DPTFeatureExtractor, DPTImageProcessor
from .models.efficientformer import EfficientFormerImageProcessor from .models.efficientformer import EfficientFormerImageProcessor
from .models.efficientnet import EfficientNetImageProcessor from .models.efficientnet import EfficientNetImageProcessor
from .models.flava import FlavaFeatureExtractor, FlavaImageProcessor, FlavaProcessor from .models.flava import (
FlavaFeatureExtractor,
FlavaImageProcessor,
FlavaProcessor,
)
from .models.fuyu import FuyuImageProcessor, FuyuProcessor from .models.fuyu import FuyuImageProcessor, FuyuProcessor
from .models.glpn import GLPNFeatureExtractor, GLPNImageProcessor from .models.glpn import GLPNFeatureExtractor, GLPNImageProcessor
from .models.idefics import IdeficsImageProcessor from .models.idefics import IdeficsImageProcessor
from .models.imagegpt import ImageGPTFeatureExtractor, ImageGPTImageProcessor from .models.imagegpt import ImageGPTFeatureExtractor, ImageGPTImageProcessor
from .models.layoutlmv2 import LayoutLMv2FeatureExtractor, LayoutLMv2ImageProcessor from .models.layoutlmv2 import (
from .models.layoutlmv3 import LayoutLMv3FeatureExtractor, LayoutLMv3ImageProcessor LayoutLMv2FeatureExtractor,
LayoutLMv2ImageProcessor,
)
from .models.layoutlmv3 import (
LayoutLMv3FeatureExtractor,
LayoutLMv3ImageProcessor,
)
from .models.levit import LevitFeatureExtractor, LevitImageProcessor from .models.levit import LevitFeatureExtractor, LevitImageProcessor
from .models.mask2former import Mask2FormerImageProcessor from .models.mask2former import Mask2FormerImageProcessor
from .models.maskformer import MaskFormerFeatureExtractor, MaskFormerImageProcessor from .models.maskformer import (
from .models.mobilenet_v1 import MobileNetV1FeatureExtractor, MobileNetV1ImageProcessor MaskFormerFeatureExtractor,
from .models.mobilenet_v2 import MobileNetV2FeatureExtractor, MobileNetV2ImageProcessor MaskFormerImageProcessor,
)
from .models.mobilenet_v1 import (
MobileNetV1FeatureExtractor,
MobileNetV1ImageProcessor,
)
from .models.mobilenet_v2 import (
MobileNetV2FeatureExtractor,
MobileNetV2ImageProcessor,
)
from .models.mobilevit import MobileViTFeatureExtractor, MobileViTImageProcessor from .models.mobilevit import MobileViTFeatureExtractor, MobileViTImageProcessor
from .models.nougat import NougatImageProcessor from .models.nougat import NougatImageProcessor
from .models.oneformer import OneFormerImageProcessor from .models.oneformer import OneFormerImageProcessor
...@@ -5229,7 +5900,10 @@ if TYPE_CHECKING: ...@@ -5229,7 +5900,10 @@ if TYPE_CHECKING:
from .models.owlvit import OwlViTFeatureExtractor, OwlViTImageProcessor from .models.owlvit import OwlViTFeatureExtractor, OwlViTImageProcessor
from .models.perceiver import PerceiverFeatureExtractor, PerceiverImageProcessor from .models.perceiver import PerceiverFeatureExtractor, PerceiverImageProcessor
from .models.pix2struct import Pix2StructImageProcessor from .models.pix2struct import Pix2StructImageProcessor
from .models.poolformer import PoolFormerFeatureExtractor, PoolFormerImageProcessor from .models.poolformer import (
PoolFormerFeatureExtractor,
PoolFormerImageProcessor,
)
from .models.pvt import PvtImageProcessor from .models.pvt import PvtImageProcessor
from .models.sam import SamImageProcessor from .models.sam import SamImageProcessor
from .models.segformer import SegformerFeatureExtractor, SegformerImageProcessor from .models.segformer import SegformerFeatureExtractor, SegformerImageProcessor
...@@ -5767,7 +6441,11 @@ if TYPE_CHECKING: ...@@ -5767,7 +6441,11 @@ if TYPE_CHECKING:
MCTCTModel, MCTCTModel,
MCTCTPreTrainedModel, MCTCTPreTrainedModel,
) )
from .models.deprecated.mmbt import MMBTForClassification, MMBTModel, ModalEmbeddings from .models.deprecated.mmbt import (
MMBTForClassification,
MMBTModel,
ModalEmbeddings,
)
from .models.deprecated.open_llama import ( from .models.deprecated.open_llama import (
OpenLlamaForCausalLM, OpenLlamaForCausalLM,
OpenLlamaForSequenceClassification, OpenLlamaForSequenceClassification,
...@@ -5836,7 +6514,11 @@ if TYPE_CHECKING: ...@@ -5836,7 +6514,11 @@ if TYPE_CHECKING:
DistilBertModel, DistilBertModel,
DistilBertPreTrainedModel, DistilBertPreTrainedModel,
) )
from .models.donut import DONUT_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST, DonutSwinModel, DonutSwinPreTrainedModel from .models.donut import (
DONUT_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST,
DonutSwinModel,
DonutSwinPreTrainedModel,
)
from .models.dpr import ( from .models.dpr import (
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST, DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST, DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
...@@ -5972,7 +6654,11 @@ if TYPE_CHECKING: ...@@ -5972,7 +6654,11 @@ if TYPE_CHECKING:
FocalNetModel, FocalNetModel,
FocalNetPreTrainedModel, FocalNetPreTrainedModel,
) )
from .models.fsmt import FSMTForConditionalGeneration, FSMTModel, PretrainedFSMTModel from .models.fsmt import (
FSMTForConditionalGeneration,
FSMTModel,
PretrainedFSMTModel,
)
from .models.funnel import ( from .models.funnel import (
FUNNEL_PRETRAINED_MODEL_ARCHIVE_LIST, FUNNEL_PRETRAINED_MODEL_ARCHIVE_LIST,
FunnelBaseModel, FunnelBaseModel,
...@@ -6182,7 +6868,12 @@ if TYPE_CHECKING: ...@@ -6182,7 +6868,12 @@ if TYPE_CHECKING:
LiltModel, LiltModel,
LiltPreTrainedModel, LiltPreTrainedModel,
) )
from .models.llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel from .models.llama import (
LlamaForCausalLM,
LlamaForSequenceClassification,
LlamaModel,
LlamaPreTrainedModel,
)
from .models.longformer import ( from .models.longformer import (
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
LongformerForMaskedLM, LongformerForMaskedLM,
...@@ -6470,6 +7161,15 @@ if TYPE_CHECKING: ...@@ -6470,6 +7161,15 @@ if TYPE_CHECKING:
OwlViTTextModel, OwlViTTextModel,
OwlViTVisionModel, OwlViTVisionModel,
) )
from .models.patchtsmixer import (
PATCHTSMIXER_PRETRAINED_MODEL_ARCHIVE_LIST,
PatchTSMixerForPrediction,
PatchTSMixerForPretraining,
PatchTSMixerForRegression,
PatchTSMixerForTimeSeriesClassification,
PatchTSMixerModel,
PatchTSMixerPreTrainedModel,
)
from .models.patchtst import ( from .models.patchtst import (
PATCHTST_PRETRAINED_MODEL_ARCHIVE_LIST, PATCHTST_PRETRAINED_MODEL_ARCHIVE_LIST,
PatchTSTForClassification, PatchTSTForClassification,
...@@ -6573,7 +7273,12 @@ if TYPE_CHECKING: ...@@ -6573,7 +7273,12 @@ if TYPE_CHECKING:
QDQBertPreTrainedModel, QDQBertPreTrainedModel,
load_tf_weights_in_qdqbert, load_tf_weights_in_qdqbert,
) )
from .models.rag import RagModel, RagPreTrainedModel, RagSequenceForGeneration, RagTokenForGeneration from .models.rag import (
RagModel,
RagPreTrainedModel,
RagSequenceForGeneration,
RagTokenForGeneration,
)
from .models.realm import ( from .models.realm import (
REALM_PRETRAINED_MODEL_ARCHIVE_LIST, REALM_PRETRAINED_MODEL_ARCHIVE_LIST,
RealmEmbedder, RealmEmbedder,
...@@ -6736,7 +7441,10 @@ if TYPE_CHECKING: ...@@ -6736,7 +7441,10 @@ if TYPE_CHECKING:
Speech2TextModel, Speech2TextModel,
Speech2TextPreTrainedModel, Speech2TextPreTrainedModel,
) )
from .models.speech_to_text_2 import Speech2Text2ForCausalLM, Speech2Text2PreTrainedModel from .models.speech_to_text_2 import (
Speech2Text2ForCausalLM,
Speech2Text2PreTrainedModel,
)
from .models.speecht5 import ( from .models.speecht5 import (
SPEECHT5_PRETRAINED_MODEL_ARCHIVE_LIST, SPEECHT5_PRETRAINED_MODEL_ARCHIVE_LIST,
SpeechT5ForSpeechToSpeech, SpeechT5ForSpeechToSpeech,
...@@ -6839,7 +7547,11 @@ if TYPE_CHECKING: ...@@ -6839,7 +7547,11 @@ if TYPE_CHECKING:
TimesformerPreTrainedModel, TimesformerPreTrainedModel,
) )
from .models.timm_backbone import TimmBackbone from .models.timm_backbone import TimmBackbone
from .models.trocr import TROCR_PRETRAINED_MODEL_ARCHIVE_LIST, TrOCRForCausalLM, TrOCRPreTrainedModel from .models.trocr import (
TROCR_PRETRAINED_MODEL_ARCHIVE_LIST,
TrOCRForCausalLM,
TrOCRPreTrainedModel,
)
from .models.tvlt import ( from .models.tvlt import (
TVLT_PRETRAINED_MODEL_ARCHIVE_LIST, TVLT_PRETRAINED_MODEL_ARCHIVE_LIST,
TvltForAudioVisualClassification, TvltForAudioVisualClassification,
...@@ -6880,7 +7592,10 @@ if TYPE_CHECKING: ...@@ -6880,7 +7592,10 @@ if TYPE_CHECKING:
UniSpeechSatPreTrainedModel, UniSpeechSatPreTrainedModel,
) )
from .models.univnet import UNIVNET_PRETRAINED_MODEL_ARCHIVE_LIST, UnivNetModel from .models.univnet import UNIVNET_PRETRAINED_MODEL_ARCHIVE_LIST, UnivNetModel
from .models.upernet import UperNetForSemanticSegmentation, UperNetPreTrainedModel from .models.upernet import (
UperNetForSemanticSegmentation,
UperNetPreTrainedModel,
)
from .models.videomae import ( from .models.videomae import (
VIDEOMAE_PRETRAINED_MODEL_ARCHIVE_LIST, VIDEOMAE_PRETRAINED_MODEL_ARCHIVE_LIST,
VideoMAEForPreTraining, VideoMAEForPreTraining,
...@@ -7005,7 +7720,12 @@ if TYPE_CHECKING: ...@@ -7005,7 +7720,12 @@ if TYPE_CHECKING:
XCLIPTextModel, XCLIPTextModel,
XCLIPVisionModel, XCLIPVisionModel,
) )
from .models.xglm import XGLM_PRETRAINED_MODEL_ARCHIVE_LIST, XGLMForCausalLM, XGLMModel, XGLMPreTrainedModel from .models.xglm import (
XGLM_PRETRAINED_MODEL_ARCHIVE_LIST,
XGLMForCausalLM,
XGLMModel,
XGLMPreTrainedModel,
)
from .models.xlm import ( from .models.xlm import (
XLM_PRETRAINED_MODEL_ARCHIVE_LIST, XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
XLMForMultipleChoice, XLMForMultipleChoice,
...@@ -7142,7 +7862,12 @@ if TYPE_CHECKING: ...@@ -7142,7 +7862,12 @@ if TYPE_CHECKING:
tf_top_k_top_p_filtering, tf_top_k_top_p_filtering,
) )
from .keras_callbacks import KerasMetricCallback, PushToHubCallback from .keras_callbacks import KerasMetricCallback, PushToHubCallback
from .modeling_tf_utils import TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, shape_list from .modeling_tf_utils import (
TFPreTrainedModel,
TFSequenceSummary,
TFSharedEmbeddings,
shape_list,
)
# TensorFlow model imports # TensorFlow model imports
from .models.albert import ( from .models.albert import (
...@@ -7273,7 +7998,11 @@ if TYPE_CHECKING: ...@@ -7273,7 +7998,11 @@ if TYPE_CHECKING:
TFConvBertModel, TFConvBertModel,
TFConvBertPreTrainedModel, TFConvBertPreTrainedModel,
) )
from .models.convnext import TFConvNextForImageClassification, TFConvNextModel, TFConvNextPreTrainedModel from .models.convnext import (
TFConvNextForImageClassification,
TFConvNextModel,
TFConvNextPreTrainedModel,
)
from .models.convnextv2 import ( from .models.convnextv2 import (
TFConvNextV2ForImageClassification, TFConvNextV2ForImageClassification,
TFConvNextV2Model, TFConvNextV2Model,
...@@ -7452,7 +8181,11 @@ if TYPE_CHECKING: ...@@ -7452,7 +8181,11 @@ if TYPE_CHECKING:
TFLayoutLMv3Model, TFLayoutLMv3Model,
TFLayoutLMv3PreTrainedModel, TFLayoutLMv3PreTrainedModel,
) )
from .models.led import TFLEDForConditionalGeneration, TFLEDModel, TFLEDPreTrainedModel from .models.led import (
TFLEDForConditionalGeneration,
TFLEDModel,
TFLEDPreTrainedModel,
)
from .models.longformer import ( from .models.longformer import (
TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFLongformerForMaskedLM, TFLongformerForMaskedLM,
...@@ -7472,8 +8205,16 @@ if TYPE_CHECKING: ...@@ -7472,8 +8205,16 @@ if TYPE_CHECKING:
TFLxmertPreTrainedModel, TFLxmertPreTrainedModel,
TFLxmertVisualFeatureEncoder, TFLxmertVisualFeatureEncoder,
) )
from .models.marian import TFMarianModel, TFMarianMTModel, TFMarianPreTrainedModel from .models.marian import (
from .models.mbart import TFMBartForConditionalGeneration, TFMBartModel, TFMBartPreTrainedModel TFMarianModel,
TFMarianMTModel,
TFMarianPreTrainedModel,
)
from .models.mbart import (
TFMBartForConditionalGeneration,
TFMBartModel,
TFMBartPreTrainedModel,
)
from .models.mobilebert import ( from .models.mobilebert import (
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST, TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFMobileBertForMaskedLM, TFMobileBertForMaskedLM,
...@@ -7505,7 +8246,11 @@ if TYPE_CHECKING: ...@@ -7505,7 +8246,11 @@ if TYPE_CHECKING:
TFMPNetModel, TFMPNetModel,
TFMPNetPreTrainedModel, TFMPNetPreTrainedModel,
) )
from .models.mt5 import TFMT5EncoderModel, TFMT5ForConditionalGeneration, TFMT5Model from .models.mt5 import (
TFMT5EncoderModel,
TFMT5ForConditionalGeneration,
TFMT5Model,
)
from .models.openai import ( from .models.openai import (
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST, TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFOpenAIGPTDoubleHeadsModel, TFOpenAIGPTDoubleHeadsModel,
...@@ -7516,8 +8261,17 @@ if TYPE_CHECKING: ...@@ -7516,8 +8261,17 @@ if TYPE_CHECKING:
TFOpenAIGPTPreTrainedModel, TFOpenAIGPTPreTrainedModel,
) )
from .models.opt import TFOPTForCausalLM, TFOPTModel, TFOPTPreTrainedModel from .models.opt import TFOPTForCausalLM, TFOPTModel, TFOPTPreTrainedModel
from .models.pegasus import TFPegasusForConditionalGeneration, TFPegasusModel, TFPegasusPreTrainedModel from .models.pegasus import (
from .models.rag import TFRagModel, TFRagPreTrainedModel, TFRagSequenceForGeneration, TFRagTokenForGeneration TFPegasusForConditionalGeneration,
TFPegasusModel,
TFPegasusPreTrainedModel,
)
from .models.rag import (
TFRagModel,
TFRagPreTrainedModel,
TFRagSequenceForGeneration,
TFRagTokenForGeneration,
)
from .models.regnet import ( from .models.regnet import (
TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST, TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRegNetForImageClassification, TFRegNetForImageClassification,
...@@ -7621,8 +8375,16 @@ if TYPE_CHECKING: ...@@ -7621,8 +8375,16 @@ if TYPE_CHECKING:
) )
from .models.vision_encoder_decoder import TFVisionEncoderDecoderModel from .models.vision_encoder_decoder import TFVisionEncoderDecoderModel
from .models.vision_text_dual_encoder import TFVisionTextDualEncoderModel from .models.vision_text_dual_encoder import TFVisionTextDualEncoderModel
from .models.vit import TFViTForImageClassification, TFViTModel, TFViTPreTrainedModel from .models.vit import (
from .models.vit_mae import TFViTMAEForPreTraining, TFViTMAEModel, TFViTMAEPreTrainedModel TFViTForImageClassification,
TFViTModel,
TFViTPreTrainedModel,
)
from .models.vit_mae import (
TFViTMAEForPreTraining,
TFViTMAEModel,
TFViTMAEPreTrainedModel,
)
from .models.wav2vec2 import ( from .models.wav2vec2 import (
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
TFWav2Vec2ForCTC, TFWav2Vec2ForCTC,
...@@ -7677,7 +8439,12 @@ if TYPE_CHECKING: ...@@ -7677,7 +8439,12 @@ if TYPE_CHECKING:
) )
# Optimization # Optimization
from .optimization_tf import AdamWeightDecay, GradientAccumulator, WarmUp, create_optimizer from .optimization_tf import (
AdamWeightDecay,
GradientAccumulator,
WarmUp,
create_optimizer,
)
# Trainer # Trainer
from .trainer_tf import TFTrainer from .trainer_tf import TFTrainer
...@@ -7694,7 +8461,11 @@ if TYPE_CHECKING: ...@@ -7694,7 +8461,11 @@ if TYPE_CHECKING:
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from .utils.dummy_essentia_and_librosa_and_pretty_midi_and_scipy_and_torch_objects import * from .utils.dummy_essentia_and_librosa_and_pretty_midi_and_scipy_and_torch_objects import *
else: else:
from .models.pop2piano import Pop2PianoFeatureExtractor, Pop2PianoProcessor, Pop2PianoTokenizer from .models.pop2piano import (
Pop2PianoFeatureExtractor,
Pop2PianoProcessor,
Pop2PianoTokenizer,
)
try: try:
if not is_flax_available(): if not is_flax_available():
...@@ -7810,7 +8581,11 @@ if TYPE_CHECKING: ...@@ -7810,7 +8581,11 @@ if TYPE_CHECKING:
FlaxBlenderbotSmallModel, FlaxBlenderbotSmallModel,
FlaxBlenderbotSmallPreTrainedModel, FlaxBlenderbotSmallPreTrainedModel,
) )
from .models.bloom import FlaxBloomForCausalLM, FlaxBloomModel, FlaxBloomPreTrainedModel from .models.bloom import (
FlaxBloomForCausalLM,
FlaxBloomModel,
FlaxBloomPreTrainedModel,
)
from .models.clip import ( from .models.clip import (
FlaxCLIPModel, FlaxCLIPModel,
FlaxCLIPPreTrainedModel, FlaxCLIPPreTrainedModel,
...@@ -7841,11 +8616,31 @@ if TYPE_CHECKING: ...@@ -7841,11 +8616,31 @@ if TYPE_CHECKING:
FlaxElectraPreTrainedModel, FlaxElectraPreTrainedModel,
) )
from .models.encoder_decoder import FlaxEncoderDecoderModel from .models.encoder_decoder import FlaxEncoderDecoderModel
from .models.gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model, FlaxGPT2PreTrainedModel from .models.gpt2 import (
from .models.gpt_neo import FlaxGPTNeoForCausalLM, FlaxGPTNeoModel, FlaxGPTNeoPreTrainedModel FlaxGPT2LMHeadModel,
from .models.gptj import FlaxGPTJForCausalLM, FlaxGPTJModel, FlaxGPTJPreTrainedModel FlaxGPT2Model,
from .models.longt5 import FlaxLongT5ForConditionalGeneration, FlaxLongT5Model, FlaxLongT5PreTrainedModel FlaxGPT2PreTrainedModel,
from .models.marian import FlaxMarianModel, FlaxMarianMTModel, FlaxMarianPreTrainedModel )
from .models.gpt_neo import (
FlaxGPTNeoForCausalLM,
FlaxGPTNeoModel,
FlaxGPTNeoPreTrainedModel,
)
from .models.gptj import (
FlaxGPTJForCausalLM,
FlaxGPTJModel,
FlaxGPTJPreTrainedModel,
)
from .models.longt5 import (
FlaxLongT5ForConditionalGeneration,
FlaxLongT5Model,
FlaxLongT5PreTrainedModel,
)
from .models.marian import (
FlaxMarianModel,
FlaxMarianMTModel,
FlaxMarianPreTrainedModel,
)
from .models.mbart import ( from .models.mbart import (
FlaxMBartForConditionalGeneration, FlaxMBartForConditionalGeneration,
FlaxMBartForQuestionAnswering, FlaxMBartForQuestionAnswering,
...@@ -7853,11 +8648,27 @@ if TYPE_CHECKING: ...@@ -7853,11 +8648,27 @@ if TYPE_CHECKING:
FlaxMBartModel, FlaxMBartModel,
FlaxMBartPreTrainedModel, FlaxMBartPreTrainedModel,
) )
from .models.mt5 import FlaxMT5EncoderModel, FlaxMT5ForConditionalGeneration, FlaxMT5Model from .models.mt5 import (
FlaxMT5EncoderModel,
FlaxMT5ForConditionalGeneration,
FlaxMT5Model,
)
from .models.opt import FlaxOPTForCausalLM, FlaxOPTModel, FlaxOPTPreTrainedModel from .models.opt import FlaxOPTForCausalLM, FlaxOPTModel, FlaxOPTPreTrainedModel
from .models.pegasus import FlaxPegasusForConditionalGeneration, FlaxPegasusModel, FlaxPegasusPreTrainedModel from .models.pegasus import (
from .models.regnet import FlaxRegNetForImageClassification, FlaxRegNetModel, FlaxRegNetPreTrainedModel FlaxPegasusForConditionalGeneration,
from .models.resnet import FlaxResNetForImageClassification, FlaxResNetModel, FlaxResNetPreTrainedModel FlaxPegasusModel,
FlaxPegasusPreTrainedModel,
)
from .models.regnet import (
FlaxRegNetForImageClassification,
FlaxRegNetModel,
FlaxRegNetPreTrainedModel,
)
from .models.resnet import (
FlaxResNetForImageClassification,
FlaxResNetModel,
FlaxResNetPreTrainedModel,
)
from .models.roberta import ( from .models.roberta import (
FlaxRobertaForCausalLM, FlaxRobertaForCausalLM,
FlaxRobertaForMaskedLM, FlaxRobertaForMaskedLM,
...@@ -7888,10 +8699,19 @@ if TYPE_CHECKING: ...@@ -7888,10 +8699,19 @@ if TYPE_CHECKING:
FlaxRoFormerPreTrainedModel, FlaxRoFormerPreTrainedModel,
) )
from .models.speech_encoder_decoder import FlaxSpeechEncoderDecoderModel from .models.speech_encoder_decoder import FlaxSpeechEncoderDecoderModel
from .models.t5 import FlaxT5EncoderModel, FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel from .models.t5 import (
FlaxT5EncoderModel,
FlaxT5ForConditionalGeneration,
FlaxT5Model,
FlaxT5PreTrainedModel,
)
from .models.vision_encoder_decoder import FlaxVisionEncoderDecoderModel from .models.vision_encoder_decoder import FlaxVisionEncoderDecoderModel
from .models.vision_text_dual_encoder import FlaxVisionTextDualEncoderModel from .models.vision_text_dual_encoder import FlaxVisionTextDualEncoderModel
from .models.vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel from .models.vit import (
FlaxViTForImageClassification,
FlaxViTModel,
FlaxViTPreTrainedModel,
)
from .models.wav2vec2 import ( from .models.wav2vec2 import (
FlaxWav2Vec2ForCTC, FlaxWav2Vec2ForCTC,
FlaxWav2Vec2ForPreTraining, FlaxWav2Vec2ForPreTraining,
...@@ -7904,7 +8724,11 @@ if TYPE_CHECKING: ...@@ -7904,7 +8724,11 @@ if TYPE_CHECKING:
FlaxWhisperModel, FlaxWhisperModel,
FlaxWhisperPreTrainedModel, FlaxWhisperPreTrainedModel,
) )
from .models.xglm import FlaxXGLMForCausalLM, FlaxXGLMModel, FlaxXGLMPreTrainedModel from .models.xglm import (
FlaxXGLMForCausalLM,
FlaxXGLMModel,
FlaxXGLMPreTrainedModel,
)
from .models.xlm_roberta import ( from .models.xlm_roberta import (
FLAX_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, FLAX_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
FlaxXLMRobertaForCausalLM, FlaxXLMRobertaForCausalLM,
......
...@@ -158,6 +158,7 @@ from . import ( ...@@ -158,6 +158,7 @@ from . import (
opt, opt,
owlv2, owlv2,
owlvit, owlvit,
patchtsmixer,
patchtst, patchtst,
pegasus, pegasus,
pegasus_x, pegasus_x,
......
...@@ -164,6 +164,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ...@@ -164,6 +164,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("opt", "OPTConfig"), ("opt", "OPTConfig"),
("owlv2", "Owlv2Config"), ("owlv2", "Owlv2Config"),
("owlvit", "OwlViTConfig"), ("owlvit", "OwlViTConfig"),
("patchtsmixer", "PatchTSMixerConfig"),
("patchtst", "PatchTSTConfig"), ("patchtst", "PatchTSTConfig"),
("pegasus", "PegasusConfig"), ("pegasus", "PegasusConfig"),
("pegasus_x", "PegasusXConfig"), ("pegasus_x", "PegasusXConfig"),
...@@ -380,6 +381,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict( ...@@ -380,6 +381,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
("opt", "OPT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("opt", "OPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("owlv2", "OWLV2_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("owlv2", "OWLV2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("owlvit", "OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("owlvit", "OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("patchtsmixer", "PATCHTSMIXER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("patchtst", "PATCHTST_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("patchtst", "PATCHTST_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("pegasus", "PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("pegasus", "PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("pegasus_x", "PEGASUS_X_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("pegasus_x", "PEGASUS_X_PRETRAINED_CONFIG_ARCHIVE_MAP"),
...@@ -616,6 +618,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ...@@ -616,6 +618,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("opt", "OPT"), ("opt", "OPT"),
("owlv2", "OWLv2"), ("owlv2", "OWLv2"),
("owlvit", "OWL-ViT"), ("owlvit", "OWL-ViT"),
("patchtsmixer", "PatchTSMixer"),
("patchtst", "PatchTST"), ("patchtst", "PatchTST"),
("pegasus", "Pegasus"), ("pegasus", "Pegasus"),
("pegasus_x", "PEGASUS-X"), ("pegasus_x", "PEGASUS-X"),
......
...@@ -18,7 +18,12 @@ import warnings ...@@ -18,7 +18,12 @@ import warnings
from collections import OrderedDict from collections import OrderedDict
from ...utils import logging from ...utils import logging
from .auto_factory import _BaseAutoBackboneClass, _BaseAutoModelClass, _LazyAutoMapping, auto_class_update from .auto_factory import (
_BaseAutoBackboneClass,
_BaseAutoModelClass,
_LazyAutoMapping,
auto_class_update,
)
from .configuration_auto import CONFIG_MAPPING_NAMES from .configuration_auto import CONFIG_MAPPING_NAMES
...@@ -157,6 +162,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ...@@ -157,6 +162,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("opt", "OPTModel"), ("opt", "OPTModel"),
("owlv2", "Owlv2Model"), ("owlv2", "Owlv2Model"),
("owlvit", "OwlViTModel"), ("owlvit", "OwlViTModel"),
("patchtsmixer", "PatchTSMixerModel"),
("patchtst", "PatchTSTModel"), ("patchtst", "PatchTSTModel"),
("pegasus", "PegasusModel"), ("pegasus", "PegasusModel"),
("pegasus_x", "PegasusXModel"), ("pegasus_x", "PegasusXModel"),
...@@ -483,7 +489,10 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -483,7 +489,10 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("convnextv2", "ConvNextV2ForImageClassification"), ("convnextv2", "ConvNextV2ForImageClassification"),
("cvt", "CvtForImageClassification"), ("cvt", "CvtForImageClassification"),
("data2vec-vision", "Data2VecVisionForImageClassification"), ("data2vec-vision", "Data2VecVisionForImageClassification"),
("deit", ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher")), (
"deit",
("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher"),
),
("dinat", "DinatForImageClassification"), ("dinat", "DinatForImageClassification"),
("dinov2", "Dinov2ForImageClassification"), ("dinov2", "Dinov2ForImageClassification"),
( (
...@@ -496,7 +505,10 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -496,7 +505,10 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("efficientnet", "EfficientNetForImageClassification"), ("efficientnet", "EfficientNetForImageClassification"),
("focalnet", "FocalNetForImageClassification"), ("focalnet", "FocalNetForImageClassification"),
("imagegpt", "ImageGPTForImageClassification"), ("imagegpt", "ImageGPTForImageClassification"),
("levit", ("LevitForImageClassification", "LevitForImageClassificationWithTeacher")), (
"levit",
("LevitForImageClassification", "LevitForImageClassificationWithTeacher"),
),
("mobilenet_v1", "MobileNetV1ForImageClassification"), ("mobilenet_v1", "MobileNetV1ForImageClassification"),
("mobilenet_v2", "MobileNetV2ForImageClassification"), ("mobilenet_v2", "MobileNetV2ForImageClassification"),
("mobilevit", "MobileViTForImageClassification"), ("mobilevit", "MobileViTForImageClassification"),
...@@ -1140,12 +1152,14 @@ MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict( ...@@ -1140,12 +1152,14 @@ MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING_NAMES = OrderedDict( MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[ [
("patchtsmixer", "PatchTSMixerForTimeSeriesClassification"),
("patchtst", "PatchTSTForClassification"), ("patchtst", "PatchTSTForClassification"),
] ]
) )
MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING_NAMES = OrderedDict( MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING_NAMES = OrderedDict(
[ [
("patchtsmixer", "PatchTSMixerForRegression"),
("patchtst", "PatchTSTForRegression"), ("patchtst", "PatchTSTForRegression"),
] ]
) )
...@@ -1305,7 +1319,9 @@ class AutoModelForSeq2SeqLM(_BaseAutoModelClass): ...@@ -1305,7 +1319,9 @@ class AutoModelForSeq2SeqLM(_BaseAutoModelClass):
AutoModelForSeq2SeqLM = auto_class_update( AutoModelForSeq2SeqLM = auto_class_update(
AutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base" AutoModelForSeq2SeqLM,
head_doc="sequence-to-sequence language modeling",
checkpoint_for_example="t5-base",
) )
......
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
# rely on isort to merge the imports
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {
"configuration_patchtsmixer": [
"PATCHTSMIXER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"PatchTSMixerConfig",
],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_patchtsmixer"] = [
"PATCHTSMIXER_PRETRAINED_MODEL_ARCHIVE_LIST",
"PatchTSMixerPreTrainedModel",
"PatchTSMixerModel",
"PatchTSMixerForPretraining",
"PatchTSMixerForPrediction",
"PatchTSMixerForTimeSeriesClassification",
"PatchTSMixerForRegression",
]
if TYPE_CHECKING:
from .configuration_patchtsmixer import (
PATCHTSMIXER_PRETRAINED_CONFIG_ARCHIVE_MAP,
PatchTSMixerConfig,
)
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_patchtsmixer import (
PATCHTSMIXER_PRETRAINED_MODEL_ARCHIVE_LIST,
PatchTSMixerForPrediction,
PatchTSMixerForPretraining,
PatchTSMixerForRegression,
PatchTSMixerForTimeSeriesClassification,
PatchTSMixerModel,
PatchTSMixerPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
# coding=utf-8
# Copyright 2023 IBM and HuggingFace Inc. team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PatchTSMixer model configuration"""
from typing import List, Optional, Union
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
PATCHTSMIXER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"ibm/patchtsmixer-etth1-pretrain": "https://huggingface.co/ibm/patchtsmixer-etth1-pretrain/resolve/main/config.json",
}
class PatchTSMixerConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`PatchTSMixerModel`]. It is used to instantiate a
PatchTSMixer model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the PatchTSMixer
[ibm/patchtsmixer-etth1-pretrain](https://huggingface.co/ibm/patchtsmixer-etth1-pretrain) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
context_length (`int`, *optional*, defaults to 32):
The context/history length for the input sequence.
patch_len (`int`, *optional*, defaults to 8):
The patch length for the input sequence.
num_input_channels (`int`, *optional*, defaults to 1):
Number of input variates. For Univariate, set it to 1.
patch_stride (`int`, *optional*, defaults to 8):
Determines the overlap between two consecutive patches. Set it to patch_length (or greater), if we want
non-overlapping patches.
num_parallel_samples (`int`, *optional*, defaults to 100):
The number of samples to generate in parallel for probabilistic forecast.
d_model (`int`, *optional*, defaults to 8):
Hidden dimension of the model. Recommended to set it as a multiple of patch_length (i.e. 2-5X of
patch_len). Larger value indicates more complex model.
expansion_factor (`int`, *optional*, defaults to 2):
Expansion factor to use inside MLP. Recommended range is 2-5. Larger value indicates more complex model.
num_layers (`int`, *optional*, defaults to 3):
Number of layers to use. Recommended range is 3-15. Larger value indicates more complex model.
dropout (`float`, *optional*, defaults to 0.2):
The dropout probability the `PatchTSMixer` backbone. Recommended range is 0.2-0.7
mode (`str`, *optional*, defaults to `"common_channel"`):
Mixer Mode. Determines how to process the channels. Allowed values: "common_channel", "mix_channel". In
"common_channel" mode, we follow Channel-independent modelling with no explicit channel-mixing. Channel
mixing happens in an implicit manner via shared weights across channels. (preferred first approach) In
"mix_channel" mode, we follow explicit channel-mixing in addition to patch and feature mixer. (preferred
approach when channel correlations are very important to model)
gated_attn (`bool`, *optional*, defaults to `True`):
Enable Gated Attention.
norm_mlp (`str`, *optional*, defaults to `"LayerNorm"`):
Normalization layer (BatchNorm or LayerNorm).
self_attn (`bool`, *optional*, defaults to `False`):
Enable Tiny self attention across patches. This can be enabled when the output of Vanilla PatchTSMixer with
gated attention is not satisfactory. Enabling this leads to explicit pair-wise attention and modelling
across patches.
self_attn_heads (`int`, *optional*, defaults to 1):
Number of self-attention heads. Works only when `self_attn` is set to `True`.
use_positional_encoding (`bool`, *optional*, defaults to `False`):
Enable the use of positional embedding for the tiny self-attention layers. Works only when `self_attn` is
set to `True`.
positional_encoding_type (`str`, *optional*, defaults to `"sincos"`):
Positional encodings. Options `"random"` and `"sincos"` are supported. Works only when
`use_positional_encoding` is set to `True`
scaling (`string` or `bool`, *optional*, defaults to `"std"`):
Whether to scale the input targets via "mean" scaler, "std" scaler or no scaler if `None`. If `True`, the
scaler is set to "mean".
loss (`string`, *optional*, defaults to `"mse"`):
The loss function for the model corresponding to the `distribution_output` head. For parametric
distributions it is the negative log likelihood ("nll") and for point estimates it is the mean squared
error "mse".
init_std (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated normal weight initialization distribution.
post_init (`bool`, *optional*, defaults to `False`):
Whether to use custom weight initialization from `transformers` library, or the default initialization in
`PyTorch`. Setting it to `False` performs `PyTorch` weight initialization.
norm_eps (`float`, *optional*, defaults to 1e-05):
A value added to the denominator for numerical stability of normalization.
mask_type (`str`, *optional*, defaults to `"random"`):
Type of masking to use for Masked Pretraining mode. Allowed values are "random", "forecast". In Random
masking, points are masked randomly. In Forecast masking, points are masked towards the end.
random_mask_ratio (`float`, *optional*, defaults to 0.5):
Masking ratio to use when `mask_type` is `random`. Higher value indicates more masking.
num_forecast_mask_patches (`int` or `list`, *optional*, defaults to `[2]`):
Number of patches to be masked at the end of each batch sample. If it is an integer, all the samples in the
batch will have the same number of masked patches. If it is a list, samples in the batch will be randomly
masked by numbers defined in the list. This argument is only used for forecast pretraining.
mask_value (`float`, *optional*, defaults to `0.0`):
Mask value to use.
masked_loss (`bool`, *optional*, defaults to `True`):
Whether to compute pretraining loss only at the masked portions, or on the entire output.
channel_consistent_masking (`bool`, *optional*, defaults to `True`):
When true, masking will be same across all channels of a timeseries. Otherwise, masking positions will vary
across channels.
unmasked_channel_indices (`list`, *optional*):
Channels that are not masked during pretraining.
head_dropout (`float`, *optional*, defaults to 0.2):
The dropout probability the `PatchTSMixer` head.
distribution_output (`string`, *optional*, defaults to `"student_t"`):
The distribution emission head for the model when loss is "nll". Could be either "student_t", "normal" or
"negative_binomial".
prediction_length (`int`, *optional*, defaults to 16):
Number of time steps to forecast for a forecasting task. Also known as the Forecast Horizon.
prediction_channel_indices (`list`, *optional*):
List of channel indices to forecast. If None, forecast all channels. Target data is expected to have all
channels and we explicitly filter the channels in prediction and target before loss computation.
num_targets (`int`, *optional*, defaults to 3):
Number of targets (dimensionality of the regressed variable) for a regression task.
output_range (`list`, *optional*):
Output range to restrict for the regression task. Defaults to None.
head_aggregation (`str`, *optional*, defaults to `"max_pool"`):
Aggregation mode to enable for classification or regression task. Allowed values are `None`, "use_last",
"max_pool", "avg_pool".
Example:
```python
>>> from transformers import PatchTSMixerConfig, PatchTSMixerModel
>>> # Initializing a default PatchTSMixer configuration
>>> configuration = PatchTSMixerConfig()
>>> # Randomly initializing a model (with random weights) from the configuration
>>> model = PatchTSMixerModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "patchtsmixer"
attribute_map = {
"hidden_size": "d_model",
"num_hidden_layers": "num_layers",
}
def __init__(
self,
# Time series specific configuration
context_length: int = 32,
patch_len: int = 8,
num_input_channels: int = 1,
patch_stride: int = 8,
num_parallel_samples: int = 100,
# General model configuration
d_model: int = 8,
expansion_factor: int = 2,
num_layers: int = 3,
dropout: float = 0.2,
mode: str = "common_channel",
gated_attn: bool = True,
norm_mlp: str = "LayerNorm",
self_attn: bool = False,
self_attn_heads: int = 1,
use_positional_encoding: bool = False,
positional_encoding_type: str = "sincos",
scaling: Optional[Union[str, bool]] = "std",
loss: str = "mse",
init_std: float = 0.02,
post_init: bool = False,
norm_eps: float = 1e-5,
# Pretrain model configuration
mask_type: str = "random",
random_mask_ratio: float = 0.5,
num_forecast_mask_patches: Optional[Union[List[int], int]] = [2],
mask_value: int = 0,
masked_loss: bool = True,
channel_consistent_masking: bool = True,
unmasked_channel_indices: Optional[List[int]] = None,
# General head configuration
head_dropout: float = 0.2,
distribution_output: str = "student_t",
# Prediction head configuration
prediction_length: int = 16,
prediction_channel_indices: list = None,
# Classification/Regression configuration
num_targets: int = 3,
output_range: list = None,
head_aggregation: str = "max_pool",
**kwargs,
):
self.num_input_channels = num_input_channels
self.context_length = context_length
self.patch_length = patch_len
self.patch_stride = patch_stride
self.d_model = d_model
self.expansion_factor = expansion_factor
self.num_layers = num_layers
self.dropout = dropout
self.mode = mode
self.gated_attn = gated_attn
self.norm_mlp = norm_mlp
self.scaling = scaling
self.head_dropout = head_dropout
self.num_patches = (max(context_length, patch_len) - patch_len) // patch_stride + 1
self.mask_type = mask_type
self.random_mask_ratio = random_mask_ratio
self.num_forecast_mask_patches = num_forecast_mask_patches
self.mask_value = mask_value
self.channel_consistent_masking = channel_consistent_masking
self.masked_loss = masked_loss
self.patch_last = True
self.use_positional_encoding = use_positional_encoding
self.positional_encoding_type = positional_encoding_type
self.prediction_length = prediction_length
self.prediction_channel_indices = prediction_channel_indices
self.num_targets = num_targets
self.output_range = output_range
self.head_aggregation = head_aggregation
self.self_attn = self_attn
self.self_attn_heads = self_attn_heads
self.init_std = init_std
self.post_init = post_init
self.distribution_output = distribution_output
self.loss = loss
self.num_parallel_samples = num_parallel_samples
self.unmasked_channel_indices = unmasked_channel_indices
self.norm_eps = norm_eps
super().__init__(**kwargs)
# coding=utf-8
# Copyright 2023 IBM and HuggingFace Inc. team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch PatchTSMixer model."""
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput
from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_patchtsmixer import PatchTSMixerConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "PatchTSMixerConfig"
PATCHTSMIXER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"ibm/patchtsmixer-etth1-pretrain",
# See all PatchTSMixer models at https://huggingface.co/models?filter=patchtsmixer
]
PATCHTSMIXER_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`PatchTSMixerConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
mask_input (`bool`, *optional*, defaults to `False`):
If True, Masking will be enabled. False otherwise.
"""
PATCHTSMIXER_INPUTS_DOCSTRING = r"""
Args:
past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
Context values of the time series. For a pretraining task, this denotes the input time series to predict
the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly,
for classification or regression tasks, it denotes the appropriate context values of the time series.
For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
greater than 1.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
class PatchTSMixerGatedAttention(nn.Module):
"""
Module that applies gated attention to input data.
Args:
in_size (`int`): The input size.
out_size (`int`): The output size.
"""
def __init__(self, in_size: int, out_size: int):
super().__init__()
self.attn_layer = nn.Linear(in_size, out_size)
self.attn_softmax = nn.Softmax(dim=-1)
def forward(self, inputs):
attn_weight = self.attn_softmax(self.attn_layer(inputs))
inputs = inputs * attn_weight
return inputs
# Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTBatchNorm with PatchTST->PatchTSMixer
class PatchTSMixerBatchNorm(nn.Module):
"""
Compute batch normalization over the sequence length (time) dimension.
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__()
self.batchnorm = nn.BatchNorm1d(config.d_model, eps=config.norm_eps)
def forward(self, inputs: torch.Tensor):
"""
Parameters:
inputs (`torch.Tensor` of shape `(batch_size, sequence_length, d_model)`):
input for Batch norm calculation
Returns:
`torch.Tensor` of shape `(batch_size, sequence_length, d_model)`
"""
output = inputs.transpose(1, 2) # output: (batch_size, d_model, sequence_length)
output = self.batchnorm(output)
return output.transpose(1, 2)
class PatchTSMixerPositionalEncoding(nn.Module):
"""
Class for positional encoding
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__()
# positional encoding: [num_patches x d_model]
if config.use_positional_encoding:
self.position_enc = self._init_pe(config)
else:
self.position_enc = nn.Parameter(torch.zeros(config.num_patches, config.d_model))
@staticmethod
def _init_pe(config: PatchTSMixerConfig) -> nn.Parameter:
# Positional encoding
if config.positional_encoding_type == "random":
position_enc = nn.Parameter(torch.randn(config.num_patches, config.d_model), requires_grad=True)
elif config.positional_encoding_type == "sincos":
position_enc = torch.zeros(config.num_patches, config.d_model)
position = torch.arange(0, config.num_patches).unsqueeze(1)
div_term = torch.exp(torch.arange(0, config.d_model, 2) * -(math.log(10000.0) / config.d_model))
position_enc[:, 0::2] = torch.sin(position * div_term)
position_enc[:, 1::2] = torch.cos(position * div_term)
position_enc = position_enc - position_enc.mean()
position_enc = position_enc / (position_enc.std() * 10)
position_enc = nn.Parameter(position_enc, requires_grad=False)
else:
raise ValueError(
f"{config.positional_encoding_type} is not a valid positional encoder. Available types are 'random' and 'sincos'."
)
return position_enc
def forward(self, patch_input: torch.Tensor):
# hidden_state: [bs x num_channels x num_patches x d_model]
hidden_state = patch_input + self.position_enc
return hidden_state
class PatchTSMixerNormLayer(nn.Module):
"""Normalization block
Args:
config (`PatchTSMixerConfig`, *required*):
Configuration.
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__()
self.norm_mlp = config.norm_mlp
if "batch" in config.norm_mlp.lower():
self.norm = PatchTSMixerBatchNorm(config)
else:
self.norm = nn.LayerNorm(config.d_model, eps=config.norm_eps)
def forward(self, inputs: torch.Tensor):
"""
Args:
inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
Input to the normalization layer.
Returns:
`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`
"""
if "batch" in self.norm_mlp.lower():
# reshape the data
inputs_reshaped = torch.reshape(
inputs,
(
inputs.shape[0] * inputs.shape[1],
inputs.shape[2],
inputs.shape[3],
),
) # inputs_reshaped: [batch_size*num_channels, num_patches, d_model]
# inputs_reshaped: [batch_size*num_channels, num_patches, d_model]
inputs_reshaped = self.norm(inputs_reshaped)
# put back data to the original shape
inputs = torch.reshape(inputs_reshaped, inputs.shape)
else:
inputs = self.norm(inputs)
return inputs
class PatchTSMixerMLP(nn.Module):
def __init__(self, in_features, out_features, config):
super().__init__()
num_hidden = in_features * config.expansion_factor
self.fc1 = nn.Linear(in_features, num_hidden)
self.dropout1 = nn.Dropout(config.dropout)
self.fc2 = nn.Linear(num_hidden, out_features)
self.dropout2 = nn.Dropout(config.dropout)
def forward(self, inputs: torch.Tensor):
"""
Args:
inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
Input to the MLP layer.
Returns:
`torch.Tensor` of the same shape as `inputs`
"""
inputs = self.dropout1(nn.functional.gelu(self.fc1(inputs)))
inputs = self.fc2(inputs)
inputs = self.dropout2(inputs)
return inputs
class PatchTSMixerChannelFeatureMixerBlock(nn.Module):
"""This module mixes the features in the channel dimension.
Args:
config (`PatchTSMixerConfig`, *required*):
Configuration.
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__()
self.norm = PatchTSMixerNormLayer(config)
self.gated_attn = config.gated_attn
self.mlp = PatchTSMixerMLP(
in_features=config.num_input_channels,
out_features=config.num_input_channels,
config=config,
)
if config.gated_attn:
self.gating_block = PatchTSMixerGatedAttention(
in_size=config.num_input_channels, out_size=config.num_input_channels
)
def forward(self, inputs: torch.Tensor):
"""
Args:
inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
input to the MLP layer
Returns:
`torch.Tensor` of the same shape as `inputs`
"""
residual = inputs
inputs = self.norm(inputs)
inputs = inputs.permute(0, 3, 2, 1)
if self.gated_attn:
inputs = self.gating_block(inputs)
inputs = self.mlp(inputs)
inputs = inputs.permute(0, 3, 2, 1)
out = inputs + residual
return out
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->PatchTSMixer
class PatchTSMixerAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[PatchTSMixerConfig] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.reshape(*proj_shape)
value_states = value_states.reshape(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if output_attentions:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped, past_key_value
class PatchMixerBlock(nn.Module):
"""This module mixes the patch dimension.
Args:
config (`PatchTSMixerConfig`, *required*):
Configuration.
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__()
self.norm = PatchTSMixerNormLayer(config)
self.self_attn = config.self_attn
self.gated_attn = config.gated_attn
self.mlp = PatchTSMixerMLP(
in_features=config.num_patches,
out_features=config.num_patches,
config=config,
)
if config.gated_attn:
self.gating_block = PatchTSMixerGatedAttention(in_size=config.num_patches, out_size=config.num_patches)
if config.self_attn:
self.self_attn_layer = PatchTSMixerAttention(
embed_dim=config.d_model,
num_heads=config.self_attn_heads,
dropout=config.dropout,
)
self.norm_attn = PatchTSMixerNormLayer(config)
def forward(self, hidden_state):
"""
Args:
hidden_state (`torch.Tensor`): Input tensor.
Returns:
`torch.Tensor`: Transformed tensor.
"""
residual = hidden_state
hidden_state = self.norm(hidden_state)
if self.self_attn:
batch_size, n_vars, num_patches, d_model = hidden_state.shape
hidden_state_reshaped = hidden_state.reshape(batch_size * n_vars, num_patches, d_model)
x_attn, _, _ = self.self_attn_layer(hidden_state_reshaped, output_attentions=False)
x_attn = x_attn.reshape(batch_size, n_vars, num_patches, d_model)
# Transpose so that num_patches is the last dimension
hidden_state = hidden_state.transpose(2, 3)
hidden_state = self.mlp(hidden_state)
if self.gated_attn:
hidden_state = self.gating_block(hidden_state)
# Transpose back
hidden_state = hidden_state.transpose(2, 3)
if self.self_attn:
hidden_state = self.norm_attn(hidden_state + x_attn)
out = hidden_state + residual
return out
class FeatureMixerBlock(nn.Module):
"""This module mixes the hidden feature dimension.
Args:
config (`PatchTSMixerConfig`, *required*):
Configuration.
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__()
self.norm = PatchTSMixerNormLayer(config)
self.gated_attn = config.gated_attn
self.mlp = PatchTSMixerMLP(
in_features=config.d_model,
out_features=config.d_model,
config=config,
)
if config.gated_attn:
self.gating_block = PatchTSMixerGatedAttention(in_size=config.d_model, out_size=config.d_model)
def forward(self, hidden: torch.Tensor):
"""
Args:
hidden (`torch.Tensor` of shape `(batch_size, num_patches, d_model)`):
Input tensor to the layer.
Returns:
`torch.Tensor`: Transformed tensor.
"""
residual = hidden
hidden = self.norm(hidden)
hidden = self.mlp(hidden)
if self.gated_attn:
hidden = self.gating_block(hidden)
out = hidden + residual
return out
class PatchTSMixerLayer(nn.Module):
"""
The `PatchTSMixer` layer that does all three kinds of mixing.
Args:
config (`PatchTSMixerConfig`, *required*):
Configuration.
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__()
self.patch_mixer = PatchMixerBlock(config=config)
self.feature_mixer = FeatureMixerBlock(config=config)
self.mode = config.mode
if config.mode == "mix_channel":
self.channel_feature_mixer = PatchTSMixerChannelFeatureMixerBlock(config=config)
def forward(self, hidden: torch.Tensor):
"""
Args:
hidden (`torch.Tensor` of shape `(batch_size, num_patches, d_model)`):
Input tensor to the layer.
Returns:
`torch.Tensor`: Transformed tensor.
"""
if self.mode == "mix_channel":
hidden = self.channel_feature_mixer(hidden)
hidden = self.patch_mixer(hidden)
hidden = self.feature_mixer(hidden) # hidden: (batch_size x num_patches x d_model)
return hidden
class PatchTSMixerBlock(nn.Module):
"""The main computing framework of the `PatchTSMixer` model.
Args:
config (`PatchTSMixerConfig`, *required*):
Configuration.
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__()
num_layers = config.num_layers
self.mixers = nn.ModuleList([PatchTSMixerLayer(config=config) for _ in range(num_layers)])
def forward(self, hidden_state, output_hidden_states: bool = False):
"""
Args:
hidden_state (`torch.Tensor`): The input tensor.
output_hidden_states (`bool`, *optional*, defaults to False.):
Whether to output the hidden states as well.
Returns:
`torch.Tensor`: The embedding. `list`: List of all hidden states if `output_hidden_states` is set to
`True`.
"""
all_hidden_states = []
embedding = hidden_state
for mod in self.mixers:
embedding = mod(embedding)
if output_hidden_states:
all_hidden_states.append(embedding)
if output_hidden_states:
return embedding, all_hidden_states
else:
return embedding, None
class PatchTSMixerForPredictionHead(nn.Module):
"""Prediction Head for Forecasting
Args:
config (`PatchTSMixerConfig`, *required*): Configuration.
"""
def __init__(self, config: PatchTSMixerConfig, distribution_output=None):
super().__init__()
self.prediction_channel_indices = config.prediction_channel_indices
if self.prediction_channel_indices is not None:
self.prediction_channel_indices.sort()
self.dropout_layer = nn.Dropout(config.head_dropout)
if distribution_output is None:
self.base_forecast_block = nn.Linear((config.num_patches * config.d_model), config.prediction_length)
else:
self.base_forecast_block = distribution_output.get_parameter_projection(
config.num_patches * config.d_model
)
self.flatten = nn.Flatten(start_dim=-2)
def forward(self, hidden_features):
"""
Args:
hidden_features (`torch.Tensor` of shape `(batch_size, num_patch, d_model)` in `flatten` mode
or `(batch_size, n_vars, num_patch, d_model)` in `common_channel`/`mix_channel` mode.): Input hidden
features.
Returns:
`torch.Tensor` of shape `(batch_size, prediction_length, nvars)`.
"""
hidden_features = self.flatten(hidden_features) # [batch_size x n_vars x num_patch * d_model]
hidden_features = self.dropout_layer(hidden_features) # [batch_size x n_vars x num_patch * d_model]
forecast = self.base_forecast_block(hidden_features) # [batch_size x n_vars x prediction_length]
if isinstance(forecast, tuple):
forecast = tuple(z.transpose(-1, -2) for z in forecast)
else:
forecast = forecast.transpose(-1, -2) # [batch_size x prediction_length x n_vars]
if self.prediction_channel_indices is not None:
if isinstance(forecast, tuple):
forecast = tuple(z[..., self.prediction_channel_indices] for z in forecast)
else:
forecast = forecast[..., self.prediction_channel_indices] # [batch_size x prediction_length x n_vars]
return forecast
class PatchTSMixerLinearHead(nn.Module):
"""Linear head for Classification and Regression.
Args:
config (`PatchTSMixerConfig`, *required*):
"""
def __init__(self, config: PatchTSMixerConfig, distribution_output=None):
super().__init__()
self.head_aggregation = config.head_aggregation
self.output_range = config.output_range
if config.head_aggregation is None:
mul_factor = config.num_patches
else:
mul_factor = 1
self.distribution_output = distribution_output
if distribution_output is None:
self.projection = nn.Linear(
config.d_model * config.num_input_channels * mul_factor,
config.num_targets,
)
else:
self.projection = distribution_output.get_parameter_projection(
config.d_model * config.num_input_channels * mul_factor
)
if config.head_aggregation is None:
self.flatten = nn.Flatten(start_dim=-3)
else:
self.flatten = nn.Flatten(start_dim=-2)
self.dropout = nn.Dropout(config.head_dropout)
def forward(self, hidden_features):
"""
Args:
hidden_features (`torch.Tensor` of shape `(batch_size x num_patch x d_model)` in `flatten` mode
or `(batch_size x n_vars x num_patch x d_model)` in `common_channel`/`mix_channel` mode.): Input hidden
features.
Returns:
`torch.Tensor` of shape `(batch_size x num_targets)`.
"""
# batch_size x d_model x num_patch or batch_size x n_vars x d_model x num_patch
hidden_features = hidden_features.transpose(-1, -2)
if self.head_aggregation == "use_last":
# batch_size x d_model (flatten) or # batch_size x n_vars x d_model (common_channel)
hidden_features = hidden_features[..., -1]
elif self.head_aggregation == "max_pool":
# batch_size x n_vars x d_model or batch_size x d_model
hidden_features = hidden_features.max(dim=-1).values
elif self.head_aggregation == "avg_pool":
# batch_size x n_vars x d_model or batch_size x d_model
hidden_features = hidden_features.mean(dim=-1)
if self.flatten:
hidden_features = self.flatten(hidden_features)
hidden_features = self.dropout(hidden_features)
hidden_features = self.projection(hidden_features) # batch_size x num_targets
if (self.distribution_output is None) and (self.output_range is not None):
hidden_features = (
torch.sigmoid(hidden_features) * (self.output_range[1] - self.output_range[0]) + self.output_range[0]
)
return hidden_features
class PatchTSMixerPreTrainedModel(PreTrainedModel):
# Weight initialization
config_class = PatchTSMixerConfig
base_model_prefix = "model"
main_input_name = "past_values"
supports_gradient_checkpointing = False
def _init_weights(self, module):
"""Initialize weights"""
if isinstance(module, PatchTSMixerPositionalEncoding):
# initialize positional encoding
if self.config.positional_encoding_type == "random":
nn.init.normal_(module.position_enc, mean=0.0, std=0.1)
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, PatchTSMixerBatchNorm):
module.batchnorm.bias.data.zero_()
module.batchnorm.weight.data.fill_(1.0)
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.init_std)
if module.bias is not None:
module.bias.data.zero_()
class PatchTSMixerPretrainHead(nn.Module):
"""Pretraining head.
Args:
config (`PatchTSMixerConfig`, *required*):
Configuration.
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__()
self.dropout_layer = nn.Dropout(config.head_dropout)
self.base_pt_block = nn.Linear(config.d_model, config.patch_length)
def forward(self, hidden_features):
"""
Args:
hidden_features (`torch.Tensor` of shape `(batch_size x num_patch x d_model)` in `flatten` mode
or `(batch_size x n_vars x num_patch x d_model)` in `common_channel`/`mix_channel` mode.): Input hidden
features.
Returns:
`torch.Tensor` of shape `(batch_size x n_vars x num_patch x patch_length)`.
"""
hidden_features = self.dropout_layer(hidden_features)
forecast = self.base_pt_block(hidden_features) # [batch_size x n_vars x num_patch x patch_length]
return forecast
# Copied from transformers.models.patchtst.modeling_patchtst.random_masking
def random_masking(
inputs: torch.Tensor,
mask_ratio: float,
unmasked_channel_indices: list = None,
channel_consistent_masking: bool = False,
mask_value: int = 0,
):
"""random_masking: Mask the input considering the control variables.
Args:
inputs (`torch.Tensor` of shape `(batch_size, num_channels, sequence_length, num_features)`):
The input tensor to mask.
mask_ratio (`float`):
Masking ratio applied to mask the input data during random pretraining. It is the number between 0 and 1.
unmasked_channel_indices (list, *optional*):
Indices of channels that will not be masked.
channel_consistent_masking (bool, *optional*, defaults to `False`):
When true, masking will be same across all channels of a timeseries. Otherwise, masking positions will vary
across channels.
mask_value (int, *optional*, defaults to 0):
Define the value of masked patches for pretraining.
Returns:
`tuple(torch.Tensor)`: inputs_mask, masked input, same shape as input Tensor and mask tensor of shape [bs x c x
n]
"""
if mask_ratio < 0 or mask_ratio >= 1:
raise ValueError(f"Mask ratio {mask_ratio} has to be between 0 and 1.")
batch_size, num_channels, sequence_length, num_features = inputs.shape
device = inputs.device
len_keep = int(sequence_length * (1 - mask_ratio))
if channel_consistent_masking:
noise = torch.rand(batch_size, 1, sequence_length, device=device) # noise in [0, 1], bs x 1 x L
noise = noise.repeat(1, num_channels, 1) # bs x num_channels x time
else:
# noise in [0, 1], bs x num_channels x L
noise = torch.rand(batch_size, num_channels, sequence_length, device=device)
# mask: [bs x num_channels x num_patch]
mask = torch.ones(batch_size, num_channels, sequence_length, device=device)
mask[:, :, :len_keep] = 0
# sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=-1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=-1) # ids_restore: [bs x num_channels x L]
mask = torch.gather(mask, dim=-1, index=ids_restore)
mask = mask.unsqueeze(-1).repeat(1, 1, 1, num_features) # mask: [bs x num_channels x num_patches x patch_length]
if unmasked_channel_indices is not None:
mask[:, unmasked_channel_indices, :, :] = 0
inputs_mask = inputs.masked_fill(mask.bool(), mask_value)
return inputs_mask, mask[..., 0]
# Copied from transformers.models.patchtst.modeling_patchtst.forecast_masking
def forecast_masking(
inputs: torch.Tensor,
num_forecast_mask_patches: Union[list, int],
unmasked_channel_indices: list = None,
mask_value: int = 0,
):
"""Forecast masking that masks the last K patches where K is from the num_forecast_mask_patches.
If num_forecast_mask_patches is a list, samples in the batch will be randomly masked by numbers defined in the list.
Parameters:
inputs (`torch.Tensor`):
Input of shape `(bs, num_channels, num_patch, patch_len)`
num_forecast_mask_patches (`list`):
Number of patches to be masked at the end of each batch sample. e.g. 4 or [3, 5].
unmasked_channel_indices (`list`, *optional*):
Indices of channels that are not masked.
mask_value (`int`, *optional*, defaults to 0):
Values in the masked patches will be filled by `mask_value`.
Returns:
`tuple(torch.Tensor)`: inputs_mask, masked input, same shape as inputs Tensor and Mask tensor of shape `(bs,
num_channels , num_patch)` or `(bs, tsg1, tsg2, num_channels, num_patch)`
"""
if isinstance(num_forecast_mask_patches, int):
num_forecast_mask_patches = [num_forecast_mask_patches]
forecast_mask_ratios = [1 for _ in num_forecast_mask_patches]
batch_size, num_channels, sequence_length, num_features = inputs.shape
mask = torch.zeros(batch_size, num_channels, sequence_length, device=inputs.device)
t_list = []
total_length = 0
total_ratio = sum(forecast_mask_ratios)
for patch_length, ratio in zip(num_forecast_mask_patches, forecast_mask_ratios):
if patch_length <= 0 or patch_length >= sequence_length:
raise ValueError(
f"num_forecast_mask_patches {patch_length} should be greater than 0 and less than total patches."
)
temp_len = int(batch_size * ratio / total_ratio)
t_list.append([patch_length, ratio, temp_len])
total_length += temp_len
t_list = sorted(t_list, key=lambda x: x[2])
if total_length < batch_size:
t_list[0][2] = t_list[0][2] + (batch_size - total_length)
elif total_length > batch_size:
t_list[-1][2] = t_list[-1][2] + (total_length - batch_size)
batch1 = 0
for patch_len, _, temp_len in t_list:
batch2 = batch1 + temp_len
mask[batch1:batch2, :, -patch_len:] = 1
batch1 = batch2
perm = torch.randperm(mask.shape[0])
mask = mask[perm]
mask = mask.unsqueeze(-1).repeat(1, 1, 1, num_features) # mask: [bs x num_channels x num_patch x patch_len]
if unmasked_channel_indices is not None:
mask[:, unmasked_channel_indices, :, :] = 0
inputs_mask = inputs.masked_fill(mask.bool(), mask_value)
return inputs_mask, mask[..., 0]
# Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTPatchify with PatchTST->PatchTSMixer
class PatchTSMixerPatchify(nn.Module):
"""
A class to patchify the time series sequence into different patches
Returns:
`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__()
self.sequence_length = config.context_length
self.patch_length = config.patch_length
self.patch_stride = config.patch_stride
if self.sequence_length <= self.patch_length:
raise ValueError(
f"Sequence length ({self.sequence_length}) has to be greater than the patch length ({self.patch_length})"
)
# get the number of patches
self.num_patches = (max(self.sequence_length, self.patch_length) - self.patch_length) // self.patch_stride + 1
new_sequence_length = self.patch_length + self.patch_stride * (self.num_patches - 1)
self.sequence_start = self.sequence_length - new_sequence_length
def forward(self, past_values: torch.Tensor):
"""
Parameters:
past_values (`torch.Tensor` of shape `(batch_size, sequence_length, num_channels)`, *required*):
Input for patchification
Returns:
`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
"""
sequence_length = past_values.shape[-2]
if sequence_length != self.sequence_length:
raise ValueError(
f"Input sequence length ({sequence_length}) doesn't match model configuration ({self.sequence_length})."
)
# output: [bs x new_sequence_length x num_channels]
output = past_values[:, self.sequence_start :, :]
# output: [bs x num_patches x num_input_channels x patch_length]
output = output.unfold(dimension=-2, size=self.patch_length, step=self.patch_stride)
# output: [bs x num_input_channels x num_patches x patch_length]
output = output.transpose(-2, -3).contiguous()
return output
# Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTMasking with PatchTST->PatchTSMixer
class PatchTSMixerMasking(nn.Module):
"""
Class to perform random or forecast masking.
Parameters:
config (`PatchTSMixerConfig`): model config
Returns:
x_mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
Masked patched input
mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
Bool tensor indicating True on masked points
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__()
self.random_mask_ratio = config.random_mask_ratio
self.channel_consistent_masking = config.channel_consistent_masking
self.mask_type = config.mask_type
self.num_forecast_mask_patches = config.num_forecast_mask_patches
self.unmasked_channel_indices = config.unmasked_channel_indices
self.mask_value = config.mask_value
if self.unmasked_channel_indices is not None:
self.unmasked_channel_indices = sorted(self.unmasked_channel_indices)
def forward(self, patch_input: torch.Tensor):
"""
Parameters:
patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*):
Patch input
Return:
masked_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
Masked patched input
mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
Bool tensor indicating True on masked points
"""
if self.mask_type == "random":
masked_input, mask = random_masking(
inputs=patch_input,
mask_ratio=self.random_mask_ratio,
unmasked_channel_indices=self.unmasked_channel_indices,
channel_consistent_masking=self.channel_consistent_masking,
mask_value=self.mask_value,
)
elif self.mask_type == "forecast":
masked_input, mask = forecast_masking(
inputs=patch_input,
num_forecast_mask_patches=self.num_forecast_mask_patches,
unmasked_channel_indices=self.unmasked_channel_indices,
mask_value=self.mask_value,
)
else:
raise ValueError(f"Invalid mask type {self.mask_type}.")
# mask: [bs x num_input_channels x num_patch]
mask = mask.bool()
return masked_input, mask
# Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTStdScaler with PatchTST->PatchTSMixer
class PatchTSMixerStdScaler(nn.Module):
"""
Standardize features by calculating the mean and scaling along the first dimension, and then normalizes it by
subtracting from the mean and dividing by the standard deviation.
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__()
self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-5
def forward(
self, data: torch.Tensor, observed_indicator: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Parameters:
data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
input for Batch norm calculation
observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
Calculating the scale on the observed indicator.
Returns:
tuple of `torch.Tensor` of shapes
(`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
`(batch_size, 1, num_input_channels)`)
"""
denominator = observed_indicator.sum(self.dim, keepdim=self.keepdim)
denominator = denominator.clamp_min(1.0)
loc = (data * observed_indicator).sum(self.dim, keepdim=self.keepdim) / denominator
variance = (((data - loc) * observed_indicator) ** 2).sum(self.dim, keepdim=self.keepdim) / denominator
scale = torch.sqrt(variance + self.minimum_scale)
return (data - loc) / scale, loc, scale
# Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTMeanScaler with PatchTST->PatchTSMixer
class PatchTSMixerMeanScaler(nn.Module):
"""
Computes a scaling factor as the weighted average absolute value along the first dimension, and scales the data
accordingly.
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__()
self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-10
self.default_scale = config.default_scale if hasattr(config, "default_scale") else None
def forward(
self, data: torch.Tensor, observed_indicator: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Parameters:
data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
input for Batch norm calculation
observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
Calculating the scale on the observed indicator.
Returns:
tuple of `torch.Tensor` of shapes
(`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
`(batch_size, 1, num_input_channels)`)
"""
ts_sum = (data * observed_indicator).abs().sum(self.dim, keepdim=True)
num_observed = observed_indicator.sum(self.dim, keepdim=True)
scale = ts_sum / torch.clamp(num_observed, min=1)
# If `default_scale` is provided, we use it, otherwise we use the scale
# of the batch.
if self.default_scale is None:
batch_sum = ts_sum.sum(dim=0)
batch_observations = torch.clamp(num_observed.sum(0), min=1)
default_scale = torch.squeeze(batch_sum / batch_observations)
else:
default_scale = self.default_scale * torch.ones_like(scale)
# apply default scale where there are no observations
scale = torch.where(num_observed > 0, scale, default_scale)
# ensure the scale is at least `self.minimum_scale`
scale = torch.clamp(scale, min=self.minimum_scale)
scaled_data = data / scale
if not self.keepdim:
scale = scale.squeeze(dim=self.dim)
return scaled_data, torch.zeros_like(scale), scale
# Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTNOPScaler with PatchTST->PatchTSMixer
class PatchTSMixerNOPScaler(nn.Module):
"""
Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data.
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__()
self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
def forward(
self, data: torch.Tensor, observed_indicator: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Parameters:
data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
input for Batch norm calculation
Returns:
tuple of `torch.Tensor` of shapes
(`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
`(batch_size, 1, num_input_channels)`)
"""
scale = torch.ones_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim)
loc = torch.zeros_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim)
return data, loc, scale
@dataclass
class PatchTSMixerEncoderOutput(ModelOutput):
"""
Base class for `PatchTSMixerEncoderOutput`, with potential hidden states.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, d_model)`):
Hidden-state at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Hidden-states of the model at the output of each layer.
"""
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
class PatchTSMixerEncoder(PatchTSMixerPreTrainedModel):
"""
Encoder for PatchTSMixer which inputs patched time-series and outputs patched embeddings.
Args:
config (`PatchTSMixerConfig`, *required*):
Configuration.
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__(config)
self.use_return_dict = config.use_return_dict
self.patcher = nn.Linear(config.patch_length, config.d_model)
if config.use_positional_encoding:
self.positional_encoder = PatchTSMixerPositionalEncoding(config=config)
else:
self.positional_encoder = None
self.mlp_mixer_encoder = PatchTSMixerBlock(config=config)
# Initialize weights and apply final processing
if config.post_init:
self.post_init()
@replace_return_docstrings(output_type=PatchTSMixerEncoderOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
past_values: torch.Tensor,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple, PatchTSMixerEncoderOutput]:
r"""
Args:
past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
Context values of the time series. For a pretraining task, this denotes the input time series to
predict the masked portion. For a forecasting task, this denotes the history/past time series values.
Similarly, for classification or regression tasks, it denotes the appropriate context values of the
time series.
For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series,
it is greater than 1.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
Returns:
`torch.FloatTensor` of shape `(batch_size, n_vars, num_patches, d_model)`
"""
return_dict = return_dict if return_dict is not None else self.use_return_dict
# flatten [bs x num_patch x d_model]. common_channel/mix_channel: [bs x n_vars x num_patch x d_model]
patches = self.patcher(past_values)
# add positional encoder
if self.positional_encoder is not None:
patches = self.positional_encoder(patches)
last_hidden_state, hidden_states = self.mlp_mixer_encoder(patches, output_hidden_states=output_hidden_states)
if not return_dict:
return tuple(
v
for v in [
last_hidden_state,
hidden_states,
]
)
return PatchTSMixerEncoderOutput(last_hidden_state=last_hidden_state, hidden_states=hidden_states)
@dataclass
class PatchTSMixerModelOutput(ModelOutput):
"""
Base class for model's outputs, with potential hidden states.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, d_model)`):
Hidden-state at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Hidden-states of the model at the output of each layer.
patch_input (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`):
Patched input data to the model.
mask: (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches)`,*optional*):
Bool Tensor indicating True in masked patches and False otherwise.
loc: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`,*optional*):
Gives the mean of the context window per channel. Used for revin denorm outside the model, if revin
enabled.
scale: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`,*optional*):
Gives the std dev of the context window per channel. Used for revin denorm outside the model, if revin
enabled.
"""
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
patch_input: torch.FloatTensor = None
mask: Optional[torch.FloatTensor] = None
loc: Optional[torch.FloatTensor] = None
scale: Optional[torch.FloatTensor] = None
@add_start_docstrings(
"The PatchTSMixer Model for time-series forecasting.",
PATCHTSMIXER_START_DOCSTRING,
)
class PatchTSMixerModel(PatchTSMixerPreTrainedModel):
def __init__(self, config: PatchTSMixerConfig, mask_input: bool = False):
super().__init__(config)
self.use_return_dict = config.use_return_dict
self.encoder = PatchTSMixerEncoder(config)
self.patching = PatchTSMixerPatchify(config)
if mask_input is True:
self.masking = PatchTSMixerMasking(config)
else:
self.masking = None
if config.scaling == "mean":
self.scaler = PatchTSMixerMeanScaler(config)
elif config.scaling == "std" or config.scaling is True:
self.scaler = PatchTSMixerStdScaler(config)
else:
self.scaler = PatchTSMixerNOPScaler(config)
# Initialize weights and apply final processing
if config.post_init:
self.post_init()
@add_start_docstrings_to_model_forward(PATCHTSMIXER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=PatchTSMixerModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
past_values: torch.Tensor,
observed_mask: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = None,
) -> PatchTSMixerModelOutput:
r"""
observed_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
in `[0, 1]`:
- 1 for values that are **observed**,
- 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
Returns:
"""
return_dict = return_dict if return_dict is not None else self.use_return_dict
mask = None
if observed_mask is None:
observed_mask = torch.ones_like(past_values)
scaled_past_values, loc, scale = self.scaler(past_values, observed_mask)
patched_x = self.patching(scaled_past_values) # [batch_size x num_input_channels x num_patch x patch_length
enc_input = patched_x
if self.masking is not None:
enc_input, mask = self.masking(patched_x)
# enc_input: [batch_size x num_input_channels x num_patch x patch_length]
# mask: [batch_size x num_input_channels x num_patch]
encoder_output = self.encoder(
enc_input,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if isinstance(encoder_output, tuple):
encoder_output = PatchTSMixerEncoderOutput(*encoder_output)
if not return_dict:
return tuple(
v
for v in [
encoder_output.last_hidden_state,
encoder_output.hidden_states,
patched_x,
mask,
loc,
scale,
]
)
return PatchTSMixerModelOutput(
last_hidden_state=encoder_output.last_hidden_state,
hidden_states=encoder_output.hidden_states,
patch_input=patched_x,
mask=mask,
loc=loc,
scale=scale,
)
@dataclass
class PatchTSMixerForPreTrainingOutput(ModelOutput):
"""
Output type of [`PatchTSMixerForPreTrainingOutput`].
Args:
prediction_outputs (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, patch_length)`):
Prediction output from the pretrain head.
hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Hidden-states of the model at the output of each layer.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
Backbone embeddings before passing through the head.
loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
Total loss
"""
loss: Optional[torch.FloatTensor] = None
prediction_outputs: torch.FloatTensor = None
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
class PatchTSMixerForPretraining(PatchTSMixerPreTrainedModel):
r"""
`PatchTSMixer` for mask pretraining.
Args:
config (`PatchTSMixerConfig`, *required*):
Configuration.
Returns:
`None`.
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__(config)
self.model = PatchTSMixerModel(config, mask_input=True)
self.head = PatchTSMixerPretrainHead(config=config)
self.masked_loss = config.masked_loss
self.use_return_dict = config.use_return_dict
# Initialize weights and apply final processing
if config.post_init:
self.post_init()
@add_start_docstrings_to_model_forward(PATCHTSMIXER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=PatchTSMixerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
past_values: torch.Tensor,
observed_mask: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = False,
return_loss: bool = True,
return_dict: Optional[bool] = None,
) -> PatchTSMixerForPreTrainingOutput:
r"""
observed_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
in `[0, 1]`:
- 1 for values that are **observed**,
- 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
return_loss (`bool`, *optional*):
Whether to return the loss in the `forward` call.
Returns:
"""
return_dict = return_dict if return_dict is not None else self.use_return_dict
if self.masked_loss is True:
loss = torch.nn.MSELoss(reduction="none")
else:
loss = torch.nn.MSELoss(reduction="mean")
# past_values: tensor [batch_size x context_length x num_input_channels]
model_output = self.model(
past_values,
observed_mask=observed_mask,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
) # x.last_hidden_state: [batch_size x nvars x num_patch x d_model]
if isinstance(model_output, tuple):
model_output = PatchTSMixerModelOutput(*model_output)
x_hat = self.head(model_output.last_hidden_state) # tensor [batch_size x nvars x num_patch x patch_length]
if return_loss is True:
loss_val = loss(x_hat, model_output.patch_input)
else:
loss_val = None
# calculate masked_loss
if self.masked_loss is True and loss_val is not None:
loss_val = (loss_val.mean(dim=-1) * model_output.mask).sum() / (model_output.mask.sum() + 1e-10)
if not return_dict:
return tuple(
v
for v in [
loss_val,
x_hat,
model_output.last_hidden_state,
model_output.hidden_states,
]
)
return PatchTSMixerForPreTrainingOutput(
loss=loss_val,
prediction_outputs=x_hat, # tensor [batch_size x nvars x num_patch x patch_length]
last_hidden_state=model_output.last_hidden_state, # x: [batch_size x nvars x num_patch x d_model]
hidden_states=model_output.hidden_states,
)
@dataclass
class PatchTSMixerForPredictionOutput(ModelOutput):
"""
Output type of [`PatchTSMixerForPredictionOutput`].
Args:
prediction_outputs (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_input_channels)`):
Prediction output from the forecast head.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
Backbone embeddings before passing through the head.
hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
Total loss.
loc (`torch.FloatTensor`, *optional* of shape `(batch_size, 1, num_input_channels)`):
Input mean
scale (`torch.FloatTensor`, *optional* of shape `(batch_size, 1, num_input_channels)`):
Input std dev
"""
loss: Optional[torch.FloatTensor] = None
prediction_outputs: torch.FloatTensor = None
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
loc: torch.FloatTensor = None
scale: torch.FloatTensor = None
@dataclass
class SamplePatchTSMixerPredictionOutput(ModelOutput):
"""
Base class for time series model's predictions outputs that contains the sampled values from the chosen
distribution.
Args:
sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, prediction_length, number_channels)`):
Sampled values from the chosen distribution.
"""
sequences: torch.FloatTensor = None
@dataclass
class SamplePatchTSMixerRegressionOutput(ModelOutput):
"""
Base class for time series model's predictions outputs that contains the sampled values from the chosen
distribution.
Args:
sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, num_targets)`
Sampled values from the chosen distribution.
"""
sequences: torch.FloatTensor = None
# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.nll
def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor:
"""
Computes the negative log likelihood loss from input distribution with respect to target.
"""
return -input.log_prob(target)
# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.weighted_average
def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor] = None, dim=None) -> torch.Tensor:
"""
Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero,
meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`.
Args:
input_tensor (`torch.FloatTensor`):
Input tensor, of which the average must be computed.
weights (`torch.FloatTensor`, *optional*):
Weights tensor, of the same shape as `input_tensor`.
dim (`int`, *optional*):
The dim along which to average `input_tensor`.
Returns:
`torch.FloatTensor`: The tensor with values averaged along the specified `dim`.
"""
if weights is not None:
weighted_tensor = torch.where(weights != 0, input_tensor * weights, torch.zeros_like(input_tensor))
sum_weights = torch.clamp(weights.sum(dim=dim) if dim else weights.sum(), min=1.0)
return (weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum()) / sum_weights
else:
return input_tensor.mean(dim=dim)
class PatchTSMixerForPrediction(PatchTSMixerPreTrainedModel):
r"""
`PatchTSMixer` for forecasting application.
Args:
config (`PatchTSMixerConfig`, *required*):
Configuration.
Returns:
`None`.
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__(config)
self.loss = config.loss
self.use_return_dict = config.use_return_dict
self.prediction_channel_indices = config.prediction_channel_indices
self.num_parallel_samples = config.num_parallel_samples
if config.loss == "mse":
self.distribution_output = None
else:
dim = config.prediction_length
distribution_output_map = {
"student_t": StudentTOutput,
"normal": NormalOutput,
"negative_binomial": NegativeBinomialOutput,
}
output_class = distribution_output_map.get(config.distribution_output, None)
if output_class is not None:
self.distribution_output = output_class(dim=dim)
else:
raise ValueError(f"Unknown distribution output {config.distribution_output}")
self.model = PatchTSMixerModel(config)
self.head = PatchTSMixerForPredictionHead(
config=config,
distribution_output=self.distribution_output,
)
# Initialize weights and apply final processing
if config.post_init:
self.post_init()
@add_start_docstrings_to_model_forward(PATCHTSMIXER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=PatchTSMixerForPredictionOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
past_values: torch.Tensor,
observed_mask: Optional[torch.Tensor] = None,
future_values: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = False,
return_loss: bool = True,
return_dict: Optional[bool] = None,
) -> PatchTSMixerForPredictionOutput:
r"""
observed_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
in `[0, 1]`:
- 1 for values that are **observed**,
- 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
future_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,:
`(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*): Target
values of the time series, that serve as labels for the model. The `future_values` is what the
Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
required for a pretraining task.
For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want
to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter,
pass the target data with all channels, as channel Filtering for both prediction and target will be
manually applied before the loss computation.
return_loss (`bool`, *optional*):
Whether to return the loss in the `forward` call.
Returns:
"""
if self.loss == "mse":
loss = nn.MSELoss(reduction="mean")
elif self.loss == "nll":
loss = nll
else:
raise ValueError("Invalid loss function: Allowed values: mse and nll")
return_dict = return_dict if return_dict is not None else self.use_return_dict
# past_values: tensor [batch_size x context_length x num_input_channels]
model_output = self.model(
past_values,
observed_mask=observed_mask,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
) # model_output: [batch_size x nvars x num_patch x d_model]
if isinstance(model_output, tuple):
model_output = PatchTSMixerModelOutput(*model_output)
# tensor [batch_size x prediction_length x num_input_channels]
y_hat = self.head(model_output.last_hidden_state)
loss_val = None
if self.prediction_channel_indices is not None:
if self.distribution_output:
distribution = self.distribution_output.distribution(
y_hat,
loc=model_output.loc[..., self.prediction_channel_indices],
scale=model_output.scale[..., self.prediction_channel_indices],
)
if future_values is not None and return_loss is True:
loss_val = loss(
distribution,
future_values[..., self.prediction_channel_indices],
)
# take average of the loss
loss_val = weighted_average(loss_val)
else:
y_hat = (
y_hat * model_output.scale[..., self.prediction_channel_indices]
+ model_output.loc[..., self.prediction_channel_indices]
)
if future_values is not None and return_loss is True:
loss_val = loss(y_hat, future_values[..., self.prediction_channel_indices])
else:
if self.distribution_output:
distribution = self.distribution_output.distribution(
y_hat, loc=model_output.loc, scale=model_output.scale
)
if future_values is not None and return_loss is True:
loss_val = loss(distribution, future_values)
loss_val = weighted_average(loss_val)
else:
y_hat = y_hat * model_output.scale + model_output.loc
if future_values is not None and return_loss is True:
loss_val = loss(y_hat, future_values)
if self.prediction_channel_indices is not None:
loc = model_output.loc[..., self.prediction_channel_indices]
scale = model_output.scale[..., self.prediction_channel_indices]
else:
loc = model_output.loc
scale = model_output.scale
if not return_dict:
return tuple(
v
for v in [
loss_val,
y_hat,
model_output.last_hidden_state,
model_output.hidden_states,
loc,
scale,
]
)
return PatchTSMixerForPredictionOutput(
loss=loss_val,
prediction_outputs=y_hat, # tensor [batch_size x prediction_length x num_input_channels]
last_hidden_state=model_output.last_hidden_state, # x: [batch_size x nvars x num_patch x d_model]
hidden_states=model_output.hidden_states,
loc=loc,
scale=scale,
)
def generate(
self,
past_values: torch.Tensor,
observed_mask: Optional[torch.Tensor] = None,
) -> SamplePatchTSMixerPredictionOutput:
"""
Generate sequences of sample predictions from a model with a probability distribution head.
Args:
past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
Past values of the time series that serves as context in order to predict the future.
observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
in `[0, 1]`:
- 1 for values that are **observed**,
- 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
Return:
[`SamplePatchTSMixerPredictionOutput`] where the outputs `sequences` tensor will have shape `(batch_size,
number of samples, prediction_length, num_input_channels)`.
"""
# get number of samples
num_parallel_samples = self.num_parallel_samples
# get model output
outputs = self(
past_values=past_values,
future_values=None,
observed_mask=observed_mask,
output_hidden_states=False,
)
# get distribution
distribution = self.distribution_output.distribution(
outputs.prediction_outputs, loc=outputs.loc, scale=outputs.scale
)
# get samples: list of [batch_size x prediction_length x num_channels]
samples = [distribution.sample() for _ in range(num_parallel_samples)]
# stack tensors
samples = torch.stack(samples, dim=1) # [batch_size x num_samples x prediction_length x num_channels]
return SamplePatchTSMixerPredictionOutput(sequences=samples)
@dataclass
class PatchTSMixerForTimeSeriesClassificationOutput(ModelOutput):
"""
Output type of [`PatchTSMixerForTimeSeriesClassificationOutput`].
Args:
prediction_outputs (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
Prediction output from the classfication head.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
Backbone embeddings before passing through the head.
hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
Total loss.
"""
loss: Optional[torch.FloatTensor] = None
prediction_outputs: torch.FloatTensor = None
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
class PatchTSMixerForTimeSeriesClassification(PatchTSMixerPreTrainedModel):
r"""
`PatchTSMixer` for classification application.
Args:
config (`PatchTSMixerConfig`, *required*):
Configuration.
Returns:
`None`.
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__(config)
self.model = PatchTSMixerModel(config)
self.head = PatchTSMixerLinearHead(
config=config,
)
self.use_return_dict = config.use_return_dict
if config.scaling in ["std", "mean", True]:
self.inject_scale = InjectScalerStatistics4D(d_model=config.d_model, num_patches=config.num_patches)
else:
self.inject_scale = None
# Initialize weights and apply final processing
if config.post_init:
self.post_init()
@add_start_docstrings_to_model_forward(PATCHTSMIXER_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=PatchTSMixerForTimeSeriesClassificationOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
past_values: torch.Tensor,
future_values: torch.Tensor = None,
output_hidden_states: Optional[bool] = False,
return_loss: bool = True,
return_dict: Optional[bool] = None,
) -> PatchTSMixerForTimeSeriesClassificationOutput:
r"""
future_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,
`(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*): Target
values of the time series, that serve as labels for the model. The `future_values` is what the
Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
required for a pretraining task.
For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want
to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter,
pass the target data with all channels, as channel Filtering for both prediction and target will be
manually applied before the loss computation.
For a classification task, it has a shape of `(batch_size,)`.
For a regression task, it has a shape of `(batch_size, num_targets)`.
return_loss (`bool`, *optional*):
Whether to return the loss in the `forward` call.
Returns:
"""
loss = torch.nn.CrossEntropyLoss()
return_dict = return_dict if return_dict is not None else self.use_return_dict
model_output = self.model(
past_values,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
) # x: [batch_size x nvars x num_patch x d_model]
if isinstance(model_output, tuple):
model_output = PatchTSMixerModelOutput(*model_output)
if self.inject_scale is not None:
model_output.last_hidden_state = self.inject_scale(
model_output.last_hidden_state,
loc=model_output.loc,
scale=model_output.scale,
) # x: [batch_size x nvars x num_patch x d_model]
y_hat = self.head(model_output.last_hidden_state) # tensor [batch_size x n_labels]
if future_values is not None and return_loss is True:
loss_val = loss(y_hat, future_values)
else:
loss_val = None
if not return_dict:
return tuple(
v
for v in [
loss_val,
y_hat,
model_output.last_hidden_state,
model_output.hidden_states,
]
)
return PatchTSMixerForTimeSeriesClassificationOutput(
loss=loss_val,
prediction_outputs=y_hat, # tensor [batch_size x n_labels]
last_hidden_state=model_output.last_hidden_state, # x: [batch_size x nvars x num_patch x d_model]
hidden_states=model_output.hidden_states,
)
@dataclass
class PatchTSMixerForRegressionOutput(ModelOutput):
"""
Output type of [`PatchTSMixerForRegressionOutput`].
Args:
prediction_outputs (`torch.FloatTensor` of shape `(batch_size, num_targets)`):
Prediction output from the regression head.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
Backbone embeddings before passing through the head.
hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
Total loss.
"""
loss: Optional[torch.FloatTensor] = None
prediction_outputs: torch.FloatTensor = None
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
class InjectScalerStatistics4D(nn.Module):
def __init__(self, d_model: int, num_patches: int, expansion: int = 2):
super().__init__()
self.inverse_trans_expansion = nn.Linear(d_model + 2, expansion * d_model)
self.inverse_trans_compression = nn.Linear(expansion * d_model, d_model)
self.map_scale_expansion = nn.Linear(2, 2 * expansion)
self.map_scale_compression = nn.Linear(2 * expansion, 2)
self.num_patches = num_patches
def forward(self, inputs: torch.Tensor, loc: torch.Tensor, scale: torch.Tensor):
"""
Args:
inputs (`torch.Tensor` of shape `(batch_size, num_input_channels, num_patch, d_model)`)
loc (`torch.Tensor` of shape `(batch_size, 1, num_input_channels)`)
scale (`torch.Tensor` of shape `(batch_size, 1, num_input_channels)`)
Returns:
`torch.Tensor` of shape `(batch_size, num_input_channels, num_patch, d_model)`
"""
mean = loc.transpose(-1, -2) # [batch_size x n_channels x 1 ]
mean = mean.unsqueeze(-2) # [batch_size x n_channels x 1 x 1]
mean = mean.repeat(1, 1, self.num_patches, 1) # [batch_size x n_channels x num_patch x 1]
stdev = scale.transpose(-1, -2) # [batch_size x n_channels x 1 ]
stdev = stdev.unsqueeze(-2) # [batch_size x n_channels x 1 x 1]
stdev = stdev.repeat(1, 1, self.num_patches, 1) # [batch_size x n_channels x num_patch x 1]
concat_stats = torch.cat([mean, stdev], dim=-1) # [batch_size x n_channels x num_patch x 2]
concat_stats = self.map_scale_expansion(concat_stats) # [batch_size x n_channels x num_patch x (2*expansion)]
concat_stats = self.map_scale_compression(concat_stats) # [batch_size x n_channels x num_patch x 2]
inputs = torch.cat([inputs, concat_stats], dim=-1) # [batch_size x channels x num_patch x d_model+2]
inputs = self.inverse_trans_expansion(inputs) # [batch_size x channels x num_patch x (expansion*d_model)]
inputs = self.inverse_trans_compression(inputs) # [batch_size x channels x num_patch x d_model]
return inputs
class PatchTSMixerForRegression(PatchTSMixerPreTrainedModel):
r"""
`PatchTSMixer` for regression application.
Args:
config (`PatchTSMixerConfig`, *required*):
Configuration.
Returns:
`None`.
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__(config)
self.model = PatchTSMixerModel(config)
self.loss = config.loss
self.distribution_output = config.distribution_output
self.use_return_dict = config.use_return_dict
self.num_parallel_samples = config.num_parallel_samples
if config.loss == "mse":
self.distribution_output = None
else:
distribution_output_map = {
"student_t": StudentTOutput,
"normal": NormalOutput,
"negative_binomial": NegativeBinomialOutput,
}
output_class = distribution_output_map.get(config.distribution_output)
if output_class is not None:
self.distribution_output = output_class(dim=config.num_targets)
else:
raise ValueError(f"Unknown distribution output {config.distribution_output}")
if config.scaling in ["std", "mean", True]:
self.inject_scale = InjectScalerStatistics4D(d_model=config.d_model, num_patches=config.num_patches)
else:
self.inject_scale = None
self.head = PatchTSMixerLinearHead(
config=config,
distribution_output=self.distribution_output,
)
# Initialize weights and apply final processing
if config.post_init:
self.post_init()
@add_start_docstrings_to_model_forward(PATCHTSMIXER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=PatchTSMixerForRegressionOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
past_values: torch.Tensor,
future_values: torch.Tensor = None,
output_hidden_states: Optional[bool] = False,
return_loss: bool = True,
return_dict: Optional[bool] = None,
) -> PatchTSMixerForRegressionOutput:
r"""
future_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,
`(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*): Target
values of the time series, that serve as labels for the model. The `future_values` is what the
Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
required for a pretraining task.
For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want
to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter,
pass the target data with all channels, as channel Filtering for both prediction and target will be
manually applied before the loss computation.
For a classification task, it has a shape of `(batch_size,)`.
For a regression task, it has a shape of `(batch_size, num_targets)`.
return_loss (`bool`, *optional*):
Whether to return the loss in the `forward` call.
Returns:
"""
if self.loss == "mse":
loss = nn.MSELoss(reduction="mean")
elif self.loss == "nll":
loss = nll
else:
raise ValueError("Invalid loss function: Allowed values: mse and nll")
return_dict = return_dict if return_dict is not None else self.use_return_dict
model_output = self.model(
past_values,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
) # model_output: [batch_size x nvars x num_patch x d_model]
if isinstance(model_output, tuple):
model_output = PatchTSMixerModelOutput(*model_output)
if self.inject_scale is not None:
model_output.last_hidden_state = self.inject_scale(
model_output.last_hidden_state,
loc=model_output.loc,
scale=model_output.scale,
) # x: [batch_size x nvars x num_patch x d_model]
y_hat = self.head(model_output.last_hidden_state) # [batch_size x num_targets]
if future_values is not None and return_loss is True:
if self.distribution_output:
if self.distribution_output == "negative_binomial" and torch.any(future_values < 0):
raise Exception("future_values cannot be negative for negative_binomial distribution.")
distribution = self.distribution_output.distribution(y_hat)
loss_val = loss(distribution, future_values)
# take average of the loss
loss_val = weighted_average(loss_val)
else:
loss_val = loss(y_hat, future_values)
else:
loss_val = None
if not return_dict:
return tuple(
v
for v in [
loss_val,
y_hat,
model_output.last_hidden_state,
model_output.hidden_states,
]
)
return PatchTSMixerForRegressionOutput(
loss=loss_val,
prediction_outputs=y_hat, # tensor [batch_size x num_targets]
last_hidden_state=model_output.last_hidden_state, # [batch_size x nvars x num_patch x d_model]
hidden_states=model_output.hidden_states,
)
def generate(
self,
past_values: torch.Tensor,
) -> SamplePatchTSMixerRegressionOutput:
"""
Generate sequences of sample predictions from a model with a probability distribution head.
Args:
past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
Past values of the time series that serves as context in order to predict the future.
Return:
[`SamplePatchTSMixerRegressionOutput`] where the outputs `sequences` tensor will have shape `(batch_size,
number of samples, num_targets)`.
"""
# get number of samples
num_parallel_samples = self.num_parallel_samples
# get model output
outputs = self(
past_values=past_values,
future_values=None,
output_hidden_states=False,
)
# get distribution
distribution = self.distribution_output.distribution(outputs.prediction_outputs)
# get samples
samples = [
distribution.sample() for _ in range(num_parallel_samples)
] # samples: list of [batch_size x num_targets]
# stack tensors
samples = torch.stack(samples, dim=1) # [batch_size x num_samples x num_targets]
return SamplePatchTSMixerRegressionOutput(sequences=samples)
...@@ -6074,6 +6074,51 @@ class OwlViTVisionModel(metaclass=DummyObject): ...@@ -6074,6 +6074,51 @@ class OwlViTVisionModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
PATCHTSMIXER_PRETRAINED_MODEL_ARCHIVE_LIST = None
class PatchTSMixerForPrediction(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PatchTSMixerForPretraining(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PatchTSMixerForRegression(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PatchTSMixerForTimeSeriesClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PatchTSMixerModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PatchTSMixerPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
PATCHTST_PRETRAINED_MODEL_ARCHIVE_LIST = None PATCHTST_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
# coding=utf-8
# Copyright 2023 IBM and HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch PatchTSMixer model. """
import inspect
import itertools
import random
import tempfile
import unittest
from typing import Dict, List, Optional, Tuple, Union
from huggingface_hub import hf_hub_download
from parameterized import parameterized
from transformers import is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import is_flaky, require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
TOLERANCE = 1e-4
if is_torch_available():
import torch
from transformers import (
MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING,
MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING,
PatchTSMixerConfig,
PatchTSMixerForPrediction,
PatchTSMixerForPretraining,
PatchTSMixerForRegression,
PatchTSMixerForTimeSeriesClassification,
PatchTSMixerModel,
)
from transformers.models.patchtsmixer.modeling_patchtsmixer import (
PatchTSMixerEncoder,
PatchTSMixerForPredictionHead,
PatchTSMixerForPredictionOutput,
PatchTSMixerForRegressionOutput,
PatchTSMixerForTimeSeriesClassificationOutput,
PatchTSMixerLinearHead,
PatchTSMixerPretrainHead,
)
@require_torch
class PatchTSMixerModelTester:
def __init__(
self,
context_length: int = 32,
patch_length: int = 8,
num_input_channels: int = 3,
patch_stride: int = 8,
# d_model: int = 128,
hidden_size: int = 8,
# num_layers: int = 8,
num_hidden_layers: int = 2,
expansion_factor: int = 2,
dropout: float = 0.5,
mode: str = "common_channel",
gated_attn: bool = True,
norm_mlp="LayerNorm",
swin_hier: int = 0,
# masking related
mask_type: str = "forecast",
random_mask_ratio=0.5,
mask_patches: list = [2, 3],
forecast_mask_ratios: list = [1, 1],
mask_value=0,
masked_loss: bool = False,
mask_mode: str = "mask_before_encoder",
channel_consistent_masking: bool = True,
scaling: Optional[Union[str, bool]] = "std",
# Head related
head_dropout: float = 0.2,
# forecast related
prediction_length: int = 16,
out_channels: int = None,
# Classification/regression related
# num_labels: int = 3,
num_targets: int = 3,
output_range: list = None,
head_aggregation: str = None,
# Trainer related
batch_size=13,
is_training=True,
seed_number=42,
post_init=True,
num_parallel_samples=4,
):
self.num_input_channels = num_input_channels
self.context_length = context_length
self.patch_length = patch_length
self.patch_stride = patch_stride
# self.d_model = d_model
self.hidden_size = hidden_size
self.expansion_factor = expansion_factor
# self.num_layers = num_layers
self.num_hidden_layers = num_hidden_layers
self.dropout = dropout
self.mode = mode
self.gated_attn = gated_attn
self.norm_mlp = norm_mlp
self.swin_hier = swin_hier
self.scaling = scaling
self.head_dropout = head_dropout
# masking related
self.mask_type = mask_type
self.random_mask_ratio = random_mask_ratio
self.mask_patches = mask_patches
self.forecast_mask_ratios = forecast_mask_ratios
self.mask_value = mask_value
self.channel_consistent_masking = channel_consistent_masking
self.mask_mode = mask_mode
self.masked_loss = masked_loss
# patching related
self.patch_last = True
# forecast related
self.prediction_length = prediction_length
self.out_channels = out_channels
# classification/regression related
# self.num_labels = num_labels
self.num_targets = num_targets
self.output_range = output_range
self.head_aggregation = head_aggregation
# Trainer related
self.batch_size = batch_size
self.is_training = is_training
self.seed_number = seed_number
self.post_init = post_init
self.num_parallel_samples = num_parallel_samples
def get_config(self):
config_ = PatchTSMixerConfig(
num_input_channels=self.num_input_channels,
context_length=self.context_length,
patch_length=self.patch_length,
patch_stride=self.patch_stride,
# d_model = self.d_model,
d_model=self.hidden_size,
expansion_factor=self.expansion_factor,
# num_layers = self.num_layers,
num_layers=self.num_hidden_layers,
dropout=self.dropout,
mode=self.mode,
gated_attn=self.gated_attn,
norm_mlp=self.norm_mlp,
swin_hier=self.swin_hier,
scaling=self.scaling,
head_dropout=self.head_dropout,
mask_type=self.mask_type,
random_mask_ratio=self.random_mask_ratio,
mask_patches=self.mask_patches,
forecast_mask_ratios=self.forecast_mask_ratios,
mask_value=self.mask_value,
channel_consistent_masking=self.channel_consistent_masking,
mask_mode=self.mask_mode,
masked_loss=self.masked_loss,
prediction_length=self.prediction_length,
out_channels=self.out_channels,
# num_labels=self.num_labels,
num_targets=self.num_targets,
output_range=self.output_range,
head_aggregation=self.head_aggregation,
post_init=self.post_init,
)
self.num_patches = config_.num_patches
return config_
def prepare_patchtsmixer_inputs_dict(self, config):
_past_length = config.context_length
# bs, n_vars, num_patch, patch_length
# [bs x context_length x n_vars]
past_values = floats_tensor([self.batch_size, _past_length, self.num_input_channels])
future_values = floats_tensor([self.batch_size, config.prediction_length, self.num_input_channels])
inputs_dict = {
"past_values": past_values,
"future_values": future_values,
}
return inputs_dict
def prepare_config_and_inputs(self):
config = self.get_config()
inputs_dict = self.prepare_patchtsmixer_inputs_dict(config)
return config, inputs_dict
def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs()
return config, inputs_dict
@require_torch
class PatchTSMixerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(
PatchTSMixerModel,
PatchTSMixerForPrediction,
PatchTSMixerForPretraining,
PatchTSMixerForTimeSeriesClassification,
PatchTSMixerForRegression,
)
if is_torch_available()
else ()
)
all_generative_model_classes = (
(PatchTSMixerForPrediction, PatchTSMixerForPretraining) if is_torch_available() else ()
)
pipeline_model_mapping = {"feature-extraction": PatchTSMixerModel} if is_torch_available() else {}
is_encoder_decoder = False
test_pruning = False
test_head_masking = False
test_missing_keys = False
test_torchscript = False
test_inputs_embeds = False
test_model_common_attributes = False
test_resize_embeddings = True
test_resize_position_embeddings = False
test_mismatched_shapes = True
test_model_parallel = False
has_attentions = False
def setUp(self):
self.model_tester = PatchTSMixerModelTester()
self.config_tester = ConfigTester(
self,
config_class=PatchTSMixerConfig,
has_text_modality=False,
prediction_length=self.model_tester.prediction_length,
common_properties=["hidden_size", "expansion_factor", "num_hidden_layers"],
)
def test_config(self):
self.config_tester.run_common_tests()
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
# if classification model:
if model_class in get_values(MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING):
rng = random.Random(self.model_tester.seed_number)
labels = ids_tensor([self.model_tester.batch_size], self.model_tester.num_targets, rng=rng)
# inputs_dict["labels"] = labels
inputs_dict["future_values"] = labels
# inputs_dict.pop("future_values")
elif model_class in get_values(MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING):
rng = random.Random(self.model_tester.seed_number)
labels = floats_tensor([self.model_tester.batch_size, self.model_tester.num_targets], rng=rng)
# inputs_dict["labels"] = labels
inputs_dict["future_values"] = labels
# inputs_dict.pop("future_values")
elif model_class in [PatchTSMixerModel, PatchTSMixerForPretraining]:
inputs_dict.pop("future_values")
inputs_dict["output_hidden_states"] = True
return inputs_dict
def test_save_load_strict(self):
config, _ = self.model_tester.prepare_config_and_inputs()
for model_class in self.all_model_classes:
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], [])
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
expected_num_layers = getattr(
self.model_tester,
"expected_num_hidden_layers",
self.model_tester.num_hidden_layers,
)
self.assertEqual(len(hidden_states), expected_num_layers)
expected_hidden_size = self.model_tester.hidden_size
self.assertEqual(hidden_states[0].shape[-1], expected_hidden_size)
num_patch = self.model_tester.num_patches
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[num_patch, self.model_tester.hidden_size],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
check_hidden_states_output(inputs_dict, config, model_class)
@unittest.skip("No tokens embeddings")
def test_resize_tokens_embeddings(self):
pass
def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def set_nan_tensor_to_zero(t):
t[t != t] = 0
return t
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
with torch.no_grad():
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
output_ = model(**dict_inputs, return_dict=True, **additional_kwargs)
attributes_ = vars(output_)
dict_output = tuple(attributes_.values())
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, Dict):
for tuple_iterable_value, dict_iterable_value in zip(
tuple_object.values(), dict_object.values()
):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
torch.allclose(
set_nan_tensor_to_zero(tuple_object),
set_nan_tensor_to_zero(dict_object),
atol=1e-5,
),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
),
)
recursive_check(tuple_output, dict_output)
for model_class in self.all_model_classes:
print(model_class)
model = model_class(config)
model.to(torch_device)
model.eval()
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
tuple_inputs.update({"output_hidden_states": False})
dict_inputs.update({"output_hidden_states": False})
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
tuple_inputs.update({"output_hidden_states": False})
dict_inputs.update({"output_hidden_states": False})
check_equivalence(
model,
tuple_inputs,
dict_inputs,
)
def test_model_main_input_name(self):
model_signature = inspect.signature(getattr(PatchTSMixerModel, "forward"))
# The main input is the name of the argument after `self`
observed_main_input_name = list(model_signature.parameters.keys())[1]
self.assertEqual(PatchTSMixerModel.main_input_name, observed_main_input_name)
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names_with_target = [
"past_values",
"observed_mask",
"future_values",
"output_hidden_states",
"return_loss",
]
expected_arg_names_without_target = [
"past_values",
"observed_mask",
"output_hidden_states",
]
expected_arg_names = expected_arg_names_with_target
if model_class == PatchTSMixerForPretraining:
expected_arg_names = expected_arg_names_without_target + ["return_loss"]
if model_class == PatchTSMixerModel:
expected_arg_names = expected_arg_names_without_target
if model_class in get_values(MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING) or model_class in get_values(
MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING
):
expected_arg_names.remove("observed_mask")
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
@is_flaky()
def test_retain_grad_hidden_states_attentions(self):
super().test_retain_grad_hidden_states_attentions()
def prepare_batch(repo_id="ibm/patchtsmixer-etth1-test-data", file="pretrain_batch.pt"):
# TODO: Make repo public
file = hf_hub_download(repo_id=repo_id, filename=file, repo_type="dataset")
batch = torch.load(file, map_location=torch_device)
return batch
@require_torch
@slow
class PatchTSMixerModelIntegrationTests(unittest.TestCase):
def test_pretrain_head(self):
model = PatchTSMixerForPretraining.from_pretrained("ibm/patchtsmixer-etth1-pretrain").to(torch_device)
batch = prepare_batch()
torch.manual_seed(0)
with torch.no_grad():
output = model(past_values=batch["past_values"].to(torch_device)).prediction_outputs
num_patch = (
max(model.config.context_length, model.config.patch_length) - model.config.patch_length
) // model.config.patch_stride + 1
expected_shape = torch.Size(
[
32,
model.config.num_input_channels,
num_patch,
model.config.patch_length,
]
)
self.assertEqual(output.shape, expected_shape)
expected_slice = torch.tensor([[[[0.1870]],[[-1.5819]],[[-0.0991]],[[-1.2609]],[[0.5633]],[[-0.5723]],[[0.3387]],]],device=torch_device) # fmt: skip
self.assertTrue(torch.allclose(output[0, :7, :1, :1], expected_slice, atol=TOLERANCE))
def test_forecasting_head(self):
model = PatchTSMixerForPrediction.from_pretrained("ibm/patchtsmixer-etth1-forecasting").to(torch_device)
batch = prepare_batch(file="forecast_batch.pt")
model.eval()
torch.manual_seed(0)
with torch.no_grad():
output = model(
past_values=batch["past_values"].to(torch_device),
future_values=batch["future_values"].to(torch_device),
).prediction_outputs
expected_shape = torch.Size([32, model.config.prediction_length, model.config.num_input_channels])
self.assertEqual(output.shape, expected_shape)
expected_slice = torch.tensor(
[[0.4271, -0.0651, 0.4656, 0.7104, -0.3085, -1.9658, 0.4560]],
device=torch_device,
)
self.assertTrue(torch.allclose(output[0, :1, :7], expected_slice, atol=TOLERANCE))
def test_prediction_generation(self):
torch_device = "cpu"
model = PatchTSMixerForPrediction.from_pretrained("ibm/patchtsmixer-etth1-generate").to(torch_device)
batch = prepare_batch(file="forecast_batch.pt")
print(batch["past_values"])
model.eval()
torch.manual_seed(0)
with torch.no_grad():
outputs = model.generate(past_values=batch["past_values"].to(torch_device))
expected_shape = torch.Size((32, 1, model.config.prediction_length, model.config.num_input_channels))
self.assertEqual(outputs.sequences.shape, expected_shape)
expected_slice = torch.tensor(
[[0.0091, -0.3625, -0.0887, 0.6544, -0.4100, -2.3124, 0.3376]],
device=torch_device,
)
mean_prediction = outputs.sequences.mean(dim=1)
self.assertTrue(torch.allclose(mean_prediction[0, -1:], expected_slice, atol=TOLERANCE))
@require_torch
class PatchTSMixerFunctionalTests(unittest.TestCase):
@classmethod
def setUpClass(cls):
"""Setup method: Called once before test-cases execution"""
cls.params = {}
cls.params.update(
context_length=32,
patch_length=8,
num_input_channels=3,
patch_stride=8,
d_model=4,
expansion_factor=2,
num_layers=3,
dropout=0.2,
mode="common_channel", # common_channel, mix_channel
gated_attn=True,
norm_mlp="LayerNorm",
mask_type="random",
random_mask_ratio=0.5,
mask_patches=[2, 3],
forecast_mask_ratios=[1, 1],
mask_value=0,
masked_loss=True,
channel_consistent_masking=True,
head_dropout=0.2,
prediction_length=64,
out_channels=None,
# num_labels=3,
num_targets=3,
output_range=None,
head_aggregation=None,
scaling="std",
use_positional_encoding=False,
positional_encoding="sincos",
self_attn=False,
self_attn_heads=1,
num_parallel_samples=4,
)
cls.num_patches = (
max(cls.params["context_length"], cls.params["patch_length"]) - cls.params["patch_length"]
) // cls.params["patch_stride"] + 1
# batch_size = 32
batch_size = 2
int(cls.params["prediction_length"] / cls.params["patch_length"])
cls.data = torch.rand(
batch_size,
cls.params["context_length"],
cls.params["num_input_channels"],
)
cls.enc_data = torch.rand(
batch_size,
cls.params["num_input_channels"],
cls.num_patches,
cls.params["patch_length"],
)
cls.enc_output = torch.rand(
batch_size,
cls.params["num_input_channels"],
cls.num_patches,
cls.params["d_model"],
)
cls.flat_enc_output = torch.rand(
batch_size,
cls.num_patches,
cls.params["d_model"],
)
cls.correct_pred_output = torch.rand(
batch_size,
cls.params["prediction_length"],
cls.params["num_input_channels"],
)
cls.correct_regression_output = torch.rand(batch_size, cls.params["num_targets"])
cls.correct_pretrain_output = torch.rand(
batch_size,
cls.params["num_input_channels"],
cls.num_patches,
cls.params["patch_length"],
)
cls.correct_forecast_output = torch.rand(
batch_size,
cls.params["prediction_length"],
cls.params["num_input_channels"],
)
cls.correct_sel_forecast_output = torch.rand(batch_size, cls.params["prediction_length"], 2)
cls.correct_classification_output = torch.rand(
batch_size,
cls.params["num_targets"],
)
cls.correct_classification_classes = torch.randint(0, cls.params["num_targets"], (batch_size,))
def test_patchtsmixer_encoder(self):
config = PatchTSMixerConfig(**self.__class__.params)
enc = PatchTSMixerEncoder(config)
output = enc(self.__class__.enc_data)
self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
def test_patchmodel(self):
config = PatchTSMixerConfig(**self.__class__.params)
mdl = PatchTSMixerModel(config)
output = mdl(self.__class__.data)
self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
self.assertEqual(output.patch_input.shape, self.__class__.enc_data.shape)
def test_pretrainhead(self):
config = PatchTSMixerConfig(**self.__class__.params)
head = PatchTSMixerPretrainHead(
config=config,
)
output = head(self.__class__.enc_output)
self.assertEqual(output.shape, self.__class__.correct_pretrain_output.shape)
def test_pretrain_full(self):
config = PatchTSMixerConfig(**self.__class__.params)
mdl = PatchTSMixerForPretraining(config)
output = mdl(self.__class__.data)
self.assertEqual(
output.prediction_outputs.shape,
self.__class__.correct_pretrain_output.shape,
)
self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
self.assertEqual(output.loss.item() < 100, True)
def test_pretrain_full_with_return_dict(self):
config = PatchTSMixerConfig(**self.__class__.params)
mdl = PatchTSMixerForPretraining(config)
output = mdl(self.__class__.data, return_dict=False)
self.assertEqual(output[1].shape, self.__class__.correct_pretrain_output.shape)
self.assertEqual(output[2].shape, self.__class__.enc_output.shape)
self.assertEqual(output[0].item() < 100, True)
def test_forecast_head(self):
config = PatchTSMixerConfig(**self.__class__.params)
head = PatchTSMixerForPredictionHead(
config=config,
)
# output = head(self.__class__.enc_output, raw_data = self.__class__.correct_pretrain_output)
output = head(self.__class__.enc_output)
self.assertEqual(output.shape, self.__class__.correct_forecast_output.shape)
def check_module(
self,
task,
params=None,
output_hidden_states=True,
):
config = PatchTSMixerConfig(**params)
if task == "forecast":
mdl = PatchTSMixerForPrediction(config)
target_input = self.__class__.correct_forecast_output
if config.prediction_channel_indices is not None:
target_output = self.__class__.correct_sel_forecast_output
else:
target_output = target_input
ref_samples = target_output.unsqueeze(1).expand(-1, config.num_parallel_samples, -1, -1)
elif task == "classification":
mdl = PatchTSMixerForTimeSeriesClassification(config)
target_input = self.__class__.correct_classification_classes
target_output = self.__class__.correct_classification_output
elif task == "regression":
mdl = PatchTSMixerForRegression(config)
target_input = self.__class__.correct_regression_output
target_output = self.__class__.correct_regression_output
ref_samples = target_output.unsqueeze(1).expand(-1, config.num_parallel_samples, -1)
elif task == "pretrain":
mdl = PatchTSMixerForPretraining(config)
target_input = None
target_output = self.__class__.correct_pretrain_output
else:
print("invalid task")
enc_output = self.__class__.enc_output
if target_input is None:
output = mdl(self.__class__.data, output_hidden_states=output_hidden_states)
else:
output = mdl(
self.__class__.data,
future_values=target_input,
output_hidden_states=output_hidden_states,
)
if isinstance(output.prediction_outputs, tuple):
for t in output.prediction_outputs:
self.assertEqual(t.shape, target_output.shape)
else:
self.assertEqual(output.prediction_outputs.shape, target_output.shape)
self.assertEqual(output.last_hidden_state.shape, enc_output.shape)
if output_hidden_states is True:
self.assertEqual(len(output.hidden_states), params["num_layers"])
else:
self.assertEqual(output.hidden_states, None)
self.assertEqual(output.loss.item() < 100, True)
if config.loss == "nll" and task in ["forecast", "regression"]:
samples = mdl.generate(self.__class__.data)
self.assertEqual(samples.sequences.shape, ref_samples.shape)
@parameterized.expand(
list(
itertools.product(
["common_channel", "mix_channel"],
[True, False],
[True, False, "mean", "std"],
[True, False],
[None, [0, 2]],
["mse", "nll"],
)
)
)
def test_forecast(self, mode, self_attn, scaling, gated_attn, prediction_channel_indices, loss):
params = self.__class__.params.copy()
params.update(
mode=mode,
self_attn=self_attn,
scaling=scaling,
prediction_channel_indices=prediction_channel_indices,
gated_attn=gated_attn,
loss=loss,
)
self.check_module(task="forecast", params=params)
@parameterized.expand(
list(
itertools.product(
["common_channel", "mix_channel"],
[True, False],
[True, False, "mean", "std"],
[True, False],
["max_pool", "avg_pool"],
)
)
)
def test_classification(self, mode, self_attn, scaling, gated_attn, head_aggregation):
params = self.__class__.params.copy()
params.update(
mode=mode,
self_attn=self_attn,
scaling=scaling,
head_aggregation=head_aggregation,
gated_attn=gated_attn,
)
self.check_module(task="classification", params=params)
@parameterized.expand(
list(
itertools.product(
["common_channel", "mix_channel"],
[True, False],
[True, False, "mean", "std"],
[True, False],
["max_pool", "avg_pool"],
["mse", "nll"],
)
)
)
def test_regression(self, mode, self_attn, scaling, gated_attn, head_aggregation, loss):
params = self.__class__.params.copy()
params.update(
mode=mode,
self_attn=self_attn,
scaling=scaling,
head_aggregation=head_aggregation,
gated_attn=gated_attn,
loss=loss,
)
self.check_module(task="regression", params=params)
@parameterized.expand(
list(
itertools.product(
["common_channel", "mix_channel"],
[True, False],
[True, False, "mean", "std"],
[True, False],
["random", "forecast"],
[True, False],
[True, False],
)
)
)
def test_pretrain(
self,
mode,
self_attn,
scaling,
gated_attn,
mask_type,
masked_loss,
channel_consistent_masking,
):
params = self.__class__.params.copy()
params.update(
mode=mode,
self_attn=self_attn,
scaling=scaling,
gated_attn=gated_attn,
mask_type=mask_type,
masked_loss=masked_loss,
channel_consistent_masking=channel_consistent_masking,
)
self.check_module(task="pretrain", params=params)
def forecast_full_module(self, params=None, output_hidden_states=False, return_dict=None):
config = PatchTSMixerConfig(**params)
mdl = PatchTSMixerForPrediction(config)
target_val = self.__class__.correct_forecast_output
if config.prediction_channel_indices is not None:
target_val = self.__class__.correct_sel_forecast_output
enc_output = self.__class__.enc_output
output = mdl(
self.__class__.data,
future_values=self.__class__.correct_forecast_output,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if isinstance(output, tuple):
output = PatchTSMixerForPredictionOutput(*output)
if config.loss == "mse":
self.assertEqual(output.prediction_outputs.shape, target_val.shape)
self.assertEqual(output.last_hidden_state.shape, enc_output.shape)
if output_hidden_states is True:
self.assertEqual(len(output.hidden_states), params["num_layers"])
else:
self.assertEqual(output.hidden_states, None)
self.assertEqual(output.loss.item() < 100, True)
if config.loss == "nll":
samples = mdl.generate(self.__class__.data)
ref_samples = target_val.unsqueeze(1).expand(-1, params["num_parallel_samples"], -1, -1)
self.assertEqual(samples.sequences.shape, ref_samples.shape)
def test_forecast_full(self):
self.check_module(task="forecast", params=self.__class__.params, output_hidden_states=True)
# self.forecast_full_module(self.__class__.params, output_hidden_states = True)
def test_forecast_full_2(self):
params = self.__class__.params.copy()
params.update(
mode="mix_channel",
)
self.forecast_full_module(params, output_hidden_states=True)
def test_forecast_full_2_with_return_dict(self):
params = self.__class__.params.copy()
params.update(
mode="mix_channel",
)
self.forecast_full_module(params, output_hidden_states=True, return_dict=False)
def test_forecast_full_3(self):
params = self.__class__.params.copy()
params.update(
mode="mix_channel",
)
self.forecast_full_module(params, output_hidden_states=True)
def test_forecast_full_5(self):
params = self.__class__.params.copy()
params.update(
self_attn=True,
use_positional_encoding=True,
positional_encoding="sincos",
)
self.forecast_full_module(params, output_hidden_states=True)
def test_forecast_full_4(self):
params = self.__class__.params.copy()
params.update(
mode="mix_channel",
prediction_channel_indices=[0, 2],
)
self.forecast_full_module(params)
def test_forecast_full_distributional(self):
params = self.__class__.params.copy()
params.update(
mode="mix_channel",
prediction_channel_indices=[0, 2],
loss="nll",
distribution_output="normal",
)
self.forecast_full_module(params)
def test_forecast_full_distributional_2(self):
params = self.__class__.params.copy()
params.update(
mode="mix_channel",
prediction_channel_indices=[0, 2],
loss="nll",
# distribution_output = "normal",
)
self.forecast_full_module(params)
def test_forecast_full_distributional_3(self):
params = self.__class__.params.copy()
params.update(
mode="mix_channel",
# prediction_channel_indices=[0, 2],
loss="nll",
distribution_output="normal",
)
self.forecast_full_module(params)
def test_forecast_full_distributional_4(self):
params = self.__class__.params.copy()
params.update(
mode="mix_channel",
# prediction_channel_indices=[0, 2],
loss="nll",
distribution_output="normal",
)
self.forecast_full_module(params)
def test_classification_head(self):
config = PatchTSMixerConfig(**self.__class__.params)
head = PatchTSMixerLinearHead(
config=config,
)
# output = head(self.__class__.enc_output, raw_data = self.__class__.correct_pretrain_output)
output = head(self.__class__.enc_output)
self.assertEqual(output.shape, self.__class__.correct_classification_output.shape)
def test_classification_full(self):
config = PatchTSMixerConfig(**self.__class__.params)
mdl = PatchTSMixerForTimeSeriesClassification(config)
output = mdl(
self.__class__.data,
future_values=self.__class__.correct_classification_classes,
)
self.assertEqual(
output.prediction_outputs.shape,
self.__class__.correct_classification_output.shape,
)
self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
self.assertEqual(output.loss.item() < 100, True)
def test_classification_full_with_return_dict(self):
config = PatchTSMixerConfig(**self.__class__.params)
mdl = PatchTSMixerForTimeSeriesClassification(config)
output = mdl(
self.__class__.data,
future_values=self.__class__.correct_classification_classes,
return_dict=False,
)
if isinstance(output, tuple):
output = PatchTSMixerForTimeSeriesClassificationOutput(*output)
self.assertEqual(
output.prediction_outputs.shape,
self.__class__.correct_classification_output.shape,
)
self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
self.assertEqual(output.loss.item() < 100, True)
def test_regression_head(self):
config = PatchTSMixerConfig(**self.__class__.params)
head = PatchTSMixerLinearHead(
config=config,
)
output = head(self.__class__.enc_output)
self.assertEqual(output.shape, self.__class__.correct_regression_output.shape)
def test_regression_full(self):
config = PatchTSMixerConfig(**self.__class__.params)
mdl = PatchTSMixerForRegression(config)
output = mdl(self.__class__.data, future_values=self.__class__.correct_regression_output)
self.assertEqual(
output.prediction_outputs.shape,
self.__class__.correct_regression_output.shape,
)
self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
self.assertEqual(output.loss.item() < 100, True)
def test_regression_full_with_return_dict(self):
config = PatchTSMixerConfig(**self.__class__.params)
mdl = PatchTSMixerForRegression(config)
output = mdl(
self.__class__.data,
future_values=self.__class__.correct_regression_output,
return_dict=False,
)
if isinstance(output, tuple):
output = PatchTSMixerForRegressionOutput(*output)
self.assertEqual(
output.prediction_outputs.shape,
self.__class__.correct_regression_output.shape,
)
self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
self.assertEqual(output.loss.item() < 100, True)
def test_regression_full_distribute(self):
params = self.__class__.params.copy()
params.update(loss="nll", distribution_output="normal")
config = PatchTSMixerConfig(**params)
mdl = PatchTSMixerForRegression(config)
output = mdl(self.__class__.data, future_values=self.__class__.correct_regression_output)
self.assertEqual(
output.prediction_outputs[0].shape,
self.__class__.correct_regression_output.shape,
)
self.assertEqual(
output.prediction_outputs[1].shape,
self.__class__.correct_regression_output.shape,
)
self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
self.assertEqual(output.loss.item() < 100, True)
if config.loss == "nll":
samples = mdl.generate(self.__class__.data)
ref_samples = self.__class__.correct_regression_output.unsqueeze(1).expand(
-1, params["num_parallel_samples"], -1
)
self.assertEqual(samples.sequences.shape, ref_samples.shape)
def test_regression_full_distribute_2(self):
params = self.__class__.params.copy()
params.update(loss="nll", distribution_output="student_t")
config = PatchTSMixerConfig(**params)
mdl = PatchTSMixerForRegression(config)
output = mdl(self.__class__.data, future_values=self.__class__.correct_regression_output)
self.assertEqual(
output.prediction_outputs[0].shape,
self.__class__.correct_regression_output.shape,
)
self.assertEqual(
output.prediction_outputs[1].shape,
self.__class__.correct_regression_output.shape,
)
self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
self.assertEqual(output.loss.item() < 100, True)
if config.loss == "nll":
samples = mdl.generate(self.__class__.data)
ref_samples = self.__class__.correct_regression_output.unsqueeze(1).expand(
-1, params["num_parallel_samples"], -1
)
self.assertEqual(samples.sequences.shape, ref_samples.shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment