Unverified Commit 7a496100 authored by Matt's avatar Matt Committed by GitHub
Browse files

Wrap Keras methods to support BatchEncoding (#28734)

* Shim the Keras methods to support BatchEncoding

* Extract everything to a convert_batch_encoding function

* Convert BatchFeature too (thanks Amy)

* tf.keras -> keras
parent 721e2d94
...@@ -41,6 +41,7 @@ from .configuration_utils import PretrainedConfig ...@@ -41,6 +41,7 @@ from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save from .dynamic_module_utils import custom_object_save
from .generation import GenerationConfig, TFGenerationMixin from .generation import GenerationConfig, TFGenerationMixin
from .tf_utils import ( from .tf_utils import (
convert_batch_encoding,
expand_1d, expand_1d,
load_attributes_from_hdf5_group, load_attributes_from_hdf5_group,
save_attributes_to_hdf5_group, save_attributes_to_hdf5_group,
...@@ -1155,6 +1156,36 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT ...@@ -1155,6 +1156,36 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
def get_config(self): def get_config(self):
return self.config.to_dict() return self.config.to_dict()
@functools.wraps(keras.Model.fit)
def fit(self, *args, **kwargs):
args, kwargs = convert_batch_encoding(*args, **kwargs)
return super().fit(*args, **kwargs)
@functools.wraps(keras.Model.train_on_batch)
def train_on_batch(self, *args, **kwargs):
args, kwargs = convert_batch_encoding(*args, **kwargs)
return super().train_on_batch(*args, **kwargs)
@functools.wraps(keras.Model.test_on_batch)
def test_on_batch(self, *args, **kwargs):
args, kwargs = convert_batch_encoding(*args, **kwargs)
return super().test_on_batch(*args, **kwargs)
@functools.wraps(keras.Model.predict_on_batch)
def predict_on_batch(self, *args, **kwargs):
args, kwargs = convert_batch_encoding(*args, **kwargs)
return super().predict_on_batch(*args, **kwargs)
@functools.wraps(keras.Model.predict)
def predict(self, *args, **kwargs):
args, kwargs = convert_batch_encoding(*args, **kwargs)
return super().predict(*args, **kwargs)
@functools.wraps(keras.Model.evaluate)
def evaluate(self, *args, **kwargs):
args, kwargs = convert_batch_encoding(*args, **kwargs)
return super().evaluate(*args, **kwargs)
@classmethod @classmethod
def from_config(cls, config, **kwargs): def from_config(cls, config, **kwargs):
if isinstance(config, PretrainedConfig): if isinstance(config, PretrainedConfig):
......
...@@ -17,6 +17,8 @@ from typing import List, Optional, Union ...@@ -17,6 +17,8 @@ from typing import List, Optional, Union
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from .feature_extraction_utils import BatchFeature
from .tokenization_utils_base import BatchEncoding
from .utils import logging from .utils import logging
...@@ -253,3 +255,13 @@ def expand_1d(data): ...@@ -253,3 +255,13 @@ def expand_1d(data):
return t return t
return tf.nest.map_structure(_expand_single_1d_tensor, data) return tf.nest.map_structure(_expand_single_1d_tensor, data)
def convert_batch_encoding(*args, **kwargs):
# Convert HF BatchEncoding/BatchFeature objects in the inputs to dicts that Keras understands
if args and isinstance(args[0], (BatchEncoding, BatchFeature)):
args = list(args)
args[0] = dict(args[0])
elif "x" in kwargs and isinstance(kwargs["x"], (BatchEncoding, BatchFeature)):
kwargs["x"] = dict(kwargs["x"])
return args, kwargs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment