Commit dfcca061 authored by Reed's avatar Reed
Browse files

Use tf.float32 instead of "float32"

parent acdf24c5
...@@ -50,7 +50,7 @@ def create_model(params, is_train): ...@@ -50,7 +50,7 @@ def create_model(params, is_train):
if params["enable_metrics_in_training"]: if params["enable_metrics_in_training"]:
logits = metrics.MetricLayer(vocab_size)([logits, targets]) logits = metrics.MetricLayer(vocab_size)([logits, targets])
logits = tf.keras.layers.Lambda(lambda x: x, name="logits", logits = tf.keras.layers.Lambda(lambda x: x, name="logits",
dtype="float32")(logits) dtype=tf.float32)(logits)
model = tf.keras.Model([inputs, targets], logits) model = tf.keras.Model([inputs, targets], logits)
# TODO(reedwm): Can we do this loss in float16 instead of float32? # TODO(reedwm): Can we do this loss in float16 instead of float32?
loss = metrics.transformer_loss( loss = metrics.transformer_loss(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment