Unverified Commit 0dc7b3a7 authored by Aritra Roy Gosthipaty's avatar Aritra Roy Gosthipaty Committed by GitHub
Browse files

[TensorFlow] Adding GroupViT (#18020)



* chore: initial commit

* chore: adding util methods

yet to work on the nn.functional.interpolate port with align_corener=True

* chore: refactor the utils

* used tf.compat.v1.image.resize to align the F.interpolate function
* added type hints to the method signatures
* added references to the gists where one 2 one alignment of torch and tf has been shown

* chore: adding the layers

* chore: porting all the layers from torch to tf

This is the initial draft, nothing is tested yet.

* chore: aligning the layers with reference to tf clip

* chore: aligning the modules

* added demaraction comments
* added copied and adapted from comments

* chore: aligning with CLIP

* chore: wrangling the layers to keep it tf compatible

* chore: aligning the names of the layers for porting

* chore: style changes

* chore: adding docs and inits

* chore: adding tfp dependencis

the code is taken from TAPAS

* chore: initial commit for testing

* chore: aligning the vision embeddings with the vit implementatino

* chore: changing model prefix

* chore: fixing the name of the model and the layer normalization test case

* chore: every test passes but the slow ones

* chore: fix style and integration test

* chore: moving comments below decorators

* chore: make fixup and fix-copies changes

* chore: adding the Vision and Text Model to check_repo

* chore: modifying the prefix name to align it with the torch implementation

* chore: fix typo in configuration

* choer: changing the name of the model variable

* chore: adding segmentation flag

* chore: gante's review

* chore: style refactor

* chore: amy review

* chore: adding shape_list to parts that have been copied from other snippets

* chore: init batchnorm with torch defaults

* chore: adding shape_list to pass the tests

* test fix: adding seed as 0

* set seed

* chore: changing the straight through trick to fix -ve dimensinos

* chore: adding a dimension to the loss

* chore: adding reviewers and contributors names to the docs

* chore: added changes after review

* chore: code quality fixup

* chore: fixing the segmentation snippet

* chore: adding  to the layer calls

* chore: changing int32 to int64 for inputs of serving

* chore: review changes

* chore: style changes

* chore: remove from_pt=True

* fix: repo consistency
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent bb6fa06f
...@@ -248,7 +248,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -248,7 +248,7 @@ Flax), PyTorch, and/or TensorFlow.
| GPT NeoX | ❌ | ✅ | ✅ | ❌ | ❌ | | GPT NeoX | ❌ | ✅ | ✅ | ❌ | ❌ |
| GPT NeoX Japanese | ✅ | ❌ | ✅ | ❌ | ❌ | | GPT NeoX Japanese | ✅ | ❌ | ✅ | ❌ | ❌ |
| GPT-J | ❌ | ❌ | ✅ | ✅ | ✅ | | GPT-J | ❌ | ❌ | ✅ | ✅ | ✅ |
| GroupViT | ❌ | ❌ | ✅ | | ❌ | | GroupViT | ❌ | ❌ | ✅ | | ❌ |
| Hubert | ❌ | ❌ | ✅ | ✅ | ❌ | | Hubert | ❌ | ❌ | ✅ | ✅ | ❌ |
| I-BERT | ❌ | ❌ | ✅ | ❌ | ❌ | | I-BERT | ❌ | ❌ | ✅ | ❌ | ❌ |
| ImageGPT | ❌ | ❌ | ✅ | ❌ | ❌ | | ImageGPT | ❌ | ❌ | ✅ | ❌ | ❌ |
......
...@@ -26,7 +26,7 @@ Tips: ...@@ -26,7 +26,7 @@ Tips:
- You may specify `output_segmentation=True` in the forward of `GroupViTModel` to get the segmentation logits of input texts. - You may specify `output_segmentation=True` in the forward of `GroupViTModel` to get the segmentation logits of input texts.
- The quickest way to get started with GroupViT is by checking the [example notebooks](https://github.com/xvjiarui/GroupViT/blob/main/demo/GroupViT_hf_inference_notebook.ipynb) (which showcase zero-shot segmentation inference). One can also check out the [HuggingFace Spaces demo](https://huggingface.co/spaces/xvjiarui/GroupViT) to play with GroupViT. - The quickest way to get started with GroupViT is by checking the [example notebooks](https://github.com/xvjiarui/GroupViT/blob/main/demo/GroupViT_hf_inference_notebook.ipynb) (which showcase zero-shot segmentation inference). One can also check out the [HuggingFace Spaces demo](https://huggingface.co/spaces/xvjiarui/GroupViT) to play with GroupViT.
This model was contributed by [xvjiarui](https://huggingface.co/xvjiarui). This model was contributed by [xvjiarui](https://huggingface.co/xvjiarui). The TensorFlow version was contributed by [ariG23498](https://huggingface.co/ariG23498) with the help of [Yih-Dar SHIEH](https://huggingface.co/ydshieh), [Amy Roberts](https://huggingface.co/amyeroberts), and [Joao Gante](https://huggingface.co/joaogante).
The original code can be found [here](https://github.com/NVlabs/GroupViT). The original code can be found [here](https://github.com/NVlabs/GroupViT).
...@@ -59,3 +59,20 @@ The original code can be found [here](https://github.com/NVlabs/GroupViT). ...@@ -59,3 +59,20 @@ The original code can be found [here](https://github.com/NVlabs/GroupViT).
[[autodoc]] GroupViTVisionModel [[autodoc]] GroupViTVisionModel
- forward - forward
## TFGroupViTModel
[[autodoc]] TFGroupViTModel
- call
- get_text_features
- get_image_features
## TFGroupViTTextModel
[[autodoc]] TFGroupViTTextModel
- call
## TFGroupViTVisionModel
[[autodoc]] TFGroupViTVisionModel
- call
\ No newline at end of file
...@@ -2417,6 +2417,15 @@ else: ...@@ -2417,6 +2417,15 @@ else:
"TFGPTJPreTrainedModel", "TFGPTJPreTrainedModel",
] ]
) )
_import_structure["models.groupvit"].extend(
[
"TF_GROUPVIT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFGroupViTModel",
"TFGroupViTPreTrainedModel",
"TFGroupViTTextModel",
"TFGroupViTVisionModel",
]
)
_import_structure["models.hubert"].extend( _import_structure["models.hubert"].extend(
[ [
"TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -4986,6 +4995,13 @@ if TYPE_CHECKING: ...@@ -4986,6 +4995,13 @@ if TYPE_CHECKING:
TFGPTJModel, TFGPTJModel,
TFGPTJPreTrainedModel, TFGPTJPreTrainedModel,
) )
from .models.groupvit import (
TF_GROUPVIT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFGroupViTModel,
TFGroupViTPreTrainedModel,
TFGroupViTTextModel,
TFGroupViTVisionModel,
)
from .models.hubert import ( from .models.hubert import (
TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST, TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFHubertForCTC, TFHubertForCTC,
......
...@@ -50,6 +50,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict( ...@@ -50,6 +50,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
("funnel", ("TFFunnelModel", "TFFunnelBaseModel")), ("funnel", ("TFFunnelModel", "TFFunnelBaseModel")),
("gpt2", "TFGPT2Model"), ("gpt2", "TFGPT2Model"),
("gptj", "TFGPTJModel"), ("gptj", "TFGPTJModel"),
("groupvit", "TFGroupViTModel"),
("hubert", "TFHubertModel"), ("hubert", "TFHubertModel"),
("layoutlm", "TFLayoutLMModel"), ("layoutlm", "TFLayoutLMModel"),
("layoutlmv3", "TFLayoutLMv3Model"), ("layoutlmv3", "TFLayoutLMv3Model"),
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
_import_structure = { _import_structure = {
...@@ -44,6 +44,20 @@ else: ...@@ -44,6 +44,20 @@ else:
"GroupViTVisionModel", "GroupViTVisionModel",
] ]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_groupvit"] = [
"TF_GROUPVIT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFGroupViTModel",
"TFGroupViTPreTrainedModel",
"TFGroupViTTextModel",
"TFGroupViTVisionModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_groupvit import ( from .configuration_groupvit import (
GROUPVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, GROUPVIT_PRETRAINED_CONFIG_ARCHIVE_MAP,
...@@ -67,6 +81,20 @@ if TYPE_CHECKING: ...@@ -67,6 +81,20 @@ if TYPE_CHECKING:
GroupViTVisionModel, GroupViTVisionModel,
) )
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_groupvit import (
TF_GROUPVIT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFGroupViTModel,
TFGroupViTPreTrainedModel,
TFGroupViTTextModel,
TFGroupViTVisionModel,
)
else: else:
import sys import sys
......
...@@ -162,7 +162,7 @@ class GroupViTVisionConfig(PretrainedConfig): ...@@ -162,7 +162,7 @@ class GroupViTVisionConfig(PretrainedConfig):
The number of layers in each encoder block. The number of layers in each encoder block.
num_group_tokens (`List[int]`, *optional*, defaults to [64, 8, 0]): num_group_tokens (`List[int]`, *optional*, defaults to [64, 8, 0]):
The number of group tokens for each stage. The number of group tokens for each stage.
num_output_groups (`List[int]`, *optional*, defaults to [64, 8, 0]): num_output_groups (`List[int]`, *optional*, defaults to [64, 8, 8]):
The number of output groups for each stage, 0 means no group. The number of output groups for each stage, 0 means no group.
num_attention_heads (`int`, *optional*, defaults to 6): num_attention_heads (`int`, *optional*, defaults to 6):
Number of attention heads for each attention layer in the Transformer encoder. Number of attention heads for each attention layer in the Transformer encoder.
......
...@@ -1300,7 +1300,7 @@ class GroupViTVisionModel(GroupViTPreTrainedModel): ...@@ -1300,7 +1300,7 @@ class GroupViTVisionModel(GroupViTPreTrainedModel):
>>> import requests >>> import requests
>>> from transformers import AutoProcessor, GroupViTVisionModel >>> from transformers import AutoProcessor, GroupViTVisionModel
>>> processor = AutoPProcessor.from_pretrained("nvidia/groupvit-gcc-yfcc") >>> processor = AutoProcessor.from_pretrained("nvidia/groupvit-gcc-yfcc")
>>> model = GroupViTVisionModel.from_pretrained("nvidia/groupvit-gcc-yfcc") >>> model = GroupViTVisionModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
......
This diff is collapsed.
...@@ -1309,6 +1309,37 @@ class TFGPTJPreTrainedModel(metaclass=DummyObject): ...@@ -1309,6 +1309,37 @@ class TFGPTJPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
TF_GROUPVIT_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TFGroupViTModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFGroupViTPreTrainedModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFGroupViTTextModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFGroupViTVisionModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import inspect import inspect
import os import os
import random
import tempfile import tempfile
import unittest import unittest
...@@ -24,7 +25,7 @@ import numpy as np ...@@ -24,7 +25,7 @@ import numpy as np
import requests import requests
from transformers import GroupViTConfig, GroupViTTextConfig, GroupViTVisionConfig from transformers import GroupViTConfig, GroupViTTextConfig, GroupViTVisionConfig
from transformers.testing_utils import require_torch, require_vision, slow, torch_device from transformers.testing_utils import is_pt_tf_cross_test, require_torch, require_vision, slow, torch_device
from transformers.utils import is_torch_available, is_vision_available from transformers.utils import is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
...@@ -95,7 +96,8 @@ class GroupViTVisionModelTester: ...@@ -95,7 +96,8 @@ class GroupViTVisionModelTester:
self.seq_length = num_patches self.seq_length = num_patches
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) rng = random.Random(0)
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size], rng=rng)
config = self.get_config() config = self.get_config()
return config, pixel_values return config, pixel_values
...@@ -161,6 +163,18 @@ class GroupViTVisionModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -161,6 +163,18 @@ class GroupViTVisionModelTest(ModelTesterMixin, unittest.TestCase):
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self):
import tensorflow as tf
seed = 338
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
tf.random.set_seed(seed)
return super().test_pt_tf_model_equivalence()
def test_model_common_attributes(self): def test_model_common_attributes(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
...@@ -368,7 +382,8 @@ class GroupViTTextModelTester: ...@@ -368,7 +382,8 @@ class GroupViTTextModelTester:
self.scope = scope self.scope = scope
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) rng = random.Random(0)
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size, rng=rng)
input_mask = None input_mask = None
if self.use_input_mask: if self.use_input_mask:
...@@ -532,6 +547,18 @@ class GroupViTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -532,6 +547,18 @@ class GroupViTModelTest(ModelTesterMixin, unittest.TestCase):
def test_model_common_attributes(self): def test_model_common_attributes(self):
pass pass
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self):
import tensorflow as tf
seed = 163
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
tf.random.set_seed(seed)
return super().test_pt_tf_model_equivalence()
# override as the `logit_scale` parameter initilization is different for GROUPVIT # override as the `logit_scale` parameter initilization is different for GROUPVIT
def test_initialization(self): def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
This diff is collapsed.
...@@ -757,7 +757,7 @@ class TFModelTesterMixin: ...@@ -757,7 +757,7 @@ class TFModelTesterMixin:
name="pixel_values", name="pixel_values",
dtype="float32", dtype="float32",
) )
elif model_class.__name__ in ["TFCLIPModel"]: elif model_class.__name__ in ["TFCLIPModel", "TFGroupViTModel"]:
inputs = { inputs = {
"input_ids": tf.keras.Input(batch_shape=(3, max_input), name="input_ids", dtype="int32"), "input_ids": tf.keras.Input(batch_shape=(3, max_input), name="input_ids", dtype="int32"),
"pixel_values": tf.keras.Input( "pixel_values": tf.keras.Input(
......
...@@ -163,6 +163,8 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ ...@@ -163,6 +163,8 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"GroupViTVisionModel", "GroupViTVisionModel",
"TFCLIPTextModel", "TFCLIPTextModel",
"TFCLIPVisionModel", "TFCLIPVisionModel",
"TFGroupViTTextModel",
"TFGroupViTVisionModel",
"FlaxCLIPTextModel", "FlaxCLIPTextModel",
"FlaxCLIPVisionModel", "FlaxCLIPVisionModel",
"FlaxWav2Vec2ForCTC", "FlaxWav2Vec2ForCTC",
......
...@@ -39,6 +39,8 @@ src/transformers/models/electra/modeling_tf_electra.py ...@@ -39,6 +39,8 @@ src/transformers/models/electra/modeling_tf_electra.py
src/transformers/models/glpn/modeling_glpn.py src/transformers/models/glpn/modeling_glpn.py
src/transformers/models/gpt2/modeling_gpt2.py src/transformers/models/gpt2/modeling_gpt2.py
src/transformers/models/gptj/modeling_gptj.py src/transformers/models/gptj/modeling_gptj.py
src/transformers/models/groupvit/modeling_groupvit.py
src/transformers/models/groupvit/modeling_tf_groupvit.py
src/transformers/models/hubert/modeling_hubert.py src/transformers/models/hubert/modeling_hubert.py
src/transformers/models/layoutlm/modeling_layoutlm.py src/transformers/models/layoutlm/modeling_layoutlm.py
src/transformers/models/layoutlm/modeling_tf_layoutlm.py src/transformers/models/layoutlm/modeling_tf_layoutlm.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment