"...resnet50_tensorflow.git" did not exist on "0abef3ece01a166307ad4a65be64a1e2073db444"
Commit dbdf712e authored by guptapriya's avatar guptapriya Committed by guptapriya
Browse files

make training input handling in keras fit case the same as CTL case

parent 5b81bb59
......@@ -93,8 +93,6 @@ def _get_train_and_eval_data(producer, params):
train_input_fn = producer.make_input_fn(is_training=True)
train_input_dataset = train_input_fn(params).map(
preprocess_train_input)
if not params["keras_use_ctl"]:
train_input_dataset = train_input_dataset.repeat(FLAGS.train_epochs)
def preprocess_eval_input(features):
"""Pre-process the eval data.
......@@ -286,8 +284,7 @@ def run_ncf(_):
eval_input_dataset = eval_input_dataset.batch(batches_per_step)
time_callback = keras_utils.TimeHistory(batch_size, FLAGS.log_steps)
per_epoch_callback = IncrementEpochCallback(producer)
callbacks = [per_epoch_callback, time_callback]
callbacks = [time_callback]
if FLAGS.early_stopping:
early_stopping_callback = CustomEarlyStopping(
......@@ -388,7 +385,6 @@ def run_ncf(_):
keras_model.compile(optimizer=optimizer)
history = keras_model.fit(train_input_dataset,
steps_per_epoch=num_train_steps,
epochs=FLAGS.train_epochs,
callbacks=callbacks,
validation_data=eval_input_dataset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment