"tests/git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "4040a074e0d37e5d9ac0289e8a4359609f0df3df"
Commit c413c90e authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Support to create classification dataset using sentence piece tokenizer.

PiperOrigin-RevId: 286805889
parent 01a51ee2
...@@ -68,6 +68,9 @@ class InputFeatures(object): ...@@ -68,6 +68,9 @@ class InputFeatures(object):
class DataProcessor(object): class DataProcessor(object):
"""Base class for data converters for sequence classification data sets.""" """Base class for data converters for sequence classification data sets."""
def __init__(self, process_text_fn=tokenization.convert_to_unicode):
self.process_text_fn = process_text_fn
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set.""" """Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError() raise NotImplementedError()
...@@ -103,7 +106,8 @@ class DataProcessor(object): ...@@ -103,7 +106,8 @@ class DataProcessor(object):
class XnliProcessor(DataProcessor): class XnliProcessor(DataProcessor):
"""Processor for the XNLI data set.""" """Processor for the XNLI data set."""
def __init__(self): def __init__(self, process_text_fn=tokenization.convert_to_unicode):
super(XnliProcessor, self).__init__(process_text_fn)
self.language = "zh" self.language = "zh"
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
...@@ -116,11 +120,11 @@ class XnliProcessor(DataProcessor): ...@@ -116,11 +120,11 @@ class XnliProcessor(DataProcessor):
if i == 0: if i == 0:
continue continue
guid = "train-%d" % (i) guid = "train-%d" % (i)
text_a = tokenization.convert_to_unicode(line[0]) text_a = self.process_text_fn(line[0])
text_b = tokenization.convert_to_unicode(line[1]) text_b = self.process_text_fn(line[1])
label = tokenization.convert_to_unicode(line[2]) label = self.process_text_fn(line[2])
if label == tokenization.convert_to_unicode("contradictory"): if label == self.process_text_fn("contradictory"):
label = tokenization.convert_to_unicode("contradiction") label = self.process_text_fn("contradiction")
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -133,12 +137,12 @@ class XnliProcessor(DataProcessor): ...@@ -133,12 +137,12 @@ class XnliProcessor(DataProcessor):
if i == 0: if i == 0:
continue continue
guid = "dev-%d" % (i) guid = "dev-%d" % (i)
language = tokenization.convert_to_unicode(line[0]) language = self.process_text_fn(line[0])
if language != tokenization.convert_to_unicode(self.language): if language != self.process_text_fn(self.language):
continue continue
text_a = tokenization.convert_to_unicode(line[6]) text_a = self.process_text_fn(line[6])
text_b = tokenization.convert_to_unicode(line[7]) text_b = self.process_text_fn(line[7])
label = tokenization.convert_to_unicode(line[1]) label = self.process_text_fn(line[1])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -187,13 +191,13 @@ class MnliProcessor(DataProcessor): ...@@ -187,13 +191,13 @@ class MnliProcessor(DataProcessor):
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, tokenization.convert_to_unicode(line[0])) guid = "%s-%s" % (set_type, self.process_text_fn(line[0]))
text_a = tokenization.convert_to_unicode(line[8]) text_a = self.process_text_fn(line[8])
text_b = tokenization.convert_to_unicode(line[9]) text_b = self.process_text_fn(line[9])
if set_type == "test": if set_type == "test":
label = "contradiction" label = "contradiction"
else: else:
label = tokenization.convert_to_unicode(line[-1]) label = self.process_text_fn(line[-1])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -233,12 +237,12 @@ class MrpcProcessor(DataProcessor): ...@@ -233,12 +237,12 @@ class MrpcProcessor(DataProcessor):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[3]) text_a = self.process_text_fn(line[3])
text_b = tokenization.convert_to_unicode(line[4]) text_b = self.process_text_fn(line[4])
if set_type == "test": if set_type == "test":
label = "0" label = "0"
else: else:
label = tokenization.convert_to_unicode(line[0]) label = self.process_text_fn(line[0])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -280,11 +284,11 @@ class ColaProcessor(DataProcessor): ...@@ -280,11 +284,11 @@ class ColaProcessor(DataProcessor):
continue continue
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
if set_type == "test": if set_type == "test":
text_a = tokenization.convert_to_unicode(line[1]) text_a = self.process_text_fn(line[1])
label = "0" label = "0"
else: else:
text_a = tokenization.convert_to_unicode(line[3]) text_a = self.process_text_fn(line[3])
label = tokenization.convert_to_unicode(line[1]) label = self.process_text_fn(line[1])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples return examples
...@@ -525,11 +529,10 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length): ...@@ -525,11 +529,10 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
def generate_tf_record_from_data_file(processor, def generate_tf_record_from_data_file(processor,
data_dir, data_dir,
vocab_file, tokenizer,
train_data_output_path=None, train_data_output_path=None,
eval_data_output_path=None, eval_data_output_path=None,
max_seq_length=128, max_seq_length=128):
do_lower_case=True):
"""Generates and saves training data into a tf record file. """Generates and saves training data into a tf record file.
Arguments: Arguments:
...@@ -537,14 +540,13 @@ def generate_tf_record_from_data_file(processor, ...@@ -537,14 +540,13 @@ def generate_tf_record_from_data_file(processor,
of `DataProcessor`. of `DataProcessor`.
data_dir: Directory that contains train/eval data to process. Data files data_dir: Directory that contains train/eval data to process. Data files
should be in from "dev.tsv", "test.tsv", or "train.tsv". should be in from "dev.tsv", "test.tsv", or "train.tsv".
vocab_file: Text file with words to be used for training/evaluation. tokenizer: The tokenizer to be applied on the data.
train_data_output_path: Output to which processed tf record for training train_data_output_path: Output to which processed tf record for training
will be saved. will be saved.
eval_data_output_path: Output to which processed tf record for evaluation eval_data_output_path: Output to which processed tf record for evaluation
will be saved. will be saved.
max_seq_length: Maximum sequence length of the to be generated max_seq_length: Maximum sequence length of the to be generated
training/eval data. training/eval data.
do_lower_case: Whether to lower case input text.
Returns: Returns:
A dictionary containing input meta data. A dictionary containing input meta data.
...@@ -552,8 +554,6 @@ def generate_tf_record_from_data_file(processor, ...@@ -552,8 +554,6 @@ def generate_tf_record_from_data_file(processor,
assert train_data_output_path or eval_data_output_path assert train_data_output_path or eval_data_output_path
label_list = processor.get_labels() label_list = processor.get_labels()
tokenizer = tokenization.FullTokenizer(
vocab_file=vocab_file, do_lower_case=do_lower_case)
assert train_data_output_path assert train_data_output_path
train_input_data_examples = processor.get_train_examples(data_dir) train_input_data_examples = processor.get_train_examples(data_dir)
file_based_convert_examples_to_features(train_input_data_examples, label_list, file_based_convert_examples_to_features(train_input_data_examples, label_list,
......
...@@ -18,6 +18,7 @@ from __future__ import absolute_import ...@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import functools
import json import json
from absl import app from absl import app
...@@ -29,6 +30,7 @@ from official.nlp.bert import classifier_data_lib ...@@ -29,6 +30,7 @@ from official.nlp.bert import classifier_data_lib
from official.nlp.bert import squad_lib as squad_lib_wp from official.nlp.bert import squad_lib as squad_lib_wp
# sentence-piece tokenizer based squad_lib # sentence-piece tokenizer based squad_lib
from official.nlp.bert import squad_lib_sp from official.nlp.bert import squad_lib_sp
from official.nlp.bert import tokenization
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -120,15 +122,24 @@ def generate_classifier_dataset(): ...@@ -120,15 +122,24 @@ def generate_classifier_dataset():
if task_name not in processors: if task_name not in processors:
raise ValueError("Task not found: %s" % (task_name)) raise ValueError("Task not found: %s" % (task_name))
processor = processors[task_name]() if FLAGS.tokenizer_impl == "word_piece":
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
processor_text_fn = tokenization.convert_to_unicode
else:
assert FLAGS.tokenizer_impl == "sentence_piece"
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
processor_text_fn = functools.partial(
tokenization.preprocess_text, lower=FLAGS.do_lower_case)
processor = processors[task_name](processor_text_fn)
return classifier_data_lib.generate_tf_record_from_data_file( return classifier_data_lib.generate_tf_record_from_data_file(
processor, processor,
FLAGS.input_data_dir, FLAGS.input_data_dir,
FLAGS.vocab_file, tokenizer,
train_data_output_path=FLAGS.train_data_output_path, train_data_output_path=FLAGS.train_data_output_path,
eval_data_output_path=FLAGS.eval_data_output_path, eval_data_output_path=FLAGS.eval_data_output_path,
max_seq_length=FLAGS.max_seq_length, max_seq_length=FLAGS.max_seq_length)
do_lower_case=FLAGS.do_lower_case)
def generate_squad_dataset(): def generate_squad_dataset():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment