Commit e16594d1 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 281473612
parent 5c15ce77
...@@ -34,13 +34,13 @@ from official.nlp.bert import run_squad ...@@ -34,13 +34,13 @@ from official.nlp.bert import run_squad
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
# pylint: disable=line-too-long # pylint: disable=line-too-long
PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/bert/tf_20/uncased_L-24_H-1024_A-16/bert_model.ckpt' PRETRAINED_CHECKPOINT_PATH = '/placer/prod/home/tensorflow-performance-data/datasets/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_model.ckpt'
SQUAD_TRAIN_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_train.tf_record' SQUAD_TRAIN_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_train.tf_record'
SQUAD_PREDICT_FILE = 'gs://tf-perfzero-data/bert/squad/dev-v1.1.json' SQUAD_PREDICT_FILE = 'gs://tf-perfzero-data/bert/squad/dev-v1.1.json'
SQUAD_VOCAB_FILE = 'gs://tf-perfzero-data/bert/squad/vocab.txt' SQUAD_VOCAB_FILE = '/placer/prod/home/tensorflow-performance-data/datasets/bert/keras_bert/uncased_L-24_H-1024_A-16/vocab.txt'
SQUAD_MEDIUM_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_medium_meta_data' SQUAD_MEDIUM_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_medium_meta_data'
SQUAD_FULL_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_full_meta_data' SQUAD_FULL_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_full_meta_data'
MODEL_CONFIG_FILE_PATH = 'gs://cloud-tpu-checkpoints/bert/tf_20/uncased_L-24_H-1024_A-16/bert_config' MODEL_CONFIG_FILE_PATH = '/placer/prod/home/tensorflow-performance-data/datasets/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_config.json'
# pylint: enable=line-too-long # pylint: enable=line-too-long
TMP_DIR = os.getenv('TMPDIR') TMP_DIR = os.getenv('TMPDIR')
...@@ -340,6 +340,7 @@ class BertSquadAccuracy(BertSquadBenchmarkBase): ...@@ -340,6 +340,7 @@ class BertSquadAccuracy(BertSquadBenchmarkBase):
FLAGS.init_checkpoint = PRETRAINED_CHECKPOINT_PATH FLAGS.init_checkpoint = PRETRAINED_CHECKPOINT_PATH
FLAGS.num_train_epochs = 2 FLAGS.num_train_epochs = 2
FLAGS.steps_per_loop = 1 FLAGS.steps_per_loop = 1
FLAGS.use_keras_bert_for_squad = True
def _run_and_report_benchmark(self, def _run_and_report_benchmark(self,
use_ds=True, use_ds=True,
......
...@@ -125,11 +125,13 @@ class Transformer(tf.keras.layers.Layer): ...@@ -125,11 +125,13 @@ class Transformer(tf.keras.layers.Layer):
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
self._attention_layer_norm = ( self._attention_layer_norm = (
tf.keras.layers.LayerNormalization( tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm", axis=-1, epsilon=1e-12, name="self_attention_layer_norm",
axis=-1,
epsilon=1e-12,
dtype=tf.float32)) dtype=tf.float32))
self._intermediate_dense = dense_einsum.DenseEinsum( self._intermediate_dense = dense_einsum.DenseEinsum(
output_shape=self._intermediate_size, output_shape=self._intermediate_size,
activation=self._intermediate_activation, activation=None,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
...@@ -137,8 +139,9 @@ class Transformer(tf.keras.layers.Layer): ...@@ -137,8 +139,9 @@ class Transformer(tf.keras.layers.Layer):
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint, bias_constraint=self._bias_constraint,
dtype=tf.float32, # This layer is always float32 for numeric stability.
name="intermediate") name="intermediate")
self._intermediate_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation)
self._output_dense = dense_einsum.DenseEinsum( self._output_dense = dense_einsum.DenseEinsum(
output_shape=hidden_size, output_shape=hidden_size,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
...@@ -215,7 +218,14 @@ class Transformer(tf.keras.layers.Layer): ...@@ -215,7 +218,14 @@ class Transformer(tf.keras.layers.Layer):
attention_output) attention_output)
intermediate_output = self._intermediate_dense(attention_output) intermediate_output = self._intermediate_dense(attention_output)
if self.dtype == tf.float16: if self.dtype == tf.float16:
# Casts to float32 so that activation is done in float32.
intermediate_output = tf.cast(intermediate_output, tf.float32)
intermediate_output = self._intermediate_activation_layer(
intermediate_output)
intermediate_output = tf.cast(intermediate_output, tf.float16) intermediate_output = tf.cast(intermediate_output, tf.float16)
else:
intermediate_output = self._intermediate_activation_layer(
intermediate_output)
layer_output = self._output_dense(intermediate_output) layer_output = self._output_dense(intermediate_output)
layer_output = self._output_dropout(layer_output) layer_output = self._output_dropout(layer_output)
# Use float32 in keras layer norm for numeric stability # Use float32 in keras layer norm for numeric stability
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment