Unverified Commit f4e17271 authored by Matt's avatar Matt Committed by GitHub
Browse files

Allows `KerasMetricCallback` to use XLA generation (#18265)

* Allows `KerasMetricCallback` to use XLA generation

* make fixup

* Slightly reword docstring
parent bbb62f29
......@@ -65,6 +65,15 @@ class KerasMetricCallback(Callback):
Batch size. Only used when the data is not a pre-batched `tf.data.Dataset`.
predict_with_generate (`bool`, *optional*, defaults to `False`):
Whether we should use `model.generate()` to get outputs for the model.
use_xla_generation (`bool`, *optional*, defaults to `False`):
If we're generating, whether to compile model generation with XLA. This can massively increase the speed of
generation (up to 100X speedup) but will require a new XLA compilation for each input shape. When using XLA
generation, it's a good idea to pad your inputs to the same size, or to use the `pad_to_multiple_of`
argument in your `tokenizer` or `DataCollator`, which will reduce the number of unique input shapes and
save a lot of compilation time. This option has no effect is `predict_with_generate` is `False`.
generate_kwargs (`dict`, *optional*):
Keyword arguments to pass to `model.generate()` when generating. Has no effect if `predict_with_generate`
is `False`.
"""
......@@ -76,6 +85,8 @@ class KerasMetricCallback(Callback):
label_cols: Optional[List[str]] = None,
batch_size: Optional[int] = None,
predict_with_generate: Optional[bool] = False,
use_xla_generation: Optional[bool] = False,
generate_kwargs: Optional[dict] = None,
):
super().__init__()
self.metric_fn = metric_fn
......@@ -126,6 +137,11 @@ class KerasMetricCallback(Callback):
if parse(tf.__version__) < parse("2.7"):
logging.warning("TF versions less than 2.7 may encounter issues with KerasMetricCallback!")
self.use_xla_generation = use_xla_generation
self.generate_kwargs = generate_kwargs or {}
self.generation_function = None
@staticmethod
def _concatenate_batches(batches, padding_index=-100):
# If all batches are unidimensional or same length, do a simple concatenation
......@@ -183,6 +199,17 @@ class KerasMetricCallback(Callback):
else:
main_input_name = getattr(self.model, "main_input_name", "input_ids")
if self.use_xla_generation and self.generation_function is None:
def generation_function(inputs, attention_mask):
return self.model.generate(
inputs,
attention_mask=attention_mask,
**self.generate_kwargs,
)
self.generation_function = tf.function(generation_function, jit_compile=True)
prediction_list = []
label_list = []
......@@ -199,8 +226,10 @@ class KerasMetricCallback(Callback):
else:
generation_inputs = batch
attention_mask = None
predictions = self.model.generate(generation_inputs, attention_mask=attention_mask)
if self.use_xla_generation:
predictions = self.generation_function(generation_inputs, attention_mask=attention_mask)
else:
predictions = self.model.generate(generation_inputs, attention_mask=attention_mask)
else:
predictions = self.model.predict_on_batch(batch)
if isinstance(predictions, dict):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment