"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "edb6b369d76ff0a132f9e18c69aa06aebfe87fe0"
Commit f5bcde0b authored by Julien Chaumond's avatar Julien Chaumond
Browse files

[multiple-choice] Simplify and use tokenizer.encode_plus

parent 2dc8cb87
...@@ -9,7 +9,7 @@ similar API between the different models. ...@@ -9,7 +9,7 @@ similar API between the different models.
| [Language Generation](#language-generation) | Conditional text generation using the auto-regressive models of the library: GPT, GPT-2, Transformer-XL and XLNet. | | [Language Generation](#language-generation) | Conditional text generation using the auto-regressive models of the library: GPT, GPT-2, Transformer-XL and XLNet. |
| [GLUE](#glue) | Examples running BERT/XLM/XLNet/RoBERTa on the 9 GLUE tasks. Examples feature distributed training as well as half-precision. | | [GLUE](#glue) | Examples running BERT/XLM/XLNet/RoBERTa on the 9 GLUE tasks. Examples feature distributed training as well as half-precision. |
| [SQuAD](#squad) | Using BERT for question answering, examples with distributed training. | | [SQuAD](#squad) | Using BERT for question answering, examples with distributed training. |
| [Multiple Choice](#multiple choice) | Examples running BERT/XLNet/RoBERTa on the SWAG/RACE/ARC tasks. | [Multiple Choice](#multiple-choice) | Examples running BERT/XLNet/RoBERTa on the SWAG/RACE/ARC tasks.
## Language model fine-tuning ## Language model fine-tuning
...@@ -283,17 +283,17 @@ The results are the following: ...@@ -283,17 +283,17 @@ The results are the following:
loss = 0.04755385363816904 loss = 0.04755385363816904
``` ```
##Multiple Choice ## Multiple Choice
Based on the script [`run_multiple_choice.py`](). Based on the script [`run_multiple_choice.py`]().
#### Fine-tuning on SWAG #### Fine-tuning on SWAG
Download [swag](https://github.com/rowanz/swagaf/tree/master/data) data Download [swag](https://github.com/rowanz/swagaf/tree/master/data) data
``` ```bash
#training on 4 tesla V100(16GB) GPUS #training on 4 tesla V100(16GB) GPUS
export SWAG_DIR=/path/to/swag_data_dir export SWAG_DIR=/path/to/swag_data_dir
python ./examples/single_model_scripts/run_multiple_choice.py \ python ./examples/run_multiple_choice.py \
--model_type roberta \ --model_type roberta \
--task_name swag \ --task_name swag \
--model_name_or_path roberta-base \ --model_name_or_path roberta-base \
......
...@@ -306,14 +306,14 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False, test=False): ...@@ -306,14 +306,14 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False, test=False):
else: else:
examples = processor.get_train_examples(args.data_dir) examples = processor.get_train_examples(args.data_dir)
logger.info("Training number: %s", str(len(examples))) logger.info("Training number: %s", str(len(examples)))
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, features = convert_examples_to_features(
cls_token_at_end=bool(args.model_type in ['xlnet']), # xlnet has a cls token at the end examples,
cls_token=tokenizer.cls_token, label_list,
sep_token=tokenizer.sep_token, args.max_seq_length,
sep_token_extra=bool(args.model_type in ['roberta']), tokenizer,
cls_token_segment_id=2 if args.model_type in ['xlnet'] else 0,
pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet
pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0) pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0
)
if args.local_rank in [-1, 0]: if args.local_rank in [-1, 0]:
logger.info("Saving features into cached file %s", cached_features_file) logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file) torch.save(features, cached_features_file)
...@@ -362,7 +362,7 @@ def main(): ...@@ -362,7 +362,7 @@ def main():
help="Whether to run eval on the dev set.") help="Whether to run eval on the dev set.")
parser.add_argument("--do_test", action='store_true', help='Whether to run test on the test set') parser.add_argument("--do_test", action='store_true', help='Whether to run test on the test set')
parser.add_argument("--evaluate_during_training", action='store_true', parser.add_argument("--evaluate_during_training", action='store_true',
help="Rul evaluation during training at each logging step.") help="Run evaluation during training at each logging step.")
parser.add_argument("--do_lower_case", action='store_true', parser.add_argument("--do_lower_case", action='store_true',
help="Set this flag if you are using an uncased model.") help="Set this flag if you are using an uncased model.")
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" BERT multiple choice fine-tuning: utilities to work with multiple choice tasks of reading comprehension """ """ Multiple choice fine-tuning: utilities to work with multiple choice tasks of reading comprehension """
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
...@@ -26,6 +26,8 @@ import json ...@@ -26,6 +26,8 @@ import json
import csv import csv
import glob import glob
import tqdm import tqdm
from typing import List
from transformers import PreTrainedTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -34,13 +36,13 @@ logger = logging.getLogger(__name__) ...@@ -34,13 +36,13 @@ logger = logging.getLogger(__name__)
class InputExample(object): class InputExample(object):
"""A single training/test example for multiple choice""" """A single training/test example for multiple choice"""
def __init__(self, example_id, question, contexts, endings, label=None): def __init__(self, example_id, question, contexts, endings, label=None):
"""Constructs a InputExample. """Constructs a InputExample.
Args: Args:
example_id: Unique id for the example. example_id: Unique id for the example.
contexts: list of str. The untokenized text of the first sequence (context of corresponding question). contexts: list of str. The untokenized text of the first sequence (context of corresponding question).
question: string. The untokenized text of the second sequence (qustion). question: string. The untokenized text of the second sequence (question).
endings: list of str. multiple choice's options. Its length must be equal to contexts' length. endings: list of str. multiple choice's options. Its length must be equal to contexts' length.
label: (Optional) string. The label of the example. This should be label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples. specified for train and dev examples, but not for test examples.
...@@ -66,7 +68,7 @@ class InputFeatures(object): ...@@ -66,7 +68,7 @@ class InputFeatures(object):
'input_mask': input_mask, 'input_mask': input_mask,
'segment_ids': segment_ids 'segment_ids': segment_ids
} }
for _, input_ids, input_mask, segment_ids in choices_features for input_ids, input_mask, segment_ids in choices_features
] ]
self.label = label self.label = label
...@@ -192,7 +194,7 @@ class SwagProcessor(DataProcessor): ...@@ -192,7 +194,7 @@ class SwagProcessor(DataProcessor):
return lines return lines
def _create_examples(self, lines, type): def _create_examples(self, lines: List[List[str]], type: str):
"""Creates examples for the training and dev sets.""" """Creates examples for the training and dev sets."""
if type == "train" and lines[0][-1] != 'label': if type == "train" and lines[0][-1] != 'label':
raise ValueError( raise ValueError(
...@@ -300,24 +302,18 @@ class ArcProcessor(DataProcessor): ...@@ -300,24 +302,18 @@ class ArcProcessor(DataProcessor):
return examples return examples
def convert_examples_to_features(examples, label_list, max_seq_length, def convert_examples_to_features(
tokenizer, examples: List[InputExample],
cls_token_at_end=False, label_list: List[str],
cls_token='[CLS]', max_length: int,
cls_token_segment_id=1, tokenizer: PreTrainedTokenizer,
sep_token='[SEP]', pad_token_segment_id=0,
sequence_a_segment_id=0, pad_on_left=False,
sequence_b_segment_id=1, pad_token=0,
sep_token_extra=False, mask_padding_with_zero=True,
pad_token_segment_id=0, ) -> List[InputFeatures]:
pad_on_left=False, """
pad_token=0, Loads a data file into a list of `InputFeatures`
mask_padding_with_zero=True):
""" Loads a data file into a list of `InputBatch`s
`cls_token_at_end` define the location of the CLS token:
- False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
- True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
`cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
""" """
label_map = {label : i for i, label in enumerate(label_list)} label_map = {label : i for i, label in enumerate(label_list)}
...@@ -328,125 +324,71 @@ def convert_examples_to_features(examples, label_list, max_seq_length, ...@@ -328,125 +324,71 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
logger.info("Writing example %d of %d" % (ex_index, len(examples))) logger.info("Writing example %d of %d" % (ex_index, len(examples)))
choices_features = [] choices_features = []
for ending_idx, (context, ending) in enumerate(zip(example.contexts, example.endings)): for ending_idx, (context, ending) in enumerate(zip(example.contexts, example.endings)):
tokens_a = tokenizer.tokenize(context) text_a = context
tokens_b = None
if example.question.find("_") != -1: if example.question.find("_") != -1:
#this is for cloze question # this is for cloze question
tokens_b = tokenizer.tokenize(example.question.replace("_", ending)) text_b = example.question.replace("_", ending)
else: else:
tokens_b = tokenizer.tokenize(example.question + " " + ending) text_b = example.question + " " + ending
# you can add seq token between quesiotn and ending. This does not make too much difference.
# tokens_b = tokenizer.tokenize(example.question) inputs = tokenizer.encode_plus(
# tokens_b += [sep_token] text_a,
# if sep_token_extra: text_b,
# tokens_b += [sep_token] add_special_tokens=True,
# tokens_b += tokenizer.tokenize(ending) max_length=max_length,
truncate_both_sequences=True
special_tokens_count = 4 if sep_token_extra else 3 )
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - special_tokens_count) if 'overflowing_tokens' in inputs and len(inputs['overflowing_tokens']) > 0:
logger.info('Attention! you are cropping tokens (swag task is ok). '
# The convention in BERT is: 'If you are training ARC and RACE and you are poping question + options,'
# (a) For sequence pairs: 'you need to try to use a bigger max seq length!')
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
# (b) For single sequences:
# tokens: [CLS] the dog is hairy . [SEP]
# type_ids: 0 0 0 0 0 0 0
#
# Where "type_ids" are used to indicate whether this is the first
# sequence or the second sequence. The embedding vectors for `type=0` and
# `type=1` were learned during pre-training and are added to the wordpiece
# embedding vector (and position vector). This is not *strictly* necessary
# since the [SEP] token unambiguously separates the sequences, but it makes
# it easier for the model to learn the concept of sequences.
#
# For classification tasks, the first vector (corresponding to [CLS]) is
# used as as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
tokens = tokens_a + [sep_token]
if sep_token_extra:
# roberta uses an extra separator b/w pairs of sentences
tokens += [sep_token]
segment_ids = [sequence_a_segment_id] * len(tokens)
if tokens_b:
tokens += tokens_b + [sep_token]
segment_ids += [sequence_b_segment_id] * (len(tokens_b) + 1)
if cls_token_at_end:
tokens = tokens + [cls_token]
segment_ids = segment_ids + [cls_token_segment_id]
else:
tokens = [cls_token] + tokens
segment_ids = [cls_token_segment_id] + segment_ids
input_ids = tokenizer.convert_tokens_to_ids(tokens) input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
# The mask has 1 for real tokens and 0 for padding tokens. Only real # The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to. # tokens are attended to.
input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
# Zero-pad up to the sequence length. # Zero-pad up to the sequence length.
padding_length = max_seq_length - len(input_ids) padding_length = max_length - len(input_ids)
if pad_on_left: if pad_on_left:
input_ids = ([pad_token] * padding_length) + input_ids input_ids = ([pad_token] * padding_length) + input_ids
input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask
segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids token_type_ids = ([pad_token_segment_id] * padding_length) + token_type_ids
else: else:
input_ids = input_ids + ([pad_token] * padding_length) input_ids = input_ids + ([pad_token] * padding_length)
input_mask = input_mask + ([0 if mask_padding_with_zero else 1] * padding_length) attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
segment_ids = segment_ids + ([pad_token_segment_id] * padding_length) token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
assert len(input_ids) == max_length
assert len(attention_mask) == max_length
assert len(token_type_ids) == max_length
choices_features.append((input_ids, attention_mask, token_type_ids))
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
choices_features.append((tokens, input_ids, input_mask, segment_ids))
label = label_map[example.label] label = label_map[example.label]
if ex_index < 2: if ex_index < 2:
logger.info("*** Example ***") logger.info("*** Example ***")
logger.info("race_id: {}".format(example.example_id)) logger.info("race_id: {}".format(example.example_id))
for choice_idx, (tokens, input_ids, input_mask, segment_ids) in enumerate(choices_features): for choice_idx, (input_ids, attention_mask, token_type_ids) in enumerate(choices_features):
logger.info("choice: {}".format(choice_idx)) logger.info("choice: {}".format(choice_idx))
logger.info("tokens: {}".format(' '.join(tokens)))
logger.info("input_ids: {}".format(' '.join(map(str, input_ids)))) logger.info("input_ids: {}".format(' '.join(map(str, input_ids))))
logger.info("input_mask: {}".format(' '.join(map(str, input_mask)))) logger.info("attention_mask: {}".format(' '.join(map(str, attention_mask))))
logger.info("segment_ids: {}".format(' '.join(map(str, segment_ids)))) logger.info("token_type_ids: {}".format(' '.join(map(str, token_type_ids))))
logger.info("label: {}".format(label)) logger.info("label: {}".format(label))
features.append( features.append(
InputFeatures( InputFeatures(
example_id = example.example_id, example_id=example.example_id,
choices_features = choices_features, choices_features=choices_features,
label = label label=label,
) )
) )
return features return features
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
# However, since we'd better not to remove tokens of options and questions, you can choose to use a bigger
# length or only pop from context
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
logger.info('Attention! you are removing from token_b (swag task is ok). '
'If you are training ARC and RACE (you are poping question + options), '
'you need to try to use a bigger max seq length!')
tokens_b.pop()
processors = { processors = {
...@@ -456,7 +398,7 @@ processors = { ...@@ -456,7 +398,7 @@ processors = {
} }
GLUE_TASKS_NUM_LABELS = { MULTIPLE_CHOICE_TASKS_NUM_LABELS = {
"race", 4, "race", 4,
"swag", 4, "swag", 4,
"arc", 4 "arc", 4
......
...@@ -699,6 +699,7 @@ class PreTrainedTokenizer(object): ...@@ -699,6 +699,7 @@ class PreTrainedTokenizer(object):
max_length=None, max_length=None,
stride=0, stride=0,
truncate_first_sequence=True, truncate_first_sequence=True,
truncate_both_sequences=False,
return_tensors=None, return_tensors=None,
**kwargs): **kwargs):
""" """
...@@ -718,7 +719,7 @@ class PreTrainedTokenizer(object): ...@@ -718,7 +719,7 @@ class PreTrainedTokenizer(object):
max_length: if set to a number, will limit the total sequence returned so that it has a maximum length. max_length: if set to a number, will limit the total sequence returned so that it has a maximum length.
If there are overflowing tokens, those will be added to the returned dictionary If there are overflowing tokens, those will be added to the returned dictionary
stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens
from the main sequence returned. The value of this argument defined the number of additional tokens. from the main sequence returned. The value of this argument defines the number of additional tokens.
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
will be truncated. will be truncated.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
...@@ -731,6 +732,7 @@ class PreTrainedTokenizer(object): ...@@ -731,6 +732,7 @@ class PreTrainedTokenizer(object):
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens,
stride=stride, stride=stride,
truncate_first_sequence=truncate_first_sequence, truncate_first_sequence=truncate_first_sequence,
truncate_both_sequences=truncate_both_sequences,
return_tensors=return_tensors, return_tensors=return_tensors,
**kwargs) **kwargs)
...@@ -743,6 +745,7 @@ class PreTrainedTokenizer(object): ...@@ -743,6 +745,7 @@ class PreTrainedTokenizer(object):
max_length=None, max_length=None,
stride=0, stride=0,
truncate_first_sequence=True, truncate_first_sequence=True,
truncate_both_sequences=False,
return_tensors=None, return_tensors=None,
**kwargs): **kwargs):
""" """
...@@ -761,7 +764,7 @@ class PreTrainedTokenizer(object): ...@@ -761,7 +764,7 @@ class PreTrainedTokenizer(object):
max_length: if set to a number, will limit the total sequence returned so that it has a maximum length. max_length: if set to a number, will limit the total sequence returned so that it has a maximum length.
If there are overflowing tokens, those will be added to the returned dictionary If there are overflowing tokens, those will be added to the returned dictionary
stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens
from the main sequence returned. The value of this argument defined the number of additional tokens. from the main sequence returned. The value of this argument defines the number of additional tokens.
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
will be truncated. will be truncated.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
...@@ -788,11 +791,12 @@ class PreTrainedTokenizer(object): ...@@ -788,11 +791,12 @@ class PreTrainedTokenizer(object):
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens,
stride=stride, stride=stride,
truncate_first_sequence=truncate_first_sequence, truncate_first_sequence=truncate_first_sequence,
truncate_both_sequences=truncate_both_sequences,
return_tensors=return_tensors) return_tensors=return_tensors)
def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=False, stride=0, def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=False, stride=0,
truncate_first_sequence=True, return_tensors=None): truncate_first_sequence=True, truncate_both_sequences=True, return_tensors=None):
""" """
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
It adds special tokens, truncates It adds special tokens, truncates
...@@ -825,21 +829,30 @@ class PreTrainedTokenizer(object): ...@@ -825,21 +829,30 @@ class PreTrainedTokenizer(object):
encoded_inputs = {} encoded_inputs = {}
if max_length: if max_length:
n_added_tokens = self.num_added_tokens(pair=pair) if add_special_tokens else 0 n_added_tokens = self.num_added_tokens(pair=pair) if add_special_tokens else 0
if pair and n_added_tokens + (len_pair_ids if truncate_first_sequence else len_ids) >= max_length:
logger.warning( if n_added_tokens + len_ids + len_pair_ids > max_length:
"You supplied a pair of sequence in which the sequence that will not be truncated is longer than the maximum specified length." if truncate_both_sequences:
"This pair of sequences will not be truncated.") tokens_a, tokens_b = self._truncate_seq_pair(
else: copy.deepcopy(ids),
if n_added_tokens + len_ids + len_pair_ids > max_length: copy.deepcopy(pair_ids),
if truncate_first_sequence or not pair: max_length=max_length - n_added_tokens
encoded_inputs["overflowing_tokens"] = ids[max_length - len_pair_ids - n_added_tokens - stride:] )
ids = ids[:max_length - len_pair_ids - n_added_tokens] encoded_inputs["overflowing_tokens"] = ids[- (len_ids - len(tokens_a)):] + pair_ids[- (len_pair_ids - len(tokens_b)):]
elif not truncate_first_sequence and pair: ids = tokens_a
encoded_inputs["overflowing_tokens"] = pair_ids[max_length - len_ids - n_added_tokens - stride:] pair_ids = tokens_b
pair_ids = pair_ids[:max_length - len_ids - n_added_tokens] elif pair and n_added_tokens + (len_pair_ids if truncate_first_sequence else len_ids) >= max_length:
else: logger.warning(
logger.warning( "You supplied a pair of sequence in which the sequence that will not be truncated is longer than the maximum specified length."
"Cannot truncate second sequence as it is not provided. No truncation.") "This pair of sequences will not be truncated.")
elif truncate_first_sequence or not pair:
encoded_inputs["overflowing_tokens"] = ids[max_length - len_pair_ids - n_added_tokens - stride:]
ids = ids[:max_length - len_pair_ids - n_added_tokens]
elif not truncate_first_sequence and pair:
encoded_inputs["overflowing_tokens"] = pair_ids[max_length - len_ids - n_added_tokens - stride:]
pair_ids = pair_ids[:max_length - len_ids - n_added_tokens]
else:
logger.warning(
"Cannot truncate second sequence as it is not provided. No truncation.")
if add_special_tokens: if add_special_tokens:
sequence = self.add_special_tokens_sequence_pair(ids, pair_ids) if pair else self.add_special_tokens_single_sequence(ids) sequence = self.add_special_tokens_sequence_pair(ids, pair_ids) if pair else self.add_special_tokens_single_sequence(ids)
...@@ -862,6 +875,25 @@ class PreTrainedTokenizer(object): ...@@ -862,6 +875,25 @@ class PreTrainedTokenizer(object):
return encoded_inputs return encoded_inputs
def _truncate_seq_pair(self, tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
# However, since we'd better not to remove tokens of options and questions, you can choose to use a bigger
# length or only pop from context
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
return (tokens_a, tokens_b)
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1): def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1):
logger.warning("This tokenizer does not make use of special tokens.") logger.warning("This tokenizer does not make use of special tokens.")
return [0] * len(token_ids_0) + [1] * len(token_ids_1) return [0] * len(token_ids_0) + [1] * len(token_ids_1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment