Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
63d754ec
Commit
63d754ec
authored
Nov 14, 2019
by
Allen Wang
Committed by
A. Unique TensorFlower
Nov 14, 2019
Browse files
Update transformer's data_download.py to use TF 1.x compatibility mode.
PiperOrigin-RevId: 280455983
parent
c59cf48d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
21 deletions
+22
-21
official/transformer/data_download.py
official/transformer/data_download.py
+22
-21
No files found.
official/transformer/data_download.py
View file @
63d754ec
...
@@ -27,7 +27,8 @@ import six
...
@@ -27,7 +27,8 @@ import six
from
six.moves
import
urllib
from
six.moves
import
urllib
from
absl
import
app
as
absl_app
from
absl
import
app
as
absl_app
from
absl
import
flags
from
absl
import
flags
import
tensorflow
as
tf
from
absl
import
logging
import
tensorflow.compat.v1
as
tf
# pylint: enable=g-bad-import-order
# pylint: enable=g-bad-import-order
from
official.transformer.utils
import
tokenizer
from
official.transformer.utils
import
tokenizer
...
@@ -164,7 +165,7 @@ def download_from_url(path, url):
...
@@ -164,7 +165,7 @@ def download_from_url(path, url):
found_file
=
find_file
(
path
,
filename
,
max_depth
=
0
)
found_file
=
find_file
(
path
,
filename
,
max_depth
=
0
)
if
found_file
is
None
:
if
found_file
is
None
:
filename
=
os
.
path
.
join
(
path
,
filename
)
filename
=
os
.
path
.
join
(
path
,
filename
)
tf
.
logging
.
info
(
"Downloading from %s to %s."
%
(
url
,
filename
))
logging
.
info
(
"Downloading from %s to %s."
%
(
url
,
filename
))
inprogress_filepath
=
filename
+
".incomplete"
inprogress_filepath
=
filename
+
".incomplete"
inprogress_filepath
,
_
=
urllib
.
request
.
urlretrieve
(
inprogress_filepath
,
_
=
urllib
.
request
.
urlretrieve
(
url
,
inprogress_filepath
,
reporthook
=
download_report_hook
)
url
,
inprogress_filepath
,
reporthook
=
download_report_hook
)
...
@@ -173,7 +174,7 @@ def download_from_url(path, url):
...
@@ -173,7 +174,7 @@ def download_from_url(path, url):
tf
.
gfile
.
Rename
(
inprogress_filepath
,
filename
)
tf
.
gfile
.
Rename
(
inprogress_filepath
,
filename
)
return
filename
return
filename
else
:
else
:
tf
.
logging
.
info
(
"Already downloaded: %s (at %s)."
%
(
url
,
found_file
))
logging
.
info
(
"Already downloaded: %s (at %s)."
%
(
url
,
found_file
))
return
found_file
return
found_file
...
@@ -196,14 +197,14 @@ def download_and_extract(path, url, input_filename, target_filename):
...
@@ -196,14 +197,14 @@ def download_and_extract(path, url, input_filename, target_filename):
input_file
=
find_file
(
path
,
input_filename
)
input_file
=
find_file
(
path
,
input_filename
)
target_file
=
find_file
(
path
,
target_filename
)
target_file
=
find_file
(
path
,
target_filename
)
if
input_file
and
target_file
:
if
input_file
and
target_file
:
tf
.
logging
.
info
(
"Already downloaded and extracted %s."
%
url
)
logging
.
info
(
"Already downloaded and extracted %s."
%
url
)
return
input_file
,
target_file
return
input_file
,
target_file
# Download archive file if it doesn't already exist.
# Download archive file if it doesn't already exist.
compressed_file
=
download_from_url
(
path
,
url
)
compressed_file
=
download_from_url
(
path
,
url
)
# Extract compressed files
# Extract compressed files
tf
.
logging
.
info
(
"Extracting %s."
%
compressed_file
)
logging
.
info
(
"Extracting %s."
%
compressed_file
)
with
tarfile
.
open
(
compressed_file
,
"r:gz"
)
as
corpus_tar
:
with
tarfile
.
open
(
compressed_file
,
"r:gz"
)
as
corpus_tar
:
corpus_tar
.
extractall
(
path
)
corpus_tar
.
extractall
(
path
)
...
@@ -239,7 +240,7 @@ def compile_files(raw_dir, raw_files, tag):
...
@@ -239,7 +240,7 @@ def compile_files(raw_dir, raw_files, tag):
Returns:
Returns:
Full path of compiled input and target files.
Full path of compiled input and target files.
"""
"""
tf
.
logging
.
info
(
"Compiling files with tag %s."
%
tag
)
logging
.
info
(
"Compiling files with tag %s."
%
tag
)
filename
=
"%s-%s"
%
(
_PREFIX
,
tag
)
filename
=
"%s-%s"
%
(
_PREFIX
,
tag
)
input_compiled_file
=
os
.
path
.
join
(
raw_dir
,
filename
+
".lang1"
)
input_compiled_file
=
os
.
path
.
join
(
raw_dir
,
filename
+
".lang1"
)
target_compiled_file
=
os
.
path
.
join
(
raw_dir
,
filename
+
".lang2"
)
target_compiled_file
=
os
.
path
.
join
(
raw_dir
,
filename
+
".lang2"
)
...
@@ -250,7 +251,7 @@ def compile_files(raw_dir, raw_files, tag):
...
@@ -250,7 +251,7 @@ def compile_files(raw_dir, raw_files, tag):
input_file
=
raw_files
[
"inputs"
][
i
]
input_file
=
raw_files
[
"inputs"
][
i
]
target_file
=
raw_files
[
"targets"
][
i
]
target_file
=
raw_files
[
"targets"
][
i
]
tf
.
logging
.
info
(
"Reading files %s and %s."
%
(
input_file
,
target_file
))
logging
.
info
(
"Reading files %s and %s."
%
(
input_file
,
target_file
))
write_file
(
input_writer
,
input_file
)
write_file
(
input_writer
,
input_file
)
write_file
(
target_writer
,
target_file
)
write_file
(
target_writer
,
target_file
)
return
input_compiled_file
,
target_compiled_file
return
input_compiled_file
,
target_compiled_file
...
@@ -286,10 +287,10 @@ def encode_and_save_files(
...
@@ -286,10 +287,10 @@ def encode_and_save_files(
for
n
in
range
(
total_shards
)]
for
n
in
range
(
total_shards
)]
if
all_exist
(
filepaths
):
if
all_exist
(
filepaths
):
tf
.
logging
.
info
(
"Files with tag %s already exist."
%
tag
)
logging
.
info
(
"Files with tag %s already exist."
%
tag
)
return
filepaths
return
filepaths
tf
.
logging
.
info
(
"Saving files with tag %s."
%
tag
)
logging
.
info
(
"Saving files with tag %s."
%
tag
)
input_file
=
raw_files
[
0
]
input_file
=
raw_files
[
0
]
target_file
=
raw_files
[
1
]
target_file
=
raw_files
[
1
]
...
@@ -300,7 +301,7 @@ def encode_and_save_files(
...
@@ -300,7 +301,7 @@ def encode_and_save_files(
for
counter
,
(
input_line
,
target_line
)
in
enumerate
(
zip
(
for
counter
,
(
input_line
,
target_line
)
in
enumerate
(
zip
(
txt_line_iterator
(
input_file
),
txt_line_iterator
(
target_file
))):
txt_line_iterator
(
input_file
),
txt_line_iterator
(
target_file
))):
if
counter
>
0
and
counter
%
100000
==
0
:
if
counter
>
0
and
counter
%
100000
==
0
:
tf
.
logging
.
info
(
"
\t
Saving case %d."
%
counter
)
logging
.
info
(
"
\t
Saving case %d."
%
counter
)
example
=
dict_to_example
(
example
=
dict_to_example
(
{
"inputs"
:
subtokenizer
.
encode
(
input_line
,
add_eos
=
True
),
{
"inputs"
:
subtokenizer
.
encode
(
input_line
,
add_eos
=
True
),
"targets"
:
subtokenizer
.
encode
(
target_line
,
add_eos
=
True
)})
"targets"
:
subtokenizer
.
encode
(
target_line
,
add_eos
=
True
)})
...
@@ -312,7 +313,7 @@ def encode_and_save_files(
...
@@ -312,7 +313,7 @@ def encode_and_save_files(
for
tmp_name
,
final_name
in
zip
(
tmp_filepaths
,
filepaths
):
for
tmp_name
,
final_name
in
zip
(
tmp_filepaths
,
filepaths
):
tf
.
gfile
.
Rename
(
tmp_name
,
final_name
)
tf
.
gfile
.
Rename
(
tmp_name
,
final_name
)
tf
.
logging
.
info
(
"Saved %d Examples"
,
counter
+
1
)
logging
.
info
(
"Saved %d Examples"
,
counter
+
1
)
return
filepaths
return
filepaths
...
@@ -324,7 +325,7 @@ def shard_filename(path, tag, shard_num, total_shards):
...
@@ -324,7 +325,7 @@ def shard_filename(path, tag, shard_num, total_shards):
def
shuffle_records
(
fname
):
def
shuffle_records
(
fname
):
"""Shuffle records in a single file."""
"""Shuffle records in a single file."""
tf
.
logging
.
info
(
"Shuffling records in file %s"
%
fname
)
logging
.
info
(
"Shuffling records in file %s"
%
fname
)
# Rename file prior to shuffling
# Rename file prior to shuffling
tmp_fname
=
fname
+
".unshuffled"
tmp_fname
=
fname
+
".unshuffled"
...
@@ -335,7 +336,7 @@ def shuffle_records(fname):
...
@@ -335,7 +336,7 @@ def shuffle_records(fname):
for
record
in
reader
:
for
record
in
reader
:
records
.
append
(
record
)
records
.
append
(
record
)
if
len
(
records
)
%
100000
==
0
:
if
len
(
records
)
%
100000
==
0
:
tf
.
logging
.
info
(
"
\t
Read: %d"
,
len
(
records
))
logging
.
info
(
"
\t
Read: %d"
,
len
(
records
))
random
.
shuffle
(
records
)
random
.
shuffle
(
records
)
...
@@ -344,7 +345,7 @@ def shuffle_records(fname):
...
@@ -344,7 +345,7 @@ def shuffle_records(fname):
for
count
,
record
in
enumerate
(
records
):
for
count
,
record
in
enumerate
(
records
):
w
.
write
(
record
)
w
.
write
(
record
)
if
count
>
0
and
count
%
100000
==
0
:
if
count
>
0
and
count
%
100000
==
0
:
tf
.
logging
.
info
(
"
\t
Writing record: %d"
%
count
)
logging
.
info
(
"
\t
Writing record: %d"
%
count
)
tf
.
gfile
.
Remove
(
tmp_fname
)
tf
.
gfile
.
Remove
(
tmp_fname
)
...
@@ -367,7 +368,7 @@ def all_exist(filepaths):
...
@@ -367,7 +368,7 @@ def all_exist(filepaths):
def
make_dir
(
path
):
def
make_dir
(
path
):
if
not
tf
.
gfile
.
Exists
(
path
):
if
not
tf
.
gfile
.
Exists
(
path
):
tf
.
logging
.
info
(
"Creating directory %s"
%
path
)
logging
.
info
(
"Creating directory %s"
%
path
)
tf
.
gfile
.
MakeDirs
(
path
)
tf
.
gfile
.
MakeDirs
(
path
)
...
@@ -377,28 +378,28 @@ def main(unused_argv):
...
@@ -377,28 +378,28 @@ def main(unused_argv):
make_dir
(
FLAGS
.
data_dir
)
make_dir
(
FLAGS
.
data_dir
)
# Download test_data
# Download test_data
tf
.
logging
.
info
(
"Step 1/5: Downloading test data"
)
logging
.
info
(
"Step 1/5: Downloading test data"
)
train_files
=
get_raw_files
(
FLAGS
.
data_dir
,
_TEST_DATA_SOURCES
)
train_files
=
get_raw_files
(
FLAGS
.
data_dir
,
_TEST_DATA_SOURCES
)
# Get paths of download/extracted training and evaluation files.
# Get paths of download/extracted training and evaluation files.
tf
.
logging
.
info
(
"Step 2/5: Downloading data from source"
)
logging
.
info
(
"Step 2/5: Downloading data from source"
)
train_files
=
get_raw_files
(
FLAGS
.
raw_dir
,
_TRAIN_DATA_SOURCES
)
train_files
=
get_raw_files
(
FLAGS
.
raw_dir
,
_TRAIN_DATA_SOURCES
)
eval_files
=
get_raw_files
(
FLAGS
.
raw_dir
,
_EVAL_DATA_SOURCES
)
eval_files
=
get_raw_files
(
FLAGS
.
raw_dir
,
_EVAL_DATA_SOURCES
)
# Create subtokenizer based on the training files.
# Create subtokenizer based on the training files.
tf
.
logging
.
info
(
"Step 3/5: Creating subtokenizer and building vocabulary"
)
logging
.
info
(
"Step 3/5: Creating subtokenizer and building vocabulary"
)
train_files_flat
=
train_files
[
"inputs"
]
+
train_files
[
"targets"
]
train_files_flat
=
train_files
[
"inputs"
]
+
train_files
[
"targets"
]
vocab_file
=
os
.
path
.
join
(
FLAGS
.
data_dir
,
VOCAB_FILE
)
vocab_file
=
os
.
path
.
join
(
FLAGS
.
data_dir
,
VOCAB_FILE
)
subtokenizer
=
tokenizer
.
Subtokenizer
.
init_from_files
(
subtokenizer
=
tokenizer
.
Subtokenizer
.
init_from_files
(
vocab_file
,
train_files_flat
,
_TARGET_VOCAB_SIZE
,
_TARGET_THRESHOLD
,
vocab_file
,
train_files_flat
,
_TARGET_VOCAB_SIZE
,
_TARGET_THRESHOLD
,
min_count
=
None
if
FLAGS
.
search
else
_TRAIN_DATA_MIN_COUNT
)
min_count
=
None
if
FLAGS
.
search
else
_TRAIN_DATA_MIN_COUNT
)
tf
.
logging
.
info
(
"Step 4/5: Compiling training and evaluation data"
)
logging
.
info
(
"Step 4/5: Compiling training and evaluation data"
)
compiled_train_files
=
compile_files
(
FLAGS
.
raw_dir
,
train_files
,
_TRAIN_TAG
)
compiled_train_files
=
compile_files
(
FLAGS
.
raw_dir
,
train_files
,
_TRAIN_TAG
)
compiled_eval_files
=
compile_files
(
FLAGS
.
raw_dir
,
eval_files
,
_EVAL_TAG
)
compiled_eval_files
=
compile_files
(
FLAGS
.
raw_dir
,
eval_files
,
_EVAL_TAG
)
# Tokenize and save data as Examples in the TFRecord format.
# Tokenize and save data as Examples in the TFRecord format.
tf
.
logging
.
info
(
"Step 5/5: Preprocessing and saving data"
)
logging
.
info
(
"Step 5/5: Preprocessing and saving data"
)
train_tfrecord_files
=
encode_and_save_files
(
train_tfrecord_files
=
encode_and_save_files
(
subtokenizer
,
FLAGS
.
data_dir
,
compiled_train_files
,
_TRAIN_TAG
,
subtokenizer
,
FLAGS
.
data_dir
,
compiled_train_files
,
_TRAIN_TAG
,
_TRAIN_SHARDS
)
_TRAIN_SHARDS
)
...
@@ -428,7 +429,7 @@ def define_data_download_flags():
...
@@ -428,7 +429,7 @@ def define_data_download_flags():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
logging
.
set_verbosity
(
logging
.
INFO
)
define_data_download_flags
()
define_data_download_flags
()
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
absl_app
.
run
(
main
)
absl_app
.
run
(
main
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment