Unverified Commit 2199382d authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Use random_attention_mask for TF tests (#16517)



* use random_attention_mask for TF tests

* Fix for TFCLIP test (for now).
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 823dbf8a
...@@ -20,7 +20,7 @@ from transformers import OpenAIGPTConfig, is_tf_available ...@@ -20,7 +20,7 @@ from transformers import OpenAIGPTConfig, is_tf_available
from transformers.testing_utils import require_tf, slow from transformers.testing_utils import require_tf, slow
from ..test_configuration_common import ConfigTester from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available(): if is_tf_available():
...@@ -70,7 +70,7 @@ class TFOpenAIGPTModelTester: ...@@ -70,7 +70,7 @@ class TFOpenAIGPTModelTester:
input_mask = None input_mask = None
if self.use_input_mask: if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) input_mask = random_attention_mask([self.batch_size, self.seq_length])
token_type_ids = None token_type_ids = None
if self.use_token_type_ids: if self.use_token_type_ids:
......
...@@ -20,7 +20,7 @@ from transformers import RemBertConfig, is_tf_available ...@@ -20,7 +20,7 @@ from transformers import RemBertConfig, is_tf_available
from transformers.testing_utils import require_tf, slow from transformers.testing_utils import require_tf, slow
from ..test_configuration_common import ConfigTester from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_tf_available(): if is_tf_available():
...@@ -95,7 +95,7 @@ class TFRemBertModelTester: ...@@ -95,7 +95,7 @@ class TFRemBertModelTester:
input_mask = None input_mask = None
if self.use_input_mask: if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) input_mask = random_attention_mask([self.batch_size, self.seq_length])
token_type_ids = None token_type_ids = None
if self.use_token_type_ids: if self.use_token_type_ids:
......
...@@ -20,7 +20,7 @@ from transformers import RobertaConfig, is_tf_available ...@@ -20,7 +20,7 @@ from transformers import RobertaConfig, is_tf_available
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
from ..test_configuration_common import ConfigTester from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_tf_available(): if is_tf_available():
...@@ -72,7 +72,7 @@ class TFRobertaModelTester: ...@@ -72,7 +72,7 @@ class TFRobertaModelTester:
input_mask = None input_mask = None
if self.use_input_mask: if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) input_mask = random_attention_mask([self.batch_size, self.seq_length])
token_type_ids = None token_type_ids = None
if self.use_token_type_ids: if self.use_token_type_ids:
......
...@@ -20,7 +20,7 @@ from transformers import RoFormerConfig, is_tf_available ...@@ -20,7 +20,7 @@ from transformers import RoFormerConfig, is_tf_available
from transformers.testing_utils import require_tf, slow from transformers.testing_utils import require_tf, slow
from ..test_configuration_common import ConfigTester from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available(): if is_tf_available():
...@@ -95,7 +95,7 @@ class TFRoFormerModelTester: ...@@ -95,7 +95,7 @@ class TFRoFormerModelTester:
input_mask = None input_mask = None
if self.use_input_mask: if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) input_mask = random_attention_mask([self.batch_size, self.seq_length])
token_type_ids = None token_type_ids = None
if self.use_token_type_ids: if self.use_token_type_ids:
......
...@@ -20,7 +20,7 @@ from transformers.testing_utils import require_sentencepiece, require_tf, requir ...@@ -20,7 +20,7 @@ from transformers.testing_utils import require_sentencepiece, require_tf, requir
from transformers.utils import cached_property from transformers.utils import cached_property
from ..test_configuration_common import ConfigTester from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available(): if is_tf_available():
...@@ -58,7 +58,7 @@ class TFT5ModelTester: ...@@ -58,7 +58,7 @@ class TFT5ModelTester:
input_mask = None input_mask = None
if self.use_input_mask: if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) input_mask = random_attention_mask([self.batch_size, self.seq_length])
token_labels = None token_labels = None
if self.use_labels: if self.use_labels:
......
...@@ -38,7 +38,7 @@ from transformers.testing_utils import require_tensorflow_probability, require_t ...@@ -38,7 +38,7 @@ from transformers.testing_utils import require_tensorflow_probability, require_t
from transformers.utils import cached_property from transformers.utils import cached_property
from ..test_configuration_common import ConfigTester from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available(): if is_tf_available():
...@@ -158,7 +158,7 @@ class TFTapasModelTester: ...@@ -158,7 +158,7 @@ class TFTapasModelTester:
input_mask = None input_mask = None
if self.use_input_mask: if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) input_mask = random_attention_mask([self.batch_size, self.seq_length])
token_type_ids = [] token_type_ids = []
for type_vocab_size in self.type_vocab_sizes: for type_vocab_size in self.type_vocab_sizes:
......
...@@ -1440,7 +1440,7 @@ def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None): ...@@ -1440,7 +1440,7 @@ def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None):
def random_attention_mask(shape, rng=None, name=None, dtype=None): def random_attention_mask(shape, rng=None, name=None, dtype=None):
attn_mask = ids_tensor(shape, vocab_size=2, rng=None, name=None, dtype=dtype) attn_mask = ids_tensor(shape, vocab_size=2, rng=None, name=None, dtype=dtype)
# make sure that at least one token is attended to for each batch # make sure that at least one token is attended to for each batch
attn_mask = tf.concat([tf.constant(value=1, shape=(shape[0], 1), dtype=dtype), attn_mask[:, 1:]], axis=1) attn_mask = tf.concat([attn_mask[:, :-1], tf.ones_like(attn_mask[:, -1:], dtype=dtype)], axis=-1)
return attn_mask return attn_mask
......
...@@ -20,7 +20,7 @@ from transformers import is_tf_available ...@@ -20,7 +20,7 @@ from transformers import is_tf_available
from transformers.testing_utils import require_tf, slow from transformers.testing_utils import require_tf, slow
from ..test_configuration_common import ConfigTester from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available(): if is_tf_available():
...@@ -75,7 +75,7 @@ class TFXLMModelTester: ...@@ -75,7 +75,7 @@ class TFXLMModelTester:
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = ids_tensor([self.batch_size, self.seq_length], 2, dtype=tf.float32) input_mask = random_attention_mask([self.batch_size, self.seq_length], dtype=tf.float32)
input_lengths = None input_lengths = None
if self.use_input_lengths: if self.use_input_lengths:
......
...@@ -22,7 +22,7 @@ from transformers import XLNetConfig, is_tf_available ...@@ -22,7 +22,7 @@ from transformers import XLNetConfig, is_tf_available
from transformers.testing_utils import require_tf, slow from transformers.testing_utils import require_tf, slow
from ..test_configuration_common import ConfigTester from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available(): if is_tf_available():
...@@ -75,7 +75,7 @@ class TFXLNetModelTester: ...@@ -75,7 +75,7 @@ class TFXLNetModelTester:
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_ids_2 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids_2 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
segment_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) segment_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
input_mask = ids_tensor([self.batch_size, self.seq_length], 2, dtype=tf.float32) input_mask = random_attention_mask([self.batch_size, self.seq_length], dtype=tf.float32)
input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size) input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size)
perm_mask = tf.zeros((self.batch_size, self.seq_length + 1, self.seq_length), dtype=tf.float32) perm_mask = tf.zeros((self.batch_size, self.seq_length + 1, self.seq_length), dtype=tf.float32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment