"...resnet50_tensorflow.git" did not exist on "e97979cb72b83ed0e80a37dbc69c4c0bc157e50e"
Commit 6726c5e0 authored by Taylor Robie's avatar Taylor Robie
Browse files

address more PR comments

parent 1bb074b0
......@@ -18,34 +18,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import atexit
import contextlib
import gc
import hashlib
import json
import os
import pickle
import signal
import socket
import subprocess
import threading
import time
import timeit
import typing
# pylint: disable=wrong-import-order
from absl import app as absl_app
from absl import flags
import numpy as np
import pandas as pd
import six
import tensorflow as tf
# pylint: enable=wrong-import-order
from official.datasets import movielens
from official.recommendation import constants as rconst
from official.recommendation import data_pipeline
from official.recommendation import stat_utils
from official.utils.logs import mlperf_helper
......@@ -60,7 +47,7 @@ _EXPECTED_CACHE_KEYS = (
rconst.EVAL_ITEM_KEY, rconst.USER_MAP, rconst.ITEM_MAP, "match_mlperf")
def _filter_index_sort(raw_rating_path, cache_path, match_mlperf):
def _filter_index_sort(raw_rating_path, cache_path):
# type: (str, str, bool) -> (dict, bool)
"""Read in data CSV, and output structured data.
......@@ -87,8 +74,6 @@ def _filter_index_sort(raw_rating_path, cache_path, match_mlperf):
Args:
raw_rating_path: The path to the CSV which contains the raw dataset.
cache_path: The path to the file where results of this function are saved.
match_mlperf: If True, change the sorting algorithm to match the MLPerf
reference implementation.
Returns:
A filtered, zero-index remapped, sorted dataframe, a dict mapping raw user
......@@ -104,9 +89,6 @@ def _filter_index_sort(raw_rating_path, cache_path, match_mlperf):
if cache_age > rconst.CACHE_INVALIDATION_SEC:
valid_cache = False
if cached_data["match_mlperf"] != match_mlperf:
valid_cache = False
for key in _EXPECTED_CACHE_KEYS:
if key not in cached_data:
valid_cache = False
......@@ -144,9 +126,6 @@ def _filter_index_sort(raw_rating_path, cache_path, match_mlperf):
mlperf_helper.ncf_print(key=mlperf_helper.TAGS.PREPROC_HP_NUM_EVAL,
value=rconst.NUM_EVAL_NEGATIVES)
mlperf_helper.ncf_print(
key=mlperf_helper.TAGS.PREPROC_HP_SAMPLE_EVAL_REPLACEMENT,
value=match_mlperf)
assert num_users <= np.iinfo(rconst.USER_DTYPE).max
assert num_items <= np.iinfo(rconst.ITEM_DTYPE).max
......@@ -186,7 +165,6 @@ def _filter_index_sort(raw_rating_path, cache_path, match_mlperf):
rconst.USER_MAP: user_map,
rconst.ITEM_MAP: item_map,
"create_time": time.time(),
"match_mlperf": match_mlperf,
}
tf.logging.info("Writing raw data cache.")
......@@ -216,8 +194,7 @@ def instantiate_pipeline(dataset, data_dir, params, constructor_type=None,
raw_rating_path = os.path.join(data_dir, dataset, movielens.RATINGS_FILE)
cache_path = os.path.join(data_dir, dataset, rconst.RAW_CACHE_FILE)
raw_data, _ = _filter_index_sort(raw_rating_path, cache_path,
params["match_mlperf"])
raw_data, _ = _filter_index_sort(raw_rating_path, cache_path)
user_map, item_map = raw_data["user_map"], raw_data["item_map"]
num_users, num_items = DATASET_TO_NUM_USERS_AND_ITEMS[dataset]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment