Commit dbdf712e authored by guptapriya's avatar guptapriya Committed by guptapriya
Browse files

make training input handling in keras fit case the same as CTL case

parent 5b81bb59
......@@ -93,8 +93,6 @@ def _get_train_and_eval_data(producer, params):
train_input_fn = producer.make_input_fn(is_training=True)
train_input_dataset = train_input_fn(params).map(
preprocess_train_input)
if not params["keras_use_ctl"]:
train_input_dataset = train_input_dataset.repeat(FLAGS.train_epochs)
def preprocess_eval_input(features):
"""Pre-process the eval data.
......@@ -286,8 +284,7 @@ def run_ncf(_):
eval_input_dataset = eval_input_dataset.batch(batches_per_step)
time_callback = keras_utils.TimeHistory(batch_size, FLAGS.log_steps)
per_epoch_callback = IncrementEpochCallback(producer)
callbacks = [per_epoch_callback, time_callback]
callbacks = [time_callback]
if FLAGS.early_stopping:
early_stopping_callback = CustomEarlyStopping(
......@@ -388,7 +385,6 @@ def run_ncf(_):
keras_model.compile(optimizer=optimizer)
history = keras_model.fit(train_input_dataset,
steps_per_epoch=num_train_steps,
epochs=FLAGS.train_epochs,
callbacks=callbacks,
validation_data=eval_input_dataset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment