Unverified Commit f04257fd authored by Matt's avatar Matt Committed by GitHub
Browse files

Add test to ensure models can take int64 inputs (#17210)

* Add test to ensure models can take int64 inputs

* is_integer is an attribute, not a method

* Fix test when some inputs aren't tensors

* Add casts to blenderbot and blenderbot-small

* Add casts to the other failing models
parent 5294fa12
...@@ -1287,7 +1287,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal ...@@ -1287,7 +1287,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
if labels is not None: if labels is not None:
labels = tf.where( labels = tf.where(
labels == self.config.pad_token_id, labels == self.config.pad_token_id,
tf.fill(shape_list(labels), -100), tf.cast(tf.fill(shape_list(labels), -100), labels.dtype),
labels, labels,
) )
use_cache = False use_cache = False
......
...@@ -1265,7 +1265,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel ...@@ -1265,7 +1265,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
if labels is not None: if labels is not None:
labels = tf.where( labels = tf.where(
labels == self.config.pad_token_id, labels == self.config.pad_token_id,
tf.fill(shape_list(labels), -100), tf.cast(tf.fill(shape_list(labels), -100), labels.dtype),
labels, labels,
) )
use_cache = False use_cache = False
......
...@@ -182,8 +182,8 @@ def get_masks(slen, lengths, causal, padding_mask=None): ...@@ -182,8 +182,8 @@ def get_masks(slen, lengths, causal, padding_mask=None):
mask = padding_mask mask = padding_mask
else: else:
# assert lengths.max().item() <= slen # assert lengths.max().item() <= slen
alen = tf.range(slen) alen = tf.range(slen, dtype=lengths.dtype)
mask = tf.math.less(alen, tf.expand_dims(lengths, axis=1)) mask = alen < tf.expand_dims(lengths, axis=1)
# attention mask is the same as mask, or triangular inferior attention (causal) # attention mask is the same as mask, or triangular inferior attention (causal)
if causal: if causal:
......
...@@ -1300,7 +1300,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo ...@@ -1300,7 +1300,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
if labels is not None: if labels is not None:
labels = tf.where( labels = tf.where(
labels == self.config.pad_token_id, labels == self.config.pad_token_id,
tf.fill(shape_list(labels), -100), tf.cast(tf.fill(shape_list(labels), -100), labels.dtype),
labels, labels,
) )
use_cache = False use_cache = False
......
...@@ -1317,7 +1317,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua ...@@ -1317,7 +1317,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
if labels is not None: if labels is not None:
labels = tf.where( labels = tf.where(
labels == self.config.pad_token_id, labels == self.config.pad_token_id,
tf.fill(shape_list(labels), -100), tf.cast(tf.fill(shape_list(labels), -100), labels.dtype),
labels, labels,
) )
use_cache = False use_cache = False
......
...@@ -1726,7 +1726,10 @@ class ProductIndexMap(IndexMap): ...@@ -1726,7 +1726,10 @@ class ProductIndexMap(IndexMap):
raise ValueError("outer_index.batch_dims and inner_index.batch_dims " "must be the same.") raise ValueError("outer_index.batch_dims and inner_index.batch_dims " "must be the same.")
super(ProductIndexMap, self).__init__( super(ProductIndexMap, self).__init__(
indices=(inner_index.indices + outer_index.indices * inner_index.num_segments), indices=(
inner_index.indices
+ outer_index.indices * tf.cast(inner_index.num_segments, inner_index.indices.dtype)
),
num_segments=inner_index.num_segments * outer_index.num_segments, num_segments=inner_index.num_segments * outer_index.num_segments,
batch_dims=inner_index.batch_dims, batch_dims=inner_index.batch_dims,
) )
...@@ -1785,7 +1788,7 @@ def flatten(index, name="segmented_flatten"): ...@@ -1785,7 +1788,7 @@ def flatten(index, name="segmented_flatten"):
for _ in range(index.batch_dims, index.indices.shape.rank): for _ in range(index.batch_dims, index.indices.shape.rank):
offset = tf.expand_dims(offset, -1) offset = tf.expand_dims(offset, -1)
indices = offset + index.indices indices = tf.cast(offset, index.indices.dtype) + index.indices
return IndexMap(indices=tf.reshape(indices, [-1]), num_segments=index.num_segments * batch_size, batch_dims=0) return IndexMap(indices=tf.reshape(indices, [-1]), num_segments=index.num_segments * batch_size, batch_dims=0)
......
...@@ -111,7 +111,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer): ...@@ -111,7 +111,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
@staticmethod @staticmethod
def _gather_logprob(logprob, target): def _gather_logprob(logprob, target):
lp_size = shape_list(logprob) lp_size = shape_list(logprob)
r = tf.range(lp_size[0]) r = tf.range(lp_size[0], dtype=target.dtype)
idx = tf.stack([r, target], 1) idx = tf.stack([r, target], 1)
return tf.gather_nd(logprob, idx) return tf.gather_nd(logprob, idx)
......
...@@ -92,8 +92,8 @@ def get_masks(slen, lengths, causal, padding_mask=None): ...@@ -92,8 +92,8 @@ def get_masks(slen, lengths, causal, padding_mask=None):
mask = padding_mask mask = padding_mask
else: else:
# assert lengths.max().item() <= slen # assert lengths.max().item() <= slen
alen = tf.range(slen) alen = tf.range(slen, dtype=lengths.dtype)
mask = tf.math.less(alen, tf.expand_dims(lengths, axis=1)) mask = alen < tf.expand_dims(lengths, axis=1)
# attention mask is the same as mask, or triangular inferior attention (causal) # attention mask is the same as mask, or triangular inferior attention (causal)
if causal: if causal:
......
...@@ -1372,6 +1372,26 @@ class TFModelTesterMixin: ...@@ -1372,6 +1372,26 @@ class TFModelTesterMixin:
val_loss2 = history2.history["val_loss"][0] val_loss2 = history2.history["val_loss"][0]
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3)) self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
def test_int64_inputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
prepared_for_class = self._prepare_for_class(
inputs_dict.copy(),
model_class,
return_labels=True if "labels" in inspect.signature(model_class.call).parameters.keys() else False,
)
if not any(
[tensor.dtype.is_integer for tensor in prepared_for_class.values() if isinstance(tensor, tf.Tensor)]
):
return # No integer inputs means no need for this test
prepared_for_class = {
key: tf.cast(tensor, tf.int64) if isinstance(tensor, tf.Tensor) and tensor.dtype.is_integer else tensor
for key, tensor in prepared_for_class.items()
}
model = model_class(config)
model(**prepared_for_class) # No assertion, we're just checking this doesn't throw an error
def test_generate_with_headmasking(self): def test_generate_with_headmasking(self):
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"] attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment