"...tests/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "7013e5e2d779d68d2a9cdafa96f8fdb2645618a5"
Commit 13cc0f70 authored by guptapriya's avatar guptapriya
Browse files

Use add_loss in transformer model

parent ab8febd4
...@@ -44,9 +44,12 @@ def create_model(params, is_train): ...@@ -44,9 +44,12 @@ def create_model(params, is_train):
label_smoothing = params["label_smoothing"] label_smoothing = params["label_smoothing"]
if params["enable_metrics_in_training"]: if params["enable_metrics_in_training"]:
logits = metrics.MetricLayer(vocab_size)([logits, targets]) logits = metrics.MetricLayer(vocab_size)([logits, targets])
logits = metrics.LossLayer(vocab_size, label_smoothing)([logits, targets])
logits = tf.keras.layers.Lambda(lambda x: x, name="logits")(logits) logits = tf.keras.layers.Lambda(lambda x: x, name="logits")(logits)
return tf.keras.Model([inputs, targets], logits) model = tf.keras.Model([inputs, targets], logits)
loss = metrics.transformer_loss(
logits, targets, label_smoothing, vocab_size)
model.add_loss(loss)
return model
else: else:
inputs = tf.keras.layers.Input((None,), dtype="int64", name="inputs") inputs = tf.keras.layers.Input((None,), dtype="int64", name="inputs")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment