Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9b7e4163
Commit
9b7e4163
authored
Oct 09, 2018
by
Shawn Wang
Browse files
Allow data async generation to be run as a separate job rather than as a subprocess.
parent
42f98218
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
125 additions
and
49 deletions
+125
-49
official/recommendation/constants.py
official/recommendation/constants.py
+2
-0
official/recommendation/data_async_generation.py
official/recommendation/data_async_generation.py
+41
-13
official/recommendation/data_preprocessing.py
official/recommendation/data_preprocessing.py
+75
-36
official/recommendation/popen_helper.py
official/recommendation/popen_helper.py
+7
-0
No files found.
official/recommendation/constants.py
View file @
9b7e4163
...
...
@@ -64,6 +64,8 @@ DUPLICATE_MASK = "duplicate_mask"
CYCLES_TO_BUFFER
=
3
# The number of train cycles worth of data to "run ahead"
# of the main training loop.
COMMAND_FILE_TEMP
=
"command.json.temp"
COMMAND_FILE
=
"command.json"
READY_FILE_TEMP
=
"ready.json.temp"
READY_FILE
=
"ready.json"
TRAIN_RECORD_TEMPLATE
=
"train_{}.tfrecords"
...
...
official/recommendation/data_async_generation.py
View file @
9b7e4163
...
...
@@ -50,6 +50,10 @@ _log_file = None
def
log_msg
(
msg
):
"""Include timestamp info when logging messages to a file."""
if
flags
.
FLAGS
.
use_command_file
:
tf
.
logging
.
info
(
msg
)
return
if
flags
.
FLAGS
.
redirect_logs
:
timestamp
=
datetime
.
datetime
.
now
().
strftime
(
"%Y-%m-%dT%H:%M:%S"
)
print
(
"[{}] {}"
.
format
(
timestamp
,
msg
),
file
=
_log_file
)
...
...
@@ -207,8 +211,7 @@ def _construct_training_records(
map_args
=
[(
shard
,
num_items
,
num_neg
,
process_seeds
[
i
])
for
i
,
shard
in
enumerate
(
training_shards
*
epochs_per_cycle
)]
with
contextlib
.
closing
(
multiprocessing
.
Pool
(
processes
=
num_workers
,
initializer
=
init_worker
))
as
pool
:
with
popen_helper
.
get_pool
(
num_workers
,
init_worker
)
as
pool
:
map_fn
=
pool
.
imap
if
deterministic
else
pool
.
imap_unordered
# pylint: disable=no-member
data_generator
=
map_fn
(
_process_shard
,
map_args
)
data
=
[
...
...
@@ -436,8 +439,39 @@ def _generation_loop(num_workers, # type: int
gc
.
collect
()
def
_set_flags_with_command_file
():
"""Use arguments from COMMAND_FILE when use_command_file is True."""
command_file
=
os
.
path
.
join
(
flags
.
FLAGS
.
data_dir
,
rconst
.
COMMAND_FILE
)
tf
.
logging
.
info
(
"Waiting for command file to appear at {}..."
.
format
(
command_file
))
while
not
tf
.
gfile
.
Exists
(
command_file
):
time
.
sleep
(
1
)
tf
.
logging
.
info
(
"Command file found."
)
with
tf
.
gfile
.
Open
(
command_file
,
"r"
)
as
f
:
command
=
json
.
load
(
f
)
flags
.
FLAGS
.
num_workers
=
command
[
"num_workers"
]
assert
flags
.
FLAGS
.
data_dir
==
command
[
"data_dir"
]
flags
.
FLAGS
.
cache_id
=
command
[
"cache_id"
]
flags
.
FLAGS
.
num_readers
=
command
[
"num_readers"
]
flags
.
FLAGS
.
num_neg
=
command
[
"num_neg"
]
flags
.
FLAGS
.
num_train_positives
=
command
[
"num_train_positives"
]
flags
.
FLAGS
.
num_items
=
command
[
"num_items"
]
flags
.
FLAGS
.
epochs_per_cycle
=
command
[
"epochs_per_cycle"
]
flags
.
FLAGS
.
train_batch_size
=
command
[
"train_batch_size"
]
flags
.
FLAGS
.
eval_batch_size
=
command
[
"eval_batch_size"
]
flags
.
FLAGS
.
spillover
=
command
[
"spillover"
]
flags
.
FLAGS
.
redirect_logs
=
command
[
"redirect_logs"
]
assert
flags
.
FLAGS
.
redirect_logs
is
False
if
"seed"
in
command
:
flags
.
FLAGS
.
seed
=
command
[
"seed"
]
def
main
(
_
):
global
_log_file
if
flags
.
FLAGS
.
use_command_file
is
not
None
:
_set_flags_with_command_file
()
redirect_logs
=
flags
.
FLAGS
.
redirect_logs
cache_paths
=
rconst
.
Paths
(
data_dir
=
flags
.
FLAGS
.
data_dir
,
cache_id
=
flags
.
FLAGS
.
cache_id
)
...
...
@@ -489,16 +523,12 @@ def main(_):
def
define_flags
():
"""Construct flags for the server.
This function does not use offical.utils.flags, as these flags are not meant
to be used by humans. Rather, they should be passed as part of a subprocess
call.
"""
"""Construct flags for the server."""
flags
.
DEFINE_integer
(
name
=
"num_workers"
,
default
=
multiprocessing
.
cpu_count
(),
help
=
"Size of the negative generation worker pool."
)
flags
.
DEFINE_string
(
name
=
"data_dir"
,
default
=
None
,
help
=
"The data root. (used to construct cache paths.)"
)
flags
.
mark_flags_as_required
([
"data_dir"
])
flags
.
DEFINE_string
(
name
=
"cache_id"
,
default
=
None
,
help
=
"The cache_id generated in the main process."
)
flags
.
DEFINE_integer
(
name
=
"num_readers"
,
default
=
4
,
...
...
@@ -531,11 +561,9 @@ def define_flags():
flags
.
DEFINE_integer
(
name
=
"seed"
,
default
=
None
,
help
=
"NumPy random seed to set at startup. If not "
"specified, a seed will not be set."
)
flags
.
mark_flags_as_required
(
[
"data_dir"
,
"cache_id"
,
"num_neg"
,
"num_train_positives"
,
"num_items"
,
"train_batch_size"
,
"eval_batch_size"
])
flags
.
DEFINE_boolean
(
name
=
"use_command_file"
,
default
=
False
,
help
=
"Use command arguments from json at command_path. "
"All arguments other than data_dir will be ignored."
)
if
__name__
==
"__main__"
:
...
...
official/recommendation/data_preprocessing.py
View file @
9b7e4163
...
...
@@ -416,7 +416,8 @@ def _shutdown(proc):
def
instantiate_pipeline
(
dataset
,
data_dir
,
batch_size
,
eval_batch_size
,
num_data_readers
=
None
,
num_neg
=
4
,
epochs_per_cycle
=
1
,
match_mlperf
=
False
,
deterministic
=
False
):
match_mlperf
=
False
,
deterministic
=
False
,
use_subprocess
=
True
):
# type: (...) -> (NCFDataset, typing.Callable)
"""Preprocess data and start negative generation subprocess."""
...
...
@@ -425,7 +426,11 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
num_data_readers
=
num_data_readers
,
match_mlperf
=
match_mlperf
,
deterministic
=
deterministic
)
# By limiting the number of workers we guarantee that the worker
# pool underlying the training generation doesn't starve other processes.
num_workers
=
int
(
multiprocessing
.
cpu_count
()
*
0.75
)
or
1
if
use_subprocess
:
tf
.
logging
.
info
(
"Creating training file subprocess."
)
subproc_env
=
os
.
environ
.
copy
()
...
...
@@ -435,10 +440,6 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
# contention with the main training process.
subproc_env
[
"CUDA_VISIBLE_DEVICES"
]
=
""
# By limiting the number of workers we guarantee that the worker
# pool underlying the training generation doesn't starve other processes.
num_workers
=
int
(
multiprocessing
.
cpu_count
()
*
0.75
)
or
1
subproc_args
=
popen_helper
.
INVOCATION
+
[
"--data_dir"
,
data_dir
,
"--cache_id"
,
str
(
ncf_dataset
.
cache_paths
.
cache_id
),
...
...
@@ -450,10 +451,10 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
"--train_batch_size"
,
str
(
batch_size
),
"--eval_batch_size"
,
str
(
eval_batch_size
),
"--num_workers"
,
str
(
num_workers
),
"--spillover"
,
"True"
,
# This allows the training input function to
# guarantee batch size and significantly improves
# performance. (~5% increase in examples/sec on
# This allows the training input function to guarantee batch size and
# significantly improves performance. (~5% increase in examples/sec on
# GPU, and needed for TPU XLA.)
"--spillover"
,
"True"
,
"--redirect_logs"
,
"True"
]
if
ncf_dataset
.
deterministic
:
...
...
@@ -464,6 +465,42 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
proc
=
subprocess
.
Popen
(
args
=
subproc_args
,
shell
=
False
,
env
=
subproc_env
)
else
:
# We write to a temp file then atomically rename it to the final file,
# because writing directly to the final file can cause the data generation
# async process to read a partially written JSON file.
command_file_temp
=
os
.
path
.
join
(
data_dir
,
rconst
.
COMMAND_FILE_TEMP
)
tf
.
logging
.
info
(
"Generation subprocess command at {} ..."
.
format
(
command_file_temp
))
with
tf
.
gfile
.
Open
(
command_file_temp
,
"w"
)
as
f
:
command
=
{
"data_dir"
:
data_dir
,
"cache_id"
:
ncf_dataset
.
cache_paths
.
cache_id
,
"num_neg"
:
num_neg
,
"num_train_positives"
:
ncf_dataset
.
num_train_positives
,
"num_items"
:
ncf_dataset
.
num_items
,
"num_readers"
:
ncf_dataset
.
num_data_readers
,
"epochs_per_cycle"
:
epochs_per_cycle
,
"train_batch_size"
:
batch_size
,
"eval_batch_size"
:
eval_batch_size
,
"num_workers"
:
num_workers
,
# This allows the training input function to guarantee batch size and
# significantly improves performance. (~5% increase in examples/sec on
# GPU, and needed for TPU XLA.)
"spillover"
:
True
,
"redirect_logs"
:
False
}
if
ncf_dataset
.
deterministic
:
command
[
"seed"
]
=
stat_utils
.
random_int32
()
json
.
dump
(
command
,
f
)
command_file
=
os
.
path
.
join
(
data_dir
,
rconst
.
COMMAND_FILE
)
tf
.
gfile
.
Rename
(
command_file_temp
,
command_file
)
tf
.
logging
.
info
(
"Generation subprocess command saved to: {}"
.
format
(
command_file
))
cleanup_called
=
{
"finished"
:
False
}
@
atexit
.
register
def
cleanup
():
...
...
@@ -471,7 +508,9 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
if
cleanup_called
[
"finished"
]:
return
if
use_subprocess
:
_shutdown
(
proc
)
try
:
tf
.
gfile
.
DeleteRecursively
(
ncf_dataset
.
cache_paths
.
cache_root
)
except
tf
.
errors
.
NotFoundError
:
...
...
official/recommendation/popen_helper.py
View file @
9b7e4163
...
...
@@ -14,6 +14,8 @@
# ==============================================================================
"""Helper file for running the async data generation process in OSS."""
import
contextlib
import
multiprocessing
import
os
import
sys
...
...
@@ -27,3 +29,8 @@ _ASYNC_GEN_PATH = os.path.join(os.path.dirname(__file__),
"data_async_generation.py"
)
INVOCATION
=
[
_PYTHON
,
_ASYNC_GEN_PATH
]
def
get_pool
(
num_workers
,
init_worker
=
None
):
return
contextlib
.
closing
(
multiprocessing
.
Pool
(
processes
=
num_workers
,
initializer
=
init_worker
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment