Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
89031e1a
Commit
89031e1a
authored
Feb 23, 2021
by
Allen Wang
Committed by
A. Unique TensorFlower
Feb 23, 2021
Browse files
tf.compat.v1 for preprocess_pretrain_data.
PiperOrigin-RevId: 359103541
parent
d2d32f46
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
63 additions
and
56 deletions
+63
-56
official/nlp/xlnet/preprocess_pretrain_data.py
official/nlp/xlnet/preprocess_pretrain_data.py
+63
-56
No files found.
official/nlp/xlnet/preprocess_pretrain_data.py
View file @
89031e1a
...
@@ -22,14 +22,15 @@ import random
...
@@ -22,14 +22,15 @@ import random
# Import libraries
# Import libraries
from
absl
import
app
from
absl
import
app
from
absl
import
flags
from
absl
import
flags
import
absl.logging
as
_logging
# pylint: disable=unused-import
from
absl
import
logging
import
numpy
as
np
import
numpy
as
np
import
tensorflow.compat.v1
as
tf
import
tensorflow.google
as
tf
from
official.nlp.xlnet
import
preprocess_utils
import
sentencepiece
as
spm
import
sentencepiece
as
spm
from
official.nlp.xlnet
import
preprocess_utils
FLAGS
=
flags
.
FLAGS
special_symbols
=
{
special_symbols
=
{
...
@@ -89,6 +90,7 @@ def format_filename(prefix, bsz_per_host, seq_len, bi_data, suffix,
...
@@ -89,6 +90,7 @@ def format_filename(prefix, bsz_per_host, seq_len, bi_data, suffix,
def
_create_data
(
idx
,
input_paths
):
def
_create_data
(
idx
,
input_paths
):
"""Creates data."""
# Load sentence-piece model
# Load sentence-piece model
sp
=
spm
.
SentencePieceProcessor
()
sp
=
spm
.
SentencePieceProcessor
()
sp
.
Load
(
FLAGS
.
sp_path
)
sp
.
Load
(
FLAGS
.
sp_path
)
...
@@ -98,10 +100,10 @@ def _create_data(idx, input_paths):
...
@@ -98,10 +100,10 @@ def _create_data(idx, input_paths):
for
input_path
in
input_paths
:
for
input_path
in
input_paths
:
input_data
,
sent_ids
=
[],
[]
input_data
,
sent_ids
=
[],
[]
sent_id
,
line_cnt
=
True
,
0
sent_id
,
line_cnt
=
True
,
0
tf
.
logging
.
info
(
"Processing %s"
,
input_path
)
logging
.
info
(
"Processing %s"
,
input_path
)
for
line
in
tf
.
gfile
.
Open
(
input_path
):
for
line
in
tf
.
gfile
.
Open
(
input_path
):
if
line_cnt
%
100000
==
0
:
if
line_cnt
%
100000
==
0
:
tf
.
logging
.
info
(
"Loading line %d"
,
line_cnt
)
logging
.
info
(
"Loading line %d"
,
line_cnt
)
line_cnt
+=
1
line_cnt
+=
1
if
not
line
.
strip
():
if
not
line
.
strip
():
...
@@ -122,7 +124,7 @@ def _create_data(idx, input_paths):
...
@@ -122,7 +124,7 @@ def _create_data(idx, input_paths):
sent_ids
.
extend
([
sent_id
]
*
len
(
cur_sent
))
sent_ids
.
extend
([
sent_id
]
*
len
(
cur_sent
))
sent_id
=
not
sent_id
sent_id
=
not
sent_id
tf
.
logging
.
info
(
"Finish with line %d"
,
line_cnt
)
logging
.
info
(
"Finish with line %d"
,
line_cnt
)
if
line_cnt
==
0
:
if
line_cnt
==
0
:
continue
continue
...
@@ -132,7 +134,7 @@ def _create_data(idx, input_paths):
...
@@ -132,7 +134,7 @@ def _create_data(idx, input_paths):
total_line_cnt
+=
line_cnt
total_line_cnt
+=
line_cnt
input_shards
.
append
((
input_data
,
sent_ids
))
input_shards
.
append
((
input_data
,
sent_ids
))
tf
.
logging
.
info
(
"[Task %d] Total number line: %d"
,
idx
,
total_line_cnt
)
logging
.
info
(
"[Task %d] Total number line: %d"
,
idx
,
total_line_cnt
)
tfrecord_dir
=
os
.
path
.
join
(
FLAGS
.
save_dir
,
"tfrecords"
)
tfrecord_dir
=
os
.
path
.
join
(
FLAGS
.
save_dir
,
"tfrecords"
)
...
@@ -142,7 +144,7 @@ def _create_data(idx, input_paths):
...
@@ -142,7 +144,7 @@ def _create_data(idx, input_paths):
np
.
random
.
seed
(
100
*
FLAGS
.
task
+
FLAGS
.
pass_id
)
np
.
random
.
seed
(
100
*
FLAGS
.
task
+
FLAGS
.
pass_id
)
perm_indices
=
np
.
random
.
permutation
(
len
(
input_shards
))
perm_indices
=
np
.
random
.
permutation
(
len
(
input_shards
))
tf
.
logging
.
info
(
"Using perm indices %s for pass %d"
,
logging
.
info
(
"Using perm indices %s for pass %d"
,
perm_indices
.
tolist
(),
FLAGS
.
pass_id
)
perm_indices
.
tolist
(),
FLAGS
.
pass_id
)
input_data_list
,
sent_ids_list
=
[],
[]
input_data_list
,
sent_ids_list
=
[],
[]
...
@@ -185,6 +187,7 @@ def _create_data(idx, input_paths):
...
@@ -185,6 +187,7 @@ def _create_data(idx, input_paths):
def
create_data
(
_
):
def
create_data
(
_
):
"""Creates pretrain data."""
# Validate FLAGS
# Validate FLAGS
assert
FLAGS
.
bsz_per_host
%
FLAGS
.
num_core_per_host
==
0
assert
FLAGS
.
bsz_per_host
%
FLAGS
.
num_core_per_host
==
0
if
not
FLAGS
.
use_tpu
:
if
not
FLAGS
.
use_tpu
:
...
@@ -221,15 +224,15 @@ def create_data(_):
...
@@ -221,15 +224,15 @@ def create_data(_):
# Interleavely split the work into FLAGS.num_task splits
# Interleavely split the work into FLAGS.num_task splits
file_paths
=
sorted
(
tf
.
gfile
.
Glob
(
FLAGS
.
input_glob
))
file_paths
=
sorted
(
tf
.
gfile
.
Glob
(
FLAGS
.
input_glob
))
tf
.
logging
.
info
(
"Use glob: %s"
,
FLAGS
.
input_glob
)
logging
.
info
(
"Use glob: %s"
,
FLAGS
.
input_glob
)
tf
.
logging
.
info
(
"Find %d files: %s"
,
len
(
file_paths
),
file_paths
)
logging
.
info
(
"Find %d files: %s"
,
len
(
file_paths
),
file_paths
)
task_file_paths
=
file_paths
[
FLAGS
.
task
::
FLAGS
.
num_task
]
task_file_paths
=
file_paths
[
FLAGS
.
task
::
FLAGS
.
num_task
]
if
not
task_file_paths
:
if
not
task_file_paths
:
tf
.
logging
.
info
(
"Exit: task %d has no file to process."
,
FLAGS
.
task
)
logging
.
info
(
"Exit: task %d has no file to process."
,
FLAGS
.
task
)
return
return
tf
.
logging
.
info
(
"Task %d process %d files: %s"
,
logging
.
info
(
"Task %d process %d files: %s"
,
FLAGS
.
task
,
len
(
task_file_paths
),
task_file_paths
)
FLAGS
.
task
,
len
(
task_file_paths
),
task_file_paths
)
record_info
=
_create_data
(
FLAGS
.
task
,
task_file_paths
)
record_info
=
_create_data
(
FLAGS
.
task
,
task_file_paths
)
...
@@ -253,6 +256,7 @@ def create_data(_):
...
@@ -253,6 +256,7 @@ def create_data(_):
def
batchify
(
data
,
bsz_per_host
,
sent_ids
=
None
):
def
batchify
(
data
,
bsz_per_host
,
sent_ids
=
None
):
"""Creates batches."""
num_step
=
len
(
data
)
//
bsz_per_host
num_step
=
len
(
data
)
//
bsz_per_host
data
=
data
[:
bsz_per_host
*
num_step
]
data
=
data
[:
bsz_per_host
*
num_step
]
data
=
data
.
reshape
(
bsz_per_host
,
num_step
)
data
=
data
.
reshape
(
bsz_per_host
,
num_step
)
...
@@ -270,7 +274,7 @@ def _split_a_and_b(data, sent_ids, begin_idx, tot_len, extend_target=False):
...
@@ -270,7 +274,7 @@ def _split_a_and_b(data, sent_ids, begin_idx, tot_len, extend_target=False):
data_len
=
data
.
shape
[
0
]
data_len
=
data
.
shape
[
0
]
if
begin_idx
+
tot_len
>=
data_len
:
if
begin_idx
+
tot_len
>=
data_len
:
tf
.
logging
.
info
(
"[_split_a_and_b] returns None: "
logging
.
info
(
"[_split_a_and_b] returns None: "
"begin_idx %d + tot_len %d >= data_len %d"
,
"begin_idx %d + tot_len %d >= data_len %d"
,
begin_idx
,
tot_len
,
data_len
)
begin_idx
,
tot_len
,
data_len
)
return
None
return
None
...
@@ -284,9 +288,9 @@ def _split_a_and_b(data, sent_ids, begin_idx, tot_len, extend_target=False):
...
@@ -284,9 +288,9 @@ def _split_a_and_b(data, sent_ids, begin_idx, tot_len, extend_target=False):
end_idx
+=
1
end_idx
+=
1
a_begin
=
begin_idx
a_begin
=
begin_idx
if
len
(
cut_points
)
==
0
or
random
.
random
()
<
0.5
:
if
len
(
cut_points
)
==
0
or
random
.
random
()
<
0.5
:
# pylint:disable=g-explicit-length-test
label
=
0
label
=
0
if
len
(
cut_points
)
==
0
:
if
len
(
cut_points
)
==
0
:
# pylint:disable=g-explicit-length-test
a_end
=
end_idx
a_end
=
end_idx
else
:
else
:
a_end
=
random
.
choice
(
cut_points
)
a_end
=
random
.
choice
(
cut_points
)
...
@@ -321,7 +325,7 @@ def _split_a_and_b(data, sent_ids, begin_idx, tot_len, extend_target=False):
...
@@ -321,7 +325,7 @@ def _split_a_and_b(data, sent_ids, begin_idx, tot_len, extend_target=False):
if
extend_target
:
if
extend_target
:
if
a_end
>=
data_len
or
b_end
>=
data_len
:
if
a_end
>=
data_len
or
b_end
>=
data_len
:
tf
.
logging
.
info
(
"[_split_a_and_b] returns None: "
logging
.
info
(
"[_split_a_and_b] returns None: "
"a_end %d or b_end %d >= data_len %d"
,
"a_end %d or b_end %d >= data_len %d"
,
a_end
,
b_end
,
data_len
)
a_end
,
b_end
,
data_len
)
return
None
return
None
...
@@ -342,9 +346,7 @@ def _is_start_piece(piece):
...
@@ -342,9 +346,7 @@ def _is_start_piece(piece):
def
_sample_mask
(
sp
,
seg
,
reverse
=
False
,
max_gram
=
5
,
goal_num_predict
=
None
):
def
_sample_mask
(
sp
,
seg
,
reverse
=
False
,
max_gram
=
5
,
goal_num_predict
=
None
):
"""Sample `goal_num_predict` tokens for partial prediction.
"""Samples `goal_num_predict` tokens for partial prediction."""
About `mask_beta` tokens are chosen in a context of `mask_alpha` tokens."""
seg_len
=
len
(
seg
)
seg_len
=
len
(
seg
)
mask
=
np
.
array
([
False
]
*
seg_len
,
dtype
=
np
.
bool
)
mask
=
np
.
array
([
False
]
*
seg_len
,
dtype
=
np
.
bool
)
...
@@ -406,8 +408,7 @@ def _sample_mask(sp, seg, reverse=False, max_gram=5, goal_num_predict=None):
...
@@ -406,8 +408,7 @@ def _sample_mask(sp, seg, reverse=False, max_gram=5, goal_num_predict=None):
def
_sample_mask_ngram
(
sp
,
seg
,
reverse
=
False
,
max_gram
=
5
,
def
_sample_mask_ngram
(
sp
,
seg
,
reverse
=
False
,
max_gram
=
5
,
goal_num_predict
=
None
):
goal_num_predict
=
None
):
"""Sample `goal_num_predict` tokens for partial prediction.
"""Sample `goal_num_predict` tokens for partial prediction."""
About `mask_beta` tokens are chosen in a context of `mask_alpha` tokens."""
seg_len
=
len
(
seg
)
seg_len
=
len
(
seg
)
mask
=
np
.
array
([
False
]
*
seg_len
,
dtype
=
np
.
bool
)
mask
=
np
.
array
([
False
]
*
seg_len
,
dtype
=
np
.
bool
)
...
@@ -474,6 +475,7 @@ def _sample_mask_ngram(sp, seg, reverse=False, max_gram=5,
...
@@ -474,6 +475,7 @@ def _sample_mask_ngram(sp, seg, reverse=False, max_gram=5,
def
create_tfrecords
(
save_dir
,
basename
,
data
,
bsz_per_host
,
seq_len
,
def
create_tfrecords
(
save_dir
,
basename
,
data
,
bsz_per_host
,
seq_len
,
bi_data
,
sp
):
bi_data
,
sp
):
"""Creates TFRecords."""
data
,
sent_ids
=
data
[
0
],
data
[
1
]
data
,
sent_ids
=
data
[
0
],
data
[
1
]
num_core
=
FLAGS
.
num_core_per_host
num_core
=
FLAGS
.
num_core_per_host
...
@@ -496,7 +498,7 @@ def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len,
...
@@ -496,7 +498,7 @@ def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len,
else
:
else
:
data
,
sent_ids
=
batchify
(
data
,
bsz_per_host
,
sent_ids
)
data
,
sent_ids
=
batchify
(
data
,
bsz_per_host
,
sent_ids
)
tf
.
logging
.
info
(
"Raw data shape %s."
,
data
.
shape
)
logging
.
info
(
"Raw data shape %s."
,
data
.
shape
)
file_name
=
format_filename
(
file_name
=
format_filename
(
prefix
=
basename
,
prefix
=
basename
,
...
@@ -512,7 +514,7 @@ def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len,
...
@@ -512,7 +514,7 @@ def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len,
)
)
save_path
=
os
.
path
.
join
(
save_dir
,
file_name
)
save_path
=
os
.
path
.
join
(
save_dir
,
file_name
)
record_writer
=
tf
.
python_io
.
TFRecordWriter
(
save_path
)
record_writer
=
tf
.
python_io
.
TFRecordWriter
(
save_path
)
tf
.
logging
.
info
(
"Start writing %s."
,
save_path
)
logging
.
info
(
"Start writing %s."
,
save_path
)
num_batch
=
0
num_batch
=
0
reuse_len
=
FLAGS
.
reuse_len
reuse_len
=
FLAGS
.
reuse_len
...
@@ -527,7 +529,7 @@ def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len,
...
@@ -527,7 +529,7 @@ def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len,
i
=
0
i
=
0
while
i
+
seq_len
<=
data_len
:
while
i
+
seq_len
<=
data_len
:
if
num_batch
%
500
==
0
:
if
num_batch
%
500
==
0
:
tf
.
logging
.
info
(
"Processing batch %d"
,
num_batch
)
logging
.
info
(
"Processing batch %d"
,
num_batch
)
all_ok
=
True
all_ok
=
True
features
=
[]
features
=
[]
...
@@ -542,7 +544,7 @@ def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len,
...
@@ -542,7 +544,7 @@ def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len,
tot_len
=
seq_len
-
reuse_len
-
3
,
tot_len
=
seq_len
-
reuse_len
-
3
,
extend_target
=
True
)
extend_target
=
True
)
if
results
is
None
:
if
results
is
None
:
tf
.
logging
.
info
(
"Break out with seq idx %d"
,
i
)
logging
.
info
(
"Break out with seq idx %d"
,
i
)
all_ok
=
False
all_ok
=
False
break
break
...
@@ -600,7 +602,7 @@ def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len,
...
@@ -600,7 +602,7 @@ def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len,
i
+=
reuse_len
i
+=
reuse_len
record_writer
.
close
()
record_writer
.
close
()
tf
.
logging
.
info
(
"Done writing %s. Num of batches: %d"
,
save_path
,
num_batch
)
logging
.
info
(
"Done writing %s. Num of batches: %d"
,
save_path
,
num_batch
)
return
save_path
,
num_batch
return
save_path
,
num_batch
...
@@ -624,6 +626,7 @@ def _convert_example(example, use_bfloat16):
...
@@ -624,6 +626,7 @@ def _convert_example(example, use_bfloat16):
def
parse_files_to_dataset
(
parser
,
file_names
,
split
,
num_batch
,
num_hosts
,
def
parse_files_to_dataset
(
parser
,
file_names
,
split
,
num_batch
,
num_hosts
,
host_id
,
num_core_per_host
,
bsz_per_core
):
host_id
,
num_core_per_host
,
bsz_per_core
):
"""Parses files to a dataset."""
# list of file pathes
# list of file pathes
num_files
=
len
(
file_names
)
num_files
=
len
(
file_names
)
num_files_per_host
=
num_files
//
num_hosts
num_files_per_host
=
num_files
//
num_hosts
...
@@ -632,7 +635,7 @@ def parse_files_to_dataset(parser, file_names, split, num_batch, num_hosts,
...
@@ -632,7 +635,7 @@ def parse_files_to_dataset(parser, file_names, split, num_batch, num_hosts,
if
host_id
==
num_hosts
-
1
:
if
host_id
==
num_hosts
-
1
:
my_end_file_id
=
num_files
my_end_file_id
=
num_files
file_paths
=
file_names
[
my_start_file_id
:
my_end_file_id
]
file_paths
=
file_names
[
my_start_file_id
:
my_end_file_id
]
tf
.
logging
.
info
(
"Host %d handles %d files"
,
host_id
,
len
(
file_paths
))
logging
.
info
(
"Host %d handles %d files"
,
host_id
,
len
(
file_paths
))
assert
split
==
"train"
assert
split
==
"train"
dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
file_paths
)
dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
file_paths
)
...
@@ -657,9 +660,7 @@ def parse_files_to_dataset(parser, file_names, split, num_batch, num_hosts,
...
@@ -657,9 +660,7 @@ def parse_files_to_dataset(parser, file_names, split, num_batch, num_hosts,
def
_local_perm
(
inputs
,
targets
,
is_masked
,
perm_size
,
seq_len
):
def
_local_perm
(
inputs
,
targets
,
is_masked
,
perm_size
,
seq_len
):
"""
"""Samples a permutation of the factorization order, and create a mask.
Sample a permutation of the factorization order, and create an
attention mask accordingly.
Args:
Args:
inputs: int64 Tensor in shape [seq_len], input ids.
inputs: int64 Tensor in shape [seq_len], input ids.
...
@@ -669,6 +670,10 @@ def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
...
@@ -669,6 +670,10 @@ def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
perm_size: the length of longest permutation. Could be set to be reuse_len.
perm_size: the length of longest permutation. Could be set to be reuse_len.
Should not be larger than reuse_len or there will be data leaks.
Should not be larger than reuse_len or there will be data leaks.
seq_len: int, sequence length.
seq_len: int, sequence length.
Returns:
The permutation mask, new targets, target mask, and new inputs.
"""
"""
# Generate permutation indices
# Generate permutation indices
...
@@ -726,6 +731,7 @@ def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
...
@@ -726,6 +731,7 @@ def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
def
get_dataset
(
params
,
num_hosts
,
num_core_per_host
,
split
,
file_names
,
def
get_dataset
(
params
,
num_hosts
,
num_core_per_host
,
split
,
file_names
,
num_batch
,
seq_len
,
reuse_len
,
perm_size
,
mask_alpha
,
num_batch
,
seq_len
,
reuse_len
,
perm_size
,
mask_alpha
,
mask_beta
,
use_bfloat16
=
False
,
num_predict
=
None
):
mask_beta
,
use_bfloat16
=
False
,
num_predict
=
None
):
"""Gets the dataset."""
bsz_per_core
=
params
[
"batch_size"
]
bsz_per_core
=
params
[
"batch_size"
]
if
num_hosts
>
1
:
if
num_hosts
>
1
:
...
@@ -821,7 +827,7 @@ def get_dataset(params, num_hosts, num_core_per_host, split, file_names,
...
@@ -821,7 +827,7 @@ def get_dataset(params, num_hosts, num_core_per_host, split, file_names,
_convert_example
(
example
,
use_bfloat16
)
_convert_example
(
example
,
use_bfloat16
)
for
k
,
v
in
example
.
items
():
for
k
,
v
in
example
.
items
():
tf
.
logging
.
info
(
"%s: %s"
,
k
,
v
)
logging
.
info
(
"%s: %s"
,
k
,
v
)
return
example
return
example
...
@@ -855,6 +861,7 @@ def get_input_fn(
...
@@ -855,6 +861,7 @@ def get_input_fn(
num_passes
=
None
,
num_passes
=
None
,
use_bfloat16
=
False
,
use_bfloat16
=
False
,
num_predict
=
None
):
num_predict
=
None
):
"""Gets the input function."""
# Merge all record infos into a single one
# Merge all record infos into a single one
record_glob_base
=
format_filename
(
record_glob_base
=
format_filename
(
...
@@ -872,15 +879,14 @@ def get_input_fn(
...
@@ -872,15 +879,14 @@ def get_input_fn(
record_info
=
{
"num_batch"
:
0
,
"filenames"
:
[]}
record_info
=
{
"num_batch"
:
0
,
"filenames"
:
[]}
tfrecord_dirs
=
tfrecord_dir
.
split
(
","
)
tfrecord_dirs
=
tfrecord_dir
.
split
(
","
)
tf
.
logging
.
info
(
"Use the following tfrecord dirs: %s"
,
tfrecord_dirs
)
logging
.
info
(
"Use the following tfrecord dirs: %s"
,
tfrecord_dirs
)
for
idx
,
record_dir
in
enumerate
(
tfrecord_dirs
):
for
idx
,
record_dir
in
enumerate
(
tfrecord_dirs
):
record_glob
=
os
.
path
.
join
(
record_dir
,
record_glob_base
)
record_glob
=
os
.
path
.
join
(
record_dir
,
record_glob_base
)
tf
.
logging
.
info
(
"[%d] Record glob: %s"
,
idx
,
record_glob
)
logging
.
info
(
"[%d] Record glob: %s"
,
idx
,
record_glob
)
record_paths
=
sorted
(
tf
.
gfile
.
Glob
(
record_glob
))
record_paths
=
sorted
(
tf
.
gfile
.
Glob
(
record_glob
))
tf
.
logging
.
info
(
"[%d] Num of record info path: %d"
,
logging
.
info
(
"[%d] Num of record info path: %d"
,
idx
,
len
(
record_paths
))
idx
,
len
(
record_paths
))
cur_record_info
=
{
"num_batch"
:
0
,
"filenames"
:
[]}
cur_record_info
=
{
"num_batch"
:
0
,
"filenames"
:
[]}
...
@@ -890,7 +896,7 @@ def get_input_fn(
...
@@ -890,7 +896,7 @@ def get_input_fn(
fields
=
record_info_name
.
split
(
"."
)[
0
].
split
(
"-"
)
fields
=
record_info_name
.
split
(
"."
)[
0
].
split
(
"-"
)
pass_id
=
int
(
fields
[
-
1
])
pass_id
=
int
(
fields
[
-
1
])
if
len
(
fields
)
==
5
and
pass_id
>=
num_passes
:
if
len
(
fields
)
==
5
and
pass_id
>=
num_passes
:
tf
.
logging
.
info
(
"Skip pass %d: %s"
,
pass_id
,
record_info_name
)
logging
.
info
(
"Skip pass %d: %s"
,
pass_id
,
record_info_name
)
continue
continue
with
tf
.
gfile
.
Open
(
record_info_path
,
"r"
)
as
fp
:
with
tf
.
gfile
.
Open
(
record_info_path
,
"r"
)
as
fp
:
...
@@ -912,21 +918,19 @@ def get_input_fn(
...
@@ -912,21 +918,19 @@ def get_input_fn(
new_filenames
.
append
(
new_filename
)
new_filenames
.
append
(
new_filename
)
cur_record_info
[
"filenames"
]
=
new_filenames
cur_record_info
[
"filenames"
]
=
new_filenames
tf
.
logging
.
info
(
"[Dir %d] Number of chosen batches: %s"
,
logging
.
info
(
"[Dir %d] Number of chosen batches: %s"
,
idx
,
cur_record_info
[
"num_batch"
])
idx
,
cur_record_info
[
"num_batch"
])
tf
.
logging
.
info
(
"[Dir %d] Number of chosen files: %s"
,
logging
.
info
(
"[Dir %d] Number of chosen files: %s"
,
idx
,
len
(
cur_record_info
[
"filenames"
]))
idx
,
len
(
cur_record_info
[
"filenames"
]))
tf
.
logging
.
info
(
cur_record_info
[
"filenames"
])
logging
.
info
(
cur_record_info
[
"filenames"
])
# add `cur_record_info` to global `record_info`
# add `cur_record_info` to global `record_info`
record_info
[
"num_batch"
]
+=
cur_record_info
[
"num_batch"
]
record_info
[
"num_batch"
]
+=
cur_record_info
[
"num_batch"
]
record_info
[
"filenames"
]
+=
cur_record_info
[
"filenames"
]
record_info
[
"filenames"
]
+=
cur_record_info
[
"filenames"
]
tf
.
logging
.
info
(
"Total number of batches: %d"
,
logging
.
info
(
"Total number of batches: %d"
,
record_info
[
"num_batch"
])
record_info
[
"num_batch"
])
logging
.
info
(
"Total number of files: %d"
,
len
(
record_info
[
"filenames"
]))
tf
.
logging
.
info
(
"Total number of files: %d"
,
logging
.
info
(
record_info
[
"filenames"
])
len
(
record_info
[
"filenames"
]))
tf
.
logging
.
info
(
record_info
[
"filenames"
])
def
input_fn
(
params
):
def
input_fn
(
params
):
"""docs."""
"""docs."""
...
@@ -952,8 +956,8 @@ def get_input_fn(
...
@@ -952,8 +956,8 @@ def get_input_fn(
return
input_fn
,
record_info
return
input_fn
,
record_info
if
__name__
==
"__main__"
:
def
define_flags
()
:
FLAGS
=
flags
.
FLAGS
"""Defines relevant
flags.
"""
flags
.
DEFINE_bool
(
"use_tpu"
,
True
,
help
=
"whether to use TPUs"
)
flags
.
DEFINE_bool
(
"use_tpu"
,
True
,
help
=
"whether to use TPUs"
)
flags
.
DEFINE_integer
(
"bsz_per_host"
,
32
,
help
=
"batch size per host."
)
flags
.
DEFINE_integer
(
"bsz_per_host"
,
32
,
help
=
"batch size per host."
)
flags
.
DEFINE_integer
(
"num_core_per_host"
,
8
,
help
=
"num TPU cores per host."
)
flags
.
DEFINE_integer
(
"num_core_per_host"
,
8
,
help
=
"num TPU cores per host."
)
...
@@ -991,5 +995,8 @@ if __name__ == "__main__":
...
@@ -991,5 +995,8 @@ if __name__ == "__main__":
flags
.
DEFINE_integer
(
"task"
,
0
,
help
=
"The Task ID. This value is used when "
flags
.
DEFINE_integer
(
"task"
,
0
,
help
=
"The Task ID. This value is used when "
"using multiple workers to identify each worker."
)
"using multiple workers to identify each worker."
)
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
if
__name__
==
"__main__"
:
define_flags
()
logging
.
set_verbosity
(
logging
.
INFO
)
app
.
run
(
create_data
)
app
.
run
(
create_data
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment