Unverified Commit d7dc774a authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix `TFGroupViT` CI (#19461)



* Fix TFGroupViT CI
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent a293a0e8
...@@ -33,6 +33,7 @@ RUN echo torch=$VERSION ...@@ -33,6 +33,7 @@ RUN echo torch=$VERSION
RUN [ "$PYTORCH" != "pre" ] && python3 -m pip install --no-cache-dir -U $VERSION torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/$CUDA || python3 -m pip install --no-cache-dir -U --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/$CUDA RUN [ "$PYTORCH" != "pre" ] && python3 -m pip install --no-cache-dir -U $VERSION torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/$CUDA || python3 -m pip install --no-cache-dir -U --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/$CUDA
RUN python3 -m pip install --no-cache-dir -U tensorflow RUN python3 -m pip install --no-cache-dir -U tensorflow
RUN python3 -m pip install --no-cache-dir -U tensorflow_probability
RUN python3 -m pip uninstall -y flax jax RUN python3 -m pip uninstall -y flax jax
# Use installed torch version for `torch-scatter` to avid to deal with PYTORCH='pre'. # Use installed torch version for `torch-scatter` to avid to deal with PYTORCH='pre'.
......
...@@ -26,7 +26,13 @@ import numpy as np ...@@ -26,7 +26,13 @@ import numpy as np
import requests import requests
from transformers import GroupViTConfig, GroupViTTextConfig, GroupViTVisionConfig from transformers import GroupViTConfig, GroupViTTextConfig, GroupViTVisionConfig
from transformers.testing_utils import is_pt_tf_cross_test, require_tf, require_vision, slow from transformers.testing_utils import (
is_pt_tf_cross_test,
require_tensorflow_probability,
require_tf,
require_vision,
slow,
)
from transformers.utils import is_tf_available, is_vision_available from transformers.utils import is_tf_available, is_vision_available
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
...@@ -155,6 +161,16 @@ class TFGroupViTVisionModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -155,6 +161,16 @@ class TFGroupViTVisionModelTest(TFModelTesterMixin, unittest.TestCase):
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
"""
During saving, TensorFlow will also run with `training=True` which trigger `gumbel_softmax` that requires
`tensorflow-probability`.
"""
@require_tensorflow_probability
@slow
def test_saved_model_creation(self):
super().test_saved_model_creation()
@unittest.skip(reason="GroupViT does not use inputs_embeds") @unittest.skip(reason="GroupViT does not use inputs_embeds")
def test_graph_mode_with_inputs_embeds(self): def test_graph_mode_with_inputs_embeds(self):
pass pass
...@@ -295,6 +311,10 @@ class TFGroupViTVisionModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -295,6 +311,10 @@ class TFGroupViTVisionModelTest(TFModelTesterMixin, unittest.TestCase):
model = TFGroupViTVisionModel.from_pretrained(model_name) model = TFGroupViTVisionModel.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
@unittest.skip(
"TFGroupViTVisionModel does not convert `hidden_states` and `attentions` to tensors as they are all of"
" different dimensions, and we get `Got a non-Tensor value` error when saving the model."
)
@slow @slow
def test_saved_model_creation_extended(self): def test_saved_model_creation_extended(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...@@ -578,6 +598,10 @@ class TFGroupViTModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -578,6 +598,10 @@ class TFGroupViTModelTest(TFModelTesterMixin, unittest.TestCase):
def test_model_common_attributes(self): def test_model_common_attributes(self):
pass pass
@require_tensorflow_probability
def test_keras_fit(self):
super().test_keras_fit()
@is_pt_tf_cross_test @is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self): def test_pt_tf_model_equivalence(self):
# `GroupViT` computes some indices using argmax, uses them as # `GroupViT` computes some indices using argmax, uses them as
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment