Commit 89031e1a authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

tf.compat.v1 for preprocess_pretrain_data.

PiperOrigin-RevId: 359103541
parent d2d32f46
...@@ -22,14 +22,15 @@ import random ...@@ -22,14 +22,15 @@ import random
# Import libraries # Import libraries
from absl import app from absl import app
from absl import flags from absl import flags
import absl.logging as _logging # pylint: disable=unused-import from absl import logging
import numpy as np import numpy as np
import tensorflow.compat.v1 as tf
import tensorflow.google as tf
from official.nlp.xlnet import preprocess_utils
import sentencepiece as spm import sentencepiece as spm
from official.nlp.xlnet import preprocess_utils
FLAGS = flags.FLAGS
special_symbols = { special_symbols = {
...@@ -89,6 +90,7 @@ def format_filename(prefix, bsz_per_host, seq_len, bi_data, suffix, ...@@ -89,6 +90,7 @@ def format_filename(prefix, bsz_per_host, seq_len, bi_data, suffix,
def _create_data(idx, input_paths): def _create_data(idx, input_paths):
"""Creates data."""
# Load sentence-piece model # Load sentence-piece model
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
sp.Load(FLAGS.sp_path) sp.Load(FLAGS.sp_path)
...@@ -98,10 +100,10 @@ def _create_data(idx, input_paths): ...@@ -98,10 +100,10 @@ def _create_data(idx, input_paths):
for input_path in input_paths: for input_path in input_paths:
input_data, sent_ids = [], [] input_data, sent_ids = [], []
sent_id, line_cnt = True, 0 sent_id, line_cnt = True, 0
tf.logging.info("Processing %s", input_path) logging.info("Processing %s", input_path)
for line in tf.gfile.Open(input_path): for line in tf.gfile.Open(input_path):
if line_cnt % 100000 == 0: if line_cnt % 100000 == 0:
tf.logging.info("Loading line %d", line_cnt) logging.info("Loading line %d", line_cnt)
line_cnt += 1 line_cnt += 1
if not line.strip(): if not line.strip():
...@@ -122,7 +124,7 @@ def _create_data(idx, input_paths): ...@@ -122,7 +124,7 @@ def _create_data(idx, input_paths):
sent_ids.extend([sent_id] * len(cur_sent)) sent_ids.extend([sent_id] * len(cur_sent))
sent_id = not sent_id sent_id = not sent_id
tf.logging.info("Finish with line %d", line_cnt) logging.info("Finish with line %d", line_cnt)
if line_cnt == 0: if line_cnt == 0:
continue continue
...@@ -132,7 +134,7 @@ def _create_data(idx, input_paths): ...@@ -132,7 +134,7 @@ def _create_data(idx, input_paths):
total_line_cnt += line_cnt total_line_cnt += line_cnt
input_shards.append((input_data, sent_ids)) input_shards.append((input_data, sent_ids))
tf.logging.info("[Task %d] Total number line: %d", idx, total_line_cnt) logging.info("[Task %d] Total number line: %d", idx, total_line_cnt)
tfrecord_dir = os.path.join(FLAGS.save_dir, "tfrecords") tfrecord_dir = os.path.join(FLAGS.save_dir, "tfrecords")
...@@ -142,7 +144,7 @@ def _create_data(idx, input_paths): ...@@ -142,7 +144,7 @@ def _create_data(idx, input_paths):
np.random.seed(100 * FLAGS.task + FLAGS.pass_id) np.random.seed(100 * FLAGS.task + FLAGS.pass_id)
perm_indices = np.random.permutation(len(input_shards)) perm_indices = np.random.permutation(len(input_shards))
tf.logging.info("Using perm indices %s for pass %d", logging.info("Using perm indices %s for pass %d",
perm_indices.tolist(), FLAGS.pass_id) perm_indices.tolist(), FLAGS.pass_id)
input_data_list, sent_ids_list = [], [] input_data_list, sent_ids_list = [], []
...@@ -185,6 +187,7 @@ def _create_data(idx, input_paths): ...@@ -185,6 +187,7 @@ def _create_data(idx, input_paths):
def create_data(_): def create_data(_):
"""Creates pretrain data."""
# Validate FLAGS # Validate FLAGS
assert FLAGS.bsz_per_host % FLAGS.num_core_per_host == 0 assert FLAGS.bsz_per_host % FLAGS.num_core_per_host == 0
if not FLAGS.use_tpu: if not FLAGS.use_tpu:
...@@ -221,15 +224,15 @@ def create_data(_): ...@@ -221,15 +224,15 @@ def create_data(_):
# Interleavely split the work into FLAGS.num_task splits # Interleavely split the work into FLAGS.num_task splits
file_paths = sorted(tf.gfile.Glob(FLAGS.input_glob)) file_paths = sorted(tf.gfile.Glob(FLAGS.input_glob))
tf.logging.info("Use glob: %s", FLAGS.input_glob) logging.info("Use glob: %s", FLAGS.input_glob)
tf.logging.info("Find %d files: %s", len(file_paths), file_paths) logging.info("Find %d files: %s", len(file_paths), file_paths)
task_file_paths = file_paths[FLAGS.task::FLAGS.num_task] task_file_paths = file_paths[FLAGS.task::FLAGS.num_task]
if not task_file_paths: if not task_file_paths:
tf.logging.info("Exit: task %d has no file to process.", FLAGS.task) logging.info("Exit: task %d has no file to process.", FLAGS.task)
return return
tf.logging.info("Task %d process %d files: %s", logging.info("Task %d process %d files: %s",
FLAGS.task, len(task_file_paths), task_file_paths) FLAGS.task, len(task_file_paths), task_file_paths)
record_info = _create_data(FLAGS.task, task_file_paths) record_info = _create_data(FLAGS.task, task_file_paths)
...@@ -253,6 +256,7 @@ def create_data(_): ...@@ -253,6 +256,7 @@ def create_data(_):
def batchify(data, bsz_per_host, sent_ids=None): def batchify(data, bsz_per_host, sent_ids=None):
"""Creates batches."""
num_step = len(data) // bsz_per_host num_step = len(data) // bsz_per_host
data = data[:bsz_per_host * num_step] data = data[:bsz_per_host * num_step]
data = data.reshape(bsz_per_host, num_step) data = data.reshape(bsz_per_host, num_step)
...@@ -270,7 +274,7 @@ def _split_a_and_b(data, sent_ids, begin_idx, tot_len, extend_target=False): ...@@ -270,7 +274,7 @@ def _split_a_and_b(data, sent_ids, begin_idx, tot_len, extend_target=False):
data_len = data.shape[0] data_len = data.shape[0]
if begin_idx + tot_len >= data_len: if begin_idx + tot_len >= data_len:
tf.logging.info("[_split_a_and_b] returns None: " logging.info("[_split_a_and_b] returns None: "
"begin_idx %d + tot_len %d >= data_len %d", "begin_idx %d + tot_len %d >= data_len %d",
begin_idx, tot_len, data_len) begin_idx, tot_len, data_len)
return None return None
...@@ -284,9 +288,9 @@ def _split_a_and_b(data, sent_ids, begin_idx, tot_len, extend_target=False): ...@@ -284,9 +288,9 @@ def _split_a_and_b(data, sent_ids, begin_idx, tot_len, extend_target=False):
end_idx += 1 end_idx += 1
a_begin = begin_idx a_begin = begin_idx
if len(cut_points) == 0 or random.random() < 0.5: if len(cut_points) == 0 or random.random() < 0.5: # pylint:disable=g-explicit-length-test
label = 0 label = 0
if len(cut_points) == 0: if len(cut_points) == 0: # pylint:disable=g-explicit-length-test
a_end = end_idx a_end = end_idx
else: else:
a_end = random.choice(cut_points) a_end = random.choice(cut_points)
...@@ -321,7 +325,7 @@ def _split_a_and_b(data, sent_ids, begin_idx, tot_len, extend_target=False): ...@@ -321,7 +325,7 @@ def _split_a_and_b(data, sent_ids, begin_idx, tot_len, extend_target=False):
if extend_target: if extend_target:
if a_end >= data_len or b_end >= data_len: if a_end >= data_len or b_end >= data_len:
tf.logging.info("[_split_a_and_b] returns None: " logging.info("[_split_a_and_b] returns None: "
"a_end %d or b_end %d >= data_len %d", "a_end %d or b_end %d >= data_len %d",
a_end, b_end, data_len) a_end, b_end, data_len)
return None return None
...@@ -342,9 +346,7 @@ def _is_start_piece(piece): ...@@ -342,9 +346,7 @@ def _is_start_piece(piece):
def _sample_mask(sp, seg, reverse=False, max_gram=5, goal_num_predict=None): def _sample_mask(sp, seg, reverse=False, max_gram=5, goal_num_predict=None):
"""Sample `goal_num_predict` tokens for partial prediction. """Samples `goal_num_predict` tokens for partial prediction."""
About `mask_beta` tokens are chosen in a context of `mask_alpha` tokens."""
seg_len = len(seg) seg_len = len(seg)
mask = np.array([False] * seg_len, dtype=np.bool) mask = np.array([False] * seg_len, dtype=np.bool)
...@@ -406,8 +408,7 @@ def _sample_mask(sp, seg, reverse=False, max_gram=5, goal_num_predict=None): ...@@ -406,8 +408,7 @@ def _sample_mask(sp, seg, reverse=False, max_gram=5, goal_num_predict=None):
def _sample_mask_ngram(sp, seg, reverse=False, max_gram=5, def _sample_mask_ngram(sp, seg, reverse=False, max_gram=5,
goal_num_predict=None): goal_num_predict=None):
"""Sample `goal_num_predict` tokens for partial prediction. """Sample `goal_num_predict` tokens for partial prediction."""
About `mask_beta` tokens are chosen in a context of `mask_alpha` tokens."""
seg_len = len(seg) seg_len = len(seg)
mask = np.array([False] * seg_len, dtype=np.bool) mask = np.array([False] * seg_len, dtype=np.bool)
...@@ -474,6 +475,7 @@ def _sample_mask_ngram(sp, seg, reverse=False, max_gram=5, ...@@ -474,6 +475,7 @@ def _sample_mask_ngram(sp, seg, reverse=False, max_gram=5,
def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len, def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len,
bi_data, sp): bi_data, sp):
"""Creates TFRecords."""
data, sent_ids = data[0], data[1] data, sent_ids = data[0], data[1]
num_core = FLAGS.num_core_per_host num_core = FLAGS.num_core_per_host
...@@ -496,7 +498,7 @@ def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len, ...@@ -496,7 +498,7 @@ def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len,
else: else:
data, sent_ids = batchify(data, bsz_per_host, sent_ids) data, sent_ids = batchify(data, bsz_per_host, sent_ids)
tf.logging.info("Raw data shape %s.", data.shape) logging.info("Raw data shape %s.", data.shape)
file_name = format_filename( file_name = format_filename(
prefix=basename, prefix=basename,
...@@ -512,7 +514,7 @@ def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len, ...@@ -512,7 +514,7 @@ def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len,
) )
save_path = os.path.join(save_dir, file_name) save_path = os.path.join(save_dir, file_name)
record_writer = tf.python_io.TFRecordWriter(save_path) record_writer = tf.python_io.TFRecordWriter(save_path)
tf.logging.info("Start writing %s.", save_path) logging.info("Start writing %s.", save_path)
num_batch = 0 num_batch = 0
reuse_len = FLAGS.reuse_len reuse_len = FLAGS.reuse_len
...@@ -527,7 +529,7 @@ def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len, ...@@ -527,7 +529,7 @@ def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len,
i = 0 i = 0
while i + seq_len <= data_len: while i + seq_len <= data_len:
if num_batch % 500 == 0: if num_batch % 500 == 0:
tf.logging.info("Processing batch %d", num_batch) logging.info("Processing batch %d", num_batch)
all_ok = True all_ok = True
features = [] features = []
...@@ -542,7 +544,7 @@ def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len, ...@@ -542,7 +544,7 @@ def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len,
tot_len=seq_len - reuse_len - 3, tot_len=seq_len - reuse_len - 3,
extend_target=True) extend_target=True)
if results is None: if results is None:
tf.logging.info("Break out with seq idx %d", i) logging.info("Break out with seq idx %d", i)
all_ok = False all_ok = False
break break
...@@ -600,7 +602,7 @@ def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len, ...@@ -600,7 +602,7 @@ def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len,
i += reuse_len i += reuse_len
record_writer.close() record_writer.close()
tf.logging.info("Done writing %s. Num of batches: %d", save_path, num_batch) logging.info("Done writing %s. Num of batches: %d", save_path, num_batch)
return save_path, num_batch return save_path, num_batch
...@@ -624,6 +626,7 @@ def _convert_example(example, use_bfloat16): ...@@ -624,6 +626,7 @@ def _convert_example(example, use_bfloat16):
def parse_files_to_dataset(parser, file_names, split, num_batch, num_hosts, def parse_files_to_dataset(parser, file_names, split, num_batch, num_hosts,
host_id, num_core_per_host, bsz_per_core): host_id, num_core_per_host, bsz_per_core):
"""Parses files to a dataset."""
# list of file pathes # list of file pathes
num_files = len(file_names) num_files = len(file_names)
num_files_per_host = num_files // num_hosts num_files_per_host = num_files // num_hosts
...@@ -632,7 +635,7 @@ def parse_files_to_dataset(parser, file_names, split, num_batch, num_hosts, ...@@ -632,7 +635,7 @@ def parse_files_to_dataset(parser, file_names, split, num_batch, num_hosts,
if host_id == num_hosts - 1: if host_id == num_hosts - 1:
my_end_file_id = num_files my_end_file_id = num_files
file_paths = file_names[my_start_file_id: my_end_file_id] file_paths = file_names[my_start_file_id: my_end_file_id]
tf.logging.info("Host %d handles %d files", host_id, len(file_paths)) logging.info("Host %d handles %d files", host_id, len(file_paths))
assert split == "train" assert split == "train"
dataset = tf.data.Dataset.from_tensor_slices(file_paths) dataset = tf.data.Dataset.from_tensor_slices(file_paths)
...@@ -657,9 +660,7 @@ def parse_files_to_dataset(parser, file_names, split, num_batch, num_hosts, ...@@ -657,9 +660,7 @@ def parse_files_to_dataset(parser, file_names, split, num_batch, num_hosts,
def _local_perm(inputs, targets, is_masked, perm_size, seq_len): def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
""" """Samples a permutation of the factorization order, and create a mask.
Sample a permutation of the factorization order, and create an
attention mask accordingly.
Args: Args:
inputs: int64 Tensor in shape [seq_len], input ids. inputs: int64 Tensor in shape [seq_len], input ids.
...@@ -669,6 +670,10 @@ def _local_perm(inputs, targets, is_masked, perm_size, seq_len): ...@@ -669,6 +670,10 @@ def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
perm_size: the length of longest permutation. Could be set to be reuse_len. perm_size: the length of longest permutation. Could be set to be reuse_len.
Should not be larger than reuse_len or there will be data leaks. Should not be larger than reuse_len or there will be data leaks.
seq_len: int, sequence length. seq_len: int, sequence length.
Returns:
The permutation mask, new targets, target mask, and new inputs.
""" """
# Generate permutation indices # Generate permutation indices
...@@ -726,6 +731,7 @@ def _local_perm(inputs, targets, is_masked, perm_size, seq_len): ...@@ -726,6 +731,7 @@ def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
def get_dataset(params, num_hosts, num_core_per_host, split, file_names, def get_dataset(params, num_hosts, num_core_per_host, split, file_names,
num_batch, seq_len, reuse_len, perm_size, mask_alpha, num_batch, seq_len, reuse_len, perm_size, mask_alpha,
mask_beta, use_bfloat16=False, num_predict=None): mask_beta, use_bfloat16=False, num_predict=None):
"""Gets the dataset."""
bsz_per_core = params["batch_size"] bsz_per_core = params["batch_size"]
if num_hosts > 1: if num_hosts > 1:
...@@ -821,7 +827,7 @@ def get_dataset(params, num_hosts, num_core_per_host, split, file_names, ...@@ -821,7 +827,7 @@ def get_dataset(params, num_hosts, num_core_per_host, split, file_names,
_convert_example(example, use_bfloat16) _convert_example(example, use_bfloat16)
for k, v in example.items(): for k, v in example.items():
tf.logging.info("%s: %s", k, v) logging.info("%s: %s", k, v)
return example return example
...@@ -855,6 +861,7 @@ def get_input_fn( ...@@ -855,6 +861,7 @@ def get_input_fn(
num_passes=None, num_passes=None,
use_bfloat16=False, use_bfloat16=False,
num_predict=None): num_predict=None):
"""Gets the input function."""
# Merge all record infos into a single one # Merge all record infos into a single one
record_glob_base = format_filename( record_glob_base = format_filename(
...@@ -872,15 +879,14 @@ def get_input_fn( ...@@ -872,15 +879,14 @@ def get_input_fn(
record_info = {"num_batch": 0, "filenames": []} record_info = {"num_batch": 0, "filenames": []}
tfrecord_dirs = tfrecord_dir.split(",") tfrecord_dirs = tfrecord_dir.split(",")
tf.logging.info("Use the following tfrecord dirs: %s", tfrecord_dirs) logging.info("Use the following tfrecord dirs: %s", tfrecord_dirs)
for idx, record_dir in enumerate(tfrecord_dirs): for idx, record_dir in enumerate(tfrecord_dirs):
record_glob = os.path.join(record_dir, record_glob_base) record_glob = os.path.join(record_dir, record_glob_base)
tf.logging.info("[%d] Record glob: %s", idx, record_glob) logging.info("[%d] Record glob: %s", idx, record_glob)
record_paths = sorted(tf.gfile.Glob(record_glob)) record_paths = sorted(tf.gfile.Glob(record_glob))
tf.logging.info("[%d] Num of record info path: %d", logging.info("[%d] Num of record info path: %d", idx, len(record_paths))
idx, len(record_paths))
cur_record_info = {"num_batch": 0, "filenames": []} cur_record_info = {"num_batch": 0, "filenames": []}
...@@ -890,7 +896,7 @@ def get_input_fn( ...@@ -890,7 +896,7 @@ def get_input_fn(
fields = record_info_name.split(".")[0].split("-") fields = record_info_name.split(".")[0].split("-")
pass_id = int(fields[-1]) pass_id = int(fields[-1])
if len(fields) == 5 and pass_id >= num_passes: if len(fields) == 5 and pass_id >= num_passes:
tf.logging.info("Skip pass %d: %s", pass_id, record_info_name) logging.info("Skip pass %d: %s", pass_id, record_info_name)
continue continue
with tf.gfile.Open(record_info_path, "r") as fp: with tf.gfile.Open(record_info_path, "r") as fp:
...@@ -912,21 +918,19 @@ def get_input_fn( ...@@ -912,21 +918,19 @@ def get_input_fn(
new_filenames.append(new_filename) new_filenames.append(new_filename)
cur_record_info["filenames"] = new_filenames cur_record_info["filenames"] = new_filenames
tf.logging.info("[Dir %d] Number of chosen batches: %s", logging.info("[Dir %d] Number of chosen batches: %s",
idx, cur_record_info["num_batch"]) idx, cur_record_info["num_batch"])
tf.logging.info("[Dir %d] Number of chosen files: %s", logging.info("[Dir %d] Number of chosen files: %s",
idx, len(cur_record_info["filenames"])) idx, len(cur_record_info["filenames"]))
tf.logging.info(cur_record_info["filenames"]) logging.info(cur_record_info["filenames"])
# add `cur_record_info` to global `record_info` # add `cur_record_info` to global `record_info`
record_info["num_batch"] += cur_record_info["num_batch"] record_info["num_batch"] += cur_record_info["num_batch"]
record_info["filenames"] += cur_record_info["filenames"] record_info["filenames"] += cur_record_info["filenames"]
tf.logging.info("Total number of batches: %d", logging.info("Total number of batches: %d", record_info["num_batch"])
record_info["num_batch"]) logging.info("Total number of files: %d", len(record_info["filenames"]))
tf.logging.info("Total number of files: %d", logging.info(record_info["filenames"])
len(record_info["filenames"]))
tf.logging.info(record_info["filenames"])
def input_fn(params): def input_fn(params):
"""docs.""" """docs."""
...@@ -952,8 +956,8 @@ def get_input_fn( ...@@ -952,8 +956,8 @@ def get_input_fn(
return input_fn, record_info return input_fn, record_info
if __name__ == "__main__": def define_flags():
FLAGS = flags.FLAGS """Defines relevant flags."""
flags.DEFINE_bool("use_tpu", True, help="whether to use TPUs") flags.DEFINE_bool("use_tpu", True, help="whether to use TPUs")
flags.DEFINE_integer("bsz_per_host", 32, help="batch size per host.") flags.DEFINE_integer("bsz_per_host", 32, help="batch size per host.")
flags.DEFINE_integer("num_core_per_host", 8, help="num TPU cores per host.") flags.DEFINE_integer("num_core_per_host", 8, help="num TPU cores per host.")
...@@ -991,5 +995,8 @@ if __name__ == "__main__": ...@@ -991,5 +995,8 @@ if __name__ == "__main__":
flags.DEFINE_integer("task", 0, help="The Task ID. This value is used when " flags.DEFINE_integer("task", 0, help="The Task ID. This value is used when "
"using multiple workers to identify each worker.") "using multiple workers to identify each worker.")
tf.logging.set_verbosity(tf.logging.INFO)
if __name__ == "__main__":
define_flags()
logging.set_verbosity(logging.INFO)
app.run(create_data) app.run(create_data)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment