Commit 87292aa4 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal Change

PiperOrigin-RevId: 314674292
parent 49b58967
...@@ -191,7 +191,8 @@ def _get_keras_model(params): ...@@ -191,7 +191,8 @@ def _get_keras_model(params):
# Custom training loop calculates loss and metric as a part of # Custom training loop calculates loss and metric as a part of
# training/evaluation step function. # training/evaluation step function.
if not params["keras_use_ctl"]: if not params["keras_use_ctl"]:
softmax_logits = MetricLayer(params)([softmax_logits, dup_mask_input]) softmax_logits = MetricLayer(
params["match_mlperf"])([softmax_logits, dup_mask_input])
# TODO(b/134744680): Use model.add_loss() instead once the API is well # TODO(b/134744680): Use model.add_loss() instead once the API is well
# supported. # supported.
softmax_logits = LossLayer(batch_size)( softmax_logits = LossLayer(batch_size)(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment