Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
12271d7c
Commit
12271d7c
authored
Mar 05, 2020
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 299149155
parent
ab6d40ca
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
10 deletions
+14
-10
official/nlp/transformer/data_download.py
official/nlp/transformer/data_download.py
+14
-10
No files found.
official/nlp/transformer/data_download.py
View file @
12271d7c
...
@@ -23,16 +23,18 @@ import random
...
@@ -23,16 +23,18 @@ import random
import
tarfile
import
tarfile
# pylint: disable=g-bad-import-order
# pylint: disable=g-bad-import-order
import
six
from
six.moves
import
urllib
from
absl
import
app
as
absl_app
from
absl
import
app
as
absl_app
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
from
absl
import
logging
import
six
from
six.moves
import
range
from
six.moves
import
urllib
from
six.moves
import
zip
import
tensorflow.compat.v1
as
tf
import
tensorflow.compat.v1
as
tf
# pylint: enable=g-bad-import-order
from
official.nlp.transformer.utils
import
tokenizer
from
official.nlp.transformer.utils
import
tokenizer
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
# pylint: enable=g-bad-import-order
# Data sources for training/evaluating the transformer translation model.
# Data sources for training/evaluating the transformer translation model.
# If any of the training sources are changed, then either:
# If any of the training sources are changed, then either:
...
@@ -148,7 +150,7 @@ def download_report_hook(count, block_size, total_size):
...
@@ -148,7 +150,7 @@ def download_report_hook(count, block_size, total_size):
total_size: total size
total_size: total size
"""
"""
percent
=
int
(
count
*
block_size
*
100
/
total_size
)
percent
=
int
(
count
*
block_size
*
100
/
total_size
)
print
(
"
\r
%d%%"
%
percent
+
" completed"
,
end
=
"
\r
"
)
print
(
six
.
ensure_str
(
"
\r
%d%%"
%
percent
)
+
" completed"
,
end
=
"
\r
"
)
def
download_from_url
(
path
,
url
):
def
download_from_url
(
path
,
url
):
...
@@ -161,12 +163,12 @@ def download_from_url(path, url):
...
@@ -161,12 +163,12 @@ def download_from_url(path, url):
Returns:
Returns:
Full path to downloaded file
Full path to downloaded file
"""
"""
filename
=
url
.
split
(
"/"
)[
-
1
]
filename
=
six
.
ensure_str
(
url
)
.
split
(
"/"
)[
-
1
]
found_file
=
find_file
(
path
,
filename
,
max_depth
=
0
)
found_file
=
find_file
(
path
,
filename
,
max_depth
=
0
)
if
found_file
is
None
:
if
found_file
is
None
:
filename
=
os
.
path
.
join
(
path
,
filename
)
filename
=
os
.
path
.
join
(
path
,
filename
)
logging
.
info
(
"Downloading from %s to %s."
%
(
url
,
filename
))
logging
.
info
(
"Downloading from %s to %s."
%
(
url
,
filename
))
inprogress_filepath
=
filename
+
".incomplete"
inprogress_filepath
=
six
.
ensure_str
(
filename
)
+
".incomplete"
inprogress_filepath
,
_
=
urllib
.
request
.
urlretrieve
(
inprogress_filepath
,
_
=
urllib
.
request
.
urlretrieve
(
url
,
inprogress_filepath
,
reporthook
=
download_report_hook
)
url
,
inprogress_filepath
,
reporthook
=
download_report_hook
)
# Print newline to clear the carriage return from the download progress.
# Print newline to clear the carriage return from the download progress.
...
@@ -242,8 +244,10 @@ def compile_files(raw_dir, raw_files, tag):
...
@@ -242,8 +244,10 @@ def compile_files(raw_dir, raw_files, tag):
"""
"""
logging
.
info
(
"Compiling files with tag %s."
%
tag
)
logging
.
info
(
"Compiling files with tag %s."
%
tag
)
filename
=
"%s-%s"
%
(
_PREFIX
,
tag
)
filename
=
"%s-%s"
%
(
_PREFIX
,
tag
)
input_compiled_file
=
os
.
path
.
join
(
raw_dir
,
filename
+
".lang1"
)
input_compiled_file
=
os
.
path
.
join
(
raw_dir
,
target_compiled_file
=
os
.
path
.
join
(
raw_dir
,
filename
+
".lang2"
)
six
.
ensure_str
(
filename
)
+
".lang1"
)
target_compiled_file
=
os
.
path
.
join
(
raw_dir
,
six
.
ensure_str
(
filename
)
+
".lang2"
)
with
tf
.
io
.
gfile
.
GFile
(
input_compiled_file
,
mode
=
"w"
)
as
input_writer
:
with
tf
.
io
.
gfile
.
GFile
(
input_compiled_file
,
mode
=
"w"
)
as
input_writer
:
with
tf
.
io
.
gfile
.
GFile
(
target_compiled_file
,
mode
=
"w"
)
as
target_writer
:
with
tf
.
io
.
gfile
.
GFile
(
target_compiled_file
,
mode
=
"w"
)
as
target_writer
:
...
@@ -295,7 +299,7 @@ def encode_and_save_files(
...
@@ -295,7 +299,7 @@ def encode_and_save_files(
target_file
=
raw_files
[
1
]
target_file
=
raw_files
[
1
]
# Write examples to each shard in round robin order.
# Write examples to each shard in round robin order.
tmp_filepaths
=
[
fname
+
".incomplete"
for
fname
in
filepaths
]
tmp_filepaths
=
[
six
.
ensure_str
(
fname
)
+
".incomplete"
for
fname
in
filepaths
]
writers
=
[
tf
.
python_io
.
TFRecordWriter
(
fname
)
for
fname
in
tmp_filepaths
]
writers
=
[
tf
.
python_io
.
TFRecordWriter
(
fname
)
for
fname
in
tmp_filepaths
]
counter
,
shard
=
0
,
0
counter
,
shard
=
0
,
0
for
counter
,
(
input_line
,
target_line
)
in
enumerate
(
zip
(
for
counter
,
(
input_line
,
target_line
)
in
enumerate
(
zip
(
...
@@ -328,7 +332,7 @@ def shuffle_records(fname):
...
@@ -328,7 +332,7 @@ def shuffle_records(fname):
logging
.
info
(
"Shuffling records in file %s"
%
fname
)
logging
.
info
(
"Shuffling records in file %s"
%
fname
)
# Rename file prior to shuffling
# Rename file prior to shuffling
tmp_fname
=
fname
+
".unshuffled"
tmp_fname
=
six
.
ensure_str
(
fname
)
+
".unshuffled"
tf
.
gfile
.
Rename
(
fname
,
tmp_fname
)
tf
.
gfile
.
Rename
(
fname
,
tmp_fname
)
reader
=
tf
.
io
.
tf_record_iterator
(
tmp_fname
)
reader
=
tf
.
io
.
tf_record_iterator
(
tmp_fname
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment