Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
aa4a0f8e
Unverified
Commit
aa4a0f8e
authored
Jan 02, 2024
by
Daniel Bustamante Ospina
Committed by
GitHub
Jan 02, 2024
Browse files
Remove fast tokenization warning in Data Collators (#28213)
parent
5be46dfc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
41 additions
and
8 deletions
+41
-8
src/transformers/data/data_collator.py
src/transformers/data/data_collator.py
+41
-8
No files found.
src/transformers/data/data_collator.py
View file @
aa4a0f8e
...
@@ -49,6 +49,28 @@ class DataCollatorMixin:
...
@@ -49,6 +49,28 @@ class DataCollatorMixin:
raise
ValueError
(
f
"Framework '
{
return_tensors
}
' not recognized!"
)
raise
ValueError
(
f
"Framework '
{
return_tensors
}
' not recognized!"
)
def
pad_without_fast_tokenizer_warning
(
tokenizer
,
*
pad_args
,
**
pad_kwargs
):
"""
Pads without triggering the warning about how using the pad function is sub-optimal when using a fast tokenizer.
"""
# To avoid errors when using Feature extractors
if
not
hasattr
(
tokenizer
,
"deprecation_warnings"
):
return
tokenizer
.
pad
(
*
pad_args
,
**
pad_kwargs
)
# Save the state of the warning, then disable it
warning_state
=
tokenizer
.
deprecation_warnings
.
get
(
"Asking-to-pad-a-fast-tokenizer"
,
False
)
tokenizer
.
deprecation_warnings
[
"Asking-to-pad-a-fast-tokenizer"
]
=
True
try
:
padded
=
tokenizer
.
pad
(
*
pad_args
,
**
pad_kwargs
)
finally
:
# Restore the state of the warning.
tokenizer
.
deprecation_warnings
[
"Asking-to-pad-a-fast-tokenizer"
]
=
warning_state
return
padded
def
default_data_collator
(
features
:
List
[
InputDataClass
],
return_tensors
=
"pt"
)
->
Dict
[
str
,
Any
]:
def
default_data_collator
(
features
:
List
[
InputDataClass
],
return_tensors
=
"pt"
)
->
Dict
[
str
,
Any
]:
"""
"""
Very simple data collator that simply collates batches of dict-like objects and performs special handling for
Very simple data collator that simply collates batches of dict-like objects and performs special handling for
...
@@ -246,7 +268,8 @@ class DataCollatorWithPadding:
...
@@ -246,7 +268,8 @@ class DataCollatorWithPadding:
return_tensors
:
str
=
"pt"
return_tensors
:
str
=
"pt"
def
__call__
(
self
,
features
:
List
[
Dict
[
str
,
Any
]])
->
Dict
[
str
,
Any
]:
def
__call__
(
self
,
features
:
List
[
Dict
[
str
,
Any
]])
->
Dict
[
str
,
Any
]:
batch
=
self
.
tokenizer
.
pad
(
batch
=
pad_without_fast_tokenizer_warning
(
self
.
tokenizer
,
features
,
features
,
padding
=
self
.
padding
,
padding
=
self
.
padding
,
max_length
=
self
.
max_length
,
max_length
=
self
.
max_length
,
...
@@ -307,7 +330,8 @@ class DataCollatorForTokenClassification(DataCollatorMixin):
...
@@ -307,7 +330,8 @@ class DataCollatorForTokenClassification(DataCollatorMixin):
no_labels_features
=
[{
k
:
v
for
k
,
v
in
feature
.
items
()
if
k
!=
label_name
}
for
feature
in
features
]
no_labels_features
=
[{
k
:
v
for
k
,
v
in
feature
.
items
()
if
k
!=
label_name
}
for
feature
in
features
]
batch
=
self
.
tokenizer
.
pad
(
batch
=
pad_without_fast_tokenizer_warning
(
self
.
tokenizer
,
no_labels_features
,
no_labels_features
,
padding
=
self
.
padding
,
padding
=
self
.
padding
,
max_length
=
self
.
max_length
,
max_length
=
self
.
max_length
,
...
@@ -343,7 +367,8 @@ class DataCollatorForTokenClassification(DataCollatorMixin):
...
@@ -343,7 +367,8 @@ class DataCollatorForTokenClassification(DataCollatorMixin):
label_name
=
"label"
if
"label"
in
features
[
0
].
keys
()
else
"labels"
label_name
=
"label"
if
"label"
in
features
[
0
].
keys
()
else
"labels"
labels
=
[
feature
[
label_name
]
for
feature
in
features
]
if
label_name
in
features
[
0
].
keys
()
else
None
labels
=
[
feature
[
label_name
]
for
feature
in
features
]
if
label_name
in
features
[
0
].
keys
()
else
None
batch
=
self
.
tokenizer
.
pad
(
batch
=
pad_without_fast_tokenizer_warning
(
self
.
tokenizer
,
features
,
features
,
padding
=
self
.
padding
,
padding
=
self
.
padding
,
max_length
=
self
.
max_length
,
max_length
=
self
.
max_length
,
...
@@ -372,7 +397,8 @@ class DataCollatorForTokenClassification(DataCollatorMixin):
...
@@ -372,7 +397,8 @@ class DataCollatorForTokenClassification(DataCollatorMixin):
def
numpy_call
(
self
,
features
):
def
numpy_call
(
self
,
features
):
label_name
=
"label"
if
"label"
in
features
[
0
].
keys
()
else
"labels"
label_name
=
"label"
if
"label"
in
features
[
0
].
keys
()
else
"labels"
labels
=
[
feature
[
label_name
]
for
feature
in
features
]
if
label_name
in
features
[
0
].
keys
()
else
None
labels
=
[
feature
[
label_name
]
for
feature
in
features
]
if
label_name
in
features
[
0
].
keys
()
else
None
batch
=
self
.
tokenizer
.
pad
(
batch
=
pad_without_fast_tokenizer_warning
(
self
.
tokenizer
,
features
,
features
,
padding
=
self
.
padding
,
padding
=
self
.
padding
,
max_length
=
self
.
max_length
,
max_length
=
self
.
max_length
,
...
@@ -583,7 +609,8 @@ class DataCollatorForSeq2Seq:
...
@@ -583,7 +609,8 @@ class DataCollatorForSeq2Seq:
else
:
else
:
feature
[
"labels"
]
=
np
.
concatenate
([
remainder
,
feature
[
"labels"
]]).
astype
(
np
.
int64
)
feature
[
"labels"
]
=
np
.
concatenate
([
remainder
,
feature
[
"labels"
]]).
astype
(
np
.
int64
)
features
=
self
.
tokenizer
.
pad
(
features
=
pad_without_fast_tokenizer_warning
(
self
.
tokenizer
,
features
,
features
,
padding
=
self
.
padding
,
padding
=
self
.
padding
,
max_length
=
self
.
max_length
,
max_length
=
self
.
max_length
,
...
@@ -692,7 +719,9 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
...
@@ -692,7 +719,9 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
# Handle dict or lists with proper padding and conversion to tensor.
# Handle dict or lists with proper padding and conversion to tensor.
if
isinstance
(
examples
[
0
],
Mapping
):
if
isinstance
(
examples
[
0
],
Mapping
):
batch
=
self
.
tokenizer
.
pad
(
examples
,
return_tensors
=
"tf"
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
batch
=
pad_without_fast_tokenizer_warning
(
self
.
tokenizer
,
examples
,
return_tensors
=
"tf"
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
else
:
else
:
batch
=
{
batch
=
{
"input_ids"
:
_tf_collate_batch
(
examples
,
self
.
tokenizer
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
"input_ids"
:
_tf_collate_batch
(
examples
,
self
.
tokenizer
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
...
@@ -729,7 +758,9 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
...
@@ -729,7 +758,9 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
def
torch_call
(
self
,
examples
:
List
[
Union
[
List
[
int
],
Any
,
Dict
[
str
,
Any
]]])
->
Dict
[
str
,
Any
]:
def
torch_call
(
self
,
examples
:
List
[
Union
[
List
[
int
],
Any
,
Dict
[
str
,
Any
]]])
->
Dict
[
str
,
Any
]:
# Handle dict or lists with proper padding and conversion to tensor.
# Handle dict or lists with proper padding and conversion to tensor.
if
isinstance
(
examples
[
0
],
Mapping
):
if
isinstance
(
examples
[
0
],
Mapping
):
batch
=
self
.
tokenizer
.
pad
(
examples
,
return_tensors
=
"pt"
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
batch
=
pad_without_fast_tokenizer_warning
(
self
.
tokenizer
,
examples
,
return_tensors
=
"pt"
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
else
:
else
:
batch
=
{
batch
=
{
"input_ids"
:
_torch_collate_batch
(
examples
,
self
.
tokenizer
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
"input_ids"
:
_torch_collate_batch
(
examples
,
self
.
tokenizer
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
...
@@ -784,7 +815,9 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
...
@@ -784,7 +815,9 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
def
numpy_call
(
self
,
examples
:
List
[
Union
[
List
[
int
],
Any
,
Dict
[
str
,
Any
]]])
->
Dict
[
str
,
Any
]:
def
numpy_call
(
self
,
examples
:
List
[
Union
[
List
[
int
],
Any
,
Dict
[
str
,
Any
]]])
->
Dict
[
str
,
Any
]:
# Handle dict or lists with proper padding and conversion to tensor.
# Handle dict or lists with proper padding and conversion to tensor.
if
isinstance
(
examples
[
0
],
Mapping
):
if
isinstance
(
examples
[
0
],
Mapping
):
batch
=
self
.
tokenizer
.
pad
(
examples
,
return_tensors
=
"np"
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
batch
=
pad_without_fast_tokenizer_warning
(
self
.
tokenizer
,
examples
,
return_tensors
=
"np"
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
else
:
else
:
batch
=
{
batch
=
{
"input_ids"
:
_numpy_collate_batch
(
examples
,
self
.
tokenizer
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
"input_ids"
:
_numpy_collate_batch
(
examples
,
self
.
tokenizer
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment