Unverified Commit 2322eb8e authored by Matt's avatar Matt Committed by GitHub
Browse files

Update serving signatures and make sure we actually use them (#19034)

* Override save() to use the serving signature as the default

* Replace int32 with int64 in all our serving signatures

* Remember one very important line so as not to break every test at once

* Dtype fix for TFLED

* dtype fix for shift_tokens_right in general

* Dtype fixes in mBART and RAG

* Fix dtypes for test_unpack_inputs

* More dtype fixes

* Yet more mBART + RAG dtype fixes

* Yet more mBART + RAG dtype fixes

* Add a check that the model actually has a serving method
parent 9b80a0bc
......@@ -425,8 +425,8 @@ class TFOPTPreTrainedModel(TFPreTrainedModel):
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"input_ids": tf.TensorSpec((None, None), tf.int64, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int64, name="attention_mask"),
}
]
)
......
......@@ -65,11 +65,15 @@ LARGE_NEGATIVE = -1e8
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)
start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
start_tokens = tf.fill(
(shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype)
)
shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = tf.where(
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
shifted_input_ids == -100,
tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)),
shifted_input_ids,
)
# "Verify that `labels` has only positive values and -100"
......
......@@ -1301,17 +1301,18 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
pad_token_id = self.generator.config.pad_token_id
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
shifted_input_ids = tf.cast(input_ids, tf.int32)
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), start_token_id)
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, :-1]], -1)
start_tokens = tf.fill((shape_list(input_ids)[0], 1), tf.cast(start_token_id, input_ids.dtype))
shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = tf.where(
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
shifted_input_ids == -100,
tf.fill(shape_list(shifted_input_ids), tf.cast(pad_token_id, input_ids.dtype)),
shifted_input_ids,
)
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32))
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, shifted_input_ids.dtype))
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
......@@ -1324,7 +1325,10 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
n_docs = n_docs if n_docs is not None else self.config.n_docs
# shift tokens left (from original Pytorch's version)
target = tf.concat([target[:, 1:], tf.fill([target.shape[0], 1], self.config.generator.pad_token_id)], axis=1)
target = tf.concat(
[target[:, 1:], tf.fill([target.shape[0], 1], tf.cast(self.config.generator.pad_token_id, target.dtype))],
axis=1,
)
rag_logprobs = self.marginalize(seq_logits, doc_scores, n_docs)
loss = self.hf_compute_loss(target, rag_logprobs, from_logits=True, reduce_loss=reduce_loss)
......@@ -1571,7 +1575,10 @@ class TFRagSequenceForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingL
self, seq_logits, doc_scores, target, reduce_loss=False, epsilon=0.0, exclude_bos_score=False, n_docs=None
):
# shift tokens left
target = tf.concat([target[:, 1:], tf.fill([target.shape[0], 1], self.config.generator.pad_token_id)], axis=1)
target = tf.concat(
[target[:, 1:], tf.fill([target.shape[0], 1], tf.cast(self.config.generator.pad_token_id, target.dtype))],
axis=1,
)
# bos_token_id is None for T5
bos_token_id = self.config.bos_token_id or self.config.generator.bos_token_id
......@@ -1580,7 +1587,7 @@ class TFRagSequenceForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingL
use_bos = bos_token_id is not None and equal_bos_token_id_all
def _mask_pads(ll, smooth_obj):
pad_mask = tf.equal(target, self.config.generator.pad_token_id)
pad_mask = tf.equal(target, tf.cast(self.config.generator.pad_token_id, target.dtype))
if tf.reduce_any(pad_mask):
ll = tf.where(pad_mask, 0.0, ll)
smooth_obj = tf.where(pad_mask, 0.0, smooth_obj)
......@@ -1611,7 +1618,7 @@ class TFRagSequenceForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingL
def torch_gather(param, id_tensor):
# 2d-gather torch equivalent: https://stackoverflow.com/questions/52129909/tensorflow-equivalent-of-torch-gather
def gather2d(target, id_tensor):
idx = tf.stack([tf.range(tf.shape(id_tensor)[0]), id_tensor[:, 0]], axis=-1)
idx = tf.stack([tf.range(tf.shape(id_tensor)[0], dtype=id_tensor.dtype), id_tensor[:, 0]], axis=-1)
result = tf.gather_nd(target, idx)
return tf.expand_dims(result, axis=-1)
......
......@@ -1435,9 +1435,9 @@ class TFRemBertForMultipleChoice(TFRemBertPreTrainedModel, TFMultipleChoiceLoss)
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None, None), tf.int32, name="token_type_ids"),
"input_ids": tf.TensorSpec((None, None, None), tf.int64, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None, None), tf.int64, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None, None), tf.int64, name="token_type_ids"),
}
]
)
......
......@@ -798,8 +798,8 @@ class TFRobertaPreTrainedModel(TFPreTrainedModel):
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"input_ids": tf.TensorSpec((None, None), tf.int64, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int64, name="attention_mask"),
}
]
)
......
......@@ -1211,9 +1211,9 @@ class TFRoFormerForMultipleChoice(TFRoFormerPreTrainedModel, TFMultipleChoiceLos
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None, None), tf.int32, name="token_type_ids"),
"input_ids": tf.TensorSpec((None, None, None), tf.int64, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None, None), tf.int64, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None, None), tf.int64, name="token_type_ids"),
}
]
)
......
......@@ -67,11 +67,15 @@ LARGE_NEGATIVE = -1e8
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)
start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
start_tokens = tf.fill(
(shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype)
)
shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = tf.where(
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
shifted_input_ids == -100,
tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)),
shifted_input_ids,
)
# "Verify that `labels` has only positive values and -100"
......@@ -591,9 +595,9 @@ class TFSpeech2TextPreTrainedModel(TFPreTrainedModel):
input_signature=[
{
"input_features": tf.TensorSpec((None, None, None), tf.float32, name="input_features"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"decoder_input_ids": tf.TensorSpec((None, None), tf.int32, name="decoder_input_ids"),
"decoder_attention_mask": tf.TensorSpec((None, None), tf.int32, name="decoder_attention_mask"),
"attention_mask": tf.TensorSpec((None, None), tf.int64, name="attention_mask"),
"decoder_input_ids": tf.TensorSpec((None, None), tf.int64, name="decoder_input_ids"),
"decoder_attention_mask": tf.TensorSpec((None, None), tf.int64, name="decoder_attention_mask"),
}
]
)
......
......@@ -872,10 +872,10 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"decoder_input_ids": tf.TensorSpec((None, None), tf.int32, name="decoder_input_ids"),
"decoder_attention_mask": tf.TensorSpec((None, None), tf.int32, name="decoder_attention_mask"),
"input_ids": tf.TensorSpec((None, None), tf.int64, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int64, name="attention_mask"),
"decoder_input_ids": tf.TensorSpec((None, None), tf.int64, name="decoder_input_ids"),
"decoder_attention_mask": tf.TensorSpec((None, None), tf.int64, name="decoder_attention_mask"),
}
]
)
......
......@@ -865,9 +865,9 @@ class TFTapasPreTrainedModel(TFPreTrainedModel):
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"input_ids": tf.TensorSpec((None, None), tf.int64, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.float32, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None, None), tf.int32, name="token_type_ids"),
"token_type_ids": tf.TensorSpec((None, None, None), tf.int64, name="token_type_ids"),
}
]
)
......
......@@ -686,7 +686,7 @@ class TFTransfoXLPreTrainedModel(TFPreTrainedModel):
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"input_ids": tf.TensorSpec((None, None), tf.int64, name="input_ids"),
}
]
)
......
......@@ -1345,8 +1345,8 @@ class TFWav2Vec2PreTrainedModel(TFPreTrainedModel):
input_signature=[
{
"input_values": tf.TensorSpec((None, None), tf.float32, name="input_values"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int64, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None), tf.int64, name="token_type_ids"),
}
]
)
......
......@@ -636,8 +636,8 @@ class TFXGLMPreTrainedModel(TFPreTrainedModel):
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"input_ids": tf.TensorSpec((None, None), tf.int64, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int64, name="attention_mask"),
}
]
)
......
......@@ -1563,9 +1563,9 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None, None), tf.int32, name="token_type_ids"),
"input_ids": tf.TensorSpec((None, None, None), tf.int64, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None, None), tf.int64, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None, None), tf.int64, name="token_type_ids"),
}
]
)
......
......@@ -1685,16 +1685,21 @@ _TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer"
LARGE_NEGATIVE = -1e8
# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)
start_tokens = tf.fill((shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype))
shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = tf.where(
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
shifted_input_ids == -100,
tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)),
shifted_input_ids,
)
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=shifted_input_ids.dtype))
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
......
......@@ -1887,9 +1887,9 @@ class UtilsFunctionsTest(unittest.TestCase):
return pixel_values, output_attentions, output_hidden_states, return_dict
dummy_model = DummyModel()
input_ids = tf.constant([0, 1, 2, 3])
past = tf.constant([4, 5, 6, 7])
pixel_values = tf.constant([8, 9, 10, 11])
input_ids = tf.constant([0, 1, 2, 3], dtype=tf.int64)
past = tf.constant([4, 5, 6, 7], dtype=tf.int64)
pixel_values = tf.constant([8, 9, 10, 11], dtype=tf.int64)
# test case 1: Pass inputs as keyword arguments; Booleans are inherited from the config.
output = dummy_model.call(input_ids=input_ids, past=past)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment