Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
20509ab0
Unverified
Commit
20509ab0
authored
Jul 13, 2022
by
Joao Gante
Committed by
GitHub
Jul 13, 2022
Browse files
TF: unpack_inputs decorator independent from main_input_name (#18110)
parent
fcefa200
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
16 deletions
+30
-16
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+16
-16
tests/test_modeling_tf_common.py
tests/test_modeling_tf_common.py
+14
-0
No files found.
src/transformers/modeling_tf_utils.py
View file @
20509ab0
...
...
@@ -404,9 +404,7 @@ def unpack_inputs(func):
fn_args_and_kwargs
.
update
(
dict
(
zip
(
func
.
__code__
.
co_varnames
[
1
:],
args
)))
# process the inputs and call the wrapped function
main_input_name
=
getattr
(
self
,
"main_input_name"
,
func
.
__code__
.
co_varnames
[
1
])
main_input
=
fn_args_and_kwargs
.
pop
(
main_input_name
,
None
)
unpacked_inputs
=
input_processing
(
func
,
self
.
config
,
main_input
,
**
fn_args_and_kwargs
)
unpacked_inputs
=
input_processing
(
func
,
self
.
config
,
**
fn_args_and_kwargs
)
return
func
(
self
,
**
unpacked_inputs
)
# Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This
...
...
@@ -417,7 +415,7 @@ def unpack_inputs(func):
return
run_call_with_unpacked_inputs
def
input_processing
(
func
,
config
,
input_ids
,
**
kwargs
):
def
input_processing
(
func
,
config
,
**
kwargs
):
"""
Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input
has to be named accordingly to the parameters name, i.e. `input_ids = tf.keras.Input(shape=(128,), dtype='int32',
...
...
@@ -438,6 +436,8 @@ def input_processing(func, config, input_ids, **kwargs):
has_kwargs
=
bool
(
signature
.
pop
(
"kwargs"
,
None
))
signature
.
pop
(
"self"
,
None
)
parameter_names
=
list
(
signature
.
keys
())
main_input_name
=
parameter_names
[
0
]
main_input
=
kwargs
.
pop
(
main_input_name
,
None
)
output
=
{}
allowed_types
=
(
tf
.
Tensor
,
bool
,
int
,
ModelOutput
,
tuple
,
list
,
dict
,
np
.
ndarray
,
KerasTensor
)
...
...
@@ -483,8 +483,8 @@ def input_processing(func, config, input_ids, **kwargs):
else
:
raise
ValueError
(
f
"Data of type
{
type
(
v
)
}
is not allowed only
{
allowed_types
}
is accepted for
{
k
}
."
)
if
isinstance
(
input
_ids
,
(
tuple
,
list
)):
for
i
,
input
in
enumerate
(
input
_ids
):
if
isinstance
(
main_
input
,
(
tuple
,
list
)):
for
i
,
input
in
enumerate
(
main_
input
):
# EagerTensors don't allow to use the .name property so we check for a real Tensor
if
type
(
input
)
==
tf
.
Tensor
:
# Tensor names have always the pattern `name:id` then we check only the
...
...
@@ -502,25 +502,25 @@ def input_processing(func, config, input_ids, **kwargs):
f
"Data of type
{
type
(
input
)
}
is not allowed only
{
allowed_types
}
is accepted for"
f
"
{
parameter_names
[
i
]
}
."
)
elif
isinstance
(
input
_ids
,
Mapping
):
if
"inputs"
in
input
_ids
:
elif
isinstance
(
main_
input
,
Mapping
):
if
"inputs"
in
main_
input
:
warnings
.
warn
(
"The `inputs` argument is deprecated and will be removed in a future version, use `input_ids`"
" instead."
,
FutureWarning
,
)
output
[
"input_ids"
]
=
input
_ids
.
pop
(
"inputs"
)
output
[
"input_ids"
]
=
main_
input
.
pop
(
"inputs"
)
if
"decoder_cached_states"
in
input
_ids
:
if
"decoder_cached_states"
in
main_
input
:
warnings
.
warn
(
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use"
" `past_key_values` instead."
,
FutureWarning
,
)
output
[
"past_key_values"
]
=
input
_ids
.
pop
(
"decoder_cached_states"
)
output
[
"past_key_values"
]
=
main_
input
.
pop
(
"decoder_cached_states"
)
for
k
,
v
in
dict
(
input
_ids
).
items
():
for
k
,
v
in
dict
(
main_
input
).
items
():
if
isinstance
(
v
,
allowed_types
)
or
v
is
None
:
output
[
k
]
=
v
elif
k
not
in
parameter_names
and
"args"
not
in
parameter_names
:
...
...
@@ -531,12 +531,12 @@ def input_processing(func, config, input_ids, **kwargs):
else
:
raise
ValueError
(
f
"Data of type
{
type
(
v
)
}
is not allowed only
{
allowed_types
}
is accepted for
{
k
}
."
)
else
:
if
isinstance
(
input
_ids
,
(
tf
.
Tensor
,
KerasTensor
))
or
input
_ids
is
None
:
output
[
parameter_names
[
0
]]
=
input
_ids
if
isinstance
(
main_
input
,
(
tf
.
Tensor
,
KerasTensor
))
or
main_
input
is
None
:
output
[
main_input_name
]
=
main_
input
else
:
raise
ValueError
(
f
"Data of type
{
type
(
input
_ids
)
}
is not allowed only
{
allowed_types
}
is accepted for"
f
"
{
parameter
_name
s
[
0
]
}
."
f
"Data of type
{
type
(
main_
input
)
}
is not allowed only
{
allowed_types
}
is accepted for"
f
"
{
main_input
_name
}
."
)
# Populates any unspecified argument with their default value, according to the signature.
...
...
tests/test_modeling_tf_common.py
View file @
20509ab0
...
...
@@ -1881,6 +1881,7 @@ class UtilsFunctionsTest(unittest.TestCase):
def
__init__
(
self
):
config_kwargs
=
{
"output_attentions"
:
False
,
"output_hidden_states"
:
False
,
"return_dict"
:
False
}
self
.
config
=
PretrainedConfig
(
**
config_kwargs
)
self
.
main_input_name
=
"input_ids"
@
unpack_inputs
def
call
(
...
...
@@ -1888,9 +1889,14 @@ class UtilsFunctionsTest(unittest.TestCase):
):
return
input_ids
,
past
,
output_attentions
,
output_hidden_states
,
return_dict
@
unpack_inputs
def
foo
(
self
,
pixel_values
,
output_attentions
=
None
,
output_hidden_states
=
None
,
return_dict
=
None
):
return
pixel_values
,
output_attentions
,
output_hidden_states
,
return_dict
dummy_model
=
DummyModel
()
input_ids
=
tf
.
constant
([
0
,
1
,
2
,
3
])
past
=
tf
.
constant
([
4
,
5
,
6
,
7
])
pixel_values
=
tf
.
constant
([
8
,
9
,
10
,
11
])
# test case 1: Pass inputs as keyword arguments; Booleans are inherited from the config.
output
=
dummy_model
.
call
(
input_ids
=
input_ids
,
past
=
past
)
...
...
@@ -1937,6 +1943,14 @@ class UtilsFunctionsTest(unittest.TestCase):
self
.
assertFalse
(
output
[
3
])
self
.
assertFalse
(
output
[
4
])
# test case 7: the decorator is independent from `main_input_name` -- it treats the first argument of the
# decorated function as its main input.
output
=
dummy_model
.
foo
(
pixel_values
=
pixel_values
)
tf
.
debugging
.
assert_equal
(
output
[
0
],
pixel_values
)
self
.
assertFalse
(
output
[
1
])
self
.
assertFalse
(
output
[
2
])
self
.
assertFalse
(
output
[
3
])
# Tests whether the stable softmax is stable on CPU, with and without XLA
def
test_xla_stable_softmax
(
self
):
large_penalty
=
-
1e9
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment