Commit 89031e1a authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

tf.compat.v1 for preprocess_pretrain_data.

PiperOrigin-RevId: 359103541
parent d2d32f46
......@@ -22,14 +22,15 @@ import random
# Import libraries
from absl import app
from absl import flags
import absl.logging as _logging # pylint: disable=unused-import
from absl import logging
import numpy as np
import tensorflow.google as tf
from official.nlp.xlnet import preprocess_utils
import tensorflow.compat.v1 as tf
import sentencepiece as spm
from official.nlp.xlnet import preprocess_utils
FLAGS = flags.FLAGS
special_symbols = {
......@@ -89,6 +90,7 @@ def format_filename(prefix, bsz_per_host, seq_len, bi_data, suffix,
def _create_data(idx, input_paths):
"""Creates data."""
# Load sentence-piece model
sp = spm.SentencePieceProcessor()
sp.Load(FLAGS.sp_path)
......@@ -98,10 +100,10 @@ def _create_data(idx, input_paths):
for input_path in input_paths:
input_data, sent_ids = [], []
sent_id, line_cnt = True, 0
tf.logging.info("Processing %s", input_path)
logging.info("Processing %s", input_path)
for line in tf.gfile.Open(input_path):
if line_cnt % 100000 == 0:
tf.logging.info("Loading line %d", line_cnt)
logging.info("Loading line %d", line_cnt)
line_cnt += 1
if not line.strip():
......@@ -122,7 +124,7 @@ def _create_data(idx, input_paths):
sent_ids.extend([sent_id] * len(cur_sent))
sent_id = not sent_id
tf.logging.info("Finish with line %d", line_cnt)
logging.info("Finish with line %d", line_cnt)
if line_cnt == 0:
continue
......@@ -132,7 +134,7 @@ def _create_data(idx, input_paths):
total_line_cnt += line_cnt
input_shards.append((input_data, sent_ids))
tf.logging.info("[Task %d] Total number line: %d", idx, total_line_cnt)
logging.info("[Task %d] Total number line: %d", idx, total_line_cnt)
tfrecord_dir = os.path.join(FLAGS.save_dir, "tfrecords")
......@@ -142,8 +144,8 @@ def _create_data(idx, input_paths):
np.random.seed(100 * FLAGS.task + FLAGS.pass_id)
perm_indices = np.random.permutation(len(input_shards))
tf.logging.info("Using perm indices %s for pass %d",
perm_indices.tolist(), FLAGS.pass_id)
logging.info("Using perm indices %s for pass %d",
perm_indices.tolist(), FLAGS.pass_id)
input_data_list, sent_ids_list = [], []
prev_sent_id = None
......@@ -185,6 +187,7 @@ def _create_data(idx, input_paths):
def create_data(_):
"""Creates pretrain data."""
# Validate FLAGS
assert FLAGS.bsz_per_host % FLAGS.num_core_per_host == 0
if not FLAGS.use_tpu:
......@@ -221,16 +224,16 @@ def create_data(_):
# Interleavely split the work into FLAGS.num_task splits
file_paths = sorted(tf.gfile.Glob(FLAGS.input_glob))
tf.logging.info("Use glob: %s", FLAGS.input_glob)
tf.logging.info("Find %d files: %s", len(file_paths), file_paths)
logging.info("Use glob: %s", FLAGS.input_glob)
logging.info("Find %d files: %s", len(file_paths), file_paths)
task_file_paths = file_paths[FLAGS.task::FLAGS.num_task]
if not task_file_paths:
tf.logging.info("Exit: task %d has no file to process.", FLAGS.task)
logging.info("Exit: task %d has no file to process.", FLAGS.task)
return
tf.logging.info("Task %d process %d files: %s",
FLAGS.task, len(task_file_paths), task_file_paths)
logging.info("Task %d process %d files: %s",
FLAGS.task, len(task_file_paths), task_file_paths)
record_info = _create_data(FLAGS.task, task_file_paths)
record_prefix = "record_info-{}-{}-{}".format(
......@@ -253,6 +256,7 @@ def create_data(_):
def batchify(data, bsz_per_host, sent_ids=None):
"""Creates batches."""
num_step = len(data) // bsz_per_host
data = data[:bsz_per_host * num_step]
data = data.reshape(bsz_per_host, num_step)
......@@ -270,9 +274,9 @@ def _split_a_and_b(data, sent_ids, begin_idx, tot_len, extend_target=False):
data_len = data.shape[0]
if begin_idx + tot_len >= data_len:
tf.logging.info("[_split_a_and_b] returns None: "
"begin_idx %d + tot_len %d >= data_len %d",
begin_idx, tot_len, data_len)
logging.info("[_split_a_and_b] returns None: "
"begin_idx %d + tot_len %d >= data_len %d",
begin_idx, tot_len, data_len)
return None
end_idx = begin_idx + 1
......@@ -284,9 +288,9 @@ def _split_a_and_b(data, sent_ids, begin_idx, tot_len, extend_target=False):
end_idx += 1
a_begin = begin_idx
if len(cut_points) == 0 or random.random() < 0.5:
if len(cut_points) == 0 or random.random() < 0.5: # pylint:disable=g-explicit-length-test
label = 0
if len(cut_points) == 0:
if len(cut_points) == 0: # pylint:disable=g-explicit-length-test
a_end = end_idx
else:
a_end = random.choice(cut_points)
......@@ -321,9 +325,9 @@ def _split_a_and_b(data, sent_ids, begin_idx, tot_len, extend_target=False):
if extend_target:
if a_end >= data_len or b_end >= data_len:
tf.logging.info("[_split_a_and_b] returns None: "
"a_end %d or b_end %d >= data_len %d",
a_end, b_end, data_len)
logging.info("[_split_a_and_b] returns None: "
"a_end %d or b_end %d >= data_len %d",
a_end, b_end, data_len)
return None
a_target = data[a_begin + 1: a_end + 1]
b_target = data[b_begin: b_end + 1]
......@@ -342,9 +346,7 @@ def _is_start_piece(piece):
def _sample_mask(sp, seg, reverse=False, max_gram=5, goal_num_predict=None):
"""Sample `goal_num_predict` tokens for partial prediction.
About `mask_beta` tokens are chosen in a context of `mask_alpha` tokens."""
"""Samples `goal_num_predict` tokens for partial prediction."""
seg_len = len(seg)
mask = np.array([False] * seg_len, dtype=np.bool)
......@@ -406,8 +408,7 @@ def _sample_mask(sp, seg, reverse=False, max_gram=5, goal_num_predict=None):
def _sample_mask_ngram(sp, seg, reverse=False, max_gram=5,
goal_num_predict=None):
"""Sample `goal_num_predict` tokens for partial prediction.
About `mask_beta` tokens are chosen in a context of `mask_alpha` tokens."""
"""Sample `goal_num_predict` tokens for partial prediction."""
seg_len = len(seg)
mask = np.array([False] * seg_len, dtype=np.bool)
......@@ -474,6 +475,7 @@ def _sample_mask_ngram(sp, seg, reverse=False, max_gram=5,
def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len,
bi_data, sp):
"""Creates TFRecords."""
data, sent_ids = data[0], data[1]
num_core = FLAGS.num_core_per_host
......@@ -496,7 +498,7 @@ def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len,
else:
data, sent_ids = batchify(data, bsz_per_host, sent_ids)
tf.logging.info("Raw data shape %s.", data.shape)
logging.info("Raw data shape %s.", data.shape)
file_name = format_filename(
prefix=basename,
......@@ -512,7 +514,7 @@ def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len,
)
save_path = os.path.join(save_dir, file_name)
record_writer = tf.python_io.TFRecordWriter(save_path)
tf.logging.info("Start writing %s.", save_path)
logging.info("Start writing %s.", save_path)
num_batch = 0
reuse_len = FLAGS.reuse_len
......@@ -527,7 +529,7 @@ def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len,
i = 0
while i + seq_len <= data_len:
if num_batch % 500 == 0:
tf.logging.info("Processing batch %d", num_batch)
logging.info("Processing batch %d", num_batch)
all_ok = True
features = []
......@@ -542,7 +544,7 @@ def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len,
tot_len=seq_len - reuse_len - 3,
extend_target=True)
if results is None:
tf.logging.info("Break out with seq idx %d", i)
logging.info("Break out with seq idx %d", i)
all_ok = False
break
......@@ -600,7 +602,7 @@ def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len,
i += reuse_len
record_writer.close()
tf.logging.info("Done writing %s. Num of batches: %d", save_path, num_batch)
logging.info("Done writing %s. Num of batches: %d", save_path, num_batch)
return save_path, num_batch
......@@ -624,6 +626,7 @@ def _convert_example(example, use_bfloat16):
def parse_files_to_dataset(parser, file_names, split, num_batch, num_hosts,
host_id, num_core_per_host, bsz_per_core):
"""Parses files to a dataset."""
# list of file pathes
num_files = len(file_names)
num_files_per_host = num_files // num_hosts
......@@ -632,7 +635,7 @@ def parse_files_to_dataset(parser, file_names, split, num_batch, num_hosts,
if host_id == num_hosts - 1:
my_end_file_id = num_files
file_paths = file_names[my_start_file_id: my_end_file_id]
tf.logging.info("Host %d handles %d files", host_id, len(file_paths))
logging.info("Host %d handles %d files", host_id, len(file_paths))
assert split == "train"
dataset = tf.data.Dataset.from_tensor_slices(file_paths)
......@@ -657,9 +660,7 @@ def parse_files_to_dataset(parser, file_names, split, num_batch, num_hosts,
def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
"""
Sample a permutation of the factorization order, and create an
attention mask accordingly.
"""Samples a permutation of the factorization order, and create a mask.
Args:
inputs: int64 Tensor in shape [seq_len], input ids.
......@@ -669,6 +670,10 @@ def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
perm_size: the length of longest permutation. Could be set to be reuse_len.
Should not be larger than reuse_len or there will be data leaks.
seq_len: int, sequence length.
Returns:
The permutation mask, new targets, target mask, and new inputs.
"""
# Generate permutation indices
......@@ -726,6 +731,7 @@ def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
def get_dataset(params, num_hosts, num_core_per_host, split, file_names,
num_batch, seq_len, reuse_len, perm_size, mask_alpha,
mask_beta, use_bfloat16=False, num_predict=None):
"""Gets the dataset."""
bsz_per_core = params["batch_size"]
if num_hosts > 1:
......@@ -821,7 +827,7 @@ def get_dataset(params, num_hosts, num_core_per_host, split, file_names,
_convert_example(example, use_bfloat16)
for k, v in example.items():
tf.logging.info("%s: %s", k, v)
logging.info("%s: %s", k, v)
return example
......@@ -855,6 +861,7 @@ def get_input_fn(
num_passes=None,
use_bfloat16=False,
num_predict=None):
"""Gets the input function."""
# Merge all record infos into a single one
record_glob_base = format_filename(
......@@ -872,15 +879,14 @@ def get_input_fn(
record_info = {"num_batch": 0, "filenames": []}
tfrecord_dirs = tfrecord_dir.split(",")
tf.logging.info("Use the following tfrecord dirs: %s", tfrecord_dirs)
logging.info("Use the following tfrecord dirs: %s", tfrecord_dirs)
for idx, record_dir in enumerate(tfrecord_dirs):
record_glob = os.path.join(record_dir, record_glob_base)
tf.logging.info("[%d] Record glob: %s", idx, record_glob)
logging.info("[%d] Record glob: %s", idx, record_glob)
record_paths = sorted(tf.gfile.Glob(record_glob))
tf.logging.info("[%d] Num of record info path: %d",
idx, len(record_paths))
logging.info("[%d] Num of record info path: %d", idx, len(record_paths))
cur_record_info = {"num_batch": 0, "filenames": []}
......@@ -890,7 +896,7 @@ def get_input_fn(
fields = record_info_name.split(".")[0].split("-")
pass_id = int(fields[-1])
if len(fields) == 5 and pass_id >= num_passes:
tf.logging.info("Skip pass %d: %s", pass_id, record_info_name)
logging.info("Skip pass %d: %s", pass_id, record_info_name)
continue
with tf.gfile.Open(record_info_path, "r") as fp:
......@@ -912,21 +918,19 @@ def get_input_fn(
new_filenames.append(new_filename)
cur_record_info["filenames"] = new_filenames
tf.logging.info("[Dir %d] Number of chosen batches: %s",
idx, cur_record_info["num_batch"])
tf.logging.info("[Dir %d] Number of chosen files: %s",
idx, len(cur_record_info["filenames"]))
tf.logging.info(cur_record_info["filenames"])
logging.info("[Dir %d] Number of chosen batches: %s",
idx, cur_record_info["num_batch"])
logging.info("[Dir %d] Number of chosen files: %s",
idx, len(cur_record_info["filenames"]))
logging.info(cur_record_info["filenames"])
# add `cur_record_info` to global `record_info`
record_info["num_batch"] += cur_record_info["num_batch"]
record_info["filenames"] += cur_record_info["filenames"]
tf.logging.info("Total number of batches: %d",
record_info["num_batch"])
tf.logging.info("Total number of files: %d",
len(record_info["filenames"]))
tf.logging.info(record_info["filenames"])
logging.info("Total number of batches: %d", record_info["num_batch"])
logging.info("Total number of files: %d", len(record_info["filenames"]))
logging.info(record_info["filenames"])
def input_fn(params):
"""docs."""
......@@ -952,8 +956,8 @@ def get_input_fn(
return input_fn, record_info
if __name__ == "__main__":
FLAGS = flags.FLAGS
def define_flags():
"""Defines relevant flags."""
flags.DEFINE_bool("use_tpu", True, help="whether to use TPUs")
flags.DEFINE_integer("bsz_per_host", 32, help="batch size per host.")
flags.DEFINE_integer("num_core_per_host", 8, help="num TPU cores per host.")
......@@ -991,5 +995,8 @@ if __name__ == "__main__":
flags.DEFINE_integer("task", 0, help="The Task ID. This value is used when "
"using multiple workers to identify each worker.")
tf.logging.set_verbosity(tf.logging.INFO)
if __name__ == "__main__":
define_flags()
logging.set_verbosity(logging.INFO)
app.run(create_data)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment