Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
18df4407
Unverified
Commit
18df4407
authored
Apr 29, 2022
by
Sylvain Gugger
Committed by
GitHub
Apr 29, 2022
Browse files
Replace dict/BatchEncoding instance checks by Mapping (#17014)
* Replace dict/BatchEncoding instance checks by Mapping * Typo
parent
b8dffd1f
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
30 additions
and
26 deletions
+30
-26
src/transformers/data/data_collator.py
src/transformers/data/data_collator.py
+14
-13
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+2
-2
src/transformers/models/hubert/modeling_tf_hubert.py
src/transformers/models/hubert/modeling_tf_hubert.py
+2
-2
src/transformers/models/luke/tokenization_luke.py
src/transformers/models/luke/tokenization_luke.py
+2
-1
src/transformers/models/mluke/tokenization_mluke.py
src/transformers/models/mluke/tokenization_mluke.py
+2
-1
src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
+2
-2
src/transformers/testing_utils.py
src/transformers/testing_utils.py
+2
-3
src/transformers/tokenization_utils_base.py
src/transformers/tokenization_utils_base.py
+2
-1
src/transformers/trainer_pt_utils.py
src/transformers/trainer_pt_utils.py
+2
-1
No files found.
src/transformers/data/data_collator.py
View file @
18df4407
...
...
@@ -14,11 +14,12 @@
import
random
import
warnings
from
collections.abc
import
Mapping
from
dataclasses
import
dataclass
from
typing
import
Any
,
Callable
,
Dict
,
List
,
NewType
,
Optional
,
Tuple
,
Union
from
..models.bert
import
BertTokenizer
,
BertTokenizerFast
from
..tokenization_utils_base
import
BatchEncoding
,
PreTrainedTokenizerBase
from
..tokenization_utils_base
import
PreTrainedTokenizerBase
from
..utils
import
PaddingStrategy
...
...
@@ -101,7 +102,7 @@ class DefaultDataCollator(DataCollatorMixin):
def
torch_default_data_collator
(
features
:
List
[
InputDataClass
])
->
Dict
[
str
,
Any
]:
import
torch
if
not
isinstance
(
features
[
0
],
(
dict
,
BatchEncod
ing
)
)
:
if
not
isinstance
(
features
[
0
],
Mapp
ing
):
features
=
[
vars
(
f
)
for
f
in
features
]
first
=
features
[
0
]
batch
=
{}
...
...
@@ -136,7 +137,7 @@ def tf_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
import
numpy
as
np
import
tensorflow
as
tf
if
not
isinstance
(
features
[
0
],
(
dict
,
BatchEncod
ing
)
)
:
if
not
isinstance
(
features
[
0
],
Mapp
ing
):
features
=
[
vars
(
f
)
for
f
in
features
]
first
=
features
[
0
]
batch
=
{}
...
...
@@ -177,7 +178,7 @@ def tf_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
def
numpy_default_data_collator
(
features
:
List
[
InputDataClass
])
->
Dict
[
str
,
Any
]:
import
numpy
as
np
if
not
isinstance
(
features
[
0
],
(
dict
,
BatchEncod
ing
)
)
:
if
not
isinstance
(
features
[
0
],
Mapp
ing
):
features
=
[
vars
(
f
)
for
f
in
features
]
first
=
features
[
0
]
batch
=
{}
...
...
@@ -687,7 +688,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
import
tensorflow
as
tf
# Handle dict or lists with proper padding and conversion to tensor.
if
isinstance
(
examples
[
0
],
(
dict
,
BatchEncod
ing
)
)
:
if
isinstance
(
examples
[
0
],
Mapp
ing
):
batch
=
self
.
tokenizer
.
pad
(
examples
,
return_tensors
=
"tf"
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
else
:
batch
=
{
...
...
@@ -724,7 +725,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
def
torch_call
(
self
,
examples
:
List
[
Union
[
List
[
int
],
Any
,
Dict
[
str
,
Any
]]])
->
Dict
[
str
,
Any
]:
# Handle dict or lists with proper padding and conversion to tensor.
if
isinstance
(
examples
[
0
],
(
dict
,
BatchEncod
ing
)
)
:
if
isinstance
(
examples
[
0
],
Mapp
ing
):
batch
=
self
.
tokenizer
.
pad
(
examples
,
return_tensors
=
"pt"
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
else
:
batch
=
{
...
...
@@ -781,7 +782,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
import
numpy
as
np
# Handle dict or lists with proper padding and conversion to tensor.
if
isinstance
(
examples
[
0
],
(
dict
,
BatchEncod
ing
)
)
:
if
isinstance
(
examples
[
0
],
Mapp
ing
):
batch
=
self
.
tokenizer
.
pad
(
examples
,
return_tensors
=
"np"
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
else
:
batch
=
{
...
...
@@ -858,7 +859,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
</Tip>"""
def
torch_call
(
self
,
examples
:
List
[
Union
[
List
[
int
],
Any
,
Dict
[
str
,
Any
]]])
->
Dict
[
str
,
Any
]:
if
isinstance
(
examples
[
0
],
(
dict
,
BatchEncod
ing
)
)
:
if
isinstance
(
examples
[
0
],
Mapp
ing
):
input_ids
=
[
e
[
"input_ids"
]
for
e
in
examples
]
else
:
input_ids
=
examples
...
...
@@ -886,7 +887,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
return
{
"input_ids"
:
inputs
,
"labels"
:
labels
}
def
tf_call
(
self
,
examples
:
List
[
Union
[
List
[
int
],
Any
,
Dict
[
str
,
Any
]]])
->
Dict
[
str
,
Any
]:
if
isinstance
(
examples
[
0
],
(
dict
,
BatchEncod
ing
)
)
:
if
isinstance
(
examples
[
0
],
Mapp
ing
):
input_ids
=
[
e
[
"input_ids"
]
for
e
in
examples
]
else
:
input_ids
=
examples
...
...
@@ -914,7 +915,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
return
{
"input_ids"
:
inputs
,
"labels"
:
labels
}
def
numpy_call
(
self
,
examples
:
List
[
Union
[
List
[
int
],
Any
,
Dict
[
str
,
Any
]]])
->
Dict
[
str
,
Any
]:
if
isinstance
(
examples
[
0
],
(
dict
,
BatchEncod
ing
)
)
:
if
isinstance
(
examples
[
0
],
Mapp
ing
):
input_ids
=
[
e
[
"input_ids"
]
for
e
in
examples
]
else
:
input_ids
=
examples
...
...
@@ -1207,21 +1208,21 @@ class DataCollatorForPermutationLanguageModeling(DataCollatorMixin):
return_tensors
:
str
=
"pt"
def
torch_call
(
self
,
examples
:
List
[
Union
[
List
[
int
],
Any
,
Dict
[
str
,
Any
]]])
->
Dict
[
str
,
Any
]:
if
isinstance
(
examples
[
0
],
(
dict
,
BatchEncod
ing
)
)
:
if
isinstance
(
examples
[
0
],
Mapp
ing
):
examples
=
[
e
[
"input_ids"
]
for
e
in
examples
]
batch
=
_torch_collate_batch
(
examples
,
self
.
tokenizer
)
inputs
,
perm_mask
,
target_mapping
,
labels
=
self
.
torch_mask_tokens
(
batch
)
return
{
"input_ids"
:
inputs
,
"perm_mask"
:
perm_mask
,
"target_mapping"
:
target_mapping
,
"labels"
:
labels
}
def
tf_call
(
self
,
examples
:
List
[
Union
[
List
[
int
],
Any
,
Dict
[
str
,
Any
]]])
->
Dict
[
str
,
Any
]:
if
isinstance
(
examples
[
0
],
(
dict
,
BatchEncod
ing
)
)
:
if
isinstance
(
examples
[
0
],
Mapp
ing
):
examples
=
[
e
[
"input_ids"
]
for
e
in
examples
]
batch
=
_tf_collate_batch
(
examples
,
self
.
tokenizer
)
inputs
,
perm_mask
,
target_mapping
,
labels
=
self
.
tf_mask_tokens
(
batch
)
return
{
"input_ids"
:
inputs
,
"perm_mask"
:
perm_mask
,
"target_mapping"
:
target_mapping
,
"labels"
:
labels
}
def
numpy_call
(
self
,
examples
:
List
[
Union
[
List
[
int
],
Any
,
Dict
[
str
,
Any
]]])
->
Dict
[
str
,
Any
]:
if
isinstance
(
examples
[
0
],
(
dict
,
BatchEncod
ing
)
)
:
if
isinstance
(
examples
[
0
],
Mapp
ing
):
examples
=
[
e
[
"input_ids"
]
for
e
in
examples
]
batch
=
_numpy_collate_batch
(
examples
,
self
.
tokenizer
)
inputs
,
perm_mask
,
target_mapping
,
labels
=
self
.
numpy_mask_tokens
(
batch
)
...
...
src/transformers/modeling_tf_utils.py
View file @
18df4407
...
...
@@ -21,6 +21,7 @@ import os
import
pickle
import
re
import
warnings
from
collections.abc
import
Mapping
from
typing
import
Dict
,
List
,
Optional
,
Union
import
h5py
...
...
@@ -39,7 +40,6 @@ from .configuration_utils import PretrainedConfig
from
.dynamic_module_utils
import
custom_object_save
from
.generation_tf_utils
import
TFGenerationMixin
from
.tf_utils
import
shape_list
from
.tokenization_utils_base
import
BatchEncoding
from
.utils
import
(
DUMMY_INPUTS
,
HUGGINGFACE_CO_RESOLVE_ENDPOINT
,
...
...
@@ -471,7 +471,7 @@ def input_processing(func, config, input_ids, **kwargs):
raise
ValueError
(
f
"Data of type
{
type
(
input
)
}
is not allowed only
{
allowed_types
}
is accepted for
{
parameter_names
[
i
]
}
."
)
elif
isinstance
(
input_ids
,
(
dict
,
BatchEncod
ing
)
)
:
elif
isinstance
(
input_ids
,
Mapp
ing
):
if
"inputs"
in
input_ids
:
warnings
.
warn
(
"The `inputs` argument is deprecated and will be removed in a future version, use `input_ids` instead."
,
...
...
src/transformers/models/hubert/modeling_tf_hubert.py
View file @
18df4407
...
...
@@ -15,6 +15,7 @@
""" TensorFlow Hubert model."""
import
inspect
import
warnings
from
collections.abc
import
Mapping
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
,
Union
import
numpy
as
np
...
...
@@ -24,7 +25,6 @@ from ...activations_tf import get_tf_activation
from
...modeling_tf_outputs
import
TFBaseModelOutput
,
TFCausalLMOutput
from
...modeling_tf_utils
import
TFPreTrainedModel
,
booleans_processing
,
get_initializer
,
keras_serializable
from
...tf_utils
import
shape_list
,
stable_softmax
from
...tokenization_utils_base
import
BatchEncoding
from
...utils
import
(
ModelOutput
,
add_start_docstrings
,
...
...
@@ -97,7 +97,7 @@ def input_values_processing(func, config, input_values, **kwargs):
raise
ValueError
(
f
"Data of type
{
type
(
input
)
}
is not allowed only
{
allowed_types
}
is accepted for
{
parameter_names
[
i
]
}
."
)
elif
isinstance
(
input_values
,
(
dict
,
BatchEncod
ing
)
)
:
elif
isinstance
(
input_values
,
Mapp
ing
):
if
"inputs"
in
input_values
:
warnings
.
warn
(
"The `inputs` argument is deprecated and will be removed in a future version, use `input_values` instead."
,
...
...
src/transformers/models/luke/tokenization_luke.py
View file @
18df4407
...
...
@@ -17,6 +17,7 @@
import
itertools
import
json
import
os
from
collections.abc
import
Mapping
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
...
...
@@ -1140,7 +1141,7 @@ class LukeTokenizer(RobertaTokenizer):
"""
# If we have a list of dicts, let's convert it in a dict of lists
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
if
isinstance
(
encoded_inputs
,
(
list
,
tuple
))
and
isinstance
(
encoded_inputs
[
0
],
(
dict
,
BatchEncod
ing
)
)
:
if
isinstance
(
encoded_inputs
,
(
list
,
tuple
))
and
isinstance
(
encoded_inputs
[
0
],
Mapp
ing
):
encoded_inputs
=
{
key
:
[
example
[
key
]
for
example
in
encoded_inputs
]
for
key
in
encoded_inputs
[
0
].
keys
()}
# The model's main input name, usually `input_ids`, has be passed for padding
...
...
src/transformers/models/mluke/tokenization_mluke.py
View file @
18df4407
...
...
@@ -18,6 +18,7 @@
import
itertools
import
json
import
os
from
collections.abc
import
Mapping
from
shutil
import
copyfile
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
...
...
@@ -1253,7 +1254,7 @@ class MLukeTokenizer(PreTrainedTokenizer):
"""
# If we have a list of dicts, let's convert it in a dict of lists
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
if
isinstance
(
encoded_inputs
,
(
list
,
tuple
))
and
isinstance
(
encoded_inputs
[
0
],
(
dict
,
BatchEncod
ing
)
)
:
if
isinstance
(
encoded_inputs
,
(
list
,
tuple
))
and
isinstance
(
encoded_inputs
[
0
],
Mapp
ing
):
encoded_inputs
=
{
key
:
[
example
[
key
]
for
example
in
encoded_inputs
]
for
key
in
encoded_inputs
[
0
].
keys
()}
# The model's main input name, usually `input_ids`, has be passed for padding
...
...
src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
View file @
18df4407
...
...
@@ -16,6 +16,7 @@
import
inspect
import
warnings
from
collections.abc
import
Mapping
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
,
Union
...
...
@@ -26,7 +27,6 @@ from ...activations_tf import get_tf_activation
from
...modeling_tf_outputs
import
TFBaseModelOutput
,
TFCausalLMOutput
from
...modeling_tf_utils
import
TFPreTrainedModel
,
booleans_processing
,
get_initializer
,
keras_serializable
from
...tf_utils
import
shape_list
,
stable_softmax
from
...tokenization_utils_base
import
BatchEncoding
from
...utils
import
(
ModelOutput
,
add_start_docstrings
,
...
...
@@ -135,7 +135,7 @@ def input_values_processing(func, config, input_values, **kwargs):
raise
ValueError
(
f
"Data of type
{
type
(
input
)
}
is not allowed only
{
allowed_types
}
is accepted for
{
parameter_names
[
i
]
}
."
)
elif
isinstance
(
input_values
,
(
dict
,
BatchEncod
ing
)
)
:
elif
isinstance
(
input_values
,
Mapp
ing
):
if
"inputs"
in
input_values
:
warnings
.
warn
(
"The `inputs` argument is deprecated and will be removed in a future version, use `input_values` instead."
,
...
...
src/transformers/testing_utils.py
View file @
18df4407
...
...
@@ -22,6 +22,7 @@ import shutil
import
sys
import
tempfile
import
unittest
from
collections.abc
import
Mapping
from
distutils.util
import
strtobool
from
io
import
StringIO
from
pathlib
import
Path
...
...
@@ -1459,13 +1460,11 @@ def nested_simplify(obj, decimals=3):
"""
import
numpy
as
np
from
transformers.tokenization_utils
import
BatchEncoding
if
isinstance
(
obj
,
list
):
return
[
nested_simplify
(
item
,
decimals
)
for
item
in
obj
]
elif
isinstance
(
obj
,
np
.
ndarray
):
return
nested_simplify
(
obj
.
tolist
())
elif
isinstance
(
obj
,
(
dict
,
BatchEncod
ing
)
)
:
elif
isinstance
(
obj
,
Mapp
ing
):
return
{
nested_simplify
(
k
,
decimals
):
nested_simplify
(
v
,
decimals
)
for
k
,
v
in
obj
.
items
()}
elif
isinstance
(
obj
,
(
str
,
int
,
np
.
int64
)):
return
obj
...
...
src/transformers/tokenization_utils_base.py
View file @
18df4407
...
...
@@ -24,6 +24,7 @@ import os
import
re
import
warnings
from
collections
import
OrderedDict
,
UserDict
from
collections.abc
import
Mapping
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
NamedTuple
,
Optional
,
Sequence
,
Tuple
,
Union
...
...
@@ -2768,7 +2769,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
"""
# If we have a list of dicts, let's convert it in a dict of lists
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
if
isinstance
(
encoded_inputs
,
(
list
,
tuple
))
and
isinstance
(
encoded_inputs
[
0
],
(
dict
,
BatchEncod
ing
)
)
:
if
isinstance
(
encoded_inputs
,
(
list
,
tuple
))
and
isinstance
(
encoded_inputs
[
0
],
Mapp
ing
):
encoded_inputs
=
{
key
:
[
example
[
key
]
for
example
in
encoded_inputs
]
for
key
in
encoded_inputs
[
0
].
keys
()}
# The model's main input name, usually `input_ids`, has be passed for padding
...
...
src/transformers/trainer_pt_utils.py
View file @
18df4407
...
...
@@ -22,6 +22,7 @@ import math
import
os
import
sys
import
warnings
from
collections.abc
import
Mapping
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
logging
import
StreamHandler
...
...
@@ -111,7 +112,7 @@ def find_batch_size(tensors):
result
=
find_batch_size
(
t
)
if
result
is
not
None
:
return
result
elif
isinstance
(
tensors
,
(
dict
,
BatchEncod
ing
)
)
:
elif
isinstance
(
tensors
,
Mapp
ing
):
for
key
,
value
in
tensors
.
items
():
result
=
find_batch_size
(
value
)
if
result
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment