Commit c413c90e authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Support to create classification dataset using sentence piece tokenizer.

PiperOrigin-RevId: 286805889
parent 01a51ee2
...@@ -68,6 +68,9 @@ class InputFeatures(object): ...@@ -68,6 +68,9 @@ class InputFeatures(object):
class DataProcessor(object): class DataProcessor(object):
"""Base class for data converters for sequence classification data sets.""" """Base class for data converters for sequence classification data sets."""
def __init__(self, process_text_fn=tokenization.convert_to_unicode):
self.process_text_fn = process_text_fn
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set.""" """Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError() raise NotImplementedError()
...@@ -103,7 +106,8 @@ class DataProcessor(object): ...@@ -103,7 +106,8 @@ class DataProcessor(object):
class XnliProcessor(DataProcessor): class XnliProcessor(DataProcessor):
"""Processor for the XNLI data set.""" """Processor for the XNLI data set."""
def __init__(self): def __init__(self, process_text_fn=tokenization.convert_to_unicode):
super(XnliProcessor, self).__init__(process_text_fn)
self.language = "zh" self.language = "zh"
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
...@@ -116,11 +120,11 @@ class XnliProcessor(DataProcessor): ...@@ -116,11 +120,11 @@ class XnliProcessor(DataProcessor):
if i == 0: if i == 0:
continue continue
guid = "train-%d" % (i) guid = "train-%d" % (i)
text_a = tokenization.convert_to_unicode(line[0]) text_a = self.process_text_fn(line[0])
text_b = tokenization.convert_to_unicode(line[1]) text_b = self.process_text_fn(line[1])
label = tokenization.convert_to_unicode(line[2]) label = self.process_text_fn(line[2])
if label == tokenization.convert_to_unicode("contradictory"): if label == self.process_text_fn("contradictory"):
label = tokenization.convert_to_unicode("contradiction") label = self.process_text_fn("contradiction")
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -133,12 +137,12 @@ class XnliProcessor(DataProcessor): ...@@ -133,12 +137,12 @@ class XnliProcessor(DataProcessor):
if i == 0: if i == 0:
continue continue
guid = "dev-%d" % (i) guid = "dev-%d" % (i)
language = tokenization.convert_to_unicode(line[0]) language = self.process_text_fn(line[0])
if language != tokenization.convert_to_unicode(self.language): if language != self.process_text_fn(self.language):
continue continue
text_a = tokenization.convert_to_unicode(line[6]) text_a = self.process_text_fn(line[6])
text_b = tokenization.convert_to_unicode(line[7]) text_b = self.process_text_fn(line[7])
label = tokenization.convert_to_unicode(line[1]) label = self.process_text_fn(line[1])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -187,13 +191,13 @@ class MnliProcessor(DataProcessor): ...@@ -187,13 +191,13 @@ class MnliProcessor(DataProcessor):
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, tokenization.convert_to_unicode(line[0])) guid = "%s-%s" % (set_type, self.process_text_fn(line[0]))
text_a = tokenization.convert_to_unicode(line[8]) text_a = self.process_text_fn(line[8])
text_b = tokenization.convert_to_unicode(line[9]) text_b = self.process_text_fn(line[9])
if set_type == "test": if set_type == "test":
label = "contradiction" label = "contradiction"
else: else:
label = tokenization.convert_to_unicode(line[-1]) label = self.process_text_fn(line[-1])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -233,12 +237,12 @@ class MrpcProcessor(DataProcessor): ...@@ -233,12 +237,12 @@ class MrpcProcessor(DataProcessor):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[3]) text_a = self.process_text_fn(line[3])
text_b = tokenization.convert_to_unicode(line[4]) text_b = self.process_text_fn(line[4])
if set_type == "test": if set_type == "test":
label = "0" label = "0"
else: else:
label = tokenization.convert_to_unicode(line[0]) label = self.process_text_fn(line[0])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -280,11 +284,11 @@ class ColaProcessor(DataProcessor): ...@@ -280,11 +284,11 @@ class ColaProcessor(DataProcessor):
continue continue
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
if set_type == "test": if set_type == "test":
text_a = tokenization.convert_to_unicode(line[1]) text_a = self.process_text_fn(line[1])
label = "0" label = "0"
else: else:
text_a = tokenization.convert_to_unicode(line[3]) text_a = self.process_text_fn(line[3])
label = tokenization.convert_to_unicode(line[1]) label = self.process_text_fn(line[1])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples return examples
...@@ -525,11 +529,10 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length): ...@@ -525,11 +529,10 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
def generate_tf_record_from_data_file(processor, def generate_tf_record_from_data_file(processor,
data_dir, data_dir,
vocab_file, tokenizer,
train_data_output_path=None, train_data_output_path=None,
eval_data_output_path=None, eval_data_output_path=None,
max_seq_length=128, max_seq_length=128):
do_lower_case=True):
"""Generates and saves training data into a tf record file. """Generates and saves training data into a tf record file.
Arguments: Arguments:
...@@ -537,14 +540,13 @@ def generate_tf_record_from_data_file(processor, ...@@ -537,14 +540,13 @@ def generate_tf_record_from_data_file(processor,
of `DataProcessor`. of `DataProcessor`.
data_dir: Directory that contains train/eval data to process. Data files data_dir: Directory that contains train/eval data to process. Data files
should be in from "dev.tsv", "test.tsv", or "train.tsv". should be in from "dev.tsv", "test.tsv", or "train.tsv".
vocab_file: Text file with words to be used for training/evaluation. tokenizer: The tokenizer to be applied on the data.
train_data_output_path: Output to which processed tf record for training train_data_output_path: Output to which processed tf record for training
will be saved. will be saved.
eval_data_output_path: Output to which processed tf record for evaluation eval_data_output_path: Output to which processed tf record for evaluation
will be saved. will be saved.
max_seq_length: Maximum sequence length of the to be generated max_seq_length: Maximum sequence length of the to be generated
training/eval data. training/eval data.
do_lower_case: Whether to lower case input text.
Returns: Returns:
A dictionary containing input meta data. A dictionary containing input meta data.
...@@ -552,8 +554,6 @@ def generate_tf_record_from_data_file(processor, ...@@ -552,8 +554,6 @@ def generate_tf_record_from_data_file(processor,
assert train_data_output_path or eval_data_output_path assert train_data_output_path or eval_data_output_path
label_list = processor.get_labels() label_list = processor.get_labels()
tokenizer = tokenization.FullTokenizer(
vocab_file=vocab_file, do_lower_case=do_lower_case)
assert train_data_output_path assert train_data_output_path
train_input_data_examples = processor.get_train_examples(data_dir) train_input_data_examples = processor.get_train_examples(data_dir)
file_based_convert_examples_to_features(train_input_data_examples, label_list, file_based_convert_examples_to_features(train_input_data_examples, label_list,
......
...@@ -18,6 +18,7 @@ from __future__ import absolute_import ...@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import functools
import json import json
from absl import app from absl import app
...@@ -29,6 +30,7 @@ from official.nlp.bert import classifier_data_lib ...@@ -29,6 +30,7 @@ from official.nlp.bert import classifier_data_lib
from official.nlp.bert import squad_lib as squad_lib_wp from official.nlp.bert import squad_lib as squad_lib_wp
# sentence-piece tokenizer based squad_lib # sentence-piece tokenizer based squad_lib
from official.nlp.bert import squad_lib_sp from official.nlp.bert import squad_lib_sp
from official.nlp.bert import tokenization
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -120,15 +122,24 @@ def generate_classifier_dataset(): ...@@ -120,15 +122,24 @@ def generate_classifier_dataset():
if task_name not in processors: if task_name not in processors:
raise ValueError("Task not found: %s" % (task_name)) raise ValueError("Task not found: %s" % (task_name))
processor = processors[task_name]() if FLAGS.tokenizer_impl == "word_piece":
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
processor_text_fn = tokenization.convert_to_unicode
else:
assert FLAGS.tokenizer_impl == "sentence_piece"
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
processor_text_fn = functools.partial(
tokenization.preprocess_text, lower=FLAGS.do_lower_case)
processor = processors[task_name](processor_text_fn)
return classifier_data_lib.generate_tf_record_from_data_file( return classifier_data_lib.generate_tf_record_from_data_file(
processor, processor,
FLAGS.input_data_dir, FLAGS.input_data_dir,
FLAGS.vocab_file, tokenizer,
train_data_output_path=FLAGS.train_data_output_path, train_data_output_path=FLAGS.train_data_output_path,
eval_data_output_path=FLAGS.eval_data_output_path, eval_data_output_path=FLAGS.eval_data_output_path,
max_seq_length=FLAGS.max_seq_length, max_seq_length=FLAGS.max_seq_length)
do_lower_case=FLAGS.do_lower_case)
def generate_squad_dataset(): def generate_squad_dataset():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment