Unverified Commit 049e7917 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

Add Data2Vec for Vision in TF (#17008)



* add utilities till TFData2VecVisionLayer.

* chore: pass window_size to attention layer.

* feat: add TFData2VecVisionRelativePositionBias.

* feat: initial implementation ready for tf data2vec.

* fix: relative position bias index, table to be fixed.

* chore: implementation added, tests remaining.

* add: tests, other PR files.

* fix: code quality.

* fix: import structure in init.

* chore: run make fix-copies.

* chore: address PR feedback (round I).

* chore: styling nit.

* fix: tests due to removal of to_2tuple().

* chore: rebase with upstream main and move the test.

* Update src/transformers/models/auto/modeling_tf_auto.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/auto/modeling_tf_auto.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix: layer call.

* chore: remove from_pt=True and rerun test.

* chore: remove cast and tf.divide.

* chore: minor edits to the test script.

* Update src/transformers/models/data2vec/modeling_tf_data2vec_vision.py
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

* fix: expand() on TF tensors with broadcast_to().

* fix: test import.
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>
parent d76d2a2a
......@@ -191,7 +191,7 @@ Flax), PyTorch, and/or TensorFlow.
| CTRL | ✅ | ❌ | ✅ | ✅ | ❌ |
| Data2VecAudio | ❌ | ❌ | ✅ | ❌ | ❌ |
| Data2VecText | ❌ | ❌ | ✅ | ❌ | ❌ |
| Data2VecVision | ❌ | ❌ | ✅ | | ❌ |
| Data2VecVision | ❌ | ❌ | ✅ | | ❌ |
| DeBERTa | ✅ | ✅ | ✅ | ✅ | ❌ |
| DeBERTa-v2 | ✅ | ✅ | ✅ | ✅ | ❌ |
| Decision Transformer | ❌ | ❌ | ✅ | ❌ | ❌ |
......
......@@ -38,9 +38,11 @@ Tips:
- For Data2VecText, preprocessing is identical to [`RobertaModel`], including tokenization.
- For Data2VecVision, preprocessing is identical to [`BeitModel`], including feature extraction.
This model was contributed by [edugp](https://huggingface.co/edugp) and [patrickvonplaten](https://huggingface.co/patrickvonplaten)
This model was contributed by [edugp](https://huggingface.co/edugp) and [patrickvonplaten](https://huggingface.co/patrickvonplaten).
[sayakpaul](https://github.com/sayakpaul) contributed Data2Vec for vision in TensorFlow.
The original code can be found [here](https://github.com/pytorch/fairseq/tree/main/examples/data2vec).
The original code (for NLP and Speech) can be found [here](https://github.com/pytorch/fairseq/tree/main/examples/data2vec).
The original code for vision can be found [here](https://github.com/facebookresearch/data2vec_vision/tree/main/beit).
## Data2VecTextConfig
......@@ -130,3 +132,13 @@ The original code can be found [here](https://github.com/pytorch/fairseq/tree/ma
[[autodoc]] Data2VecVisionForSemanticSegmentation
- forward
## TFData2VecVisionModel
[[autodoc]] TFData2VecVisionModel
- call
## TFData2VecVisionForImageClassification
[[autodoc]] TFData2VecVisionForImageClassification
- call
\ No newline at end of file
......@@ -1878,6 +1878,13 @@ if is_tf_available():
"TFCTRLPreTrainedModel",
]
)
_import_structure["models.data2vec"].extend(
[
"TFData2VecVisionForImageClassification",
"TFData2VecVisionModel",
"TFData2VecVisionPreTrainedModel",
]
)
_import_structure["models.deberta"].extend(
[
"TF_DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
......@@ -4029,6 +4036,11 @@ if TYPE_CHECKING:
TFCTRLModel,
TFCTRLPreTrainedModel,
)
from .models.data2vec import (
TFData2VecVisionForImageClassification,
TFData2VecVisionModel,
TFData2VecVisionPreTrainedModel,
)
from .models.deberta import (
TF_DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFDebertaForMaskedLM,
......
......@@ -37,6 +37,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
("roformer", "TFRoFormerModel"),
("convbert", "TFConvBertModel"),
("convnext", "TFConvNextModel"),
("data2vec-vision", "TFData2VecVisionModel"),
("led", "TFLEDModel"),
("lxmert", "TFLxmertModel"),
("mt5", "TFMT5Model"),
......@@ -163,6 +164,7 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
# Model for Image-classsification
("vit", "TFViTForImageClassification"),
("convnext", "TFConvNextForImageClassification"),
("data2vec-vision", "TFData2VecVisionForImageClassification"),
]
)
......
......@@ -18,6 +18,8 @@
from typing import TYPE_CHECKING
from transformers.utils.import_utils import is_tf_available
from ...utils import _LazyModule, is_torch_available
......@@ -68,6 +70,13 @@ if is_torch_available():
"Data2VecVisionPreTrainedModel",
]
if is_tf_available():
_import_structure["modeling_tf_data2vec_vision"] = [
"TFData2VecVisionForImageClassification",
"TFData2VecVisionModel",
"TFData2VecVisionPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_data2vec_audio import DATA2VEC_AUDIO_PRETRAINED_CONFIG_ARCHIVE_MAP, Data2VecAudioConfig
from .configuration_data2vec_text import (
......@@ -110,6 +119,12 @@ if TYPE_CHECKING:
Data2VecVisionModel,
Data2VecVisionPreTrainedModel,
)
if is_tf_available():
from .modeling_tf_data2vec_vision import (
TFData2VecVisionForImageClassification,
TFData2VecVisionModel,
TFData2VecVisionPreTrainedModel,
)
else:
import sys
......
......@@ -742,6 +742,27 @@ class TFCTRLPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["tf"])
class TFData2VecVisionForImageClassification(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFData2VecVisionModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFData2VecVisionPreTrainedModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
TF_DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment