Unverified Commit 8d6de0b9 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: future proof our keras imports (#20317)

* future proof our tf code

* parse tf versions
parent b2c863a3
......@@ -216,7 +216,13 @@ def load_pytorch_state_dict_in_tf2_model(
tf_model, pt_state_dict, tf_inputs=None, allow_missing_keys=False, output_loading_info=False
):
"""Load a pytorch state_dict in a TF 2.0 model."""
from tensorflow.python.keras import backend as K
import tensorflow as tf
from packaging.version import parse
if parse(tf.__version__) >= parse("2.11.0"):
from keras import backend as K
else:
from tensorflow.python.keras import backend as K
if tf_inputs is None:
tf_inputs = tf_model.dummy_inputs
......
......@@ -30,13 +30,9 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
import h5py
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine import data_adapter
from tensorflow.python.keras.engine.keras_tensor import KerasTensor
from tensorflow.python.keras.saving import hdf5_format
from packaging.version import parse
from huggingface_hub import Repository, list_repo_files
from keras.saving.hdf5_format import save_attributes_to_hdf5_group
from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
from . import DataCollatorWithPadding, DefaultDataCollator
......@@ -68,6 +64,18 @@ from .utils import (
)
if parse(tf.__version__) >= parse("2.11.0"):
from keras import backend as K
from keras.engine import data_adapter
from keras.engine.keras_tensor import KerasTensor
from keras.saving.legacy import hdf5_format
else:
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine import data_adapter
from tensorflow.python.keras.engine.keras_tensor import KerasTensor
from tensorflow.python.keras.saving import hdf5_format
if is_safetensors_available():
from safetensors import safe_open
from safetensors.tensorflow import load_file as safe_load_file
......@@ -2310,7 +2318,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
)
param_dset[:] = layer.numpy()
layers.append(layer_name.encode("utf8"))
save_attributes_to_hdf5_group(shard_file, "layer_names", layers)
hdf5_format.save_attributes_to_hdf5_group(shard_file, "layer_names", layers)
if push_to_hub:
self._upload_modified_files(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment