# coding=utf-8 # Copyright 2021 The OneFlow Authors. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import os from .utils import DataProcessor, EncodePattern, InputExample, InputFeatures logger = logging.getLogger(__name__) def glue_convert_examples_to_features( examples, tokenizer, max_length, task=None, pattern=EncodePattern.bert_pattern, label_list=None, output_mode=None, ): if task is not None: processor = glue_processors[task]() if label_list is None: label_list = processor.get_labels() logger.info(f"Using label list {label_list} for task {task}") if output_mode is None: output_mode = glue_output_modes[task] logger.info(f"Using output mode {output_mode} for task {task}") label_map = {label: i for i, label in enumerate(label_list)} start_token = [] if tokenizer.start_token is None else [tokenizer.start_token] end_token = [] if tokenizer.end_token is None else [tokenizer.end_token] pad_id = tokenizer.pad_token_id if pattern == EncodePattern.bert_pattern: added_special_tokens = [2, 3] elif pattern == EncodePattern.roberta_pattern: added_special_tokens = [2, 4] else: raise KeyError("pattern is not a valid EncodePattern") features = [] for (ex_index, example) in enumerate(examples): if ex_index % 10000 == 0: logger.info("Writing example %d of %d" % (ex_index, len(examples))) tokens_a = tokenizer.tokenize(example.text_a) tokens_b = None if example.text_b: tokens_b = tokenizer.tokenize(example.text_b) _truncate_seq_pair(tokens_a, tokens_b, max_length - added_special_tokens[1]) else: if len(tokens_a) > max_length - added_special_tokens[0]: tokens_a = tokens_a[: (max_length - added_special_tokens[0])] if pattern is EncodePattern.bert_pattern: tokens = start_token + tokens_a + end_token token_type_ids = [0] * len(tokens) if tokens_b: tokens += tokens_b + end_token token_type_ids += [1] * (len(tokens) - len(token_type_ids)) elif pattern is EncodePattern.roberta_pattern: tokens = start_token + tokens_a + end_token token_type_ids = [0] * len(tokens) if tokens_b: tokens += end_token + tokens_b + end_token token_type_ids += [1] * (len(tokens) - len(token_type_ids)) else: raise KeyError("pattern is not a valid EncodePattern") input_ids = tokenizer.convert_tokens_to_ids(tokens) attention_mask = [1] * len(input_ids) padding_length = max_length - len(input_ids) input_ids = input_ids + ([pad_id] * padding_length) attention_mask = attention_mask + ([0] * padding_length) token_type_ids = token_type_ids + ([0] * padding_length) label = None if example.label is not None: if output_mode == "classification": label = label_map[example.label] elif output_mode == "regression": label = float(example.label) if ex_index < 5: logger.info("*** Example ***") logger.info("guid: %s" % (example.guid)) logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask])) logger.info("token_type_ids: %s" % " ".join([str(x) for x in token_type_ids])) logger.info("label: %s (id = %d)" % (example.label, label)) features.append( InputFeatures( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=label, ) ) return features def _truncate_seq_pair(tokens_a, tokens_b, max_length): while True: total_length = len(tokens_a) + len(tokens_b) if total_length <= max_length: break if len(tokens_a) > len(tokens_b): tokens_a.pop() else: tokens_b.pop() class MrpcProcessor(DataProcessor): """Processor for the MRPC data set (GLUE version). Sentence pair classification task. Determine whether the two sentences have the same meaning. """ def get_train_examples(self, data_dir): """See base class.""" return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") def get_dev_examples(self, data_dir): """See base class.""" return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") def get_test_examples(self, data_dir): """See base class.""" return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") def get_labels(self): """See base class.""" return ["0", "1"] def _create_examples(self, lines, set_type): """Creates examples for the training, dev and test sets.""" examples = [] for (i, line) in enumerate(lines): if i == 0: continue guid = f"{set_type}-{i}" text_a = line[3] text_b = line[4] label = None if set_type == "test" else line[0] examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) return examples class MnliProcessor(DataProcessor): """Processor for the MultiNLI data set (GLUE version). Sentence pair classification task. Given a premise sentence and a hypothesis sentence, the task is to predict whether the premise entails the hypothesis (entailment), contradicts the hypothesis (contradiction), or neither (neutral). """ def get_train_examples(self, data_dir): """See base class.""" return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") def get_dev_examples(self, data_dir): """See base class.""" return self._create_examples( self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched" ) def get_test_examples(self, data_dir): """See base class.""" return self._create_examples( self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test_matched" ) def get_labels(self): """See base class.""" return ["contradiction", "entailment", "neutral"] def _create_examples(self, lines, set_type): """Creates examples for the training, dev and test sets.""" examples = [] for (i, line) in enumerate(lines): if i == 0: continue guid = f"{set_type}-{line[0]}" text_a = line[8] text_b = line[9] label = None if set_type.startswith("test") else line[-1] examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) return examples class MnliMismatchedProcessor(MnliProcessor): """Processor for the MultiNLI Mismatched data set (GLUE version).""" def get_dev_examples(self, data_dir): """See base class.""" return self._create_examples( self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_mismatched", ) def get_test_examples(self, data_dir): """See base class.""" return self._create_examples( self._read_tsv(os.path.join(data_dir, "test_mismatched.tsv")), "test_mismatched", ) class ColaProcessor(DataProcessor): """Processor for the CoLA data set (GLUE version). Single sentence classification task. Each example is a sequence of words annotated with whether it is a grammatical English sentence. """ def get_train_examples(self, data_dir): """See base class.""" return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") def get_dev_examples(self, data_dir): """See base class.""" return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") def get_test_examples(self, data_dir): """See base class.""" return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") def get_labels(self): """See base class.""" return ["0", "1"] def _create_examples(self, lines, set_type): """Creates examples for the training, dev and test sets.""" test_mode = set_type == "test" if test_mode: lines = lines[1:] text_index = 1 if test_mode else 3 examples = [] for (i, line) in enumerate(lines): guid = f"{set_type}-{i}" text_a = line[text_index] label = None if test_mode else line[1] examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) return examples class Sst2Processor(DataProcessor): """Processor for the SST-2 data set (GLUE version). Single sentence classification task. The task is to predict the sentiment of a given sentence. We use the two-way (positive/negative) class split, and use only sentence-level labels. """ def get_train_examples(self, data_dir): """See base class.""" return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") def get_dev_examples(self, data_dir): """See base class.""" return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") def get_test_examples(self, data_dir): """See base class.""" return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") def get_labels(self): """See base class.""" return ["0", "1"] def _create_examples(self, lines, set_type): """Creates examples for the training, dev and test sets.""" examples = [] text_index = 1 if set_type == "test" else 0 for (i, line) in enumerate(lines): if i == 0: continue guid = f"{set_type}-{i}" text_a = line[text_index] label = None if set_type == "test" else line[1] examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) return examples class StsbProcessor(DataProcessor): """Processor for the STS-B data set (GLUE version). Sentence pair task but it is a regression task. This task is to predict the similarity score of two sentences. """ def get_train_examples(self, data_dir): """See base class.""" return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") def get_dev_examples(self, data_dir): """See base class.""" return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") def get_test_examples(self, data_dir): """See base class.""" return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") def get_labels(self): """See base class.""" return [None] def _create_examples(self, lines, set_type): """Creates examples for the training, dev and test sets.""" examples = [] for (i, line) in enumerate(lines): if i == 0: continue guid = f"{set_type}-{line[0]}" text_a = line[7] text_b = line[8] label = None if set_type == "test" else line[-1] examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) return examples class QqpProcessor(DataProcessor): """Processor for the QQP data set (GLUE version). Sentence pair classification task. The task is to determine whether a pair of questions are semantically equivalent. """ def get_train_examples(self, data_dir): """See base class.""" return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") def get_dev_examples(self, data_dir): """See base class.""" return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") def get_test_examples(self, data_dir): """See base class.""" return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") def get_labels(self): """See base class.""" return ["0", "1"] def _create_examples(self, lines, set_type): """Creates examples for the training, dev and test sets.""" test_mode = set_type == "test" q1_index = 1 if test_mode else 3 q2_index = 2 if test_mode else 4 examples = [] for (i, line) in enumerate(lines): if i == 0: continue guid = f"{set_type}-{line[0]}" try: text_a = line[q1_index] text_b = line[q2_index] label = None if test_mode else line[5] except IndexError: continue examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) return examples class QnliProcessor(DataProcessor): """Processor for the QNLI data set (GLUE version). Sentence pair classification task. The task is to determine whether the context sentence contains the answer to the question. """ def get_train_examples(self, data_dir): """See base class.""" return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") def get_dev_examples(self, data_dir): """See base class.""" return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") def get_test_examples(self, data_dir): """See base class.""" return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") def get_labels(self): """See base class.""" return ["entailment", "not_entailment"] def _create_examples(self, lines, set_type): """Creates examples for the training, dev and test sets.""" examples = [] for (i, line) in enumerate(lines): if i == 0: continue guid = f"{set_type}-{line[0]}" text_a = line[1] text_b = line[2] label = None if set_type == "test" else line[-1] examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) return examples class RteProcessor(DataProcessor): """Processor for the RTE data set (GLUE version). Sentence pair classification task. Recognizing Textual Entailment. Predict whether the two sentences is entailment or not entailment (neutral and contradiction). """ def get_train_examples(self, data_dir): """See base class.""" return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") def get_dev_examples(self, data_dir): """See base class.""" return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") def get_test_examples(self, data_dir): """See base class.""" return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") def get_labels(self): """See base class.""" return ["entailment", "not_entailment"] def _create_examples(self, lines, set_type): """Creates examples for the training, dev and test sets.""" examples = [] for (i, line) in enumerate(lines): if i == 0: continue guid = f"{set_type}-{line[0]}" text_a = line[1] text_b = line[2] label = None if set_type == "test" else line[-1] examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) return examples class WnliProcessor(DataProcessor): """Processor for the WNLI data set (GLUE version). Sentence pair classification task. The task is to predict if the sentence with the pronoun substituted is entailed by the original sentence. """ def get_train_examples(self, data_dir): """See base class.""" return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") def get_dev_examples(self, data_dir): """See base class.""" return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") def get_test_examples(self, data_dir): """See base class.""" return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") def get_labels(self): """See base class.""" return ["0", "1"] def _create_examples(self, lines, set_type): """Creates examples for the training, dev and test sets.""" examples = [] for (i, line) in enumerate(lines): if i == 0: continue guid = f"{set_type}-{line[0]}" text_a = line[1] text_b = line[2] label = None if set_type == "test" else line[-1] examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) return examples glue_tasks_num_labels = { "cola": 2, "mnli": 3, "mrpc": 2, "sst-2": 2, "sts-b": 1, "qqp": 2, "qnli": 2, "rte": 2, "wnli": 2, } glue_processors = { "cola": ColaProcessor, "mnli": MnliProcessor, "mnli-mm": MnliMismatchedProcessor, "mrpc": MrpcProcessor, "sst-2": Sst2Processor, "sts-b": StsbProcessor, "qqp": QqpProcessor, "qnli": QnliProcessor, "rte": RteProcessor, "wnli": WnliProcessor, } glue_output_modes = { "cola": "classification", "mnli": "classification", "mnli-mm": "classification", "mrpc": "classification", "sst-2": "classification", "sts-b": "regression", "qqp": "classification", "qnli": "classification", "rte": "classification", "wnli": "classification", }