Commit dfcca061 authored by Reed's avatar Reed
Browse files

Use tf.float32 instead of "float32"

parent acdf24c5
......@@ -50,7 +50,7 @@ def create_model(params, is_train):
if params["enable_metrics_in_training"]:
logits = metrics.MetricLayer(vocab_size)([logits, targets])
logits = tf.keras.layers.Lambda(lambda x: x, name="logits",
dtype="float32")(logits)
dtype=tf.float32)(logits)
model = tf.keras.Model([inputs, targets], logits)
# TODO(reedwm): Can we do this loss in float16 instead of float32?
loss = metrics.transformer_loss(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment