Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a541d974
Unverified
Commit
a541d974
authored
Aug 18, 2022
by
Joao Gante
Committed by
GitHub
Aug 18, 2022
Browse files
Generate: validate model_kwargs on FLAX (and catch typos in generate arguments) (#18653)
parent
0ea53822
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
1 deletion
+44
-1
src/transformers/generation_flax_utils.py
src/transformers/generation_flax_utils.py
+23
-1
tests/generation/test_generation_flax_utils.py
tests/generation/test_generation_flax_utils.py
+21
-0
No files found.
src/transformers/generation_flax_utils.py
View file @
a541d974
...
...
@@ -15,9 +15,10 @@
# limitations under the License.
import
inspect
import
warnings
from
functools
import
partial
from
typing
import
Dict
,
Optional
from
typing
import
Any
,
Dict
,
Optional
import
numpy
as
np
...
...
@@ -160,6 +161,24 @@ class FlaxGenerationMixin:
"""
return
logits
def
_validate_model_kwargs
(
self
,
model_kwargs
:
Dict
[
str
,
Any
]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
unused_model_args
=
[]
model_args
=
set
(
inspect
.
signature
(
self
.
prepare_inputs_for_generation
).
parameters
)
# `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
if
"kwargs"
in
model_args
:
model_args
|=
set
(
inspect
.
signature
(
self
.
__call__
).
parameters
)
for
key
,
value
in
model_kwargs
.
items
():
if
value
is
not
None
and
key
not
in
model_args
:
unused_model_args
.
append
(
key
)
if
unused_model_args
:
raise
ValueError
(
f
"The following `model_kwargs` are not used by the model:
{
unused_model_args
}
(note: typos in the"
" generate arguments will also show up in this list)"
)
def
generate
(
self
,
input_ids
:
jnp
.
ndarray
,
...
...
@@ -262,6 +281,9 @@ class FlaxGenerationMixin:
>>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
```"""
# Validate model kwargs
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
# set init values
bos_token_id
=
bos_token_id
if
bos_token_id
is
not
None
else
self
.
config
.
bos_token_id
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
config
.
pad_token_id
...
...
tests/generation/test_generation_flax_utils.py
View file @
a541d974
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
import
random
import
unittest
import
numpy
as
np
...
...
@@ -26,6 +27,7 @@ if is_flax_available():
import
jax.numpy
as
jnp
from
jax
import
jit
from
transformers
import
AutoTokenizer
,
FlaxAutoModelForCausalLM
from
transformers.modeling_flax_pytorch_utils
import
load_flax_weights_in_pytorch_model
os
.
environ
[
"XLA_PYTHON_CLIENT_MEM_FRACTION"
]
=
"0.12"
# assumed parallelism: 8
...
...
@@ -273,3 +275,22 @@ class FlaxGenerationTesterMixin:
jit_generation_outputs
=
jit_generate
(
input_ids
,
attention_mask
=
attention_mask
).
sequences
self
.
assertListEqual
(
generation_outputs
.
tolist
(),
jit_generation_outputs
.
tolist
())
@
require_flax
class
FlaxGenerationIntegrationTests
(
unittest
.
TestCase
):
def
test_validate_generation_inputs
(
self
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-bert"
)
model
=
FlaxAutoModelForCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-bert-flax-only"
)
encoder_input_str
=
"Hello world"
input_ids
=
tokenizer
(
encoder_input_str
,
return_tensors
=
"np"
).
input_ids
# typos are quickly detected (the correct argument is `do_sample`)
with
self
.
assertRaisesRegex
(
ValueError
,
"do_samples"
):
model
.
generate
(
input_ids
,
do_samples
=
True
)
# arbitrary arguments that will not be used anywhere are also not accepted
with
self
.
assertRaisesRegex
(
ValueError
,
"foo"
):
fake_model_kwargs
=
{
"foo"
:
"bar"
}
model
.
generate
(
input_ids
,
**
fake_model_kwargs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment