"docs/source/git@developer.sourcefind.cn:norm/vllm.git" did not exist on "5313c2cb8b3bcf7f71c0e6024c59d120efe94d88"
Commit 51cb03b0 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Deprecate the graph rewrite path for fp16. This is no longer a TF2 api and there is no usage.

PiperOrigin-RevId: 410629444
parent 743c28aa
...@@ -14,14 +14,19 @@ ...@@ -14,14 +14,19 @@
"""Functions and classes related to training performance.""" """Functions and classes related to training performance."""
from absl import logging
import tensorflow as tf import tensorflow as tf
def configure_optimizer(optimizer, def configure_optimizer(optimizer,
use_float16=False, use_float16=False,
use_graph_rewrite=False, loss_scale=None,
loss_scale=None): use_graph_rewrite=None):
"""Configures optimizer object with performance options.""" """Configures optimizer object with performance options."""
if use_graph_rewrite is not None:
logging.warning('`use_graph_rewrite` is deprecated inside '
'`configure_optimizer`. Please remove the usage.')
del use_graph_rewrite
if use_float16: if use_float16:
if loss_scale in (None, 'dynamic'): if loss_scale in (None, 'dynamic'):
optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer) optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer)
...@@ -29,13 +34,6 @@ def configure_optimizer(optimizer, ...@@ -29,13 +34,6 @@ def configure_optimizer(optimizer,
# loss_scale is a number. We interpret that as a fixed loss scale. # loss_scale is a number. We interpret that as a fixed loss scale.
optimizer = tf.keras.mixed_precision.LossScaleOptimizer( optimizer = tf.keras.mixed_precision.LossScaleOptimizer(
optimizer, dynamic=False, initial_scale=loss_scale) optimizer, dynamic=False, initial_scale=loss_scale)
if use_graph_rewrite:
# Note: the model dtype must be 'float32', which will ensure
# tf.keras.mixed_precision and enable_mixed_precision_graph_rewrite do not
# double up.
optimizer = (
tf.compat.v1.mixed_precision.enable_mixed_precision_graph_rewrite(
optimizer))
return optimizer return optimizer
......
...@@ -121,9 +121,5 @@ def use_float16(): ...@@ -121,9 +121,5 @@ def use_float16():
return flags_core.get_tf_dtype(flags.FLAGS) == tf.float16 return flags_core.get_tf_dtype(flags.FLAGS) == tf.float16
def use_graph_rewrite():
return flags.FLAGS.fp16_implementation == 'graph_rewrite'
def get_loss_scale(): def get_loss_scale():
return flags_core.get_loss_scale(flags.FLAGS, default_for_fp16='dynamic') return flags_core.get_loss_scale(flags.FLAGS, default_for_fp16='dynamic')
...@@ -150,8 +150,7 @@ def run_bert_classifier(strategy, ...@@ -150,8 +150,7 @@ def run_bert_classifier(strategy,
FLAGS.optimizer_type) FLAGS.optimizer_type)
classifier_model.optimizer = performance.configure_optimizer( classifier_model.optimizer = performance.configure_optimizer(
optimizer, optimizer,
use_float16=common_flags.use_float16(), use_float16=common_flags.use_float16())
use_graph_rewrite=common_flags.use_graph_rewrite())
return classifier_model, core_model return classifier_model, core_model
# tf.keras.losses objects accept optional sample_weight arguments (eg. coming # tf.keras.losses objects accept optional sample_weight arguments (eg. coming
......
...@@ -125,8 +125,7 @@ def run_customized_training(strategy, ...@@ -125,8 +125,7 @@ def run_customized_training(strategy,
end_lr, optimizer_type) end_lr, optimizer_type)
pretrain_model.optimizer = performance.configure_optimizer( pretrain_model.optimizer = performance.configure_optimizer(
optimizer, optimizer,
use_float16=common_flags.use_float16(), use_float16=common_flags.use_float16())
use_graph_rewrite=common_flags.use_graph_rewrite())
return pretrain_model, core_model return pretrain_model, core_model
trained_model = model_training_utils.run_customized_training_loop( trained_model = model_training_utils.run_customized_training_loop(
......
...@@ -252,8 +252,7 @@ def train_squad(strategy, ...@@ -252,8 +252,7 @@ def train_squad(strategy,
squad_model.optimizer = performance.configure_optimizer( squad_model.optimizer = performance.configure_optimizer(
optimizer, optimizer,
use_float16=common_flags.use_float16(), use_float16=common_flags.use_float16())
use_graph_rewrite=common_flags.use_graph_rewrite())
return squad_model, core_model return squad_model, core_model
# Only when explicit_allreduce = True, post_allreduce_callbacks and # Only when explicit_allreduce = True, post_allreduce_callbacks and
......
...@@ -440,7 +440,6 @@ class TransformerTask(object): ...@@ -440,7 +440,6 @@ class TransformerTask(object):
opt = performance.configure_optimizer( opt = performance.configure_optimizer(
opt, opt,
use_float16=params["dtype"] == tf.float16, use_float16=params["dtype"] == tf.float16,
use_graph_rewrite=self.flags_obj.fp16_implementation == "graph_rewrite",
loss_scale=flags_core.get_loss_scale( loss_scale=flags_core.get_loss_scale(
self.flags_obj, default_for_fp16="dynamic")) self.flags_obj, default_for_fp16="dynamic"))
......
...@@ -401,7 +401,6 @@ class YoloTask(base_task.Task): ...@@ -401,7 +401,6 @@ class YoloTask(base_task.Task):
use_float16 = runtime_config.mixed_precision_dtype == 'float16' use_float16 = runtime_config.mixed_precision_dtype == 'float16'
optimizer = performance.configure_optimizer( optimizer = performance.configure_optimizer(
optimizer, optimizer,
use_graph_rewrite=False,
use_float16=use_float16, use_float16=use_float16,
loss_scale=runtime_config.loss_scale) loss_scale=runtime_config.loss_scale)
......
...@@ -72,14 +72,9 @@ class ResnetRunnable(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -72,14 +72,9 @@ class ResnetRunnable(orbit.StandardTrainer, orbit.StandardEvaluator):
# Make sure iterations variable is created inside scope. # Make sure iterations variable is created inside scope.
self.global_step = self.optimizer.iterations self.global_step = self.optimizer.iterations
use_graph_rewrite = flags_obj.fp16_implementation == 'graph_rewrite'
if use_graph_rewrite and not flags_obj.use_tf_function:
raise ValueError('--fp16_implementation=graph_rewrite requires '
'--use_tf_function to be true')
self.optimizer = performance.configure_optimizer( self.optimizer = performance.configure_optimizer(
self.optimizer, self.optimizer,
use_float16=self.dtype == tf.float16, use_float16=self.dtype == tf.float16,
use_graph_rewrite=use_graph_rewrite,
loss_scale=flags_core.get_loss_scale(flags_obj, default_for_fp16=128)) loss_scale=flags_core.get_loss_scale(flags_obj, default_for_fp16=128))
self.train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32) self.train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment