Commit c413c90e authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Support to create classification dataset using sentence piece tokenizer.

PiperOrigin-RevId: 286805889
parent 01a51ee2
......@@ -68,6 +68,9 @@ class InputFeatures(object):
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
def __init__(self, process_text_fn=tokenization.convert_to_unicode):
self.process_text_fn = process_text_fn
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
......@@ -103,7 +106,8 @@ class DataProcessor(object):
class XnliProcessor(DataProcessor):
"""Processor for the XNLI data set."""
def __init__(self):
def __init__(self, process_text_fn=tokenization.convert_to_unicode):
super(XnliProcessor, self).__init__(process_text_fn)
self.language = "zh"
def get_train_examples(self, data_dir):
......@@ -116,11 +120,11 @@ class XnliProcessor(DataProcessor):
if i == 0:
continue
guid = "train-%d" % (i)
text_a = tokenization.convert_to_unicode(line[0])
text_b = tokenization.convert_to_unicode(line[1])
label = tokenization.convert_to_unicode(line[2])
if label == tokenization.convert_to_unicode("contradictory"):
label = tokenization.convert_to_unicode("contradiction")
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
if label == self.process_text_fn("contradictory"):
label = self.process_text_fn("contradiction")
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
......@@ -133,12 +137,12 @@ class XnliProcessor(DataProcessor):
if i == 0:
continue
guid = "dev-%d" % (i)
language = tokenization.convert_to_unicode(line[0])
if language != tokenization.convert_to_unicode(self.language):
language = self.process_text_fn(line[0])
if language != self.process_text_fn(self.language):
continue
text_a = tokenization.convert_to_unicode(line[6])
text_b = tokenization.convert_to_unicode(line[7])
label = tokenization.convert_to_unicode(line[1])
text_a = self.process_text_fn(line[6])
text_b = self.process_text_fn(line[7])
label = self.process_text_fn(line[1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
......@@ -187,13 +191,13 @@ class MnliProcessor(DataProcessor):
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, tokenization.convert_to_unicode(line[0]))
text_a = tokenization.convert_to_unicode(line[8])
text_b = tokenization.convert_to_unicode(line[9])
guid = "%s-%s" % (set_type, self.process_text_fn(line[0]))
text_a = self.process_text_fn(line[8])
text_b = self.process_text_fn(line[9])
if set_type == "test":
label = "contradiction"
else:
label = tokenization.convert_to_unicode(line[-1])
label = self.process_text_fn(line[-1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
......@@ -233,12 +237,12 @@ class MrpcProcessor(DataProcessor):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[3])
text_b = tokenization.convert_to_unicode(line[4])
text_a = self.process_text_fn(line[3])
text_b = self.process_text_fn(line[4])
if set_type == "test":
label = "0"
else:
label = tokenization.convert_to_unicode(line[0])
label = self.process_text_fn(line[0])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
......@@ -280,11 +284,11 @@ class ColaProcessor(DataProcessor):
continue
guid = "%s-%s" % (set_type, i)
if set_type == "test":
text_a = tokenization.convert_to_unicode(line[1])
text_a = self.process_text_fn(line[1])
label = "0"
else:
text_a = tokenization.convert_to_unicode(line[3])
label = tokenization.convert_to_unicode(line[1])
text_a = self.process_text_fn(line[3])
label = self.process_text_fn(line[1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
......@@ -525,11 +529,10 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
def generate_tf_record_from_data_file(processor,
data_dir,
vocab_file,
tokenizer,
train_data_output_path=None,
eval_data_output_path=None,
max_seq_length=128,
do_lower_case=True):
max_seq_length=128):
"""Generates and saves training data into a tf record file.
Arguments:
......@@ -537,14 +540,13 @@ def generate_tf_record_from_data_file(processor,
of `DataProcessor`.
data_dir: Directory that contains train/eval data to process. Data files
should be in from "dev.tsv", "test.tsv", or "train.tsv".
vocab_file: Text file with words to be used for training/evaluation.
tokenizer: The tokenizer to be applied on the data.
train_data_output_path: Output to which processed tf record for training
will be saved.
eval_data_output_path: Output to which processed tf record for evaluation
will be saved.
max_seq_length: Maximum sequence length of the to be generated
training/eval data.
do_lower_case: Whether to lower case input text.
Returns:
A dictionary containing input meta data.
......@@ -552,8 +554,6 @@ def generate_tf_record_from_data_file(processor,
assert train_data_output_path or eval_data_output_path
label_list = processor.get_labels()
tokenizer = tokenization.FullTokenizer(
vocab_file=vocab_file, do_lower_case=do_lower_case)
assert train_data_output_path
train_input_data_examples = processor.get_train_examples(data_dir)
file_based_convert_examples_to_features(train_input_data_examples, label_list,
......
......@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import json
from absl import app
......@@ -29,6 +30,7 @@ from official.nlp.bert import classifier_data_lib
from official.nlp.bert import squad_lib as squad_lib_wp
# sentence-piece tokenizer based squad_lib
from official.nlp.bert import squad_lib_sp
from official.nlp.bert import tokenization
FLAGS = flags.FLAGS
......@@ -120,15 +122,24 @@ def generate_classifier_dataset():
if task_name not in processors:
raise ValueError("Task not found: %s" % (task_name))
processor = processors[task_name]()
if FLAGS.tokenizer_impl == "word_piece":
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
processor_text_fn = tokenization.convert_to_unicode
else:
assert FLAGS.tokenizer_impl == "sentence_piece"
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
processor_text_fn = functools.partial(
tokenization.preprocess_text, lower=FLAGS.do_lower_case)
processor = processors[task_name](processor_text_fn)
return classifier_data_lib.generate_tf_record_from_data_file(
processor,
FLAGS.input_data_dir,
FLAGS.vocab_file,
tokenizer,
train_data_output_path=FLAGS.train_data_output_path,
eval_data_output_path=FLAGS.eval_data_output_path,
max_seq_length=FLAGS.max_seq_length,
do_lower_case=FLAGS.do_lower_case)
max_seq_length=FLAGS.max_seq_length)
def generate_squad_dataset():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment