Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0b51532c
".circleci/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "e344e45f2d8fa5d8b1024bc2e8b740b873652dd4"
Commit
0b51532c
authored
Dec 13, 2019
by
Morgan Funtowicz
Browse files
Reintroducing the batch_encode_plus method
parent
8938b546
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
86 additions
and
0 deletions
+86
-0
transformers/tokenization_utils.py
transformers/tokenization_utils.py
+86
-0
No files found.
transformers/tokenization_utils.py
View file @
0b51532c
...
...
@@ -878,6 +878,92 @@ class PreTrainedTokenizer(object):
return_overflowing_tokens
=
return_overflowing_tokens
,
return_special_tokens_mask
=
return_special_tokens_mask
)
def
batch_encode_plus
(
self
,
batch_text_or_text_pairs
=
None
,
add_special_tokens
=
False
,
max_length
=
None
,
stride
=
0
,
truncation_strategy
=
'longest_first'
,
return_tensors
=
None
,
return_input_lengths
=
False
,
return_attention_masks
=
False
,
**
kwargs
):
"""
Returns a dictionary containing the encoded sequence or sequence pair and additional information:
the mask for sequence classification and the overflowing elements if a ``max_length`` is specified.
Args:
batch_text_or_text_pairs: Batch of sequences or pair of sequences to be encoded.
This can be a list of string/string-sequences/int-sequences or a list of pair of
string/string-sequences/int-sequence (see details in encode_plus)
add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative
to their model.
max_length: if set to a number, will limit the total sequence returned so that it has a maximum length.
If there are overflowing tokens, those will be added to the returned dictionary`
stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens
from the main sequence returned. The value of this argument defines the number of additional tokens.
truncation_strategy: string selected in the following options:
- 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length
starting from the longest one at each token (when there is a pair of input sequences)
- 'only_first': Only truncate the first sequence
- 'only_second': Only truncate the second sequence
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
**kwargs: passed to the `self.tokenize()` method
"""
batch_outputs
=
{}
for
ids_or_pair_ids
in
batch_text_or_text_pairs
:
if
isinstance
(
ids_or_pair_ids
,
(
list
,
tuple
)):
assert
len
(
ids_or_pair_ids
)
==
2
ids
,
pair_ids
=
ids_or_pair_ids
else
:
ids
,
pair_ids
=
ids_or_pair_ids
,
None
outputs
=
self
.
encode_plus
(
ids
,
pair_ids
,
add_special_tokens
=
add_special_tokens
,
max_length
=
max_length
,
stride
=
stride
,
truncation_strategy
=
truncation_strategy
,
return_tensors
=
None
)
# Append the non-padded length to the output
if
return_input_lengths
:
outputs
[
'input_len'
]
=
len
(
outputs
[
'input_ids'
])
for
key
,
value
in
outputs
.
items
():
if
key
not
in
batch_outputs
:
batch_outputs
[
key
]
=
[]
batch_outputs
[
key
].
append
(
value
)
# Compute longest sequence size
max_seq_len
=
max
(
map
(
len
,
batch_outputs
[
'input_ids'
]))
if
return_attention_masks
:
# Allow the model to not give any special attention to padded input
batch_outputs
[
'attention_mask'
]
=
[[
0
]
*
len
(
v
)
for
v
in
batch_outputs
[
'input_ids'
]]
if
return_tensors
is
not
None
:
# Do the tensor conversion in batch
for
key
,
value
in
batch_outputs
.
items
():
padded_value
=
value
if
key
!=
'input_len'
:
# Padding handle
padded_value
=
[
v
+
[
self
.
pad_token_id
if
key
==
'input_ids'
else
1
]
*
(
max_seq_len
-
len
(
v
))
for
v
in
padded_value
]
if
return_tensors
==
'tf'
and
is_tf_available
():
batch_outputs
[
key
]
=
tf
.
constant
(
padded_value
)
elif
return_tensors
==
'pt'
and
is_torch_available
():
batch_outputs
[
key
]
=
torch
.
tensor
(
padded_value
)
elif
return_tensors
is
not
None
:
logger
.
warning
(
"Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available."
.
format
(
return_tensors
))
# encoder_attention_mask requires 1 for real token, 0 for padding, just invert value
if
return_attention_masks
:
if
is_tf_available
():
batch_outputs
[
'attention_mask'
]
=
tf
.
abs
(
batch_outputs
[
'attention_mask'
]
-
1
)
else
:
batch_outputs
[
'attention_mask'
]
=
torch
.
abs
(
batch_outputs
[
'attention_mask'
]
-
1
)
return
batch_outputs
def
prepare_for_model
(
self
,
ids
,
pair_ids
=
None
,
max_length
=
None
,
add_special_tokens
=
True
,
stride
=
0
,
truncation_strategy
=
'longest_first'
,
pad_to_max_length
=
False
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment