Unverified Commit df8faba4 authored by Dimitre Oliveira's avatar Dimitre Oliveira Committed by GitHub
Browse files

Enabling custom TF signature draft (#19249)



* Custom TF signature draft

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Adding tf signature tests

* Fixing signature check and adding asserts

* fixing model load path

* Adjusting signature tests

* Formatting file
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: default avatarDimitre Oliveira <dimitreoliveira@Dimitres-MacBook-Air.local>
parent 10100979
......@@ -2097,6 +2097,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
saved_model=False,
version=1,
push_to_hub=False,
signatures=None,
max_shard_size: Union[int, str] = "10GB",
create_pr: bool = False,
**kwargs
......@@ -2118,6 +2119,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
signatures (`dict` or `tf.function`, *optional*):
Model's signature used for serving. This will be passed to the `signatures` argument of model.save().
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
......@@ -2148,8 +2151,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
files_timestamps = self._get_files_timestamps(save_directory)
if saved_model:
if signatures is None:
signatures = self.serving
saved_model_dir = os.path.join(save_directory, "saved_model", str(version))
self.save(saved_model_dir, include_optimizer=False, signatures=self.serving)
self.save(saved_model_dir, include_optimizer=False, signatures=signatures)
logger.info(f"Saved model created in {saved_model_dir}")
# Save configuration file
......
......@@ -2216,6 +2216,46 @@ class UtilsFunctionsTest(unittest.TestCase):
for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
def test_save_pretrained_signatures(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# Short custom TF signature function.
# `input_signature` is specific to BERT.
@tf.function(
input_signature=[
[
tf.TensorSpec([None, None], tf.int32, name="input_ids"),
tf.TensorSpec([None, None], tf.int32, name="token_type_ids"),
tf.TensorSpec([None, None], tf.int32, name="attention_mask"),
]
]
)
def serving_fn(input):
return model(input)
# Using default signature (default behavior) overrides 'serving_default'
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, saved_model=True, signatures=None)
model_loaded = tf.keras.models.load_model(f"{tmp_dir}/saved_model/1")
self.assertTrue("serving_default" in list(model_loaded.signatures.keys()))
# Providing custom signature function
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, saved_model=True, signatures={"custom_signature": serving_fn})
model_loaded = tf.keras.models.load_model(f"{tmp_dir}/saved_model/1")
self.assertTrue("custom_signature" in list(model_loaded.signatures.keys()))
# Providing multiple custom signature function
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(
tmp_dir,
saved_model=True,
signatures={"custom_signature_1": serving_fn, "custom_signature_2": serving_fn},
)
model_loaded = tf.keras.models.load_model(f"{tmp_dir}/saved_model/1")
self.assertTrue("custom_signature_1" in list(model_loaded.signatures.keys()))
self.assertTrue("custom_signature_2" in list(model_loaded.signatures.keys()))
@require_tf
@is_staging_test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment