Commit 2bd8de61 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 381396116
parent 59c35095
......@@ -1299,6 +1299,139 @@ class SuperGLUERTEProcessor(DefaultGLUEDataProcessor):
return examples
class WiCInputExample(InputExample):
"""Processor for the WiC dataset (SuperGLUE version)."""
def __init__(self,
guid,
text_a,
text_b=None,
label=None,
word=None,
weight=None,
example_id=None):
"""A single training/test example for simple seq regression/classification."""
super(WiCInputExample, self).__init__(guid, text_a, text_b, label, weight,
example_id)
self.word = word
class WiCProcessor(DefaultGLUEDataProcessor):
"""Processor for the RTE dataset (SuperGLUE version)."""
def get_labels(self):
"""Not used."""
return []
@staticmethod
def get_processor_name():
"""See base class."""
return "RTESuperGLUE"
def _create_examples_tfds(self, set_type):
"""Creates examples for the training/dev/test sets."""
examples = []
dataset = tfds.load(
"super_glue/wic", split=set_type, try_gcs=True).as_numpy_iterator()
for example in dataset:
guid = "%s-%s" % (set_type, self.process_text_fn(str(example["idx"])))
text_a = self.process_text_fn(example["sentence1"])
text_b = self.process_text_fn(example["sentence2"])
word = self.process_text_fn(example["word"])
label = 0
if set_type != "test":
label = example["label"]
examples.append(
WiCInputExample(
guid=guid, text_a=text_a, text_b=text_b, word=word, label=label))
return examples
def featurize_example(self, ex_index, example, label_list, max_seq_length,
tokenizer):
"""Here we concate sentence1, sentence2, word together with [SEP] tokens."""
del label_list
tokens_a = tokenizer.tokenize(example.text_a)
tokens_b = tokenizer.tokenize(example.text_b)
tokens_word = tokenizer.tokenize(example.word)
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# Account for [CLS], [SEP], [SEP], [SEP] with "- 4"
# Here we only pop out the first two sentence tokens.
_truncate_seq_pair(tokens_a, tokens_b,
max_seq_length - 4 - len(tokens_word))
seg_id_a = 0
seg_id_b = 1
seg_id_c = 2
seg_id_cls = 0
seg_id_pad = 0
tokens = []
segment_ids = []
tokens.append("[CLS]")
segment_ids.append(seg_id_cls)
for token in tokens_a:
tokens.append(token)
segment_ids.append(seg_id_a)
tokens.append("[SEP]")
segment_ids.append(seg_id_a)
for token in tokens_b:
tokens.append(token)
segment_ids.append(seg_id_b)
tokens.append("[SEP]")
segment_ids.append(seg_id_b)
for token in tokens_word:
tokens.append(token)
segment_ids.append(seg_id_c)
tokens.append("[SEP]")
segment_ids.append(seg_id_c)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask = [1] * len(input_ids)
# Zero-pad up to the sequence length.
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(seg_id_pad)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
label_id = example.label
if ex_index < 5:
logging.info("*** Example ***")
logging.info("guid: %s", (example.guid))
logging.info("tokens: %s",
" ".join([tokenization.printable_text(x) for x in tokens]))
logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
logging.info("label: %s (id = %s)", example.label, str(label_id))
logging.info("weight: %s", example.weight)
logging.info("example_id: %s", example.example_id)
feature = InputFeatures(
input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
label_id=label_id,
is_real_example=True,
weight=example.weight,
example_id=example.example_id)
return feature
def file_based_convert_examples_to_features(examples,
label_list,
max_seq_length,
......
......@@ -39,6 +39,7 @@ class BertClassifierLibTest(tf.test.TestCase, parameterized.TestCase):
"CB": classifier_data_lib.CBProcessor,
"SUPERGLUE-RTE": classifier_data_lib.SuperGLUERTEProcessor,
"BOOLQ": classifier_data_lib.BoolQProcessor,
"WIC": classifier_data_lib.WiCProcessor,
}
vocab_tokens = [
......@@ -55,6 +56,7 @@ class BertClassifierLibTest(tf.test.TestCase, parameterized.TestCase):
{"task_type": "CB"},
{"task_type": "BOOLQ"},
{"task_type": "SUPERGLUE-RTE"},
{"task_type": "WIC"},
)
def test_generate_dataset_from_tfds_processor(self, task_type):
with tfds.testing.mock_data(num_examples=5):
......
......@@ -50,7 +50,7 @@ flags.DEFINE_enum(
"classification_task_name", "MNLI", [
"AX", "COLA", "IMDB", "MNLI", "MRPC", "PAWS-X", "QNLI", "QQP", "RTE",
"SST-2", "STS-B", "WNLI", "XNLI", "XTREME-XNLI", "XTREME-PAWS-X",
"AX-g", "SUPERGLUE-RTE", "CB", "BoolQ"
"AX-g", "SUPERGLUE-RTE", "CB", "BoolQ", "WIC"
], "The name of the task to train BERT classifier. The "
"difference between XTREME-XNLI and XNLI is: 1. the format "
"of input tsv files; 2. the dev set for XTREME is english "
......@@ -174,8 +174,20 @@ flags.DEFINE_string(
def generate_classifier_dataset():
"""Generates classifier dataset and returns input meta data."""
if FLAGS.classification_task_name in [
"COLA", "WNLI", "SST-2", "MRPC", "QQP", "STS-B", "MNLI", "QNLI", "RTE",
"AX", "SUPERGLUE-RTE", "CB", "BoolQ"
"COLA",
"WNLI",
"SST-2",
"MRPC",
"QQP",
"STS-B",
"MNLI",
"QNLI",
"RTE",
"AX",
"SUPERGLUE-RTE",
"CB",
"BoolQ",
"WIC",
]:
assert not FLAGS.input_data_dir or FLAGS.tfds_params
else:
......@@ -254,6 +266,8 @@ def generate_classifier_dataset():
classifier_data_lib.CBProcessor,
"boolq":
classifier_data_lib.BoolQProcessor,
"wic":
classifier_data_lib.WnliProcessor,
}
task_name = FLAGS.classification_task_name.lower()
if task_name not in processors:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment