"test/vscode:/vscode.git/clone" did not exist on "60f29ca00c55486cc7126f7d8bfc91bcb5157127"
Commit 12271d7c authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 299149155
parent ab6d40ca
......@@ -23,16 +23,18 @@ import random
import tarfile
# pylint: disable=g-bad-import-order
import six
from six.moves import urllib
from absl import app as absl_app
from absl import flags
from absl import logging
import six
from six.moves import range
from six.moves import urllib
from six.moves import zip
import tensorflow.compat.v1 as tf
# pylint: enable=g-bad-import-order
from official.nlp.transformer.utils import tokenizer
from official.utils.flags import core as flags_core
# pylint: enable=g-bad-import-order
# Data sources for training/evaluating the transformer translation model.
# If any of the training sources are changed, then either:
......@@ -148,7 +150,7 @@ def download_report_hook(count, block_size, total_size):
total_size: total size
"""
percent = int(count * block_size * 100 / total_size)
print("\r%d%%" % percent + " completed", end="\r")
print(six.ensure_str("\r%d%%" % percent) + " completed", end="\r")
def download_from_url(path, url):
......@@ -161,12 +163,12 @@ def download_from_url(path, url):
Returns:
Full path to downloaded file
"""
filename = url.split("/")[-1]
filename = six.ensure_str(url).split("/")[-1]
found_file = find_file(path, filename, max_depth=0)
if found_file is None:
filename = os.path.join(path, filename)
logging.info("Downloading from %s to %s." % (url, filename))
inprogress_filepath = filename + ".incomplete"
inprogress_filepath = six.ensure_str(filename) + ".incomplete"
inprogress_filepath, _ = urllib.request.urlretrieve(
url, inprogress_filepath, reporthook=download_report_hook)
# Print newline to clear the carriage return from the download progress.
......@@ -242,8 +244,10 @@ def compile_files(raw_dir, raw_files, tag):
"""
logging.info("Compiling files with tag %s." % tag)
filename = "%s-%s" % (_PREFIX, tag)
input_compiled_file = os.path.join(raw_dir, filename + ".lang1")
target_compiled_file = os.path.join(raw_dir, filename + ".lang2")
input_compiled_file = os.path.join(raw_dir,
six.ensure_str(filename) + ".lang1")
target_compiled_file = os.path.join(raw_dir,
six.ensure_str(filename) + ".lang2")
with tf.io.gfile.GFile(input_compiled_file, mode="w") as input_writer:
with tf.io.gfile.GFile(target_compiled_file, mode="w") as target_writer:
......@@ -295,7 +299,7 @@ def encode_and_save_files(
target_file = raw_files[1]
# Write examples to each shard in round robin order.
tmp_filepaths = [fname + ".incomplete" for fname in filepaths]
tmp_filepaths = [six.ensure_str(fname) + ".incomplete" for fname in filepaths]
writers = [tf.python_io.TFRecordWriter(fname) for fname in tmp_filepaths]
counter, shard = 0, 0
for counter, (input_line, target_line) in enumerate(zip(
......@@ -328,7 +332,7 @@ def shuffle_records(fname):
logging.info("Shuffling records in file %s" % fname)
# Rename file prior to shuffling
tmp_fname = fname + ".unshuffled"
tmp_fname = six.ensure_str(fname) + ".unshuffled"
tf.gfile.Rename(fname, tmp_fname)
reader = tf.io.tf_record_iterator(tmp_fname)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment