Commit a2a1b66f authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Move distribution stragegy init to beginning of run

PiperOrigin-RevId: 281641189
parent 92384c60
...@@ -160,22 +160,6 @@ class TransformerTask(object): ...@@ -160,22 +160,6 @@ class TransformerTask(object):
params["dtype"] = flags_core.get_tf_dtype(flags_obj) params["dtype"] = flags_core.get_tf_dtype(flags_obj)
params["enable_metrics_in_training"] = flags_obj.enable_metrics_in_training params["enable_metrics_in_training"] = flags_obj.enable_metrics_in_training
if params["dtype"] == tf.float16:
# TODO(reedwm): It's pretty ugly to set the global policy in a constructor
# like this. What if multiple instances of TransformerTask are created?
# We should have a better way in the tf.keras.mixed_precision API of doing
# this.
loss_scale = flags_core.get_loss_scale(flags_obj,
default_for_fp16="dynamic")
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
"mixed_float16", loss_scale=loss_scale)
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
if params["dtype"] == tf.bfloat16:
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
"mixed_bfloat16")
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
self.distribution_strategy = distribution_utils.get_distribution_strategy( self.distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy, distribution_strategy=flags_obj.distribution_strategy,
num_gpus=num_gpus, num_gpus=num_gpus,
...@@ -193,6 +177,22 @@ class TransformerTask(object): ...@@ -193,6 +177,22 @@ class TransformerTask(object):
else: else:
logging.info("Not using any distribution strategy.") logging.info("Not using any distribution strategy.")
if params["dtype"] == tf.float16:
# TODO(reedwm): It's pretty ugly to set the global policy in a constructor
# like this. What if multiple instances of TransformerTask are created?
# We should have a better way in the tf.keras.mixed_precision API of doing
# this.
loss_scale = flags_core.get_loss_scale(flags_obj,
default_for_fp16="dynamic")
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
"mixed_float16", loss_scale=loss_scale)
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
if params["dtype"] == tf.bfloat16:
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
"mixed_bfloat16")
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
@property @property
def use_tpu(self): def use_tpu(self):
if self.distribution_strategy: if self.distribution_strategy:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment