Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a541d974
Unverified
Commit
a541d974
authored
Aug 18, 2022
by
Joao Gante
Committed by
GitHub
Aug 18, 2022
Browse files
Generate: validate model_kwargs on FLAX (and catch typos in generate arguments) (#18653)
parent
0ea53822
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
1 deletion
+44
-1
src/transformers/generation_flax_utils.py
src/transformers/generation_flax_utils.py
+23
-1
tests/generation/test_generation_flax_utils.py
tests/generation/test_generation_flax_utils.py
+21
-0
No files found.
src/transformers/generation_flax_utils.py
View file @
a541d974
...
...
@@ -15,9 +15,10 @@
# limitations under the License.
import
inspect
import
warnings
from
functools
import
partial
from
typing
import
Dict
,
Optional
from
typing
import
Any
,
Dict
,
Optional
import
numpy
as
np
...
...
@@ -160,6 +161,24 @@ class FlaxGenerationMixin:
"""
return
logits
def
_validate_model_kwargs
(
self
,
model_kwargs
:
Dict
[
str
,
Any
]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
unused_model_args
=
[]
model_args
=
set
(
inspect
.
signature
(
self
.
prepare_inputs_for_generation
).
parameters
)
# `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
if
"kwargs"
in
model_args
:
model_args
|=
set
(
inspect
.
signature
(
self
.
__call__
).
parameters
)
for
key
,
value
in
model_kwargs
.
items
():
if
value
is
not
None
and
key
not
in
model_args
:
unused_model_args
.
append
(
key
)
if
unused_model_args
:
raise
ValueError
(
f
"The following `model_kwargs` are not used by the model:
{
unused_model_args
}
(note: typos in the"
" generate arguments will also show up in this list)"
)
def
generate
(
self
,
input_ids
:
jnp
.
ndarray
,
...
...
@@ -262,6 +281,9 @@ class FlaxGenerationMixin:
>>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
```"""
# Validate model kwargs
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
# set init values
bos_token_id
=
bos_token_id
if
bos_token_id
is
not
None
else
self
.
config
.
bos_token_id
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
config
.
pad_token_id
...
...
tests/generation/test_generation_flax_utils.py
View file @
a541d974
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
import
random
import
unittest
import
numpy
as
np
...
...
@@ -26,6 +27,7 @@ if is_flax_available():
import
jax.numpy
as
jnp
from
jax
import
jit
from
transformers
import
AutoTokenizer
,
FlaxAutoModelForCausalLM
from
transformers.modeling_flax_pytorch_utils
import
load_flax_weights_in_pytorch_model
os
.
environ
[
"XLA_PYTHON_CLIENT_MEM_FRACTION"
]
=
"0.12"
# assumed parallelism: 8
...
...
@@ -273,3 +275,22 @@ class FlaxGenerationTesterMixin:
jit_generation_outputs
=
jit_generate
(
input_ids
,
attention_mask
=
attention_mask
).
sequences
self
.
assertListEqual
(
generation_outputs
.
tolist
(),
jit_generation_outputs
.
tolist
())
@
require_flax
class
FlaxGenerationIntegrationTests
(
unittest
.
TestCase
):
def
test_validate_generation_inputs
(
self
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-bert"
)
model
=
FlaxAutoModelForCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-bert-flax-only"
)
encoder_input_str
=
"Hello world"
input_ids
=
tokenizer
(
encoder_input_str
,
return_tensors
=
"np"
).
input_ids
# typos are quickly detected (the correct argument is `do_sample`)
with
self
.
assertRaisesRegex
(
ValueError
,
"do_samples"
):
model
.
generate
(
input_ids
,
do_samples
=
True
)
# arbitrary arguments that will not be used anywhere are also not accepted
with
self
.
assertRaisesRegex
(
ValueError
,
"foo"
):
fake_model_kwargs
=
{
"foo"
:
"bar"
}
model
.
generate
(
input_ids
,
**
fake_model_kwargs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment