Unverified Commit a541d974 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: validate model_kwargs on FLAX (and catch typos in generate arguments) (#18653)

parent 0ea53822
...@@ -15,9 +15,10 @@ ...@@ -15,9 +15,10 @@
# limitations under the License. # limitations under the License.
import inspect
import warnings import warnings
from functools import partial from functools import partial
from typing import Dict, Optional from typing import Any, Dict, Optional
import numpy as np import numpy as np
...@@ -160,6 +161,24 @@ class FlaxGenerationMixin: ...@@ -160,6 +161,24 @@ class FlaxGenerationMixin:
""" """
return logits return logits
def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
unused_model_args = []
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
# `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
if "kwargs" in model_args:
model_args |= set(inspect.signature(self.__call__).parameters)
for key, value in model_kwargs.items():
if value is not None and key not in model_args:
unused_model_args.append(key)
if unused_model_args:
raise ValueError(
f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
" generate arguments will also show up in this list)"
)
def generate( def generate(
self, self,
input_ids: jnp.ndarray, input_ids: jnp.ndarray,
...@@ -262,6 +281,9 @@ class FlaxGenerationMixin: ...@@ -262,6 +281,9 @@ class FlaxGenerationMixin:
>>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True) >>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
```""" ```"""
# Validate model kwargs
self._validate_model_kwargs(model_kwargs.copy())
# set init values # set init values
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import random import random
import unittest
import numpy as np import numpy as np
...@@ -26,6 +27,7 @@ if is_flax_available(): ...@@ -26,6 +27,7 @@ if is_flax_available():
import jax.numpy as jnp import jax.numpy as jnp
from jax import jit from jax import jit
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8 os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
...@@ -273,3 +275,22 @@ class FlaxGenerationTesterMixin: ...@@ -273,3 +275,22 @@ class FlaxGenerationTesterMixin:
jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist()) self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
@require_flax
class FlaxGenerationIntegrationTests(unittest.TestCase):
def test_validate_generation_inputs(self):
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-bert")
model = FlaxAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
encoder_input_str = "Hello world"
input_ids = tokenizer(encoder_input_str, return_tensors="np").input_ids
# typos are quickly detected (the correct argument is `do_sample`)
with self.assertRaisesRegex(ValueError, "do_samples"):
model.generate(input_ids, do_samples=True)
# arbitrary arguments that will not be used anywhere are also not accepted
with self.assertRaisesRegex(ValueError, "foo"):
fake_model_kwargs = {"foo": "bar"}
model.generate(input_ids, **fake_model_kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment