Unverified Commit 3b5a56e5 authored by Matěj Kripner's avatar Matěj Kripner Committed by GitHub
Browse files

Fix `KerasMetricCallback`: pass `generate_kwargs` even if `use_xla_generation` is False (#24333)

* Fix `KerasMetricCallback`: always pass `generate_kwargs`.

* Reformat code using Black.
parent 0b259a3b
......@@ -224,7 +224,9 @@ class KerasMetricCallback(Callback):
if self.use_xla_generation:
predictions = self.generation_function(generation_inputs, attention_mask=attention_mask)
else:
predictions = self.model.generate(generation_inputs, attention_mask=attention_mask)
predictions = self.model.generate(
generation_inputs, attention_mask=attention_mask, **self.generate_kwargs
)
else:
predictions = self.model.predict_on_batch(batch)
if isinstance(predictions, dict):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment