Commit 19e8d0a0 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 327983381
parent 56186d78
...@@ -148,7 +148,7 @@ python ../data/create_finetuning_data.py \ ...@@ -148,7 +148,7 @@ python ../data/create_finetuning_data.py \
--meta_data_file_path=${OUTPUT_DIR}/${TASK_NAME}_meta_data \ --meta_data_file_path=${OUTPUT_DIR}/${TASK_NAME}_meta_data \
--fine_tuning_task_type=classification --max_seq_length=128 \ --fine_tuning_task_type=classification --max_seq_length=128 \
--classification_task_name=${TASK_NAME} \ --classification_task_name=${TASK_NAME} \
--tokenizer_impl=sentence_piece --tokenization=SentencePiece
``` ```
* SQUAD * SQUAD
...@@ -177,7 +177,7 @@ python ../data/create_finetuning_data.py \ ...@@ -177,7 +177,7 @@ python ../data/create_finetuning_data.py \
--train_data_output_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_train.tf_record \ --train_data_output_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
--meta_data_file_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_meta_data \ --meta_data_file_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_meta_data \
--fine_tuning_task_type=squad --max_seq_length=384 \ --fine_tuning_task_type=squad --max_seq_length=384 \
--tokenizer_impl=sentence_piece --tokenization=SentencePiece
``` ```
## Fine-tuning with ALBERT ## Fine-tuning with ALBERT
......
...@@ -142,10 +142,10 @@ flags.DEFINE_string("sp_model_file", "", ...@@ -142,10 +142,10 @@ flags.DEFINE_string("sp_model_file", "",
"The path to the model used by sentence piece tokenizer.") "The path to the model used by sentence piece tokenizer.")
flags.DEFINE_enum( flags.DEFINE_enum(
"tokenizer_impl", "word_piece", ["word_piece", "sentence_piece"], "tokenization", "WordPiece", ["WordPiece", "SentencePiece"],
"Specifies the tokenizer implementation, i.e., whehter to use word_piece " "Specifies the tokenizer implementation, i.e., whether to use WordPiece "
"or sentence_piece tokenizer. Canonical BERT uses word_piece tokenizer, " "or SentencePiece tokenizer. Canonical BERT uses WordPiece tokenizer, "
"while ALBERT uses sentence_piece tokenizer.") "while ALBERT uses SentencePiece tokenizer.")
flags.DEFINE_string("tfds_params", "", flags.DEFINE_string("tfds_params", "",
"Comma-separated list of TFDS parameter assigments for " "Comma-separated list of TFDS parameter assigments for "
...@@ -158,12 +158,12 @@ def generate_classifier_dataset(): ...@@ -158,12 +158,12 @@ def generate_classifier_dataset():
assert (FLAGS.input_data_dir and FLAGS.classification_task_name assert (FLAGS.input_data_dir and FLAGS.classification_task_name
or FLAGS.tfds_params) or FLAGS.tfds_params)
if FLAGS.tokenizer_impl == "word_piece": if FLAGS.tokenization == "WordPiece":
tokenizer = tokenization.FullTokenizer( tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
processor_text_fn = tokenization.convert_to_unicode processor_text_fn = tokenization.convert_to_unicode
else: else:
assert FLAGS.tokenizer_impl == "sentence_piece" assert FLAGS.tokenization == "SentencePiece"
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file) tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
processor_text_fn = functools.partial( processor_text_fn = functools.partial(
tokenization.preprocess_text, lower=FLAGS.do_lower_case) tokenization.preprocess_text, lower=FLAGS.do_lower_case)
...@@ -226,12 +226,12 @@ def generate_classifier_dataset(): ...@@ -226,12 +226,12 @@ def generate_classifier_dataset():
def generate_regression_dataset(): def generate_regression_dataset():
"""Generates regression dataset and returns input meta data.""" """Generates regression dataset and returns input meta data."""
if FLAGS.tokenizer_impl == "word_piece": if FLAGS.tokenization == "WordPiece":
tokenizer = tokenization.FullTokenizer( tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
processor_text_fn = tokenization.convert_to_unicode processor_text_fn = tokenization.convert_to_unicode
else: else:
assert FLAGS.tokenizer_impl == "sentence_piece" assert FLAGS.tokenization == "SentencePiece"
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file) tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
processor_text_fn = functools.partial( processor_text_fn = functools.partial(
tokenization.preprocess_text, lower=FLAGS.do_lower_case) tokenization.preprocess_text, lower=FLAGS.do_lower_case)
...@@ -255,13 +255,13 @@ def generate_regression_dataset(): ...@@ -255,13 +255,13 @@ def generate_regression_dataset():
def generate_squad_dataset(): def generate_squad_dataset():
"""Generates squad training dataset and returns input meta data.""" """Generates squad training dataset and returns input meta data."""
assert FLAGS.squad_data_file assert FLAGS.squad_data_file
if FLAGS.tokenizer_impl == "word_piece": if FLAGS.tokenization == "WordPiece":
return squad_lib_wp.generate_tf_record_from_json_file( return squad_lib_wp.generate_tf_record_from_json_file(
FLAGS.squad_data_file, FLAGS.vocab_file, FLAGS.train_data_output_path, FLAGS.squad_data_file, FLAGS.vocab_file, FLAGS.train_data_output_path,
FLAGS.max_seq_length, FLAGS.do_lower_case, FLAGS.max_query_length, FLAGS.max_seq_length, FLAGS.do_lower_case, FLAGS.max_query_length,
FLAGS.doc_stride, FLAGS.version_2_with_negative) FLAGS.doc_stride, FLAGS.version_2_with_negative)
else: else:
assert FLAGS.tokenizer_impl == "sentence_piece" assert FLAGS.tokenization == "SentencePiece"
return squad_lib_sp.generate_tf_record_from_json_file( return squad_lib_sp.generate_tf_record_from_json_file(
FLAGS.squad_data_file, FLAGS.sp_model_file, FLAGS.squad_data_file, FLAGS.sp_model_file,
FLAGS.train_data_output_path, FLAGS.max_seq_length, FLAGS.do_lower_case, FLAGS.train_data_output_path, FLAGS.max_seq_length, FLAGS.do_lower_case,
...@@ -271,12 +271,12 @@ def generate_squad_dataset(): ...@@ -271,12 +271,12 @@ def generate_squad_dataset():
def generate_retrieval_dataset(): def generate_retrieval_dataset():
"""Generate retrieval test and dev dataset and returns input meta data.""" """Generate retrieval test and dev dataset and returns input meta data."""
assert (FLAGS.input_data_dir and FLAGS.retrieval_task_name) assert (FLAGS.input_data_dir and FLAGS.retrieval_task_name)
if FLAGS.tokenizer_impl == "word_piece": if FLAGS.tokenization == "WordPiece":
tokenizer = tokenization.FullTokenizer( tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
processor_text_fn = tokenization.convert_to_unicode processor_text_fn = tokenization.convert_to_unicode
else: else:
assert FLAGS.tokenizer_impl == "sentence_piece" assert FLAGS.tokenization == "SentencePiece"
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file) tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
processor_text_fn = functools.partial( processor_text_fn = functools.partial(
tokenization.preprocess_text, lower=FLAGS.do_lower_case) tokenization.preprocess_text, lower=FLAGS.do_lower_case)
...@@ -311,16 +311,16 @@ def generate_tagging_dataset(): ...@@ -311,16 +311,16 @@ def generate_tagging_dataset():
if task_name not in processors: if task_name not in processors:
raise ValueError("Task not found: %s" % task_name) raise ValueError("Task not found: %s" % task_name)
if FLAGS.tokenizer_impl == "word_piece": if FLAGS.tokenization == "WordPiece":
tokenizer = tokenization.FullTokenizer( tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
processor_text_fn = tokenization.convert_to_unicode processor_text_fn = tokenization.convert_to_unicode
elif FLAGS.tokenizer_impl == "sentence_piece": elif FLAGS.tokenization == "SentencePiece":
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file) tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
processor_text_fn = functools.partial( processor_text_fn = functools.partial(
tokenization.preprocess_text, lower=FLAGS.do_lower_case) tokenization.preprocess_text, lower=FLAGS.do_lower_case)
else: else:
raise ValueError("Unsupported tokenizer_impl: %s" % FLAGS.tokenizer_impl) raise ValueError("Unsupported tokenization: %s" % FLAGS.tokenization)
processor = processors[task_name]() processor = processors[task_name]()
return tagging_data_lib.generate_tf_record_from_data_file( return tagging_data_lib.generate_tf_record_from_data_file(
...@@ -330,12 +330,12 @@ def generate_tagging_dataset(): ...@@ -330,12 +330,12 @@ def generate_tagging_dataset():
def main(_): def main(_):
if FLAGS.tokenizer_impl == "word_piece": if FLAGS.tokenization == "WordPiece":
if not FLAGS.vocab_file: if not FLAGS.vocab_file:
raise ValueError( raise ValueError(
"FLAG vocab_file for word-piece tokenizer is not specified.") "FLAG vocab_file for word-piece tokenizer is not specified.")
else: else:
assert FLAGS.tokenizer_impl == "sentence_piece" assert FLAGS.tokenization == "SentencePiece"
if not FLAGS.sp_model_file: if not FLAGS.sp_model_file:
raise ValueError( raise ValueError(
"FLAG sp_model_file for sentence-piece tokenizer is not specified.") "FLAG sp_model_file for sentence-piece tokenizer is not specified.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment