Commit 63d754ec authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Update transformer's data_download.py to use TF 1.x compatibility mode.

PiperOrigin-RevId: 280455983
parent c59cf48d
...@@ -27,7 +27,8 @@ import six ...@@ -27,7 +27,8 @@ import six
from six.moves import urllib from six.moves import urllib
from absl import app as absl_app from absl import app as absl_app
from absl import flags from absl import flags
import tensorflow as tf from absl import logging
import tensorflow.compat.v1 as tf
# pylint: enable=g-bad-import-order # pylint: enable=g-bad-import-order
from official.transformer.utils import tokenizer from official.transformer.utils import tokenizer
...@@ -164,7 +165,7 @@ def download_from_url(path, url): ...@@ -164,7 +165,7 @@ def download_from_url(path, url):
found_file = find_file(path, filename, max_depth=0) found_file = find_file(path, filename, max_depth=0)
if found_file is None: if found_file is None:
filename = os.path.join(path, filename) filename = os.path.join(path, filename)
tf.logging.info("Downloading from %s to %s." % (url, filename)) logging.info("Downloading from %s to %s." % (url, filename))
inprogress_filepath = filename + ".incomplete" inprogress_filepath = filename + ".incomplete"
inprogress_filepath, _ = urllib.request.urlretrieve( inprogress_filepath, _ = urllib.request.urlretrieve(
url, inprogress_filepath, reporthook=download_report_hook) url, inprogress_filepath, reporthook=download_report_hook)
...@@ -173,7 +174,7 @@ def download_from_url(path, url): ...@@ -173,7 +174,7 @@ def download_from_url(path, url):
tf.gfile.Rename(inprogress_filepath, filename) tf.gfile.Rename(inprogress_filepath, filename)
return filename return filename
else: else:
tf.logging.info("Already downloaded: %s (at %s)." % (url, found_file)) logging.info("Already downloaded: %s (at %s)." % (url, found_file))
return found_file return found_file
...@@ -196,14 +197,14 @@ def download_and_extract(path, url, input_filename, target_filename): ...@@ -196,14 +197,14 @@ def download_and_extract(path, url, input_filename, target_filename):
input_file = find_file(path, input_filename) input_file = find_file(path, input_filename)
target_file = find_file(path, target_filename) target_file = find_file(path, target_filename)
if input_file and target_file: if input_file and target_file:
tf.logging.info("Already downloaded and extracted %s." % url) logging.info("Already downloaded and extracted %s." % url)
return input_file, target_file return input_file, target_file
# Download archive file if it doesn't already exist. # Download archive file if it doesn't already exist.
compressed_file = download_from_url(path, url) compressed_file = download_from_url(path, url)
# Extract compressed files # Extract compressed files
tf.logging.info("Extracting %s." % compressed_file) logging.info("Extracting %s." % compressed_file)
with tarfile.open(compressed_file, "r:gz") as corpus_tar: with tarfile.open(compressed_file, "r:gz") as corpus_tar:
corpus_tar.extractall(path) corpus_tar.extractall(path)
...@@ -239,7 +240,7 @@ def compile_files(raw_dir, raw_files, tag): ...@@ -239,7 +240,7 @@ def compile_files(raw_dir, raw_files, tag):
Returns: Returns:
Full path of compiled input and target files. Full path of compiled input and target files.
""" """
tf.logging.info("Compiling files with tag %s." % tag) logging.info("Compiling files with tag %s." % tag)
filename = "%s-%s" % (_PREFIX, tag) filename = "%s-%s" % (_PREFIX, tag)
input_compiled_file = os.path.join(raw_dir, filename + ".lang1") input_compiled_file = os.path.join(raw_dir, filename + ".lang1")
target_compiled_file = os.path.join(raw_dir, filename + ".lang2") target_compiled_file = os.path.join(raw_dir, filename + ".lang2")
...@@ -250,7 +251,7 @@ def compile_files(raw_dir, raw_files, tag): ...@@ -250,7 +251,7 @@ def compile_files(raw_dir, raw_files, tag):
input_file = raw_files["inputs"][i] input_file = raw_files["inputs"][i]
target_file = raw_files["targets"][i] target_file = raw_files["targets"][i]
tf.logging.info("Reading files %s and %s." % (input_file, target_file)) logging.info("Reading files %s and %s." % (input_file, target_file))
write_file(input_writer, input_file) write_file(input_writer, input_file)
write_file(target_writer, target_file) write_file(target_writer, target_file)
return input_compiled_file, target_compiled_file return input_compiled_file, target_compiled_file
...@@ -286,10 +287,10 @@ def encode_and_save_files( ...@@ -286,10 +287,10 @@ def encode_and_save_files(
for n in range(total_shards)] for n in range(total_shards)]
if all_exist(filepaths): if all_exist(filepaths):
tf.logging.info("Files with tag %s already exist." % tag) logging.info("Files with tag %s already exist." % tag)
return filepaths return filepaths
tf.logging.info("Saving files with tag %s." % tag) logging.info("Saving files with tag %s." % tag)
input_file = raw_files[0] input_file = raw_files[0]
target_file = raw_files[1] target_file = raw_files[1]
...@@ -300,7 +301,7 @@ def encode_and_save_files( ...@@ -300,7 +301,7 @@ def encode_and_save_files(
for counter, (input_line, target_line) in enumerate(zip( for counter, (input_line, target_line) in enumerate(zip(
txt_line_iterator(input_file), txt_line_iterator(target_file))): txt_line_iterator(input_file), txt_line_iterator(target_file))):
if counter > 0 and counter % 100000 == 0: if counter > 0 and counter % 100000 == 0:
tf.logging.info("\tSaving case %d." % counter) logging.info("\tSaving case %d." % counter)
example = dict_to_example( example = dict_to_example(
{"inputs": subtokenizer.encode(input_line, add_eos=True), {"inputs": subtokenizer.encode(input_line, add_eos=True),
"targets": subtokenizer.encode(target_line, add_eos=True)}) "targets": subtokenizer.encode(target_line, add_eos=True)})
...@@ -312,7 +313,7 @@ def encode_and_save_files( ...@@ -312,7 +313,7 @@ def encode_and_save_files(
for tmp_name, final_name in zip(tmp_filepaths, filepaths): for tmp_name, final_name in zip(tmp_filepaths, filepaths):
tf.gfile.Rename(tmp_name, final_name) tf.gfile.Rename(tmp_name, final_name)
tf.logging.info("Saved %d Examples", counter + 1) logging.info("Saved %d Examples", counter + 1)
return filepaths return filepaths
...@@ -324,7 +325,7 @@ def shard_filename(path, tag, shard_num, total_shards): ...@@ -324,7 +325,7 @@ def shard_filename(path, tag, shard_num, total_shards):
def shuffle_records(fname): def shuffle_records(fname):
"""Shuffle records in a single file.""" """Shuffle records in a single file."""
tf.logging.info("Shuffling records in file %s" % fname) logging.info("Shuffling records in file %s" % fname)
# Rename file prior to shuffling # Rename file prior to shuffling
tmp_fname = fname + ".unshuffled" tmp_fname = fname + ".unshuffled"
...@@ -335,7 +336,7 @@ def shuffle_records(fname): ...@@ -335,7 +336,7 @@ def shuffle_records(fname):
for record in reader: for record in reader:
records.append(record) records.append(record)
if len(records) % 100000 == 0: if len(records) % 100000 == 0:
tf.logging.info("\tRead: %d", len(records)) logging.info("\tRead: %d", len(records))
random.shuffle(records) random.shuffle(records)
...@@ -344,7 +345,7 @@ def shuffle_records(fname): ...@@ -344,7 +345,7 @@ def shuffle_records(fname):
for count, record in enumerate(records): for count, record in enumerate(records):
w.write(record) w.write(record)
if count > 0 and count % 100000 == 0: if count > 0 and count % 100000 == 0:
tf.logging.info("\tWriting record: %d" % count) logging.info("\tWriting record: %d" % count)
tf.gfile.Remove(tmp_fname) tf.gfile.Remove(tmp_fname)
...@@ -367,7 +368,7 @@ def all_exist(filepaths): ...@@ -367,7 +368,7 @@ def all_exist(filepaths):
def make_dir(path): def make_dir(path):
if not tf.gfile.Exists(path): if not tf.gfile.Exists(path):
tf.logging.info("Creating directory %s" % path) logging.info("Creating directory %s" % path)
tf.gfile.MakeDirs(path) tf.gfile.MakeDirs(path)
...@@ -377,28 +378,28 @@ def main(unused_argv): ...@@ -377,28 +378,28 @@ def main(unused_argv):
make_dir(FLAGS.data_dir) make_dir(FLAGS.data_dir)
# Download test_data # Download test_data
tf.logging.info("Step 1/5: Downloading test data") logging.info("Step 1/5: Downloading test data")
train_files = get_raw_files(FLAGS.data_dir, _TEST_DATA_SOURCES) train_files = get_raw_files(FLAGS.data_dir, _TEST_DATA_SOURCES)
# Get paths of download/extracted training and evaluation files. # Get paths of download/extracted training and evaluation files.
tf.logging.info("Step 2/5: Downloading data from source") logging.info("Step 2/5: Downloading data from source")
train_files = get_raw_files(FLAGS.raw_dir, _TRAIN_DATA_SOURCES) train_files = get_raw_files(FLAGS.raw_dir, _TRAIN_DATA_SOURCES)
eval_files = get_raw_files(FLAGS.raw_dir, _EVAL_DATA_SOURCES) eval_files = get_raw_files(FLAGS.raw_dir, _EVAL_DATA_SOURCES)
# Create subtokenizer based on the training files. # Create subtokenizer based on the training files.
tf.logging.info("Step 3/5: Creating subtokenizer and building vocabulary") logging.info("Step 3/5: Creating subtokenizer and building vocabulary")
train_files_flat = train_files["inputs"] + train_files["targets"] train_files_flat = train_files["inputs"] + train_files["targets"]
vocab_file = os.path.join(FLAGS.data_dir, VOCAB_FILE) vocab_file = os.path.join(FLAGS.data_dir, VOCAB_FILE)
subtokenizer = tokenizer.Subtokenizer.init_from_files( subtokenizer = tokenizer.Subtokenizer.init_from_files(
vocab_file, train_files_flat, _TARGET_VOCAB_SIZE, _TARGET_THRESHOLD, vocab_file, train_files_flat, _TARGET_VOCAB_SIZE, _TARGET_THRESHOLD,
min_count=None if FLAGS.search else _TRAIN_DATA_MIN_COUNT) min_count=None if FLAGS.search else _TRAIN_DATA_MIN_COUNT)
tf.logging.info("Step 4/5: Compiling training and evaluation data") logging.info("Step 4/5: Compiling training and evaluation data")
compiled_train_files = compile_files(FLAGS.raw_dir, train_files, _TRAIN_TAG) compiled_train_files = compile_files(FLAGS.raw_dir, train_files, _TRAIN_TAG)
compiled_eval_files = compile_files(FLAGS.raw_dir, eval_files, _EVAL_TAG) compiled_eval_files = compile_files(FLAGS.raw_dir, eval_files, _EVAL_TAG)
# Tokenize and save data as Examples in the TFRecord format. # Tokenize and save data as Examples in the TFRecord format.
tf.logging.info("Step 5/5: Preprocessing and saving data") logging.info("Step 5/5: Preprocessing and saving data")
train_tfrecord_files = encode_and_save_files( train_tfrecord_files = encode_and_save_files(
subtokenizer, FLAGS.data_dir, compiled_train_files, _TRAIN_TAG, subtokenizer, FLAGS.data_dir, compiled_train_files, _TRAIN_TAG,
_TRAIN_SHARDS) _TRAIN_SHARDS)
...@@ -428,7 +429,7 @@ def define_data_download_flags(): ...@@ -428,7 +429,7 @@ def define_data_download_flags():
if __name__ == "__main__": if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO) logging.set_verbosity(logging.INFO)
define_data_download_flags() define_data_download_flags()
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
absl_app.run(main) absl_app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment