Unverified Commit 6beae766 authored by Matt's avatar Matt Committed by GitHub
Browse files

Fix KerasMetricCallback prediction with generate() and inference of column names (#15351)



* Fix prediction with generate() and the inference of column names
Should now have very few differences with the PyTorch implementation

* Minor edit to parent class

* Update src/transformers/keras_callbacks.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Explaining the dict conversion

* Putting main_input_name back

* Fixes to main_input_name
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent da5ef25d
...@@ -56,8 +56,6 @@ class KerasMetricCallback(Callback): ...@@ -56,8 +56,6 @@ class KerasMetricCallback(Callback):
metric names to numerical values. metric names to numerical values.
eval_dataset (`tf.data.Dataset` or `dict` or `tuple` or `np.ndarray` or `tf.Tensor`): eval_dataset (`tf.data.Dataset` or `dict` or `tuple` or `np.ndarray` or `tf.Tensor`):
Validation data to be used to generate predictions for the `metric_fn`. Validation data to be used to generate predictions for the `metric_fn`.
metric_fn_kwargs (`dict`, *optional*):
Additional keyword arguments to be passed to the metric_fn.
output_cols (`List[str], *optional*): output_cols (`List[str], *optional*):
A list of columns to be retained from the model output as the predictions. Defaults to all. A list of columns to be retained from the model output as the predictions. Defaults to all.
label_cols ('`List[str]`, *optional*'): label_cols ('`List[str]`, *optional*'):
...@@ -74,7 +72,6 @@ class KerasMetricCallback(Callback): ...@@ -74,7 +72,6 @@ class KerasMetricCallback(Callback):
self, self,
metric_fn: Callable, metric_fn: Callable,
eval_dataset: Union[tf.data.Dataset, np.ndarray, tf.Tensor, tuple, dict], eval_dataset: Union[tf.data.Dataset, np.ndarray, tf.Tensor, tuple, dict],
metric_fn_kwargs: Optional[dict] = None,
output_cols: Optional[List[str]] = None, output_cols: Optional[List[str]] = None,
label_cols: Optional[List[str]] = None, label_cols: Optional[List[str]] = None,
batch_size: Optional[int] = None, batch_size: Optional[int] = None,
...@@ -94,12 +91,6 @@ class KerasMetricCallback(Callback): ...@@ -94,12 +91,6 @@ class KerasMetricCallback(Callback):
self.eval_dataset = eval_dataset self.eval_dataset = eval_dataset
self.predict_with_generate = predict_with_generate self.predict_with_generate = predict_with_generate
self.output_cols = output_cols self.output_cols = output_cols
self.metric_fn_kwargs = metric_fn_kwargs or dict()
if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name:
self.main_input_name = self.model.encoder.main_input_name
else:
self.main_input_name = self.model.main_input_name
# This next block attempts to parse out which elements of the dataset should be appended to the labels list # This next block attempts to parse out which elements of the dataset should be appended to the labels list
# that is passed to the metric_fn # that is passed to the metric_fn
...@@ -123,32 +114,75 @@ class KerasMetricCallback(Callback): ...@@ -123,32 +114,75 @@ class KerasMetricCallback(Callback):
self.label_cols = ["labels"] self.label_cols = ["labels"]
self.use_keras_label = False self.use_keras_label = False
logging.warning("No label_cols specified for KerasMetricCallback, assuming you want the 'labels' key.") logging.warning("No label_cols specified for KerasMetricCallback, assuming you want the 'labels' key.")
elif "start_positions" in input_spec and "end_positions" in input_spec:
self.label_cols = ["start_positions", "end_positions"]
self.use_keras_label = False
logging.warning(
"No label_cols specified for KerasMetricCallback, assuming you want the "
"start_positions and end_positions keys."
)
else: else:
raise ValueError("Could not autodetect label_cols for KerasMetricCallback, please specify them!") raise ValueError("Could not autodetect label_cols for KerasMetricCallback, please specify them!")
if parse(tf.__version__).minor < parse("2.7"): if parse(tf.__version__) < parse("2.7"):
logging.warning("TF versions less than 2.7 may encounter issues with KerasMetricCallback!") logging.warning("TF versions less than 2.7 may encounter issues with KerasMetricCallback!")
@staticmethod @staticmethod
def _concatenate_batches(batches): def _concatenate_batches(batches, padding_index=-100):
# Flattens Numpy array batches into a list of single samples, where each sample is still np.ndarray # If all batches are unidimensional or same length, do a simple concatenation
return [sample for batch in batches for sample in batch] if batches[0].ndim == 1 or all([batch.shape[1] == batches[0].shape[1] for batch in batches]):
return np.concatenate(batches, axis=0)
# Welp, they're not the same length. Let's do some padding
max_len = max([batch.shape[1] for batch in batches])
num_samples = sum([batch.shape[0] for batch in batches])
output = np.full_like(
batches[0], fill_value=padding_index, shape=[num_samples, max_len] + list(batches[0].shape[2:])
)
# i keeps track of which part of the concatenated array we're writing the next batch to
i = 0
for batch in batches:
output[i : i + len(batch), : batch.shape[1]] = batch
i += len(batch)
return output
def _postprocess_predictions_or_labels(self, inputs): def _postprocess_predictions_or_labels(self, inputs):
if isinstance(inputs[0], dict): if isinstance(inputs[0], dict):
outputs = dict() outputs = dict()
for key in inputs[0].keys(): for key in inputs[0].keys():
outputs[key] = self._concatenate_batches(batch[key] for batch in inputs) outputs[key] = self._concatenate_batches([batch[key] for batch in inputs])
# If it's a dict with only one key, just return the array
if len(outputs) == 1:
outputs = list(outputs.values())[0]
elif isinstance(inputs[0], list) or isinstance(inputs[0], tuple): elif isinstance(inputs[0], list) or isinstance(inputs[0], tuple):
outputs = [] outputs = []
for input_list in zip(*inputs): for input_list in zip(*inputs):
outputs.append(self._concatenate_batches(input_list)) outputs.append(self._concatenate_batches(input_list))
if len(outputs) == 1:
outputs = outputs[0] # If it's a list with only one element, just return the array
elif isinstance(inputs[0], np.ndarray): elif isinstance(inputs[0], np.ndarray):
outputs = self._concatenate_batches(inputs) outputs = self._concatenate_batches(inputs)
elif isinstance(inputs[0], tf.Tensor):
outputs = self._concatenate_batches([tensor.numpy() for tensor in inputs])
else: else:
raise TypeError(f"Couldn't handle batch of type {type(inputs[0])}!") raise TypeError(f"Couldn't handle batch of type {type(inputs[0])}!")
return outputs return outputs
def on_epoch_end(self, epoch, logs=None): def on_epoch_end(self, epoch, logs=None):
if hasattr(self.model, "config"):
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
else:
ignore_keys = []
main_input_name = None
if self.predict_with_generate:
# This dense conditional recognizes the case where we have an encoder-decoder model, but
# avoids getting tangled up when we just have a model with a layer called 'encoder'
if hasattr(self.model, "encoder") and hasattr(self.model.encoder, "main_input_name"):
if self.model.encoder.main_input_name != self.model.main_input_name:
main_input_name = self.model.encoder.main_input_name
else:
main_input_name = getattr(self.model, "main_input_name", "input_ids")
prediction_list = [] prediction_list = []
label_list = [] label_list = []
...@@ -160,7 +194,7 @@ class KerasMetricCallback(Callback): ...@@ -160,7 +194,7 @@ class KerasMetricCallback(Callback):
labels = None labels = None
if self.predict_with_generate: if self.predict_with_generate:
if isinstance(batch, dict): if isinstance(batch, dict):
generation_inputs = batch[self.main_input_name] generation_inputs = batch[main_input_name]
attention_mask = batch.get("attention_mask", None) attention_mask = batch.get("attention_mask", None)
else: else:
generation_inputs = batch generation_inputs = batch
...@@ -169,9 +203,14 @@ class KerasMetricCallback(Callback): ...@@ -169,9 +203,14 @@ class KerasMetricCallback(Callback):
predictions = self.model.generate(generation_inputs, attention_mask=attention_mask) predictions = self.model.generate(generation_inputs, attention_mask=attention_mask)
else: else:
predictions = self.model.predict(batch) predictions = self.model.predict(batch)
predictions = dict(predictions) if isinstance(predictions, dict):
if self.output_cols is not None: # This converts any dict-subclass to a regular dict
predictions = {key: predictions[key] for key in self.output_cols} # Keras REALLY doesn't like it when we pass around a BatchEncoding or other derived class
predictions = dict(predictions)
if self.output_cols is not None:
predictions = {key: predictions[key] for key in self.output_cols}
else:
predictions = {key: val for key, val in predictions.items() if key not in ignore_keys + ["loss"]}
prediction_list.append(predictions) prediction_list.append(predictions)
if not self.use_keras_label: if not self.use_keras_label:
labels = {key: batch[key].numpy() for key in self.label_cols} labels = {key: batch[key].numpy() for key in self.label_cols}
...@@ -185,10 +224,10 @@ class KerasMetricCallback(Callback): ...@@ -185,10 +224,10 @@ class KerasMetricCallback(Callback):
raise TypeError(f"Confused by labels of type {type(labels)}") raise TypeError(f"Confused by labels of type {type(labels)}")
label_list.append(labels) label_list.append(labels)
prediction_list = self._postprocess_predictions_or_labels(prediction_list) all_preds = self._postprocess_predictions_or_labels(prediction_list)
label_list = self._postprocess_predictions_or_labels(label_list) all_labels = self._postprocess_predictions_or_labels(label_list)
metric_output = self.metric_fn(prediction_list, label_list, **self.metric_fn_kwargs) metric_output = self.metric_fn((all_preds, all_labels))
if not isinstance(metric_output, dict): if not isinstance(metric_output, dict):
raise TypeError( raise TypeError(
f"metric_fn should return a dict mapping metric names to values but instead returned {metric_output}" f"metric_fn should return a dict mapping metric names to values but instead returned {metric_output}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment