Commit 02cc984e authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 278754959
parent 0203278a
......@@ -57,7 +57,8 @@ def define_flags():
intra_op=False,
synthetic_data=False,
max_train_steps=False,
dtype=False,
dtype=True,
loss_scale=True,
enable_xla=True,
force_v2_in_keras_compile=True)
......@@ -156,7 +157,8 @@ def build_model(vocab_size,
return_sequences=True,
stateful=stateful,
recurrent_initializer='glorot_uniform'),
tf.keras.layers.Dense(vocab_size, activation='softmax')])
tf.keras.layers.Dense(vocab_size),
tf.keras.layers.Softmax(dtype=tf.float32)])
def train_model(flags_obj, dataset, vocab_size, strategy, checkpoint_dir=None):
......@@ -178,6 +180,9 @@ def train_model(flags_obj, dataset, vocab_size, strategy, checkpoint_dir=None):
with strategy_scope:
model = build_model(vocab_size=vocab_size, batch_size=flags_obj.batch_size,
use_cudnn=flags_obj.cudnn)
# When keras_use_ctl is False, Model.fit() automatically applies
# loss scaling so we don't need to create a LossScaleOptimizer.
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.CategoricalCrossentropy(),
......@@ -260,6 +265,13 @@ def run(flags_obj):
'https://storage.googleapis.com/download.tensorflow.org/data/'
'shakespeare.txt')
if flags_obj.dtype == 'fp16':
policy = tf.keras.mixed_precision.experimental.Policy(
'mixed_float16',
loss_scale=flags_core.get_loss_scale(flags_obj,
default_for_fp16='dynamic'))
tf.keras.mixed_precision.experimental.set_policy(policy)
keras_utils.set_session_config(
enable_eager=flags_obj.enable_eager,
enable_xla=flags_obj.enable_xla)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment