Commit 513fdbb2 authored by guptapriya's avatar guptapriya Committed by Toby Boyd
Browse files

Transformer 2.0: Make metrics layer optional (#7075)

* trying fake merge call

* make metrics optional

* Remove extra print
parent 0f6845ce
...@@ -102,6 +102,9 @@ def define_transformer_flags(): ...@@ -102,6 +102,9 @@ def define_transformer_flags():
flags.DEFINE_boolean( flags.DEFINE_boolean(
name='enable_tensorboard', default=False, name='enable_tensorboard', default=False,
help='Whether to enable Tensorboard callback.') help='Whether to enable Tensorboard callback.')
flags.DEFINE_boolean(
name='enable_metrics_in_training', default=False,
help='Whether to enable metrics during training.')
flags.DEFINE_string( flags.DEFINE_string(
name='profile_steps', default=None, name='profile_steps', default=None,
help='Save profiling data to model dir at given range of steps. The ' help='Save profiling data to model dir at given range of steps. The '
......
...@@ -42,7 +42,8 @@ def create_model(params, is_train): ...@@ -42,7 +42,8 @@ def create_model(params, is_train):
logits = internal_model([inputs, targets], training=is_train) logits = internal_model([inputs, targets], training=is_train)
vocab_size = params["vocab_size"] vocab_size = params["vocab_size"]
label_smoothing = params["label_smoothing"] label_smoothing = params["label_smoothing"]
logits = metrics.MetricLayer(vocab_size)([logits, targets]) if params["enable_metrics_in_training"]:
logits = metrics.MetricLayer(vocab_size)([logits, targets])
logits = metrics.LossLayer(vocab_size, label_smoothing)([logits, targets]) logits = metrics.LossLayer(vocab_size, label_smoothing)([logits, targets])
logits = tf.keras.layers.Lambda(lambda x: x, name="logits")(logits) logits = tf.keras.layers.Lambda(lambda x: x, name="logits")(logits)
return tf.keras.Model([inputs, targets], logits) return tf.keras.Model([inputs, targets], logits)
......
...@@ -119,6 +119,7 @@ class TransformerTask(object): ...@@ -119,6 +119,7 @@ class TransformerTask(object):
params["batch_size"] = flags_obj.batch_size or params["default_batch_size"] params["batch_size"] = flags_obj.batch_size or params["default_batch_size"]
params["repeat_dataset"] = None params["repeat_dataset"] = None
params["dtype"] = flags_core.get_tf_dtype(flags_obj) params["dtype"] = flags_core.get_tf_dtype(flags_obj)
params["enable_metrics_in_training"] = flags_obj.enable_metrics_in_training
if params["dtype"] == tf.float16: if params["dtype"] == tf.float16:
# TODO(reedwm): It's pretty ugly to set the global policy in a constructor # TODO(reedwm): It's pretty ugly to set the global policy in a constructor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment