Commit bba6134c authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

[NCF] Updating data preprocessing script.

PiperOrigin-RevId: 386140865
parent 31dc0f62
...@@ -29,17 +29,16 @@ import timeit ...@@ -29,17 +29,16 @@ import timeit
import traceback import traceback
import typing import typing
from absl import logging
import numpy as np import numpy as np
import six
from six.moves import queue from six.moves import queue
import tensorflow as tf import tensorflow as tf
from absl import logging
from tensorflow.python.tpu.datasets import StreamingFilesDataset
from official.recommendation import constants as rconst from official.recommendation import constants as rconst
from official.recommendation import movielens from official.recommendation import movielens
from official.recommendation import popen_helper from official.recommendation import popen_helper
from official.recommendation import stat_utils from official.recommendation import stat_utils
from tensorflow.python.tpu.datasets import StreamingFilesDataset
SUMMARY_TEMPLATE = """General: SUMMARY_TEMPLATE = """General:
{spacer}Num users: {num_users} {spacer}Num users: {num_users}
...@@ -119,6 +118,7 @@ class DatasetManager(object): ...@@ -119,6 +118,7 @@ class DatasetManager(object):
"""Convert NumPy arrays into a TFRecords entry.""" """Convert NumPy arrays into a TFRecords entry."""
def create_int_feature(values): def create_int_feature(values):
values = np.squeeze(values)
return tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) return tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
feature_dict = { feature_dict = {
......
...@@ -23,21 +23,19 @@ import os ...@@ -23,21 +23,19 @@ import os
import pickle import pickle
import time import time
import timeit import timeit
import typing
# pylint: disable=wrong-import-order from typing import Dict, Text, Tuple
from absl import logging from absl import logging
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import tensorflow as tf import tensorflow as tf
import typing
from typing import Dict, Text, Tuple
# pylint: enable=wrong-import-order
from official.recommendation import constants as rconst from official.recommendation import constants as rconst
from official.recommendation import data_pipeline from official.recommendation import data_pipeline
from official.recommendation import movielens from official.recommendation import movielens
_EXPECTED_CACHE_KEYS = (rconst.TRAIN_USER_KEY, rconst.TRAIN_ITEM_KEY, _EXPECTED_CACHE_KEYS = (rconst.TRAIN_USER_KEY, rconst.TRAIN_ITEM_KEY,
rconst.EVAL_USER_KEY, rconst.EVAL_ITEM_KEY, rconst.EVAL_USER_KEY, rconst.EVAL_ITEM_KEY,
rconst.USER_MAP, rconst.ITEM_MAP) rconst.USER_MAP, rconst.ITEM_MAP)
...@@ -196,7 +194,7 @@ def _filter_index_sort(raw_rating_path: Text, ...@@ -196,7 +194,7 @@ def _filter_index_sort(raw_rating_path: Text,
logging.info("Writing raw data cache.") logging.info("Writing raw data cache.")
with tf.io.gfile.GFile(cache_path, "wb") as f: with tf.io.gfile.GFile(cache_path, "wb") as f:
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL) pickle.dump(data, f, protocol=4)
# TODO(robieta): MLPerf cache clear. # TODO(robieta): MLPerf cache clear.
return data, valid_cache return data, valid_cache
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment