Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
20509ab0
Unverified
Commit
20509ab0
authored
Jul 13, 2022
by
Joao Gante
Committed by
GitHub
Jul 13, 2022
Browse files
TF: unpack_inputs decorator independent from main_input_name (#18110)
parent
fcefa200
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
16 deletions
+30
-16
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+16
-16
tests/test_modeling_tf_common.py
tests/test_modeling_tf_common.py
+14
-0
No files found.
src/transformers/modeling_tf_utils.py
View file @
20509ab0
...
@@ -404,9 +404,7 @@ def unpack_inputs(func):
...
@@ -404,9 +404,7 @@ def unpack_inputs(func):
fn_args_and_kwargs
.
update
(
dict
(
zip
(
func
.
__code__
.
co_varnames
[
1
:],
args
)))
fn_args_and_kwargs
.
update
(
dict
(
zip
(
func
.
__code__
.
co_varnames
[
1
:],
args
)))
# process the inputs and call the wrapped function
# process the inputs and call the wrapped function
main_input_name
=
getattr
(
self
,
"main_input_name"
,
func
.
__code__
.
co_varnames
[
1
])
unpacked_inputs
=
input_processing
(
func
,
self
.
config
,
**
fn_args_and_kwargs
)
main_input
=
fn_args_and_kwargs
.
pop
(
main_input_name
,
None
)
unpacked_inputs
=
input_processing
(
func
,
self
.
config
,
main_input
,
**
fn_args_and_kwargs
)
return
func
(
self
,
**
unpacked_inputs
)
return
func
(
self
,
**
unpacked_inputs
)
# Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This
# Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This
...
@@ -417,7 +415,7 @@ def unpack_inputs(func):
...
@@ -417,7 +415,7 @@ def unpack_inputs(func):
return
run_call_with_unpacked_inputs
return
run_call_with_unpacked_inputs
def
input_processing
(
func
,
config
,
input_ids
,
**
kwargs
):
def
input_processing
(
func
,
config
,
**
kwargs
):
"""
"""
Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input
Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input
has to be named accordingly to the parameters name, i.e. `input_ids = tf.keras.Input(shape=(128,), dtype='int32',
has to be named accordingly to the parameters name, i.e. `input_ids = tf.keras.Input(shape=(128,), dtype='int32',
...
@@ -438,6 +436,8 @@ def input_processing(func, config, input_ids, **kwargs):
...
@@ -438,6 +436,8 @@ def input_processing(func, config, input_ids, **kwargs):
has_kwargs
=
bool
(
signature
.
pop
(
"kwargs"
,
None
))
has_kwargs
=
bool
(
signature
.
pop
(
"kwargs"
,
None
))
signature
.
pop
(
"self"
,
None
)
signature
.
pop
(
"self"
,
None
)
parameter_names
=
list
(
signature
.
keys
())
parameter_names
=
list
(
signature
.
keys
())
main_input_name
=
parameter_names
[
0
]
main_input
=
kwargs
.
pop
(
main_input_name
,
None
)
output
=
{}
output
=
{}
allowed_types
=
(
tf
.
Tensor
,
bool
,
int
,
ModelOutput
,
tuple
,
list
,
dict
,
np
.
ndarray
,
KerasTensor
)
allowed_types
=
(
tf
.
Tensor
,
bool
,
int
,
ModelOutput
,
tuple
,
list
,
dict
,
np
.
ndarray
,
KerasTensor
)
...
@@ -483,8 +483,8 @@ def input_processing(func, config, input_ids, **kwargs):
...
@@ -483,8 +483,8 @@ def input_processing(func, config, input_ids, **kwargs):
else
:
else
:
raise
ValueError
(
f
"Data of type
{
type
(
v
)
}
is not allowed only
{
allowed_types
}
is accepted for
{
k
}
."
)
raise
ValueError
(
f
"Data of type
{
type
(
v
)
}
is not allowed only
{
allowed_types
}
is accepted for
{
k
}
."
)
if
isinstance
(
input
_ids
,
(
tuple
,
list
)):
if
isinstance
(
main_
input
,
(
tuple
,
list
)):
for
i
,
input
in
enumerate
(
input
_ids
):
for
i
,
input
in
enumerate
(
main_
input
):
# EagerTensors don't allow to use the .name property so we check for a real Tensor
# EagerTensors don't allow to use the .name property so we check for a real Tensor
if
type
(
input
)
==
tf
.
Tensor
:
if
type
(
input
)
==
tf
.
Tensor
:
# Tensor names have always the pattern `name:id` then we check only the
# Tensor names have always the pattern `name:id` then we check only the
...
@@ -502,25 +502,25 @@ def input_processing(func, config, input_ids, **kwargs):
...
@@ -502,25 +502,25 @@ def input_processing(func, config, input_ids, **kwargs):
f
"Data of type
{
type
(
input
)
}
is not allowed only
{
allowed_types
}
is accepted for"
f
"Data of type
{
type
(
input
)
}
is not allowed only
{
allowed_types
}
is accepted for"
f
"
{
parameter_names
[
i
]
}
."
f
"
{
parameter_names
[
i
]
}
."
)
)
elif
isinstance
(
input
_ids
,
Mapping
):
elif
isinstance
(
main_
input
,
Mapping
):
if
"inputs"
in
input
_ids
:
if
"inputs"
in
main_
input
:
warnings
.
warn
(
warnings
.
warn
(
"The `inputs` argument is deprecated and will be removed in a future version, use `input_ids`"
"The `inputs` argument is deprecated and will be removed in a future version, use `input_ids`"
" instead."
,
" instead."
,
FutureWarning
,
FutureWarning
,
)
)
output
[
"input_ids"
]
=
input
_ids
.
pop
(
"inputs"
)
output
[
"input_ids"
]
=
main_
input
.
pop
(
"inputs"
)
if
"decoder_cached_states"
in
input
_ids
:
if
"decoder_cached_states"
in
main_
input
:
warnings
.
warn
(
warnings
.
warn
(
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use"
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use"
" `past_key_values` instead."
,
" `past_key_values` instead."
,
FutureWarning
,
FutureWarning
,
)
)
output
[
"past_key_values"
]
=
input
_ids
.
pop
(
"decoder_cached_states"
)
output
[
"past_key_values"
]
=
main_
input
.
pop
(
"decoder_cached_states"
)
for
k
,
v
in
dict
(
input
_ids
).
items
():
for
k
,
v
in
dict
(
main_
input
).
items
():
if
isinstance
(
v
,
allowed_types
)
or
v
is
None
:
if
isinstance
(
v
,
allowed_types
)
or
v
is
None
:
output
[
k
]
=
v
output
[
k
]
=
v
elif
k
not
in
parameter_names
and
"args"
not
in
parameter_names
:
elif
k
not
in
parameter_names
and
"args"
not
in
parameter_names
:
...
@@ -531,12 +531,12 @@ def input_processing(func, config, input_ids, **kwargs):
...
@@ -531,12 +531,12 @@ def input_processing(func, config, input_ids, **kwargs):
else
:
else
:
raise
ValueError
(
f
"Data of type
{
type
(
v
)
}
is not allowed only
{
allowed_types
}
is accepted for
{
k
}
."
)
raise
ValueError
(
f
"Data of type
{
type
(
v
)
}
is not allowed only
{
allowed_types
}
is accepted for
{
k
}
."
)
else
:
else
:
if
isinstance
(
input
_ids
,
(
tf
.
Tensor
,
KerasTensor
))
or
input
_ids
is
None
:
if
isinstance
(
main_
input
,
(
tf
.
Tensor
,
KerasTensor
))
or
main_
input
is
None
:
output
[
parameter_names
[
0
]]
=
input
_ids
output
[
main_input_name
]
=
main_
input
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Data of type
{
type
(
input
_ids
)
}
is not allowed only
{
allowed_types
}
is accepted for"
f
"Data of type
{
type
(
main_
input
)
}
is not allowed only
{
allowed_types
}
is accepted for"
f
"
{
parameter
_name
s
[
0
]
}
."
f
"
{
main_input
_name
}
."
)
)
# Populates any unspecified argument with their default value, according to the signature.
# Populates any unspecified argument with their default value, according to the signature.
...
...
tests/test_modeling_tf_common.py
View file @
20509ab0
...
@@ -1881,6 +1881,7 @@ class UtilsFunctionsTest(unittest.TestCase):
...
@@ -1881,6 +1881,7 @@ class UtilsFunctionsTest(unittest.TestCase):
def
__init__
(
self
):
def
__init__
(
self
):
config_kwargs
=
{
"output_attentions"
:
False
,
"output_hidden_states"
:
False
,
"return_dict"
:
False
}
config_kwargs
=
{
"output_attentions"
:
False
,
"output_hidden_states"
:
False
,
"return_dict"
:
False
}
self
.
config
=
PretrainedConfig
(
**
config_kwargs
)
self
.
config
=
PretrainedConfig
(
**
config_kwargs
)
self
.
main_input_name
=
"input_ids"
@
unpack_inputs
@
unpack_inputs
def
call
(
def
call
(
...
@@ -1888,9 +1889,14 @@ class UtilsFunctionsTest(unittest.TestCase):
...
@@ -1888,9 +1889,14 @@ class UtilsFunctionsTest(unittest.TestCase):
):
):
return
input_ids
,
past
,
output_attentions
,
output_hidden_states
,
return_dict
return
input_ids
,
past
,
output_attentions
,
output_hidden_states
,
return_dict
@
unpack_inputs
def
foo
(
self
,
pixel_values
,
output_attentions
=
None
,
output_hidden_states
=
None
,
return_dict
=
None
):
return
pixel_values
,
output_attentions
,
output_hidden_states
,
return_dict
dummy_model
=
DummyModel
()
dummy_model
=
DummyModel
()
input_ids
=
tf
.
constant
([
0
,
1
,
2
,
3
])
input_ids
=
tf
.
constant
([
0
,
1
,
2
,
3
])
past
=
tf
.
constant
([
4
,
5
,
6
,
7
])
past
=
tf
.
constant
([
4
,
5
,
6
,
7
])
pixel_values
=
tf
.
constant
([
8
,
9
,
10
,
11
])
# test case 1: Pass inputs as keyword arguments; Booleans are inherited from the config.
# test case 1: Pass inputs as keyword arguments; Booleans are inherited from the config.
output
=
dummy_model
.
call
(
input_ids
=
input_ids
,
past
=
past
)
output
=
dummy_model
.
call
(
input_ids
=
input_ids
,
past
=
past
)
...
@@ -1937,6 +1943,14 @@ class UtilsFunctionsTest(unittest.TestCase):
...
@@ -1937,6 +1943,14 @@ class UtilsFunctionsTest(unittest.TestCase):
self
.
assertFalse
(
output
[
3
])
self
.
assertFalse
(
output
[
3
])
self
.
assertFalse
(
output
[
4
])
self
.
assertFalse
(
output
[
4
])
# test case 7: the decorator is independent from `main_input_name` -- it treats the first argument of the
# decorated function as its main input.
output
=
dummy_model
.
foo
(
pixel_values
=
pixel_values
)
tf
.
debugging
.
assert_equal
(
output
[
0
],
pixel_values
)
self
.
assertFalse
(
output
[
1
])
self
.
assertFalse
(
output
[
2
])
self
.
assertFalse
(
output
[
3
])
# Tests whether the stable softmax is stable on CPU, with and without XLA
# Tests whether the stable softmax is stable on CPU, with and without XLA
def
test_xla_stable_softmax
(
self
):
def
test_xla_stable_softmax
(
self
):
large_penalty
=
-
1e9
large_penalty
=
-
1e9
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment