Unverified Commit 965cc3ee authored by Ayushman Kumar's avatar Ayushman Kumar Committed by GitHub
Browse files

Merge pull request #7 from tensorflow/master

updated
parents 1f3247f4 1f685c54
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper functions to generate data directly on devices."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
import string
from absl import logging
import tensorflow as tf
# The `SyntheticDataset` is a temporary solution for generating synthetic data
# directly on devices. It is only useful for Keras with Distribution
# Strategies. We will have better support in `tf.data` or Distribution Strategy
# later.
class SyntheticDataset(object):
"""A dataset that generates synthetic data on each device."""
def __init__(self, dataset, split_by=1):
# dataset.take(1) doesn't have GPU kernel.
with tf.device('device:CPU:0'):
tensor = tf.data.experimental.get_single_element(dataset.take(1))
flat_tensor = tf.nest.flatten(tensor)
variable_data = []
initializers = []
for t in flat_tensor:
rebatched_t = tf.split(t, num_or_size_splits=split_by, axis=0)[0]
assert rebatched_t.shape.is_fully_defined(), rebatched_t.shape
v = tf.compat.v1.get_local_variable(self._random_name(),
initializer=rebatched_t)
variable_data.append(v)
initializers.append(v.initializer)
input_data = tf.nest.pack_sequence_as(tensor, variable_data)
self._iterator = SyntheticIterator(input_data, initializers)
def _random_name(self, size=10, chars=string.ascii_uppercase + string.digits):
return ''.join(random.choice(chars) for _ in range(size))
def __iter__(self):
return self._iterator
def make_one_shot_iterator(self):
return self._iterator
def make_initializable_iterator(self):
return self._iterator
class SyntheticIterator(object):
"""A dataset that generates synthetic data on each device."""
def __init__(self, input_data, initializers):
self._input_data = input_data
self._initializers = initializers
def get_next(self):
return self._input_data
def next(self):
return self.__next__()
def __next__(self):
try:
return self.get_next()
except tf.errors.OutOfRangeError:
raise StopIteration
def initialize(self):
if tf.executing_eagerly():
return tf.no_op()
else:
return self._initializers
def _monkey_patch_dataset_method(strategy):
"""Monkey-patch `strategy`'s `make_dataset_iterator` method."""
def make_dataset(self, dataset):
logging.info('Using pure synthetic data.')
with self.scope():
if self.extended._global_batch_size: # pylint: disable=protected-access
return SyntheticDataset(dataset, self.num_replicas_in_sync)
else:
return SyntheticDataset(dataset)
def make_iterator(self, dataset):
dist_dataset = make_dataset(self, dataset)
return iter(dist_dataset)
strategy.orig_make_dataset_iterator = strategy.make_dataset_iterator
strategy.make_dataset_iterator = make_iterator
strategy.orig_distribute_dataset = strategy.experimental_distribute_dataset
strategy.experimental_distribute_dataset = make_dataset
def _undo_monkey_patch_dataset_method(strategy):
if hasattr(strategy, 'orig_make_dataset_iterator'):
strategy.make_dataset_iterator = strategy.orig_make_dataset_iterator
if hasattr(strategy, 'orig_distribute_dataset'):
strategy.make_dataset_iterator = strategy.orig_distribute_dataset
def set_up_synthetic_data():
_monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy)
_monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
_monkey_patch_dataset_method(
tf.distribute.experimental.MultiWorkerMirroredStrategy)
def undo_set_up_synthetic_data():
_undo_monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy)
_undo_monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
_undo_monkey_patch_dataset_method(
tf.distribute.experimental.MultiWorkerMirroredStrategy)
...@@ -24,11 +24,10 @@ from absl import flags ...@@ -24,11 +24,10 @@ from absl import flags
from absl import logging from absl import logging
from absl.testing import flagsaver from absl.testing import flagsaver
import tensorflow as tf import tensorflow as tf
from official.benchmark import benchmark_wrappers
from official.recommendation import ncf_common from official.recommendation import ncf_common
from official.recommendation import ncf_keras_main from official.recommendation import ncf_keras_main
from official.utils.flags import core from official.utils.flags import core
from official.utils.testing import benchmark_wrappers
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
NCF_DATA_DIR_NAME = 'movielens_data' NCF_DATA_DIR_NAME = 'movielens_data'
...@@ -50,7 +49,6 @@ class NCFKerasBenchmarkBase(tf.test.Benchmark): ...@@ -50,7 +49,6 @@ class NCFKerasBenchmarkBase(tf.test.Benchmark):
def _setup(self): def _setup(self):
"""Sets up and resets flags before each test.""" """Sets up and resets flags before each test."""
assert tf.version.VERSION.startswith('2.')
logging.set_verbosity(logging.INFO) logging.set_verbosity(logging.INFO)
if NCFKerasBenchmarkBase.local_flags is None: if NCFKerasBenchmarkBase.local_flags is None:
ncf_common.define_ncf_flags() ncf_common.define_ncf_flags()
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils to set Owner annotations on benchmarks.
@owner_utils.Owner('owner_team/user') can be set either at the benchmark class
level / benchmark method level or both.
Runner frameworks can use owner_utils.GetOwner(benchmark_method) to get the
actual owner. Python inheritance for the owner attribute is respected. (E.g
method level owner takes precedence over class level).
See owner_utils_test for associated tests and more examples.
The decorator can be applied both at the method level and at the class level.
Simple example:
===============
class MLBenchmark:
@Owner('example_id')
def benchmark_method_1_gpu(self):
return True
"""
def Owner(owner_name):
"""Sets the owner attribute on a decorated method or class."""
def _Wrapper(func_or_class):
"""Sets the benchmark owner attribute."""
func_or_class.__benchmark__owner__ = owner_name
return func_or_class
return _Wrapper
def GetOwner(benchmark_method_or_class):
"""Gets the inherited owner attribute for this benchmark.
Checks for existence of __benchmark__owner__. If it's not present, looks for
it in the parent class's attribute list.
Args:
benchmark_method_or_class: A benchmark method or class.
Returns:
string - the associated owner if present / None.
"""
if hasattr(benchmark_method_or_class, '__benchmark__owner__'):
return benchmark_method_or_class.__benchmark__owner__
elif hasattr(benchmark_method_or_class, '__self__'):
if hasattr(benchmark_method_or_class.__self__, '__benchmark__owner__'):
return benchmark_method_or_class.__self__.__benchmark__owner__
return None
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.benchmark.owner_utils."""
from absl.testing import absltest
from official.benchmark import owner_utils
@owner_utils.Owner('static_owner')
def static_function(foo=5):
return foo
def static_function_without_owner(foo=5):
return foo
class BenchmarkClassWithoutOwner:
def method_without_owner(self):
return 100
@owner_utils.Owner('method_owner')
def method_with_owner(self):
return 200
@owner_utils.Owner('class_owner')
class SomeBenchmarkClass:
def method_inherited_owner(self):
return 123
@owner_utils.Owner('method_owner')
def method_override_owner(self):
return 345
@owner_utils.Owner('new_class_owner')
class InheritedClass(SomeBenchmarkClass):
def method_inherited_owner(self):
return 456
@owner_utils.Owner('new_method_owner')
def method_override_owner(self):
return 567
class OwnerUtilsTest(absltest.TestCase):
"""Tests to assert for owner decorator functionality."""
def test_owner_tag_missing(self):
self.assertEqual(None, owner_utils.GetOwner(static_function_without_owner))
benchmark_class = BenchmarkClassWithoutOwner()
self.assertEqual(None,
owner_utils.GetOwner(benchmark_class.method_without_owner))
self.assertEqual(100, benchmark_class.method_without_owner())
self.assertEqual('method_owner',
owner_utils.GetOwner(benchmark_class.method_with_owner))
self.assertEqual(200, benchmark_class.method_with_owner())
def test_owner_attributes_static(self):
self.assertEqual('static_owner', owner_utils.GetOwner(static_function))
self.assertEqual(5, static_function(5))
def test_owner_attributes_per_class(self):
level1 = SomeBenchmarkClass()
self.assertEqual('class_owner',
owner_utils.GetOwner(level1.method_inherited_owner))
self.assertEqual(123, level1.method_inherited_owner())
self.assertEqual('method_owner',
owner_utils.GetOwner(level1.method_override_owner))
self.assertEqual(345, level1.method_override_owner())
def test_owner_attributes_inherited_class(self):
level2 = InheritedClass()
self.assertEqual('new_class_owner',
owner_utils.GetOwner(level2.method_inherited_owner))
self.assertEqual(456, level2.method_inherited_owner())
self.assertEqual('new_method_owner',
owner_utils.GetOwner(level2.method_override_owner))
self.assertEqual(567, level2.method_override_owner())
if __name__ == '__main__':
absltest.main()
...@@ -13,19 +13,19 @@ ...@@ -13,19 +13,19 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Executes CTL benchmarks and accuracy tests.""" """Executes CTL benchmarks and accuracy tests."""
# pylint: disable=line-too-long,g-bad-import-order
from __future__ import print_function from __future__ import print_function
import os import os
import time import time
# pylint: disable=g-bad-import-order
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
from official.vision.image_classification.resnet import common from official.vision.image_classification.resnet import common
from official.vision.image_classification.resnet import resnet_ctl_imagenet_main from official.vision.image_classification.resnet import resnet_ctl_imagenet_main
from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark from official.benchmark.perfzero_benchmark import PerfZeroBenchmark
from official.utils.testing import benchmark_wrappers from official.benchmark import benchmark_wrappers
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
MIN_TOP_1_ACCURACY = 0.76 MIN_TOP_1_ACCURACY = 0.76
...@@ -53,7 +53,8 @@ class CtlBenchmark(PerfZeroBenchmark): ...@@ -53,7 +53,8 @@ class CtlBenchmark(PerfZeroBenchmark):
top_1_min=None, top_1_min=None,
total_batch_size=None, total_batch_size=None,
log_steps=None, log_steps=None,
warmup=1): warmup=1,
start_time_sec=None):
"""Report benchmark results by writing to local protobuf file. """Report benchmark results by writing to local protobuf file.
Args: Args:
...@@ -64,6 +65,7 @@ class CtlBenchmark(PerfZeroBenchmark): ...@@ -64,6 +65,7 @@ class CtlBenchmark(PerfZeroBenchmark):
total_batch_size: Global batch-size. total_batch_size: Global batch-size.
log_steps: How often the log was created for stats['step_timestamp_log']. log_steps: How often the log was created for stats['step_timestamp_log'].
warmup: number of entries in stats['step_timestamp_log'] to ignore. warmup: number of entries in stats['step_timestamp_log'] to ignore.
start_time_sec: the start time of the program in seconds since epoch.
""" """
metrics = [] metrics = []
...@@ -98,6 +100,12 @@ class CtlBenchmark(PerfZeroBenchmark): ...@@ -98,6 +100,12 @@ class CtlBenchmark(PerfZeroBenchmark):
'value': stats['avg_exp_per_second'] 'value': stats['avg_exp_per_second']
}) })
if start_time_sec and 'step_timestamp_log' in stats:
time_log = stats['step_timestamp_log']
# time_log[0] is recorded at the beginning of the first step.
startup_time = time_log[0].timestamp - start_time_sec
metrics.append({'name': 'startup_time', 'value': startup_time})
flags_str = flags_core.get_nondefault_flags_as_str() flags_str = flags_core.get_nondefault_flags_as_str()
self.report_benchmark( self.report_benchmark(
iters=-1, iters=-1,
...@@ -136,8 +144,6 @@ class Resnet50CtlAccuracy(CtlBenchmark): ...@@ -136,8 +144,6 @@ class Resnet50CtlAccuracy(CtlBenchmark):
FLAGS.epochs_between_evals = 10 FLAGS.epochs_between_evals = 10
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu') FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
FLAGS.dtype = 'fp32' FLAGS.dtype = 'fp32'
# Add some thread tunings to improve performance.
FLAGS.datasets_num_private_threads = 14
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_8_gpu_fp16(self): def benchmark_8_gpu_fp16(self):
...@@ -150,8 +156,6 @@ class Resnet50CtlAccuracy(CtlBenchmark): ...@@ -150,8 +156,6 @@ class Resnet50CtlAccuracy(CtlBenchmark):
FLAGS.epochs_between_evals = 10 FLAGS.epochs_between_evals = 10
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16') FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16')
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
# Add some thread tunings to improve performance.
FLAGS.datasets_num_private_threads = 14
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_8_gpu_amp(self): def benchmark_8_gpu_amp(self):
...@@ -165,8 +169,6 @@ class Resnet50CtlAccuracy(CtlBenchmark): ...@@ -165,8 +169,6 @@ class Resnet50CtlAccuracy(CtlBenchmark):
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp') FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp')
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite' FLAGS.fp16_implementation = 'graph_rewrite'
# Add some thread tunings to improve performance.
FLAGS.datasets_num_private_threads = 14
self._run_and_report_benchmark() self._run_and_report_benchmark()
@benchmark_wrappers.enable_runtime_flags @benchmark_wrappers.enable_runtime_flags
...@@ -181,7 +183,8 @@ class Resnet50CtlAccuracy(CtlBenchmark): ...@@ -181,7 +183,8 @@ class Resnet50CtlAccuracy(CtlBenchmark):
top_1_min=MIN_TOP_1_ACCURACY, top_1_min=MIN_TOP_1_ACCURACY,
top_1_max=MAX_TOP_1_ACCURACY, top_1_max=MAX_TOP_1_ACCURACY,
total_batch_size=FLAGS.batch_size, total_batch_size=FLAGS.batch_size,
log_steps=100) log_steps=100,
start_time_sec=start_time_sec)
def _get_model_dir(self, folder_name): def _get_model_dir(self, folder_name):
return os.path.join(self.output_dir, folder_name) return os.path.join(self.output_dir, folder_name)
...@@ -213,7 +216,8 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -213,7 +216,8 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
wall_time_sec, wall_time_sec,
total_batch_size=FLAGS.batch_size, total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps, log_steps=FLAGS.log_steps,
warmup=warmup) warmup=warmup,
start_time_sec=start_time_sec)
def benchmark_1_gpu_no_dist_strat(self): def benchmark_1_gpu_no_dist_strat(self):
"""Test Keras model with 1 GPU, no distribution strategy.""" """Test Keras model with 1 GPU, no distribution strategy."""
...@@ -278,7 +282,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -278,7 +282,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.num_gpus = 1 FLAGS.num_gpus = 1
FLAGS.distribution_strategy = 'one_device' FLAGS.distribution_strategy = 'one_device'
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_eager') FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_eager')
FLAGS.batch_size = 128 FLAGS.batch_size = 120
FLAGS.use_tf_function = False FLAGS.use_tf_function = False
FLAGS.use_tf_while_loop = False FLAGS.use_tf_while_loop = False
FLAGS.single_l2_loss_op = True FLAGS.single_l2_loss_op = True
...@@ -291,7 +295,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -291,7 +295,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.num_gpus = 1 FLAGS.num_gpus = 1
FLAGS.distribution_strategy = 'one_device' FLAGS.distribution_strategy = 'one_device'
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_fp16_eager') FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_fp16_eager')
FLAGS.batch_size = 250 FLAGS.batch_size = 240
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.use_tf_function = False FLAGS.use_tf_function = False
FLAGS.use_tf_while_loop = False FLAGS.use_tf_while_loop = False
......
...@@ -32,7 +32,7 @@ import tensorflow as tf ...@@ -32,7 +32,7 @@ import tensorflow as tf
from official.benchmark import bert_benchmark_utils as benchmark_utils from official.benchmark import bert_benchmark_utils as benchmark_utils
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.testing import benchmark_wrappers from official.benchmark import benchmark_wrappers
from official.vision.detection import main as detection from official.vision.detection import main as detection
TMP_DIR = os.getenv('TMPDIR') TMP_DIR = os.getenv('TMPDIR')
......
...@@ -23,11 +23,11 @@ import time ...@@ -23,11 +23,11 @@ import time
from absl import flags from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.staging.shakespeare import shakespeare_main from official.benchmark.models.shakespeare import shakespeare_main
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.utils.testing import benchmark_wrappers from official.benchmark import benchmark_wrappers
from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark from official.benchmark.perfzero_benchmark import PerfZeroBenchmark
SHAKESPEARE_TRAIN_DATA = 'shakespeare/shakespeare.txt' SHAKESPEARE_TRAIN_DATA = 'shakespeare/shakespeare.txt'
TMP_DIR = os.getenv('TMPDIR') TMP_DIR = os.getenv('TMPDIR')
......
# Copyright 2019 Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev and
# Percy Liang. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluation of SQuAD predictions (version 1.1).
The functions are copied from
https://worksheets.codalab.org/rest/bundles/0xbcd57bee090b421c982906709c8c27e1/contents/blob/.
The SQuAD dataset is described in this paper:
SQuAD: 100,000+ Questions for Machine Comprehension of Text
Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, Percy Liang
https://nlp.stanford.edu/pubs/rajpurkar2016squad.pdf
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
import string
# pylint: disable=g-bad-import-order
from absl import logging
# pylint: enable=g-bad-import-order
def _normalize_answer(s):
"""Lowers text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def _f1_score(prediction, ground_truth):
"""Computes F1 score by comparing prediction to ground truth."""
prediction_tokens = _normalize_answer(prediction).split()
ground_truth_tokens = _normalize_answer(ground_truth).split()
prediction_counter = collections.Counter(prediction_tokens)
ground_truth_counter = collections.Counter(ground_truth_tokens)
common = prediction_counter & ground_truth_counter
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def _exact_match_score(prediction, ground_truth):
"""Checks if predicted answer exactly matches ground truth answer."""
return _normalize_answer(prediction) == _normalize_answer(ground_truth)
def _metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
"""Computes the max over all metric scores."""
scores_for_ground_truths = []
for ground_truth in ground_truths:
score = metric_fn(prediction, ground_truth)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)
def evaluate(dataset, predictions):
"""Evaluates predictions for a dataset."""
f1 = exact_match = total = 0
for article in dataset:
for paragraph in article["paragraphs"]:
for qa in paragraph["qas"]:
total += 1
if qa["id"] not in predictions:
message = "Unanswered question " + qa["id"] + " will receive score 0."
logging.error(message)
continue
ground_truths = [entry["text"] for entry in qa["answers"]]
prediction = predictions[qa["id"]]
exact_match += _metric_max_over_ground_truths(_exact_match_score,
prediction, ground_truths)
f1 += _metric_max_over_ground_truths(_f1_score, prediction,
ground_truths)
exact_match = exact_match / total
f1 = f1 / total
return {"exact_match": exact_match, "f1": f1}
...@@ -20,10 +20,10 @@ import functools ...@@ -20,10 +20,10 @@ import functools
import time import time
from absl import flags from absl import flags
import tensorflow.compat.v2 as tf import tensorflow as tf
import tensorflow_hub as hub import tensorflow_hub as hub
from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark from official.benchmark.perfzero_benchmark import PerfZeroBenchmark
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -22,12 +22,11 @@ import time ...@@ -22,12 +22,11 @@ import time
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
from official.benchmark import benchmark_wrappers
from official.benchmark.perfzero_benchmark import PerfZeroBenchmark
from official.nlp.transformer import misc from official.nlp.transformer import misc
from official.nlp.transformer import transformer_main as transformer_main from official.nlp.transformer import transformer_main as transformer_main
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.testing import benchmark_wrappers
from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark
TRANSFORMER_EN2DE_DATA_DIR_NAME = 'wmt32k-en2de-official' TRANSFORMER_EN2DE_DATA_DIR_NAME = 'wmt32k-en2de-official'
EN2DE_2014_BLEU_DATA_DIR_NAME = 'newstest2014' EN2DE_2014_BLEU_DATA_DIR_NAME = 'newstest2014'
...@@ -44,7 +43,6 @@ class TransformerBenchmark(PerfZeroBenchmark): ...@@ -44,7 +43,6 @@ class TransformerBenchmark(PerfZeroBenchmark):
def __init__(self, output_dir=None, default_flags=None, root_data_dir=None, def __init__(self, output_dir=None, default_flags=None, root_data_dir=None,
flag_methods=None): flag_methods=None):
assert tf.version.VERSION.startswith('2.')
root_data_dir = root_data_dir if root_data_dir else '' root_data_dir = root_data_dir if root_data_dir else ''
self.train_data_dir = os.path.join(root_data_dir, self.train_data_dir = os.path.join(root_data_dir,
......
...@@ -31,7 +31,7 @@ import tensorflow as tf ...@@ -31,7 +31,7 @@ import tensorflow as tf
from official.benchmark import bert_benchmark_utils as benchmark_utils from official.benchmark import bert_benchmark_utils as benchmark_utils
from official.nlp.xlnet import run_classifier from official.nlp.xlnet import run_classifier
from official.nlp.xlnet import run_squad from official.nlp.xlnet import run_squad
from official.utils.testing import benchmark_wrappers from official.benchmark import benchmark_wrappers
# pylint: disable=line-too-long # pylint: disable=line-too-long
......
This diff is collapsed.
...@@ -257,10 +257,8 @@ class RuntimeConfig(Config): ...@@ -257,10 +257,8 @@ class RuntimeConfig(Config):
Attributes: Attributes:
distribution_strategy: e.g. 'mirrored', 'tpu', etc. distribution_strategy: e.g. 'mirrored', 'tpu', etc.
enable_eager: Whether or not to enable eager mode.
enable_xla: Whether or not to enable XLA. enable_xla: Whether or not to enable XLA.
per_gpu_thread_count: thread count per GPU. per_gpu_thread_count: thread count per GPU.
gpu_threads_enabled: Whether or not GPU threads are enabled.
gpu_thread_mode: Whether and how the GPU device uses its own threadpool. gpu_thread_mode: Whether and how the GPU device uses its own threadpool.
dataset_num_private_threads: Number of threads for a private threadpool dataset_num_private_threads: Number of threads for a private threadpool
created for all datasets computation. created for all datasets computation.
...@@ -272,11 +270,13 @@ class RuntimeConfig(Config): ...@@ -272,11 +270,13 @@ class RuntimeConfig(Config):
all_reduce_alg: Defines the algorithm for performing all-reduce. all_reduce_alg: Defines the algorithm for performing all-reduce.
num_packs: Sets `num_packs` in the cross device ops used in num_packs: Sets `num_packs` in the cross device ops used in
MirroredStrategy. For details, see tf.distribute.NcclAllReduce. MirroredStrategy. For details, see tf.distribute.NcclAllReduce.
loss_scale: The type of loss scale. This is used when setting the mixed
precision policy.
run_eagerly: Whether or not to run the experiment eagerly.
""" """
distribution_strategy: str = 'mirrored' distribution_strategy: str = 'mirrored'
enable_eager: bool = False
enable_xla: bool = False enable_xla: bool = False
gpu_threads_enabled: bool = False
gpu_thread_mode: Optional[str] = None gpu_thread_mode: Optional[str] = None
dataset_num_private_threads: Optional[int] = None dataset_num_private_threads: Optional[int] = None
per_gpu_thread_count: int = 0 per_gpu_thread_count: int = 0
...@@ -286,6 +286,8 @@ class RuntimeConfig(Config): ...@@ -286,6 +286,8 @@ class RuntimeConfig(Config):
task_index: int = -1 task_index: int = -1
all_reduce_alg: Optional[str] = None all_reduce_alg: Optional[str] = None
num_packs: int = 1 num_packs: int = 1
loss_scale: Optional[str] = None
run_eagerly: bool = False
@dataclasses.dataclass @dataclasses.dataclass
...@@ -312,7 +314,10 @@ class CallbacksConfig(Config): ...@@ -312,7 +314,10 @@ class CallbacksConfig(Config):
Callback. Defaults to True. Callback. Defaults to True.
enable_tensorboard: Whether or not to enable Tensorboard as a Callback. enable_tensorboard: Whether or not to enable Tensorboard as a Callback.
Defaults to True. Defaults to True.
enable_time_history: Whether or not to enable TimeHistory Callbacks.
Defaults to True.
""" """
enable_checkpoint_and_export: bool = True enable_checkpoint_and_export: bool = True
enable_tensorboard: bool = True enable_tensorboard: bool = True
enable_time_history: bool = True
...@@ -19,7 +19,6 @@ from __future__ import division ...@@ -19,7 +19,6 @@ from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
from __future__ import print_function from __future__ import print_function
import json
import os import os
from absl import flags from absl import flags
...@@ -31,8 +30,9 @@ import tensorflow as tf ...@@ -31,8 +30,9 @@ import tensorflow as tf
# pylint: disable=unused-import,g-import-not-at-top,redefined-outer-name,reimported # pylint: disable=unused-import,g-import-not-at-top,redefined-outer-name,reimported
from typing import Optional, Dict, List, Text, Callable, Union, Iterator, Any from typing import Optional, Dict, List, Text, Callable, Union, Iterator, Any
from official.modeling.hyperparams import params_dict from official.modeling.hyperparams import params_dict
from official.utils.misc import distribution_utils
from official.utils import hyperparams_flags from official.utils import hyperparams_flags
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -59,6 +59,45 @@ def _no_metric(): ...@@ -59,6 +59,45 @@ def _no_metric():
return None return None
def metrics_as_dict(metric):
"""Puts input metric(s) into a list.
Args:
metric: metric(s) to be put into the list. `metric` could be a object, a
list or a dict of tf.keras.metrics.Metric or has the `required_method`.
Returns:
A dictionary of valid metrics.
"""
if isinstance(metric, tf.keras.metrics.Metric):
metrics = {metric.name: metric}
elif isinstance(metric, list):
metrics = {m.name: m for m in metric}
elif isinstance(metric, dict):
metrics = metric
elif not metric:
return {}
else:
metrics = {'metric': metric}
return metrics
def metric_results(metric):
"""Collects results from the given metric(s)."""
metrics = metrics_as_dict(metric)
metric_result = {
name: m.result().numpy().astype(float) for name, m in metrics.items()
}
return metric_result
def reset_states(metric):
"""Resets states of the given metric(s)."""
metrics = metrics_as_dict(metric)
for m in metrics.values():
m.reset_states()
class SummaryWriter(object): class SummaryWriter(object):
"""Simple SummaryWriter for writing dictionary of metrics. """Simple SummaryWriter for writing dictionary of metrics.
...@@ -185,6 +224,7 @@ class DistributedExecutor(object): ...@@ -185,6 +224,7 @@ class DistributedExecutor(object):
loss_fn, loss_fn,
optimizer, optimizer,
metric=None): metric=None):
metrics = metrics_as_dict(metric)
def _replicated_step(inputs): def _replicated_step(inputs):
"""Replicated training step.""" """Replicated training step."""
...@@ -195,11 +235,8 @@ class DistributedExecutor(object): ...@@ -195,11 +235,8 @@ class DistributedExecutor(object):
prediction_loss = loss_fn(labels, outputs) prediction_loss = loss_fn(labels, outputs)
loss = tf.reduce_mean(prediction_loss) loss = tf.reduce_mean(prediction_loss)
loss = loss / strategy.num_replicas_in_sync loss = loss / strategy.num_replicas_in_sync
if isinstance(metric, tf.keras.metrics.Metric): for m in metrics.values():
metric.update_state(labels, outputs) m.update_state(labels, outputs)
else:
logging.error('train metric is not an instance of '
'tf.keras.metrics.Metric.')
grads = tape.gradient(loss, model.trainable_variables) grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables)) optimizer.apply_gradients(zip(grads, model.trainable_variables))
...@@ -235,6 +272,7 @@ class DistributedExecutor(object): ...@@ -235,6 +272,7 @@ class DistributedExecutor(object):
Args: Args:
iterator: an iterator that yields input tensors. iterator: an iterator that yields input tensors.
num_steps: the number of steps in the loop.
Returns: Returns:
The loss tensor. The loss tensor.
...@@ -259,6 +297,7 @@ class DistributedExecutor(object): ...@@ -259,6 +297,7 @@ class DistributedExecutor(object):
def _create_test_step(self, strategy, model, metric): def _create_test_step(self, strategy, model, metric):
"""Creates a distributed test step.""" """Creates a distributed test step."""
metrics = metrics_as_dict(metric)
@tf.function @tf.function
def test_step(iterator): def test_step(iterator):
...@@ -266,22 +305,20 @@ class DistributedExecutor(object): ...@@ -266,22 +305,20 @@ class DistributedExecutor(object):
if not metric: if not metric:
logging.info('Skip test_step because metric is None (%s)', metric) logging.info('Skip test_step because metric is None (%s)', metric)
return None, None return None, None
if not isinstance(metric, tf.keras.metrics.Metric):
raise ValueError(
'Metric must be an instance of tf.keras.metrics.Metric '
'for running in test_step. Actual {}'.format(metric))
def _test_step_fn(inputs): def _test_step_fn(inputs):
"""Replicated accuracy calculation.""" """Replicated accuracy calculation."""
inputs, labels = inputs inputs, labels = inputs
model_outputs = model(inputs, training=False) model_outputs = model(inputs, training=False)
metric.update_state(labels, model_outputs) for m in metrics.values():
m.update_state(labels, model_outputs)
return labels, model_outputs return labels, model_outputs
return strategy.run(_test_step_fn, args=(next(iterator),)) return strategy.run(_test_step_fn, args=(next(iterator),))
return test_step return test_step
def train(self, def train(self,
train_input_fn: Callable[[params_dict.ParamsDict], tf.data.Dataset], train_input_fn: Callable[[params_dict.ParamsDict], tf.data.Dataset],
eval_input_fn: Callable[[params_dict.ParamsDict], eval_input_fn: Callable[[params_dict.ParamsDict],
...@@ -330,10 +367,12 @@ class DistributedExecutor(object): ...@@ -330,10 +367,12 @@ class DistributedExecutor(object):
eval_metric_fn = eval_metric_fn or _no_metric eval_metric_fn = eval_metric_fn or _no_metric
if custom_callbacks and iterations_per_loop != 1: if custom_callbacks and iterations_per_loop != 1:
logging.error( logging.warning(
'It is sematically wrong to run callbacks when ' 'It is sematically wrong to run callbacks when '
'iterations_per_loop is not one (%s)', iterations_per_loop) 'iterations_per_loop is not one (%s)', iterations_per_loop)
custom_callbacks = custom_callbacks or []
def _run_callbacks_on_batch_begin(batch): def _run_callbacks_on_batch_begin(batch):
"""Runs custom callbacks at the start of every step.""" """Runs custom callbacks at the start of every step."""
if not custom_callbacks: if not custom_callbacks:
...@@ -402,6 +441,11 @@ class DistributedExecutor(object): ...@@ -402,6 +441,11 @@ class DistributedExecutor(object):
test_summary_writer = summary_writer_fn(model_dir, 'eval_test') test_summary_writer = summary_writer_fn(model_dir, 'eval_test')
self.eval_summary_writer = test_summary_writer.writer self.eval_summary_writer = test_summary_writer.writer
# Use training summary writer in TimeHistory if it's in use
for cb in custom_callbacks:
if isinstance(cb, keras_utils.TimeHistory):
cb.summary_writer = self.train_summary_writer
# Continue training loop. # Continue training loop.
train_step = self._create_train_step( train_step = self._create_train_step(
strategy=strategy, strategy=strategy,
...@@ -414,6 +458,20 @@ class DistributedExecutor(object): ...@@ -414,6 +458,20 @@ class DistributedExecutor(object):
self.global_train_step = model.optimizer.iterations self.global_train_step = model.optimizer.iterations
test_step = self._create_test_step(strategy, model, metric=eval_metric) test_step = self._create_test_step(strategy, model, metric=eval_metric)
# Step-0 operations
if current_step == 0 and not latest_checkpoint_file:
_save_checkpoint(
checkpoint, model_dir, checkpoint_name.format(step=current_step))
if test_step:
eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
eval_metric_result = self._run_evaluation(
test_step, current_step, eval_metric, eval_iterator)
logging.info(
'Step: %s evalation metric = %s.', current_step, eval_metric_result)
test_summary_writer(
metrics=eval_metric_result, step=optimizer.iterations)
reset_states(eval_metric)
logging.info('Training started') logging.info('Training started')
last_save_checkpoint_step = current_step last_save_checkpoint_step = current_step
while current_step < total_steps: while current_step < total_steps:
...@@ -422,23 +480,19 @@ class DistributedExecutor(object): ...@@ -422,23 +480,19 @@ class DistributedExecutor(object):
_run_callbacks_on_batch_begin(current_step) _run_callbacks_on_batch_begin(current_step)
train_loss = train_step(train_iterator, train_loss = train_step(train_iterator,
tf.convert_to_tensor(num_steps, dtype=tf.int32)) tf.convert_to_tensor(num_steps, dtype=tf.int32))
_run_callbacks_on_batch_end(current_step)
current_step += num_steps current_step += num_steps
train_loss = tf.nest.map_structure(lambda x: x.numpy().astype(float), train_loss = tf.nest.map_structure(lambda x: x.numpy().astype(float),
train_loss) train_loss)
_run_callbacks_on_batch_end(current_step - 1)
if not isinstance(train_loss, dict): if not isinstance(train_loss, dict):
train_loss = {'total_loss': train_loss} train_loss = {'total_loss': train_loss}
if np.isnan(train_loss['total_loss']): if np.isnan(train_loss['total_loss']):
raise ValueError('total loss is NaN.') raise ValueError('total loss is NaN.')
if train_metric: if train_metric:
train_metric_result = train_metric.result() train_metric_result = metric_results(train_metric)
if isinstance(train_metric, tf.keras.metrics.Metric):
train_metric_result = tf.nest.map_structure(
lambda x: x.numpy().astype(float), train_metric_result)
if not isinstance(train_metric_result, dict):
train_metric_result = {'metric': train_metric_result}
train_metric_result.update(train_loss) train_metric_result.update(train_loss)
else: else:
train_metric_result = train_loss train_metric_result = train_loss
...@@ -475,9 +529,9 @@ class DistributedExecutor(object): ...@@ -475,9 +529,9 @@ class DistributedExecutor(object):
# Re-initialize evaluation metric, except the last step. # Re-initialize evaluation metric, except the last step.
if eval_metric and current_step < total_steps: if eval_metric and current_step < total_steps:
eval_metric.reset_states() reset_states(eval_metric)
if train_metric and current_step < total_steps: if train_metric and current_step < total_steps:
train_metric.reset_states() reset_states(train_metric)
# Reaches the end of training and saves the last checkpoint. # Reaches the end of training and saves the last checkpoint.
if last_save_checkpoint_step < total_steps: if last_save_checkpoint_step < total_steps:
...@@ -493,6 +547,9 @@ class DistributedExecutor(object): ...@@ -493,6 +547,9 @@ class DistributedExecutor(object):
test_summary_writer( test_summary_writer(
metrics=eval_metric_result, step=optimizer.iterations) metrics=eval_metric_result, step=optimizer.iterations)
self.train_summary_writer.close()
self.eval_summary_writer.close()
return train_loss, eval_metric_result return train_loss, eval_metric_result
def _run_evaluation(self, test_step, current_training_step, metric, def _run_evaluation(self, test_step, current_training_step, metric,
...@@ -510,9 +567,7 @@ class DistributedExecutor(object): ...@@ -510,9 +567,7 @@ class DistributedExecutor(object):
except (StopIteration, tf.errors.OutOfRangeError): except (StopIteration, tf.errors.OutOfRangeError):
break break
metric_result = metric.result() metric_result = metric_results(metric)
if isinstance(metric, tf.keras.metrics.Metric):
metric_result = metric_result.numpy().astype(float)
logging.info('Step: [%d] Validation metric = %f', current_training_step, logging.info('Step: [%d] Validation metric = %f', current_training_step,
metric_result) metric_result)
return metric_result return metric_result
...@@ -629,7 +684,7 @@ class DistributedExecutor(object): ...@@ -629,7 +684,7 @@ class DistributedExecutor(object):
logging.info('Step: %s evalation metric = %s.', current_step, logging.info('Step: %s evalation metric = %s.', current_step,
eval_metric_result) eval_metric_result)
summary_writer(metrics=eval_metric_result, step=current_step) summary_writer(metrics=eval_metric_result, step=current_step)
eval_metric.reset_states() reset_states(eval_metric)
return eval_metric_result, current_step return eval_metric_result, current_step
......
...@@ -7,8 +7,9 @@ state-of-the-art models. ...@@ -7,8 +7,9 @@ state-of-the-art models.
The repository contains the following models, with implementations, pre-trained The repository contains the following models, with implementations, pre-trained
model weights, usage scripts and conversion utilities: model weights, usage scripts and conversion utilities:
* [Bert](bert)
* [Albert](albert) * [Albert](albert)
* [Bert](bert)
* [NHNet](nhnet)
* [XLNet](xlnet) * [XLNet](xlnet)
* [Transformer for translation](transformer) * [Transformer for translation](transformer)
...@@ -16,6 +17,3 @@ Addtional features: ...@@ -16,6 +17,3 @@ Addtional features:
* Distributed trainable on both multi-GPU and TPU * Distributed trainable on both multi-GPU and TPU
* e2e training for custom models, including both pretraining and finetuning. * e2e training for custom models, including both pretraining and finetuning.
...@@ -19,9 +19,12 @@ from __future__ import division ...@@ -19,9 +19,12 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import json import json
import os
import time
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging
import tensorflow as tf import tensorflow as tf
from official.nlp.albert import configs as albert_configs from official.nlp.albert import configs as albert_configs
...@@ -53,7 +56,7 @@ def train_squad(strategy, ...@@ -53,7 +56,7 @@ def train_squad(strategy,
def predict_squad(strategy, input_meta_data): def predict_squad(strategy, input_meta_data):
"""Makes predictions for a squad dataset.""" """Makes predictions for the squad dataset."""
bert_config = albert_configs.AlbertConfig.from_json_file( bert_config = albert_configs.AlbertConfig.from_json_file(
FLAGS.bert_config_file) FLAGS.bert_config_file)
tokenizer = tokenization.FullSentencePieceTokenizer( tokenizer = tokenization.FullSentencePieceTokenizer(
...@@ -63,6 +66,18 @@ def predict_squad(strategy, input_meta_data): ...@@ -63,6 +66,18 @@ def predict_squad(strategy, input_meta_data):
bert_config, squad_lib_sp) bert_config, squad_lib_sp)
def eval_squad(strategy, input_meta_data):
"""Evaluate on the squad dataset."""
bert_config = albert_configs.AlbertConfig.from_json_file(
FLAGS.bert_config_file)
tokenizer = tokenization.FullSentencePieceTokenizer(
sp_model_file=FLAGS.sp_model_file)
eval_metrics = run_squad_helper.eval_squad(
strategy, input_meta_data, tokenizer, bert_config, squad_lib_sp)
return eval_metrics
def export_squad(model_export_path, input_meta_data): def export_squad(model_export_path, input_meta_data):
"""Exports a trained model as a `SavedModel` for inference. """Exports a trained model as a `SavedModel` for inference.
...@@ -97,10 +112,25 @@ def main(_): ...@@ -97,10 +112,25 @@ def main(_):
num_gpus=FLAGS.num_gpus, num_gpus=FLAGS.num_gpus,
all_reduce_alg=FLAGS.all_reduce_alg, all_reduce_alg=FLAGS.all_reduce_alg,
tpu_address=FLAGS.tpu) tpu_address=FLAGS.tpu)
if FLAGS.mode in ('train', 'train_and_predict'):
if 'train' in FLAGS.mode:
train_squad(strategy, input_meta_data, run_eagerly=FLAGS.run_eagerly) train_squad(strategy, input_meta_data, run_eagerly=FLAGS.run_eagerly)
if FLAGS.mode in ('predict', 'train_and_predict'): if 'predict' in FLAGS.mode:
predict_squad(strategy, input_meta_data) predict_squad(strategy, input_meta_data)
if 'eval' in FLAGS.mode:
eval_metrics = eval_squad(strategy, input_meta_data)
f1_score = eval_metrics['final_f1']
logging.info('SQuAD eval F1-score: %f', f1_score)
summary_dir = os.path.join(FLAGS.model_dir, 'summaries', 'eval')
summary_writer = tf.summary.create_file_writer(summary_dir)
with summary_writer.as_default():
# TODO(lehou): write to the correct step number.
tf.summary.scalar('F1-score', f1_score, step=0)
summary_writer.flush()
# Also write eval_metrics to json file.
squad_lib_sp.write_to_json_files(
eval_metrics, os.path.join(summary_dir, 'eval_metrics.json'))
time.sleep(60)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -54,29 +54,41 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer): ...@@ -54,29 +54,41 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
self.add_metric(lm_example_loss, name='lm_example_loss', aggregation='mean') self.add_metric(lm_example_loss, name='lm_example_loss', aggregation='mean')
next_sentence_accuracy = tf.keras.metrics.sparse_categorical_accuracy( if sentence_labels is not None:
sentence_labels, sentence_output) next_sentence_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
self.add_metric( sentence_labels, sentence_output)
next_sentence_accuracy, self.add_metric(
name='next_sentence_accuracy', next_sentence_accuracy,
aggregation='mean') name='next_sentence_accuracy',
aggregation='mean')
self.add_metric(
next_sentence_loss, name='next_sentence_loss', aggregation='mean') if next_sentence_loss is not None:
self.add_metric(
def call(self, lm_output, sentence_output, lm_label_ids, lm_label_weights, next_sentence_loss, name='next_sentence_loss', aggregation='mean')
sentence_labels):
def call(self,
lm_output,
sentence_output,
lm_label_ids,
lm_label_weights,
sentence_labels=None):
"""Implements call() for the layer.""" """Implements call() for the layer."""
lm_label_weights = tf.cast(lm_label_weights, tf.float32) lm_label_weights = tf.cast(lm_label_weights, tf.float32)
lm_output = tf.cast(lm_output, tf.float32) lm_output = tf.cast(lm_output, tf.float32)
sentence_output = tf.cast(sentence_output, tf.float32)
mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss( mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss(
labels=lm_label_ids, predictions=lm_output, weights=lm_label_weights) labels=lm_label_ids, predictions=lm_output, weights=lm_label_weights)
sentence_loss = losses.weighted_sparse_categorical_crossentropy_loss(
labels=sentence_labels, predictions=sentence_output) if sentence_labels is not None:
loss = mask_label_loss + sentence_loss sentence_output = tf.cast(sentence_output, tf.float32)
batch_shape = tf.slice(tf.shape(sentence_labels), [0], [1]) sentence_loss = losses.weighted_sparse_categorical_crossentropy_loss(
labels=sentence_labels, predictions=sentence_output)
loss = mask_label_loss + sentence_loss
else:
sentence_loss = None
loss = mask_label_loss
batch_shape = tf.slice(tf.shape(lm_label_ids), [0], [1])
# TODO(hongkuny): Avoids the hack and switches add_loss. # TODO(hongkuny): Avoids the hack and switches add_loss.
final_loss = tf.fill(batch_shape, loss) final_loss = tf.fill(batch_shape, loss)
...@@ -120,8 +132,12 @@ def get_transformer_encoder(bert_config, ...@@ -120,8 +132,12 @@ def get_transformer_encoder(bert_config,
dropout_rate=bert_config.hidden_dropout_prob, dropout_rate=bert_config.hidden_dropout_prob,
attention_dropout_rate=bert_config.attention_probs_dropout_prob, attention_dropout_rate=bert_config.attention_probs_dropout_prob,
) )
kwargs = dict(embedding_cfg=embedding_cfg, hidden_cfg=hidden_cfg, kwargs = dict(
num_hidden_instances=bert_config.num_hidden_layers,) embedding_cfg=embedding_cfg,
hidden_cfg=hidden_cfg,
num_hidden_instances=bert_config.num_hidden_layers,
pooled_output_dim=bert_config.hidden_size,
)
# Relies on gin configuration to define the Transformer encoder arguments. # Relies on gin configuration to define the Transformer encoder arguments.
return transformer_encoder_cls(**kwargs) return transformer_encoder_cls(**kwargs)
...@@ -151,7 +167,8 @@ def get_transformer_encoder(bert_config, ...@@ -151,7 +167,8 @@ def get_transformer_encoder(bert_config,
def pretrain_model(bert_config, def pretrain_model(bert_config,
seq_length, seq_length,
max_predictions_per_seq, max_predictions_per_seq,
initializer=None): initializer=None,
use_next_sentence_label=True):
"""Returns model to be used for pre-training. """Returns model to be used for pre-training.
Args: Args:
...@@ -160,6 +177,7 @@ def pretrain_model(bert_config, ...@@ -160,6 +177,7 @@ def pretrain_model(bert_config,
max_predictions_per_seq: Maximum number of tokens in sequence to mask out max_predictions_per_seq: Maximum number of tokens in sequence to mask out
and use for pretraining. and use for pretraining.
initializer: Initializer for weights in BertPretrainer. initializer: Initializer for weights in BertPretrainer.
use_next_sentence_label: Whether to use the next sentence label.
Returns: Returns:
Pretraining model as well as core BERT submodel from which to save Pretraining model as well as core BERT submodel from which to save
...@@ -181,8 +199,12 @@ def pretrain_model(bert_config, ...@@ -181,8 +199,12 @@ def pretrain_model(bert_config,
shape=(max_predictions_per_seq,), shape=(max_predictions_per_seq,),
name='masked_lm_weights', name='masked_lm_weights',
dtype=tf.int32) dtype=tf.int32)
next_sentence_labels = tf.keras.layers.Input(
shape=(1,), name='next_sentence_labels', dtype=tf.int32) if use_next_sentence_label:
next_sentence_labels = tf.keras.layers.Input(
shape=(1,), name='next_sentence_labels', dtype=tf.int32)
else:
next_sentence_labels = None
transformer_encoder = get_transformer_encoder(bert_config, seq_length) transformer_encoder = get_transformer_encoder(bert_config, seq_length)
if initializer is None: if initializer is None:
...@@ -202,17 +224,18 @@ def pretrain_model(bert_config, ...@@ -202,17 +224,18 @@ def pretrain_model(bert_config,
vocab_size=bert_config.vocab_size) vocab_size=bert_config.vocab_size)
output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids, output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids,
masked_lm_weights, next_sentence_labels) masked_lm_weights, next_sentence_labels)
keras_model = tf.keras.Model( inputs = {
inputs={ 'input_word_ids': input_word_ids,
'input_word_ids': input_word_ids, 'input_mask': input_mask,
'input_mask': input_mask, 'input_type_ids': input_type_ids,
'input_type_ids': input_type_ids, 'masked_lm_positions': masked_lm_positions,
'masked_lm_positions': masked_lm_positions, 'masked_lm_ids': masked_lm_ids,
'masked_lm_ids': masked_lm_ids, 'masked_lm_weights': masked_lm_weights,
'masked_lm_weights': masked_lm_weights, }
'next_sentence_labels': next_sentence_labels, if use_next_sentence_label:
}, inputs['next_sentence_labels'] = next_sentence_labels
outputs=output_loss)
keras_model = tf.keras.Model(inputs=inputs, outputs=output_loss)
return keras_model, transformer_encoder return keras_model, transformer_encoder
...@@ -309,8 +332,7 @@ def classifier_model(bert_config, ...@@ -309,8 +332,7 @@ def classifier_model(bert_config,
shape=(max_seq_length,), dtype=tf.int32, name='input_mask') shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
input_type_ids = tf.keras.layers.Input( input_type_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids') shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
bert_model = hub.KerasLayer( bert_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable)
hub_module_url, trainable=hub_module_trainable)
pooled_output, _ = bert_model([input_word_ids, input_mask, input_type_ids]) pooled_output, _ = bert_model([input_word_ids, input_mask, input_type_ids])
output = tf.keras.layers.Dropout(rate=bert_config.hidden_dropout_prob)( output = tf.keras.layers.Dropout(rate=bert_config.hidden_dropout_prob)(
pooled_output) pooled_output)
......
...@@ -39,7 +39,6 @@ def define_common_bert_flags(): ...@@ -39,7 +39,6 @@ def define_common_bert_flags():
stop_threshold=False, stop_threshold=False,
batch_size=False, batch_size=False,
num_gpu=True, num_gpu=True,
hooks=False,
export_dir=False, export_dir=False,
distribution_strategy=True, distribution_strategy=True,
run_eagerly=True) run_eagerly=True)
...@@ -63,6 +62,10 @@ def define_common_bert_flags(): ...@@ -63,6 +62,10 @@ def define_common_bert_flags():
'inside.') 'inside.')
flags.DEFINE_float('learning_rate', 5e-5, flags.DEFINE_float('learning_rate', 5e-5,
'The initial learning rate for Adam.') 'The initial learning rate for Adam.')
flags.DEFINE_float('end_lr', 0.0,
'The end learning rate for learning rate decay.')
flags.DEFINE_string('optimizer_type', 'adamw',
'The type of optimizer to use for training (adamw|lamb)')
flags.DEFINE_boolean( flags.DEFINE_boolean(
'scale_loss', False, 'scale_loss', False,
'Whether to divide the loss by number of replica inside the per-replica ' 'Whether to divide the loss by number of replica inside the per-replica '
......
...@@ -20,6 +20,7 @@ from __future__ import print_function ...@@ -20,6 +20,7 @@ from __future__ import print_function
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging
import tensorflow as tf import tensorflow as tf
from typing import Text from typing import Text
from official.nlp.bert import bert_models from official.nlp.bert import bert_models
...@@ -34,6 +35,9 @@ flags.DEFINE_string("model_checkpoint_path", None, ...@@ -34,6 +35,9 @@ flags.DEFINE_string("model_checkpoint_path", None,
flags.DEFINE_string("export_path", None, "TF-Hub SavedModel destination path.") flags.DEFINE_string("export_path", None, "TF-Hub SavedModel destination path.")
flags.DEFINE_string("vocab_file", None, flags.DEFINE_string("vocab_file", None,
"The vocabulary file that the BERT model was trained on.") "The vocabulary file that the BERT model was trained on.")
flags.DEFINE_bool("do_lower_case", None, "Whether to lowercase. If None, "
"do_lower_case will be enabled if 'uncased' appears in the "
"name of --vocab_file")
def create_bert_model(bert_config: configs.BertConfig) -> tf.keras.Model: def create_bert_model(bert_config: configs.BertConfig) -> tf.keras.Model:
...@@ -65,21 +69,26 @@ def create_bert_model(bert_config: configs.BertConfig) -> tf.keras.Model: ...@@ -65,21 +69,26 @@ def create_bert_model(bert_config: configs.BertConfig) -> tf.keras.Model:
def export_bert_tfhub(bert_config: configs.BertConfig, def export_bert_tfhub(bert_config: configs.BertConfig,
model_checkpoint_path: Text, hub_destination: Text, model_checkpoint_path: Text, hub_destination: Text,
vocab_file: Text): vocab_file: Text, do_lower_case: bool = None):
"""Restores a tf.keras.Model and saves for TF-Hub.""" """Restores a tf.keras.Model and saves for TF-Hub."""
# If do_lower_case is not explicit, default to checking whether "uncased" is
# in the vocab file name
if do_lower_case is None:
do_lower_case = "uncased" in vocab_file
logging.info("Using do_lower_case=%s based on name of vocab_file=%s",
do_lower_case, vocab_file)
core_model, encoder = create_bert_model(bert_config) core_model, encoder = create_bert_model(bert_config)
checkpoint = tf.train.Checkpoint(model=encoder) checkpoint = tf.train.Checkpoint(model=encoder)
checkpoint.restore(model_checkpoint_path).assert_consumed() checkpoint.restore(model_checkpoint_path).assert_consumed()
core_model.vocab_file = tf.saved_model.Asset(vocab_file) core_model.vocab_file = tf.saved_model.Asset(vocab_file)
core_model.do_lower_case = tf.Variable( core_model.do_lower_case = tf.Variable(do_lower_case, trainable=False)
"uncased" in vocab_file, trainable=False)
core_model.save(hub_destination, include_optimizer=False, save_format="tf") core_model.save(hub_destination, include_optimizer=False, save_format="tf")
def main(_): def main(_):
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file) bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path, FLAGS.export_path, export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path, FLAGS.export_path,
FLAGS.vocab_file) FLAGS.vocab_file, FLAGS.do_lower_case)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment