Commit 7bd795db authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Move bert benchmark stuff to a central place.

PiperOrigin-RevId: 269835917
parent e2293a97
...@@ -26,10 +26,10 @@ import time ...@@ -26,10 +26,10 @@ import time
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
from absl import flags from absl import flags
from absl.testing import flagsaver from absl.testing import flagsaver
import tensorflow as tf import tensorflow.compat.v2 as tf
# pylint: enable=g-bad-import-order # pylint: enable=g-bad-import-order
from official.bert.benchmark import benchmark_utils from official.benchmark import bert_benchmark_utils as benchmark_utils
from official.nlp import bert_modeling as modeling from official.nlp import bert_modeling as modeling
from official.nlp.bert import run_classifier from official.nlp.bert import run_classifier
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
......
...@@ -25,7 +25,7 @@ import time ...@@ -25,7 +25,7 @@ import time
import numpy as np import numpy as np
from absl import flags from absl import flags
from absl.testing import flagsaver from absl.testing import flagsaver
import tensorflow as tf import tensorflow.compat.v2 as tf
# pylint: enable=g-bad-import-order # pylint: enable=g-bad-import-order
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
......
...@@ -25,11 +25,11 @@ import time ...@@ -25,11 +25,11 @@ import time
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
from absl import flags from absl import flags
from absl.testing import flagsaver from absl.testing import flagsaver
import tensorflow as tf import tensorflow.compat.v2 as tf
# pylint: enable=g-bad-import-order # pylint: enable=g-bad-import-order
from official.bert.benchmark import benchmark_utils from official.benchmark import bert_benchmark_utils as benchmark_utils
from official.bert.benchmark import squad_evaluate_v1_1 from official.benchmark import squad_evaluate_v1_1
from official.nlp.bert import run_squad from official.nlp.bert import run_squad
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
...@@ -82,7 +82,7 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase): ...@@ -82,7 +82,7 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
@flagsaver.flagsaver @flagsaver.flagsaver
def _train_squad(self, use_ds=True, run_eagerly=False): def _train_squad(self, use_ds=True, run_eagerly=False):
"""Runs BERT SQuAD training.""" """Runs BERT SQuAD training."""
assert tf.version.VERSION.startswith('2.') tf.enable_v2_behavior()
input_meta_data = self._read_input_meta_data_from_file() input_meta_data = self._read_input_meta_data_from_file()
strategy = self._get_distribution_strategy(use_ds) strategy = self._get_distribution_strategy(use_ds)
...@@ -95,7 +95,7 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase): ...@@ -95,7 +95,7 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
@flagsaver.flagsaver @flagsaver.flagsaver
def _evaluate_squad(self, use_ds=True): def _evaluate_squad(self, use_ds=True):
"""Runs BERT SQuAD evaluation.""" """Runs BERT SQuAD evaluation."""
assert tf.version.VERSION.startswith('2.') tf.enable_v2_behavior()
input_meta_data = self._read_input_meta_data_from_file() input_meta_data = self._read_input_meta_data_from_file()
strategy = self._get_distribution_strategy(use_ds) strategy = self._get_distribution_strategy(use_ds)
......
# Copyright 2019 Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev and # Copyright 2019 Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev and
# Percy Liang. All Rights Reserved. # Percy Liang. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment