Unverified Commit cc3d0e1b authored by fatih's avatar fatih Committed by GitHub
Browse files

[New Model] Add TimeSformer model (#18908)

* init timesformer

* apply fix-copies

* reformat style

* revert back some incoorect style updates

* init timesformer

* apply fix-copies

* reformat style

* revert back some incoorect style updates

* update timseformer doc

* add some functions and classes

* add new config params

* implement multiple classes

* update TimeSformerLayer

* update TimeSformerModel, TimeSformerPreTrainedModel, TimeSformerEncoder

* several fixes

* reformat

* temporary update

* fix some typos

* fix weight converter

* more fixes

* fix a typo

* fix typo

* remove redundant params

* fix for latest hf-hub

* merge fix

* fix some checks

* video classification works with einops

* add paper info to docs

* merge fix

* remove redundant line

* remove redundant docstring

* update config

* fix some typos

* fix converter

* update some test constants

* refactor einops functions

* reformat

* fix a comment

* remove redundat imports

* reformat

* fix a typo

* remove comment

* remove unused imports

* remove redundant doc line

* reformat

* add missing line

* fix docs

* fix timesformer auto feat ext

* add unittests

* reformat

* fix docs

* some fixes and updates

* fix readme

* fix modeling

* fix readme

* update index

* revert _toctree.yml changes

* update timseformer.mdx

* update drop_path_prob to drop_path_rate

* add dosctring for drop_path_rate

* update TimeSformerPatchEmbed naming

* remove to_2tuple

* explicit use of nn.functional

* reformat

* many updates from review comments

* fix a typo

* reformat

* remove assert, better variable name

* make variable names more explicit

* add some adapted from

* more explicit variable names

* remove redundant docstring

* fix initilaization

* move permute inside embedding

* update class names

* remove unused imports

* add test for video classification

* update PretrainedModel with PreTrainedModel

* remove double permute

* update based on sylvain's review

* aply auto fix

* update image_processing_auto for timesformer

* update hub urls

* reformat

* remove duplicate import

* update doc link
parent 3a9476d1
......@@ -391,6 +391,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h
1. **[TAPAS](https://huggingface.co/docs/transformers/model_doc/tapas)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos.
1. **[TAPEX](https://huggingface.co/docs/transformers/model_doc/tapex)** (from Microsoft Research) released with the paper [TAPEX: Table Pre-training via Learning a Neural SQL Executor](https://arxiv.org/abs/2107.07653) by Qian Liu, Bei Chen, Jiaqi Guo, Morteza Ziyadi, Zeqi Lin, Weizhu Chen, Jian-Guang Lou.
1. **[Time Series Transformer](https://huggingface.co/docs/transformers/model_doc/time_series_transformer)** (from HuggingFace).
1. **[TimeSformer](https://huggingface.co/docs/transformers/main/model_doc/timesformer)** (from Facebook) released with the paper [Is Space-Time Attention All You Need for Video Understanding?](https://arxiv.org/abs/2102.05095) by Gedas Bertasius, Heng Wang, Lorenzo Torresani.
1. **[Trajectory Transformer](https://huggingface.co/docs/transformers/model_doc/trajectory_transformers)** (from the University of California at Berkeley) released with the paper [Offline Reinforcement Learning as One Big Sequence Modeling Problem](https://arxiv.org/abs/2106.02039) by Michael Janner, Qiyang Li, Sergey Levine
1. **[Transformer-XL](https://huggingface.co/docs/transformers/model_doc/transfo-xl)** (from Google/CMU) released with the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (from Microsoft), released together with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei.
......
......@@ -391,6 +391,7 @@ Número actual de puntos de control: ![](https://img.shields.io/endpoint?url=htt
1. **[TAPAS](https://huggingface.co/docs/transformers/model_doc/tapas)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos.
1. **[TAPEX](https://huggingface.co/docs/transformers/model_doc/tapex)** (from Microsoft Research) released with the paper [TAPEX: Table Pre-training via Learning a Neural SQL Executor](https://arxiv.org/abs/2107.07653) by Qian Liu, Bei Chen, Jiaqi Guo, Morteza Ziyadi, Zeqi Lin, Weizhu Chen, Jian-Guang Lou.
1. **[Time Series Transformer](https://huggingface.co/docs/transformers/model_doc/time_series_transformer)** (from HuggingFace).
1. **[TimeSformer](https://huggingface.co/docs/transformers/main/model_doc/timesformer)** (from Facebook) released with the paper [Is Space-Time Attention All You Need for Video Understanding?](https://arxiv.org/abs/2102.05095) by Gedas Bertasius, Heng Wang, Lorenzo Torresani.
1. **[Trajectory Transformer](https://huggingface.co/docs/transformers/model_doc/trajectory_transformers)** (from the University of California at Berkeley) released with the paper [Offline Reinforcement Learning as One Big Sequence Modeling Problem](https://arxiv.org/abs/2106.02039) by Michael Janner, Qiyang Li, Sergey Levine
1. **[Transformer-XL](https://huggingface.co/docs/transformers/model_doc/transfo-xl)** (from Google/CMU) released with the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (from Microsoft), released together with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei.
......
......@@ -426,6 +426,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ
1. **[TAPAS](https://huggingface.co/docs/transformers/model_doc/tapas)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos.
1. **[TAPEX](https://huggingface.co/docs/transformers/model_doc/tapex)** (from Microsoft Research) released with the paper [TAPEX: Table Pre-training via Learning a Neural SQL Executor](https://arxiv.org/abs/2107.07653) by Qian Liu, Bei Chen, Jiaqi Guo, Morteza Ziyadi, Zeqi Lin, Weizhu Chen, Jian-Guang Lou.
1. **[Time Series Transformer](https://huggingface.co/docs/transformers/model_doc/time_series_transformer)** (from HuggingFace).
1. **[TimeSformer](https://huggingface.co/docs/transformers/main/model_doc/timesformer)** (from Facebook) released with the paper [Is Space-Time Attention All You Need for Video Understanding?](https://arxiv.org/abs/2102.05095) by Gedas Bertasius, Heng Wang, Lorenzo Torresani.
1. **[Trajectory Transformer](https://huggingface.co/docs/transformers/model_doc/trajectory_transformers)** (from the University of California at Berkeley) released with the paper [Offline Reinforcement Learning as One Big Sequence Modeling Problem](https://arxiv.org/abs/2106.02039) by Michael Janner, Qiyang Li, Sergey Levine
1. **[Transformer-XL](https://huggingface.co/docs/transformers/model_doc/transfo-xl)** (from Google/CMU) released with the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (from Microsoft), released together with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei.
......
......@@ -341,6 +341,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
1. **[TAPAS](https://huggingface.co/docs/transformers/model_doc/tapas)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos.
1. **[TAPEX](https://huggingface.co/docs/transformers/model_doc/tapex)** (from Microsoft Research) released with the paper [TAPEX: Table Pre-training via Learning a Neural SQL Executor](https://arxiv.org/abs/2107.07653) by Qian Liu, Bei Chen, Jiaqi Guo, Morteza Ziyadi, Zeqi Lin, Weizhu Chen, Jian-Guang Lou.
1. **[Time Series Transformer](https://huggingface.co/docs/transformers/model_doc/time_series_transformer)** (from HuggingFace).
1. **[TimeSformer](https://huggingface.co/docs/transformers/main/model_doc/timesformer)** (from Facebook) released with the paper [Is Space-Time Attention All You Need for Video Understanding?](https://arxiv.org/abs/2102.05095) by Gedas Bertasius, Heng Wang, Lorenzo Torresani.
1. **[Trajectory Transformer](https://huggingface.co/docs/transformers/model_doc/trajectory_transformers)** (from the University of California at Berkeley) released with the paper [Offline Reinforcement Learning as One Big Sequence Modeling Problem](https://arxiv.org/abs/2106.02039) by Michael Janner, Qiyang Li, Sergey Levine
1. **[Transformer-XL](https://huggingface.co/docs/transformers/model_doc/transfo-xl)** (from Google/CMU) released with the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (from Microsoft), released together with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei.
......
......@@ -365,6 +365,7 @@ conda install -c huggingface transformers
1. **[TAPAS](https://huggingface.co/docs/transformers/model_doc/tapas)** (来自 Google AI) 伴随论文 [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) 由 Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos 发布。
1. **[TAPEX](https://huggingface.co/docs/transformers/model_doc/tapex)** (来自 Microsoft Research) 伴随论文 [TAPEX: Table Pre-training via Learning a Neural SQL Executor](https://arxiv.org/abs/2107.07653) 由 Qian Liu, Bei Chen, Jiaqi Guo, Morteza Ziyadi, Zeqi Lin, Weizhu Chen, Jian-Guang Lou 发布。
1. **[Time Series Transformer](https://huggingface.co/docs/transformers/model_doc/time_series_transformer)** (from HuggingFace).
1. **[TimeSformer](https://huggingface.co/docs/transformers/main/model_doc/timesformer)** (from Facebook) released with the paper [Is Space-Time Attention All You Need for Video Understanding?](https://arxiv.org/abs/2102.05095) by Gedas Bertasius, Heng Wang, Lorenzo Torresani.
1. **[Trajectory Transformer](https://huggingface.co/docs/transformers/model_doc/trajectory_transformers)** (from the University of California at Berkeley) released with the paper [Offline Reinforcement Learning as One Big Sequence Modeling Problem](https://arxiv.org/abs/2106.02039) by Michael Janner, Qiyang Li, Sergey Levine
1. **[Transformer-XL](https://huggingface.co/docs/transformers/model_doc/transfo-xl)** (来自 Google/CMU) 伴随论文 [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) 由 Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov 发布。
1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (来自 Microsoft) 伴随论文 [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) 由 Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei 发布。
......
......@@ -377,6 +377,7 @@ conda install -c huggingface transformers
1. **[TAPAS](https://huggingface.co/docs/transformers/model_doc/tapas)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos.
1. **[TAPEX](https://huggingface.co/docs/transformers/model_doc/tapex)** (from Microsoft Research) released with the paper [TAPEX: Table Pre-training via Learning a Neural SQL Executor](https://arxiv.org/abs/2107.07653) by Qian Liu, Bei Chen, Jiaqi Guo, Morteza Ziyadi, Zeqi Lin, Weizhu Chen, Jian-Guang Lou.
1. **[Time Series Transformer](https://huggingface.co/docs/transformers/model_doc/time_series_transformer)** (from HuggingFace).
1. **[TimeSformer](https://huggingface.co/docs/transformers/main/model_doc/timesformer)** (from Facebook) released with the paper [Is Space-Time Attention All You Need for Video Understanding?](https://arxiv.org/abs/2102.05095) by Gedas Bertasius, Heng Wang, Lorenzo Torresani.
1. **[Trajectory Transformer](https://huggingface.co/docs/transformers/model_doc/trajectory_transformers)** (from the University of California at Berkeley) released with the paper [Offline Reinforcement Learning as One Big Sequence Modeling Problem](https://arxiv.org/abs/2106.02039) by Michael Janner, Qiyang Li, Sergey Levine
1. **[Transformer-XL](https://huggingface.co/docs/transformers/model_doc/transfo-xl)** (from Google/CMU) released with the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (from Microsoft) released with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei.
......
......@@ -432,6 +432,8 @@
title: Swin Transformer V2
- local: model_doc/table-transformer
title: Table Transformer
- local: model_doc/timesformer
title: TimeSformer
- local: model_doc/van
title: VAN
- local: model_doc/videomae
......
......@@ -179,6 +179,7 @@ The documentation is organized into five sections:
1. **[TAPAS](model_doc/tapas)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos.
1. **[TAPEX](model_doc/tapex)** (from Microsoft Research) released with the paper [TAPEX: Table Pre-training via Learning a Neural SQL Executor](https://arxiv.org/abs/2107.07653) by Qian Liu, Bei Chen, Jiaqi Guo, Morteza Ziyadi, Zeqi Lin, Weizhu Chen, Jian-Guang Lou.
1. **[Time Series Transformer](model_doc/time_series_transformer)** (from HuggingFace).
1. **[TimeSformer](model_doc/timesformer)** (from Facebook) released with the paper [Is Space-Time Attention All You Need for Video Understanding?](https://arxiv.org/abs/2102.05095) by Gedas Bertasius, Heng Wang, Lorenzo Torresani.
1. **[Trajectory Transformer](model_doc/trajectory_transformers)** (from the University of California at Berkeley) released with the paper [Offline Reinforcement Learning as One Big Sequence Modeling Problem](https://arxiv.org/abs/2106.02039) by Michael Janner, Qiyang Li, Sergey Levine
1. **[Transformer-XL](model_doc/transfo-xl)** (from Google/CMU) released with the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
1. **[TrOCR](model_doc/trocr)** (from Microsoft), released together with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei.
......@@ -339,6 +340,7 @@ Flax), PyTorch, and/or TensorFlow.
| Table Transformer | ❌ | ❌ | ✅ | ❌ | ❌ |
| TAPAS | ✅ | ❌ | ✅ | ✅ | ❌ |
| Time Series Transformer | ❌ | ❌ | ✅ | ❌ | ❌ |
| TimeSformer | ❌ | ❌ | ✅ | ❌ | ❌ |
| Trajectory Transformer | ❌ | ❌ | ✅ | ❌ | ❌ |
| Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ |
| TrOCR | ❌ | ❌ | ✅ | ❌ | ❌ |
......
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# TimeSformer
## Overview
The TimeSformer model was proposed in [TimeSformer: Is Space-Time Attention All You Need for Video Understanding?](https://arxiv.org/abs/2102.05095) by Facebook Research.
This work is a milestone in action-recognition field being the first video transformer. It inspired many transformer based video understanding and classification papers.
The abstract from the paper is the following:
*We present a convolution-free approach to video classification built exclusively on self-attention over space and time. Our method, named "TimeSformer," adapts the standard Transformer architecture to video by enabling spatiotemporal feature learning directly from a sequence of frame-level patches. Our experimental study compares different self-attention schemes and suggests that "divided attention," where temporal attention and spatial attention are separately applied within each block, leads to the best video classification accuracy among the design choices considered. Despite the radically new design, TimeSformer achieves state-of-the-art results on several action recognition benchmarks, including the best reported accuracy on Kinetics-400 and Kinetics-600. Finally, compared to 3D convolutional networks, our model is faster to train, it can achieve dramatically higher test efficiency (at a small drop in accuracy), and it can also be applied to much longer video clips (over one minute long). Code and models are available at: [this https URL](https://github.com/facebookresearch/TimeSformer).*
Tips:
There are many pretrained variants. Select your pretrained model based on the dataset it is trained on. Moreover, the number of input frames per clip changes based on the model size so you should consider this parameter while selecting your pretrained model.
This model was contributed by [fcakyon](https://huggingface.co/fcakyon).
The original code can be found [here](https://github.com/facebookresearch/TimeSformer).
## TimesformerConfig
[[autodoc]] TimesformerConfig
## TimesformerModel
[[autodoc]] TimesformerModel
- forward
## TimesformerForVideoClassification
[[autodoc]] TimesformerForVideoClassification
- forward
\ No newline at end of file
......@@ -379,6 +379,7 @@ _import_structure = {
"TIME_SERIES_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"TimeSeriesTransformerConfig",
],
"models.timesformer": ["TIMESFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "TimesformerConfig"],
"models.trajectory_transformer": [
"TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"TrajectoryTransformerConfig",
......@@ -2104,6 +2105,14 @@ else:
"TimeSeriesTransformerPreTrainedModel",
]
)
_import_structure["models.timesformer"].extend(
[
"TIMESFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"TimesformerForVideoClassification",
"TimesformerModel",
"TimesformerPreTrainedModel",
]
)
_import_structure["models.trajectory_transformer"].extend(
[
"TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
......@@ -3588,6 +3597,7 @@ if TYPE_CHECKING:
TIME_SERIES_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
TimeSeriesTransformerConfig,
)
from .models.timesformer import TIMESFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, TimesformerConfig
from .models.trajectory_transformer import (
TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
TrajectoryTransformerConfig,
......@@ -5015,6 +5025,12 @@ if TYPE_CHECKING:
TimeSeriesTransformerModel,
TimeSeriesTransformerPreTrainedModel,
)
from .models.timesformer import (
TIMESFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TimesformerForVideoClassification,
TimesformerModel,
TimesformerPreTrainedModel,
)
from .models.trajectory_transformer import (
TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TrajectoryTransformerModel,
......
......@@ -152,6 +152,7 @@ from . import (
tapas,
tapex,
time_series_transformer,
timesformer,
trajectory_transformer,
transfo_xl,
trocr,
......
......@@ -148,6 +148,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("table-transformer", "TableTransformerConfig"),
("tapas", "TapasConfig"),
("time_series_transformer", "TimeSeriesTransformerConfig"),
("timesformer", "TimesformerConfig"),
("trajectory_transformer", "TrajectoryTransformerConfig"),
("transfo-xl", "TransfoXLConfig"),
("trocr", "TrOCRConfig"),
......@@ -290,6 +291,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
("table-transformer", "TABLE_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("tapas", "TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("time_series_transformer", "TIME_SERIES_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("timesformer", "TIMESFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("transfo-xl", "TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("unispeech", "UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("unispeech-sat", "UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
......@@ -455,6 +457,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("tapas", "TAPAS"),
("tapex", "TAPEX"),
("time_series_transformer", "Time Series Transformer"),
("timesformer", "TimeSformer"),
("trajectory_transformer", "Trajectory Transformer"),
("transfo-xl", "Transformer-XL"),
("trocr", "TrOCR"),
......
......@@ -77,6 +77,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
("swin", "ViTFeatureExtractor"),
("swinv2", "ViTFeatureExtractor"),
("table-transformer", "DetrFeatureExtractor"),
("timesformer", "VideoMAEFeatureExtractor"),
("van", "ConvNextFeatureExtractor"),
("videomae", "VideoMAEFeatureExtractor"),
("vilt", "ViltFeatureExtractor"),
......
......@@ -74,6 +74,7 @@ IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
("swin", "ViTImageProcessor"),
("swinv2", "ViTImageProcessor"),
("table-transformer", "DetrImageProcessor"),
("timesformer", "VideoMAEImageProcessor"),
("van", "ConvNextImageProcessor"),
("videomae", "VideoMAEImageProcessor"),
("vilt", "ViltImageProcessor"),
......
......@@ -144,6 +144,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("table-transformer", "TableTransformerModel"),
("tapas", "TapasModel"),
("time_series_transformer", "TimeSeriesTransformerModel"),
("timesformer", "TimesformerModel"),
("trajectory_transformer", "TrajectoryTransformerModel"),
("transfo-xl", "TransfoXLModel"),
("unispeech", "UniSpeechModel"),
......@@ -429,6 +430,7 @@ MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
("timesformer", "TimesformerForVideoClassification"),
("videomae", "VideoMAEForVideoClassification"),
]
)
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {
"configuration_timesformer": ["TIMESFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "TimesformerConfig"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_timesformer"] = [
"TIMESFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"TimesformerModel",
"TimesformerForVideoClassification",
"TimesformerPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_timesformer import TIMESFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, TimesformerConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_timesformer import (
TIMESFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TimesformerForVideoClassification,
TimesformerModel,
TimesformerPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TimeSformer model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
TIMESFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/timesformer": "https://huggingface.co/facebook/timesformer/resolve/main/config.json",
}
class TimesformerConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`TimesformerModel`]. It is used to instantiate a
TimeSformer model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the TimeSformer
[facebook/timesformer](https://huggingface.co/facebook/timesformer-base-finetuned-k600) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
image_size (`int`, *optional*, defaults to 224):
The size (resolution) of each image.
patch_size (`int`, *optional*, defaults to 16):
The size (resolution) of each patch.
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
num_frames (`int`, *optional*, defaults to 8):
The number of frames in each video.
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` are supported.
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
The epsilon used by the layer normalization layers.
qkv_bias (`bool`, *optional*, defaults to `True`):
Whether to add a bias to the queries, keys and values.
attention_type (`str`, *optional*, defaults to `"divided_space_time"`):
The attention type to use. Must be one of `"divided_space_time"`, `"space_only"`, `"joint_space_time"`.
drop_path_rate (`float`, *optional*, defaults to 0):
The dropout ratio for stochastic depth.
Example:
```python
>>> from transformers import TimesformerConfig, TimesformerModel
>>> # Initializing a TimeSformer timesformer-base style configuration
>>> configuration = TimesformerConfig()
>>> # Randomly initializing a model from the configuration
>>> model = TimesformerModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "timesformer"
def __init__(
self,
image_size=224,
patch_size=16,
num_channels=3,
num_frames=8,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
initializer_range=0.02,
layer_norm_eps=1e-6,
qkv_bias=True,
attention_type="divided_space_time",
drop_path_rate=0,
**kwargs
):
super().__init__(**kwargs)
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_frames = num_frames
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.qkv_bias = qkv_bias
self.attention_type = attention_type
self.drop_path_rate = drop_path_rate
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert TimeSformer checkpoints from the original repository: https://github.com/MCG-NJU/TimeSformer"""
import argparse
import json
import numpy as np
import torch
import gdown
from huggingface_hub import hf_hub_download
from transformers import TimesformerConfig, TimesformerForVideoClassification, VideoMAEFeatureExtractor
def get_timesformer_config(model_name):
config = TimesformerConfig()
if "large" in model_name:
config.num_frames = 96
if "hr" in model_name:
config.num_frames = 16
config.image_size = 448
repo_id = "huggingface/label-files"
if "k400" in model_name:
config.num_labels = 400
filename = "kinetics400-id2label.json"
elif "k600" in model_name:
config.num_labels = 600
filename = "kinetics600-id2label.json"
elif "ssv2" in model_name:
config.num_labels = 174
filename = "something-something-v2-id2label.json"
else:
raise ValueError("Model name should either contain 'k400', 'k600' or 'ssv2'.")
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
return config
def rename_key(name):
if "encoder." in name:
name = name.replace("encoder.", "")
if "cls_token" in name:
name = name.replace("cls_token", "timesformer.embeddings.cls_token")
if "pos_embed" in name:
name = name.replace("pos_embed", "timesformer.embeddings.position_embeddings")
if "time_embed" in name:
name = name.replace("time_embed", "timesformer.embeddings.time_embeddings")
if "patch_embed.proj" in name:
name = name.replace("patch_embed.proj", "timesformer.embeddings.patch_embeddings.projection")
if "patch_embed.norm" in name:
name = name.replace("patch_embed.norm", "timesformer.embeddings.norm")
if "blocks" in name:
name = name.replace("blocks", "timesformer.encoder.layer")
if "attn.proj" in name:
name = name.replace("attn.proj", "attention.output.dense")
if "attn" in name and "bias" not in name and "temporal" not in name:
name = name.replace("attn", "attention.self")
if "attn" in name and "temporal" not in name:
name = name.replace("attn", "attention.attention")
if "temporal_norm1" in name:
name = name.replace("temporal_norm1", "temporal_layernorm")
if "temporal_attn.proj" in name:
name = name.replace("temporal_attn", "temporal_attention.output.dense")
if "temporal_fc" in name:
name = name.replace("temporal_fc", "temporal_dense")
if "norm1" in name and "temporal" not in name:
name = name.replace("norm1", "layernorm_before")
if "norm2" in name:
name = name.replace("norm2", "layernorm_after")
if "mlp.fc1" in name:
name = name.replace("mlp.fc1", "intermediate.dense")
if "mlp.fc2" in name:
name = name.replace("mlp.fc2", "output.dense")
if "norm.weight" in name and "fc" not in name and "temporal" not in name:
name = name.replace("norm.weight", "timesformer.layernorm.weight")
if "norm.bias" in name and "fc" not in name and "temporal" not in name:
name = name.replace("norm.bias", "timesformer.layernorm.bias")
if "head" in name:
name = name.replace("head", "classifier")
return name
def convert_state_dict(orig_state_dict, config):
for key in orig_state_dict.copy().keys():
val = orig_state_dict.pop(key)
if key.startswith("model."):
key = key.replace("model.", "")
if "qkv" in key:
key_split = key.split(".")
layer_num = int(key_split[1])
prefix = "timesformer.encoder.layer."
if "temporal" in key:
postfix = ".temporal_attention.attention.qkv."
else:
postfix = ".attention.attention.qkv."
if "weight" in key:
orig_state_dict[f"{prefix}{layer_num}{postfix}weight"] = val
else:
orig_state_dict[f"{prefix}{layer_num}{postfix}bias"] = val
else:
orig_state_dict[rename_key(key)] = val
return orig_state_dict
# We will verify our results on a video of eating spaghetti
# Frame indices used: [164 168 172 176 181 185 189 193 198 202 206 210 215 219 223 227]
def prepare_video():
file = hf_hub_download(
repo_id="hf-internal-testing/spaghetti-video", filename="eating_spaghetti.npy", repo_type="dataset"
)
video = np.load(file)
return list(video)
def convert_timesformer_checkpoint(checkpoint_url, pytorch_dump_folder_path, model_name, push_to_hub):
config = get_timesformer_config(model_name)
model = TimesformerForVideoClassification(config)
# download original checkpoint, hosted on Google Drive
output = "pytorch_model.bin"
gdown.cached_download(checkpoint_url, output, quiet=False)
files = torch.load(output, map_location="cpu")
if "model" in files:
state_dict = files["model"]
elif "module" in files:
state_dict = files["module"]
else:
state_dict = files["model_state"]
new_state_dict = convert_state_dict(state_dict, config)
model.load_state_dict(new_state_dict)
model.eval()
# verify model on basic input
feature_extractor = VideoMAEFeatureExtractor(image_mean=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5])
video = prepare_video()
inputs = feature_extractor(video[:8], return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
model_names = [
# Kinetics-400 checkpoints (hr = high resolution input of 448px instead of 224px)
"timesformer-base-finetuned-k400",
"timesformer-large-finetuned-k400",
"timesformer-hr-finetuned-k400",
# Kinetics-600 checkpoints (hr = high resolution input of 448px instead of 224px)
"timesformer-base-finetuned-k600",
"timesformer-large-finetuned-k600",
"timesformer-hr-finetuned-k600",
# Something-Something-v2 checkpoints (hr = high resolution input of 448px instead of 224px)
"timesformer-base-finetuned-ssv2",
"timesformer-large-finetuned-ssv2",
"timesformer-hr-finetuned-ssv2",
]
# NOTE: logits were tested with image_mean and image_std equal to [0.5, 0.5, 0.5] and [0.5, 0.5, 0.5]
if model_name == "timesformer-base-finetuned-k400":
expected_shape = torch.Size([1, 400])
expected_slice = torch.tensor([-0.3016, -0.7713, -0.4205])
elif model_name == "timesformer-base-finetuned-k600":
expected_shape = torch.Size([1, 600])
expected_slice = torch.tensor([-0.7267, -0.7466, 3.2404])
elif model_name == "timesformer-base-finetuned-ssv2":
expected_shape = torch.Size([1, 174])
expected_slice = torch.tensor([-0.9059, 0.6433, -3.1457])
elif model_name == "timesformer-large-finetuned-k400":
expected_shape = torch.Size([1, 400])
expected_slice = torch.tensor([0, 0, 0])
elif model_name == "timesformer-large-finetuned-k600":
expected_shape = torch.Size([1, 600])
expected_slice = torch.tensor([0, 0, 0])
elif model_name == "timesformer-large-finetuned-ssv2":
expected_shape = torch.Size([1, 174])
expected_slice = torch.tensor([0, 0, 0])
elif model_name == "timesformer-hr-finetuned-k400":
expected_shape = torch.Size([1, 400])
expected_slice = torch.tensor([-0.9617, -3.7311, -3.7708])
elif model_name == "timesformer-hr-finetuned-k600":
expected_shape = torch.Size([1, 600])
expected_slice = torch.tensor([2.5273, 0.7127, 1.8848])
elif model_name == "timesformer-hr-finetuned-ssv2":
expected_shape = torch.Size([1, 174])
expected_slice = torch.tensor([-3.6756, -0.7513, 0.7180])
else:
raise ValueError(f"Model name not supported. Should be one of {model_names}")
# verify logits
assert logits.shape == expected_shape
assert torch.allclose(logits[0, :3], expected_slice, atol=1e-4)
print("Logits ok!")
if pytorch_dump_folder_path is not None:
print(f"Saving model and feature extractor to {pytorch_dump_folder_path}")
feature_extractor.save_pretrained(pytorch_dump_folder_path)
model.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
print("Pushing to the hub...")
model.push_to_hub(f"fcakyon/{model_name}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--checkpoint_url",
default="https://drive.google.com/u/1/uc?id=17yvuYp9L4mn-HpIcK5Zo6K3UoOy1kA5l&export=download",
type=str,
help=(
"URL of the original PyTorch checkpoint (on Google Drive) you'd like to convert. Should be a direct"
" download link."
),
)
parser.add_argument(
"--pytorch_dump_folder_path",
default="",
type=str,
help="Path to the output PyTorch model directory.",
)
parser.add_argument("--model_name", default="timesformer-base-finetuned-k400", type=str, help="Name of the model.")
parser.add_argument(
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
)
args = parser.parse_args()
convert_timesformer_checkpoint(
args.checkpoint_url, args.pytorch_dump_folder_path, args.model_name, args.push_to_hub
)
# coding=utf-8
# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch TimeSformer model."""
import collections
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_timesformer import TimesformerConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "TimesformerConfig"
_CHECKPOINT_FOR_DOC = "facebook/timesformer"
TIMESFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/timesformer-base-finetuned-k400",
# See all TimeSformer models at https://huggingface.co/models?filter=timesformer
]
# Adapted from https://github.com/facebookresearch/TimeSformer/blob/a5ef29a7b7264baff199a30b3306ac27de901133/timesformer/models/vit.py#L155
class TimesformerPatchEmbeddings(nn.Module):
"""Image to Patch Embedding"""
def __init__(self, config):
super().__init__()
image_size = config.image_size
patch_size = config.patch_size
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
self.num_patches = num_patches
self.projection = nn.Conv2d(config.num_channels, config.hidden_size, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values):
batch_size, num_frames, num_channels, height, width = pixel_values.shape
pixel_values = pixel_values.reshape(batch_size * num_frames, num_channels, height, width)
embeddings = self.projection(pixel_values)
patch_width = embeddings.size(-1)
embeddings = embeddings.flatten(2).transpose(1, 2)
return embeddings, num_frames, patch_width
class TimesformerEmbeddings(nn.Module):
"""
Construct the patch and position embeddings.
"""
def __init__(self, config):
super().__init__()
embed_dim = config.hidden_size
num_frames = config.num_frames
drop_rate = config.hidden_dropout_prob
attention_type = config.attention_type
self.attention_type = attention_type
self.patch_embeddings = TimesformerPatchEmbeddings(config)
self.num_patches = self.patch_embeddings.num_patches
# Positional Embeddings
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
if attention_type != "space_only":
self.time_embeddings = nn.Parameter(torch.zeros(1, num_frames, embed_dim))
self.time_drop = nn.Dropout(p=drop_rate)
def forward(self, pixel_values):
batch_size = pixel_values.shape[0]
# create patch embeddings
embeddings, num_frames, patch_width = self.patch_embeddings(pixel_values)
cls_tokens = self.cls_token.expand(embeddings.size(0), -1, -1)
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
# resizing the positional embeddings in case they don't match the input at inference
if embeddings.size(1) != self.position_embeddings.size(1):
position_embeddings = self.position_embeddings
cls_pos_embed = position_embeddings[0, 0, :].unsqueeze(0).unsqueeze(1)
other_pos_embed = position_embeddings[0, 1:, :].unsqueeze(0).transpose(1, 2)
patch_num = int(other_pos_embed.size(2) ** 0.5)
patch_height = embeddings.size(1) // patch_width
other_pos_embed = other_pos_embed.reshape(1, embeddings.size(2), patch_num, patch_num)
new_pos_embed = nn.functional.interpolate(
other_pos_embed, size=(patch_height, patch_width), mode="nearest"
)
new_pos_embed = new_pos_embed.flatten(2)
new_pos_embed = new_pos_embed.transpose(1, 2)
new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1)
embeddings = embeddings + new_pos_embed
else:
embeddings = embeddings + self.position_embeddings
embeddings = self.pos_drop(embeddings)
# Time Embeddings
if self.attention_type != "space_only":
cls_tokens = embeddings[:batch_size, 0, :].unsqueeze(1)
embeddings = embeddings[:, 1:]
_, patch_height, patch_width = embeddings.shape
embeddings = (
embeddings.reshape(batch_size, num_frames, patch_height, patch_width)
.permute(0, 2, 1, 3)
.reshape(batch_size * patch_height, num_frames, patch_width)
)
# Resizing time embeddings in case they don't match
if num_frames != self.time_embeddings.size(1):
time_embeddings = self.time_embeddings.transpose(1, 2)
new_time_embeddings = nn.functional.interpolate(time_embeddings, size=(num_frames), mode="nearest")
new_time_embeddings = new_time_embeddings.transpose(1, 2)
embeddings = embeddings + new_time_embeddings
else:
embeddings = embeddings + self.time_embeddings
embeddings = self.time_drop(embeddings)
embeddings = embeddings.view(batch_size, patch_height, num_frames, patch_width).reshape(
batch_size, patch_height * num_frames, patch_width
)
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
return embeddings
# Copied from transformers.models.beit.modeling_beit.drop_path
def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
argument.
"""
if drop_prob == 0.0 or not training:
return input
keep_prob = 1 - drop_prob
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
random_tensor.floor_() # binarize
output = input.div(keep_prob) * random_tensor
return output
# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->TimeSformer
class TimeSformerDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob: Optional[float] = None) -> None:
super().__init__()
self.drop_prob = drop_prob
def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return "p={}".format(self.drop_prob)
# Adapted from https://github.com/facebookresearch/TimeSformer/blob/a5ef29a7b7264baff199a30b3306ac27de901133/timesformer/models/vit.py#L57
class TimesformerSelfAttention(nn.Module):
def __init__(self, config: TimesformerConfig):
super().__init__()
num_heads = config.num_attention_heads
qkv_bias = config.qkv_bias
attention_dropout_prob = config.attention_probs_dropout_prob
self.num_heads = num_heads
head_dim = config.hidden_size // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attention_dropout_prob)
def forward(self, hidden_states, output_attentions: bool = False):
batch_size, hidden_size, num_channels = hidden_states.shape
qkv = (
self.qkv(hidden_states)
.reshape(batch_size, hidden_size, 3, self.num_heads, num_channels // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
query, key, value = qkv[0], qkv[1], qkv[2]
attention_probs = (query @ key.transpose(-2, -1)) * self.scale
attention_probs = attention_probs.softmax(dim=-1)
attention_probs = self.attn_drop(attention_probs)
context_layer = (attention_probs @ value).transpose(1, 2).reshape(batch_size, hidden_size, num_channels)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
class TimesformerSelfOutput(nn.Module):
"""
The residual connection is defined in TimesformerLayer instead of here (as is the case with other models), due to
the layernorm applied before each block.
"""
def __init__(self, config: TimesformerConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class TimeSformerAttention(nn.Module):
def __init__(self, config: TimesformerConfig) -> None:
super().__init__()
self.attention = TimesformerSelfAttention(config)
self.output = TimesformerSelfOutput(config)
def forward(
self,
hidden_states: torch.Tensor,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_outputs = self.attention(hidden_states, output_attentions)
attention_output = self.output(self_outputs[0])
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
# Adapted from https://github.com/facebookresearch/TimeSformer/blob/a5ef29a7b7264baff199a30b3306ac27de901133/timesformer/models/vit.py#L39
class TimesformerIntermediate(nn.Module):
def __init__(self, config: TimesformerConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class TimesformerOutput(nn.Module):
def __init__(self, config: TimesformerConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
# Adapted from https://github.com/facebookresearch/TimeSformer/blob/a5ef29a7b7264baff199a30b3306ac27de901133/timesformer/models/vit.py#L89
class TimesformerLayer(nn.Module):
def __init__(self, config: TimesformerConfig, layer_index: int) -> None:
super().__init__()
attention_type = config.attention_type
drop_path_rates = [
x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)
] # stochastic depth decay rule
drop_path_rate = drop_path_rates[layer_index]
self.drop_path = TimeSformerDropPath(config.drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
self.attention = TimeSformerAttention(config)
self.intermediate = TimesformerIntermediate(config)
self.output = TimesformerOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.config = config
self.attention_type = attention_type
if attention_type not in ["divided_space_time", "space_only", "joint_space_time"]:
raise ValueError("Unknown attention type: {}".format(attention_type))
# Temporal Attention Parameters
if self.attention_type == "divided_space_time":
self.temporal_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.temporal_attention = TimeSformerAttention(config)
self.temporal_dense = nn.Linear(config.hidden_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor, output_attentions: bool = False):
num_frames = self.config.num_frames
num_patch_width = self.config.image_size // self.config.patch_size
batch_size = hidden_states.shape[0]
num_spatial_tokens = (hidden_states.size(1) - 1) // num_frames
num_patch_height = num_spatial_tokens // num_patch_width
if self.attention_type in ["space_only", "joint_space_time"]:
self_attention_outputs = self.attention(
self.layernorm_before(hidden_states), output_attentions=output_attentions
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
hidden_states = hidden_states + self.drop_path(attention_output)
layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
layer_output = self.output(layer_output)
layer_output = hidden_states + self.drop_path(layer_output)
outputs = (layer_output,) + outputs
return outputs
elif self.attention_type == "divided_space_time":
# Temporal
temporal_embedding = hidden_states[:, 1:, :]
temporal_embedding = temporal_embedding.reshape(
batch_size, num_patch_height, num_patch_width, num_frames, temporal_embedding.shape[2]
).reshape(batch_size * num_patch_height * num_patch_width, num_frames, temporal_embedding.shape[2])
temporal_attention_outputs = self.temporal_attention(
self.temporal_layernorm(temporal_embedding),
)
attention_output = temporal_attention_outputs[0]
residual_temporal = self.drop_path(attention_output)
residual_temporal = residual_temporal.reshape(
batch_size, num_patch_height, num_patch_width, num_frames, residual_temporal.shape[2]
).reshape(batch_size, num_patch_height * num_patch_width * num_frames, residual_temporal.shape[2])
residual_temporal = self.temporal_dense(residual_temporal)
temporal_embedding = hidden_states[:, 1:, :] + residual_temporal
# Spatial
init_cls_token = hidden_states[:, 0, :].unsqueeze(1)
cls_token = init_cls_token.repeat(1, num_frames, 1)
cls_token = cls_token.reshape(batch_size * num_frames, 1, cls_token.shape[2])
spatial_embedding = temporal_embedding
spatial_embedding = (
spatial_embedding.reshape(
batch_size, num_patch_height, num_patch_width, num_frames, spatial_embedding.shape[2]
)
.permute(0, 3, 1, 2, 4)
.reshape(batch_size * num_frames, num_patch_height * num_patch_width, spatial_embedding.shape[2])
)
spatial_embedding = torch.cat((cls_token, spatial_embedding), 1)
spatial_attention_outputs = self.attention(
self.layernorm_before(spatial_embedding), output_attentions=output_attentions
)
attention_output = spatial_attention_outputs[0]
outputs = spatial_attention_outputs[1:] # add self attentions if we output attention weights
residual_spatial = self.drop_path(attention_output)
# Taking care of CLS token
cls_token = residual_spatial[:, 0, :]
cls_token = cls_token.reshape(batch_size, num_frames, cls_token.shape[1])
cls_token = torch.mean(cls_token, 1, True) # averaging for every frame
residual_spatial = residual_spatial[:, 1:, :]
residual_spatial = (
residual_spatial.reshape(
batch_size, num_frames, num_patch_height, num_patch_width, residual_spatial.shape[2]
)
.permute(0, 2, 3, 1, 4)
.reshape(batch_size, num_patch_height * num_patch_width * num_frames, residual_spatial.shape[2])
)
residual = residual_spatial
hidden_states = temporal_embedding
# Mlp
hidden_states = torch.cat((init_cls_token, hidden_states), 1) + torch.cat((cls_token, residual), 1)
layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
layer_output = self.output(layer_output)
layer_output = hidden_states + self.drop_path(layer_output)
outputs = (layer_output,) + outputs
return outputs
class TimesformerEncoder(nn.Module):
def __init__(self, config: TimesformerConfig) -> None:
super().__init__()
self.config = config
self.layer = nn.ModuleList([TimesformerLayer(config, ind) for ind in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
) -> Union[tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
)
else:
layer_outputs = layer_module(hidden_states, output_attentions)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class TimesformerPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = TimesformerConfig
base_model_prefix = "timesformer"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Conv2d)):
nn.init.trunc_normal_(module.weight, std=self.config.initializer_range)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
elif isinstance(module, nn.LayerNorm):
nn.init.constant_(module.bias, 0)
nn.init.constant_(module.weight, 1.0)
elif isinstance(module, TimesformerEmbeddings):
nn.init.trunc_normal_(module.cls_token, std=self.config.initializer_range)
nn.init.trunc_normal_(module.position_embeddings, std=self.config.initializer_range)
module.patch_embeddings.apply(self._init_weights)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, TimesformerEncoder):
module.gradient_checkpointing = value
TIMESFORMER_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`TimesformerConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
TIMESFORMER_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`VideoMAEFeatureExtractor`]. See
[`VideoMAEFeatureExtractor.__call__`] for details.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"The bare TimeSformer Model transformer outputting raw hidden-states without any specific head on top.",
TIMESFORMER_START_DOCSTRING,
)
class TimesformerModel(TimesformerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.embeddings = TimesformerEmbeddings(config)
self.encoder = TimesformerEncoder(config)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(TIMESFORMER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
Returns:
Examples:
```python
>>> from decord import VideoReader, cpu
>>> import numpy as np
>>> from transformers import TimeSformerFeatureExtractor, TimesformerModel
>>> from huggingface_hub import hf_hub_download
>>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
... converted_len = int(clip_len * frame_sample_rate)
... end_idx = np.random.randint(converted_len, seg_len)
... start_idx = end_idx - converted_len
... indices = np.linspace(start_idx, end_idx, num=clip_len)
... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
... return indices
>>> # video clip consists of 300 frames (10 seconds at 30 FPS)
>>> file_path = hf_hub_download(
... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
... )
>>> videoreader = VideoReader(file_path, num_threads=1, ctx=cpu(0))
>>> # sample 8 frames
>>> videoreader.seek(0)
>>> indices = sample_frame_indices(clip_len=8, frame_sample_rate=4, seg_len=len(videoreader))
>>> video = videoreader.get_batch(indices).asnumpy()
>>> feature_extractor = VideoMAEFeatureExtractor.from_pretrained("MCG-NJU/videomae-base")
>>> model = TimesformerModel.from_pretrained("facebook/timesformer-base-finetuned-k400")
>>> # prepare video for the model
>>> inputs = feature_extractor(list(video), return_tensors="pt")
>>> # forward pass
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
>>> list(last_hidden_states.shape)
[1, 1568, 768]
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
embedding_output = self.embeddings(pixel_values)
encoder_outputs = self.encoder(
embedding_output,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
if self.layernorm is not None:
sequence_output = self.layernorm(sequence_output)
if not return_dict:
return (sequence_output,) + encoder_outputs[1:]
return BaseModelOutput(
last_hidden_state=sequence_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
@add_start_docstrings(
"""TimeSformer Model transformer with a video classification head on top (a linear layer on top of the final hidden state
of the [CLS] token) e.g. for ImageNet.""",
TIMESFORMER_START_DOCSTRING,
)
class TimesformerForVideoClassification(TimesformerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.timesformer = TimesformerModel(config)
# Classifier head
self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(TIMESFORMER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, ImageClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
Examples:
```python
>>> from decord import VideoReader, cpu
>>> import torch
>>> import numpy as np
>>> from transformers import VideoMAEFeatureExtractor, TimesformerForVideoClassification
>>> from huggingface_hub import hf_hub_download
>>> np.random.seed(0)
>>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
... converted_len = int(clip_len * frame_sample_rate)
... end_idx = np.random.randint(converted_len, seg_len)
... start_idx = end_idx - converted_len
... indices = np.linspace(start_idx, end_idx, num=clip_len)
... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
... return indices
>>> # video clip consists of 300 frames (10 seconds at 30 FPS)
>>> file_path = hf_hub_download(
... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
... )
>>> videoreader = VideoReader(file_path, num_threads=1, ctx=cpu(0))
>>> # sample 8 frames
>>> videoreader.seek(0)
>>> indices = sample_frame_indices(clip_len=8, frame_sample_rate=4, seg_len=len(videoreader))
>>> video = videoreader.get_batch(indices).asnumpy()
>>> feature_extractor = VideoMAEFeatureExtractor.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics")
>>> model = TimesformerForVideoClassification.from_pretrained("facebook/timesformer-base-finetuned-k400")
>>> inputs = feature_extractor(list(video), return_tensors="pt")
>>> with torch.no_grad():
... outputs = model(**inputs)
... logits = outputs.logits
>>> # model predicts one of the 400 Kinetics-400 classes
>>> predicted_label = logits.argmax(-1).item()
>>> print(model.config.id2label[predicted_label])
eating spaghetti
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.timesformer(
pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0][:, 0]
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return ImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -5379,6 +5379,30 @@ class TimeSeriesTransformerPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
TIMESFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TimesformerForVideoClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class TimesformerModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class TimesformerPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment