Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4cf38148
Unverified
Commit
4cf38148
authored
Nov 21, 2022
by
Joao Gante
Committed by
GitHub
Nov 21, 2022
Browse files
Generate: `model_kwargs` can also be an input to `prepare_inputs_for_generation` (#20353)
parent
d21c97cc
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
15 additions
and
11 deletions
+15
-11
src/transformers/generation/flax_utils.py
src/transformers/generation/flax_utils.py
+3
-3
src/transformers/generation/tf_utils.py
src/transformers/generation/tf_utils.py
+3
-3
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+3
-3
tests/generation/test_utils.py
tests/generation/test_utils.py
+6
-2
No files found.
src/transformers/generation/flax_utils.py
View file @
4cf38148
...
@@ -194,9 +194,9 @@ class FlaxGenerationMixin:
...
@@ -194,9 +194,9 @@ class FlaxGenerationMixin:
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
unused_model_args
=
[]
unused_model_args
=
[]
model_args
=
set
(
inspect
.
signature
(
self
.
prepare_inputs_for_generation
).
parameters
)
model_args
=
set
(
inspect
.
signature
(
self
.
prepare_inputs_for_generation
).
parameters
)
# `kwargs` i
f
often used to handle optional forward pass inputs like `attention_mask`. If
# `kwargs`
/`model_kwargs`
i
s
often used to handle optional forward pass inputs like `attention_mask`. If
# `prepare_inputs_for_generation` doesn't accept
`kwargs`
, then a stricter check can be made ;)
# `prepare_inputs_for_generation` doesn't accept
them
, then a stricter check can be made ;)
if
"kwargs"
in
model_args
:
if
"kwargs"
in
model_args
or
"model_kwargs"
in
model_args
:
model_args
|=
set
(
inspect
.
signature
(
self
.
__call__
).
parameters
)
model_args
|=
set
(
inspect
.
signature
(
self
.
__call__
).
parameters
)
for
key
,
value
in
model_kwargs
.
items
():
for
key
,
value
in
model_kwargs
.
items
():
if
value
is
not
None
and
key
not
in
model_args
:
if
value
is
not
None
and
key
not
in
model_args
:
...
...
src/transformers/generation/tf_utils.py
View file @
4cf38148
...
@@ -1445,9 +1445,9 @@ class TFGenerationMixin:
...
@@ -1445,9 +1445,9 @@ class TFGenerationMixin:
unused_model_args
=
[]
unused_model_args
=
[]
model_args
=
set
(
inspect
.
signature
(
self
.
prepare_inputs_for_generation
).
parameters
)
model_args
=
set
(
inspect
.
signature
(
self
.
prepare_inputs_for_generation
).
parameters
)
# `kwargs` i
f
often used to handle optional forward pass inputs like `attention_mask`. If
# `kwargs`
/`model_kwargs`
i
s
often used to handle optional forward pass inputs like `attention_mask`. If
# `prepare_inputs_for_generation` doesn't accept
`kwargs`
, then a stricter check can be made ;)
# `prepare_inputs_for_generation` doesn't accept
them
, then a stricter check can be made ;)
if
"kwargs"
in
model_args
:
if
"kwargs"
in
model_args
or
"model_kwargs"
in
model_args
:
model_args
|=
set
(
inspect
.
signature
(
self
.
call
).
parameters
)
model_args
|=
set
(
inspect
.
signature
(
self
.
call
).
parameters
)
for
key
,
value
in
model_kwargs
.
items
():
for
key
,
value
in
model_kwargs
.
items
():
if
value
is
not
None
and
key
not
in
model_args
:
if
value
is
not
None
and
key
not
in
model_args
:
...
...
src/transformers/generation/utils.py
View file @
4cf38148
...
@@ -981,9 +981,9 @@ class GenerationMixin:
...
@@ -981,9 +981,9 @@ class GenerationMixin:
unused_model_args
=
[]
unused_model_args
=
[]
model_args
=
set
(
inspect
.
signature
(
self
.
prepare_inputs_for_generation
).
parameters
)
model_args
=
set
(
inspect
.
signature
(
self
.
prepare_inputs_for_generation
).
parameters
)
# `kwargs` i
f
often used to handle optional forward pass inputs like `attention_mask`. If
# `kwargs`
/`model_kwargs`
i
s
often used to handle optional forward pass inputs like `attention_mask`. If
# `prepare_inputs_for_generation` doesn't accept
`kwargs`
, then a stricter check can be made ;)
# `prepare_inputs_for_generation` doesn't accept
them
, then a stricter check can be made ;)
if
"kwargs"
in
model_args
:
if
"kwargs"
in
model_args
or
"model_kwargs"
in
model_args
:
model_args
|=
set
(
inspect
.
signature
(
self
.
forward
).
parameters
)
model_args
|=
set
(
inspect
.
signature
(
self
.
forward
).
parameters
)
for
key
,
value
in
model_kwargs
.
items
():
for
key
,
value
in
model_kwargs
.
items
():
if
value
is
not
None
and
key
not
in
model_args
:
if
value
is
not
None
and
key
not
in
model_args
:
...
...
tests/generation/test_utils.py
View file @
4cf38148
...
@@ -3007,8 +3007,8 @@ class GenerationIntegrationTests(unittest.TestCase):
...
@@ -3007,8 +3007,8 @@ class GenerationIntegrationTests(unittest.TestCase):
self
.
assertTrue
(
max_score_diff
<
1e-5
)
self
.
assertTrue
(
max_score_diff
<
1e-5
)
def
test_validate_generation_inputs
(
self
):
def
test_validate_generation_inputs
(
self
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-
t5
"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-
roberta
"
)
model
=
AutoModelFor
Seq2Seq
LM
.
from_pretrained
(
"hf-internal-testing/tiny-random-
t5
"
)
model
=
AutoModelFor
Causal
LM
.
from_pretrained
(
"hf-internal-testing/tiny-random-
roberta
"
)
encoder_input_str
=
"Hello world"
encoder_input_str
=
"Hello world"
input_ids
=
tokenizer
(
encoder_input_str
,
return_tensors
=
"pt"
).
input_ids
input_ids
=
tokenizer
(
encoder_input_str
,
return_tensors
=
"pt"
).
input_ids
...
@@ -3021,3 +3021,7 @@ class GenerationIntegrationTests(unittest.TestCase):
...
@@ -3021,3 +3021,7 @@ class GenerationIntegrationTests(unittest.TestCase):
with
self
.
assertRaisesRegex
(
ValueError
,
"foo"
):
with
self
.
assertRaisesRegex
(
ValueError
,
"foo"
):
fake_model_kwargs
=
{
"foo"
:
"bar"
}
fake_model_kwargs
=
{
"foo"
:
"bar"
}
model
.
generate
(
input_ids
,
**
fake_model_kwargs
)
model
.
generate
(
input_ids
,
**
fake_model_kwargs
)
# However, valid model_kwargs are accepted
valid_model_kwargs
=
{
"attention_mask"
:
torch
.
zeros_like
(
input_ids
)}
model
.
generate
(
input_ids
,
**
valid_model_kwargs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment