Unverified Commit a552e76a authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

Change to experimental_run_tf_function. (#7344)

parent 480d2630
......@@ -370,9 +370,10 @@ def run_ncf(_):
else:
with distribution_utils.get_strategy_scope(strategy):
keras_model.compile(optimizer=optimizer,
keras_model.compile(
optimizer=optimizer,
run_eagerly=FLAGS.run_eagerly,
run_distributed=FLAGS.force_v2_in_keras_compile)
experimental_run_tf_function=FLAGS.force_v2_in_keras_compile)
history = keras_model.fit(
train_input_dataset,
......
......@@ -177,12 +177,13 @@ def run(flags_obj):
optimizer = keras_common.get_optimizer()
model = resnet_cifar_model.resnet56(classes=cifar_main.NUM_CLASSES)
model.compile(loss='categorical_crossentropy',
model.compile(
loss='categorical_crossentropy',
optimizer=optimizer,
metrics=(['categorical_accuracy']
if flags_obj.report_accuracy_metrics else None),
run_eagerly=flags_obj.run_eagerly,
run_distributed=flags_obj.force_v2_in_keras_compile)
experimental_run_tf_function=flags_obj.force_v2_in_keras_compile)
callbacks = keras_common.get_callbacks(
learning_rate_schedule, cifar_main.NUM_IMAGES['train'])
......
......@@ -201,12 +201,13 @@ def run(flags_obj):
num_classes=imagenet_main.NUM_CLASSES,
dtype=dtype)
model.compile(loss='sparse_categorical_crossentropy',
model.compile(
loss='sparse_categorical_crossentropy',
optimizer=optimizer,
metrics=(['sparse_categorical_accuracy']
if flags_obj.report_accuracy_metrics else None),
run_eagerly=flags_obj.run_eagerly,
run_distributed=flags_obj.force_v2_in_keras_compile)
experimental_run_tf_function=flags_obj.force_v2_in_keras_compile)
callbacks = keras_common.get_callbacks(
learning_rate_schedule, imagenet_main.NUM_IMAGES['train'])
......
......@@ -162,13 +162,13 @@ def train_model(flags_obj, dataset, vocab_size, strategy, checkpoint_dir=None):
with strategy_scope:
model = build_model(vocab_size=vocab_size, batch_size=flags_obj.batch_size)
model.compile(optimizer=tf.keras.optimizers.Adam(),
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.CategoricalCrossentropy(),
metrics=[
tf.keras.metrics.Recall(top_k=1, name='RecallAt1'),
metrics=[tf.keras.metrics.Recall(top_k=1, name='RecallAt1'),
tf.keras.metrics.Recall(top_k=5, name='RecallAt5')],
run_eagerly=flags_obj.run_eagerly,
run_distributed=flags_obj.force_v2_in_keras_compile)
experimental_run_tf_function=flags_obj.force_v2_in_keras_compile)
callbacks = []
if checkpoint_dir:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment