Unverified Commit 68097dcc authored by Matt's avatar Matt Committed by GitHub
Browse files

Fix Sylvain's nits on the original KerasMetricCallback PR (#18300)



* Fix Sylvain's nits on the original PR

* Update src/transformers/keras_callbacks.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Re-add "optional" to docstring
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 66491331
...@@ -84,8 +84,8 @@ class KerasMetricCallback(Callback): ...@@ -84,8 +84,8 @@ class KerasMetricCallback(Callback):
output_cols: Optional[List[str]] = None, output_cols: Optional[List[str]] = None,
label_cols: Optional[List[str]] = None, label_cols: Optional[List[str]] = None,
batch_size: Optional[int] = None, batch_size: Optional[int] = None,
predict_with_generate: Optional[bool] = False, predict_with_generate: bool = False,
use_xla_generation: Optional[bool] = False, use_xla_generation: bool = False,
generate_kwargs: Optional[dict] = None, generate_kwargs: Optional[dict] = None,
): ):
super().__init__() super().__init__()
...@@ -138,7 +138,7 @@ class KerasMetricCallback(Callback): ...@@ -138,7 +138,7 @@ class KerasMetricCallback(Callback):
logging.warning("TF versions less than 2.7 may encounter issues with KerasMetricCallback!") logging.warning("TF versions less than 2.7 may encounter issues with KerasMetricCallback!")
self.use_xla_generation = use_xla_generation self.use_xla_generation = use_xla_generation
self.generate_kwargs = generate_kwargs or {} self.generate_kwargs = {} if generate_kwargs is None else generate_kwargs
self.generation_function = None self.generation_function = None
...@@ -202,11 +202,7 @@ class KerasMetricCallback(Callback): ...@@ -202,11 +202,7 @@ class KerasMetricCallback(Callback):
if self.use_xla_generation and self.generation_function is None: if self.use_xla_generation and self.generation_function is None:
def generation_function(inputs, attention_mask): def generation_function(inputs, attention_mask):
return self.model.generate( return self.model.generate(inputs, attention_mask=attention_mask, **self.generate_kwargs)
inputs,
attention_mask=attention_mask,
**self.generate_kwargs,
)
self.generation_function = tf.function(generation_function, jit_compile=True) self.generation_function = tf.function(generation_function, jit_compile=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment