Commit 7bd795db authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Move bert benchmark stuff to a central place.

PiperOrigin-RevId: 269835917
parent e2293a97
......@@ -26,10 +26,10 @@ import time
# pylint: disable=g-bad-import-order
from absl import flags
from absl.testing import flagsaver
import tensorflow as tf
import tensorflow.compat.v2 as tf
# pylint: enable=g-bad-import-order
from official.bert.benchmark import benchmark_utils
from official.benchmark import bert_benchmark_utils as benchmark_utils
from official.nlp import bert_modeling as modeling
from official.nlp.bert import run_classifier
from official.utils.misc import distribution_utils
......
......@@ -25,7 +25,7 @@ import time
import numpy as np
from absl import flags
from absl.testing import flagsaver
import tensorflow as tf
import tensorflow.compat.v2 as tf
# pylint: enable=g-bad-import-order
from official.utils.flags import core as flags_core
......
......@@ -25,11 +25,11 @@ import time
# pylint: disable=g-bad-import-order
from absl import flags
from absl.testing import flagsaver
import tensorflow as tf
import tensorflow.compat.v2 as tf
# pylint: enable=g-bad-import-order
from official.bert.benchmark import benchmark_utils
from official.bert.benchmark import squad_evaluate_v1_1
from official.benchmark import bert_benchmark_utils as benchmark_utils
from official.benchmark import squad_evaluate_v1_1
from official.nlp.bert import run_squad
from official.utils.misc import distribution_utils
......@@ -82,7 +82,7 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
@flagsaver.flagsaver
def _train_squad(self, use_ds=True, run_eagerly=False):
"""Runs BERT SQuAD training."""
assert tf.version.VERSION.startswith('2.')
tf.enable_v2_behavior()
input_meta_data = self._read_input_meta_data_from_file()
strategy = self._get_distribution_strategy(use_ds)
......@@ -95,7 +95,7 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
@flagsaver.flagsaver
def _evaluate_squad(self, use_ds=True):
"""Runs BERT SQuAD evaluation."""
assert tf.version.VERSION.startswith('2.')
tf.enable_v2_behavior()
input_meta_data = self._read_input_meta_data_from_file()
strategy = self._get_distribution_strategy(use_ds)
......
# Copyright 2019 Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev and
# Percy Liang. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment