Unverified Commit d76e39e7 authored by Hongjun Choi's avatar Hongjun Choi Committed by GitHub
Browse files

Merged commit includes the following changes: (#6921)

250606180  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Fix BERT benchamrk test errors.

--
250589623  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Change BERT benchmark test pretrained checkpoint url.

--
250587892  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Fix error in BERT custom training loop checkpoint restoration.

--
250577163  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Add logic to inject callback that measures performance in BERT custom training
    loop.

--
250529526  by hongkuny<hongkuny@google.com>:

    Internal clean up

--
250428976  by hongkuny<hongkuny@google.com>:

    Internal change

250415383  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Add min/max value to BERT classifier benchmark test.

--
250376246  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Add benchmark performance test to run BERT on multiple numbers of GPUs.

--

PiperOrigin-RevId: 250606180
parent ab993a21
...@@ -19,34 +19,67 @@ from __future__ import division ...@@ -19,34 +19,67 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import json import json
import math
import os import os
import time import time
# pylint: disable=g-bad-import-order
import numpy as np
from absl import flags from absl import flags
from absl.testing import flagsaver from absl.testing import flagsaver
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.bert import modeling
from official.bert import run_classifier from official.bert import run_classifier
from official.utils.misc import distribution_utils
# pylint: disable=line-too-long # pylint: disable=line-too-long
PRETRAINED_CHECKPOINT_PATH = 'gs://tf-perfzero-data/bert/bert_model.ckpt' PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/bert/tf_20/uncased_L-24_H-1024_A-16/bert_model.ckpt'
CLASSIFIER_TRAIN_DATA_PATH = 'gs://tf-perfzero-data/bert/classification/mrpc_train.tf_record' CLASSIFIER_TRAIN_DATA_PATH = 'gs://tf-perfzero-data/bert/classification/mrpc_train.tf_record'
CLASSIFIER_EVAL_DATA_PATH = 'gs://tf-perfzero-data/bert/classification/mrpc_eval.tf_record' CLASSIFIER_EVAL_DATA_PATH = 'gs://tf-perfzero-data/bert/classification/mrpc_eval.tf_record'
CLASSIFIER_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/classification/mrpc_meta_data' CLASSIFIER_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/classification/mrpc_meta_data'
MODEL_CONFIG_FILE_PATH = 'gs://tf-perfzero-data/bert/bert_config' MODEL_CONFIG_FILE_PATH = 'gs://cloud-tpu-checkpoints/bert/tf_20/uncased_L-24_H-1024_A-16/bert_config'
# pylint: enable=line-too-long # pylint: enable=line-too-long
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
class BenchmarkTimerCallback(tf.keras.callbacks.Callback):
"""Callback that records time it takes to run each batch."""
def __init__(self, num_batches_to_skip=10):
super(BenchmarkTimerCallback, self).__init__()
self.num_batches_to_skip = num_batches_to_skip
self.timer_records = []
self.start_time = None
def on_batch_start(self, batch, logs=None):
if batch < self.num_batches_to_skip:
return
self.start_time = time.time()
def on_batch_end(self, batch, logs=None):
if batch < self.num_batches_to_skip:
return
assert self.start_time
self.timer_records.append(time.time() - self.start_time)
def get_examples_per_sec(self, batch_size):
return batch_size / np.mean(self.timer_records)
class BertBenchmarkBase(tf.test.Benchmark): class BertBenchmarkBase(tf.test.Benchmark):
"""Base class to hold methods common to test classes in the module.""" """Base class to hold methods common to test classes in the module."""
local_flags = None local_flags = None
def __init__(self, output_dir=None): def __init__(self, output_dir=None):
self.num_gpus = 8
if not output_dir: if not output_dir:
output_dir = '/tmp' output_dir = '/tmp'
self.output_dir = output_dir self.output_dir = output_dir
self.timer_callback = None
def _get_model_dir(self, folder_name): def _get_model_dir(self, folder_name):
"""Returns directory to store info, e.g. saved model and event log.""" """Returns directory to store info, e.g. saved model and event log."""
...@@ -55,6 +88,7 @@ class BertBenchmarkBase(tf.test.Benchmark): ...@@ -55,6 +88,7 @@ class BertBenchmarkBase(tf.test.Benchmark):
def _setup(self): def _setup(self):
"""Sets up and resets flags before each test.""" """Sets up and resets flags before each test."""
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.DEBUG) tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.DEBUG)
self.timer_callback = BenchmarkTimerCallback()
if BertBenchmarkBase.local_flags is None: if BertBenchmarkBase.local_flags is None:
# Loads flags to get defaults to then override. List cannot be empty. # Loads flags to get defaults to then override. List cannot be empty.
...@@ -64,22 +98,33 @@ class BertBenchmarkBase(tf.test.Benchmark): ...@@ -64,22 +98,33 @@ class BertBenchmarkBase(tf.test.Benchmark):
else: else:
flagsaver.restore_flag_values(BertBenchmarkBase.local_flags) flagsaver.restore_flag_values(BertBenchmarkBase.local_flags)
def _report_benchmark(self, stats, wall_time_sec): def _report_benchmark(self, stats, wall_time_sec, min_accuracy, max_accuracy):
"""Report benchmark results by writing to local protobuf file. """Report benchmark results by writing to local protobuf file.
Args: Args:
stats: dict returned from BERT models with known entries. stats: dict returned from BERT models with known entries.
wall_time_sec: the during of the benchmark execution in seconds wall_time_sec: the during of the benchmark execution in seconds
min_accuracy: Minimum classification accuracy constraint to verify
correctness of the model.
max_accuracy: Maximum classification accuracy constraint to verify
correctness of the model.
""" """
metrics = [{ metrics = [{
'name': 'training_loss', 'name': 'training_loss',
'value': stats['train_loss'], 'value': stats['train_loss'],
}, {
'name':
'examples_per_second',
'value':
self.timer_callback.get_examples_per_sec(FLAGS.train_batch_size)
}] }]
if 'eval_metrics' in stats: if 'eval_metrics' in stats:
metrics.append({ metrics.append({
'name': 'eval_accuracy', 'name': 'eval_accuracy',
'value': stats['eval_metrics'], 'value': stats['eval_metrics'],
'min_value': min_accuracy,
'max_value': max_accuracy,
}) })
self.report_benchmark( self.report_benchmark(
...@@ -88,16 +133,42 @@ class BertBenchmarkBase(tf.test.Benchmark): ...@@ -88,16 +133,42 @@ class BertBenchmarkBase(tf.test.Benchmark):
metrics=metrics) metrics=metrics)
@flagsaver.flagsaver @flagsaver.flagsaver
def _run_bert_classifier(self): def _run_bert_classifier(self, callbacks=None):
"""Starts BERT classification task."""
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8')) input_meta_data = json.loads(reader.read().decode('utf-8'))
strategy = tf.distribute.MirroredStrategy() bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
run_classifier.run_bert(strategy, input_meta_data) epochs = FLAGS.num_train_epochs
train_data_size = input_meta_data['train_data_size']
steps_per_epoch = int(train_data_size / FLAGS.train_batch_size)
class BertBenchmarkPerformanceTest(BertBenchmarkBase): warmup_steps = int(epochs * train_data_size * 0.1 / FLAGS.train_batch_size)
"""Short benchmark performance tests for BERT model.""" eval_steps = int(
math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size))
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy='mirrored', num_gpus=self.num_gpus)
run_classifier.run_customized_training(
strategy,
bert_config,
input_meta_data,
FLAGS.model_dir,
epochs,
steps_per_epoch,
eval_steps,
warmup_steps,
FLAGS.learning_rate,
FLAGS.init_checkpoint,
custom_callbacks=callbacks)
class BertClassifyBenchmark(BertBenchmarkBase):
"""Short benchmark performance tests for BERT model.
Tests BERT classification performance in different GPU configurations.
The naming convention of below test cases follow
`benchmark_(number of gpus)_gpu_(dataset type)` format.
"""
def __init__(self, output_dir=None, **kwargs): def __init__(self, output_dir=None, **kwargs):
self.train_data_path = CLASSIFIER_TRAIN_DATA_PATH self.train_data_path = CLASSIFIER_TRAIN_DATA_PATH
...@@ -105,13 +176,16 @@ class BertBenchmarkPerformanceTest(BertBenchmarkBase): ...@@ -105,13 +176,16 @@ class BertBenchmarkPerformanceTest(BertBenchmarkBase):
self.bert_config_file = MODEL_CONFIG_FILE_PATH self.bert_config_file = MODEL_CONFIG_FILE_PATH
self.input_meta_data_path = CLASSIFIER_INPUT_META_DATA_PATH self.input_meta_data_path = CLASSIFIER_INPUT_META_DATA_PATH
super(BertBenchmarkPerformanceTest, self).__init__(output_dir=output_dir) super(BertClassifyBenchmark, self).__init__(output_dir=output_dir)
def _run_and_report_benchmark(self, training_summary_path): def _run_and_report_benchmark(self,
training_summary_path,
min_accuracy=0,
max_accuracy=1):
"""Starts BERT performance benchmark test.""" """Starts BERT performance benchmark test."""
start_time_sec = time.time() start_time_sec = time.time()
self._run_bert_classifier() self._run_bert_classifier(callbacks=[self.timer_callback])
wall_time_sec = time.time() - start_time_sec wall_time_sec = time.time() - start_time_sec
with tf.io.gfile.GFile(training_summary_path, 'rb') as reader: with tf.io.gfile.GFile(training_summary_path, 'rb') as reader:
...@@ -120,14 +194,64 @@ class BertBenchmarkPerformanceTest(BertBenchmarkBase): ...@@ -120,14 +194,64 @@ class BertBenchmarkPerformanceTest(BertBenchmarkBase):
# Since we do not load from any pretrained checkpoints, we ignore all # Since we do not load from any pretrained checkpoints, we ignore all
# accuracy metrics. # accuracy metrics.
summary.pop('eval_metrics', None) summary.pop('eval_metrics', None)
super(BertBenchmarkPerformanceTest, self)._report_benchmark( super(BertClassifyBenchmark, self)._report_benchmark(
stats=summary, wall_time_sec=wall_time_sec) stats=summary,
wall_time_sec=wall_time_sec,
min_accuracy=min_accuracy,
max_accuracy=max_accuracy)
def benchmark_1_gpu_mrpc(self):
"""Test BERT model performance with 1 GPU."""
self._setup()
self.num_gpus = 1
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_mrpc')
FLAGS.train_data_path = self.train_data_path
FLAGS.eval_data_path = self.eval_data_path
FLAGS.input_meta_data_path = self.input_meta_data_path
FLAGS.bert_config_file = self.bert_config_file
FLAGS.train_batch_size = 4
FLAGS.eval_batch_size = 4
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
self._run_and_report_benchmark(summary_path)
def benchmark_2_gpu_mprc(self):
"""Test BERT model performance with 2 GPUs."""
self._setup()
self.num_gpus = 2
FLAGS.model_dir = self._get_model_dir('benchmark_2_gpu_mprc')
FLAGS.train_data_path = self.train_data_path
FLAGS.eval_data_path = self.eval_data_path
FLAGS.input_meta_data_path = self.input_meta_data_path
FLAGS.bert_config_file = self.bert_config_file
FLAGS.train_batch_size = 8
FLAGS.eval_batch_size = 8
def benchmark_8_gpu(self): summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
self._run_and_report_benchmark(summary_path)
def benchmark_4_gpu_mrpc(self):
"""Test BERT model performance with 4 GPUs."""
self._setup()
self.num_gpus = 4
FLAGS.model_dir = self._get_model_dir('benchmark_4_gpu_mrpc')
FLAGS.train_data_path = self.train_data_path
FLAGS.eval_data_path = self.eval_data_path
FLAGS.input_meta_data_path = self.input_meta_data_path
FLAGS.bert_config_file = self.bert_config_file
FLAGS.train_batch_size = 16
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
self._run_and_report_benchmark(summary_path)
def benchmark_8_gpu_mrpc(self):
"""Test BERT model performance with 8 GPUs.""" """Test BERT model performance with 8 GPUs."""
self._setup() self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu') FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_mrpc')
FLAGS.train_data_path = self.train_data_path FLAGS.train_data_path = self.train_data_path
FLAGS.eval_data_path = self.eval_data_path FLAGS.eval_data_path = self.eval_data_path
FLAGS.input_meta_data_path = self.input_meta_data_path FLAGS.input_meta_data_path = self.input_meta_data_path
...@@ -137,8 +261,13 @@ class BertBenchmarkPerformanceTest(BertBenchmarkBase): ...@@ -137,8 +261,13 @@ class BertBenchmarkPerformanceTest(BertBenchmarkBase):
self._run_and_report_benchmark(summary_path) self._run_and_report_benchmark(summary_path)
class BertBenchmarkAccuracyTest(BertBenchmarkBase): class BertClassifyAccuracy(BertBenchmarkBase):
"""Short benchmark test for BERT model that tests accuracy metrics.""" """Short accuracy test for BERT model.
Tests BERT classification task model accuracy. The naming
convention of below test cases follow
`benchmark_(number of gpus)_gpu_(dataset type)` format.
"""
def __init__(self, output_dir=None, **kwargs): def __init__(self, output_dir=None, **kwargs):
self.train_data_path = CLASSIFIER_TRAIN_DATA_PATH self.train_data_path = CLASSIFIER_TRAIN_DATA_PATH
...@@ -147,26 +276,37 @@ class BertBenchmarkAccuracyTest(BertBenchmarkBase): ...@@ -147,26 +276,37 @@ class BertBenchmarkAccuracyTest(BertBenchmarkBase):
self.input_meta_data_path = CLASSIFIER_INPUT_META_DATA_PATH self.input_meta_data_path = CLASSIFIER_INPUT_META_DATA_PATH
self.pretrained_checkpoint_path = PRETRAINED_CHECKPOINT_PATH self.pretrained_checkpoint_path = PRETRAINED_CHECKPOINT_PATH
super(BertBenchmarkAccuracyTest, self).__init__(output_dir=output_dir) super(BertClassifyAccuracy, self).__init__(output_dir=output_dir)
def _run_and_report_benchmark(self, training_summary_path): def _run_and_report_benchmark(self,
training_summary_path,
min_accuracy=0.84,
max_accuracy=0.88):
"""Starts BERT accuracy benchmark test.""" """Starts BERT accuracy benchmark test."""
start_time_sec = time.time() start_time_sec = time.time()
self._run_bert_classifier() self._run_bert_classifier(callbacks=[self.timer_callback])
wall_time_sec = time.time() - start_time_sec wall_time_sec = time.time() - start_time_sec
with tf.io.gfile.GFile(training_summary_path, 'rb') as reader: with tf.io.gfile.GFile(training_summary_path, 'rb') as reader:
summary = json.loads(reader.read().decode('utf-8')) summary = json.loads(reader.read().decode('utf-8'))
super(BertBenchmarkAccuracyTest, self)._report_benchmark( super(BertClassifyAccuracy, self)._report_benchmark(
stats=summary, wall_time_sec=wall_time_sec) stats=summary,
wall_time_sec=wall_time_sec,
min_accuracy=min_accuracy,
max_accuracy=max_accuracy)
def benchmark_8_gpu(self): def benchmark_8_gpu_mrpc(self):
"""Run BERT model accuracy test with 8 GPUs.""" """Run BERT model accuracy test with 8 GPUs.
Due to comparatively small cardinality of MRPC dataset, training
accuracy metric has high variance between trainings. As so, we
set the wide range of allowed accuracy (84% to 88%).
"""
self._setup() self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu') FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_mrpc')
FLAGS.train_data_path = self.train_data_path FLAGS.train_data_path = self.train_data_path
FLAGS.eval_data_path = self.eval_data_path FLAGS.eval_data_path = self.eval_data_path
FLAGS.input_meta_data_path = self.input_meta_data_path FLAGS.input_meta_data_path = self.input_meta_data_path
......
...@@ -23,28 +23,6 @@ import os ...@@ -23,28 +23,6 @@ import os
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
try:
import h5py as _ # pylint: disable=g-import-not-at-top
HAS_H5PY = True
except ImportError:
logging.warning('`h5py` is not installed. Please consider installing it '
'to save weights for long-running training.')
HAS_H5PY = False
def save_model(model, model_dir, weights_file):
"""Saves the model weights."""
weights_file_path = os.path.join(model_dir, weights_file)
del model_dir, weights_file # avoid accident usages.
if not HAS_H5PY:
logging.warning('`h5py` is not installed. Skip saving model weights.')
return
logging.info('Saving weights and optimizer states into %s', weights_file_path)
logging.info('This might take a while...')
model.save(weights_file_path, overwrite=True, include_optimizer=True)
def export_bert_model(model_export_path, def export_bert_model(model_export_path,
model=None, model=None,
......
...@@ -58,7 +58,8 @@ def run_customized_training_loop( ...@@ -58,7 +58,8 @@ def run_customized_training_loop(
eval_steps=None, eval_steps=None,
metric_fn=None, metric_fn=None,
init_checkpoint=None, init_checkpoint=None,
use_remote_tpu=False): use_remote_tpu=False,
custom_callbacks=None):
"""Run BERT pretrain model training using low-level API. """Run BERT pretrain model training using low-level API.
Arguments: Arguments:
...@@ -87,6 +88,9 @@ def run_customized_training_loop( ...@@ -87,6 +88,9 @@ def run_customized_training_loop(
`model_fn`. `model_fn`.
use_remote_tpu: If true, input pipeline ops are placed in TPU worker host use_remote_tpu: If true, input pipeline ops are placed in TPU worker host
as an optimization. as an optimization.
custom_callbacks: A list of Keras Callbacks objects to run during
training. More specifically, `on_batch_start()`, `on_batch_end()`,
methods are invoked during training.
Returns: Returns:
Trained model. Trained model.
...@@ -134,7 +138,12 @@ def run_customized_training_loop( ...@@ -134,7 +138,12 @@ def run_customized_training_loop(
optimizer = model.optimizer optimizer = model.optimizer
if init_checkpoint: if init_checkpoint:
sub_model.load_weights(init_checkpoint) logging.info(
'Checkpoint file %s found and restoring from '
'initial checkpoint for core model.', init_checkpoint)
checkpoint = tf.train.Checkpoint(model=sub_model)
checkpoint.restore(init_checkpoint).assert_consumed()
logging.info('Loading from checkpoint file completed')
metric = metric_fn() if metric_fn else None metric = metric_fn() if metric_fn else None
# If evaluation is required, make a copy of metric as it will be used by # If evaluation is required, make a copy of metric as it will be used by
...@@ -193,15 +202,28 @@ def run_customized_training_loop( ...@@ -193,15 +202,28 @@ def run_customized_training_loop(
metric_result) metric_result)
return metric_result return metric_result
def _run_callbacks_on_batch_start(batch):
"""Runs custom callbacks at the start of every step."""
if not custom_callbacks:
return
for callback in custom_callbacks:
callback.on_batch_start(batch)
def _run_callbacks_on_batch_end(batch):
"""Runs custom callbacks at the end of every step."""
if not custom_callbacks:
return
for callback in custom_callbacks:
callback.on_batch_end(batch)
# Training loop starts here. # Training loop starts here.
checkpoint = tf.train.Checkpoint(model=model) checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
latest_checkpoint_file = tf.train.latest_checkpoint(model_dir) latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
if latest_checkpoint_file: if latest_checkpoint_file:
logging.info( logging.info(
'Checkpoint file %s found and restoring from ' 'Checkpoint file %s found and restoring from '
'checkpoint', latest_checkpoint_file) 'checkpoint', latest_checkpoint_file)
checkpoint.restore( checkpoint.restore(latest_checkpoint_file)
latest_checkpoint_file).assert_existing_objects_matched()
logging.info('Loading from checkpoint file completed') logging.info('Loading from checkpoint file completed')
current_step = optimizer.iterations.numpy() current_step = optimizer.iterations.numpy()
...@@ -211,8 +233,10 @@ def run_customized_training_loop( ...@@ -211,8 +233,10 @@ def run_customized_training_loop(
eval_metric_result = None eval_metric_result = None
train_loss = None train_loss = None
while current_step < total_training_steps: while current_step < total_training_steps:
train_loss = train_step(train_iterator).numpy().astype(float)
current_step += 1 current_step += 1
_run_callbacks_on_batch_start(current_step)
train_loss = train_step(train_iterator).numpy().astype(float)
if train_metric: if train_metric:
train_metric_result = train_metric.result().numpy().astype(float) train_metric_result = train_metric.result().numpy().astype(float)
...@@ -223,6 +247,8 @@ def run_customized_training_loop( ...@@ -223,6 +247,8 @@ def run_customized_training_loop(
logging.info('Train Step: %d/%d / loss = %s', current_step, logging.info('Train Step: %d/%d / loss = %s', current_step,
total_training_steps, train_loss) total_training_steps, train_loss)
_run_callbacks_on_batch_end(current_step)
# Saves model checkpoints and run validation steps at every epoch end. # Saves model checkpoints and run validation steps at every epoch end.
if current_step % steps_per_epoch == 0: if current_step % steps_per_epoch == 0:
# To avoid repeated model saving, we do not save after the last # To avoid repeated model saving, we do not save after the last
......
...@@ -109,7 +109,8 @@ def run_customized_training(strategy, ...@@ -109,7 +109,8 @@ def run_customized_training(strategy,
warmup_steps, warmup_steps,
initial_lr, initial_lr,
init_checkpoint, init_checkpoint,
use_remote_tpu=False): use_remote_tpu=False,
custom_callbacks=None):
"""Run BERT classifier training using low-level API.""" """Run BERT classifier training using low-level API."""
max_seq_length = input_meta_data['max_seq_length'] max_seq_length = input_meta_data['max_seq_length']
num_classes = input_meta_data['num_labels'] num_classes = input_meta_data['num_labels']
...@@ -155,7 +156,8 @@ def run_customized_training(strategy, ...@@ -155,7 +156,8 @@ def run_customized_training(strategy,
eval_steps=eval_steps, eval_steps=eval_steps,
init_checkpoint=init_checkpoint, init_checkpoint=init_checkpoint,
metric_fn=metric_fn, metric_fn=metric_fn,
use_remote_tpu=use_remote_tpu) use_remote_tpu=use_remote_tpu,
custom_callbacks=custom_callbacks)
def export_classifier(model_export_path, input_meta_data): def export_classifier(model_export_path, input_meta_data):
......
...@@ -116,7 +116,7 @@ def run_customized_training(strategy, ...@@ -116,7 +116,7 @@ def run_customized_training(strategy,
initial_lr, steps_per_epoch * epochs, warmup_steps) initial_lr, steps_per_epoch * epochs, warmup_steps)
return pretrain_model, core_model return pretrain_model, core_model
model_training_utils.run_customized_training_loop( return model_training_utils.run_customized_training_loop(
strategy=strategy, strategy=strategy,
model_fn=_get_pretrain_model, model_fn=_get_pretrain_model,
loss_fn=get_loss_fn(), loss_fn=get_loss_fn(),
...@@ -175,7 +175,7 @@ def main(_): ...@@ -175,7 +175,7 @@ def main(_):
if strategy: if strategy:
print('***** Number of cores used : ', strategy.num_replicas_in_sync) print('***** Number of cores used : ', strategy.num_replicas_in_sync)
run_bert_pretrain(strategy) return run_bert_pretrain(strategy)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment