"...csrc/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "43dbfd2e397930a9e4595a8914eb0221b34a55d5"
Commit 1bb074b0 authored by Taylor Robie's avatar Taylor Robie
Browse files

address PR comments

parent 444f5993
...@@ -19,7 +19,6 @@ from __future__ import division ...@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import atexit import atexit
import collections
import functools import functools
import os import os
import sys import sys
......
...@@ -20,12 +20,10 @@ from __future__ import print_function ...@@ -20,12 +20,10 @@ from __future__ import print_function
from collections import defaultdict from collections import defaultdict
import hashlib import hashlib
import mock
import os import os
import pickle
import time
import numpy as np import numpy as np
import pandas as pd
import scipy.stats import scipy.stats
import tensorflow as tf import tensorflow as tf
...@@ -33,7 +31,6 @@ from official.datasets import movielens ...@@ -33,7 +31,6 @@ from official.datasets import movielens
from official.recommendation import constants as rconst from official.recommendation import constants as rconst
from official.recommendation import data_preprocessing from official.recommendation import data_preprocessing
from official.recommendation import popen_helper from official.recommendation import popen_helper
from official.recommendation import stat_utils
DATASET = "ml-test" DATASET = "ml-test"
...@@ -53,15 +50,12 @@ FRESH_RANDOMNESS_MD5 = "63d0dff73c0e5f1048fbdc8c65021e22" ...@@ -53,15 +50,12 @@ FRESH_RANDOMNESS_MD5 = "63d0dff73c0e5f1048fbdc8c65021e22"
def mock_download(*args, **kwargs): def mock_download(*args, **kwargs):
return return
# The forkpool used by data producers interacts badly with the threading
# used by TestCase. Without this patch tests will hang, and no amount
# of diligent closing and joining within the producer will prevent it.
@mock.patch.object(popen_helper, "get_forkpool", popen_helper.get_fauxpool)
class BaseTest(tf.test.TestCase): class BaseTest(tf.test.TestCase):
def setUp(self): def setUp(self):
# The forkpool used by data producers interacts badly with the threading
# used by TestCase. Without this monkey patch tests will hang, and no amount
# of diligent closing and joining within the producer will prevent it.
self._get_forkpool = popen_helper.get_forkpool
popen_helper.get_forkpool = popen_helper.get_fauxpool
self.temp_data_dir = self.get_temp_dir() self.temp_data_dir = self.get_temp_dir()
ratings_folder = os.path.join(self.temp_data_dir, DATASET) ratings_folder = os.path.join(self.temp_data_dir, DATASET)
tf.gfile.MakeDirs(ratings_folder) tf.gfile.MakeDirs(ratings_folder)
...@@ -99,9 +93,6 @@ class BaseTest(tf.test.TestCase): ...@@ -99,9 +93,6 @@ class BaseTest(tf.test.TestCase):
data_preprocessing.DATASET_TO_NUM_USERS_AND_ITEMS[DATASET] = (NUM_USERS, data_preprocessing.DATASET_TO_NUM_USERS_AND_ITEMS[DATASET] = (NUM_USERS,
NUM_ITEMS) NUM_ITEMS)
def tearDown(self):
popen_helper.get_forkpool = self._get_forkpool
def make_params(self, train_epochs=1): def make_params(self, train_epochs=1):
return { return {
"train_epochs": train_epochs, "train_epochs": train_epochs,
...@@ -176,7 +167,9 @@ class BaseTest(tf.test.TestCase): ...@@ -176,7 +167,9 @@ class BaseTest(tf.test.TestCase):
data_list = [ data_list = [
features[movielens.USER_COLUMN], features[movielens.ITEM_COLUMN], features[movielens.USER_COLUMN], features[movielens.ITEM_COLUMN],
features[rconst.VALID_POINT_MASK], labels] features[rconst.VALID_POINT_MASK], labels]
[md5.update(i.tobytes()) for i in data_list] for i in data_list:
md5.update(i.tobytes())
for u, i, v, l in zip(*data_list): for u, i, v, l in zip(*data_list):
if not v: if not v:
continue # ignore padding continue # ignore padding
...@@ -222,7 +215,9 @@ class BaseTest(tf.test.TestCase): ...@@ -222,7 +215,9 @@ class BaseTest(tf.test.TestCase):
data_list = [ data_list = [
features[movielens.USER_COLUMN], features[movielens.ITEM_COLUMN], features[movielens.USER_COLUMN], features[movielens.ITEM_COLUMN],
features[rconst.DUPLICATE_MASK]] features[rconst.DUPLICATE_MASK]]
[md5.update(i.tobytes()) for i in data_list] for i in data_list:
md5.update(i.tobytes())
for idx, (u, i, d) in enumerate(zip(*data_list)): for idx, (u, i, d) in enumerate(zip(*data_list)):
u_raw = user_inv_map[u] u_raw = user_inv_map[u]
i_raw = item_inv_map[i] i_raw = item_inv_map[i]
...@@ -280,7 +275,9 @@ class BaseTest(tf.test.TestCase): ...@@ -280,7 +275,9 @@ class BaseTest(tf.test.TestCase):
data_list = [ data_list = [
features[movielens.USER_COLUMN], features[movielens.ITEM_COLUMN], features[movielens.USER_COLUMN], features[movielens.ITEM_COLUMN],
features[rconst.VALID_POINT_MASK], labels] features[rconst.VALID_POINT_MASK], labels]
[md5.update(i.tobytes()) for i in data_list] for i in data_list:
md5.update(i.tobytes())
for u, i, v, l in zip(*data_list): for u, i, v, l in zip(*data_list):
if not v: if not v:
continue # ignore padding continue # ignore padding
......
...@@ -24,14 +24,11 @@ import mock ...@@ -24,14 +24,11 @@ import mock
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from absl import flags
from absl.testing import flagsaver from absl.testing import flagsaver
from official.recommendation import constants as rconst from official.recommendation import constants as rconst
from official.recommendation import data_pipeline from official.recommendation import data_pipeline
from official.recommendation import data_preprocessing
from official.recommendation import neumf_model from official.recommendation import neumf_model
from official.recommendation import ncf_main from official.recommendation import ncf_main
from official.recommendation import stat_utils
NUM_TRAIN_NEG = 4 NUM_TRAIN_NEG = 4
......
#!/bin/bash #!/bin/bash
set -e set -e
if ! which unbuffer > /dev/null; then
echo "Could not find unbuffer command. Make sure the expect package is installed."
exit 1
fi
if [ `id -u` != 0 ]; then if [ `id -u` != 0 ]; then
echo "Calling sudo to gain root for this shell. (Needed to clear caches.)" echo "Calling sudo to gain root for this shell. (Needed to clear caches.)"
sudo echo "Success" sudo echo "Success"
...@@ -60,7 +55,7 @@ do ...@@ -60,7 +55,7 @@ do
# To reduce variation set the seed flag: # To reduce variation set the seed flag:
# --seed ${i} # --seed ${i}
unbuffer python ncf_main.py \ python -u ncf_main.py \
--model_dir ${MODEL_DIR} \ --model_dir ${MODEL_DIR} \
--data_dir ${DATA_DIR} \ --data_dir ${DATA_DIR} \
--dataset ${DATASET} --hooks "" \ --dataset ${DATASET} --hooks "" \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment