"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "ab4c10229d43e9d9f94c1e771d19c44c15461281"
Unverified Commit 803f833c authored by Hongjun Choi's avatar Hongjun Choi Committed by GitHub
Browse files

Merged commit includes the following changes: (#7322)

260228553  by priyag<priyag@google.com>:

    Enable transformer and NCF official model tests. Also fix some minor issues so that all tests pass with TF 1 + enable_v2_behavior.

--
260043210  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Add logic to train NCF model using offline generated data.

--
259778607  by priyag<priyag@google.com>:

    Internal change

259656389  by hongkuny<hongkuny@google.com>:

    Internal change

PiperOrigin-RevId: 260228553
parent 8c7a0e75
...@@ -47,8 +47,14 @@ flags.DEFINE_integer("num_train_epochs", 14, ...@@ -47,8 +47,14 @@ flags.DEFINE_integer("num_train_epochs", 14,
flags.DEFINE_integer( flags.DEFINE_integer(
"num_negative_samples", 4, "num_negative_samples", 4,
"Number of negative instances to pair with positive instance.") "Number of negative instances to pair with positive instance.")
flags.DEFINE_integer("prebatch_size", 99000, flags.DEFINE_integer(
"Batch size to be used for prebatching the dataset.") "train_prebatch_size", 99000,
"Batch size to be used for prebatching the dataset "
"for training.")
flags.DEFINE_integer(
"eval_prebatch_size", 99000,
"Batch size to be used for prebatching the dataset "
"for training.")
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -82,7 +88,10 @@ def prepare_raw_data(flag_obj): ...@@ -82,7 +88,10 @@ def prepare_raw_data(flag_obj):
"num_train_elements": producer._elements_in_epoch, "num_train_elements": producer._elements_in_epoch,
"num_eval_elements": producer._eval_elements_in_epoch, "num_eval_elements": producer._eval_elements_in_epoch,
"num_train_epochs": flag_obj.num_train_epochs, "num_train_epochs": flag_obj.num_train_epochs,
"prebatch_size": flag_obj.prebatch_size, "train_prebatch_size": flag_obj.train_prebatch_size,
"eval_prebatch_size": flag_obj.eval_prebatch_size,
"num_train_steps": producer.train_batches_per_epoch,
"num_eval_steps": producer.eval_batches_per_epoch,
} }
# pylint: enable=protected-access # pylint: enable=protected-access
......
...@@ -125,13 +125,16 @@ class DatasetManager(object): ...@@ -125,13 +125,16 @@ class DatasetManager(object):
return tf.train.Example( return tf.train.Example(
features=tf.train.Features(feature=feature_dict)).SerializeToString() features=tf.train.Features(feature=feature_dict)).SerializeToString()
def deserialize(self, serialized_data, batch_size): @staticmethod
def deserialize(serialized_data, batch_size=None, is_training=True):
"""Convert serialized TFRecords into tensors. """Convert serialized TFRecords into tensors.
Args: Args:
serialized_data: A tensor containing serialized records. serialized_data: A tensor containing serialized records.
batch_size: The data arrives pre-batched, so batch size is needed to batch_size: The data arrives pre-batched, so batch size is needed to
deserialize the data. deserialize the data.
is_training: Boolean, whether data to deserialize to training data
or evaluation data.
""" """
def _get_feature_map(batch_size, is_training=True): def _get_feature_map(batch_size, is_training=True):
...@@ -159,7 +162,7 @@ class DatasetManager(object): ...@@ -159,7 +162,7 @@ class DatasetManager(object):
} }
features = tf.parse_single_example( features = tf.parse_single_example(
serialized_data, _get_feature_map(batch_size, self._is_training)) serialized_data, _get_feature_map(batch_size, is_training=is_training))
users = tf.reshape( users = tf.reshape(
tf.cast(features[movielens.USER_COLUMN], rconst.USER_DTYPE), tf.cast(features[movielens.USER_COLUMN], rconst.USER_DTYPE),
(batch_size,)) (batch_size,))
...@@ -167,25 +170,39 @@ class DatasetManager(object): ...@@ -167,25 +170,39 @@ class DatasetManager(object):
tf.cast(features[movielens.ITEM_COLUMN], rconst.ITEM_DTYPE), tf.cast(features[movielens.ITEM_COLUMN], rconst.ITEM_DTYPE),
(batch_size,)) (batch_size,))
if self._is_training: if is_training:
valid_point_mask = tf.reshape( valid_point_mask = tf.reshape(
tf.cast(features[movielens.ITEM_COLUMN], tf.bool), (batch_size,)) tf.cast(features[movielens.ITEM_COLUMN], tf.bool), (batch_size,))
fake_dup_mask = tf.zeros_like(features[movielens.USER_COLUMN])
return { return {
movielens.USER_COLUMN: users, movielens.USER_COLUMN: users,
movielens.ITEM_COLUMN: items, movielens.ITEM_COLUMN: items,
rconst.VALID_POINT_MASK: valid_point_mask, rconst.VALID_POINT_MASK: valid_point_mask,
}, tf.reshape(tf.cast(features["labels"], tf.bool), (batch_size,)) rconst.TRAIN_LABEL_KEY:
tf.reshape(tf.cast(features["labels"], tf.bool),
return { (batch_size, 1)),
movielens.USER_COLUMN: rconst.DUPLICATE_MASK: fake_dup_mask
users, }
movielens.ITEM_COLUMN: else:
items, labels = tf.reshape(
rconst.DUPLICATE_MASK: tf.cast(tf.zeros_like(features[movielens.USER_COLUMN]), tf.bool),
tf.reshape( (batch_size, 1))
tf.cast(features[rconst.DUPLICATE_MASK], tf.bool), fake_valid_pt_mask = tf.cast(
(batch_size,)) tf.zeros_like(features[movielens.USER_COLUMN]), tf.bool)
} return {
movielens.USER_COLUMN:
users,
movielens.ITEM_COLUMN:
items,
rconst.DUPLICATE_MASK:
tf.reshape(
tf.cast(features[rconst.DUPLICATE_MASK], tf.bool),
(batch_size,)),
rconst.VALID_POINT_MASK:
fake_valid_pt_mask,
rconst.TRAIN_LABEL_KEY:
labels
}
def put(self, index, data): def put(self, index, data):
# type: (int, dict) -> None # type: (int, dict) -> None
...@@ -287,7 +304,10 @@ class DatasetManager(object): ...@@ -287,7 +304,10 @@ class DatasetManager(object):
files=file_pattern, worker_job=popen_helper.worker_job(), files=file_pattern, worker_job=popen_helper.worker_job(),
num_parallel_reads=rconst.NUM_FILE_SHARDS, num_epochs=1, num_parallel_reads=rconst.NUM_FILE_SHARDS, num_epochs=1,
sloppy=not self._deterministic) sloppy=not self._deterministic)
map_fn = functools.partial(self.deserialize, batch_size=batch_size) map_fn = functools.partial(
self.deserialize,
batch_size=batch_size,
is_training=self._is_training)
dataset = dataset.map(map_fn, num_parallel_calls=16) dataset = dataset.map(map_fn, num_parallel_calls=16)
else: else:
...@@ -672,6 +692,11 @@ class BaseDataConstructor(threading.Thread): ...@@ -672,6 +692,11 @@ class BaseDataConstructor(threading.Thread):
class DummyConstructor(threading.Thread): class DummyConstructor(threading.Thread):
"""Class for running with synthetic data.""" """Class for running with synthetic data."""
def __init__(self, *args, **kwargs):
super(DummyConstructor, self).__init__(*args, **kwargs)
self.train_batches_per_epoch = rconst.SYNTHETIC_BATCHES_PER_EPOCH
self.eval_batches_per_epoch = rconst.SYNTHETIC_BATCHES_PER_EPOCH
def run(self): def run(self):
pass pass
......
...@@ -109,6 +109,9 @@ def parse_flags(flags_obj): ...@@ -109,6 +109,9 @@ def parse_flags(flags_obj):
"keras_use_ctl": flags_obj.keras_use_ctl, "keras_use_ctl": flags_obj.keras_use_ctl,
"hr_threshold": flags_obj.hr_threshold, "hr_threshold": flags_obj.hr_threshold,
"stream_files": flags_obj.tpu is not None, "stream_files": flags_obj.tpu is not None,
"train_dataset_path": flags_obj.train_dataset_path,
"eval_dataset_path": flags_obj.eval_dataset_path,
"input_meta_data_path": flags_obj.input_meta_data_path,
} }
...@@ -261,6 +264,21 @@ def define_ncf_flags(): ...@@ -261,6 +264,21 @@ def define_ncf_flags():
"precompute that scales badly, but a faster per-epoch construction" "precompute that scales badly, but a faster per-epoch construction"
"time and can be faster on very large systems.")) "time and can be faster on very large systems."))
flags.DEFINE_string(
name="train_dataset_path",
default=None,
help=flags_core.help_wrap("Path to training data."))
flags.DEFINE_string(
name="eval_dataset_path",
default=None,
help=flags_core.help_wrap("Path to evaluation data."))
flags.DEFINE_string(
name="input_meta_data_path",
default=None,
help=flags_core.help_wrap("Path to input meta data file."))
flags.DEFINE_bool( flags.DEFINE_bool(
name="ml_perf", default=False, name="ml_perf", default=False,
help=flags_core.help_wrap( help=flags_core.help_wrap(
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""NCF model input pipeline."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
# pylint: disable=g-bad-import-order
import numpy as np
import tensorflow.compat.v2 as tf
# pylint: enable=g-bad-import-order
from official.datasets import movielens
from official.recommendation import constants as rconst
from official.recommendation import data_pipeline
NUM_SHARDS = 16
def create_dataset_from_tf_record_files(input_file_pattern,
pre_batch_size,
batch_size,
is_training=True):
"""Creates dataset from (tf)records files for training/evaluation."""
files = tf.data.Dataset.list_files(input_file_pattern, shuffle=is_training)
def make_dataset(files_dataset, shard_index):
"""Returns dataset for sharded tf record files."""
files_dataset = files_dataset.shard(NUM_SHARDS, shard_index)
dataset = files_dataset.interleave(tf.data.TFRecordDataset)
decode_fn = functools.partial(
data_pipeline.DatasetManager.deserialize,
batch_size=pre_batch_size,
is_training=is_training)
dataset = dataset.map(
decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.apply(tf.data.experimental.unbatch())
dataset = dataset.batch(batch_size, drop_remainder=True)
return dataset
dataset = tf.data.Dataset.range(NUM_SHARDS)
map_fn = functools.partial(make_dataset, files)
dataset = dataset.interleave(
map_fn,
cycle_length=NUM_SHARDS,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset
def create_dataset_from_data_producer(producer, params):
"""Return dataset online-generating data."""
def preprocess_train_input(features, labels):
"""Pre-process the training data.
This is needed because
- The label needs to be extended to be used in the loss fn
- We need the same inputs for training and eval so adding fake inputs
for DUPLICATE_MASK in training data.
Args:
features: Dictionary of features for training.
labels: Training labels.
Returns:
Processed training features.
"""
labels = tf.expand_dims(labels, -1)
fake_dup_mask = tf.zeros_like(features[movielens.USER_COLUMN])
features[rconst.DUPLICATE_MASK] = fake_dup_mask
features[rconst.TRAIN_LABEL_KEY] = labels
return features
train_input_fn = producer.make_input_fn(is_training=True)
train_input_dataset = train_input_fn(params).map(preprocess_train_input)
def preprocess_eval_input(features):
"""Pre-process the eval data.
This is needed because:
- The label needs to be extended to be used in the loss fn
- We need the same inputs for training and eval so adding fake inputs
for VALID_PT_MASK in eval data.
Args:
features: Dictionary of features for evaluation.
Returns:
Processed evaluation features.
"""
labels = tf.cast(tf.zeros_like(features[movielens.USER_COLUMN]), tf.bool)
labels = tf.expand_dims(labels, -1)
fake_valid_pt_mask = tf.cast(
tf.zeros_like(features[movielens.USER_COLUMN]), tf.bool)
features[rconst.VALID_POINT_MASK] = fake_valid_pt_mask
features[rconst.TRAIN_LABEL_KEY] = labels
return features
eval_input_fn = producer.make_input_fn(is_training=False)
eval_input_dataset = eval_input_fn(params).map(preprocess_eval_input)
return train_input_dataset, eval_input_dataset
def create_ncf_input_data(params, producer=None, input_meta_data=None):
"""Creates NCF training/evaluation dataset.
Args:
params: Dictionary containing parameters for train/evaluation data.
producer: Instance of BaseDataConstructor that generates data online. Must
not be None when params['train_dataset_path'] or
params['eval_dataset_path'] is not specified.
input_meta_data: A dictionary of input metadata to be used when reading data
from tf record files. Must be specified when params["train_input_dataset"]
is specified.
Returns:
(training dataset, evaluation dataset, train steps per epoch,
eval steps per epoch)
"""
if params["train_dataset_path"]:
train_dataset = create_dataset_from_tf_record_files(
params["train_dataset_path"],
input_meta_data["train_prebatch_size"],
params["batch_size"],
is_training=True)
eval_dataset = create_dataset_from_tf_record_files(
params["eval_dataset_path"],
input_meta_data["eval_prebatch_size"],
params["eval_batch_size"],
is_training=False)
# TODO(b/259377621): Remove number of devices (i.e.
# params["batches_per_step"]) in input pipeline logic and only use
# global batch size instead.
num_train_steps = int(
np.ceil(input_meta_data["num_train_steps"] /
params["batches_per_step"]))
num_eval_steps = (
input_meta_data["num_eval_steps"] // params["batches_per_step"])
else:
assert producer
# Start retrieving data from producer.
train_dataset, eval_dataset = create_dataset_from_data_producer(
producer, params)
num_train_steps = (
producer.train_batches_per_epoch // params["batches_per_step"])
num_eval_steps = (
producer.eval_batches_per_epoch // params["batches_per_step"])
assert not producer.train_batches_per_epoch % params["batches_per_step"]
assert not producer.eval_batches_per_epoch % params["batches_per_step"]
# It is required that for distributed training, the dataset must call
# batch(). The parameter of batch() here is the number of replicas involed,
# such that each replica evenly gets a slice of data.
# drop_remainder = True, as we would like batch call to return a fixed shape
# vs None, this prevents a expensive broadcast during weighted_loss
batches_per_step = params["batches_per_step"]
train_dataset = train_dataset.batch(batches_per_step, drop_remainder=True)
eval_dataset = eval_dataset.batch(batches_per_step, drop_remainder=True)
return train_dataset, eval_dataset, num_train_steps, num_eval_steps
...@@ -115,7 +115,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase): ...@@ -115,7 +115,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
Note: MLPerf like tests are not tuned to hit a specific hr@10 value, but Note: MLPerf like tests are not tuned to hit a specific hr@10 value, but
we want it recorded. we want it recorded.
""" """
self._run_and_report_benchmark(hr_at_10_min=0.61, hr_at_10_max=0.65) self._run_and_report_benchmark(hr_at_10_min=0.61)
def _run_and_report_benchmark(self, hr_at_10_min=0.630, hr_at_10_max=0.640): def _run_and_report_benchmark(self, hr_at_10_min=0.630, hr_at_10_max=0.640):
"""Run test and report results. """Run test and report results.
......
...@@ -22,10 +22,10 @@ from __future__ import absolute_import ...@@ -22,10 +22,10 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import json
import os import os
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
from absl import app as absl_app
from absl import flags from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
...@@ -34,6 +34,7 @@ import tensorflow as tf ...@@ -34,6 +34,7 @@ import tensorflow as tf
from official.datasets import movielens from official.datasets import movielens
from official.recommendation import constants as rconst from official.recommendation import constants as rconst
from official.recommendation import ncf_common from official.recommendation import ncf_common
from official.recommendation import ncf_input_pipeline
from official.recommendation import neumf_model from official.recommendation import neumf_model
from official.utils.logs import logger from official.utils.logs import logger
from official.utils.logs import mlperf_helper from official.utils.logs import mlperf_helper
...@@ -71,60 +72,6 @@ class MetricLayer(tf.keras.layers.Layer): ...@@ -71,60 +72,6 @@ class MetricLayer(tf.keras.layers.Layer):
return logits return logits
def _get_train_and_eval_data(producer, params):
"""Returns the datasets for training and evalutating."""
def preprocess_train_input(features, labels):
"""Pre-process the training data.
This is needed because
- The label needs to be extended to be used in the loss fn
- We need the same inputs for training and eval so adding fake inputs
for DUPLICATE_MASK in training data.
"""
labels = tf.expand_dims(labels, -1)
fake_dup_mask = tf.zeros_like(features[movielens.USER_COLUMN])
features[rconst.DUPLICATE_MASK] = fake_dup_mask
features[rconst.TRAIN_LABEL_KEY] = labels
if params["distribute_strategy"] or not keras_utils.is_v2_0():
return features
else:
# b/134708104
return (features,)
train_input_fn = producer.make_input_fn(is_training=True)
train_input_dataset = train_input_fn(params).map(
preprocess_train_input)
def preprocess_eval_input(features):
"""Pre-process the eval data.
This is needed because:
- The label needs to be extended to be used in the loss fn
- We need the same inputs for training and eval so adding fake inputs
for VALID_PT_MASK in eval data.
"""
labels = tf.cast(tf.zeros_like(features[movielens.USER_COLUMN]), tf.bool)
labels = tf.expand_dims(labels, -1)
fake_valid_pt_mask = tf.cast(
tf.zeros_like(features[movielens.USER_COLUMN]), tf.bool)
features[rconst.VALID_POINT_MASK] = fake_valid_pt_mask
features[rconst.TRAIN_LABEL_KEY] = labels
if params["distribute_strategy"] or not keras_utils.is_v2_0():
return features
else:
# b/134708104
return (features,)
eval_input_fn = producer.make_input_fn(is_training=False)
eval_input_dataset = eval_input_fn(params).map(
lambda features: preprocess_eval_input(features))
return train_input_dataset, eval_input_dataset
class IncrementEpochCallback(tf.keras.callbacks.Callback): class IncrementEpochCallback(tf.keras.callbacks.Callback):
"""A callback to increase the requested epoch for the data producer. """A callback to increase the requested epoch for the data producer.
...@@ -269,6 +216,7 @@ def run_ncf(_): ...@@ -269,6 +216,7 @@ def run_ncf(_):
FLAGS.eval_batch_size = FLAGS.batch_size FLAGS.eval_batch_size = FLAGS.batch_size
params = ncf_common.parse_flags(FLAGS) params = ncf_common.parse_flags(FLAGS)
model_helpers.apply_clean(flags.FLAGS)
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy, distribution_strategy=FLAGS.distribution_strategy,
...@@ -291,36 +239,36 @@ def run_ncf(_): ...@@ -291,36 +239,36 @@ def run_ncf(_):
params["batch_size"] = params["eval_batch_size"] params["batch_size"] = params["eval_batch_size"]
batch_size = params["batch_size"] batch_size = params["batch_size"]
num_users, num_items, num_train_steps, num_eval_steps, producer = ( time_callback = keras_utils.TimeHistory(batch_size, FLAGS.log_steps)
ncf_common.get_inputs(params)) callbacks = [time_callback]
producer, input_meta_data = None, None
generate_input_online = params["train_dataset_path"] is None
if generate_input_online:
# Start data producing thread.
num_users, num_items, num_train_steps, num_eval_steps, producer = (
ncf_common.get_inputs(params))
producer.start()
per_epoch_callback = IncrementEpochCallback(producer)
callbacks.append(per_epoch_callback)
else:
assert params["eval_dataset_path"] and params["input_meta_data_path"]
with tf.gfile.GFile(params["input_meta_data_path"], "rb") as reader:
input_meta_data = json.loads(reader.read().decode("utf-8"))
num_users = input_meta_data["num_users"]
num_items = input_meta_data["num_items"]
params["num_users"], params["num_items"] = num_users, num_items params["num_users"], params["num_items"] = num_users, num_items
producer.start() (train_input_dataset, eval_input_dataset, num_train_steps, num_eval_steps) = \
model_helpers.apply_clean(flags.FLAGS) (ncf_input_pipeline.create_ncf_input_data(
params, producer, input_meta_data))
batches_per_step = params["batches_per_step"] steps_per_epoch = None if generate_input_online else num_train_steps
train_input_dataset, eval_input_dataset = _get_train_and_eval_data(producer,
params)
# It is required that for distributed training, the dataset must call
# batch(). The parameter of batch() here is the number of replicas involed,
# such that each replica evenly gets a slice of data.
# drop_remainder = True, as we would like batch call to return a fixed shape
# vs None, this prevents a expensive broadcast during weighted_loss
train_input_dataset = train_input_dataset.batch(batches_per_step,
drop_remainder=True)
eval_input_dataset = eval_input_dataset.batch(batches_per_step,
drop_remainder=True)
time_callback = keras_utils.TimeHistory(batch_size, FLAGS.log_steps)
per_epoch_callback = IncrementEpochCallback(producer)
callbacks = [per_epoch_callback, time_callback]
if FLAGS.early_stopping: if FLAGS.early_stopping:
early_stopping_callback = CustomEarlyStopping( early_stopping_callback = CustomEarlyStopping(
"val_HR_METRIC", desired_value=FLAGS.hr_threshold) "val_HR_METRIC", desired_value=FLAGS.hr_threshold)
callbacks.append(early_stopping_callback) callbacks.append(early_stopping_callback)
with distribution_utils.get_strategy_scope(strategy): with distribution_utils.get_strategy_scope(strategy):
keras_model = _get_keras_model(params) keras_model = _get_keras_model(params)
optimizer = tf.keras.optimizers.Adam( optimizer = tf.keras.optimizers.Adam(
...@@ -331,7 +279,7 @@ def run_ncf(_): ...@@ -331,7 +279,7 @@ def run_ncf(_):
if params["keras_use_ctl"]: if params["keras_use_ctl"]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy( loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
reduction=tf.keras.losses.Reduction.SUM, reduction="sum",
from_logits=True) from_logits=True)
train_input_iterator = strategy.make_dataset_iterator(train_input_dataset) train_input_iterator = strategy.make_dataset_iterator(train_input_dataset)
eval_input_iterator = strategy.make_dataset_iterator(eval_input_dataset) eval_input_iterator = strategy.make_dataset_iterator(eval_input_dataset)
...@@ -383,8 +331,17 @@ def run_ncf(_): ...@@ -383,8 +331,17 @@ def run_ncf(_):
time_callback.on_train_begin() time_callback.on_train_begin()
for epoch in range(FLAGS.train_epochs): for epoch in range(FLAGS.train_epochs):
per_epoch_callback.on_epoch_begin(epoch) for cb in callbacks:
train_input_iterator.initialize() cb.on_epoch_begin(epoch)
# As NCF dataset is sampled with randomness, not repeating
# data elements in each epoch has significant impact on
# convergence. As so, offline-generated TF record files
# contains all epoch worth of data. Thus we do not need
# to initialize dataset when reading from tf record files.
if generate_input_online:
train_input_iterator.initialize()
train_loss = 0 train_loss = 0
for step in range(num_train_steps): for step in range(num_train_steps):
time_callback.on_batch_begin(step+epoch*num_train_steps) time_callback.on_batch_begin(step+epoch*num_train_steps)
...@@ -416,19 +373,19 @@ def run_ncf(_): ...@@ -416,19 +373,19 @@ def run_ncf(_):
run_eagerly=FLAGS.run_eagerly, run_eagerly=FLAGS.run_eagerly,
run_distributed=FLAGS.force_v2_in_keras_compile) run_distributed=FLAGS.force_v2_in_keras_compile)
history = keras_model.fit(train_input_dataset, history = keras_model.fit(
epochs=FLAGS.train_epochs, train_input_dataset,
callbacks=callbacks, epochs=FLAGS.train_epochs,
validation_data=eval_input_dataset, steps_per_epoch=steps_per_epoch,
validation_steps=num_eval_steps, callbacks=callbacks,
verbose=2) validation_data=eval_input_dataset,
validation_steps=num_eval_steps,
verbose=2)
logging.info("Training done. Start evaluating") logging.info("Training done. Start evaluating")
eval_results = keras_model.evaluate( eval_results = keras_model.evaluate(
eval_input_dataset, eval_input_dataset, steps=num_eval_steps, verbose=2)
steps=num_eval_steps,
verbose=2)
logging.info("Keras evaluation is done.") logging.info("Keras evaluation is done.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment