Unverified Commit de4159a3 authored by Matt's avatar Matt Committed by GitHub
Browse files

More TF int dtype fixes (#20384)

* Add a test to ensure int dummy inputs are int64

* Move the test into the existing int64 test and update a lot of existing dummies

* Fix remaining dummies

* Fix remaining dummies

* Test for int64 serving sigs as well

* Update core tests to use tf.int64

* Add better messages to the assertions

* Update all serving sigs to int64

* More sneaky hiding tf.int32s

* Add an optional int32 signature in save_pretrained

* make fixup

* Add Amy's suggestions

* Switch all serving sigs back to tf.int32

* Switch all dummies to tf.int32

* Adjust tests to check for tf.int32 instead of tf.int64

* Fix base dummy_inputs dtype

* Start casting to tf.int32 in input_processing

* Change dtype for unpack_inputs test

* Add proper tf.int32 test

* Make the alternate serving signature int64
parent 72b19ca6
......@@ -474,15 +474,15 @@ class TFWhisperPreTrainedModel(TFPreTrainedModel):
self.main_input_name: tf.random.uniform(
[2, self.config.num_mel_bins, self.config.max_source_positions * 2 - 1], dtype=tf.float32
),
"decoder_input_ids": tf.constant([[2, 3]], dtype=tf.int64),
"decoder_input_ids": tf.constant([[2, 3]], dtype=tf.int32),
}
@tf.function(
input_signature=[
{
"input_features": tf.TensorSpec((None, None, None), tf.float32, name="input_features"),
"decoder_input_ids": tf.TensorSpec((None, None), tf.int64, name="decoder_input_ids"),
"decoder_attention_mask": tf.TensorSpec((None, None), tf.int64, name="decoder_attention_mask"),
"decoder_input_ids": tf.TensorSpec((None, None), tf.int32, name="decoder_input_ids"),
"decoder_attention_mask": tf.TensorSpec((None, None), tf.int32, name="decoder_attention_mask"),
}
]
)
......
......@@ -639,15 +639,15 @@ class TFXGLMPreTrainedModel(TFPreTrainedModel):
input_ids = tf.cast(tf.convert_to_tensor(DUMMY_INPUTS), tf.int32)
dummy_inputs = {
"input_ids": input_ids,
"attention_mask": tf.math.not_equal(input_ids, pad_token),
"attention_mask": tf.cast(input_ids != pad_token, tf.int32),
}
return dummy_inputs
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int64, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int64, name="attention_mask"),
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
}
]
)
......
......@@ -533,13 +533,13 @@ class TFXLMPreTrainedModel(TFPreTrainedModel):
@property
def dummy_inputs(self):
# Sometimes XLM has language embeddings so don't forget to build them as well if needed
inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]], dtype=tf.int32)
attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]], dtype=tf.int32)
if self.config.use_lang_emb and self.config.n_langs > 1:
return {
"input_ids": inputs_list,
"attention_mask": attns_list,
"langs": tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]),
"langs": tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]], dtype=tf.int32),
}
else:
return {"input_ids": inputs_list, "attention_mask": attns_list}
......@@ -1006,12 +1006,12 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
# Sometimes XLM has language embeddings so don't forget to build them as well if needed
if self.config.use_lang_emb and self.config.n_langs > 1:
return {
"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS),
"langs": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS),
"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS, dtype=tf.int32),
"langs": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS, dtype=tf.int32),
}
else:
return {
"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS),
"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS, dtype=tf.int32),
}
@unpack_inputs
......
......@@ -1486,7 +1486,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
Returns:
tf.Tensor with dummy inputs
"""
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS, dtype=tf.int32)}
@unpack_inputs
@add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
......@@ -1573,9 +1573,9 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None, None), tf.int64, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None, None), tf.int64, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None, None), tf.int64, name="token_type_ids"),
"input_ids": tf.TensorSpec((None, None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None, None), tf.int32, name="token_type_ids"),
}
]
)
......
......@@ -821,7 +821,7 @@ class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel):
Returns:
`Dict[str, tf.Tensor]`: The dummy inputs.
"""
dummy = {"input_ids": tf.constant(DUMMY_INPUTS)}
dummy = {"input_ids": tf.constant(DUMMY_INPUTS, dtype=tf.int64)}
# Add `encoder_hidden_states` to make the cross-attention layers' weights initialized
if self.config.add_cross_attention:
batch_size, seq_len = tf.constant(DUMMY_INPUTS).shape
......@@ -1365,7 +1365,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
Returns:
tf.Tensor with dummy inputs
"""
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS, dtype=tf.int64)}
@unpack_inputs
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
......
......@@ -1643,7 +1643,7 @@ class TFModelTesterMixin:
if metrics:
self.assertTrue(len(accuracy1) == len(accuracy3) > 0, "Missing metrics!")
def test_int64_inputs(self):
def test_int_support(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
prepared_for_class = self._prepare_for_class(
......@@ -1662,6 +1662,26 @@ class TFModelTesterMixin:
}
model = model_class(config)
model(**prepared_for_class) # No assertion, we're just checking this doesn't throw an error
int32_prepared_for_class = {
key: tf.cast(tensor, tf.int32) if isinstance(tensor, tf.Tensor) and tensor.dtype.is_integer else tensor
for key, tensor in prepared_for_class.items()
}
model(**int32_prepared_for_class) # No assertion, we're just checking this doesn't throw an error
# After testing that the model accepts all int inputs, confirm that its dummies are int32
for key, tensor in model.dummy_inputs.items():
self.assertTrue(isinstance(tensor, tf.Tensor), "Dummy inputs should be tf.Tensor!")
if tensor.dtype.is_integer:
self.assertTrue(tensor.dtype == tf.int32, "Integer dummy inputs should be tf.int32!")
# Also confirm that the serving sig uses int32
if hasattr(model, "serving"):
serving_sig = model.serving.input_signature
for key, tensor_spec in serving_sig[0].items():
if tensor_spec.dtype.is_integer:
self.assertTrue(
tensor_spec.dtype == tf.int32, "Serving signatures should use tf.int32 for ints!"
)
def test_generate_with_headmasking(self):
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
......@@ -2005,9 +2025,9 @@ class UtilsFunctionsTest(unittest.TestCase):
return pixel_values, output_attentions, output_hidden_states, return_dict
dummy_model = DummyModel()
input_ids = tf.constant([0, 1, 2, 3], dtype=tf.int64)
past_key_values = tf.constant([4, 5, 6, 7], dtype=tf.int64)
pixel_values = tf.constant([8, 9, 10, 11], dtype=tf.int64)
input_ids = tf.constant([0, 1, 2, 3], dtype=tf.int32)
past_key_values = tf.constant([4, 5, 6, 7], dtype=tf.int32)
pixel_values = tf.constant([8, 9, 10, 11], dtype=tf.int32)
# test case 1: Pass inputs as keyword arguments; Booleans are inherited from the config.
output = dummy_model.call(input_ids=input_ids, past_key_values=past_key_values)
......
......@@ -218,6 +218,11 @@ class TFCoreModelTesterMixin:
model = model_class(config)
num_out = len(model(class_inputs_dict))
for key in class_inputs_dict.keys():
# Check it's a tensor, in case the inputs dict has some bools in it too
if isinstance(class_inputs_dict[key], tf.Tensor) and class_inputs_dict[key].dtype.is_integer:
class_inputs_dict[key] = tf.cast(class_inputs_dict[key], tf.int32)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, saved_model=True)
saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment