Commit 513fdbb2 authored by guptapriya's avatar guptapriya Committed by Toby Boyd
Browse files

Transformer 2.0: Make metrics layer optional (#7075)

* trying fake merge call

* make metrics optional

* Remove extra print
parent 0f6845ce
......@@ -102,6 +102,9 @@ def define_transformer_flags():
flags.DEFINE_boolean(
name='enable_tensorboard', default=False,
help='Whether to enable Tensorboard callback.')
flags.DEFINE_boolean(
name='enable_metrics_in_training', default=False,
help='Whether to enable metrics during training.')
flags.DEFINE_string(
name='profile_steps', default=None,
help='Save profiling data to model dir at given range of steps. The '
......
......@@ -42,7 +42,8 @@ def create_model(params, is_train):
logits = internal_model([inputs, targets], training=is_train)
vocab_size = params["vocab_size"]
label_smoothing = params["label_smoothing"]
logits = metrics.MetricLayer(vocab_size)([logits, targets])
if params["enable_metrics_in_training"]:
logits = metrics.MetricLayer(vocab_size)([logits, targets])
logits = metrics.LossLayer(vocab_size, label_smoothing)([logits, targets])
logits = tf.keras.layers.Lambda(lambda x: x, name="logits")(logits)
return tf.keras.Model([inputs, targets], logits)
......
......@@ -119,6 +119,7 @@ class TransformerTask(object):
params["batch_size"] = flags_obj.batch_size or params["default_batch_size"]
params["repeat_dataset"] = None
params["dtype"] = flags_core.get_tf_dtype(flags_obj)
params["enable_metrics_in_training"] = flags_obj.enable_metrics_in_training
if params["dtype"] == tf.float16:
# TODO(reedwm): It's pretty ugly to set the global policy in a constructor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment