Unverified Commit d913f4aa authored by Matt's avatar Matt Committed by GitHub
Browse files

Fix to KerasMetricCallback when the model returns unstructured output (#21727)

* Stop doing dict-things to non-dict inputs

* Add a debug check

* Add a debug check

* Remove debug checks, looks good now!

* make fixup
parent 82e61f34
......@@ -231,10 +231,12 @@ class KerasMetricCallback(Callback):
# This converts any dict-subclass to a regular dict
# Keras REALLY doesn't like it when we pass around a BatchEncoding or other derived class
predictions = dict(predictions)
if self.output_cols is not None:
predictions = {key: predictions[key] for key in self.output_cols}
else:
predictions = {key: val for key, val in predictions.items() if key not in ignore_keys + ["loss"]}
if self.output_cols is not None:
predictions = {key: predictions[key] for key in self.output_cols}
else:
predictions = {
key: val for key, val in predictions.items() if key not in ignore_keys + ["loss"]
}
prediction_list.append(predictions)
if not self.use_keras_label:
labels = {key: batch[key].numpy() for key in self.label_cols}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment