Unverified Commit 8d6de0b9 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: future proof our keras imports (#20317)

* future proof our tf code

* parse tf versions
parent b2c863a3
...@@ -216,6 +216,12 @@ def load_pytorch_state_dict_in_tf2_model( ...@@ -216,6 +216,12 @@ def load_pytorch_state_dict_in_tf2_model(
tf_model, pt_state_dict, tf_inputs=None, allow_missing_keys=False, output_loading_info=False tf_model, pt_state_dict, tf_inputs=None, allow_missing_keys=False, output_loading_info=False
): ):
"""Load a pytorch state_dict in a TF 2.0 model.""" """Load a pytorch state_dict in a TF 2.0 model."""
import tensorflow as tf
from packaging.version import parse
if parse(tf.__version__) >= parse("2.11.0"):
from keras import backend as K
else:
from tensorflow.python.keras import backend as K from tensorflow.python.keras import backend as K
if tf_inputs is None: if tf_inputs is None:
......
...@@ -30,13 +30,9 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union ...@@ -30,13 +30,9 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
import h5py import h5py
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras import backend as K from packaging.version import parse
from tensorflow.python.keras.engine import data_adapter
from tensorflow.python.keras.engine.keras_tensor import KerasTensor
from tensorflow.python.keras.saving import hdf5_format
from huggingface_hub import Repository, list_repo_files from huggingface_hub import Repository, list_repo_files
from keras.saving.hdf5_format import save_attributes_to_hdf5_group
from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
from . import DataCollatorWithPadding, DefaultDataCollator from . import DataCollatorWithPadding, DefaultDataCollator
...@@ -68,6 +64,18 @@ from .utils import ( ...@@ -68,6 +64,18 @@ from .utils import (
) )
if parse(tf.__version__) >= parse("2.11.0"):
from keras import backend as K
from keras.engine import data_adapter
from keras.engine.keras_tensor import KerasTensor
from keras.saving.legacy import hdf5_format
else:
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine import data_adapter
from tensorflow.python.keras.engine.keras_tensor import KerasTensor
from tensorflow.python.keras.saving import hdf5_format
if is_safetensors_available(): if is_safetensors_available():
from safetensors import safe_open from safetensors import safe_open
from safetensors.tensorflow import load_file as safe_load_file from safetensors.tensorflow import load_file as safe_load_file
...@@ -2310,7 +2318,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2310,7 +2318,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
) )
param_dset[:] = layer.numpy() param_dset[:] = layer.numpy()
layers.append(layer_name.encode("utf8")) layers.append(layer_name.encode("utf8"))
save_attributes_to_hdf5_group(shard_file, "layer_names", layers) hdf5_format.save_attributes_to_hdf5_group(shard_file, "layer_names", layers)
if push_to_hub: if push_to_hub:
self._upload_modified_files( self._upload_modified_files(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment