Unverified Commit c6bef65a authored by Taylor Robie's avatar Taylor Robie Committed by GitHub
Browse files

Add indirection file to NCF async process. (#4958)

* add indirection file

* remove unused imports

* fix import
parent abc62005
...@@ -43,10 +43,7 @@ import tensorflow as tf ...@@ -43,10 +43,7 @@ import tensorflow as tf
from official.datasets import movielens from official.datasets import movielens
from official.recommendation import constants as rconst from official.recommendation import constants as rconst
from official.recommendation import stat_utils from official.recommendation import stat_utils
from official.recommendation import popen_helper
_ASYNC_GEN_PATH = os.path.join(os.path.dirname(__file__),
"data_async_generation.py")
class NCFDataset(object): class NCFDataset(object):
...@@ -391,14 +388,11 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size, ...@@ -391,14 +388,11 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
# contention with the main training process. # contention with the main training process.
subproc_env["CUDA_VISIBLE_DEVICES"] = "" subproc_env["CUDA_VISIBLE_DEVICES"] = ""
python = "python3" if six.PY3 else "python2"
# By limiting the number of workers we guarantee that the worker # By limiting the number of workers we guarantee that the worker
# pool underlying the training generation doesn't starve other processes. # pool underlying the training generation doesn't starve other processes.
num_workers = int(multiprocessing.cpu_count() * 0.75) num_workers = int(multiprocessing.cpu_count() * 0.75)
subproc_args = [ subproc_args = popen_helper.INVOCATION + [
python, _ASYNC_GEN_PATH,
"--data_dir", data_dir, "--data_dir", data_dir,
"--cache_id", str(ncf_dataset.cache_paths.cache_id), "--cache_id", str(ncf_dataset.cache_paths.cache_id),
"--num_neg", str(num_neg), "--num_neg", str(num_neg),
......
...@@ -38,10 +38,8 @@ import typing ...@@ -38,10 +38,8 @@ import typing
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras.utils import tf_utils
from official.datasets import movielens # pylint: disable=g-bad-import-order from official.datasets import movielens # pylint: disable=g-bad-import-order
from official.utils.accelerator import tpu as tpu_utils
def neumf_model_fn(features, labels, mode, params): def neumf_model_fn(features, labels, mode, params):
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper file for running the async data generation process in OSS."""
import os
import six
_PYTHON = "python3" if six.PY3 else "python2"
_ASYNC_GEN_PATH = os.path.join(os.path.dirname(__file__),
"data_async_generation.py")
INVOCATION = [_PYTHON, _ASYNC_GEN_PATH]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment