"...stereo/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "10dafd9b2704a1ce7bcac8244e100ac8e2620351"
Commit 6726c5e0 authored by Taylor Robie's avatar Taylor Robie
Browse files

address more PR comments

parent 1bb074b0
...@@ -18,34 +18,21 @@ from __future__ import absolute_import ...@@ -18,34 +18,21 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import atexit
import contextlib
import gc
import hashlib
import json
import os import os
import pickle import pickle
import signal
import socket
import subprocess
import threading
import time import time
import timeit import timeit
import typing import typing
# pylint: disable=wrong-import-order # pylint: disable=wrong-import-order
from absl import app as absl_app
from absl import flags
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import six
import tensorflow as tf import tensorflow as tf
# pylint: enable=wrong-import-order # pylint: enable=wrong-import-order
from official.datasets import movielens from official.datasets import movielens
from official.recommendation import constants as rconst from official.recommendation import constants as rconst
from official.recommendation import data_pipeline from official.recommendation import data_pipeline
from official.recommendation import stat_utils
from official.utils.logs import mlperf_helper from official.utils.logs import mlperf_helper
...@@ -60,7 +47,7 @@ _EXPECTED_CACHE_KEYS = ( ...@@ -60,7 +47,7 @@ _EXPECTED_CACHE_KEYS = (
rconst.EVAL_ITEM_KEY, rconst.USER_MAP, rconst.ITEM_MAP, "match_mlperf") rconst.EVAL_ITEM_KEY, rconst.USER_MAP, rconst.ITEM_MAP, "match_mlperf")
def _filter_index_sort(raw_rating_path, cache_path, match_mlperf): def _filter_index_sort(raw_rating_path, cache_path):
# type: (str, str, bool) -> (dict, bool) # type: (str, str, bool) -> (dict, bool)
"""Read in data CSV, and output structured data. """Read in data CSV, and output structured data.
...@@ -87,8 +74,6 @@ def _filter_index_sort(raw_rating_path, cache_path, match_mlperf): ...@@ -87,8 +74,6 @@ def _filter_index_sort(raw_rating_path, cache_path, match_mlperf):
Args: Args:
raw_rating_path: The path to the CSV which contains the raw dataset. raw_rating_path: The path to the CSV which contains the raw dataset.
cache_path: The path to the file where results of this function are saved. cache_path: The path to the file where results of this function are saved.
match_mlperf: If True, change the sorting algorithm to match the MLPerf
reference implementation.
Returns: Returns:
A filtered, zero-index remapped, sorted dataframe, a dict mapping raw user A filtered, zero-index remapped, sorted dataframe, a dict mapping raw user
...@@ -104,9 +89,6 @@ def _filter_index_sort(raw_rating_path, cache_path, match_mlperf): ...@@ -104,9 +89,6 @@ def _filter_index_sort(raw_rating_path, cache_path, match_mlperf):
if cache_age > rconst.CACHE_INVALIDATION_SEC: if cache_age > rconst.CACHE_INVALIDATION_SEC:
valid_cache = False valid_cache = False
if cached_data["match_mlperf"] != match_mlperf:
valid_cache = False
for key in _EXPECTED_CACHE_KEYS: for key in _EXPECTED_CACHE_KEYS:
if key not in cached_data: if key not in cached_data:
valid_cache = False valid_cache = False
...@@ -144,9 +126,6 @@ def _filter_index_sort(raw_rating_path, cache_path, match_mlperf): ...@@ -144,9 +126,6 @@ def _filter_index_sort(raw_rating_path, cache_path, match_mlperf):
mlperf_helper.ncf_print(key=mlperf_helper.TAGS.PREPROC_HP_NUM_EVAL, mlperf_helper.ncf_print(key=mlperf_helper.TAGS.PREPROC_HP_NUM_EVAL,
value=rconst.NUM_EVAL_NEGATIVES) value=rconst.NUM_EVAL_NEGATIVES)
mlperf_helper.ncf_print(
key=mlperf_helper.TAGS.PREPROC_HP_SAMPLE_EVAL_REPLACEMENT,
value=match_mlperf)
assert num_users <= np.iinfo(rconst.USER_DTYPE).max assert num_users <= np.iinfo(rconst.USER_DTYPE).max
assert num_items <= np.iinfo(rconst.ITEM_DTYPE).max assert num_items <= np.iinfo(rconst.ITEM_DTYPE).max
...@@ -186,7 +165,6 @@ def _filter_index_sort(raw_rating_path, cache_path, match_mlperf): ...@@ -186,7 +165,6 @@ def _filter_index_sort(raw_rating_path, cache_path, match_mlperf):
rconst.USER_MAP: user_map, rconst.USER_MAP: user_map,
rconst.ITEM_MAP: item_map, rconst.ITEM_MAP: item_map,
"create_time": time.time(), "create_time": time.time(),
"match_mlperf": match_mlperf,
} }
tf.logging.info("Writing raw data cache.") tf.logging.info("Writing raw data cache.")
...@@ -216,8 +194,7 @@ def instantiate_pipeline(dataset, data_dir, params, constructor_type=None, ...@@ -216,8 +194,7 @@ def instantiate_pipeline(dataset, data_dir, params, constructor_type=None,
raw_rating_path = os.path.join(data_dir, dataset, movielens.RATINGS_FILE) raw_rating_path = os.path.join(data_dir, dataset, movielens.RATINGS_FILE)
cache_path = os.path.join(data_dir, dataset, rconst.RAW_CACHE_FILE) cache_path = os.path.join(data_dir, dataset, rconst.RAW_CACHE_FILE)
raw_data, _ = _filter_index_sort(raw_rating_path, cache_path, raw_data, _ = _filter_index_sort(raw_rating_path, cache_path)
params["match_mlperf"])
user_map, item_map = raw_data["user_map"], raw_data["item_map"] user_map, item_map = raw_data["user_map"], raw_data["item_map"]
num_users, num_items = DATASET_TO_NUM_USERS_AND_ITEMS[dataset] num_users, num_items = DATASET_TO_NUM_USERS_AND_ITEMS[dataset]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment