Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
ec0d43ba
"src/graph/vscode:/vscode.git/clone" did not exist on "18eaad17cce0ccb358df343cd8b1479582a2712c"
Commit
ec0d43ba
authored
Dec 21, 2018
by
Taylor Robie
Browse files
address PR comments
parent
c556dad9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
65 additions
and
55 deletions
+65
-55
official/recommendation/data_pipeline.py
official/recommendation/data_pipeline.py
+44
-36
official/recommendation/ncf_main.py
official/recommendation/ncf_main.py
+6
-9
official/recommendation/stat_utils.py
official/recommendation/stat_utils.py
+15
-10
No files found.
official/recommendation/data_pipeline.py
View file @
ec0d43ba
...
@@ -19,17 +19,15 @@ from __future__ import division
...
@@ -19,17 +19,15 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
atexit
import
atexit
import
collections
import
functools
import
functools
import
os
import
os
import
pickle
import
struct
import
sys
import
sys
import
tempfile
import
tempfile
import
threading
import
threading
import
time
import
time
import
timeit
import
timeit
import
traceback
import
traceback
import
typing
import
numpy
as
np
import
numpy
as
np
import
six
import
six
...
@@ -82,6 +80,18 @@ class DatasetManager(object):
...
@@ -82,6 +80,18 @@ class DatasetManager(object):
"""
"""
def
__init__
(
self
,
is_training
,
stream_files
,
batches_per_epoch
,
def
__init__
(
self
,
is_training
,
stream_files
,
batches_per_epoch
,
shard_root
=
None
):
shard_root
=
None
):
# type: (bool, bool, int, typing.Optional[str]) -> None
"""Constructs a `DatasetManager` instance.
Args:
is_training: Boolean of whether the data provided is training or
evaluation data. This determines whether to reuse the data
(if is_training=False) and the exact structure to use when storing and
yielding data.
stream_files: Boolean indicating whether data should be serialized and
written to file shards.
batches_per_epoch: The number of batches in a single epoch.
shard_root: The base directory to be used when stream_files=True.
"""
self
.
_is_training
=
is_training
self
.
_is_training
=
is_training
self
.
_stream_files
=
stream_files
self
.
_stream_files
=
stream_files
self
.
_writers
=
[]
self
.
_writers
=
[]
...
@@ -183,9 +193,8 @@ class DatasetManager(object):
...
@@ -183,9 +193,8 @@ class DatasetManager(object):
batch_size
=
data
[
movielens
.
ITEM_COLUMN
].
shape
[
0
]
batch_size
=
data
[
movielens
.
ITEM_COLUMN
].
shape
[
0
]
data
[
rconst
.
VALID_POINT_MASK
]
=
np
.
less
(
np
.
arange
(
batch_size
),
data
[
rconst
.
VALID_POINT_MASK
]
=
np
.
less
(
np
.
arange
(
batch_size
),
mask_start_index
)
mask_start_index
)
self
.
_result_queue
.
put
((
data
,
data
.
pop
(
"labels"
)))
data
=
(
data
,
data
.
pop
(
"labels"
))
else
:
self
.
_result_queue
.
put
(
data
)
self
.
_result_reuse
.
append
(
data
)
def
start_construction
(
self
):
def
start_construction
(
self
):
if
self
.
_stream_files
:
if
self
.
_stream_files
:
...
@@ -199,26 +208,31 @@ class DatasetManager(object):
...
@@ -199,26 +208,31 @@ class DatasetManager(object):
[
writer
.
close
()
for
writer
in
self
.
_writers
]
[
writer
.
close
()
for
writer
in
self
.
_writers
]
self
.
_writers
=
[]
self
.
_writers
=
[]
self
.
_result_queue
.
put
(
self
.
current_data_root
)
self
.
_result_queue
.
put
(
self
.
current_data_root
)
elif
not
self
.
_is_training
:
self
.
_result_queue
.
put
(
True
)
# data is ready.
self
.
_epochs_completed
+=
1
self
.
_epochs_completed
+=
1
def
data_generator
(
self
,
epochs_between_evals
):
def
data_generator
(
self
,
epochs_between_evals
):
"""Yields examples during local training."""
"""Yields examples during local training."""
assert
not
self
.
_stream_files
assert
not
self
.
_stream_files
assert
self
.
_is_training
or
epochs_between_evals
==
1
if
self
.
_is_training
:
if
self
.
_is_training
:
for
_
in
range
(
self
.
_batches_per_epoch
*
epochs_between_evals
):
for
_
in
range
(
self
.
_batches_per_epoch
*
epochs_between_evals
):
yield
self
.
_result_queue
.
get
(
timeout
=
300
)
yield
self
.
_result_queue
.
get
(
timeout
=
300
)
else
:
else
:
# Evaluation waits for all data to be ready.
if
self
.
_result_reuse
:
self
.
_result_queue
.
put
(
self
.
_result_queue
.
get
(
timeout
=
300
))
assert
len
(
self
.
_result_reuse
)
==
self
.
_batches_per_epoch
assert
len
(
self
.
_result_reuse
)
==
self
.
_batches_per_epoch
assert
epochs_between_evals
==
1
for
i
in
self
.
_result_reuse
:
for
i
in
self
.
_result_reuse
:
yield
i
yield
i
else
:
# First epoch.
for
_
in
range
(
self
.
_batches_per_epoch
*
epochs_between_evals
):
result
=
self
.
_result_queue
.
get
(
timeout
=
300
)
self
.
_result_reuse
.
append
(
result
)
yield
result
def
get_dataset
(
self
,
batch_size
,
epochs_between_evals
):
def
get_dataset
(
self
,
batch_size
,
epochs_between_evals
):
"""Construct the dataset to be used for training and eval.
"""Construct the dataset to be used for training and eval.
...
@@ -341,7 +355,7 @@ class BaseDataConstructor(threading.Thread):
...
@@ -341,7 +355,7 @@ class BaseDataConstructor(threading.Thread):
"User positives ({}) is different from item positives ({})"
.
format
(
"User positives ({}) is different from item positives ({})"
.
format
(
self
.
_train_pos_users
.
shape
,
self
.
_train_pos_items
.
shape
))
self
.
_train_pos_users
.
shape
,
self
.
_train_pos_items
.
shape
))
self
.
_train_pos_count
=
self
.
_train_pos_users
.
shape
[
0
]
(
self
.
_train_pos_count
,)
=
self
.
_train_pos_users
.
shape
self
.
_elements_in_epoch
=
(
1
+
num_train_negatives
)
*
self
.
_train_pos_count
self
.
_elements_in_epoch
=
(
1
+
num_train_negatives
)
*
self
.
_train_pos_count
self
.
train_batches_per_epoch
=
self
.
_count_batches
(
self
.
train_batches_per_epoch
=
self
.
_count_batches
(
self
.
_elements_in_epoch
,
train_batch_size
,
batches_per_train_step
)
self
.
_elements_in_epoch
,
train_batch_size
,
batches_per_train_step
)
...
@@ -372,13 +386,12 @@ class BaseDataConstructor(threading.Thread):
...
@@ -372,13 +386,12 @@ class BaseDataConstructor(threading.Thread):
False
,
stream_files
,
self
.
eval_batches_per_epoch
,
self
.
_shard_root
)
False
,
stream_files
,
self
.
eval_batches_per_epoch
,
self
.
_shard_root
)
# Threading details
# Threading details
self
.
_current_epoch_order_lock
=
threading
.
RLock
()
super
(
BaseDataConstructor
,
self
).
__init__
()
super
(
BaseDataConstructor
,
self
).
__init__
()
self
.
daemon
=
True
self
.
daemon
=
True
self
.
_stop_loop
=
False
self
.
_stop_loop
=
False
self
.
_fatal_exception
=
None
self
.
_fatal_exception
=
None
def
__
rep
r__
(
self
):
def
__
st
r__
(
self
):
multiplier
=
(
"(x{} devices)"
.
format
(
self
.
_batches_per_train_step
)
multiplier
=
(
"(x{} devices)"
.
format
(
self
.
_batches_per_train_step
)
if
self
.
_batches_per_train_step
>
1
else
""
)
if
self
.
_batches_per_train_step
>
1
else
""
)
summary
=
SUMMARY_TEMPLATE
.
format
(
summary
=
SUMMARY_TEMPLATE
.
format
(
...
@@ -388,24 +401,17 @@ class BaseDataConstructor(threading.Thread):
...
@@ -388,24 +401,17 @@ class BaseDataConstructor(threading.Thread):
train_batch_ct
=
self
.
train_batches_per_epoch
,
train_batch_ct
=
self
.
train_batches_per_epoch
,
eval_pos_ct
=
self
.
_num_users
,
eval_batch_size
=
self
.
eval_batch_size
,
eval_pos_ct
=
self
.
_num_users
,
eval_batch_size
=
self
.
eval_batch_size
,
eval_batch_ct
=
self
.
eval_batches_per_epoch
,
multiplier
=
multiplier
)
eval_batch_ct
=
self
.
eval_batches_per_epoch
,
multiplier
=
multiplier
)
return
super
(
BaseDataConstructor
,
self
).
__
rep
r__
()
+
"
\n
"
+
summary
return
super
(
BaseDataConstructor
,
self
).
__
st
r__
()
+
"
\n
"
+
summary
@
staticmethod
@
staticmethod
def
_count_batches
(
example_count
,
batch_size
,
batches_per_step
):
def
_count_batches
(
example_count
,
batch_size
,
batches_per_step
):
"""Determine the number of batches, rounding up to fill all devices."""
x
=
(
example_count
+
batch_size
-
1
)
//
batch_size
x
=
(
example_count
+
batch_size
-
1
)
//
batch_size
return
(
x
+
batches_per_step
-
1
)
//
batches_per_step
*
batches_per_step
return
(
x
+
batches_per_step
-
1
)
//
batches_per_step
*
batches_per_step
def
stop_loop
(
self
):
def
stop_loop
(
self
):
self
.
_stop_loop
=
True
self
.
_stop_loop
=
True
def
_get_order_chunk
(
self
):
with
self
.
_current_epoch_order_lock
:
batch_indices
,
self
.
_current_epoch_order
=
(
self
.
_current_epoch_order
[:
self
.
train_batch_size
],
self
.
_current_epoch_order
[
self
.
train_batch_size
:])
return
batch_indices
def
construct_lookup_variables
(
self
):
def
construct_lookup_variables
(
self
):
"""Perform any one time pre-compute work."""
"""Perform any one time pre-compute work."""
raise
NotImplementedError
raise
NotImplementedError
...
@@ -429,7 +435,7 @@ class BaseDataConstructor(threading.Thread):
...
@@ -429,7 +435,7 @@ class BaseDataConstructor(threading.Thread):
except
Exception
as
e
:
except
Exception
as
e
:
# The Thread base class swallows stack traces, so unfortunately it is
# The Thread base class swallows stack traces, so unfortunately it is
# necessary to catch and re-raise to get debug output
# necessary to catch and re-raise to get debug output
print
(
traceback
.
forma
t_exc
()
,
file
=
sys
.
stderr
)
traceback
.
prin
t_exc
()
self
.
_fatal_exception
=
e
self
.
_fatal_exception
=
e
sys
.
stderr
.
flush
()
sys
.
stderr
.
flush
()
raise
raise
...
@@ -448,8 +454,9 @@ class BaseDataConstructor(threading.Thread):
...
@@ -448,8 +454,9 @@ class BaseDataConstructor(threading.Thread):
i: The index of the batch. This is used when stream_files=True to assign
i: The index of the batch. This is used when stream_files=True to assign
data to file shards.
data to file shards.
"""
"""
batch_indices
=
self
.
_get_order_chunk
()
batch_indices
=
self
.
_current_epoch_order
[
i
*
self
.
train_batch_size
:
mask_start_index
=
batch_indices
.
shape
[
0
]
(
i
+
1
)
*
self
.
train_batch_size
]
(
mask_start_index
,)
=
batch_indices
.
shape
batch_ind_mod
=
np
.
mod
(
batch_indices
,
self
.
_train_pos_count
)
batch_ind_mod
=
np
.
mod
(
batch_indices
,
self
.
_train_pos_count
)
users
=
self
.
_train_pos_users
[
batch_ind_mod
]
users
=
self
.
_train_pos_users
[
batch_ind_mod
]
...
@@ -462,7 +469,7 @@ class BaseDataConstructor(threading.Thread):
...
@@ -462,7 +469,7 @@ class BaseDataConstructor(threading.Thread):
items
=
self
.
_train_pos_items
[
batch_ind_mod
]
items
=
self
.
_train_pos_items
[
batch_ind_mod
]
items
[
negative_indices
]
=
negative_items
items
[
negative_indices
]
=
negative_items
labels
=
np
.
logical_not
(
negative_indices
)
.
astype
(
np
.
bool
)
labels
=
np
.
logical_not
(
negative_indices
)
# Pad last partial batch
# Pad last partial batch
pad_length
=
self
.
train_batch_size
-
mask_start_index
pad_length
=
self
.
train_batch_size
-
mask_start_index
...
@@ -502,8 +509,7 @@ class BaseDataConstructor(threading.Thread):
...
@@ -502,8 +509,7 @@ class BaseDataConstructor(threading.Thread):
self
.
_train_dataset
.
start_construction
()
self
.
_train_dataset
.
start_construction
()
map_args
=
list
(
range
(
self
.
train_batches_per_epoch
))
map_args
=
list
(
range
(
self
.
train_batches_per_epoch
))
assert
not
self
.
_current_epoch_order
.
shape
[
0
]
self
.
_current_epoch_order
=
next
(
self
.
_shuffle_iterator
)
self
.
_current_epoch_order
=
six
.
next
(
self
.
_shuffle_iterator
)
with
popen_helper
.
get_threadpool
(
6
)
as
pool
:
with
popen_helper
.
get_threadpool
(
6
)
as
pool
:
pool
.
map
(
self
.
_get_training_batch
,
map_args
)
pool
.
map
(
self
.
_get_training_batch
,
map_args
)
...
@@ -536,7 +542,7 @@ class BaseDataConstructor(threading.Thread):
...
@@ -536,7 +542,7 @@ class BaseDataConstructor(threading.Thread):
items
=
np
.
concatenate
([
positive_items
,
negative_items
],
axis
=
1
)
items
=
np
.
concatenate
([
positive_items
,
negative_items
],
axis
=
1
)
# We pad the users and items here so that the duplicate mask calculation
# We pad the users and items here so that the duplicate mask calculation
# will include
the
padding. The metric function relies on
every
element
# will include padding. The metric function relies on
all padded
element
s
# except the positive being marked as duplicate to mask out padded points.
# except the positive being marked as duplicate to mask out padded points.
if
users
.
shape
[
0
]
<
users_per_batch
:
if
users
.
shape
[
0
]
<
users_per_batch
:
pad_rows
=
users_per_batch
-
users
.
shape
[
0
]
pad_rows
=
users_per_batch
-
users
.
shape
[
0
]
...
@@ -592,6 +598,8 @@ class BaseDataConstructor(threading.Thread):
...
@@ -592,6 +598,8 @@ class BaseDataConstructor(threading.Thread):
timeit
.
default_timer
()
-
start_time
))
timeit
.
default_timer
()
-
start_time
))
def
make_input_fn
(
self
,
is_training
):
def
make_input_fn
(
self
,
is_training
):
# It isn't feasible to provide a foolproof check, so this is designed to
# catch most failures rather than provide an exhaustive guard.
if
self
.
_fatal_exception
is
not
None
:
if
self
.
_fatal_exception
is
not
None
:
raise
ValueError
(
"Fatal exception in the data production loop: {}"
raise
ValueError
(
"Fatal exception in the data production loop: {}"
.
format
(
self
.
_fatal_exception
))
.
format
(
self
.
_fatal_exception
))
...
@@ -616,7 +624,7 @@ class DummyConstructor(threading.Thread):
...
@@ -616,7 +624,7 @@ class DummyConstructor(threading.Thread):
def
input_fn
(
params
):
def
input_fn
(
params
):
"""Generated input_fn for the given epoch."""
"""Generated input_fn for the given epoch."""
batch_size
=
(
params
[
"batch_size"
]
if
is_training
else
batch_size
=
(
params
[
"batch_size"
]
if
is_training
else
params
[
"eval_batch_size"
]
or
params
[
"batch_size"
]
)
params
[
"eval_batch_size"
])
num_users
=
params
[
"num_users"
]
num_users
=
params
[
"num_users"
]
num_items
=
params
[
"num_items"
]
num_items
=
params
[
"num_items"
]
...
@@ -657,7 +665,7 @@ class MaterializedDataConstructor(BaseDataConstructor):
...
@@ -657,7 +665,7 @@ class MaterializedDataConstructor(BaseDataConstructor):
This class creates a table (num_users x num_items) containing all of the
This class creates a table (num_users x num_items) containing all of the
negative examples for each user. This table is conceptually ragged; that is to
negative examples for each user. This table is conceptually ragged; that is to
say the items dimension will have
elements at the end which are not use
d equal
say the items dimension will have
a number of unused elements at the en
d equal
to the number of positive elements for a given user. For instance:
to the number of positive elements for a given user. For instance:
num_users = 3
num_users = 3
...
@@ -693,7 +701,7 @@ class MaterializedDataConstructor(BaseDataConstructor):
...
@@ -693,7 +701,7 @@ class MaterializedDataConstructor(BaseDataConstructor):
start_time
=
timeit
.
default_timer
()
start_time
=
timeit
.
default_timer
()
inner_bounds
=
np
.
argwhere
(
self
.
_train_pos_users
[
1
:]
-
inner_bounds
=
np
.
argwhere
(
self
.
_train_pos_users
[
1
:]
-
self
.
_train_pos_users
[:
-
1
])[:,
0
]
+
1
self
.
_train_pos_users
[:
-
1
])[:,
0
]
+
1
upper_bound
=
self
.
_train_pos_users
.
shape
[
0
]
(
upper_bound
,)
=
self
.
_train_pos_users
.
shape
index_bounds
=
[
0
]
+
inner_bounds
.
tolist
()
+
[
upper_bound
]
index_bounds
=
[
0
]
+
inner_bounds
.
tolist
()
+
[
upper_bound
]
self
.
_negative_table
=
np
.
zeros
(
shape
=
(
self
.
_num_users
,
self
.
_num_items
),
self
.
_negative_table
=
np
.
zeros
(
shape
=
(
self
.
_num_users
,
self
.
_num_items
),
dtype
=
rconst
.
ITEM_DTYPE
)
dtype
=
rconst
.
ITEM_DTYPE
)
...
...
official/recommendation/ncf_main.py
View file @
ec0d43ba
...
@@ -114,7 +114,7 @@ def construct_estimator(model_dir, params):
...
@@ -114,7 +114,7 @@ def construct_estimator(model_dir, params):
def
log_and_get_hooks
(
eval_batch_size
):
def
log_and_get_hooks
(
eval_batch_size
):
"""Convenience
method
for hook and logger creation."""
"""Convenience
function
for hook and logger creation."""
# Create hooks that log information about the training and metric values
# Create hooks that log information about the training and metric values
train_hooks
=
hooks_helper
.
get_train_hooks
(
train_hooks
=
hooks_helper
.
get_train_hooks
(
FLAGS
.
hooks
,
FLAGS
.
hooks
,
...
@@ -140,19 +140,16 @@ def log_and_get_hooks(eval_batch_size):
...
@@ -140,19 +140,16 @@ def log_and_get_hooks(eval_batch_size):
def
parse_flags
(
flags_obj
):
def
parse_flags
(
flags_obj
):
"""Convenience
method
to turn flags into params."""
"""Convenience
function
to turn flags into params."""
num_gpus
=
flags_core
.
get_num_gpus
(
flags_obj
)
num_gpus
=
flags_core
.
get_num_gpus
(
flags_obj
)
num_devices
=
FLAGS
.
num_tpu_shards
if
FLAGS
.
tpu
else
num_gpus
or
1
num_devices
=
FLAGS
.
num_tpu_shards
if
FLAGS
.
tpu
else
num_gpus
or
1
batch_size
=
distribution_utils
.
per_device_batch_size
(
batch_size
=
(
flags_obj
.
batch_size
+
num_devices
-
1
)
//
num_devices
(
int
(
flags_obj
.
batch_size
)
+
num_devices
-
1
)
//
num_devices
*
num_devices
,
num_devices
)
eval_divisor
=
(
rconst
.
NUM_EVAL_NEGATIVES
+
1
)
*
num_devices
eval_divisor
=
(
rconst
.
NUM_EVAL_NEGATIVES
+
1
)
*
num_devices
eval_batch_size
=
int
(
flags_obj
.
eval_batch_size
or
flags_obj
.
batch_size
or
1
)
eval_batch_size
=
flags_obj
.
eval_batch_size
or
flags_obj
.
batch_size
eval_batch_size
=
distribution_utils
.
per_device_batch_size
(
eval_batch_size
=
((
eval_batch_size
+
eval_divisor
-
1
)
//
(
eval_batch_size
+
eval_divisor
-
1
)
//
eval_divisor
*
eval_divisor
//
num_devices
)
eval_divisor
*
eval_divisor
,
num_devices
)
return
{
return
{
"train_epochs"
:
flags_obj
.
train_epochs
,
"train_epochs"
:
flags_obj
.
train_epochs
,
...
...
official/recommendation/stat_utils.py
View file @
ec0d43ba
...
@@ -18,27 +18,32 @@ from __future__ import absolute_import
...
@@ -18,27 +18,32 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
atexit
from
collections
import
deque
import
multiprocessing
import
os
import
os
import
struct
import
sys
import
threading
import
time
import
numpy
as
np
import
numpy
as
np
from
official.recommendation
import
popen_helper
def
random_int32
():
def
random_int32
():
return
np
.
random
.
randint
(
low
=
0
,
high
=
np
.
iinfo
(
np
.
int32
).
max
,
dtype
=
np
.
int32
)
return
np
.
random
.
randint
(
low
=
0
,
high
=
np
.
iinfo
(
np
.
int32
).
max
,
dtype
=
np
.
int32
)
def
permutation
(
args
):
def
permutation
(
args
):
"""Fork safe permutation function.
This function can be called within a multiprocessing worker and give
appropriately random results.
Args:
args: A size two tuple that will unpacked into the size of the permutation
and the random seed. This form is used because starmap is not universally
available.
returns:
A NumPy array containing a random permutation.
"""
x
,
seed
=
args
x
,
seed
=
args
seed
=
seed
or
struct
.
unpack
(
"<L"
,
os
.
urandom
(
4
))[
0
]
# If seed is None NumPy will seed randomly.
state
=
np
.
random
.
RandomState
(
seed
=
seed
)
# pylint: disable=no-member
state
=
np
.
random
.
RandomState
(
seed
=
seed
)
# pylint: disable=no-member
output
=
np
.
arange
(
x
,
dtype
=
np
.
int32
)
output
=
np
.
arange
(
x
,
dtype
=
np
.
int32
)
state
.
shuffle
(
output
)
state
.
shuffle
(
output
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment