Unverified Commit 6e016634 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: smaller TF serving test (#18840)

parent 563a8d58
......@@ -75,7 +75,7 @@ if is_tf_available():
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
BertConfig,
TFAutoModel,
TFAutoModelForSeq2SeqLM,
TFAutoModelForCausalLM,
TFAutoModelForSequenceClassification,
TFBertModel,
TFSharedEmbeddings,
......@@ -2180,8 +2180,8 @@ class UtilsFunctionsTest(unittest.TestCase):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
def test_generate_tf_function_export(self):
test_model = TFAutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
max_length = 8
test_model = TFAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
max_length = 2
class DummyModel(tf.Module):
def __init__(self, model):
......@@ -2204,8 +2204,8 @@ class UtilsFunctionsTest(unittest.TestCase):
)
return {"sequences": outputs["sequences"]}
dummy_input_ids = [[2, 3, 4, 1, 0, 0, 0, 0], [102, 103, 104, 105, 1, 0, 0, 0]]
dummy_attention_masks = [[1, 1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 0, 0, 0]]
dummy_input_ids = [[2, 0], [102, 103]]
dummy_attention_masks = [[1, 0], [1, 1]]
dummy_model = DummyModel(model=test_model)
with tempfile.TemporaryDirectory() as tmp_dir:
tf.saved_model.save(dummy_model, tmp_dir, signatures={"serving_default": dummy_model.serving})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment