Commit 19e8d0a0 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 327983381
parent 56186d78
......@@ -148,7 +148,7 @@ python ../data/create_finetuning_data.py \
--meta_data_file_path=${OUTPUT_DIR}/${TASK_NAME}_meta_data \
--fine_tuning_task_type=classification --max_seq_length=128 \
--classification_task_name=${TASK_NAME} \
--tokenizer_impl=sentence_piece
--tokenization=SentencePiece
```
* SQUAD
......@@ -177,7 +177,7 @@ python ../data/create_finetuning_data.py \
--train_data_output_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
--meta_data_file_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_meta_data \
--fine_tuning_task_type=squad --max_seq_length=384 \
--tokenizer_impl=sentence_piece
--tokenization=SentencePiece
```
## Fine-tuning with ALBERT
......
......@@ -142,10 +142,10 @@ flags.DEFINE_string("sp_model_file", "",
"The path to the model used by sentence piece tokenizer.")
flags.DEFINE_enum(
"tokenizer_impl", "word_piece", ["word_piece", "sentence_piece"],
"Specifies the tokenizer implementation, i.e., whehter to use word_piece "
"or sentence_piece tokenizer. Canonical BERT uses word_piece tokenizer, "
"while ALBERT uses sentence_piece tokenizer.")
"tokenization", "WordPiece", ["WordPiece", "SentencePiece"],
"Specifies the tokenizer implementation, i.e., whether to use WordPiece "
"or SentencePiece tokenizer. Canonical BERT uses WordPiece tokenizer, "
"while ALBERT uses SentencePiece tokenizer.")
flags.DEFINE_string("tfds_params", "",
"Comma-separated list of TFDS parameter assigments for "
......@@ -158,12 +158,12 @@ def generate_classifier_dataset():
assert (FLAGS.input_data_dir and FLAGS.classification_task_name
or FLAGS.tfds_params)
if FLAGS.tokenizer_impl == "word_piece":
if FLAGS.tokenization == "WordPiece":
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
processor_text_fn = tokenization.convert_to_unicode
else:
assert FLAGS.tokenizer_impl == "sentence_piece"
assert FLAGS.tokenization == "SentencePiece"
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
processor_text_fn = functools.partial(
tokenization.preprocess_text, lower=FLAGS.do_lower_case)
......@@ -226,12 +226,12 @@ def generate_classifier_dataset():
def generate_regression_dataset():
"""Generates regression dataset and returns input meta data."""
if FLAGS.tokenizer_impl == "word_piece":
if FLAGS.tokenization == "WordPiece":
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
processor_text_fn = tokenization.convert_to_unicode
else:
assert FLAGS.tokenizer_impl == "sentence_piece"
assert FLAGS.tokenization == "SentencePiece"
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
processor_text_fn = functools.partial(
tokenization.preprocess_text, lower=FLAGS.do_lower_case)
......@@ -255,13 +255,13 @@ def generate_regression_dataset():
def generate_squad_dataset():
"""Generates squad training dataset and returns input meta data."""
assert FLAGS.squad_data_file
if FLAGS.tokenizer_impl == "word_piece":
if FLAGS.tokenization == "WordPiece":
return squad_lib_wp.generate_tf_record_from_json_file(
FLAGS.squad_data_file, FLAGS.vocab_file, FLAGS.train_data_output_path,
FLAGS.max_seq_length, FLAGS.do_lower_case, FLAGS.max_query_length,
FLAGS.doc_stride, FLAGS.version_2_with_negative)
else:
assert FLAGS.tokenizer_impl == "sentence_piece"
assert FLAGS.tokenization == "SentencePiece"
return squad_lib_sp.generate_tf_record_from_json_file(
FLAGS.squad_data_file, FLAGS.sp_model_file,
FLAGS.train_data_output_path, FLAGS.max_seq_length, FLAGS.do_lower_case,
......@@ -271,12 +271,12 @@ def generate_squad_dataset():
def generate_retrieval_dataset():
"""Generate retrieval test and dev dataset and returns input meta data."""
assert (FLAGS.input_data_dir and FLAGS.retrieval_task_name)
if FLAGS.tokenizer_impl == "word_piece":
if FLAGS.tokenization == "WordPiece":
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
processor_text_fn = tokenization.convert_to_unicode
else:
assert FLAGS.tokenizer_impl == "sentence_piece"
assert FLAGS.tokenization == "SentencePiece"
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
processor_text_fn = functools.partial(
tokenization.preprocess_text, lower=FLAGS.do_lower_case)
......@@ -311,16 +311,16 @@ def generate_tagging_dataset():
if task_name not in processors:
raise ValueError("Task not found: %s" % task_name)
if FLAGS.tokenizer_impl == "word_piece":
if FLAGS.tokenization == "WordPiece":
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
processor_text_fn = tokenization.convert_to_unicode
elif FLAGS.tokenizer_impl == "sentence_piece":
elif FLAGS.tokenization == "SentencePiece":
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
processor_text_fn = functools.partial(
tokenization.preprocess_text, lower=FLAGS.do_lower_case)
else:
raise ValueError("Unsupported tokenizer_impl: %s" % FLAGS.tokenizer_impl)
raise ValueError("Unsupported tokenization: %s" % FLAGS.tokenization)
processor = processors[task_name]()
return tagging_data_lib.generate_tf_record_from_data_file(
......@@ -330,12 +330,12 @@ def generate_tagging_dataset():
def main(_):
if FLAGS.tokenizer_impl == "word_piece":
if FLAGS.tokenization == "WordPiece":
if not FLAGS.vocab_file:
raise ValueError(
"FLAG vocab_file for word-piece tokenizer is not specified.")
else:
assert FLAGS.tokenizer_impl == "sentence_piece"
assert FLAGS.tokenization == "SentencePiece"
if not FLAGS.sp_model_file:
raise ValueError(
"FLAG sp_model_file for sentence-piece tokenizer is not specified.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment