Fix KerasMetricCallback prediction with generate() and inference of column names (#15351)
* Fix prediction with generate() and the inference of column names Should now have very few differences with the PyTorch implementation * Minor edit to parent class * Update src/transformers/keras_callbacks.py Co-authored-by:Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Explaining the dict conversion * Putting main_input_name back * Fixes to main_input_name Co-authored-by:
Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Showing
Please register or sign in to comment