Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a7ca6d73
Commit
a7ca6d73
authored
Dec 04, 2019
by
LysandreJik
Browse files
Padding side is tokenizer-dependant
parent
cca75e78
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
58 additions
and
35 deletions
+58
-35
transformers/data/processors/squad.py
transformers/data/processors/squad.py
+5
-6
transformers/tests/tokenization_tests_commons.py
transformers/tests/tokenization_tests_commons.py
+15
-6
transformers/tokenization_utils.py
transformers/tokenization_utils.py
+37
-23
transformers/tokenization_xlnet.py
transformers/tokenization_xlnet.py
+1
-0
No files found.
transformers/data/processors/squad.py
View file @
a7ca6d73
...
...
@@ -73,8 +73,7 @@ def _is_whitespace(c):
return
False
def
squad_convert_examples_to_features
(
examples
,
tokenizer
,
max_seq_length
,
doc_stride
,
max_query_length
,
is_training
,
sequence_a_is_doc
=
False
):
doc_stride
,
max_query_length
,
is_training
):
"""Loads a data file into a list of `InputBatch`s."""
# Defining helper methods
...
...
@@ -127,13 +126,13 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
while
len
(
spans
)
*
doc_stride
<
len
(
all_doc_tokens
):
encoded_dict
=
tokenizer
.
encode_plus
(
truncated_query
if
not
sequence_a_is_doc
else
span_doc_tokens
,
span_doc_tokens
if
not
sequence_a_is_doc
else
truncated_query
,
truncated_query
if
tokenizer
.
padding_side
==
"right"
else
span_doc_tokens
,
span_doc_tokens
if
tokenizer
.
padding_side
==
"right"
else
truncated_query
,
max_length
=
max_seq_length
,
return_overflowing_tokens
=
True
,
pad
ding_strategy
=
'right'
,
pad
_to_max_length
=
True
,
stride
=
max_seq_length
-
doc_stride
-
len
(
truncated_query
)
-
sequence_pair_added_tokens
,
truncation_strategy
=
'only_second'
if
not
sequence_a_is_doc
else
'only_first'
truncation_strategy
=
'only_second'
if
tokenizer
.
padding_side
==
"right"
else
'only_first'
)
paragraph_len
=
min
(
len
(
all_doc_tokens
)
-
len
(
spans
)
*
doc_stride
,
max_seq_length
-
len
(
truncated_query
)
-
sequence_pair_added_tokens
)
...
...
transformers/tests/tokenization_tests_commons.py
View file @
a7ca6d73
...
...
@@ -344,17 +344,19 @@ class CommonTestCases:
padding_idx
=
tokenizer
.
pad_token_id
# RIGHT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True
tokenizer
.
padding_side
=
"right"
encoded_sequence
=
tokenizer
.
encode
(
sequence
)
sequence_length
=
len
(
encoded_sequence
)
padded_sequence
=
tokenizer
.
encode
(
sequence
,
max_length
=
sequence_length
+
padding_size
,
pad
ding_strategy
=
'right'
)
padded_sequence
=
tokenizer
.
encode
(
sequence
,
max_length
=
sequence_length
+
padding_size
,
pad
_to_max_length
=
True
)
padded_sequence_length
=
len
(
padded_sequence
)
assert
sequence_length
+
padding_size
==
padded_sequence_length
assert
encoded_sequence
+
[
padding_idx
]
*
padding_size
==
padded_sequence
# LEFT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True
tokenizer
.
padding_side
=
"left"
encoded_sequence
=
tokenizer
.
encode
(
sequence
)
sequence_length
=
len
(
encoded_sequence
)
padded_sequence
=
tokenizer
.
encode
(
sequence
,
max_length
=
sequence_length
+
padding_size
,
pad
ding_strategy
=
'left'
)
padded_sequence
=
tokenizer
.
encode
(
sequence
,
max_length
=
sequence_length
+
padding_size
,
pad
_to_max_length
=
True
)
padded_sequence_length
=
len
(
padded_sequence
)
assert
sequence_length
+
padding_size
==
padded_sequence_length
assert
[
padding_idx
]
*
padding_size
+
encoded_sequence
==
padded_sequence
...
...
@@ -362,10 +364,15 @@ class CommonTestCases:
# RIGHT & LEFT PADDING - Check that nothing is done when a maximum length is not specified
encoded_sequence
=
tokenizer
.
encode
(
sequence
)
sequence_length
=
len
(
encoded_sequence
)
padded_sequence_right
=
tokenizer
.
encode
(
sequence
,
padding_strategy
=
'right'
)
tokenizer
.
padding_side
=
"right"
padded_sequence_right
=
tokenizer
.
encode
(
sequence
,
pad_to_max_length
=
True
)
padded_sequence_right_length
=
len
(
padded_sequence_right
)
padded_sequence_left
=
tokenizer
.
encode
(
sequence
,
padding_strategy
=
'left'
)
tokenizer
.
padding_side
=
"left"
padded_sequence_left
=
tokenizer
.
encode
(
sequence
,
pad_to_max_length
=
True
)
padded_sequence_left_length
=
len
(
padded_sequence_left
)
assert
sequence_length
==
padded_sequence_right_length
assert
encoded_sequence
==
padded_sequence_right
assert
sequence_length
==
padded_sequence_left_length
...
...
@@ -387,7 +394,8 @@ class CommonTestCases:
sequence_length
=
len
(
input_ids
)
# Test right padding
padded_sequence
=
tokenizer
.
encode_plus
(
sequence
,
max_length
=
sequence_length
+
padding_size
,
padding_strategy
=
'right'
,
return_special_tokens_mask
=
True
)
tokenizer
.
padding_side
=
"right"
padded_sequence
=
tokenizer
.
encode_plus
(
sequence
,
max_length
=
sequence_length
+
padding_size
,
pad_to_max_length
=
True
,
return_special_tokens_mask
=
True
)
padded_input_ids
=
padded_sequence
[
'input_ids'
]
padded_token_type_ids
=
padded_sequence
[
'token_type_ids'
]
padded_attention_mask
=
padded_sequence
[
'attention_mask'
]
...
...
@@ -401,7 +409,8 @@ class CommonTestCases:
assert
special_tokens_mask
+
[
1
]
*
padding_size
==
padded_special_tokens_mask
# Test left padding
padded_sequence
=
tokenizer
.
encode_plus
(
sequence
,
max_length
=
sequence_length
+
padding_size
,
padding_strategy
=
'left'
,
return_special_tokens_mask
=
True
)
tokenizer
.
padding_side
=
"left"
padded_sequence
=
tokenizer
.
encode_plus
(
sequence
,
max_length
=
sequence_length
+
padding_size
,
pad_to_max_length
=
True
,
return_special_tokens_mask
=
True
)
padded_input_ids
=
padded_sequence
[
'input_ids'
]
padded_token_type_ids
=
padded_sequence
[
'token_type_ids'
]
padded_attention_mask
=
padded_sequence
[
'attention_mask'
]
...
...
transformers/tokenization_utils.py
View file @
a7ca6d73
...
...
@@ -77,6 +77,8 @@ class PreTrainedTokenizer(object):
"pad_token"
,
"cls_token"
,
"mask_token"
,
"additional_special_tokens"
]
padding_side
=
"right"
@
property
def
bos_token
(
self
):
""" Beginning of sentence token (string). Log an error if used while not having been set. """
...
...
@@ -223,6 +225,9 @@ class PreTrainedTokenizer(object):
self
.
max_len
=
max_len
if
max_len
is
not
None
else
int
(
1e12
)
# Padding side is right by default and over-riden in subclsses. If specified in the kwargs, it is changed.
self
.
padding_side
=
kwargs
.
pop
(
'padding_side'
,
self
.
padding_side
)
# Added tokens
self
.
added_tokens_encoder
=
{}
self
.
added_tokens_decoder
=
{}
...
...
@@ -702,7 +707,7 @@ class PreTrainedTokenizer(object):
max_length
=
None
,
stride
=
0
,
truncation_strategy
=
'longest_first'
,
pad
ding_strategy
=
Non
e
,
pad
_to_max_length
=
Fals
e
,
return_tensors
=
None
,
**
kwargs
):
"""
...
...
@@ -729,12 +734,12 @@ class PreTrainedTokenizer(object):
- 'only_first': Only truncate the first sequence
- 'only_second': Only truncate the second sequence
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
pad
ding_strategy
: if set to
a strategy
, the returned sequences will be padded according to the model's
padding index, up to their max length. If no max length is specified,
no
padding is done.
The
strategi
es are handled by the following strings:
pad
_to_max_length
: if set to
True
, the returned sequences will be padded according to the model's
padding side and
padding index, up to their max length. If no max length is specified,
the
padding is done
up to the model's max length
.
The
tokenizer padding sid
es are handled by the following strings:
- 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences
Defaults to
Non
e: no padding.
Defaults to
Fals
e: no padding.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
**kwargs: passed to the `self.tokenize()` method
...
...
@@ -745,7 +750,7 @@ class PreTrainedTokenizer(object):
add_special_tokens
=
add_special_tokens
,
stride
=
stride
,
truncation_strategy
=
truncation_strategy
,
pad
ding_strategy
=
padding_strategy
,
pad
_to_max_length
=
pad_to_max_length
,
return_tensors
=
return_tensors
,
**
kwargs
)
...
...
@@ -758,7 +763,7 @@ class PreTrainedTokenizer(object):
max_length
=
None
,
stride
=
0
,
truncation_strategy
=
'longest_first'
,
pad
ding_strategy
=
Non
e
,
pad
_to_max_length
=
Fals
e
,
return_tensors
=
None
,
return_token_type_ids
=
True
,
return_attention_mask
=
True
,
...
...
@@ -788,12 +793,12 @@ class PreTrainedTokenizer(object):
- 'only_first': Only truncate the first sequence
- 'only_second': Only truncate the second sequence
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
pad
ding_strategy
: if set to
a strategy
, the returned sequences will be padded according to the model's
padding index, up to their max length. If no max length is specified,
no
padding is done.
The
strategi
es are handled by the following strings:
pad
_to_max_length
: if set to
True
, the returned sequences will be padded according to the model's
padding side and
padding index, up to their max length. If no max length is specified,
the
padding is done
up to the model's max length
.
The
tokenizer padding sid
es are handled by the following strings:
- 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences
Defaults to
Non
e: no padding.
Defaults to
Fals
e: no padding.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
return_token_type_ids: (optional) Set to False to avoid returning token_type_ids (default True).
...
...
@@ -841,7 +846,7 @@ class PreTrainedTokenizer(object):
return
self
.
prepare_for_model
(
first_ids
,
pair_ids
=
second_ids
,
max_length
=
max_length
,
pad
ding_strategy
=
padding_strategy
,
pad
_to_max_length
=
pad_to_max_length
,
add_special_tokens
=
add_special_tokens
,
stride
=
stride
,
truncation_strategy
=
truncation_strategy
,
...
...
@@ -853,7 +858,7 @@ class PreTrainedTokenizer(object):
def
prepare_for_model
(
self
,
ids
,
pair_ids
=
None
,
max_length
=
None
,
add_special_tokens
=
True
,
stride
=
0
,
truncation_strategy
=
'longest_first'
,
pad
ding_strategy
=
Non
e
,
pad
_to_max_length
=
Fals
e
,
return_tensors
=
None
,
return_token_type_ids
=
True
,
return_attention_mask
=
True
,
...
...
@@ -881,12 +886,12 @@ class PreTrainedTokenizer(object):
- 'only_first': Only truncate the first sequence
- 'only_second': Only truncate the second sequence
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
pad
ding_strategy
: if set to
a strategy
, the returned sequences will be padded according to the model's
padding index, up to their max length. If no max length is specified,
no
padding is done.
The
strategi
es are handled by the following strings:
pad
_to_max_length
: if set to
True
, the returned sequences will be padded according to the model's
padding side and
padding index, up to their max length. If no max length is specified,
the
padding is done
up to the model's max length
.
The
tokenizer padding sid
es are handled by the following strings:
- 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences
Defaults to
Non
e: no padding.
Defaults to
Fals
e: no padding.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
return_token_type_ids: (optional) Set to False to avoid returning token_type_ids (default True).
...
...
@@ -955,10 +960,19 @@ class PreTrainedTokenizer(object):
"for this model ({} > {}). Running this sequence through the model will result in "
"indexing errors"
.
format
(
len
(
ids
),
self
.
max_len
))
if
padding_strategy
is
not
None
and
max_length
and
len
(
encoded_inputs
[
"input_ids"
])
<
max_length
:
difference
=
max_length
-
len
(
encoded_inputs
[
"input_ids"
])
needs_to_be_padded
=
pad_to_max_length
and
(
max_length
and
len
(
encoded_inputs
[
"input_ids"
])
<
max_length
or
max_length
is
None
and
len
(
encoded_inputs
[
"input_ids"
])
<
self
.
max_len
and
self
.
max_len
<=
10000
)
if
pad_to_max_length
and
max_length
is
None
and
self
.
max_len
>
10000
:
logger
.
warning
(
"Sequence can't be padded as the maximum "
)
if
needs_to_be_padded
:
difference
=
(
max_length
if
max_length
is
not
None
else
self
.
max_len
)
-
len
(
encoded_inputs
[
"input_ids"
])
if
padding_s
trategy
==
'right'
:
if
self
.
padding_s
ide
==
'right'
:
if
return_attention_mask
:
encoded_inputs
[
"attention_mask"
]
=
[
1
]
*
len
(
encoded_inputs
[
"input_ids"
])
+
[
0
]
*
difference
if
return_token_type_ids
:
...
...
@@ -967,7 +981,7 @@ class PreTrainedTokenizer(object):
encoded_inputs
[
"special_tokens_mask"
]
=
encoded_inputs
[
"special_tokens_mask"
]
+
[
1
]
*
difference
encoded_inputs
[
"input_ids"
]
=
encoded_inputs
[
"input_ids"
]
+
[
self
.
pad_token_id
]
*
difference
elif
padding_s
trategy
==
'left'
:
elif
self
.
padding_s
ide
==
'left'
:
if
return_attention_mask
:
encoded_inputs
[
"attention_mask"
]
=
[
0
]
*
difference
+
[
1
]
*
len
(
encoded_inputs
[
"input_ids"
])
if
return_token_type_ids
:
...
...
@@ -977,7 +991,7 @@ class PreTrainedTokenizer(object):
encoded_inputs
[
"input_ids"
]
=
[
self
.
pad_token_id
]
*
difference
+
encoded_inputs
[
"input_ids"
]
else
:
raise
ValueError
(
"Invalid padding strategy:"
+
str
(
padding_s
trategy
))
raise
ValueError
(
"Invalid padding strategy:"
+
str
(
self
.
padding_s
ide
))
elif
return_attention_mask
:
encoded_inputs
[
"attention_mask"
]
=
[
1
]
*
len
(
encoded_inputs
[
"input_ids"
])
...
...
transformers/tokenization_xlnet.py
View file @
a7ca6d73
...
...
@@ -60,6 +60,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
vocab_files_names
=
VOCAB_FILES_NAMES
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
padding_side
=
"left"
def
__init__
(
self
,
vocab_file
,
do_lower_case
=
False
,
remove_space
=
True
,
keep_accents
=
False
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment