Unverified Commit 1f3247f4 authored by Ayushman Kumar's avatar Ayushman Kumar Committed by GitHub
Browse files

Merge pull request #6 from tensorflow/master

Updated
parents 370a4c8d 0265f59c
...@@ -80,6 +80,8 @@ installable Official Models package. This is being tracked in ...@@ -80,6 +80,8 @@ installable Official Models package. This is being tracked in
### Natural Language Processing ### Natural Language Processing
* [albert](nlp/albert): A Lite BERT for Self-supervised Learning of Language
Representations.
* [bert](nlp/bert): A powerful pre-trained language representation model: * [bert](nlp/bert): A powerful pre-trained language representation model:
BERT, which stands for Bidirectional Encoder Representations from BERT, which stands for Bidirectional Encoder Representations from
Transformers. Transformers.
......
...@@ -212,39 +212,6 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase): ...@@ -212,39 +212,6 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
'summaries/training_summary.txt') 'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path, use_ds=False) self._run_and_report_benchmark(summary_path, use_ds=False)
def benchmark_2_gpu_mrpc(self):
"""Test BERT model performance with 2 GPUs."""
self._setup()
self.num_gpus = 2
FLAGS.model_dir = self._get_model_dir('benchmark_2_gpu_mrpc')
FLAGS.train_data_path = self.train_data_path
FLAGS.eval_data_path = self.eval_data_path
FLAGS.input_meta_data_path = self.input_meta_data_path
FLAGS.bert_config_file = self.bert_config_file
FLAGS.train_batch_size = 8
FLAGS.eval_batch_size = 8
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path)
def benchmark_4_gpu_mrpc(self):
"""Test BERT model performance with 4 GPUs."""
self._setup()
self.num_gpus = 4
FLAGS.model_dir = self._get_model_dir('benchmark_4_gpu_mrpc')
FLAGS.train_data_path = self.train_data_path
FLAGS.eval_data_path = self.eval_data_path
FLAGS.input_meta_data_path = self.input_meta_data_path
FLAGS.bert_config_file = self.bert_config_file
FLAGS.train_batch_size = 16
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path)
def benchmark_8_gpu_mrpc(self): def benchmark_8_gpu_mrpc(self):
"""Test BERT model performance with 8 GPUs.""" """Test BERT model performance with 8 GPUs."""
......
...@@ -24,12 +24,12 @@ import time ...@@ -24,12 +24,12 @@ import time
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
from absl import flags from absl import flags
from absl import logging
from absl.testing import flagsaver from absl.testing import flagsaver
import tensorflow as tf import tensorflow as tf
# pylint: enable=g-bad-import-order # pylint: enable=g-bad-import-order
from official.benchmark import bert_benchmark_utils as benchmark_utils from official.benchmark import bert_benchmark_utils as benchmark_utils
from official.benchmark import squad_evaluate_v1_1
from official.nlp.bert import run_squad from official.nlp.bert import run_squad
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
...@@ -70,18 +70,6 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase): ...@@ -70,18 +70,6 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
return json.loads(reader.read().decode('utf-8')) return json.loads(reader.read().decode('utf-8'))
def _read_predictions_dataset_from_file(self):
"""Reads the predictions dataset from a file."""
with tf.io.gfile.GFile(SQUAD_PREDICT_FILE, 'r') as reader:
dataset_json = json.load(reader)
return dataset_json['data']
def _read_predictions_from_file(self):
"""Reads the predictions from a file."""
predictions_file = os.path.join(FLAGS.model_dir, 'predictions.json')
with tf.io.gfile.GFile(predictions_file, 'r') as reader:
return json.load(reader)
def _get_distribution_strategy(self, ds_type='mirrored'): def _get_distribution_strategy(self, ds_type='mirrored'):
"""Gets the distribution strategy. """Gets the distribution strategy.
...@@ -135,12 +123,10 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase): ...@@ -135,12 +123,10 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
input_meta_data = self._read_input_meta_data_from_file() input_meta_data = self._read_input_meta_data_from_file()
strategy = self._get_distribution_strategy(ds_type) strategy = self._get_distribution_strategy(ds_type)
run_squad.predict_squad(strategy=strategy, input_meta_data=input_meta_data) if input_meta_data.get('version_2_with_negative', False):
logging.error('In memory evaluation result for SQuAD v2 is not accurate')
dataset = self._read_predictions_dataset_from_file() eval_metrics = run_squad.eval_squad(strategy=strategy,
predictions = self._read_predictions_from_file() input_meta_data=input_meta_data)
eval_metrics = squad_evaluate_v1_1.evaluate(dataset, predictions)
# Use F1 score as reported evaluation metric. # Use F1 score as reported evaluation metric.
self.eval_metrics = eval_metrics['f1'] self.eval_metrics = eval_metrics['f1']
......
...@@ -20,6 +20,7 @@ from __future__ import print_function ...@@ -20,6 +20,7 @@ from __future__ import print_function
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from official.benchmark.models import resnet_cifar_model from official.benchmark.models import resnet_cifar_model
...@@ -100,7 +101,7 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback): ...@@ -100,7 +101,7 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
if lr != self.prev_lr: if lr != self.prev_lr:
self.model.optimizer.learning_rate = lr # lr should be a float here self.model.optimizer.learning_rate = lr # lr should be a float here
self.prev_lr = lr self.prev_lr = lr
tf.compat.v1.logging.debug( logging.debug(
'Epoch %05d Batch %05d: LearningRateBatchScheduler ' 'Epoch %05d Batch %05d: LearningRateBatchScheduler '
'change learning rate to %s.', self.epochs, batch, lr) 'change learning rate to %s.', self.epochs, batch, lr)
...@@ -137,8 +138,8 @@ def run(flags_obj): ...@@ -137,8 +138,8 @@ def run(flags_obj):
data_format = flags_obj.data_format data_format = flags_obj.data_format
if data_format is None: if data_format is None:
data_format = ('channels_first' data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
if tf.test.is_built_with_cuda() else 'channels_last') else 'channels_last')
tf.keras.backend.set_image_data_format(data_format) tf.keras.backend.set_image_data_format(data_format)
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
...@@ -280,6 +281,6 @@ def main(_): ...@@ -280,6 +281,6 @@ def main(_):
if __name__ == '__main__': if __name__ == '__main__':
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) logging.set_verbosity(logging.INFO)
define_cifar_flags() define_cifar_flags()
app.run(main) app.run(main)
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Executes Keras benchmarks and accuracy tests.""" """Executes Keras benchmarks and accuracy tests."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
...@@ -22,6 +21,7 @@ import os ...@@ -22,6 +21,7 @@ import os
import time import time
from absl import flags from absl import flags
from absl import logging
from absl.testing import flagsaver from absl.testing import flagsaver
import tensorflow as tf import tensorflow as tf
...@@ -51,7 +51,7 @@ class NCFKerasBenchmarkBase(tf.test.Benchmark): ...@@ -51,7 +51,7 @@ class NCFKerasBenchmarkBase(tf.test.Benchmark):
def _setup(self): def _setup(self):
"""Sets up and resets flags before each test.""" """Sets up and resets flags before each test."""
assert tf.version.VERSION.startswith('2.') assert tf.version.VERSION.startswith('2.')
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) logging.set_verbosity(logging.INFO)
if NCFKerasBenchmarkBase.local_flags is None: if NCFKerasBenchmarkBase.local_flags is None:
ncf_common.define_ncf_flags() ncf_common.define_ncf_flags()
# Loads flags to get defaults to then override. List cannot be empty. # Loads flags to get defaults to then override. List cannot be empty.
......
...@@ -269,6 +269,7 @@ python run_classifier.py \ ...@@ -269,6 +269,7 @@ python run_classifier.py \
--init_checkpoint=${BERT_DIR}/bert_model.ckpt \ --init_checkpoint=${BERT_DIR}/bert_model.ckpt \
--train_batch_size=32 \ --train_batch_size=32 \
--eval_batch_size=32 \ --eval_batch_size=32 \
--steps_per_loop=1000 \
--learning_rate=2e-5 \ --learning_rate=2e-5 \
--num_train_epochs=3 \ --num_train_epochs=3 \
--model_dir=${MODEL_DIR} \ --model_dir=${MODEL_DIR} \
...@@ -276,6 +277,10 @@ python run_classifier.py \ ...@@ -276,6 +277,10 @@ python run_classifier.py \
--tpu=grpc://${TPU_IP_ADDRESS}:8470 --tpu=grpc://${TPU_IP_ADDRESS}:8470
``` ```
Note that, we specify `steps_per_loop=1000` for TPU, because running a loop of
training steps inside a `tf.function` can significantly increase TPU utilization
and callbacks will not be called inside the loop.
### SQuAD 1.1 ### SQuAD 1.1
The Stanford Question Answering Dataset (SQuAD) is a popular question answering The Stanford Question Answering Dataset (SQuAD) is a popular question answering
......
...@@ -57,7 +57,7 @@ def define_common_bert_flags(): ...@@ -57,7 +57,7 @@ def define_common_bert_flags():
flags.DEFINE_integer('num_train_epochs', 3, flags.DEFINE_integer('num_train_epochs', 3,
'Total number of training epochs to perform.') 'Total number of training epochs to perform.')
flags.DEFINE_integer( flags.DEFINE_integer(
'steps_per_loop', 200, 'steps_per_loop', 1,
'Number of steps per graph-mode loop. Only training step ' 'Number of steps per graph-mode loop. Only training step '
'happens inside the loop. Callbacks will not be called ' 'happens inside the loop. Callbacks will not be called '
'inside.') 'inside.')
......
...@@ -415,7 +415,7 @@ def run_customized_training_loop( ...@@ -415,7 +415,7 @@ def run_customized_training_loop(
# Runs several steps in the host while loop. # Runs several steps in the host while loop.
steps = steps_to_run(current_step, steps_per_epoch, steps_per_loop) steps = steps_to_run(current_step, steps_per_epoch, steps_per_loop)
if tf.test.is_built_with_cuda(): if tf.config.list_physical_devices('GPU'):
# TODO(zongweiz): merge with train_steps once tf.while_loop # TODO(zongweiz): merge with train_steps once tf.while_loop
# GPU performance bugs are fixed. # GPU performance bugs are fixed.
for _ in range(steps): for _ in range(steps):
......
...@@ -20,6 +20,7 @@ from __future__ import print_function ...@@ -20,6 +20,7 @@ from __future__ import print_function
import os import os
from absl import logging
from absl.testing import parameterized from absl.testing import parameterized
from absl.testing.absltest import mock from absl.testing.absltest import mock
import numpy as np import numpy as np
...@@ -27,7 +28,7 @@ import tensorflow as tf ...@@ -27,7 +28,7 @@ import tensorflow as tf
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from official.modeling import model_training_utils from official.nlp.bert import model_training_utils
def eager_strategy_combinations(): def eager_strategy_combinations():
...@@ -125,7 +126,7 @@ def summaries_with_matching_keyword(keyword, summary_dir): ...@@ -125,7 +126,7 @@ def summaries_with_matching_keyword(keyword, summary_dir):
if event.summary is not None: if event.summary is not None:
for value in event.summary.value: for value in event.summary.value:
if keyword in value.tag: if keyword in value.tag:
tf.compat.v1.logging.error(event) logging.error(event)
yield event.summary yield event.summary
......
...@@ -25,8 +25,6 @@ from absl import app ...@@ -25,8 +25,6 @@ from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.modeling import model_training_utils
from official.modeling import performance from official.modeling import performance
from official.nlp import optimization from official.nlp import optimization
from official.nlp.bert import bert_models from official.nlp.bert import bert_models
...@@ -34,6 +32,7 @@ from official.nlp.bert import common_flags ...@@ -34,6 +32,7 @@ from official.nlp.bert import common_flags
from official.nlp.bert import configs as bert_configs from official.nlp.bert import configs as bert_configs
from official.nlp.bert import input_pipeline from official.nlp.bert import input_pipeline
from official.nlp.bert import model_saving_utils from official.nlp.bert import model_saving_utils
from official.nlp.bert import model_training_utils
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
...@@ -156,6 +155,7 @@ def run_bert_classifier(strategy, ...@@ -156,6 +155,7 @@ def run_bert_classifier(strategy,
init_checkpoint, init_checkpoint,
epochs, epochs,
steps_per_epoch, steps_per_epoch,
steps_per_loop,
eval_steps, eval_steps,
custom_callbacks=custom_callbacks) custom_callbacks=custom_callbacks)
...@@ -189,6 +189,7 @@ def run_keras_compile_fit(model_dir, ...@@ -189,6 +189,7 @@ def run_keras_compile_fit(model_dir,
init_checkpoint, init_checkpoint,
epochs, epochs,
steps_per_epoch, steps_per_epoch,
steps_per_loop,
eval_steps, eval_steps,
custom_callbacks=None): custom_callbacks=None):
"""Runs BERT classifier model using Keras compile/fit API.""" """Runs BERT classifier model using Keras compile/fit API."""
...@@ -203,7 +204,11 @@ def run_keras_compile_fit(model_dir, ...@@ -203,7 +204,11 @@ def run_keras_compile_fit(model_dir,
checkpoint = tf.train.Checkpoint(model=sub_model) checkpoint = tf.train.Checkpoint(model=sub_model)
checkpoint.restore(init_checkpoint).assert_existing_objects_matched() checkpoint.restore(init_checkpoint).assert_existing_objects_matched()
bert_model.compile(optimizer=optimizer, loss=loss_fn, metrics=[metric_fn()]) bert_model.compile(
optimizer=optimizer,
loss=loss_fn,
metrics=[metric_fn()],
experimental_steps_per_execution=steps_per_loop)
summary_dir = os.path.join(model_dir, 'summaries') summary_dir = os.path.join(model_dir, 'summaries')
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir) summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
......
...@@ -22,14 +22,13 @@ from absl import flags ...@@ -22,14 +22,13 @@ from absl import flags
from absl import logging from absl import logging
import gin import gin
import tensorflow as tf import tensorflow as tf
from official.modeling import model_training_utils
from official.modeling import performance from official.modeling import performance
from official.nlp import optimization from official.nlp import optimization
from official.nlp.bert import bert_models from official.nlp.bert import bert_models
from official.nlp.bert import common_flags from official.nlp.bert import common_flags
from official.nlp.bert import configs from official.nlp.bert import configs
from official.nlp.bert import input_pipeline from official.nlp.bert import input_pipeline
from official.nlp.bert import model_training_utils
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
......
...@@ -19,9 +19,13 @@ from __future__ import division ...@@ -19,9 +19,13 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import json import json
import os
import tempfile
import time
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging
import tensorflow as tf import tensorflow as tf
from official.nlp.bert import configs as bert_configs from official.nlp.bert import configs as bert_configs
...@@ -52,12 +56,22 @@ def train_squad(strategy, ...@@ -52,12 +56,22 @@ def train_squad(strategy,
def predict_squad(strategy, input_meta_data): def predict_squad(strategy, input_meta_data):
"""Makes predictions for a squad dataset.""" """Makes predictions for the squad dataset."""
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file) bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
tokenizer = tokenization.FullTokenizer( tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
run_squad_helper.predict_squad(strategy, input_meta_data, tokenizer, run_squad_helper.predict_squad(
bert_config, squad_lib_wp) strategy, input_meta_data, tokenizer, bert_config, squad_lib_wp)
def eval_squad(strategy, input_meta_data):
"""Evaluate on the squad dataset."""
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
eval_metrics = run_squad_helper.eval_squad(
strategy, input_meta_data, tokenizer, bert_config, squad_lib_wp)
return eval_metrics
def export_squad(model_export_path, input_meta_data): def export_squad(model_export_path, input_meta_data):
...@@ -93,7 +107,8 @@ def main(_): ...@@ -93,7 +107,8 @@ def main(_):
num_gpus=FLAGS.num_gpus, num_gpus=FLAGS.num_gpus,
all_reduce_alg=FLAGS.all_reduce_alg, all_reduce_alg=FLAGS.all_reduce_alg,
tpu_address=FLAGS.tpu) tpu_address=FLAGS.tpu)
if FLAGS.mode in ('train', 'train_and_predict'):
if 'train' in FLAGS.mode:
if FLAGS.log_steps: if FLAGS.log_steps:
custom_callbacks = [keras_utils.TimeHistory( custom_callbacks = [keras_utils.TimeHistory(
batch_size=FLAGS.train_batch_size, batch_size=FLAGS.train_batch_size,
...@@ -109,8 +124,25 @@ def main(_): ...@@ -109,8 +124,25 @@ def main(_):
custom_callbacks=custom_callbacks, custom_callbacks=custom_callbacks,
run_eagerly=FLAGS.run_eagerly, run_eagerly=FLAGS.run_eagerly,
) )
if FLAGS.mode in ('predict', 'train_and_predict'): if 'predict' in FLAGS.mode:
predict_squad(strategy, input_meta_data) predict_squad(strategy, input_meta_data)
if 'eval' in FLAGS.mode:
eval_metrics = eval_squad(strategy, input_meta_data)
f1_score = eval_metrics['final_f1']
logging.info('SQuAD eval F1-score: %f', f1_score)
if (not strategy) or strategy.extended.should_save_summary:
summary_dir = os.path.join(FLAGS.model_dir, 'summaries')
else:
summary_dir = tempfile.mkdtemp()
summary_writer = tf.summary.create_file_writer(
os.path.join(summary_dir, 'eval'))
with summary_writer.as_default():
# TODO(lehou): write to the correct step number.
tf.summary.scalar('F1-score', f1_score, step=0)
summary_writer.flush()
# Wait for some time, for the depending mldash/tensorboard jobs to finish
# exporting the final F1-score.
time.sleep(60)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -18,18 +18,20 @@ from __future__ import division ...@@ -18,18 +18,20 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections import collections
import json
import os import os
from absl import flags from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.modeling import model_training_utils
from official.modeling import performance from official.modeling import performance
from official.nlp import optimization from official.nlp import optimization
from official.nlp.bert import bert_models from official.nlp.bert import bert_models
from official.nlp.bert import common_flags from official.nlp.bert import common_flags
from official.nlp.bert import input_pipeline from official.nlp.bert import input_pipeline
from official.nlp.bert import model_saving_utils from official.nlp.bert import model_saving_utils
from official.nlp.bert import model_training_utils
from official.nlp.bert import squad_evaluate_v1_1
from official.nlp.bert import squad_evaluate_v2_0
from official.nlp.data import squad_lib_sp from official.nlp.data import squad_lib_sp
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
...@@ -37,11 +39,15 @@ from official.utils.misc import keras_utils ...@@ -37,11 +39,15 @@ from official.utils.misc import keras_utils
def define_common_squad_flags(): def define_common_squad_flags():
"""Defines common flags used by SQuAD tasks.""" """Defines common flags used by SQuAD tasks."""
flags.DEFINE_enum( flags.DEFINE_enum(
'mode', 'train_and_predict', 'mode', 'train_and_eval',
['train_and_predict', 'train', 'predict', 'export_only'], ['train_and_eval', 'train_and_predict',
'One of {"train_and_predict", "train", "predict", "export_only"}. ' 'train', 'eval', 'predict', 'export_only'],
'`train_and_predict`: both train and predict to a json file. ' 'One of {"train_and_eval", "train_and_predict", '
'"train", "eval", "predict", "export_only"}. '
'`train_and_eval`: train & predict to json files & compute eval metrics. '
'`train_and_predict`: train & predict to json files. '
'`train`: only trains the model. ' '`train`: only trains the model. '
'`eval`: predict answers from squad json file & compute eval metrics. '
'`predict`: predict answers from the squad json file. ' '`predict`: predict answers from the squad json file. '
'`export_only`: will take the latest checkpoint inside ' '`export_only`: will take the latest checkpoint inside '
'model_dir and export a `SavedModel`.') 'model_dir and export a `SavedModel`.')
...@@ -271,7 +277,8 @@ def train_squad(strategy, ...@@ -271,7 +277,8 @@ def train_squad(strategy,
post_allreduce_callbacks=[clip_by_global_norm_callback]) post_allreduce_callbacks=[clip_by_global_norm_callback])
def predict_squad(strategy, input_meta_data, tokenizer, bert_config, squad_lib): def prediction_output_squad(
strategy, input_meta_data, tokenizer, bert_config, squad_lib):
"""Makes predictions for a squad dataset.""" """Makes predictions for a squad dataset."""
doc_stride = input_meta_data['doc_stride'] doc_stride = input_meta_data['doc_stride']
max_query_length = input_meta_data['max_query_length'] max_query_length = input_meta_data['max_query_length']
...@@ -322,23 +329,61 @@ def predict_squad(strategy, input_meta_data, tokenizer, bert_config, squad_lib): ...@@ -322,23 +329,61 @@ def predict_squad(strategy, input_meta_data, tokenizer, bert_config, squad_lib):
all_results = predict_squad_customized(strategy, input_meta_data, bert_config, all_results = predict_squad_customized(strategy, input_meta_data, bert_config,
eval_writer.filename, num_steps) eval_writer.filename, num_steps)
output_prediction_file = os.path.join(FLAGS.model_dir, 'predictions.json') all_predictions, all_nbest_json, scores_diff_json = (
output_nbest_file = os.path.join(FLAGS.model_dir, 'nbest_predictions.json') squad_lib.postprocess_output(
output_null_log_odds_file = os.path.join(FLAGS.model_dir, 'null_odds.json')
squad_lib.write_predictions(
eval_examples, eval_examples,
eval_features, eval_features,
all_results, all_results,
FLAGS.n_best_size, FLAGS.n_best_size,
FLAGS.max_answer_length, FLAGS.max_answer_length,
FLAGS.do_lower_case, FLAGS.do_lower_case,
output_prediction_file,
output_nbest_file,
output_null_log_odds_file,
version_2_with_negative=version_2_with_negative, version_2_with_negative=version_2_with_negative,
null_score_diff_threshold=FLAGS.null_score_diff_threshold, null_score_diff_threshold=FLAGS.null_score_diff_threshold,
verbose=FLAGS.verbose_logging) verbose=FLAGS.verbose_logging))
return all_predictions, all_nbest_json, scores_diff_json
def dump_to_files(all_predictions, all_nbest_json, scores_diff_json,
squad_lib, version_2_with_negative):
"""Save output to json files."""
output_prediction_file = os.path.join(FLAGS.model_dir, 'predictions.json')
output_nbest_file = os.path.join(FLAGS.model_dir, 'nbest_predictions.json')
output_null_log_odds_file = os.path.join(FLAGS.model_dir, 'null_odds.json')
logging.info('Writing predictions to: %s', (output_prediction_file))
logging.info('Writing nbest to: %s', (output_nbest_file))
squad_lib.write_to_json_files(all_predictions, output_prediction_file)
squad_lib.write_to_json_files(all_nbest_json, output_nbest_file)
if version_2_with_negative:
squad_lib.write_to_json_files(scores_diff_json, output_null_log_odds_file)
def predict_squad(strategy, input_meta_data, tokenizer, bert_config, squad_lib):
"""Get prediction results and evaluate them to hard drive."""
all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad(
strategy, input_meta_data, tokenizer, bert_config, squad_lib)
dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib,
input_meta_data.get('version_2_with_negative', False))
def eval_squad(strategy, input_meta_data, tokenizer, bert_config, squad_lib):
"""Get prediction results and evaluate them against ground truth."""
all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad(
strategy, input_meta_data, tokenizer, bert_config, squad_lib)
dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib,
input_meta_data.get('version_2_with_negative', False))
with tf.io.gfile.GFile(FLAGS.predict_file, 'r') as reader:
dataset_json = json.load(reader)
pred_dataset = dataset_json['data']
if input_meta_data.get('version_2_with_negative', False):
eval_metrics = squad_evaluate_v2_0.evaluate(pred_dataset,
all_predictions,
scores_diff_json)
else:
eval_metrics = squad_evaluate_v1_1.evaluate(pred_dataset, all_predictions)
return eval_metrics
def export_squad(model_export_path, input_meta_data, bert_config): def export_squad(model_export_path, input_meta_data, bert_config):
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluation of SQuAD predictions (version 1.1).
The functions are copied from
https://worksheets.codalab.org/rest/bundles/0xbcd57bee090b421c982906709c8c27e1/contents/blob/.
The SQuAD dataset is described in this paper:
SQuAD: 100,000+ Questions for Machine Comprehension of Text
Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, Percy Liang
https://nlp.stanford.edu/pubs/rajpurkar2016squad.pdf
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
import string
# pylint: disable=g-bad-import-order
from absl import logging
# pylint: enable=g-bad-import-order
def _normalize_answer(s):
"""Lowers text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def _f1_score(prediction, ground_truth):
"""Computes F1 score by comparing prediction to ground truth."""
prediction_tokens = _normalize_answer(prediction).split()
ground_truth_tokens = _normalize_answer(ground_truth).split()
prediction_counter = collections.Counter(prediction_tokens)
ground_truth_counter = collections.Counter(ground_truth_tokens)
common = prediction_counter & ground_truth_counter
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def _exact_match_score(prediction, ground_truth):
"""Checks if predicted answer exactly matches ground truth answer."""
return _normalize_answer(prediction) == _normalize_answer(ground_truth)
def _metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
"""Computes the max over all metric scores."""
scores_for_ground_truths = []
for ground_truth in ground_truths:
score = metric_fn(prediction, ground_truth)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)
def evaluate(dataset, predictions):
"""Evaluates predictions for a dataset."""
f1 = exact_match = total = 0
for article in dataset:
for paragraph in article["paragraphs"]:
for qa in paragraph["qas"]:
total += 1
if qa["id"] not in predictions:
message = "Unanswered question " + qa["id"] + " will receive score 0."
logging.error(message)
continue
ground_truths = [entry["text"] for entry in qa["answers"]]
prediction = predictions[qa["id"]]
exact_match += _metric_max_over_ground_truths(_exact_match_score,
prediction, ground_truths)
f1 += _metric_max_over_ground_truths(_f1_score, prediction,
ground_truths)
exact_match = exact_match / total
f1 = f1 / total
return {"exact_match": exact_match, "final_f1": f1}
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluation script for SQuAD version 2.0.
The functions are copied and modified from
https://raw.githubusercontent.com/white127/SQUAD-2.0-bidaf/master/evaluate-v2.0.py
In addition to basic functionality, we also compute additional statistics and
plot precision-recall curves if an additional na_prob.json file is provided.
This file is expected to map question ID's to the model's predicted probability
that a question is unanswerable.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
import string
from absl import logging
def _make_qid_to_has_ans(dataset):
qid_to_has_ans = {}
for article in dataset:
for p in article['paragraphs']:
for qa in p['qas']:
qid_to_has_ans[qa['id']] = bool(qa['answers'])
return qid_to_has_ans
def _normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
return re.sub(regex, ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def _get_tokens(s):
if not s: return []
return _normalize_answer(s).split()
def _compute_exact(a_gold, a_pred):
return int(_normalize_answer(a_gold) == _normalize_answer(a_pred))
def _compute_f1(a_gold, a_pred):
"""Compute F1-score."""
gold_toks = _get_tokens(a_gold)
pred_toks = _get_tokens(a_pred)
common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
num_same = sum(common.values())
if not gold_toks or not pred_toks:
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
return int(gold_toks == pred_toks)
if num_same == 0:
return 0
precision = 1.0 * num_same / len(pred_toks)
recall = 1.0 * num_same / len(gold_toks)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def _get_raw_scores(dataset, predictions):
"""Compute raw scores."""
exact_scores = {}
f1_scores = {}
for article in dataset:
for p in article['paragraphs']:
for qa in p['qas']:
qid = qa['id']
gold_answers = [a['text'] for a in qa['answers']
if _normalize_answer(a['text'])]
if not gold_answers:
# For unanswerable questions, only correct answer is empty string
gold_answers = ['']
if qid not in predictions:
logging.error('Missing prediction for %s', qid)
continue
a_pred = predictions[qid]
# Take max over all gold answers
exact_scores[qid] = max(_compute_exact(a, a_pred) for a in gold_answers)
f1_scores[qid] = max(_compute_f1(a, a_pred) for a in gold_answers)
return exact_scores, f1_scores
def _apply_no_ans_threshold(
scores, na_probs, qid_to_has_ans, na_prob_thresh=1.0):
new_scores = {}
for qid, s in scores.items():
pred_na = na_probs[qid] > na_prob_thresh
if pred_na:
new_scores[qid] = float(not qid_to_has_ans[qid])
else:
new_scores[qid] = s
return new_scores
def _make_eval_dict(exact_scores, f1_scores, qid_list=None):
"""Make evaluation result dictionary."""
if not qid_list:
total = len(exact_scores)
return collections.OrderedDict([
('exact', 100.0 * sum(exact_scores.values()) / total),
('f1', 100.0 * sum(f1_scores.values()) / total),
('total', total),
])
else:
total = len(qid_list)
return collections.OrderedDict([
('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total),
('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total),
('total', total),
])
def _merge_eval(main_eval, new_eval, prefix):
for k in new_eval:
main_eval['%s_%s' % (prefix, k)] = new_eval[k]
def _make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans):
"""Make evaluation dictionary containing average recision recall."""
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
true_pos = 0.0
cur_p = 1.0
cur_r = 0.0
precisions = [1.0]
recalls = [0.0]
avg_prec = 0.0
for i, qid in enumerate(qid_list):
if qid_to_has_ans[qid]:
true_pos += scores[qid]
cur_p = true_pos / float(i+1)
cur_r = true_pos / float(num_true_pos)
if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i+1]]:
# i.e., if we can put a threshold after this point
avg_prec += cur_p * (cur_r - recalls[-1])
precisions.append(cur_p)
recalls.append(cur_r)
return {'ap': 100.0 * avg_prec}
def _run_precision_recall_analysis(
main_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans):
"""Run precision recall analysis and return result dictionary."""
num_true_pos = sum(1 for v in qid_to_has_ans.values() if v)
if num_true_pos == 0:
return
pr_exact = _make_precision_recall_eval(
exact_raw, na_probs, num_true_pos, qid_to_has_ans)
pr_f1 = _make_precision_recall_eval(
f1_raw, na_probs, num_true_pos, qid_to_has_ans)
oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()}
pr_oracle = _make_precision_recall_eval(
oracle_scores, na_probs, num_true_pos, qid_to_has_ans)
_merge_eval(main_eval, pr_exact, 'pr_exact')
_merge_eval(main_eval, pr_f1, 'pr_f1')
_merge_eval(main_eval, pr_oracle, 'pr_oracle')
def _find_best_thresh(predictions, scores, na_probs, qid_to_has_ans):
"""Find the best threshold for no answer probability."""
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
cur_score = num_no_ans
best_score = cur_score
best_thresh = 0.0
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
for qid in qid_list:
if qid not in scores: continue
if qid_to_has_ans[qid]:
diff = scores[qid]
else:
if predictions[qid]:
diff = -1
else:
diff = 0
cur_score += diff
if cur_score > best_score:
best_score = cur_score
best_thresh = na_probs[qid]
return 100.0 * best_score / len(scores), best_thresh
def _find_all_best_thresh(
main_eval, predictions, exact_raw, f1_raw, na_probs, qid_to_has_ans):
best_exact, exact_thresh = _find_best_thresh(
predictions, exact_raw, na_probs, qid_to_has_ans)
best_f1, f1_thresh = _find_best_thresh(
predictions, f1_raw, na_probs, qid_to_has_ans)
main_eval['final_exact'] = best_exact
main_eval['final_exact_thresh'] = exact_thresh
main_eval['final_f1'] = best_f1
main_eval['final_f1_thresh'] = f1_thresh
def evaluate(dataset, predictions, na_probs=None):
"""Evaluate prediction results."""
new_orig_data = []
for article in dataset:
for p in article['paragraphs']:
for qa in p['qas']:
if qa['id'] in predictions:
new_para = {'qas': [qa]}
new_article = {'paragraphs': [new_para]}
new_orig_data.append(new_article)
dataset = new_orig_data
if na_probs is None:
na_probs = {k: 0.0 for k in predictions}
qid_to_has_ans = _make_qid_to_has_ans(dataset) # maps qid to True/False
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
exact_raw, f1_raw = _get_raw_scores(dataset, predictions)
exact_thresh = _apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans)
f1_thresh = _apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans)
out_eval = _make_eval_dict(exact_thresh, f1_thresh)
if has_ans_qids:
has_ans_eval = _make_eval_dict(
exact_thresh, f1_thresh, qid_list=has_ans_qids)
_merge_eval(out_eval, has_ans_eval, 'HasAns')
if no_ans_qids:
no_ans_eval = _make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
_merge_eval(out_eval, no_ans_eval, 'NoAns')
_find_all_best_thresh(
out_eval, predictions, exact_raw, f1_raw, na_probs, qid_to_has_ans)
_run_precision_recall_analysis(
out_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans)
return out_eval
...@@ -506,6 +506,34 @@ def write_predictions(all_examples, ...@@ -506,6 +506,34 @@ def write_predictions(all_examples,
logging.info("Writing predictions to: %s", (output_prediction_file)) logging.info("Writing predictions to: %s", (output_prediction_file))
logging.info("Writing nbest to: %s", (output_nbest_file)) logging.info("Writing nbest to: %s", (output_nbest_file))
all_predictions, all_nbest_json, scores_diff_json = (
postprocess_output(all_examples=all_examples,
all_features=all_features,
all_results=all_results,
n_best_size=n_best_size,
max_answer_length=max_answer_length,
do_lower_case=do_lower_case,
version_2_with_negative=version_2_with_negative,
null_score_diff_threshold=null_score_diff_threshold,
verbose=verbose))
write_to_json_files(all_predictions, output_prediction_file)
write_to_json_files(all_nbest_json, output_nbest_file)
if version_2_with_negative:
write_to_json_files(scores_diff_json, output_null_log_odds_file)
def postprocess_output(all_examples,
all_features,
all_results,
n_best_size,
max_answer_length,
do_lower_case,
version_2_with_negative=False,
null_score_diff_threshold=0.0,
verbose=False):
"""Postprocess model output, to form predicton results."""
example_index_to_features = collections.defaultdict(list) example_index_to_features = collections.defaultdict(list)
for feature in all_features: for feature in all_features:
example_index_to_features[feature.example_index].append(feature) example_index_to_features[feature.example_index].append(feature)
...@@ -676,15 +704,12 @@ def write_predictions(all_examples, ...@@ -676,15 +704,12 @@ def write_predictions(all_examples,
all_nbest_json[example.qas_id] = nbest_json all_nbest_json[example.qas_id] = nbest_json
with tf.io.gfile.GFile(output_prediction_file, "w") as writer: return all_predictions, all_nbest_json, scores_diff_json
writer.write(json.dumps(all_predictions, indent=4) + "\n")
with tf.io.gfile.GFile(output_nbest_file, "w") as writer:
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
if version_2_with_negative: def write_to_json_files(json_records, json_file):
with tf.io.gfile.GFile(output_null_log_odds_file, "w") as writer: with tf.io.gfile.GFile(json_file, "w") as writer:
writer.write(json.dumps(scores_diff_json, indent=4) + "\n") writer.write(json.dumps(json_records, indent=4) + "\n")
def get_final_text(pred_text, orig_text, do_lower_case, verbose=False): def get_final_text(pred_text, orig_text, do_lower_case, verbose=False):
......
...@@ -575,10 +575,39 @@ def write_predictions(all_examples, ...@@ -575,10 +575,39 @@ def write_predictions(all_examples,
null_score_diff_threshold=0.0, null_score_diff_threshold=0.0,
verbose=False): verbose=False):
"""Write final predictions to the json file and log-odds of null if needed.""" """Write final predictions to the json file and log-odds of null if needed."""
del do_lower_case, verbose
logging.info("Writing predictions to: %s", (output_prediction_file)) logging.info("Writing predictions to: %s", (output_prediction_file))
logging.info("Writing nbest to: %s", (output_nbest_file)) logging.info("Writing nbest to: %s", (output_nbest_file))
all_predictions, all_nbest_json, scores_diff_json = (
postprocess_output(all_examples=all_examples,
all_features=all_features,
all_results=all_results,
n_best_size=n_best_size,
max_answer_length=max_answer_length,
do_lower_case=do_lower_case,
version_2_with_negative=version_2_with_negative,
null_score_diff_threshold=null_score_diff_threshold,
verbose=verbose))
write_to_json_files(all_predictions, output_prediction_file)
write_to_json_files(all_nbest_json, output_nbest_file)
if version_2_with_negative:
write_to_json_files(scores_diff_json, output_null_log_odds_file)
def postprocess_output(all_examples,
all_features,
all_results,
n_best_size,
max_answer_length,
do_lower_case,
version_2_with_negative=False,
null_score_diff_threshold=0.0,
verbose=False):
"""Postprocess model output, to form predicton results."""
del do_lower_case, verbose
example_index_to_features = collections.defaultdict(list) example_index_to_features = collections.defaultdict(list)
for feature in all_features: for feature in all_features:
example_index_to_features[feature.example_index].append(feature) example_index_to_features[feature.example_index].append(feature)
...@@ -740,15 +769,12 @@ def write_predictions(all_examples, ...@@ -740,15 +769,12 @@ def write_predictions(all_examples,
all_nbest_json[example.qas_id] = nbest_json all_nbest_json[example.qas_id] = nbest_json
with tf.io.gfile.GFile(output_prediction_file, "w") as writer: return all_predictions, all_nbest_json, scores_diff_json
writer.write(json.dumps(all_predictions, indent=4) + "\n")
with tf.io.gfile.GFile(output_nbest_file, "w") as writer:
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
if version_2_with_negative: def write_to_json_files(json_records, json_file):
with tf.io.gfile.GFile(output_null_log_odds_file, "w") as writer: with tf.io.gfile.GFile(json_file, "w") as writer:
writer.write(json.dumps(scores_diff_json, indent=4) + "\n") writer.write(json.dumps(json_records, indent=4) + "\n")
def _get_best_indexes(logits, n_best_size): def _get_best_indexes(logits, n_best_size):
......
...@@ -140,19 +140,19 @@ class AdamWeightDecay(tf.keras.optimizers.Adam): ...@@ -140,19 +140,19 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
def apply_gradients(self, def apply_gradients(self,
grads_and_vars, grads_and_vars,
name=None, name=None,
all_reduce_sum_gradients=True): experimental_aggregate_gradients=True):
grads, tvars = list(zip(*grads_and_vars)) grads, tvars = list(zip(*grads_and_vars))
if all_reduce_sum_gradients: if experimental_aggregate_gradients:
# when all_reduce_sum_gradients = False, apply_gradients() no longer # when experimental_aggregate_gradients = False, apply_gradients() no
# implicitly allreduce gradients, users manually allreduce gradient and # longer implicitly allreduce gradients, users manually allreduce gradient
# passed the allreduced grads_and_vars. For now, the clip_by_global_norm # and passed the allreduced grads_and_vars. For now, the
# will be moved to before the explicit allreduce to keep the math # clip_by_global_norm will be moved to before the explicit allreduce to
# the same as TF 1 and pre TF 2.2 implementation. # keep the math the same as TF 1 and pre TF 2.2 implementation.
(grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
return super(AdamWeightDecay, self).apply_gradients( return super(AdamWeightDecay, self).apply_gradients(
zip(grads, tvars), zip(grads, tvars),
name=name, name=name,
all_reduce_sum_gradients=all_reduce_sum_gradients) experimental_aggregate_gradients=experimental_aggregate_gradients)
def _get_lr(self, var_device, var_dtype, apply_state): def _get_lr(self, var_device, var_dtype, apply_state):
"""Retrieves the learning rate with the given state.""" """Retrieves the learning rate with the given state."""
......
...@@ -239,7 +239,10 @@ def get_callbacks(steps_per_epoch): ...@@ -239,7 +239,10 @@ def get_callbacks(steps_per_epoch):
"""Returns common callbacks.""" """Returns common callbacks."""
callbacks = [] callbacks = []
if FLAGS.enable_time_history: if FLAGS.enable_time_history:
time_callback = keras_utils.TimeHistory(FLAGS.batch_size, FLAGS.log_steps) time_callback = keras_utils.TimeHistory(
FLAGS.batch_size,
FLAGS.log_steps,
FLAGS.model_dir if FLAGS.enable_tensorboard else None)
callbacks.append(time_callback) callbacks.append(time_callback)
if FLAGS.enable_tensorboard: if FLAGS.enable_tensorboard:
......
...@@ -246,6 +246,11 @@ class TransformerTask(object): ...@@ -246,6 +246,11 @@ class TransformerTask(object):
callbacks = self._create_callbacks(flags_obj.model_dir, 0, params) callbacks = self._create_callbacks(flags_obj.model_dir, 0, params)
# Only TimeHistory callback is supported for CTL
if params["use_ctl"]:
callbacks = [cb for cb in callbacks
if isinstance(cb, keras_utils.TimeHistory)]
# TODO(b/139418525): Refactor the custom training loop logic. # TODO(b/139418525): Refactor the custom training loop logic.
@tf.function @tf.function
def train_steps(iterator, steps): def train_steps(iterator, steps):
...@@ -299,8 +304,13 @@ class TransformerTask(object): ...@@ -299,8 +304,13 @@ class TransformerTask(object):
if not self.use_tpu: if not self.use_tpu:
raise NotImplementedError( raise NotImplementedError(
"Custom training loop on GPUs is not implemented.") "Custom training loop on GPUs is not implemented.")
# Runs training steps. # Runs training steps.
with summary_writer.as_default(): with summary_writer.as_default():
for cb in callbacks:
cb.on_epoch_begin(current_iteration)
cb.on_batch_begin(0)
train_steps( train_steps(
train_ds_iterator, train_ds_iterator,
tf.convert_to_tensor(train_steps_per_eval, dtype=tf.int32)) tf.convert_to_tensor(train_steps_per_eval, dtype=tf.int32))
...@@ -309,10 +319,18 @@ class TransformerTask(object): ...@@ -309,10 +319,18 @@ class TransformerTask(object):
logging.info("Train Step: %d/%d / loss = %s", current_step, logging.info("Train Step: %d/%d / loss = %s", current_step,
flags_obj.train_steps, train_loss) flags_obj.train_steps, train_loss)
for cb in callbacks:
cb.on_batch_end(train_steps_per_eval - 1)
cb.on_epoch_end(current_iteration)
if params["enable_tensorboard"]: if params["enable_tensorboard"]:
for metric_obj in train_metrics: for metric_obj in train_metrics:
tf.compat.v2.summary.scalar(metric_obj.name, metric_obj.result(), tf.compat.v2.summary.scalar(metric_obj.name, metric_obj.result(),
current_step) current_step)
summary_writer.flush()
for cb in callbacks:
cb.on_train_end()
if flags_obj.enable_checkpointing: if flags_obj.enable_checkpointing:
# avoid check-pointing when running for benchmarking. # avoid check-pointing when running for benchmarking.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment