Unverified Commit 2199382d authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Use random_attention_mask for TF tests (#16517)



* use random_attention_mask for TF tests

* Fix for TFCLIP test (for now).
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 823dbf8a
......@@ -21,7 +21,7 @@ from transformers import is_tf_available, {{cookiecutter.camelcase_modelname}}Co
from transformers.testing_utils import require_tf, slow
from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_tf_available():
......@@ -92,7 +92,7 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTester:
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
input_mask = random_attention_mask([self.batch_size, self.seq_length])
token_type_ids = None
if self.use_token_type_ids:
......
......@@ -21,7 +21,7 @@ from transformers.models.auto import get_values
from transformers.testing_utils import require_tf, slow
from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
......@@ -96,7 +96,7 @@ class TFAlbertModelTester:
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
input_mask = random_attention_mask([self.batch_size, self.seq_length])
token_type_ids = None
if self.use_token_type_ids:
......
......@@ -21,7 +21,7 @@ from transformers.models.auto import get_values
from transformers.testing_utils import require_tf, slow
from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ..utils.test_modeling_tf_core import TFCoreModelTesterMixin
......@@ -96,7 +96,7 @@ class TFBertModelTester:
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
input_mask = random_attention_mask([self.batch_size, self.seq_length])
token_type_ids = None
if self.use_token_type_ids:
......
......@@ -301,6 +301,12 @@ class TFCLIPTextModelTester:
input_mask = None
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
# make sure the first token has attention mask `1` to ensure that, after combining the causal mask, there
# is still at least one token being attended to for each batch.
# TODO: Change `random_attention_mask` in PT/TF/Flax common test file, after a discussion with the team.
input_mask = tf.concat(
[tf.ones_like(input_mask[:, :1], dtype=input_mask.dtype), input_mask[:, 1:]], axis=-1
)
config = self.get_config()
......
......@@ -20,7 +20,7 @@ from transformers import ConvBertConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
......@@ -94,7 +94,7 @@ class TFConvBertModelTester:
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
input_mask = random_attention_mask([self.batch_size, self.seq_length])
token_type_ids = None
if self.use_token_type_ids:
......
......@@ -20,7 +20,7 @@ from transformers import CTRLConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
......@@ -69,7 +69,7 @@ class TFCTRLModelTester(object):
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
input_mask = random_attention_mask([self.batch_size, self.seq_length])
token_type_ids = None
if self.use_token_type_ids:
......
......@@ -20,7 +20,7 @@ from transformers import DebertaConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
......@@ -92,7 +92,7 @@ class TFDebertaModelTester:
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
input_mask = random_attention_mask([self.batch_size, self.seq_length])
token_type_ids = None
if self.use_token_type_ids:
......
......@@ -20,7 +20,7 @@ from transformers import DebertaV2Config, is_tf_available
from transformers.testing_utils import require_tf, slow
from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
......@@ -95,7 +95,7 @@ class TFDebertaV2ModelTester:
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
input_mask = random_attention_mask([self.batch_size, self.seq_length])
token_type_ids = None
if self.use_token_type_ids:
......
......@@ -20,7 +20,7 @@ from transformers import DistilBertConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
......@@ -70,7 +70,7 @@ class TFDistilBertModelTester:
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
input_mask = random_attention_mask([self.batch_size, self.seq_length])
sequence_labels = None
token_labels = None
......
......@@ -19,7 +19,7 @@ from transformers import is_tf_available
from transformers.testing_utils import require_tf, slow
from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
......@@ -94,9 +94,8 @@ class TFDPRModelTester:
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor(
[self.batch_size, self.seq_length], vocab_size=2
) # follow test_modeling_tf_ctrl.py
# follow test_modeling_tf_ctrl.py
input_mask = random_attention_mask([self.batch_size, self.seq_length])
token_type_ids = None
if self.use_token_type_ids:
......
......@@ -20,7 +20,7 @@ from transformers import ElectraConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_tf_available():
......@@ -71,7 +71,7 @@ class TFElectraModelTester:
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
input_mask = random_attention_mask([self.batch_size, self.seq_length])
token_type_ids = None
if self.use_token_type_ids:
......
......@@ -19,7 +19,7 @@ from transformers import is_tf_available
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
......@@ -75,7 +75,7 @@ class TFFlaubertModelTester:
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = ids_tensor([self.batch_size, self.seq_length], 2, dtype=tf.float32)
input_mask = random_attention_mask([self.batch_size, self.seq_length], dtype=tf.float32)
input_lengths = None
if self.use_input_lengths:
......
......@@ -20,7 +20,7 @@ from transformers import FunnelConfig, is_tf_available
from transformers.testing_utils import require_tf
from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
......@@ -111,7 +111,7 @@ class TFFunnelModelTester:
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
input_mask = random_attention_mask([self.batch_size, self.seq_length])
token_type_ids = None
if self.use_token_type_ids:
......
......@@ -19,7 +19,7 @@ from transformers import GPT2Config, is_tf_available
from transformers.testing_utils import require_tf, slow
from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ..utils.test_modeling_tf_core import TFCoreModelTesterMixin
......@@ -74,7 +74,7 @@ class TFGPT2ModelTester:
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
input_mask = random_attention_mask([self.batch_size, self.seq_length])
token_type_ids = None
if self.use_token_type_ids:
......
......@@ -20,7 +20,7 @@ from transformers import AutoTokenizer, GPTJConfig, is_tf_available
from transformers.testing_utils import require_tf, slow, tooslow
from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
from ..utils.test_modeling_tf_core import TFCoreModelTesterMixin
......@@ -70,7 +70,7 @@ class TFGPTJModelTester:
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
input_mask = random_attention_mask([self.batch_size, self.seq_length])
token_type_ids = None
if self.use_token_type_ids:
......
......@@ -21,7 +21,7 @@ from transformers import LayoutLMConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
......@@ -107,7 +107,7 @@ class TFLayoutLMModelTester:
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
input_mask = random_attention_mask([self.batch_size, self.seq_length])
token_type_ids = None
if self.use_token_type_ids:
......
......@@ -20,7 +20,7 @@ from transformers import is_tf_available
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
......@@ -79,7 +79,7 @@ class TFLongformerModelTester:
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
input_mask = random_attention_mask([self.batch_size, self.seq_length])
token_type_ids = None
if self.use_token_type_ids:
......
......@@ -23,7 +23,7 @@ from transformers import LxmertConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
......@@ -124,7 +124,7 @@ class TFLxmertModelTester(object):
input_mask = None
if self.use_lang_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
input_mask = random_attention_mask([self.batch_size, self.seq_length])
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
......
......@@ -20,7 +20,7 @@ from transformers import MobileBertConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
......@@ -114,7 +114,7 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
input_mask = random_attention_mask([self.batch_size, self.seq_length])
token_type_ids = None
if self.use_token_type_ids:
......
......@@ -20,7 +20,7 @@ from transformers import MPNetConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
......@@ -90,7 +90,7 @@ class TFMPNetModelTester:
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
input_mask = random_attention_mask([self.batch_size, self.seq_length])
sequence_labels = None
token_labels = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment