Commit ab834d35 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Use distribution utils in XLNET

PiperOrigin-RevId: 331015243
parent 7d5c47aa
...@@ -14,11 +14,6 @@ ...@@ -14,11 +14,6 @@
# ============================================================================== # ==============================================================================
"""XLNet classification finetuning runner in tf2.0.""" """XLNet classification finetuning runner in tf2.0."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import functools import functools
# Import libraries # Import libraries
from absl import app from absl import app
...@@ -34,7 +29,7 @@ from official.nlp.xlnet import optimization ...@@ -34,7 +29,7 @@ from official.nlp.xlnet import optimization
from official.nlp.xlnet import training_utils from official.nlp.xlnet import training_utils
from official.nlp.xlnet import xlnet_config from official.nlp.xlnet import xlnet_config
from official.nlp.xlnet import xlnet_modeling as modeling from official.nlp.xlnet import xlnet_modeling as modeling
from official.utils.misc import tpu_lib from official.utils.misc import distribution_utils
flags.DEFINE_integer("n_class", default=2, help="Number of classes.") flags.DEFINE_integer("n_class", default=2, help="Number of classes.")
flags.DEFINE_string( flags.DEFINE_string(
...@@ -135,14 +130,9 @@ def get_metric_fn(): ...@@ -135,14 +130,9 @@ def get_metric_fn():
def main(unused_argv): def main(unused_argv):
del unused_argv del unused_argv
if FLAGS.strategy_type == "mirror": strategy = distribution_utils.get_distribution_strategy(
strategy = tf.distribute.MirroredStrategy() distribution_strategy=FLAGS.strategy_type,
elif FLAGS.strategy_type == "tpu": tpu_address=FLAGS.tpu)
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
else:
raise ValueError("The distribution strategy type is not supported: %s" %
FLAGS.strategy_type)
if strategy: if strategy:
logging.info("***** Number of cores used : %d", logging.info("***** Number of cores used : %d",
strategy.num_replicas_in_sync) strategy.num_replicas_in_sync)
......
...@@ -12,12 +12,7 @@ ...@@ -12,12 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""XLNet classification finetuning runner in tf2.0.""" """XLNet pretraining runner in tf2.0."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import functools import functools
import os import os
...@@ -34,7 +29,7 @@ from official.nlp.xlnet import optimization ...@@ -34,7 +29,7 @@ from official.nlp.xlnet import optimization
from official.nlp.xlnet import training_utils from official.nlp.xlnet import training_utils
from official.nlp.xlnet import xlnet_config from official.nlp.xlnet import xlnet_config
from official.nlp.xlnet import xlnet_modeling as modeling from official.nlp.xlnet import xlnet_modeling as modeling
from official.utils.misc import tpu_lib from official.utils.misc import distribution_utils
flags.DEFINE_integer( flags.DEFINE_integer(
"num_predict", "num_predict",
...@@ -77,17 +72,11 @@ def get_pretrainxlnet_model(model_config, run_config): ...@@ -77,17 +72,11 @@ def get_pretrainxlnet_model(model_config, run_config):
def main(unused_argv): def main(unused_argv):
del unused_argv del unused_argv
num_hosts = 1 num_hosts = 1
if FLAGS.strategy_type == "mirror": strategy = distribution_utils.get_distribution_strategy(
strategy = tf.distribute.MirroredStrategy() distribution_strategy=FLAGS.strategy_type,
elif FLAGS.strategy_type == "tpu": tpu_address=FLAGS.tpu)
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu) if FLAGS.strategy_type == "tpu":
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) num_hosts = strategy.extended.num_hosts
topology = FLAGS.tpu_topology.split("x")
total_num_core = 2 * int(topology[0]) * int(topology[1])
num_hosts = total_num_core // FLAGS.num_core_per_host
else:
raise ValueError("The distribution strategy type is not supported: %s" %
FLAGS.strategy_type)
if strategy: if strategy:
logging.info("***** Number of cores used : %d", logging.info("***** Number of cores used : %d",
strategy.num_replicas_in_sync) strategy.num_replicas_in_sync)
......
...@@ -14,11 +14,6 @@ ...@@ -14,11 +14,6 @@
# ============================================================================== # ==============================================================================
"""XLNet SQUAD finetuning runner in tf2.0.""" """XLNet SQUAD finetuning runner in tf2.0."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import functools import functools
import json import json
import os import os
...@@ -39,7 +34,7 @@ from official.nlp.xlnet import squad_utils ...@@ -39,7 +34,7 @@ from official.nlp.xlnet import squad_utils
from official.nlp.xlnet import training_utils from official.nlp.xlnet import training_utils
from official.nlp.xlnet import xlnet_config from official.nlp.xlnet import xlnet_config
from official.nlp.xlnet import xlnet_modeling as modeling from official.nlp.xlnet import xlnet_modeling as modeling
from official.utils.misc import tpu_lib from official.utils.misc import distribution_utils
flags.DEFINE_string( flags.DEFINE_string(
"test_feature_path", default=None, help="Path to feature of test set.") "test_feature_path", default=None, help="Path to feature of test set.")
...@@ -217,14 +212,9 @@ def get_qaxlnet_model(model_config, run_config, start_n_top, end_n_top): ...@@ -217,14 +212,9 @@ def get_qaxlnet_model(model_config, run_config, start_n_top, end_n_top):
def main(unused_argv): def main(unused_argv):
del unused_argv del unused_argv
if FLAGS.strategy_type == "mirror": strategy = distribution_utils.get_distribution_strategy(
strategy = tf.distribute.MirroredStrategy() distribution_strategy=FLAGS.strategy_type,
elif FLAGS.strategy_type == "tpu": tpu_address=FLAGS.tpu)
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
else:
raise ValueError("The distribution strategy type is not supported: %s" %
FLAGS.strategy_type)
if strategy: if strategy:
logging.info("***** Number of cores used : %d", logging.info("***** Number of cores used : %d",
strategy.num_replicas_in_sync) strategy.num_replicas_in_sync)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment