Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
20509ab0
Unverified
Commit
20509ab0
authored
Jul 13, 2022
by
Joao Gante
Committed by
GitHub
Jul 13, 2022
Browse files
TF: unpack_inputs decorator independent from main_input_name (#18110)
parent
fcefa200
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
16 deletions
+30
-16
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+16
-16
tests/test_modeling_tf_common.py
tests/test_modeling_tf_common.py
+14
-0
No files found.
src/transformers/modeling_tf_utils.py
View file @
20509ab0
...
...
@@ -404,9 +404,7 @@ def unpack_inputs(func):
fn_args_and_kwargs
.
update
(
dict
(
zip
(
func
.
__code__
.
co_varnames
[
1
:],
args
)))
# process the inputs and call the wrapped function
main_input_name
=
getattr
(
self
,
"main_input_name"
,
func
.
__code__
.
co_varnames
[
1
])
main_input
=
fn_args_and_kwargs
.
pop
(
main_input_name
,
None
)
unpacked_inputs
=
input_processing
(
func
,
self
.
config
,
main_input
,
**
fn_args_and_kwargs
)
unpacked_inputs
=
input_processing
(
func
,
self
.
config
,
**
fn_args_and_kwargs
)
return
func
(
self
,
**
unpacked_inputs
)
# Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This
...
...
@@ -417,7 +415,7 @@ def unpack_inputs(func):
return
run_call_with_unpacked_inputs
def
input_processing
(
func
,
config
,
input_ids
,
**
kwargs
):
def
input_processing
(
func
,
config
,
**
kwargs
):
"""
Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input
has to be named accordingly to the parameters name, i.e. `input_ids = tf.keras.Input(shape=(128,), dtype='int32',
...
...
@@ -438,6 +436,8 @@ def input_processing(func, config, input_ids, **kwargs):
has_kwargs
=
bool
(
signature
.
pop
(
"kwargs"
,
None
))
signature
.
pop
(
"self"
,
None
)
parameter_names
=
list
(
signature
.
keys
())
main_input_name
=
parameter_names
[
0
]
main_input
=
kwargs
.
pop
(
main_input_name
,
None
)
output
=
{}
allowed_types
=
(
tf
.
Tensor
,
bool
,
int
,
ModelOutput
,
tuple
,
list
,
dict
,
np
.
ndarray
,
KerasTensor
)
...
...
@@ -483,8 +483,8 @@ def input_processing(func, config, input_ids, **kwargs):
else
:
raise
ValueError
(
f
"Data of type
{
type
(
v
)
}
is not allowed only
{
allowed_types
}
is accepted for
{
k
}
."
)
if
isinstance
(
input
_ids
,
(
tuple
,
list
)):
for
i
,
input
in
enumerate
(
input
_ids
):
if
isinstance
(
main_
input
,
(
tuple
,
list
)):
for
i
,
input
in
enumerate
(
main_
input
):
# EagerTensors don't allow to use the .name property so we check for a real Tensor
if
type
(
input
)
==
tf
.
Tensor
:
# Tensor names have always the pattern `name:id` then we check only the
...
...
@@ -502,25 +502,25 @@ def input_processing(func, config, input_ids, **kwargs):
f
"Data of type
{
type
(
input
)
}
is not allowed only
{
allowed_types
}
is accepted for"
f
"
{
parameter_names
[
i
]
}
."
)
elif
isinstance
(
input
_ids
,
Mapping
):
if
"inputs"
in
input
_ids
:
elif
isinstance
(
main_
input
,
Mapping
):
if
"inputs"
in
main_
input
:
warnings
.
warn
(
"The `inputs` argument is deprecated and will be removed in a future version, use `input_ids`"
" instead."
,
FutureWarning
,
)
output
[
"input_ids"
]
=
input
_ids
.
pop
(
"inputs"
)
output
[
"input_ids"
]
=
main_
input
.
pop
(
"inputs"
)
if
"decoder_cached_states"
in
input
_ids
:
if
"decoder_cached_states"
in
main_
input
:
warnings
.
warn
(
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use"
" `past_key_values` instead."
,
FutureWarning
,
)
output
[
"past_key_values"
]
=
input
_ids
.
pop
(
"decoder_cached_states"
)
output
[
"past_key_values"
]
=
main_
input
.
pop
(
"decoder_cached_states"
)
for
k
,
v
in
dict
(
input
_ids
).
items
():
for
k
,
v
in
dict
(
main_
input
).
items
():
if
isinstance
(
v
,
allowed_types
)
or
v
is
None
:
output
[
k
]
=
v
elif
k
not
in
parameter_names
and
"args"
not
in
parameter_names
:
...
...
@@ -531,12 +531,12 @@ def input_processing(func, config, input_ids, **kwargs):
else
:
raise
ValueError
(
f
"Data of type
{
type
(
v
)
}
is not allowed only
{
allowed_types
}
is accepted for
{
k
}
."
)
else
:
if
isinstance
(
input
_ids
,
(
tf
.
Tensor
,
KerasTensor
))
or
input
_ids
is
None
:
output
[
parameter_names
[
0
]]
=
input
_ids
if
isinstance
(
main_
input
,
(
tf
.
Tensor
,
KerasTensor
))
or
main_
input
is
None
:
output
[
main_input_name
]
=
main_
input
else
:
raise
ValueError
(
f
"Data of type
{
type
(
input
_ids
)
}
is not allowed only
{
allowed_types
}
is accepted for"
f
"
{
parameter
_name
s
[
0
]
}
."
f
"Data of type
{
type
(
main_
input
)
}
is not allowed only
{
allowed_types
}
is accepted for"
f
"
{
main_input
_name
}
."
)
# Populates any unspecified argument with their default value, according to the signature.
...
...
tests/test_modeling_tf_common.py
View file @
20509ab0
...
...
@@ -1881,6 +1881,7 @@ class UtilsFunctionsTest(unittest.TestCase):
def
__init__
(
self
):
config_kwargs
=
{
"output_attentions"
:
False
,
"output_hidden_states"
:
False
,
"return_dict"
:
False
}
self
.
config
=
PretrainedConfig
(
**
config_kwargs
)
self
.
main_input_name
=
"input_ids"
@
unpack_inputs
def
call
(
...
...
@@ -1888,9 +1889,14 @@ class UtilsFunctionsTest(unittest.TestCase):
):
return
input_ids
,
past
,
output_attentions
,
output_hidden_states
,
return_dict
@
unpack_inputs
def
foo
(
self
,
pixel_values
,
output_attentions
=
None
,
output_hidden_states
=
None
,
return_dict
=
None
):
return
pixel_values
,
output_attentions
,
output_hidden_states
,
return_dict
dummy_model
=
DummyModel
()
input_ids
=
tf
.
constant
([
0
,
1
,
2
,
3
])
past
=
tf
.
constant
([
4
,
5
,
6
,
7
])
pixel_values
=
tf
.
constant
([
8
,
9
,
10
,
11
])
# test case 1: Pass inputs as keyword arguments; Booleans are inherited from the config.
output
=
dummy_model
.
call
(
input_ids
=
input_ids
,
past
=
past
)
...
...
@@ -1937,6 +1943,14 @@ class UtilsFunctionsTest(unittest.TestCase):
self
.
assertFalse
(
output
[
3
])
self
.
assertFalse
(
output
[
4
])
# test case 7: the decorator is independent from `main_input_name` -- it treats the first argument of the
# decorated function as its main input.
output
=
dummy_model
.
foo
(
pixel_values
=
pixel_values
)
tf
.
debugging
.
assert_equal
(
output
[
0
],
pixel_values
)
self
.
assertFalse
(
output
[
1
])
self
.
assertFalse
(
output
[
2
])
self
.
assertFalse
(
output
[
3
])
# Tests whether the stable softmax is stable on CPU, with and without XLA
def
test_xla_stable_softmax
(
self
):
large_penalty
=
-
1e9
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment