"...git@developer.sourcefind.cn:OpenDAS/EasyR1.git" did not exist on "f92481f0d8a32874630ff6a91ed5ad84fae1d798"
Commit fb35d6be authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Creates modeling/performance.py to include mix prediction related stuff

PiperOrigin-RevId: 297002741
parent 02af9bb5
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions and classes related to training performance."""
import tensorflow as tf
def configure_optimizer(optimizer,
use_float16=False,
use_graph_rewrite=False,
loss_scale="dynamic"):
"""Configures optimizer object with performance options."""
if use_float16:
# Wraps optimizer with a LossScaleOptimizer. This is done automatically
# in compile() with the "mixed_float16" policy, but since we do not call
# compile(), we must wrap the optimizer manually.
optimizer = (
tf.keras.mixed_precision.experimental.LossScaleOptimizer(
optimizer, loss_scale=loss_scale))
if use_graph_rewrite:
# Note: the model dtype must be 'float32', which will ensure
# tf.ckeras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer)
return optimizer
def set_mixed_precision_policy(dtype, loss_scale=None):
"""Sets mix precision policy."""
if dtype == tf.float16:
policy = tf.keras.mixed_precision.experimental.Policy(
'mixed_float16', loss_scale=loss_scale)
tf.keras.mixed_precision.experimental.set_policy(policy)
elif dtype == tf.bfloat16:
policy = tf.keras.mixed_precision.experimental.Policy(
'mixed_bfloat16')
tf.keras.mixed_precision.experimental.set_policy(policy)
elif dtype == tf.float32:
tf.keras.mixed_precision.experimental.set_policy('float32')
else:
raise ValueError("Unexpected dtype: %s" % dtype)
...@@ -69,6 +69,8 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer): ...@@ -69,6 +69,8 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
sentence_labels): sentence_labels):
"""Implements call() for the layer.""" """Implements call() for the layer."""
lm_label_weights = tf.cast(lm_label_weights, tf.float32) lm_label_weights = tf.cast(lm_label_weights, tf.float32)
lm_output = tf.cast(lm_output, tf.float32)
sentence_output = tf.cast(sentence_output, tf.float32)
mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss( mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss(
labels=lm_label_ids, predictions=lm_output, weights=lm_label_weights) labels=lm_label_ids, predictions=lm_output, weights=lm_label_weights)
......
...@@ -88,9 +88,17 @@ def define_common_bert_flags(): ...@@ -88,9 +88,17 @@ def define_common_bert_flags():
) )
def dtype():
return flags_core.get_tf_dtype(flags.FLAGS)
def use_float16(): def use_float16():
return flags_core.get_tf_dtype(flags.FLAGS) == tf.float16 return flags_core.get_tf_dtype(flags.FLAGS) == tf.float16
def use_graph_rewrite():
return flags.FLAGS.fp16_implementation == 'graph_rewrite'
def get_loss_scale(): def get_loss_scale():
return flags_core.get_loss_scale(flags.FLAGS, default_for_fp16='dynamic') return flags_core.get_loss_scale(flags.FLAGS, default_for_fp16='dynamic')
...@@ -27,6 +27,7 @@ from absl import logging ...@@ -27,6 +27,7 @@ from absl import logging
import tensorflow as tf import tensorflow as tf
from official.modeling import model_training_utils from official.modeling import model_training_utils
from official.modeling import performance
from official.nlp import optimization from official.nlp import optimization
from official.nlp.bert import bert_models from official.nlp.bert import bert_models
from official.nlp.bert import common_flags from official.nlp.bert import common_flags
...@@ -126,16 +127,12 @@ def run_bert_classifier(strategy, ...@@ -126,16 +127,12 @@ def run_bert_classifier(strategy,
max_seq_length, max_seq_length,
hub_module_url=FLAGS.hub_module_url, hub_module_url=FLAGS.hub_module_url,
hub_module_trainable=FLAGS.hub_module_trainable)) hub_module_trainable=FLAGS.hub_module_trainable))
classifier_model.optimizer = optimization.create_optimizer( optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps) initial_lr, steps_per_epoch * epochs, warmup_steps)
if FLAGS.fp16_implementation == 'graph_rewrite': classifier_model.optimizer = performance.configure_optimizer(
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as optimizer,
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32' use_float16=common_flags.use_float16(),
# which will ensure tf.compat.v2.keras.mixed_precision and use_graph_rewrite=common_flags.use_graph_rewrite())
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
classifier_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
classifier_model.optimizer)
return classifier_model, core_model return classifier_model, core_model
# During distributed training, loss used for gradient computation is # During distributed training, loss used for gradient computation is
...@@ -302,6 +299,7 @@ def run_bert(strategy, ...@@ -302,6 +299,7 @@ def run_bert(strategy,
raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode) raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode)
# Enables XLA in Session Config. Should not be set for TPU. # Enables XLA in Session Config. Should not be set for TPU.
keras_utils.set_config_v2(FLAGS.enable_xla) keras_utils.set_config_v2(FLAGS.enable_xla)
performance.set_mixed_precision_policy(common_flags.dtype())
epochs = FLAGS.num_train_epochs epochs = FLAGS.num_train_epochs
train_data_size = input_meta_data['train_data_size'] train_data_size = input_meta_data['train_data_size']
......
...@@ -23,6 +23,7 @@ from absl import logging ...@@ -23,6 +23,7 @@ from absl import logging
import tensorflow as tf import tensorflow as tf
from official.modeling import model_training_utils from official.modeling import model_training_utils
from official.modeling import performance
from official.nlp import optimization from official.nlp import optimization
from official.nlp.bert import bert_models from official.nlp.bert import bert_models
from official.nlp.bert import common_flags from official.nlp.bert import common_flags
...@@ -102,16 +103,12 @@ def run_customized_training(strategy, ...@@ -102,16 +103,12 @@ def run_customized_training(strategy,
"""Gets a pretraining model.""" """Gets a pretraining model."""
pretrain_model, core_model = bert_models.pretrain_model( pretrain_model, core_model = bert_models.pretrain_model(
bert_config, max_seq_length, max_predictions_per_seq) bert_config, max_seq_length, max_predictions_per_seq)
pretrain_model.optimizer = optimization.create_optimizer( optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps) initial_lr, steps_per_epoch * epochs, warmup_steps)
if FLAGS.fp16_implementation == 'graph_rewrite': pretrain_model.optimizer = performance.configure_optimizer(
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as optimizer,
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32' use_float16=common_flags.use_float16(),
# which will ensure tf.compat.v2.keras.mixed_precision and use_graph_rewrite=common_flags.use_graph_rewrite())
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
pretrain_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
pretrain_model.optimizer)
return pretrain_model, core_model return pretrain_model, core_model
trained_model = model_training_utils.run_customized_training_loop( trained_model = model_training_utils.run_customized_training_loop(
...@@ -141,6 +138,8 @@ def run_bert_pretrain(strategy): ...@@ -141,6 +138,8 @@ def run_bert_pretrain(strategy):
logging.info('Training using customized training loop TF 2.0 with distrubuted' logging.info('Training using customized training loop TF 2.0 with distrubuted'
'strategy.') 'strategy.')
performance.set_mixed_precision_policy(common_flags.dtype())
return run_customized_training( return run_customized_training(
strategy, strategy,
bert_config, bert_config,
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
See README for description of setting the training schedule and evaluating the See README for description of setting the training schedule and evaluating the
BLEU score. BLEU score.
""" """
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
...@@ -30,19 +29,19 @@ from absl import flags ...@@ -30,19 +29,19 @@ from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
# pylint: disable=g-bad-import-order from official.modeling import performance
from official.nlp.transformer import compute_bleu from official.nlp.transformer import compute_bleu
from official.nlp.transformer.utils import tokenizer
from official.nlp.transformer import data_pipeline from official.nlp.transformer import data_pipeline
from official.nlp.transformer import metrics from official.nlp.transformer import metrics
from official.nlp.transformer import misc from official.nlp.transformer import misc
from official.nlp.transformer import optimizer from official.nlp.transformer import optimizer
from official.nlp.transformer import transformer from official.nlp.transformer import transformer
from official.nlp.transformer import translate from official.nlp.transformer import translate
from official.nlp.transformer.utils import tokenizer
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.logs import logger from official.utils.logs import logger
from official.utils.misc import keras_utils
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
INF = int(1e9) INF = int(1e9)
BLEU_DIR = "bleu" BLEU_DIR = "bleu"
...@@ -180,21 +179,9 @@ class TransformerTask(object): ...@@ -180,21 +179,9 @@ class TransformerTask(object):
else: else:
logging.info("Not using any distribution strategy.") logging.info("Not using any distribution strategy.")
if params["dtype"] == tf.float16: performance.set_mixed_precision_policy(
# TODO(reedwm): It's pretty ugly to set the global policy in a constructor params["dtype"],
# like this. What if multiple instances of TransformerTask are created? flags_core.get_loss_scale(flags_obj, default_for_fp16="dynamic"))
# We should have a better way in the tf.keras.mixed_precision API of doing
# this.
loss_scale = flags_core.get_loss_scale(
flags_obj, default_for_fp16="dynamic")
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
"mixed_float16", loss_scale=loss_scale)
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
elif params["dtype"] == tf.bfloat16:
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
"mixed_bfloat16")
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
@property @property
def use_tpu(self): def use_tpu(self):
...@@ -434,8 +421,6 @@ class TransformerTask(object): ...@@ -434,8 +421,6 @@ class TransformerTask(object):
def _create_optimizer(self): def _create_optimizer(self):
"""Creates optimizer.""" """Creates optimizer."""
params = self.params params = self.params
# TODO(b/139414679): Explore the difference between using
# LearningRateSchedule and callback for GPU runs, and try to merge them.
lr_schedule = optimizer.LearningRateSchedule( lr_schedule = optimizer.LearningRateSchedule(
params["learning_rate"], params["hidden_size"], params["learning_rate"], params["hidden_size"],
params["learning_rate_warmup_steps"]) params["learning_rate_warmup_steps"])
...@@ -445,18 +430,12 @@ class TransformerTask(object): ...@@ -445,18 +430,12 @@ class TransformerTask(object):
params["optimizer_adam_beta2"], params["optimizer_adam_beta2"],
epsilon=params["optimizer_adam_epsilon"]) epsilon=params["optimizer_adam_epsilon"])
if params["dtype"] == tf.float16: opt = performance.configure_optimizer(
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer( opt,
opt, use_float16=params["dtype"] == tf.float16,
loss_scale=flags_core.get_loss_scale( use_graph_rewrite=self.flags_obj.fp16_implementation == "graph_rewrite",
self.flags_obj, default_for_fp16="dynamic")) loss_scale=flags_core.get_loss_scale(
if self.flags_obj.fp16_implementation == "graph_rewrite": self.flags_obj, default_for_fp16="dynamic"))
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.compat.v2.keras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(opt)
return opt return opt
......
...@@ -43,14 +43,15 @@ def get_tf_dtype(flags_obj): ...@@ -43,14 +43,15 @@ def get_tf_dtype(flags_obj):
def get_loss_scale(flags_obj, default_for_fp16): def get_loss_scale(flags_obj, default_for_fp16):
dtype = get_tf_dtype(flags_obj)
if flags_obj.loss_scale == "dynamic": if flags_obj.loss_scale == "dynamic":
return flags_obj.loss_scale return flags_obj.loss_scale
elif flags_obj.loss_scale is not None: elif flags_obj.loss_scale is not None:
return float(flags_obj.loss_scale) return float(flags_obj.loss_scale)
elif flags_obj.dtype == "fp32": elif dtype == tf.float32 or dtype == tf.bfloat16:
return 1 # No loss scaling is needed for fp32 return 1 # No loss scaling is needed for fp32
else: else:
assert flags_obj.dtype == "fp16" assert dtype == tf.float16
return default_for_fp16 return default_for_fp16
......
...@@ -23,6 +23,7 @@ from absl import flags ...@@ -23,6 +23,7 @@ from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.modeling import performance
from official.staging.training import controller from official.staging.training import controller
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.logs import logger from official.utils.logs import logger
...@@ -110,16 +111,7 @@ def run(flags_obj): ...@@ -110,16 +111,7 @@ def run(flags_obj):
keras_utils.set_session_config( keras_utils.set_session_config(
enable_eager=flags_obj.enable_eager, enable_eager=flags_obj.enable_eager,
enable_xla=flags_obj.enable_xla) enable_xla=flags_obj.enable_xla)
performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj))
dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == tf.float16:
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
'mixed_float16')
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
elif dtype == tf.bfloat16:
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
'mixed_bfloat16')
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
# This only affects GPU. # This only affects GPU.
common.set_cudnn_batchnorm_mode() common.set_cudnn_batchnorm_mode()
......
...@@ -28,6 +28,7 @@ import tensorflow as tf ...@@ -28,6 +28,7 @@ import tensorflow as tf
import tensorflow_model_optimization as tfmot import tensorflow_model_optimization as tfmot
from official.benchmark.models import trivial_model from official.benchmark.models import trivial_model
from official.modeling import performance
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.logs import logger from official.utils.logs import logger
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
...@@ -65,17 +66,9 @@ def run(flags_obj): ...@@ -65,17 +66,9 @@ def run(flags_obj):
common.set_cudnn_batchnorm_mode() common.set_cudnn_batchnorm_mode()
dtype = flags_core.get_tf_dtype(flags_obj) dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == tf.float16: performance.set_mixed_precision_policy(
loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128) flags_core.get_tf_dtype(flags_obj),
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( flags_core.get_loss_scale(flags_obj, default_for_fp16=128))
'mixed_float16', loss_scale=loss_scale)
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
if not keras_utils.is_v2_0():
raise ValueError('--dtype=fp16 is not supported in TensorFlow 1.')
elif dtype == tf.bfloat16:
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
'mixed_bfloat16')
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
data_format = flags_obj.data_format data_format = flags_obj.data_format
if data_format is None: if data_format is None:
......
...@@ -20,6 +20,7 @@ from __future__ import print_function ...@@ -20,6 +20,7 @@ from __future__ import print_function
import tensorflow.compat.v2 as tf import tensorflow.compat.v2 as tf
from official.modeling import performance
from official.staging.training import standard_runnable from official.staging.training import standard_runnable
from official.staging.training import utils from official.staging.training import utils
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
...@@ -85,21 +86,15 @@ class ResnetRunnable(standard_runnable.StandardTrainable, ...@@ -85,21 +86,15 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
# Make sure iterations variable is created inside scope. # Make sure iterations variable is created inside scope.
self.global_step = self.optimizer.iterations self.global_step = self.optimizer.iterations
if self.dtype == tf.float16: use_graph_rewrite = flags_obj.fp16_implementation == 'graph_rewrite'
loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128) if use_graph_rewrite and not flags_obj.use_tf_function:
self.optimizer = ( raise ValueError('--fp16_implementation=graph_rewrite requires '
tf.keras.mixed_precision.experimental.LossScaleOptimizer( '--use_tf_function to be true')
self.optimizer, loss_scale)) self.optimizer = performance.configure_optimizer(
elif flags_obj.fp16_implementation == 'graph_rewrite': self.optimizer,
# `dtype` is still float32 in this case. We built the graph in float32 use_float16=self.dtype == tf.float16,
# and let the graph rewrite change parts of it float16. use_graph_rewrite=use_graph_rewrite,
if not flags_obj.use_tf_function: loss_scale=flags_core.get_loss_scale(flags_obj, default_for_fp16=128))
raise ValueError('--fp16_implementation=graph_rewrite requires '
'--use_tf_function to be true')
loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128)
self.optimizer = (
tf.train.experimental.enable_mixed_precision_graph_rewrite(
self.optimizer, loss_scale))
self.train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32) self.train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment