Unverified Commit df154258 authored by Matt's avatar Matt Committed by GitHub
Browse files

Set env var to hold Keras at Keras 2 (#29598)

* Set env var to hold Keras at Keras 2

* Add Amy's update

* make fixup

* Use a warning instead
parent b6404866
...@@ -78,6 +78,16 @@ if is_safetensors_available(): ...@@ -78,6 +78,16 @@ if is_safetensors_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from . import PreTrainedTokenizerBase from . import PreTrainedTokenizerBase
logger = logging.get_logger(__name__)
if "TF_USE_LEGACY_KERAS" not in os.environ:
os.environ["TF_USE_LEGACY_KERAS"] = "1" # Compatibility fix to make sure tf.keras stays at Keras 2
elif os.environ["TF_USE_LEGACY_KERAS"] != "1":
logger.warning(
"Transformers is only compatible with Keras 2, but you have explicitly set `TF_USE_LEGACY_KERAS` to `0`. "
"This may result in unexpected behaviour or errors if Keras 3 objects are passed to Transformers models."
)
try: try:
import tf_keras as keras import tf_keras as keras
from tf_keras import backend as K from tf_keras import backend as K
...@@ -93,7 +103,6 @@ except (ModuleNotFoundError, ImportError): ...@@ -93,7 +103,6 @@ except (ModuleNotFoundError, ImportError):
) )
logger = logging.get_logger(__name__)
tf_logger = tf.get_logger() tf_logger = tf.get_logger()
TFModelInputType = Union[ TFModelInputType = Union[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment