Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
98dd890b
Commit
98dd890b
authored
May 08, 2020
by
A. Unique TensorFlower
Browse files
refactors some preprocessing code.
PiperOrigin-RevId: 310658964
parent
0fc994b6
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
101 additions
and
69 deletions
+101
-69
official/recommendation/data_preprocessing.py
official/recommendation/data_preprocessing.py
+82
-55
official/recommendation/data_test.py
official/recommendation/data_test.py
+1
-2
official/recommendation/movielens.py
official/recommendation/movielens.py
+14
-5
official/recommendation/ncf_common.py
official/recommendation/ncf_common.py
+4
-7
No files found.
official/recommendation/data_preprocessing.py
View file @
98dd890b
...
...
@@ -16,18 +16,21 @@
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
os
import
pickle
import
time
import
timeit
# pylint: disable=wrong-import-order
from
absl
import
logging
import
numpy
as
np
import
pandas
as
pd
import
tensorflow
as
tf
import
typing
from
typing
import
Dict
,
Text
,
Tuple
# pylint: enable=wrong-import-order
from
official.recommendation
import
constants
as
rconst
...
...
@@ -35,20 +38,15 @@ from official.recommendation import data_pipeline
from
official.recommendation
import
movielens
DATASET_TO_NUM_USERS_AND_ITEMS
=
{
"ml-1m"
:
(
6040
,
3706
),
"ml-20m"
:
(
138493
,
26744
)
}
_EXPECTED_CACHE_KEYS
=
(
rconst
.
TRAIN_USER_KEY
,
rconst
.
TRAIN_ITEM_KEY
,
rconst
.
EVAL_USER_KEY
,
rconst
.
EVAL_ITEM_KEY
,
rconst
.
USER_MAP
,
rconst
.
ITEM_MAP
)
def
_filter_index_sort
(
raw_rating_path
,
cache_path
):
# type: (str, str, bool) -> (dict, bool)
"""Read in data CSV, and output structured data.
def
read_dataframe
(
raw_rating_path
:
Text
)
->
Tuple
[
Dict
[
int
,
int
],
Dict
[
int
,
int
],
pd
.
DataFrame
]:
"""Read in data CSV, and output DataFrame for downstream processing.
This function reads in the raw CSV of positive items, and performs three
preprocessing transformations:
...
...
@@ -63,43 +61,14 @@ def _filter_index_sort(raw_rating_path, cache_path):
This allows the dataframe to be sliced by user in-place, and for the last
item to be selected simply by calling the `-1` index of a user's slice.
While all of these transformations are performed by Pandas (and are therefore
single-threaded), they only take ~2 minutes, and the overhead to apply a
MapReduce pattern to parallel process the dataset adds significant complexity
for no computational gain. For a larger dataset parallelizing this
preprocessing could yield speedups. (Also, this preprocessing step is only
performed once for an entire run.
Args:
raw_rating_path: The path to the CSV which contains the raw dataset.
cache_path: The path to the file where results of this function are saved.
Returns:
A
filtered, zero-index remapped, sorted dataframe
, a dict mapping raw
user
IDs to regularized
user
IDs, and a
dict mapping raw item IDs to regulariz
ed
item IDs
.
A
dict mapping raw user IDs to regularized user IDs
, a dict mapping raw
item
IDs to regularized
item
IDs, and a
filtered, zero-index remapp
ed
,
sorted dataframe
.
"""
valid_cache
=
tf
.
io
.
gfile
.
exists
(
cache_path
)
if
valid_cache
:
with
tf
.
io
.
gfile
.
GFile
(
cache_path
,
"rb"
)
as
f
:
cached_data
=
pickle
.
load
(
f
)
# (nnigania)disabled this check as the dataset is not expected to change
# cache_age = time.time() - cached_data.get("create_time", 0)
# if cache_age > rconst.CACHE_INVALIDATION_SEC:
# valid_cache = False
for
key
in
_EXPECTED_CACHE_KEYS
:
if
key
not
in
cached_data
:
valid_cache
=
False
if
not
valid_cache
:
logging
.
info
(
"Removing stale raw data cache file."
)
tf
.
io
.
gfile
.
remove
(
cache_path
)
if
valid_cache
:
data
=
cached_data
else
:
with
tf
.
io
.
gfile
.
GFile
(
raw_rating_path
)
as
f
:
df
=
pd
.
read_csv
(
f
)
...
...
@@ -142,10 +111,68 @@ def _filter_index_sort(raw_rating_path, cache_path):
# reference implementation.
df
.
sort_values
(
by
=
movielens
.
TIMESTAMP_COLUMN
,
inplace
=
True
)
df
.
sort_values
([
movielens
.
USER_COLUMN
,
movielens
.
TIMESTAMP_COLUMN
],
inplace
=
True
,
kind
=
"mergesort"
)
inplace
=
True
,
kind
=
"mergesort"
)
# The dataframe does not reconstruct indices in the sort or filter steps.
df
=
df
.
reset_index
()
return
user_map
,
item_map
,
df
.
reset_index
()
def
_filter_index_sort
(
raw_rating_path
:
Text
,
cache_path
:
Text
)
->
Tuple
[
pd
.
DataFrame
,
bool
]:
"""Read in data CSV, and output structured data.
This function reads in the raw CSV of positive items, and performs three
preprocessing transformations:
1) Filter out all users who have not rated at least a certain number
of items. (Typically 20 items)
2) Zero index the users and items such that the largest user_id is
`num_users - 1` and the largest item_id is `num_items - 1`
3) Sort the dataframe by user_id, with timestamp as a secondary sort key.
This allows the dataframe to be sliced by user in-place, and for the last
item to be selected simply by calling the `-1` index of a user's slice.
While all of these transformations are performed by Pandas (and are therefore
single-threaded), they only take ~2 minutes, and the overhead to apply a
MapReduce pattern to parallel process the dataset adds significant complexity
for no computational gain. For a larger dataset parallelizing this
preprocessing could yield speedups. (Also, this preprocessing step is only
performed once for an entire run.
Args:
raw_rating_path: The path to the CSV which contains the raw dataset.
cache_path: The path to the file where results of this function are saved.
Returns:
A filtered, zero-index remapped, sorted dataframe, a dict mapping raw user
IDs to regularized user IDs, and a dict mapping raw item IDs to regularized
item IDs.
"""
valid_cache
=
tf
.
io
.
gfile
.
exists
(
cache_path
)
if
valid_cache
:
with
tf
.
io
.
gfile
.
GFile
(
cache_path
,
"rb"
)
as
f
:
cached_data
=
pickle
.
load
(
f
)
# (nnigania)disabled this check as the dataset is not expected to change
# cache_age = time.time() - cached_data.get("create_time", 0)
# if cache_age > rconst.CACHE_INVALIDATION_SEC:
# valid_cache = False
for
key
in
_EXPECTED_CACHE_KEYS
:
if
key
not
in
cached_data
:
valid_cache
=
False
if
not
valid_cache
:
logging
.
info
(
"Removing stale raw data cache file."
)
tf
.
io
.
gfile
.
remove
(
cache_path
)
if
valid_cache
:
data
=
cached_data
else
:
user_map
,
item_map
,
df
=
read_dataframe
(
raw_rating_path
)
grouped
=
df
.
groupby
(
movielens
.
USER_COLUMN
,
group_keys
=
False
)
eval_df
,
train_df
=
grouped
.
tail
(
1
),
grouped
.
apply
(
lambda
x
:
x
.
iloc
[:
-
1
])
...
...
@@ -201,7 +228,7 @@ def instantiate_pipeline(dataset,
raw_data
,
_
=
_filter_index_sort
(
raw_rating_path
,
cache_path
)
user_map
,
item_map
=
raw_data
[
"user_map"
],
raw_data
[
"item_map"
]
num_users
,
num_items
=
DATASET_TO_NUM_USERS_AND_ITEMS
[
dataset
]
num_users
,
num_items
=
movielens
.
DATASET_TO_NUM_USERS_AND_ITEMS
[
dataset
]
if
num_users
!=
len
(
user_map
):
raise
ValueError
(
"Expected to find {} users, but found {}"
.
format
(
...
...
official/recommendation/data_test.py
View file @
98dd890b
...
...
@@ -95,8 +95,7 @@ class BaseTest(tf.test.TestCase):
movielens
.
download
=
mock_download
movielens
.
NUM_RATINGS
[
DATASET
]
=
NUM_PTS
data_preprocessing
.
DATASET_TO_NUM_USERS_AND_ITEMS
[
DATASET
]
=
(
NUM_USERS
,
NUM_ITEMS
)
movielens
.
DATASET_TO_NUM_USERS_AND_ITEMS
[
DATASET
]
=
(
NUM_USERS
,
NUM_ITEMS
)
def
make_params
(
self
,
train_epochs
=
1
):
return
{
...
...
official/recommendation/movielens.py
View file @
98dd890b
...
...
@@ -84,6 +84,8 @@ NUM_RATINGS = {
ML_20M
:
20000263
}
DATASET_TO_NUM_USERS_AND_ITEMS
=
{
ML_1M
:
(
6040
,
3706
),
ML_20M
:
(
138493
,
26744
)}
def
_download_and_clean
(
dataset
,
data_dir
):
"""Download MovieLens dataset in a standard format.
...
...
@@ -284,17 +286,24 @@ def integerize_genres(dataframe):
return
dataframe
def
define_flags
():
"""Add flags specifying data usage arguments."""
flags
.
DEFINE_enum
(
name
=
"dataset"
,
default
=
None
,
enum_values
=
DATASETS
,
case_sensitive
=
False
,
help
=
flags_core
.
help_wrap
(
"Dataset to be trained and evaluated."
))
def
define_data_download_flags
():
"""Add flags specifying data download arguments."""
"""Add flags specifying data download
and usage
arguments."""
flags
.
DEFINE_string
(
name
=
"data_dir"
,
default
=
"/tmp/movielens-data/"
,
help
=
flags_core
.
help_wrap
(
"Directory to download and extract data."
))
flags
.
DEFINE_enum
(
name
=
"dataset"
,
default
=
None
,
enum_values
=
DATASETS
,
case_sensitive
=
False
,
help
=
flags_core
.
help_wrap
(
"Dataset to be trained and evaluated."
))
define_flags
()
def
main
(
_
):
...
...
official/recommendation/ncf_common.py
View file @
98dd890b
...
...
@@ -50,7 +50,7 @@ def get_inputs(params):
if
FLAGS
.
use_synthetic_data
:
producer
=
data_pipeline
.
DummyConstructor
()
num_users
,
num_items
=
data_preprocessing
.
DATASET_TO_NUM_USERS_AND_ITEMS
[
num_users
,
num_items
=
movielens
.
DATASET_TO_NUM_USERS_AND_ITEMS
[
FLAGS
.
dataset
]
num_train_steps
=
rconst
.
SYNTHETIC_BATCHES_PER_EPOCH
num_eval_steps
=
rconst
.
SYNTHETIC_BATCHES_PER_EPOCH
...
...
@@ -163,21 +163,18 @@ def define_ncf_flags():
flags
.
adopt_module_key_flags
(
flags_core
)
movielens
.
define_flags
()
flags_core
.
set_defaults
(
model_dir
=
"/tmp/ncf/"
,
data_dir
=
"/tmp/movielens-data/"
,
dataset
=
movielens
.
ML_1M
,
train_epochs
=
2
,
batch_size
=
99000
,
tpu
=
None
)
# Add ncf-specific flags
flags
.
DEFINE_enum
(
name
=
"dataset"
,
default
=
"ml-1m"
,
enum_values
=
[
"ml-1m"
,
"ml-20m"
],
case_sensitive
=
False
,
help
=
flags_core
.
help_wrap
(
"Dataset to be trained and evaluated."
))
flags
.
DEFINE_boolean
(
name
=
"download_if_missing"
,
default
=
True
,
help
=
flags_core
.
help_wrap
(
"Download data to data_dir if it is not already present."
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment