Unverified Commit 071df6eb authored by Matt's avatar Matt Committed by GitHub
Browse files

Call _set_save_spec() when creating TF models (#19321)



* Add a build_from_serving_sig_and_dummies method and replace all calls like model(model.dummy_inputs) with it.

* make fixup

* Remove the overridden save() as this is no longer necessary

* Also call _set_save_spec(), the last missing piece

* Ensure we set the save spec when loading from config too

* Turn this whole thing into a one-line PR

* Turn this whole thing into a one-line PR

* Turn this whole thing into a one-line PR
Co-authored-by: default avatarYour Name <you@example.com>
parent c875a96e
......@@ -1049,6 +1049,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
# Save config and origin of the pretrained weights if given in model
self.config = config
self.name_or_path = config.name_or_path
# Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec
self._set_save_spec(self.serving.input_signature[0])
def get_config(self):
return self.config.to_dict()
......@@ -1097,29 +1099,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
"""
raise NotImplementedError
def save(
self,
filepath,
overwrite=True,
include_optimizer=True,
save_format=None,
signatures=None,
options=None,
save_traces=True,
):
# Very simple wrapper that ensures we set the correct serving signature when saving
if signatures is None and hasattr(self, "serving"):
signatures = self.serving
super().save(
filepath,
overwrite=overwrite,
include_optimizer=include_optimizer,
save_format=save_format,
signatures=signatures,
options=options,
save_traces=save_traces,
)
def get_input_embeddings(self) -> tf.keras.layers.Layer:
"""
Returns the model's input embeddings layer.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment