Unverified Commit a552e76a authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

Change to experimental_run_tf_function. (#7344)

parent 480d2630
...@@ -370,9 +370,10 @@ def run_ncf(_): ...@@ -370,9 +370,10 @@ def run_ncf(_):
else: else:
with distribution_utils.get_strategy_scope(strategy): with distribution_utils.get_strategy_scope(strategy):
keras_model.compile(optimizer=optimizer, keras_model.compile(
run_eagerly=FLAGS.run_eagerly, optimizer=optimizer,
run_distributed=FLAGS.force_v2_in_keras_compile) run_eagerly=FLAGS.run_eagerly,
experimental_run_tf_function=FLAGS.force_v2_in_keras_compile)
history = keras_model.fit( history = keras_model.fit(
train_input_dataset, train_input_dataset,
......
...@@ -177,12 +177,13 @@ def run(flags_obj): ...@@ -177,12 +177,13 @@ def run(flags_obj):
optimizer = keras_common.get_optimizer() optimizer = keras_common.get_optimizer()
model = resnet_cifar_model.resnet56(classes=cifar_main.NUM_CLASSES) model = resnet_cifar_model.resnet56(classes=cifar_main.NUM_CLASSES)
model.compile(loss='categorical_crossentropy', model.compile(
optimizer=optimizer, loss='categorical_crossentropy',
metrics=(['categorical_accuracy'] optimizer=optimizer,
if flags_obj.report_accuracy_metrics else None), metrics=(['categorical_accuracy']
run_eagerly=flags_obj.run_eagerly, if flags_obj.report_accuracy_metrics else None),
run_distributed=flags_obj.force_v2_in_keras_compile) run_eagerly=flags_obj.run_eagerly,
experimental_run_tf_function=flags_obj.force_v2_in_keras_compile)
callbacks = keras_common.get_callbacks( callbacks = keras_common.get_callbacks(
learning_rate_schedule, cifar_main.NUM_IMAGES['train']) learning_rate_schedule, cifar_main.NUM_IMAGES['train'])
......
...@@ -201,12 +201,13 @@ def run(flags_obj): ...@@ -201,12 +201,13 @@ def run(flags_obj):
num_classes=imagenet_main.NUM_CLASSES, num_classes=imagenet_main.NUM_CLASSES,
dtype=dtype) dtype=dtype)
model.compile(loss='sparse_categorical_crossentropy', model.compile(
optimizer=optimizer, loss='sparse_categorical_crossentropy',
metrics=(['sparse_categorical_accuracy'] optimizer=optimizer,
if flags_obj.report_accuracy_metrics else None), metrics=(['sparse_categorical_accuracy']
run_eagerly=flags_obj.run_eagerly, if flags_obj.report_accuracy_metrics else None),
run_distributed=flags_obj.force_v2_in_keras_compile) run_eagerly=flags_obj.run_eagerly,
experimental_run_tf_function=flags_obj.force_v2_in_keras_compile)
callbacks = keras_common.get_callbacks( callbacks = keras_common.get_callbacks(
learning_rate_schedule, imagenet_main.NUM_IMAGES['train']) learning_rate_schedule, imagenet_main.NUM_IMAGES['train'])
......
...@@ -162,13 +162,13 @@ def train_model(flags_obj, dataset, vocab_size, strategy, checkpoint_dir=None): ...@@ -162,13 +162,13 @@ def train_model(flags_obj, dataset, vocab_size, strategy, checkpoint_dir=None):
with strategy_scope: with strategy_scope:
model = build_model(vocab_size=vocab_size, batch_size=flags_obj.batch_size) model = build_model(vocab_size=vocab_size, batch_size=flags_obj.batch_size)
model.compile(optimizer=tf.keras.optimizers.Adam(), model.compile(
loss=tf.keras.losses.CategoricalCrossentropy(), optimizer=tf.keras.optimizers.Adam(),
metrics=[ loss=tf.keras.losses.CategoricalCrossentropy(),
tf.keras.metrics.Recall(top_k=1, name='RecallAt1'), metrics=[tf.keras.metrics.Recall(top_k=1, name='RecallAt1'),
tf.keras.metrics.Recall(top_k=5, name='RecallAt5')], tf.keras.metrics.Recall(top_k=5, name='RecallAt5')],
run_eagerly=flags_obj.run_eagerly, run_eagerly=flags_obj.run_eagerly,
run_distributed=flags_obj.force_v2_in_keras_compile) experimental_run_tf_function=flags_obj.force_v2_in_keras_compile)
callbacks = [] callbacks = []
if checkpoint_dir: if checkpoint_dir:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment