Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
6726c5e0
Commit
6726c5e0
authored
Jan 07, 2019
by
Taylor Robie
Browse files
address more PR comments
parent
1bb074b0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
25 deletions
+2
-25
official/recommendation/data_preprocessing.py
official/recommendation/data_preprocessing.py
+2
-25
No files found.
official/recommendation/data_preprocessing.py
View file @
6726c5e0
...
...
@@ -18,34 +18,21 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
atexit
import
contextlib
import
gc
import
hashlib
import
json
import
os
import
pickle
import
signal
import
socket
import
subprocess
import
threading
import
time
import
timeit
import
typing
# pylint: disable=wrong-import-order
from
absl
import
app
as
absl_app
from
absl
import
flags
import
numpy
as
np
import
pandas
as
pd
import
six
import
tensorflow
as
tf
# pylint: enable=wrong-import-order
from
official.datasets
import
movielens
from
official.recommendation
import
constants
as
rconst
from
official.recommendation
import
data_pipeline
from
official.recommendation
import
stat_utils
from
official.utils.logs
import
mlperf_helper
...
...
@@ -60,7 +47,7 @@ _EXPECTED_CACHE_KEYS = (
rconst
.
EVAL_ITEM_KEY
,
rconst
.
USER_MAP
,
rconst
.
ITEM_MAP
,
"match_mlperf"
)
def
_filter_index_sort
(
raw_rating_path
,
cache_path
,
match_mlperf
):
def
_filter_index_sort
(
raw_rating_path
,
cache_path
):
# type: (str, str, bool) -> (dict, bool)
"""Read in data CSV, and output structured data.
...
...
@@ -87,8 +74,6 @@ def _filter_index_sort(raw_rating_path, cache_path, match_mlperf):
Args:
raw_rating_path: The path to the CSV which contains the raw dataset.
cache_path: The path to the file where results of this function are saved.
match_mlperf: If True, change the sorting algorithm to match the MLPerf
reference implementation.
Returns:
A filtered, zero-index remapped, sorted dataframe, a dict mapping raw user
...
...
@@ -104,9 +89,6 @@ def _filter_index_sort(raw_rating_path, cache_path, match_mlperf):
if
cache_age
>
rconst
.
CACHE_INVALIDATION_SEC
:
valid_cache
=
False
if
cached_data
[
"match_mlperf"
]
!=
match_mlperf
:
valid_cache
=
False
for
key
in
_EXPECTED_CACHE_KEYS
:
if
key
not
in
cached_data
:
valid_cache
=
False
...
...
@@ -144,9 +126,6 @@ def _filter_index_sort(raw_rating_path, cache_path, match_mlperf):
mlperf_helper
.
ncf_print
(
key
=
mlperf_helper
.
TAGS
.
PREPROC_HP_NUM_EVAL
,
value
=
rconst
.
NUM_EVAL_NEGATIVES
)
mlperf_helper
.
ncf_print
(
key
=
mlperf_helper
.
TAGS
.
PREPROC_HP_SAMPLE_EVAL_REPLACEMENT
,
value
=
match_mlperf
)
assert
num_users
<=
np
.
iinfo
(
rconst
.
USER_DTYPE
).
max
assert
num_items
<=
np
.
iinfo
(
rconst
.
ITEM_DTYPE
).
max
...
...
@@ -186,7 +165,6 @@ def _filter_index_sort(raw_rating_path, cache_path, match_mlperf):
rconst
.
USER_MAP
:
user_map
,
rconst
.
ITEM_MAP
:
item_map
,
"create_time"
:
time
.
time
(),
"match_mlperf"
:
match_mlperf
,
}
tf
.
logging
.
info
(
"Writing raw data cache."
)
...
...
@@ -216,8 +194,7 @@ def instantiate_pipeline(dataset, data_dir, params, constructor_type=None,
raw_rating_path
=
os
.
path
.
join
(
data_dir
,
dataset
,
movielens
.
RATINGS_FILE
)
cache_path
=
os
.
path
.
join
(
data_dir
,
dataset
,
rconst
.
RAW_CACHE_FILE
)
raw_data
,
_
=
_filter_index_sort
(
raw_rating_path
,
cache_path
,
params
[
"match_mlperf"
])
raw_data
,
_
=
_filter_index_sort
(
raw_rating_path
,
cache_path
)
user_map
,
item_map
=
raw_data
[
"user_map"
],
raw_data
[
"item_map"
]
num_users
,
num_items
=
DATASET_TO_NUM_USERS_AND_ITEMS
[
dataset
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment