Commit 12271d7c authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 299149155
parent ab6d40ca
...@@ -23,16 +23,18 @@ import random ...@@ -23,16 +23,18 @@ import random
import tarfile import tarfile
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
import six
from six.moves import urllib
from absl import app as absl_app from absl import app as absl_app
from absl import flags from absl import flags
from absl import logging from absl import logging
import six
from six.moves import range
from six.moves import urllib
from six.moves import zip
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
# pylint: enable=g-bad-import-order
from official.nlp.transformer.utils import tokenizer from official.nlp.transformer.utils import tokenizer
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
# pylint: enable=g-bad-import-order
# Data sources for training/evaluating the transformer translation model. # Data sources for training/evaluating the transformer translation model.
# If any of the training sources are changed, then either: # If any of the training sources are changed, then either:
...@@ -148,7 +150,7 @@ def download_report_hook(count, block_size, total_size): ...@@ -148,7 +150,7 @@ def download_report_hook(count, block_size, total_size):
total_size: total size total_size: total size
""" """
percent = int(count * block_size * 100 / total_size) percent = int(count * block_size * 100 / total_size)
print("\r%d%%" % percent + " completed", end="\r") print(six.ensure_str("\r%d%%" % percent) + " completed", end="\r")
def download_from_url(path, url): def download_from_url(path, url):
...@@ -161,12 +163,12 @@ def download_from_url(path, url): ...@@ -161,12 +163,12 @@ def download_from_url(path, url):
Returns: Returns:
Full path to downloaded file Full path to downloaded file
""" """
filename = url.split("/")[-1] filename = six.ensure_str(url).split("/")[-1]
found_file = find_file(path, filename, max_depth=0) found_file = find_file(path, filename, max_depth=0)
if found_file is None: if found_file is None:
filename = os.path.join(path, filename) filename = os.path.join(path, filename)
logging.info("Downloading from %s to %s." % (url, filename)) logging.info("Downloading from %s to %s." % (url, filename))
inprogress_filepath = filename + ".incomplete" inprogress_filepath = six.ensure_str(filename) + ".incomplete"
inprogress_filepath, _ = urllib.request.urlretrieve( inprogress_filepath, _ = urllib.request.urlretrieve(
url, inprogress_filepath, reporthook=download_report_hook) url, inprogress_filepath, reporthook=download_report_hook)
# Print newline to clear the carriage return from the download progress. # Print newline to clear the carriage return from the download progress.
...@@ -242,8 +244,10 @@ def compile_files(raw_dir, raw_files, tag): ...@@ -242,8 +244,10 @@ def compile_files(raw_dir, raw_files, tag):
""" """
logging.info("Compiling files with tag %s." % tag) logging.info("Compiling files with tag %s." % tag)
filename = "%s-%s" % (_PREFIX, tag) filename = "%s-%s" % (_PREFIX, tag)
input_compiled_file = os.path.join(raw_dir, filename + ".lang1") input_compiled_file = os.path.join(raw_dir,
target_compiled_file = os.path.join(raw_dir, filename + ".lang2") six.ensure_str(filename) + ".lang1")
target_compiled_file = os.path.join(raw_dir,
six.ensure_str(filename) + ".lang2")
with tf.io.gfile.GFile(input_compiled_file, mode="w") as input_writer: with tf.io.gfile.GFile(input_compiled_file, mode="w") as input_writer:
with tf.io.gfile.GFile(target_compiled_file, mode="w") as target_writer: with tf.io.gfile.GFile(target_compiled_file, mode="w") as target_writer:
...@@ -295,7 +299,7 @@ def encode_and_save_files( ...@@ -295,7 +299,7 @@ def encode_and_save_files(
target_file = raw_files[1] target_file = raw_files[1]
# Write examples to each shard in round robin order. # Write examples to each shard in round robin order.
tmp_filepaths = [fname + ".incomplete" for fname in filepaths] tmp_filepaths = [six.ensure_str(fname) + ".incomplete" for fname in filepaths]
writers = [tf.python_io.TFRecordWriter(fname) for fname in tmp_filepaths] writers = [tf.python_io.TFRecordWriter(fname) for fname in tmp_filepaths]
counter, shard = 0, 0 counter, shard = 0, 0
for counter, (input_line, target_line) in enumerate(zip( for counter, (input_line, target_line) in enumerate(zip(
...@@ -328,7 +332,7 @@ def shuffle_records(fname): ...@@ -328,7 +332,7 @@ def shuffle_records(fname):
logging.info("Shuffling records in file %s" % fname) logging.info("Shuffling records in file %s" % fname)
# Rename file prior to shuffling # Rename file prior to shuffling
tmp_fname = fname + ".unshuffled" tmp_fname = six.ensure_str(fname) + ".unshuffled"
tf.gfile.Rename(fname, tmp_fname) tf.gfile.Rename(fname, tmp_fname)
reader = tf.io.tf_record_iterator(tmp_fname) reader = tf.io.tf_record_iterator(tmp_fname)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment