"vscode:/vscode.git/clone" did not exist on "1a3a1adea3975adbd2e27770bf174cff3c7a3df3"
Commit 986b0825 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Add init_checkpoint for nhnet benchmark.

PiperOrigin-RevId: 313205490
parent 5110d3a2
...@@ -32,6 +32,7 @@ from official.utils.flags import core as flags_core ...@@ -32,6 +32,7 @@ from official.utils.flags import core as flags_core
MIN_LOSS = 0.40 MIN_LOSS = 0.40
MAX_LOSS = 0.55 MAX_LOSS = 0.55
NHNET_DATA = 'gs://tf-perfzero-data/nhnet/v1/processed/train.tfrecord*' NHNET_DATA = 'gs://tf-perfzero-data/nhnet/v1/processed/train.tfrecord*'
PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_model.ckpt'
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -126,6 +127,7 @@ class NHNetAccuracyBenchmark(NHNetBenchmark): ...@@ -126,6 +127,7 @@ class NHNetAccuracyBenchmark(NHNetBenchmark):
FLAGS.train_steps = 50000 FLAGS.train_steps = 50000
FLAGS.checkpoint_interval = FLAGS.train_steps FLAGS.checkpoint_interval = FLAGS.train_steps
FLAGS.distribution_strategy = 'tpu' FLAGS.distribution_strategy = 'tpu'
FLAGS.init_checkpoint = PRETRAINED_CHECKPOINT_PATH
FLAGS.model_dir = self._get_model_dir( FLAGS.model_dir = self._get_model_dir(
'benchmark_accuracy_4x4_tpu_bf32_50k_steps') 'benchmark_accuracy_4x4_tpu_bf32_50k_steps')
self._run_and_report_benchmark() self._run_and_report_benchmark()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment