Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
18df4407
Unverified
Commit
18df4407
authored
Apr 29, 2022
by
Sylvain Gugger
Committed by
GitHub
Apr 29, 2022
Browse files
Replace dict/BatchEncoding instance checks by Mapping (#17014)
* Replace dict/BatchEncoding instance checks by Mapping * Typo
parent
b8dffd1f
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
30 additions
and
26 deletions
+30
-26
src/transformers/data/data_collator.py
src/transformers/data/data_collator.py
+14
-13
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+2
-2
src/transformers/models/hubert/modeling_tf_hubert.py
src/transformers/models/hubert/modeling_tf_hubert.py
+2
-2
src/transformers/models/luke/tokenization_luke.py
src/transformers/models/luke/tokenization_luke.py
+2
-1
src/transformers/models/mluke/tokenization_mluke.py
src/transformers/models/mluke/tokenization_mluke.py
+2
-1
src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
+2
-2
src/transformers/testing_utils.py
src/transformers/testing_utils.py
+2
-3
src/transformers/tokenization_utils_base.py
src/transformers/tokenization_utils_base.py
+2
-1
src/transformers/trainer_pt_utils.py
src/transformers/trainer_pt_utils.py
+2
-1
No files found.
src/transformers/data/data_collator.py
View file @
18df4407
...
@@ -14,11 +14,12 @@
...
@@ -14,11 +14,12 @@
import
random
import
random
import
warnings
import
warnings
from
collections.abc
import
Mapping
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Any
,
Callable
,
Dict
,
List
,
NewType
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
NewType
,
Optional
,
Tuple
,
Union
from
..models.bert
import
BertTokenizer
,
BertTokenizerFast
from
..models.bert
import
BertTokenizer
,
BertTokenizerFast
from
..tokenization_utils_base
import
BatchEncoding
,
PreTrainedTokenizerBase
from
..tokenization_utils_base
import
PreTrainedTokenizerBase
from
..utils
import
PaddingStrategy
from
..utils
import
PaddingStrategy
...
@@ -101,7 +102,7 @@ class DefaultDataCollator(DataCollatorMixin):
...
@@ -101,7 +102,7 @@ class DefaultDataCollator(DataCollatorMixin):
def
torch_default_data_collator
(
features
:
List
[
InputDataClass
])
->
Dict
[
str
,
Any
]:
def
torch_default_data_collator
(
features
:
List
[
InputDataClass
])
->
Dict
[
str
,
Any
]:
import
torch
import
torch
if
not
isinstance
(
features
[
0
],
(
dict
,
BatchEncod
ing
)
)
:
if
not
isinstance
(
features
[
0
],
Mapp
ing
):
features
=
[
vars
(
f
)
for
f
in
features
]
features
=
[
vars
(
f
)
for
f
in
features
]
first
=
features
[
0
]
first
=
features
[
0
]
batch
=
{}
batch
=
{}
...
@@ -136,7 +137,7 @@ def tf_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
...
@@ -136,7 +137,7 @@ def tf_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
if
not
isinstance
(
features
[
0
],
(
dict
,
BatchEncod
ing
)
)
:
if
not
isinstance
(
features
[
0
],
Mapp
ing
):
features
=
[
vars
(
f
)
for
f
in
features
]
features
=
[
vars
(
f
)
for
f
in
features
]
first
=
features
[
0
]
first
=
features
[
0
]
batch
=
{}
batch
=
{}
...
@@ -177,7 +178,7 @@ def tf_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
...
@@ -177,7 +178,7 @@ def tf_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
def
numpy_default_data_collator
(
features
:
List
[
InputDataClass
])
->
Dict
[
str
,
Any
]:
def
numpy_default_data_collator
(
features
:
List
[
InputDataClass
])
->
Dict
[
str
,
Any
]:
import
numpy
as
np
import
numpy
as
np
if
not
isinstance
(
features
[
0
],
(
dict
,
BatchEncod
ing
)
)
:
if
not
isinstance
(
features
[
0
],
Mapp
ing
):
features
=
[
vars
(
f
)
for
f
in
features
]
features
=
[
vars
(
f
)
for
f
in
features
]
first
=
features
[
0
]
first
=
features
[
0
]
batch
=
{}
batch
=
{}
...
@@ -687,7 +688,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
...
@@ -687,7 +688,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
import
tensorflow
as
tf
import
tensorflow
as
tf
# Handle dict or lists with proper padding and conversion to tensor.
# Handle dict or lists with proper padding and conversion to tensor.
if
isinstance
(
examples
[
0
],
(
dict
,
BatchEncod
ing
)
)
:
if
isinstance
(
examples
[
0
],
Mapp
ing
):
batch
=
self
.
tokenizer
.
pad
(
examples
,
return_tensors
=
"tf"
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
batch
=
self
.
tokenizer
.
pad
(
examples
,
return_tensors
=
"tf"
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
else
:
else
:
batch
=
{
batch
=
{
...
@@ -724,7 +725,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
...
@@ -724,7 +725,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
def
torch_call
(
self
,
examples
:
List
[
Union
[
List
[
int
],
Any
,
Dict
[
str
,
Any
]]])
->
Dict
[
str
,
Any
]:
def
torch_call
(
self
,
examples
:
List
[
Union
[
List
[
int
],
Any
,
Dict
[
str
,
Any
]]])
->
Dict
[
str
,
Any
]:
# Handle dict or lists with proper padding and conversion to tensor.
# Handle dict or lists with proper padding and conversion to tensor.
if
isinstance
(
examples
[
0
],
(
dict
,
BatchEncod
ing
)
)
:
if
isinstance
(
examples
[
0
],
Mapp
ing
):
batch
=
self
.
tokenizer
.
pad
(
examples
,
return_tensors
=
"pt"
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
batch
=
self
.
tokenizer
.
pad
(
examples
,
return_tensors
=
"pt"
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
else
:
else
:
batch
=
{
batch
=
{
...
@@ -781,7 +782,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
...
@@ -781,7 +782,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
import
numpy
as
np
import
numpy
as
np
# Handle dict or lists with proper padding and conversion to tensor.
# Handle dict or lists with proper padding and conversion to tensor.
if
isinstance
(
examples
[
0
],
(
dict
,
BatchEncod
ing
)
)
:
if
isinstance
(
examples
[
0
],
Mapp
ing
):
batch
=
self
.
tokenizer
.
pad
(
examples
,
return_tensors
=
"np"
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
batch
=
self
.
tokenizer
.
pad
(
examples
,
return_tensors
=
"np"
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
else
:
else
:
batch
=
{
batch
=
{
...
@@ -858,7 +859,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
...
@@ -858,7 +859,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
</Tip>"""
</Tip>"""
def
torch_call
(
self
,
examples
:
List
[
Union
[
List
[
int
],
Any
,
Dict
[
str
,
Any
]]])
->
Dict
[
str
,
Any
]:
def
torch_call
(
self
,
examples
:
List
[
Union
[
List
[
int
],
Any
,
Dict
[
str
,
Any
]]])
->
Dict
[
str
,
Any
]:
if
isinstance
(
examples
[
0
],
(
dict
,
BatchEncod
ing
)
)
:
if
isinstance
(
examples
[
0
],
Mapp
ing
):
input_ids
=
[
e
[
"input_ids"
]
for
e
in
examples
]
input_ids
=
[
e
[
"input_ids"
]
for
e
in
examples
]
else
:
else
:
input_ids
=
examples
input_ids
=
examples
...
@@ -886,7 +887,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
...
@@ -886,7 +887,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
return
{
"input_ids"
:
inputs
,
"labels"
:
labels
}
return
{
"input_ids"
:
inputs
,
"labels"
:
labels
}
def
tf_call
(
self
,
examples
:
List
[
Union
[
List
[
int
],
Any
,
Dict
[
str
,
Any
]]])
->
Dict
[
str
,
Any
]:
def
tf_call
(
self
,
examples
:
List
[
Union
[
List
[
int
],
Any
,
Dict
[
str
,
Any
]]])
->
Dict
[
str
,
Any
]:
if
isinstance
(
examples
[
0
],
(
dict
,
BatchEncod
ing
)
)
:
if
isinstance
(
examples
[
0
],
Mapp
ing
):
input_ids
=
[
e
[
"input_ids"
]
for
e
in
examples
]
input_ids
=
[
e
[
"input_ids"
]
for
e
in
examples
]
else
:
else
:
input_ids
=
examples
input_ids
=
examples
...
@@ -914,7 +915,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
...
@@ -914,7 +915,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
return
{
"input_ids"
:
inputs
,
"labels"
:
labels
}
return
{
"input_ids"
:
inputs
,
"labels"
:
labels
}
def
numpy_call
(
self
,
examples
:
List
[
Union
[
List
[
int
],
Any
,
Dict
[
str
,
Any
]]])
->
Dict
[
str
,
Any
]:
def
numpy_call
(
self
,
examples
:
List
[
Union
[
List
[
int
],
Any
,
Dict
[
str
,
Any
]]])
->
Dict
[
str
,
Any
]:
if
isinstance
(
examples
[
0
],
(
dict
,
BatchEncod
ing
)
)
:
if
isinstance
(
examples
[
0
],
Mapp
ing
):
input_ids
=
[
e
[
"input_ids"
]
for
e
in
examples
]
input_ids
=
[
e
[
"input_ids"
]
for
e
in
examples
]
else
:
else
:
input_ids
=
examples
input_ids
=
examples
...
@@ -1207,21 +1208,21 @@ class DataCollatorForPermutationLanguageModeling(DataCollatorMixin):
...
@@ -1207,21 +1208,21 @@ class DataCollatorForPermutationLanguageModeling(DataCollatorMixin):
return_tensors
:
str
=
"pt"
return_tensors
:
str
=
"pt"
def
torch_call
(
self
,
examples
:
List
[
Union
[
List
[
int
],
Any
,
Dict
[
str
,
Any
]]])
->
Dict
[
str
,
Any
]:
def
torch_call
(
self
,
examples
:
List
[
Union
[
List
[
int
],
Any
,
Dict
[
str
,
Any
]]])
->
Dict
[
str
,
Any
]:
if
isinstance
(
examples
[
0
],
(
dict
,
BatchEncod
ing
)
)
:
if
isinstance
(
examples
[
0
],
Mapp
ing
):
examples
=
[
e
[
"input_ids"
]
for
e
in
examples
]
examples
=
[
e
[
"input_ids"
]
for
e
in
examples
]
batch
=
_torch_collate_batch
(
examples
,
self
.
tokenizer
)
batch
=
_torch_collate_batch
(
examples
,
self
.
tokenizer
)
inputs
,
perm_mask
,
target_mapping
,
labels
=
self
.
torch_mask_tokens
(
batch
)
inputs
,
perm_mask
,
target_mapping
,
labels
=
self
.
torch_mask_tokens
(
batch
)
return
{
"input_ids"
:
inputs
,
"perm_mask"
:
perm_mask
,
"target_mapping"
:
target_mapping
,
"labels"
:
labels
}
return
{
"input_ids"
:
inputs
,
"perm_mask"
:
perm_mask
,
"target_mapping"
:
target_mapping
,
"labels"
:
labels
}
def
tf_call
(
self
,
examples
:
List
[
Union
[
List
[
int
],
Any
,
Dict
[
str
,
Any
]]])
->
Dict
[
str
,
Any
]:
def
tf_call
(
self
,
examples
:
List
[
Union
[
List
[
int
],
Any
,
Dict
[
str
,
Any
]]])
->
Dict
[
str
,
Any
]:
if
isinstance
(
examples
[
0
],
(
dict
,
BatchEncod
ing
)
)
:
if
isinstance
(
examples
[
0
],
Mapp
ing
):
examples
=
[
e
[
"input_ids"
]
for
e
in
examples
]
examples
=
[
e
[
"input_ids"
]
for
e
in
examples
]
batch
=
_tf_collate_batch
(
examples
,
self
.
tokenizer
)
batch
=
_tf_collate_batch
(
examples
,
self
.
tokenizer
)
inputs
,
perm_mask
,
target_mapping
,
labels
=
self
.
tf_mask_tokens
(
batch
)
inputs
,
perm_mask
,
target_mapping
,
labels
=
self
.
tf_mask_tokens
(
batch
)
return
{
"input_ids"
:
inputs
,
"perm_mask"
:
perm_mask
,
"target_mapping"
:
target_mapping
,
"labels"
:
labels
}
return
{
"input_ids"
:
inputs
,
"perm_mask"
:
perm_mask
,
"target_mapping"
:
target_mapping
,
"labels"
:
labels
}
def
numpy_call
(
self
,
examples
:
List
[
Union
[
List
[
int
],
Any
,
Dict
[
str
,
Any
]]])
->
Dict
[
str
,
Any
]:
def
numpy_call
(
self
,
examples
:
List
[
Union
[
List
[
int
],
Any
,
Dict
[
str
,
Any
]]])
->
Dict
[
str
,
Any
]:
if
isinstance
(
examples
[
0
],
(
dict
,
BatchEncod
ing
)
)
:
if
isinstance
(
examples
[
0
],
Mapp
ing
):
examples
=
[
e
[
"input_ids"
]
for
e
in
examples
]
examples
=
[
e
[
"input_ids"
]
for
e
in
examples
]
batch
=
_numpy_collate_batch
(
examples
,
self
.
tokenizer
)
batch
=
_numpy_collate_batch
(
examples
,
self
.
tokenizer
)
inputs
,
perm_mask
,
target_mapping
,
labels
=
self
.
numpy_mask_tokens
(
batch
)
inputs
,
perm_mask
,
target_mapping
,
labels
=
self
.
numpy_mask_tokens
(
batch
)
...
...
src/transformers/modeling_tf_utils.py
View file @
18df4407
...
@@ -21,6 +21,7 @@ import os
...
@@ -21,6 +21,7 @@ import os
import
pickle
import
pickle
import
re
import
re
import
warnings
import
warnings
from
collections.abc
import
Mapping
from
typing
import
Dict
,
List
,
Optional
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Union
import
h5py
import
h5py
...
@@ -39,7 +40,6 @@ from .configuration_utils import PretrainedConfig
...
@@ -39,7 +40,6 @@ from .configuration_utils import PretrainedConfig
from
.dynamic_module_utils
import
custom_object_save
from
.dynamic_module_utils
import
custom_object_save
from
.generation_tf_utils
import
TFGenerationMixin
from
.generation_tf_utils
import
TFGenerationMixin
from
.tf_utils
import
shape_list
from
.tf_utils
import
shape_list
from
.tokenization_utils_base
import
BatchEncoding
from
.utils
import
(
from
.utils
import
(
DUMMY_INPUTS
,
DUMMY_INPUTS
,
HUGGINGFACE_CO_RESOLVE_ENDPOINT
,
HUGGINGFACE_CO_RESOLVE_ENDPOINT
,
...
@@ -471,7 +471,7 @@ def input_processing(func, config, input_ids, **kwargs):
...
@@ -471,7 +471,7 @@ def input_processing(func, config, input_ids, **kwargs):
raise
ValueError
(
raise
ValueError
(
f
"Data of type
{
type
(
input
)
}
is not allowed only
{
allowed_types
}
is accepted for
{
parameter_names
[
i
]
}
."
f
"Data of type
{
type
(
input
)
}
is not allowed only
{
allowed_types
}
is accepted for
{
parameter_names
[
i
]
}
."
)
)
elif
isinstance
(
input_ids
,
(
dict
,
BatchEncod
ing
)
)
:
elif
isinstance
(
input_ids
,
Mapp
ing
):
if
"inputs"
in
input_ids
:
if
"inputs"
in
input_ids
:
warnings
.
warn
(
warnings
.
warn
(
"The `inputs` argument is deprecated and will be removed in a future version, use `input_ids` instead."
,
"The `inputs` argument is deprecated and will be removed in a future version, use `input_ids` instead."
,
...
...
src/transformers/models/hubert/modeling_tf_hubert.py
View file @
18df4407
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
""" TensorFlow Hubert model."""
""" TensorFlow Hubert model."""
import
inspect
import
inspect
import
warnings
import
warnings
from
collections.abc
import
Mapping
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -24,7 +25,6 @@ from ...activations_tf import get_tf_activation
...
@@ -24,7 +25,6 @@ from ...activations_tf import get_tf_activation
from
...modeling_tf_outputs
import
TFBaseModelOutput
,
TFCausalLMOutput
from
...modeling_tf_outputs
import
TFBaseModelOutput
,
TFCausalLMOutput
from
...modeling_tf_utils
import
TFPreTrainedModel
,
booleans_processing
,
get_initializer
,
keras_serializable
from
...modeling_tf_utils
import
TFPreTrainedModel
,
booleans_processing
,
get_initializer
,
keras_serializable
from
...tf_utils
import
shape_list
,
stable_softmax
from
...tf_utils
import
shape_list
,
stable_softmax
from
...tokenization_utils_base
import
BatchEncoding
from
...utils
import
(
from
...utils
import
(
ModelOutput
,
ModelOutput
,
add_start_docstrings
,
add_start_docstrings
,
...
@@ -97,7 +97,7 @@ def input_values_processing(func, config, input_values, **kwargs):
...
@@ -97,7 +97,7 @@ def input_values_processing(func, config, input_values, **kwargs):
raise
ValueError
(
raise
ValueError
(
f
"Data of type
{
type
(
input
)
}
is not allowed only
{
allowed_types
}
is accepted for
{
parameter_names
[
i
]
}
."
f
"Data of type
{
type
(
input
)
}
is not allowed only
{
allowed_types
}
is accepted for
{
parameter_names
[
i
]
}
."
)
)
elif
isinstance
(
input_values
,
(
dict
,
BatchEncod
ing
)
)
:
elif
isinstance
(
input_values
,
Mapp
ing
):
if
"inputs"
in
input_values
:
if
"inputs"
in
input_values
:
warnings
.
warn
(
warnings
.
warn
(
"The `inputs` argument is deprecated and will be removed in a future version, use `input_values` instead."
,
"The `inputs` argument is deprecated and will be removed in a future version, use `input_values` instead."
,
...
...
src/transformers/models/luke/tokenization_luke.py
View file @
18df4407
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
import
itertools
import
itertools
import
json
import
json
import
os
import
os
from
collections.abc
import
Mapping
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -1140,7 +1141,7 @@ class LukeTokenizer(RobertaTokenizer):
...
@@ -1140,7 +1141,7 @@ class LukeTokenizer(RobertaTokenizer):
"""
"""
# If we have a list of dicts, let's convert it in a dict of lists
# If we have a list of dicts, let's convert it in a dict of lists
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
if
isinstance
(
encoded_inputs
,
(
list
,
tuple
))
and
isinstance
(
encoded_inputs
[
0
],
(
dict
,
BatchEncod
ing
)
)
:
if
isinstance
(
encoded_inputs
,
(
list
,
tuple
))
and
isinstance
(
encoded_inputs
[
0
],
Mapp
ing
):
encoded_inputs
=
{
key
:
[
example
[
key
]
for
example
in
encoded_inputs
]
for
key
in
encoded_inputs
[
0
].
keys
()}
encoded_inputs
=
{
key
:
[
example
[
key
]
for
example
in
encoded_inputs
]
for
key
in
encoded_inputs
[
0
].
keys
()}
# The model's main input name, usually `input_ids`, has be passed for padding
# The model's main input name, usually `input_ids`, has be passed for padding
...
...
src/transformers/models/mluke/tokenization_mluke.py
View file @
18df4407
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
import
itertools
import
itertools
import
json
import
json
import
os
import
os
from
collections.abc
import
Mapping
from
shutil
import
copyfile
from
shutil
import
copyfile
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
...
@@ -1253,7 +1254,7 @@ class MLukeTokenizer(PreTrainedTokenizer):
...
@@ -1253,7 +1254,7 @@ class MLukeTokenizer(PreTrainedTokenizer):
"""
"""
# If we have a list of dicts, let's convert it in a dict of lists
# If we have a list of dicts, let's convert it in a dict of lists
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
if
isinstance
(
encoded_inputs
,
(
list
,
tuple
))
and
isinstance
(
encoded_inputs
[
0
],
(
dict
,
BatchEncod
ing
)
)
:
if
isinstance
(
encoded_inputs
,
(
list
,
tuple
))
and
isinstance
(
encoded_inputs
[
0
],
Mapp
ing
):
encoded_inputs
=
{
key
:
[
example
[
key
]
for
example
in
encoded_inputs
]
for
key
in
encoded_inputs
[
0
].
keys
()}
encoded_inputs
=
{
key
:
[
example
[
key
]
for
example
in
encoded_inputs
]
for
key
in
encoded_inputs
[
0
].
keys
()}
# The model's main input name, usually `input_ids`, has be passed for padding
# The model's main input name, usually `input_ids`, has be passed for padding
...
...
src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
View file @
18df4407
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
import
inspect
import
inspect
import
warnings
import
warnings
from
collections.abc
import
Mapping
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
,
Union
...
@@ -26,7 +27,6 @@ from ...activations_tf import get_tf_activation
...
@@ -26,7 +27,6 @@ from ...activations_tf import get_tf_activation
from
...modeling_tf_outputs
import
TFBaseModelOutput
,
TFCausalLMOutput
from
...modeling_tf_outputs
import
TFBaseModelOutput
,
TFCausalLMOutput
from
...modeling_tf_utils
import
TFPreTrainedModel
,
booleans_processing
,
get_initializer
,
keras_serializable
from
...modeling_tf_utils
import
TFPreTrainedModel
,
booleans_processing
,
get_initializer
,
keras_serializable
from
...tf_utils
import
shape_list
,
stable_softmax
from
...tf_utils
import
shape_list
,
stable_softmax
from
...tokenization_utils_base
import
BatchEncoding
from
...utils
import
(
from
...utils
import
(
ModelOutput
,
ModelOutput
,
add_start_docstrings
,
add_start_docstrings
,
...
@@ -135,7 +135,7 @@ def input_values_processing(func, config, input_values, **kwargs):
...
@@ -135,7 +135,7 @@ def input_values_processing(func, config, input_values, **kwargs):
raise
ValueError
(
raise
ValueError
(
f
"Data of type
{
type
(
input
)
}
is not allowed only
{
allowed_types
}
is accepted for
{
parameter_names
[
i
]
}
."
f
"Data of type
{
type
(
input
)
}
is not allowed only
{
allowed_types
}
is accepted for
{
parameter_names
[
i
]
}
."
)
)
elif
isinstance
(
input_values
,
(
dict
,
BatchEncod
ing
)
)
:
elif
isinstance
(
input_values
,
Mapp
ing
):
if
"inputs"
in
input_values
:
if
"inputs"
in
input_values
:
warnings
.
warn
(
warnings
.
warn
(
"The `inputs` argument is deprecated and will be removed in a future version, use `input_values` instead."
,
"The `inputs` argument is deprecated and will be removed in a future version, use `input_values` instead."
,
...
...
src/transformers/testing_utils.py
View file @
18df4407
...
@@ -22,6 +22,7 @@ import shutil
...
@@ -22,6 +22,7 @@ import shutil
import
sys
import
sys
import
tempfile
import
tempfile
import
unittest
import
unittest
from
collections.abc
import
Mapping
from
distutils.util
import
strtobool
from
distutils.util
import
strtobool
from
io
import
StringIO
from
io
import
StringIO
from
pathlib
import
Path
from
pathlib
import
Path
...
@@ -1459,13 +1460,11 @@ def nested_simplify(obj, decimals=3):
...
@@ -1459,13 +1460,11 @@ def nested_simplify(obj, decimals=3):
"""
"""
import
numpy
as
np
import
numpy
as
np
from
transformers.tokenization_utils
import
BatchEncoding
if
isinstance
(
obj
,
list
):
if
isinstance
(
obj
,
list
):
return
[
nested_simplify
(
item
,
decimals
)
for
item
in
obj
]
return
[
nested_simplify
(
item
,
decimals
)
for
item
in
obj
]
elif
isinstance
(
obj
,
np
.
ndarray
):
elif
isinstance
(
obj
,
np
.
ndarray
):
return
nested_simplify
(
obj
.
tolist
())
return
nested_simplify
(
obj
.
tolist
())
elif
isinstance
(
obj
,
(
dict
,
BatchEncod
ing
)
)
:
elif
isinstance
(
obj
,
Mapp
ing
):
return
{
nested_simplify
(
k
,
decimals
):
nested_simplify
(
v
,
decimals
)
for
k
,
v
in
obj
.
items
()}
return
{
nested_simplify
(
k
,
decimals
):
nested_simplify
(
v
,
decimals
)
for
k
,
v
in
obj
.
items
()}
elif
isinstance
(
obj
,
(
str
,
int
,
np
.
int64
)):
elif
isinstance
(
obj
,
(
str
,
int
,
np
.
int64
)):
return
obj
return
obj
...
...
src/transformers/tokenization_utils_base.py
View file @
18df4407
...
@@ -24,6 +24,7 @@ import os
...
@@ -24,6 +24,7 @@ import os
import
re
import
re
import
warnings
import
warnings
from
collections
import
OrderedDict
,
UserDict
from
collections
import
OrderedDict
,
UserDict
from
collections.abc
import
Mapping
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
NamedTuple
,
Optional
,
Sequence
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
NamedTuple
,
Optional
,
Sequence
,
Tuple
,
Union
...
@@ -2768,7 +2769,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
...
@@ -2768,7 +2769,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
"""
"""
# If we have a list of dicts, let's convert it in a dict of lists
# If we have a list of dicts, let's convert it in a dict of lists
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
if
isinstance
(
encoded_inputs
,
(
list
,
tuple
))
and
isinstance
(
encoded_inputs
[
0
],
(
dict
,
BatchEncod
ing
)
)
:
if
isinstance
(
encoded_inputs
,
(
list
,
tuple
))
and
isinstance
(
encoded_inputs
[
0
],
Mapp
ing
):
encoded_inputs
=
{
key
:
[
example
[
key
]
for
example
in
encoded_inputs
]
for
key
in
encoded_inputs
[
0
].
keys
()}
encoded_inputs
=
{
key
:
[
example
[
key
]
for
example
in
encoded_inputs
]
for
key
in
encoded_inputs
[
0
].
keys
()}
# The model's main input name, usually `input_ids`, has be passed for padding
# The model's main input name, usually `input_ids`, has be passed for padding
...
...
src/transformers/trainer_pt_utils.py
View file @
18df4407
...
@@ -22,6 +22,7 @@ import math
...
@@ -22,6 +22,7 @@ import math
import
os
import
os
import
sys
import
sys
import
warnings
import
warnings
from
collections.abc
import
Mapping
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
logging
import
StreamHandler
from
logging
import
StreamHandler
...
@@ -111,7 +112,7 @@ def find_batch_size(tensors):
...
@@ -111,7 +112,7 @@ def find_batch_size(tensors):
result
=
find_batch_size
(
t
)
result
=
find_batch_size
(
t
)
if
result
is
not
None
:
if
result
is
not
None
:
return
result
return
result
elif
isinstance
(
tensors
,
(
dict
,
BatchEncod
ing
)
)
:
elif
isinstance
(
tensors
,
Mapp
ing
):
for
key
,
value
in
tensors
.
items
():
for
key
,
value
in
tensors
.
items
():
result
=
find_batch_size
(
value
)
result
=
find_batch_size
(
value
)
if
result
is
not
None
:
if
result
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment