Unverified Commit 7df13432 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Change `torch_dtype` to `str` when `saved_model=True` in `save_pretrained` for TF models (#22740)



* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 8eb38f63
......@@ -2313,6 +2313,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
files_timestamps = self._get_files_timestamps(save_directory)
if saved_model:
# If `torch_dtype` is in the config with a torch dtype class as the value, we need to change it to string.
# (Although TF doesn't care about this attribute, we can't just remove it or set it to `None`.)
if getattr(self.config, "torch_dtype", None) is not None and not isinstance(self.config.torch_dtype, str):
self.config.torch_dtype = str(self.config.torch_dtype).split(".")[1]
if signatures is None:
if any(spec.dtype == tf.int32 for spec in self.serving.input_signature[0].values()):
int64_spec = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment