"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "0611eab5e3ddcf155b7ac1fcc92f59aaa71ee17b"
Commit a7ca6d73 authored by LysandreJik's avatar LysandreJik
Browse files

Padding side is tokenizer-dependant

parent cca75e78
...@@ -73,8 +73,7 @@ def _is_whitespace(c): ...@@ -73,8 +73,7 @@ def _is_whitespace(c):
return False return False
def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
doc_stride, max_query_length, is_training, doc_stride, max_query_length, is_training):
sequence_a_is_doc=False):
"""Loads a data file into a list of `InputBatch`s.""" """Loads a data file into a list of `InputBatch`s."""
# Defining helper methods # Defining helper methods
...@@ -127,13 +126,13 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -127,13 +126,13 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
while len(spans) * doc_stride < len(all_doc_tokens): while len(spans) * doc_stride < len(all_doc_tokens):
encoded_dict = tokenizer.encode_plus( encoded_dict = tokenizer.encode_plus(
truncated_query if not sequence_a_is_doc else span_doc_tokens, truncated_query if tokenizer.padding_side == "right" else span_doc_tokens,
span_doc_tokens if not sequence_a_is_doc else truncated_query, span_doc_tokens if tokenizer.padding_side == "right" else truncated_query,
max_length=max_seq_length, max_length=max_seq_length,
return_overflowing_tokens=True, return_overflowing_tokens=True,
padding_strategy='right', pad_to_max_length=True,
stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens, stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens,
truncation_strategy='only_second' if not sequence_a_is_doc else 'only_first' truncation_strategy='only_second' if tokenizer.padding_side == "right" else 'only_first'
) )
paragraph_len = min(len(all_doc_tokens) - len(spans) * doc_stride, max_seq_length - len(truncated_query) - sequence_pair_added_tokens) paragraph_len = min(len(all_doc_tokens) - len(spans) * doc_stride, max_seq_length - len(truncated_query) - sequence_pair_added_tokens)
......
...@@ -344,17 +344,19 @@ class CommonTestCases: ...@@ -344,17 +344,19 @@ class CommonTestCases:
padding_idx = tokenizer.pad_token_id padding_idx = tokenizer.pad_token_id
# RIGHT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True # RIGHT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True
tokenizer.padding_side = "right"
encoded_sequence = tokenizer.encode(sequence) encoded_sequence = tokenizer.encode(sequence)
sequence_length = len(encoded_sequence) sequence_length = len(encoded_sequence)
padded_sequence = tokenizer.encode(sequence, max_length=sequence_length + padding_size, padding_strategy='right') padded_sequence = tokenizer.encode(sequence, max_length=sequence_length + padding_size, pad_to_max_length=True)
padded_sequence_length = len(padded_sequence) padded_sequence_length = len(padded_sequence)
assert sequence_length + padding_size == padded_sequence_length assert sequence_length + padding_size == padded_sequence_length
assert encoded_sequence + [padding_idx] * padding_size == padded_sequence assert encoded_sequence + [padding_idx] * padding_size == padded_sequence
# LEFT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True # LEFT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True
tokenizer.padding_side = "left"
encoded_sequence = tokenizer.encode(sequence) encoded_sequence = tokenizer.encode(sequence)
sequence_length = len(encoded_sequence) sequence_length = len(encoded_sequence)
padded_sequence = tokenizer.encode(sequence, max_length=sequence_length + padding_size, padding_strategy='left') padded_sequence = tokenizer.encode(sequence, max_length=sequence_length + padding_size, pad_to_max_length=True)
padded_sequence_length = len(padded_sequence) padded_sequence_length = len(padded_sequence)
assert sequence_length + padding_size == padded_sequence_length assert sequence_length + padding_size == padded_sequence_length
assert [padding_idx] * padding_size + encoded_sequence == padded_sequence assert [padding_idx] * padding_size + encoded_sequence == padded_sequence
...@@ -362,10 +364,15 @@ class CommonTestCases: ...@@ -362,10 +364,15 @@ class CommonTestCases:
# RIGHT & LEFT PADDING - Check that nothing is done when a maximum length is not specified # RIGHT & LEFT PADDING - Check that nothing is done when a maximum length is not specified
encoded_sequence = tokenizer.encode(sequence) encoded_sequence = tokenizer.encode(sequence)
sequence_length = len(encoded_sequence) sequence_length = len(encoded_sequence)
padded_sequence_right = tokenizer.encode(sequence, padding_strategy='right')
tokenizer.padding_side = "right"
padded_sequence_right = tokenizer.encode(sequence, pad_to_max_length=True)
padded_sequence_right_length = len(padded_sequence_right) padded_sequence_right_length = len(padded_sequence_right)
padded_sequence_left = tokenizer.encode(sequence, padding_strategy='left')
tokenizer.padding_side = "left"
padded_sequence_left = tokenizer.encode(sequence, pad_to_max_length=True)
padded_sequence_left_length = len(padded_sequence_left) padded_sequence_left_length = len(padded_sequence_left)
assert sequence_length == padded_sequence_right_length assert sequence_length == padded_sequence_right_length
assert encoded_sequence == padded_sequence_right assert encoded_sequence == padded_sequence_right
assert sequence_length == padded_sequence_left_length assert sequence_length == padded_sequence_left_length
...@@ -387,7 +394,8 @@ class CommonTestCases: ...@@ -387,7 +394,8 @@ class CommonTestCases:
sequence_length = len(input_ids) sequence_length = len(input_ids)
# Test right padding # Test right padding
padded_sequence = tokenizer.encode_plus(sequence, max_length=sequence_length + padding_size, padding_strategy='right', return_special_tokens_mask=True) tokenizer.padding_side = "right"
padded_sequence = tokenizer.encode_plus(sequence, max_length=sequence_length + padding_size, pad_to_max_length=True, return_special_tokens_mask=True)
padded_input_ids = padded_sequence['input_ids'] padded_input_ids = padded_sequence['input_ids']
padded_token_type_ids = padded_sequence['token_type_ids'] padded_token_type_ids = padded_sequence['token_type_ids']
padded_attention_mask = padded_sequence['attention_mask'] padded_attention_mask = padded_sequence['attention_mask']
...@@ -401,7 +409,8 @@ class CommonTestCases: ...@@ -401,7 +409,8 @@ class CommonTestCases:
assert special_tokens_mask + [1] * padding_size == padded_special_tokens_mask assert special_tokens_mask + [1] * padding_size == padded_special_tokens_mask
# Test left padding # Test left padding
padded_sequence = tokenizer.encode_plus(sequence, max_length=sequence_length + padding_size, padding_strategy='left', return_special_tokens_mask=True) tokenizer.padding_side = "left"
padded_sequence = tokenizer.encode_plus(sequence, max_length=sequence_length + padding_size, pad_to_max_length=True, return_special_tokens_mask=True)
padded_input_ids = padded_sequence['input_ids'] padded_input_ids = padded_sequence['input_ids']
padded_token_type_ids = padded_sequence['token_type_ids'] padded_token_type_ids = padded_sequence['token_type_ids']
padded_attention_mask = padded_sequence['attention_mask'] padded_attention_mask = padded_sequence['attention_mask']
......
...@@ -77,6 +77,8 @@ class PreTrainedTokenizer(object): ...@@ -77,6 +77,8 @@ class PreTrainedTokenizer(object):
"pad_token", "cls_token", "mask_token", "pad_token", "cls_token", "mask_token",
"additional_special_tokens"] "additional_special_tokens"]
padding_side = "right"
@property @property
def bos_token(self): def bos_token(self):
""" Beginning of sentence token (string). Log an error if used while not having been set. """ """ Beginning of sentence token (string). Log an error if used while not having been set. """
...@@ -223,6 +225,9 @@ class PreTrainedTokenizer(object): ...@@ -223,6 +225,9 @@ class PreTrainedTokenizer(object):
self.max_len = max_len if max_len is not None else int(1e12) self.max_len = max_len if max_len is not None else int(1e12)
# Padding side is right by default and over-riden in subclsses. If specified in the kwargs, it is changed.
self.padding_side = kwargs.pop('padding_side', self.padding_side)
# Added tokens # Added tokens
self.added_tokens_encoder = {} self.added_tokens_encoder = {}
self.added_tokens_decoder = {} self.added_tokens_decoder = {}
...@@ -702,7 +707,7 @@ class PreTrainedTokenizer(object): ...@@ -702,7 +707,7 @@ class PreTrainedTokenizer(object):
max_length=None, max_length=None,
stride=0, stride=0,
truncation_strategy='longest_first', truncation_strategy='longest_first',
padding_strategy=None, pad_to_max_length=False,
return_tensors=None, return_tensors=None,
**kwargs): **kwargs):
""" """
...@@ -729,12 +734,12 @@ class PreTrainedTokenizer(object): ...@@ -729,12 +734,12 @@ class PreTrainedTokenizer(object):
- 'only_first': Only truncate the first sequence - 'only_first': Only truncate the first sequence
- 'only_second': Only truncate the second sequence - 'only_second': Only truncate the second sequence
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length) - 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
padding_strategy: if set to a strategy, the returned sequences will be padded according to the model's pad_to_max_length: if set to True, the returned sequences will be padded according to the model's padding side and
padding index, up to their max length. If no max length is specified, no padding is done. padding index, up to their max length. If no max length is specified, the padding is done up to the model's max length.
The strategies are handled by the following strings: The tokenizer padding sides are handled by the following strings:
- 'left': pads on the left of the sequences - 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences - 'right': pads on the right of the sequences
Defaults to None: no padding. Defaults to False: no padding.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers. or PyTorch torch.Tensor instead of a list of python integers.
**kwargs: passed to the `self.tokenize()` method **kwargs: passed to the `self.tokenize()` method
...@@ -745,7 +750,7 @@ class PreTrainedTokenizer(object): ...@@ -745,7 +750,7 @@ class PreTrainedTokenizer(object):
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens,
stride=stride, stride=stride,
truncation_strategy=truncation_strategy, truncation_strategy=truncation_strategy,
padding_strategy=padding_strategy, pad_to_max_length=pad_to_max_length,
return_tensors=return_tensors, return_tensors=return_tensors,
**kwargs) **kwargs)
...@@ -758,7 +763,7 @@ class PreTrainedTokenizer(object): ...@@ -758,7 +763,7 @@ class PreTrainedTokenizer(object):
max_length=None, max_length=None,
stride=0, stride=0,
truncation_strategy='longest_first', truncation_strategy='longest_first',
padding_strategy=None, pad_to_max_length=False,
return_tensors=None, return_tensors=None,
return_token_type_ids=True, return_token_type_ids=True,
return_attention_mask=True, return_attention_mask=True,
...@@ -788,12 +793,12 @@ class PreTrainedTokenizer(object): ...@@ -788,12 +793,12 @@ class PreTrainedTokenizer(object):
- 'only_first': Only truncate the first sequence - 'only_first': Only truncate the first sequence
- 'only_second': Only truncate the second sequence - 'only_second': Only truncate the second sequence
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length) - 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
padding_strategy: if set to a strategy, the returned sequences will be padded according to the model's pad_to_max_length: if set to True, the returned sequences will be padded according to the model's padding side and
padding index, up to their max length. If no max length is specified, no padding is done. padding index, up to their max length. If no max length is specified, the padding is done up to the model's max length.
The strategies are handled by the following strings: The tokenizer padding sides are handled by the following strings:
- 'left': pads on the left of the sequences - 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences - 'right': pads on the right of the sequences
Defaults to None: no padding. Defaults to False: no padding.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers. or PyTorch torch.Tensor instead of a list of python integers.
return_token_type_ids: (optional) Set to False to avoid returning token_type_ids (default True). return_token_type_ids: (optional) Set to False to avoid returning token_type_ids (default True).
...@@ -841,7 +846,7 @@ class PreTrainedTokenizer(object): ...@@ -841,7 +846,7 @@ class PreTrainedTokenizer(object):
return self.prepare_for_model(first_ids, return self.prepare_for_model(first_ids,
pair_ids=second_ids, pair_ids=second_ids,
max_length=max_length, max_length=max_length,
padding_strategy=padding_strategy, pad_to_max_length=pad_to_max_length,
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens,
stride=stride, stride=stride,
truncation_strategy=truncation_strategy, truncation_strategy=truncation_strategy,
...@@ -853,7 +858,7 @@ class PreTrainedTokenizer(object): ...@@ -853,7 +858,7 @@ class PreTrainedTokenizer(object):
def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=True, stride=0, def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=True, stride=0,
truncation_strategy='longest_first', truncation_strategy='longest_first',
padding_strategy=None, pad_to_max_length=False,
return_tensors=None, return_tensors=None,
return_token_type_ids=True, return_token_type_ids=True,
return_attention_mask=True, return_attention_mask=True,
...@@ -881,12 +886,12 @@ class PreTrainedTokenizer(object): ...@@ -881,12 +886,12 @@ class PreTrainedTokenizer(object):
- 'only_first': Only truncate the first sequence - 'only_first': Only truncate the first sequence
- 'only_second': Only truncate the second sequence - 'only_second': Only truncate the second sequence
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length) - 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
padding_strategy: if set to a strategy, the returned sequences will be padded according to the model's pad_to_max_length: if set to True, the returned sequences will be padded according to the model's padding side and
padding index, up to their max length. If no max length is specified, no padding is done. padding index, up to their max length. If no max length is specified, the padding is done up to the model's max length.
The strategies are handled by the following strings: The tokenizer padding sides are handled by the following strings:
- 'left': pads on the left of the sequences - 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences - 'right': pads on the right of the sequences
Defaults to None: no padding. Defaults to False: no padding.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers. or PyTorch torch.Tensor instead of a list of python integers.
return_token_type_ids: (optional) Set to False to avoid returning token_type_ids (default True). return_token_type_ids: (optional) Set to False to avoid returning token_type_ids (default True).
...@@ -955,10 +960,19 @@ class PreTrainedTokenizer(object): ...@@ -955,10 +960,19 @@ class PreTrainedTokenizer(object):
"for this model ({} > {}). Running this sequence through the model will result in " "for this model ({} > {}). Running this sequence through the model will result in "
"indexing errors".format(len(ids), self.max_len)) "indexing errors".format(len(ids), self.max_len))
if padding_strategy is not None and max_length and len(encoded_inputs["input_ids"]) < max_length: needs_to_be_padded = pad_to_max_length and (
difference = max_length - len(encoded_inputs["input_ids"]) max_length and len(encoded_inputs["input_ids"]) < max_length
or
max_length is None and len(encoded_inputs["input_ids"]) < self.max_len and self.max_len <= 10000
)
if pad_to_max_length and max_length is None and self.max_len > 10000:
logger.warning("Sequence can't be padded as the maximum ")
if needs_to_be_padded:
difference = (max_length if max_length is not None else self.max_len) - len(encoded_inputs["input_ids"])
if padding_strategy == 'right': if self.padding_side == 'right':
if return_attention_mask: if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) + [0] * difference encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) + [0] * difference
if return_token_type_ids: if return_token_type_ids:
...@@ -967,7 +981,7 @@ class PreTrainedTokenizer(object): ...@@ -967,7 +981,7 @@ class PreTrainedTokenizer(object):
encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
encoded_inputs["input_ids"] = encoded_inputs["input_ids"] + [self.pad_token_id] * difference encoded_inputs["input_ids"] = encoded_inputs["input_ids"] + [self.pad_token_id] * difference
elif padding_strategy == 'left': elif self.padding_side == 'left':
if return_attention_mask: if return_attention_mask:
encoded_inputs["attention_mask"] = [0] * difference + [1] * len(encoded_inputs["input_ids"]) encoded_inputs["attention_mask"] = [0] * difference + [1] * len(encoded_inputs["input_ids"])
if return_token_type_ids: if return_token_type_ids:
...@@ -977,7 +991,7 @@ class PreTrainedTokenizer(object): ...@@ -977,7 +991,7 @@ class PreTrainedTokenizer(object):
encoded_inputs["input_ids"] = [self.pad_token_id] * difference + encoded_inputs["input_ids"] encoded_inputs["input_ids"] = [self.pad_token_id] * difference + encoded_inputs["input_ids"]
else: else:
raise ValueError("Invalid padding strategy:" + str(padding_strategy)) raise ValueError("Invalid padding strategy:" + str(self.padding_side))
elif return_attention_mask: elif return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"])
......
...@@ -60,6 +60,7 @@ class XLNetTokenizer(PreTrainedTokenizer): ...@@ -60,6 +60,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
padding_side = "left"
def __init__(self, vocab_file, def __init__(self, vocab_file,
do_lower_case=False, remove_space=True, keep_accents=False, do_lower_case=False, remove_space=True, keep_accents=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment