# coding=utf-8
# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Run BERT on SQuAD 1.1 and SQuAD 2.0."""

from __future__ import absolute_import, division, print_function

import collections
import json
import math
import os
import random
import shutil
import time
import pickle

import horovod.tensorflow as hvd
import numpy as np
import six
import tensorflow as tf
from tensorflow.python.client import device_lib

import modeling
import optimization
import tokenization
from utils.create_squad_data import *
from utils.utils import LogEvalRunHook, LogTrainRunHook

from ft_tensorflow_quantization import get_calibrators, QuantDense, QuantDescriptor

flags = tf.flags

FLAGS = flags.FLAGS

## Required parameters
flags.DEFINE_string(
    "bert_config_file", None,
    "The config json file corresponding to the pre-trained BERT model. "
    "This specifies the model architecture.")

flags.DEFINE_string("vocab_file", None,
                    "The vocabulary file that the BERT model was trained on.")

flags.DEFINE_string(
    "output_dir", None,
    "The output directory where the model checkpoints will be written.")

## Other parameters
flags.DEFINE_string("train_file", None,
                    "SQuAD json for training. E.g., train-v1.1.json")

flags.DEFINE_string(
    "predict_file", None,
    "SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")

flags.DEFINE_string(
    "init_checkpoint", None,
    "Initial checkpoint (usually from a pre-trained BERT model).")

flags.DEFINE_bool(
    "do_lower_case", True,
    "Whether to lower case the input text. Should be True for uncased "
    "models and False for cased models.")

flags.DEFINE_integer(
    "max_seq_length", 384,
    "The maximum total input sequence length after WordPiece tokenization. "
    "Sequences longer than this will be truncated, and sequences shorter "
    "than this will be padded.")

flags.DEFINE_integer(
    "doc_stride", 128,
    "When splitting up a long document into chunks, how much stride to "
    "take between chunks.")

flags.DEFINE_integer(
    "max_query_length", 64,
    "The maximum number of tokens for the question. Questions longer than "
    "this will be truncated to this length.")

flags.DEFINE_bool("do_train", False, "Whether to run training.")

flags.DEFINE_bool("do_predict", False, "Whether to run eval on the dev set.")

flags.DEFINE_integer("train_batch_size", 8, "Total batch size for training.")

flags.DEFINE_integer("predict_batch_size", 8,
                     "Total batch size for predictions.")

flags.DEFINE_float("learning_rate", 5e-6, "The initial learning rate for Adam.")

flags.DEFINE_bool("use_trt", False, "Whether to use TF-TRT")

flags.DEFINE_bool("horovod", False, "Whether to use Horovod for multi-gpu runs")
flags.DEFINE_float("num_train_epochs", 3.0,
                   "Total number of training epochs to perform.")

flags.DEFINE_float(
    "warmup_proportion", 0.1,
    "Proportion of training to perform linear learning rate warmup for. "
    "E.g., 0.1 = 10% of training.")

flags.DEFINE_integer("save_checkpoints_steps", 1000,
                     "How often to save the model checkpoint.")

flags.DEFINE_integer("iterations_per_loop", 1000,
                     "How many steps to make in each estimator call.")

flags.DEFINE_integer("num_accumulation_steps", 1,
                     "Number of accumulation steps before gradient update" 
                      "Global batch size = num_accumulation_steps * train_batch_size")

flags.DEFINE_integer(
    "n_best_size", 20,
    "The total number of n-best predictions to generate in the "
    "nbest_predictions.json output file.")

flags.DEFINE_integer(
    "max_answer_length", 30,
    "The maximum length of an answer that can be generated. This is needed "
    "because the start and end predictions are not conditioned on one another.")


flags.DEFINE_bool(
    "verbose_logging", False,
    "If true, all of the warnings related to data processing will be printed. "
    "A number of warnings are expected for a normal SQuAD evaluation.")

flags.DEFINE_bool(
    "version_2_with_negative", False,
    "If true, the SQuAD examples contain some that do not have an answer.")

flags.DEFINE_float(
    "null_score_diff_threshold", 0.0,
    "If null_score - best_non_null is greater than the threshold predict null.")

flags.DEFINE_bool("use_fp16", False, "Whether to use fp32 or fp16 arithmetic on GPU.")
flags.DEFINE_bool("use_xla", False, "Whether to enable XLA JIT compilation.")
flags.DEFINE_integer("num_eval_iterations", None,
                     "How many eval iterations to run - performs inference on subset")

# TRTIS Specific flags
flags.DEFINE_bool("export_trtis", False, "Whether to export saved model or run inference with TRTIS")
flags.DEFINE_string("trtis_model_name", "bert", "exports to appropriate directory for TRTIS")
flags.DEFINE_integer("trtis_model_version", 1, "exports to appropriate directory for TRTIS")
flags.DEFINE_string("trtis_server_url", "localhost:8001", "exports to appropriate directory for TRTIS")
flags.DEFINE_bool("trtis_model_overwrite", False, "If True, will overwrite an existing directory with the specified 'model_name' and 'version_name'")
flags.DEFINE_integer("trtis_max_batch_size", 8, "Specifies the 'max_batch_size' in the TRTIS model config. See the TRTIS documentation for more info.")
flags.DEFINE_float("trtis_dyn_batching_delay", 0, "Determines the dynamic_batching queue delay in milliseconds(ms) for the TRTIS model config. Use '0' or '-1' to specify static batching. See the TRTIS documentation for more info.")
flags.DEFINE_integer("trtis_engine_count", 1, "Specifies the 'instance_group' count value in the TRTIS model config. See the TRTIS documentation for more info.")

flags.DEFINE_bool("do_calib", False, "Whether to do calibration.")
flags.DEFINE_bool("if_quant", False, "Whether to quantize.")
flags.DEFINE_integer("calib_batch", 4, "Number of batches for calibration.")
flags.DEFINE_string("calib_method", "percentile", "calibration method [percentile, mse, max, entropy]")
flags.DEFINE_float("percentile", 99.99, "percentile for percentile calibrator")
flags.DEFINE_string("calibrator_file", "calibrators.pkl", "pickle file for calibrators")
flags.DEFINE_string("quant_mode", 'ft2', "predefined quantization mode, choices: ['ft1', 'ft2', 'ft3', 'trt']")

flags.DEFINE_bool("distillation", False, "Whether or not to use the techer-student model for finetuning (Knowledge distillation)")
flags.DEFINE_string("teacher", None, "teacher checkpoint file for distillation")
flags.DEFINE_float("distillation_loss_scale", 10000., "scale applied to distillation component of loss")

if FLAGS.quant_mode == 'ft1':
  KERNEL_AXIS = 1
  ACTIVATION_NARROW_RANGE = False
  DISABLE_LIST = ['aftergemm', 'softmax_input', 'residual_input', 'local_input', 'final_input']
  FUSE_QKV = False
elif FLAGS.quant_mode == 'ft2':
  KERNEL_AXIS = None
  ACTIVATION_NARROW_RANGE = False
  DISABLE_LIST = ['local_input', 'softmax_input', 'final_input']
  FUSE_QKV = True
elif FLAGS.quant_mode == 'ft3':
  KERNEL_AXIS = None
  ACTIVATION_NARROW_RANGE = False
  DISABLE_LIST = ['local_input', 'final_input']
  FUSE_QKV = True
elif FLAGS.quant_mode == 'trt':
  # for demobert
  KERNEL_AXIS = None
  ACTIVATION_NARROW_RANGE = False
  DISABLE_LIST = ['aftergemm', 'softmax_input']
  FUSE_QKV = True
else:
  raise ValueError("wrong argument value for 'quant_mode'")
input_desc = QuantDescriptor('input', narrow_range=ACTIVATION_NARROW_RANGE, disable_key_words=DISABLE_LIST)
kernel_desc = QuantDescriptor('kernel', axis=KERNEL_AXIS, disable_key_words=DISABLE_LIST)
QuantDense.set_default_quant_desc_input(input_desc)
QuantDense.set_default_quant_desc_kernel(kernel_desc)

class CalibrationHook(tf.train.SessionRunHook):
  def __init__(self, layer_num):
    self.layer_num = layer_num
    self.calibrator_lists = {}

  def begin(self):
    self.saver = tf.train.Saver()
    tf.compat.v1.logging.info("initializing calibrators")
    graph = tf.compat.v1.get_default_graph()
    self.calibrator_lists['input'] = get_calibrators('input', collector_type='histogram')
    self.calibrator_lists['kernel'] = get_calibrators('kernel', collector_type='max', axis=KERNEL_AXIS)
    for k in ['input', 'kernel']:
      tf.compat.v1.logging.info("There are {} calibrators in collection '{}'".format(len(self.calibrator_lists[k]), k))

    self.calib_step = [
      calibrator.calib_step_op(graph) for _, calib_list in self.calibrator_lists.items() for calibrator in calib_list]

    self.placeholders = {}
    self.load_min_op = {}
    self.load_max_op = {}
    self.calibrator_reverse_map = {}

    for _, calib_list in self.calibrator_lists.items():
      for i, calibrator in enumerate(calib_list):
        if calibrator.tensor_name_prefix in self.placeholders:
          raise ValueError("repeated name prefix")
        self.placeholders[calibrator.tensor_name_prefix] = tf.placeholder(tf.float32)
        self.load_min_op[calibrator.tensor_name_prefix] = tf.compat.v1.assign(graph.get_tensor_by_name(calibrator.quant_min_name),
            self.placeholders[calibrator.tensor_name_prefix])
        self.load_max_op[calibrator._tensor_name_prefix] = tf.compat.v1.assign(graph.get_tensor_by_name(calibrator.quant_max_name),
            self.placeholders[calibrator.tensor_name_prefix])
        self.calibrator_reverse_map[calibrator.tensor_name_prefix] = calibrator

  def before_run(self, run_context):
    tf.compat.v1.logging.info("registering calibration step")
    return tf.estimator.SessionRunArgs(
        fetches=self.calib_step)

  def end(self, session):
    tf.compat.v1.logging.info("computing calibration ranges")
    if FLAGS.calib_method == 'max':
      tf.compat.v1.logging.info("max calibration.")
      for calibrator in self.calibrator_lists['input']:
          calibrator.compute_range('max')
    elif FLAGS.calib_method == 'percentile':
      tf.compat.v1.logging.info("percentile calibration with value {}.".format(FLAGS.percentile))
      for calibrator in self.calibrator_lists['input']:
          calibrator.compute_range('percentile', percentile=FLAGS.percentile)
    elif FLAGS.calib_method == 'mse':
      tf.compat.v1.logging.info("mse calibration.")
      for calibrator in self.calibrator_lists['input']:
          calibrator.compute_range('mse')
    elif FLAGS.calib_method == 'entropy':
      tf.compat.v1.logging.info("entropy calibration.")
      for calibrator in self.calibrator_lists['input']:
          calibrator.compute_range('entropy')
    else:
      raise ValueError("Unsupported calibration method.")
    for calibrator in self.calibrator_lists['kernel']:
      calibrator.compute_range('max')

    if FUSE_QKV:
      tf.compat.v1.logging.info("fusing QKV")
      for i in range(self.layer_num):
        prefix = f"bert/encoder/layer_{i}/attention/self"
        tf.compat.v1.logging.info(f"FUSE_QKV: {prefix:50}")
        fuse_list = [self.calibrator_reverse_map[prefix + f"/{name}/kernel_quantizer"] for name in ['query', 'key', 'value']]
        self.fuse3(*fuse_list)
        fuse_list = [self.calibrator_reverse_map[prefix + f"/{name}/aftergemm_quantizer"] for name in ['query', 'key', 'value']]
        self.fuse3(*fuse_list)
        fuse_list = [self.calibrator_reverse_map[prefix + f"/matmul_{name}_input_quantizer"] for name in ['q', 'k', 'v']]
        self.fuse3(*fuse_list)

    tf.compat.v1.logging.info("loading calibration ranges")
    session.run(self.load_min_op, {self.placeholders[calibrator.tensor_name_prefix]:calibrator.calib_min for _, calib_list in self.calibrator_lists.items() for calibrator in calib_list})
    session.run(self.load_max_op, {self.placeholders[calibrator.tensor_name_prefix]:calibrator.calib_max for _, calib_list in self.calibrator_lists.items() for calibrator in calib_list})
    tf.compat.v1.logging.info("saving calibrated model")
    with open(os.path.join(FLAGS.output_dir, FLAGS.calibrator_file), 'wb') as f:
      pickle.dump(self.calibrator_lists, f)
    self.saver.save(session, os.path.join(FLAGS.output_dir, 'model.ckpt-calibrated'))

  def fuse3(self, qq, qk, qv):
    if not hasattr(qq, 'calib_min') or not hasattr(qk, 'calib_min') or not hasattr(qv, 'calib_min') or \
        not hasattr(qq, 'calib_max') or not hasattr(qk, 'calib_max') or not hasattr(qv, 'calib_max'):
        raise RuntimeError('missing min/max buffer, unable to fuse')
    qmax = qq.calib_max
    kmax = qk.calib_max
    vmax = qv.calib_max
    qmin = qq.calib_min
    kmin = qk.calib_min
    vmin = qv.calib_min
    amax = max(qmax, kmax, vmax)
    qq._calib_max = amax
    qk._calib_max = amax
    qv._calib_max = amax
    amin = min(qmin, kmin, vmin)
    qq._calib_min = amin
    qk._calib_min = amin
    qv._calib_min = amin
    tf.compat.v1.logging.info(
      f'          q={qmin:7.4f}/{qmax:7.4f} k={kmin:7.4f}/{kmax:7.4f} v={vmin:7.4f}/{vmax:7.4f} -> {amin:7.4f}/{amax:7.4f}')


def create_model(bert_config, is_training, input_ids, input_mask, segment_ids,
                 use_one_hot_embeddings, if_quant):
  """Creates a classification model."""
  model = modeling.BertModel(
      config=bert_config,
      is_training=is_training,
      input_ids=input_ids,
      input_mask=input_mask,
      token_type_ids=segment_ids,
      use_one_hot_embeddings=use_one_hot_embeddings,
      compute_type=tf.float32,
      if_quant=if_quant)

  final_hidden = model.get_sequence_output()

  final_hidden_shape = modeling.get_shape_list(final_hidden, expected_rank=3)
  batch_size = final_hidden_shape[0]
  seq_length = final_hidden_shape[1]
  hidden_size = final_hidden_shape[2]

  output_weights = tf.get_variable(
      "cls/squad/output_weights", [2, hidden_size],
      initializer=tf.truncated_normal_initializer(stddev=0.02))

  output_bias = tf.get_variable(
      "cls/squad/output_bias", [2], initializer=tf.zeros_initializer())

  final_hidden_matrix = tf.reshape(final_hidden,
                                   [batch_size * seq_length, hidden_size])
  logits = tf.matmul(final_hidden_matrix, output_weights, transpose_b=True)
  logits = tf.nn.bias_add(logits, output_bias)

  logits = tf.reshape(logits, [batch_size, seq_length, 2])
  logits = tf.transpose(logits, [2, 0, 1])

  unstacked_logits = tf.unstack(logits, axis=0, name='unstack')

  (start_logits, end_logits) = (unstacked_logits[0], unstacked_logits[1])

  return (start_logits, end_logits)

def get_frozen_tftrt_model(bert_config, shape, use_one_hot_embeddings, init_checkpoint):
  tf_config = tf.compat.v1.ConfigProto()
  tf_config.gpu_options.allow_growth = True
  output_node_names = ['unstack']

  with tf.Session(config=tf_config) as tf_sess:
    input_ids = tf.placeholder(tf.int32, shape, 'input_ids')
    input_mask = tf.placeholder(tf.int32, shape, 'input_mask')
    segment_ids = tf.placeholder(tf.int32, shape, 'segment_ids')

    (start_logits, end_logits) = create_model(bert_config=bert_config,
                                              is_training=False,
                                              input_ids=input_ids,
                                              input_mask=input_mask,
                                              segment_ids=segment_ids,
                                              use_one_hot_embeddings=use_one_hot_embeddings)


    tvars = tf.trainable_variables()
    (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
    tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
    tf_sess.run(tf.global_variables_initializer())
    print("LOADED!")
    tf.compat.v1.logging.info("**** Trainable Variables ****")
    for var in tvars:
      init_string = ""
      if var.name in initialized_variable_names:
        init_string = ", *INIT_FROM_CKPT*"
      else:
        init_string = ", *NOTTTTTTTTTTTTTTTTTTTTT"
        tf.compat.v1.logging.info("  name = %s, shape = %s%s", var.name, var.shape, init_string)

    frozen_graph = tf.graph_util.convert_variables_to_constants(tf_sess, 
            tf_sess.graph.as_graph_def(), output_node_names)

    num_nodes = len(frozen_graph.node)
    print('Converting graph using TensorFlow-TensorRT...')
    from tensorflow.python.compiler.tensorrt import trt_convert as trt
    converter = trt.TrtGraphConverter(
        input_graph_def=frozen_graph,
        nodes_blacklist=output_node_names,
        max_workspace_size_bytes=(4096 << 20) - 1000,
        precision_mode = "FP16" if FLAGS.use_fp16 else "FP32",
        minimum_segment_size=4,
        is_dynamic_op=True,
        maximum_cached_engines=1000
    )
    frozen_graph = converter.convert()

    print('Total node count before and after TF-TRT conversion:',
          num_nodes, '->', len(frozen_graph.node))
    print('TRT node count:',
          len([1 for n in frozen_graph.node if str(n.op) == 'TRTEngineOp']))
    
    with tf.io.gfile.GFile("frozen_modelTRT.pb", "wb") as f:
      f.write(frozen_graph.SerializeToString())      
        
  return frozen_graph


def model_fn_builder(bert_config, init_checkpoint, learning_rate,
                     num_train_steps, num_warmup_steps,
                     hvd=None, use_fp16=False, use_one_hot_embeddings=False):
  """Returns `model_fn` closure for Estimator."""

  def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
    """The `model_fn` for Estimator."""
    if FLAGS.verbose_logging:
        tf.compat.v1.logging.info("*** Features ***")
        for name in sorted(features.keys()):
          tf.compat.v1.logging.info("  name = %s, shape = %s" % (name, features[name].shape))

    unique_ids = features["unique_ids"]
    input_ids = features["input_ids"]
    input_mask = features["input_mask"]
    segment_ids = features["segment_ids"]

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    if FLAGS.if_quant:
      if_quant = True
    else:
      if_quant = False
    if FLAGS.do_calib and (mode == tf.estimator.ModeKeys.TRAIN):
      is_training = False
      if_quant = False

    if not is_training and FLAGS.use_trt:
        trt_graph = get_frozen_tftrt_model(bert_config, input_ids.shape, use_one_hot_embeddings, init_checkpoint)
        (start_logits, end_logits) = tf.import_graph_def(trt_graph,
                input_map={'input_ids':input_ids, 'input_mask':input_mask, 'segment_ids':segment_ids},
                return_elements=['unstack:0', 'unstack:1'],
                name='')
        predictions = {
            "unique_ids": unique_ids,
            "start_logits": start_logits,
            "end_logits": end_logits,
        }
        output_spec = tf.estimator.EstimatorSpec(
            mode=mode, predictions=predictions)
        return output_spec

    if is_training and FLAGS.distillation:
      tf.compat.v1.logging.info("initializing teacher model.")
      with tf.variable_scope("teacher"):
        (start_logits_teacher, end_logits_teacher) = create_model(
            bert_config=bert_config,
            is_training=False,
            input_ids=input_ids,
            input_mask=input_mask,
            segment_ids=segment_ids,
            use_one_hot_embeddings=use_one_hot_embeddings,
            if_quant=False)
      tvars = tf.trainable_variables()
      initialized_variable_names_t = {}
      if not FLAGS.teacher:
        raise ValueError("no teacher checkpoint is supplied.")
      if (hvd is None or hvd.rank() == 0):
        (assignment_map_t, initialized_variable_names_t
        ) = modeling.get_assignment_map_from_checkpoint(tvars, FLAGS.teacher, "teacher/")
        tf.train.init_from_checkpoint(FLAGS.teacher, assignment_map_t)
      trainable_vars = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
      del trainable_vars[:]

    tf.compat.v1.logging.info("!!!!!!!!!!if_quant is {}!!!!!!!!!!".format(if_quant))
    (start_logits, end_logits) = create_model(
        bert_config=bert_config,
        is_training=is_training,
        input_ids=input_ids,
        input_mask=input_mask,
        segment_ids=segment_ids,
        use_one_hot_embeddings=use_one_hot_embeddings,
        if_quant=if_quant)

    tvars = tf.trainable_variables()
    qvars = tf.get_collection("quantization_variables")

    initialized_variable_names = {}
    if init_checkpoint and (hvd is None or hvd.rank() == 0):
      tf.compat.v1.logging.info("restore from checkpoint: " + init_checkpoint)
      (assignment_map, initialized_variable_names
      ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
      (assignment_map_q, initialized_variable_names_q
      ) = modeling.get_assignment_map_from_checkpoint(qvars, init_checkpoint, allow_shape_mismatch=True)
      assignment_map.update(assignment_map_q)
      initialized_variable_names.update(initialized_variable_names_q)

      tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
    
    if FLAGS.verbose_logging:
        tf.compat.v1.logging.info("**** Trainable Variables ****")
        for var in tvars:
          init_string = ""
          if var.name in initialized_variable_names:
            init_string = ", *INIT_FROM_CKPT*"
          tf.compat.v1.logging.info(" %d name = %s, shape = %s%s", 0 if hvd is None else hvd.rank(), var.name, var.shape,
                          init_string)


    output_spec = None
    if mode == tf.estimator.ModeKeys.TRAIN:
      seq_length = modeling.get_shape_list(input_ids)[1]

      def compute_loss(logits, positions):
        one_hot_positions = tf.one_hot(
            positions, depth=seq_length, dtype=tf.float32)
        log_probs = tf.nn.log_softmax(logits, axis=-1)
        loss = -tf.reduce_mean(
            tf.reduce_sum(one_hot_positions * log_probs, axis=-1))
        return loss

      def fixprob(att, T=3):
        att = tf.nn.softmax(att/T, axis=-1) + 1e-9
        return att
      def kl_loss(x, y):
        x = fixprob(x)
        y = fixprob(y)
        X = tf.distributions.Categorical(probs=x)
        Y = tf.distributions.Categorical(probs=y)
        return tf.math.reduce_mean(tf.distributions.kl_divergence(X, Y))

      start_positions = features["start_positions"]
      end_positions = features["end_positions"]

      start_loss = compute_loss(start_logits, start_positions)
      end_loss = compute_loss(end_logits, end_positions)

      total_loss = (start_loss + end_loss) / 2.0

      if FLAGS.distillation:
        dloss = kl_loss(start_logits, start_logits_teacher) + kl_loss(end_logits, end_logits_teacher)
        total_loss = total_loss + dloss * FLAGS.distillation_loss_scale

      if FLAGS.do_calib:
        global_step = tf.compat.v1.train.get_or_create_global_step()
        new_global_step = global_step + 1
        new_global_step = tf.identity(new_global_step, name='step_update')
        train_op = tf.group(tf.no_op(), [global_step.assign(new_global_step)])
      else:
        train_op = optimization.create_optimizer(
            total_loss, learning_rate, num_train_steps, num_warmup_steps, hvd, False, use_fp16, FLAGS.num_accumulation_steps)

      output_spec = tf.estimator.EstimatorSpec(
          mode=mode,
          loss=total_loss,
          train_op=train_op)
    elif mode == tf.estimator.ModeKeys.PREDICT:
      predictions = {
          "unique_ids": unique_ids,
          "start_logits": start_logits,
          "end_logits": end_logits,
      }
      output_spec = tf.estimator.EstimatorSpec(
          mode=mode, predictions=predictions)
    else:
      raise ValueError(
          "Only TRAIN and PREDICT modes are supported: %s" % (mode))

    return output_spec

  return model_fn


def input_fn_builder(input_file, batch_size, seq_length, is_training, drop_remainder, hvd=None):
  """Creates an `input_fn` closure to be passed to Estimator."""

  name_to_features = {
      "unique_ids": tf.io.FixedLenFeature([], tf.int64),
      "input_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
      "input_mask": tf.io.FixedLenFeature([seq_length], tf.int64),
      "segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
  }

  if is_training:
    name_to_features["start_positions"] = tf.io.FixedLenFeature([], tf.int64)
    name_to_features["end_positions"] = tf.io.FixedLenFeature([], tf.int64)

  def _decode_record(record, name_to_features):
    """Decodes a record to a TensorFlow example."""
    example = tf.parse_single_example(record, name_to_features)

    # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
    # So cast all int64 to int32.
    for name in list(example.keys()):
      t = example[name]
      if t.dtype == tf.int64:
        t = tf.to_int32(t)
      example[name] = t

    return example

  def input_fn():
    """The actual input function."""

    # For training, we want a lot of parallel reading and shuffling.
    # For eval, we want no shuffling and parallel reading doesn't matter.
    if is_training:
        d = tf.data.TFRecordDataset(input_file, num_parallel_reads=4)
        if hvd is not None: d = d.shard(hvd.size(), hvd.rank())
        d = d.apply(tf.data.experimental.ignore_errors())
        d = d.shuffle(buffer_size=100)
        d = d.repeat()
    else:
        d = tf.data.TFRecordDataset(input_file)


    d = d.apply(
        tf.contrib.data.map_and_batch(
            lambda record: _decode_record(record, name_to_features),
            batch_size=batch_size,
            drop_remainder=drop_remainder))

    return d

  return input_fn



RawResult = collections.namedtuple("RawResult",
                                   ["unique_id", "start_logits", "end_logits"])


def write_predictions(all_examples, all_features, all_results, n_best_size,
                      max_answer_length, do_lower_case, output_prediction_file,
                      output_nbest_file, output_null_log_odds_file):
  """Write final predictions to the json file and log-odds of null if needed."""
  tf.compat.v1.logging.info("Writing predictions to: %s" % (output_prediction_file))
  tf.compat.v1.logging.info("Writing nbest to: %s" % (output_nbest_file))

  example_index_to_features = collections.defaultdict(list)
  for feature in all_features:
    example_index_to_features[feature.example_index].append(feature)

  unique_id_to_result = {}
  for result in all_results:
    unique_id_to_result[result.unique_id] = result

  _PrelimPrediction = collections.namedtuple(  # pylint: disable=invalid-name
      "PrelimPrediction",
      ["feature_index", "start_index", "end_index", "start_logit", "end_logit"])

  all_predictions = collections.OrderedDict()
  all_nbest_json = collections.OrderedDict()
  scores_diff_json = collections.OrderedDict()

  for (example_index, example) in enumerate(all_examples):
    features = example_index_to_features[example_index]

    prelim_predictions = []
    # keep track of the minimum score of null start+end of position 0
    score_null = 1000000  # large and positive
    min_null_feature_index = 0  # the paragraph slice with min mull score
    null_start_logit = 0  # the start logit at the slice with min null score
    null_end_logit = 0  # the end logit at the slice with min null score
    for (feature_index, feature) in enumerate(features):
      result = unique_id_to_result[feature.unique_id]
      start_indexes = _get_best_indexes(result.start_logits, n_best_size)
      end_indexes = _get_best_indexes(result.end_logits, n_best_size)
      # if we could have irrelevant answers, get the min score of irrelevant
      if FLAGS.version_2_with_negative:
        feature_null_score = result.start_logits[0] + result.end_logits[0]
        if feature_null_score < score_null:
          score_null = feature_null_score
          min_null_feature_index = feature_index
          null_start_logit = result.start_logits[0]
          null_end_logit = result.end_logits[0]
      for start_index in start_indexes:
        for end_index in end_indexes:
          # We could hypothetically create invalid predictions, e.g., predict
          # that the start of the span is in the question. We throw out all
          # invalid predictions.
          if start_index >= len(feature.tokens):
            continue
          if end_index >= len(feature.tokens):
            continue
          if start_index not in feature.token_to_orig_map:
            continue
          if end_index not in feature.token_to_orig_map:
            continue
          if not feature.token_is_max_context.get(start_index, False):
            continue
          if end_index < start_index:
            continue
          length = end_index - start_index + 1
          if length > max_answer_length:
            continue
          prelim_predictions.append(
              _PrelimPrediction(
                  feature_index=feature_index,
                  start_index=start_index,
                  end_index=end_index,
                  start_logit=result.start_logits[start_index],
                  end_logit=result.end_logits[end_index]))

    if FLAGS.version_2_with_negative:
      prelim_predictions.append(
          _PrelimPrediction(
              feature_index=min_null_feature_index,
              start_index=0,
              end_index=0,
              start_logit=null_start_logit,
              end_logit=null_end_logit))
    prelim_predictions = sorted(
        prelim_predictions,
        key=lambda x: (x.start_logit + x.end_logit),
        reverse=True)

    _NbestPrediction = collections.namedtuple(  # pylint: disable=invalid-name
        "NbestPrediction", ["text", "start_logit", "end_logit"])

    seen_predictions = {}
    nbest = []
    for pred in prelim_predictions:
      if len(nbest) >= n_best_size:
        break
      feature = features[pred.feature_index]
      if pred.start_index > 0:  # this is a non-null prediction
        tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
        orig_doc_start = feature.token_to_orig_map[pred.start_index]
        orig_doc_end = feature.token_to_orig_map[pred.end_index]
        orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
        tok_text = " ".join(tok_tokens)

        # De-tokenize WordPieces that have been split off.
        tok_text = tok_text.replace(" ##", "")
        tok_text = tok_text.replace("##", "")

        # Clean whitespace
        tok_text = tok_text.strip()
        tok_text = " ".join(tok_text.split())
        orig_text = " ".join(orig_tokens)

        final_text = get_final_text(tok_text, orig_text, do_lower_case)
        if final_text in seen_predictions:
          continue

        seen_predictions[final_text] = True
      else:
        final_text = ""
        seen_predictions[final_text] = True

      nbest.append(
          _NbestPrediction(
              text=final_text,
              start_logit=pred.start_logit,
              end_logit=pred.end_logit))

    # if we didn't include the empty option in the n-best, include it
    if FLAGS.version_2_with_negative:
      if "" not in seen_predictions:
        nbest.append(
            _NbestPrediction(
                text="", start_logit=null_start_logit,
                end_logit=null_end_logit))
    # In very rare edge cases we could have no valid predictions. So we
    # just create a nonce prediction in this case to avoid failure.
    if not nbest:
      nbest.append(
          _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))

    assert len(nbest) >= 1

    total_scores = []
    best_non_null_entry = None
    for entry in nbest:
      total_scores.append(entry.start_logit + entry.end_logit)
      if not best_non_null_entry:
        if entry.text:
          best_non_null_entry = entry

    probs = _compute_softmax(total_scores)

    nbest_json = []
    for (i, entry) in enumerate(nbest):
      output = collections.OrderedDict()
      output["text"] = entry.text
      output["probability"] = probs[i]
      output["start_logit"] = entry.start_logit
      output["end_logit"] = entry.end_logit
      nbest_json.append(output)

    assert len(nbest_json) >= 1

    if not FLAGS.version_2_with_negative:
      all_predictions[example.qas_id] = nbest_json[0]["text"]
    else:
      # predict "" iff the null score - the score of best non-null > threshold
      score_diff = score_null - best_non_null_entry.start_logit - (
          best_non_null_entry.end_logit)
      scores_diff_json[example.qas_id] = score_diff
      if score_diff > FLAGS.null_score_diff_threshold:
        all_predictions[example.qas_id] = ""
      else:
        all_predictions[example.qas_id] = best_non_null_entry.text

    all_nbest_json[example.qas_id] = nbest_json

  with tf.io.gfile.GFile(output_prediction_file, "w") as writer:
    writer.write(json.dumps(all_predictions, indent=4) + "\n")

  with tf.io.gfile.GFile(output_nbest_file, "w") as writer:
    writer.write(json.dumps(all_nbest_json, indent=4) + "\n")

  if FLAGS.version_2_with_negative:
    with tf.io.gfile.GFile(output_null_log_odds_file, "w") as writer:
      writer.write(json.dumps(scores_diff_json, indent=4) + "\n")


def get_final_text(pred_text, orig_text, do_lower_case):
  """Project the tokenized prediction back to the original text."""

  # When we created the data, we kept track of the alignment between original
  # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
  # now `orig_text` contains the span of our original text corresponding to the
  # span that we predicted.
  #
  # However, `orig_text` may contain extra characters that we don't want in
  # our prediction.
  #
  # For example, let's say:
  #   pred_text = steve smith
  #   orig_text = Steve Smith's
  #
  # We don't want to return `orig_text` because it contains the extra "'s".
  #
  # We don't want to return `pred_text` because it's already been normalized
  # (the SQuAD eval script also does punctuation stripping/lower casing but
  # our tokenizer does additional normalization like stripping accent
  # characters).
  #
  # What we really want to return is "Steve Smith".
  #
  # Therefore, we have to apply a semi-complicated alignment heruistic between
  # `pred_text` and `orig_text` to get a character-to-charcter alignment. This
  # can fail in certain cases in which case we just return `orig_text`.

  def _strip_spaces(text):
    ns_chars = []
    ns_to_s_map = collections.OrderedDict()
    for (i, c) in enumerate(text):
      if c == " ":
        continue
      ns_to_s_map[len(ns_chars)] = i
      ns_chars.append(c)
    ns_text = "".join(ns_chars)
    return (ns_text, ns_to_s_map)

  # We first tokenize `orig_text`, strip whitespace from the result
  # and `pred_text`, and check if they are the same length. If they are
  # NOT the same length, the heuristic has failed. If they are the same
  # length, we assume the characters are one-to-one aligned.
  tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case)

  tok_text = " ".join(tokenizer.tokenize(orig_text))

  start_position = tok_text.find(pred_text)
  if start_position == -1:
    if FLAGS.verbose_logging:
      tf.compat.v1.logging.info(
          "Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
    return orig_text
  end_position = start_position + len(pred_text) - 1

  (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
  (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)

  if len(orig_ns_text) != len(tok_ns_text):
    if FLAGS.verbose_logging:
      tf.compat.v1.logging.info("Length not equal after stripping spaces: '%s' vs '%s'",
                      orig_ns_text, tok_ns_text)
    return orig_text

  # We then project the characters in `pred_text` back to `orig_text` using
  # the character-to-character alignment.
  tok_s_to_ns_map = {}
  for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
    tok_s_to_ns_map[tok_index] = i

  orig_start_position = None
  if start_position in tok_s_to_ns_map:
    ns_start_position = tok_s_to_ns_map[start_position]
    if ns_start_position in orig_ns_to_s_map:
      orig_start_position = orig_ns_to_s_map[ns_start_position]

  if orig_start_position is None:
    if FLAGS.verbose_logging:
      tf.compat.v1.logging.info("Couldn't map start position")
    return orig_text

  orig_end_position = None
  if end_position in tok_s_to_ns_map:
    ns_end_position = tok_s_to_ns_map[end_position]
    if ns_end_position in orig_ns_to_s_map:
      orig_end_position = orig_ns_to_s_map[ns_end_position]

  if orig_end_position is None:
    if FLAGS.verbose_logging:
      tf.compat.v1.logging.info("Couldn't map end position")
    return orig_text

  output_text = orig_text[orig_start_position:(orig_end_position + 1)]
  return output_text


def _get_best_indexes(logits, n_best_size):
  """Get the n-best logits from a list."""
  index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)

  best_indexes = []
  for i in range(len(index_and_score)):
    if i >= n_best_size:
      break
    best_indexes.append(index_and_score[i][0])
  return best_indexes


def _compute_softmax(scores):
  """Compute softmax probability over raw logits."""
  if not scores:
    return []

  max_score = None
  for score in scores:
    if max_score is None or score > max_score:
      max_score = score

  exp_scores = []
  total_sum = 0.0
  for score in scores:
    x = math.exp(score - max_score)
    exp_scores.append(x)
    total_sum += x

  probs = []
  for score in exp_scores:
    probs.append(score / total_sum)
  return probs



def validate_flags_or_throw(bert_config):
  """Validate the input FLAGS or throw an exception."""
  tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                FLAGS.init_checkpoint)

  if not FLAGS.do_train and not FLAGS.do_calib and not FLAGS.do_predict and not FLAGS.export_trtis:
    raise ValueError("At least one of `do_train` or `do_predict` or `export_SavedModel` must be True.")

  if FLAGS.do_train or FLAGS.do_calib:
    if not FLAGS.train_file:
      raise ValueError(
          "If `do_train` or `do_calib` is True, then `train_file` must be specified.")
  if FLAGS.do_predict:
    if not FLAGS.predict_file:
      raise ValueError(
          "If `do_predict` is True, then `predict_file` must be specified.")

  if FLAGS.max_seq_length > bert_config.max_position_embeddings:
    raise ValueError(
        "Cannot use sequence length %d because the BERT model "
        "was only trained up to sequence length %d" %
        (FLAGS.max_seq_length, bert_config.max_position_embeddings))

  if FLAGS.max_seq_length <= FLAGS.max_query_length + 3:
    raise ValueError(
        "The max_seq_length (%d) must be greater than max_query_length "
        "(%d) + 3" % (FLAGS.max_seq_length, FLAGS.max_query_length))


def export_model(estimator, export_dir, init_checkpoint):
    """Exports a checkpoint in SavedModel format in a directory structure compatible with TRTIS."""


    def serving_input_fn():
        label_ids = tf.placeholder(tf.int32, [None,], name='unique_ids')
        input_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_ids')
        input_mask = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_mask')
        segment_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='segment_ids')
        input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
            'unique_ids': label_ids,
            'input_ids': input_ids,
            'input_mask': input_mask,
            'segment_ids': segment_ids,
        })()
        return input_fn

    saved_dir = estimator.export_savedmodel(
        export_dir,
        serving_input_fn,
        assets_extra=None,
        as_text=False,
        checkpoint_path=init_checkpoint,
        strip_default_attrs=False)

    model_name = FLAGS.trtis_model_name

    model_folder = export_dir + "/trtis_models/" + model_name
    version_folder = model_folder + "/" + str(FLAGS.trtis_model_version)
    final_model_folder = version_folder + "/model.savedmodel"

    if not os.path.exists(version_folder):
        os.makedirs(version_folder)
    
    if (not os.path.exists(final_model_folder)):
        os.rename(saved_dir, final_model_folder)
        print("Model saved to dir", final_model_folder)
    else:
        if (FLAGS.trtis_model_overwrite):
            shutil.rmtree(final_model_folder)
            os.rename(saved_dir, final_model_folder)
            print("WARNING: Existing model was overwritten. Model dir: {}".format(final_model_folder))
        else:
            print("ERROR: Could not save TRTIS model. Folder already exists. Use '--trtis_model_overwrite=True' if you would like to overwrite an existing model. Model dir: {}".format(final_model_folder))
            return

    # Now build the config for TRTIS. Check to make sure we can overwrite it, if it exists
    config_filename = os.path.join(model_folder, "config.pbtxt")

    if (os.path.exists(config_filename) and not FLAGS.trtis_model_overwrite):
        print("ERROR: Could not save TRTIS model config. Config file already exists. Use '--trtis_model_overwrite=True' if you would like to overwrite an existing model config. Model config: {}".format(config_filename))
        return
    
    config_template = r"""
name: "{model_name}"
platform: "tensorflow_savedmodel"
max_batch_size: {max_batch_size}
input [
    {{
        name: "unique_ids"
        data_type: TYPE_INT32
        dims: [ 1 ]
        reshape: {{ shape: [ ] }}
    }},
    {{
        name: "segment_ids"
        data_type: TYPE_INT32
        dims: {seq_length}
    }},
    {{
        name: "input_ids"
        data_type: TYPE_INT32
        dims: {seq_length}
    }},
    {{
        name: "input_mask"
        data_type: TYPE_INT32
        dims: {seq_length}
    }}
    ]
    output [
    {{
        name: "end_logits"
        data_type: TYPE_FP32
        dims: {seq_length}
    }},
    {{
        name: "start_logits"
        data_type: TYPE_FP32
        dims: {seq_length}
    }}
]
{dynamic_batching}
instance_group [
    {{
        count: {engine_count}
        kind: KIND_GPU
        gpus: [{gpu_list}]
    }}
]"""

    batching_str = ""
    max_batch_size = FLAGS.trtis_max_batch_size

    if (FLAGS.trtis_dyn_batching_delay > 0):

        # Use only full and half full batches
        pref_batch_size = [int(max_batch_size / 2.0), max_batch_size]

        batching_str = r"""
dynamic_batching {{
    preferred_batch_size: [{0}]
    max_queue_delay_microseconds: {1}
}}""".format(", ".join([str(x) for x in pref_batch_size]), int(FLAGS.trtis_dyn_batching_delay * 1000.0))

    config_values = {
        "model_name": model_name,
        "max_batch_size": max_batch_size,
        "seq_length": FLAGS.max_seq_length,
        "dynamic_batching": batching_str,
        "gpu_list": ", ".join([x.name.split(":")[-1] for x in device_lib.list_local_devices() if x.device_type == "GPU"]),
        "engine_count": FLAGS.trtis_engine_count
    }

    with open(model_folder + "/config.pbtxt", "w") as file:

        final_config_str = config_template.format_map(config_values)
        file.write(final_config_str)

def main(_):
  tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)

  if FLAGS.horovod:
    hvd.init()
  if FLAGS.use_fp16:
    os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1"

  bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

  validate_flags_or_throw(bert_config)

  tf.io.gfile.makedirs(FLAGS.output_dir)

  tokenizer = tokenization.FullTokenizer(
      vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)

  master_process = True
  training_hooks = []
  global_batch_size = FLAGS.train_batch_size * FLAGS.num_accumulation_steps
  hvd_rank = 0

  config = tf.compat.v1.ConfigProto()
  learning_rate = FLAGS.learning_rate
  if FLAGS.horovod:

      tf.compat.v1.logging.info("Multi-GPU training with TF Horovod")
      tf.compat.v1.logging.info("hvd.size() = %d hvd.rank() = %d", hvd.size(), hvd.rank())
      global_batch_size = FLAGS.train_batch_size * hvd.size() * FLAGS.num_accumulation_steps
      learning_rate = learning_rate * hvd.size()
      master_process = (hvd.rank() == 0)
      hvd_rank = hvd.rank()
      config.gpu_options.visible_device_list = str(hvd.local_rank())
      if hvd.size() > 1:
          training_hooks.append(hvd.BroadcastGlobalVariablesHook(0))
  if FLAGS.use_xla:
    config.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.ON_1
  run_config = tf.estimator.RunConfig(
      model_dir=FLAGS.output_dir if master_process else None,
      session_config=config,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps if master_process else None,
      keep_checkpoint_max=1)

  if master_process:
      tf.compat.v1.logging.info("***** Configuration *****")
      for key in FLAGS.__flags.keys():
          tf.compat.v1.logging.info('  {}: {}'.format(key, getattr(FLAGS, key)))
      tf.compat.v1.logging.info("**************************")

  train_examples = None
  num_train_steps = None
  num_warmup_steps = None
  if FLAGS.do_calib:
    training_hooks.append(CalibrationHook(bert_config.num_hidden_layers))
  training_hooks.append(LogTrainRunHook(global_batch_size, hvd_rank, FLAGS.save_checkpoints_steps))

  # Prepare Training Data
  if FLAGS.do_train or (FLAGS.do_calib and master_process):
    train_examples = read_squad_examples(
        input_file=FLAGS.train_file, is_training=True,
        version_2_with_negative=FLAGS.version_2_with_negative)
    num_train_steps = int(
        len(train_examples) / global_batch_size * FLAGS.num_train_epochs)
    num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    # Pre-shuffle the input to avoid having to make a very large shuffle
    # buffer in in the `input_fn`.
    rng = random.Random(12345)
    rng.shuffle(train_examples)
    if FLAGS.do_calib:
      num_train_steps = FLAGS.calib_batch

    start_index = 0 
    if FLAGS.do_calib:
      end_index = min(len(train_examples), num_train_steps * global_batch_size)
    else:
      end_index = len(train_examples)
    tmp_filenames = [os.path.join(FLAGS.output_dir, "train.tf_record")]

    if FLAGS.horovod:
      tmp_filenames = [os.path.join(FLAGS.output_dir, "train.tf_record{}".format(i)) for i in range(hvd.size())]
      num_examples_per_rank = len(train_examples) // hvd.size()
      remainder = len(train_examples) % hvd.size()
      if hvd.rank() < remainder:
        start_index = hvd.rank() * (num_examples_per_rank+1)
        end_index = start_index + num_examples_per_rank + 1
      else:
        start_index = hvd.rank() * num_examples_per_rank + remainder
        end_index = start_index + (num_examples_per_rank)


  model_fn = model_fn_builder(
      bert_config=bert_config,
      init_checkpoint=FLAGS.init_checkpoint,
      learning_rate=learning_rate,
      num_train_steps=num_train_steps,
      num_warmup_steps=num_warmup_steps,
      hvd=None if not FLAGS.horovod else hvd,
      use_fp16=FLAGS.use_fp16)

  estimator = tf.estimator.Estimator(
      model_fn=model_fn,
      config=run_config)

  if FLAGS.do_train or (FLAGS.do_calib and master_process):

    # We write to a temporary file to avoid storing very large constant tensors
    # in memory.
    train_writer = FeatureWriter(
        filename=tmp_filenames[hvd_rank],
        is_training=True)
    convert_examples_to_features(
        examples=train_examples[start_index:end_index],
        tokenizer=tokenizer,
        max_seq_length=FLAGS.max_seq_length,
        doc_stride=FLAGS.doc_stride,
        max_query_length=FLAGS.max_query_length,
        is_training=True,
        output_fn=train_writer.process_feature,
        verbose_logging=FLAGS.verbose_logging)
    train_writer.close()

    tf.compat.v1.logging.info("***** Running training *****")
    tf.compat.v1.logging.info("  Num orig examples = %d", end_index - start_index)
    tf.compat.v1.logging.info("  Num split examples = %d", train_writer.num_features)
    tf.compat.v1.logging.info("  Batch size = %d", FLAGS.train_batch_size)
    tf.compat.v1.logging.info("  Num steps = %d", num_train_steps)
    tf.compat.v1.logging.info("  LR = %f", learning_rate)
    del train_examples

    train_input_fn = input_fn_builder(
        input_file=tmp_filenames,
        batch_size=FLAGS.train_batch_size,
        seq_length=FLAGS.max_seq_length,
        is_training=True,
        drop_remainder=True,
        hvd=None if not FLAGS.horovod else hvd)

    train_start_time = time.time()
    estimator.train(input_fn=train_input_fn, hooks=training_hooks, max_steps=num_train_steps)
    train_time_elapsed = time.time() - train_start_time
    train_time_wo_overhead = training_hooks[-1].total_time
    avg_sentences_per_second = num_train_steps * global_batch_size * 1.0 / train_time_elapsed
    ss_sentences_per_second = (num_train_steps - training_hooks[-1].skipped) * global_batch_size * 1.0 / train_time_wo_overhead

    if master_process:
        tf.compat.v1.logging.info("-----------------------------")
        tf.compat.v1.logging.info("Total Training Time = %0.2f for Sentences = %d", train_time_elapsed,
                        num_train_steps * global_batch_size)
        tf.compat.v1.logging.info("Total Training Time W/O Overhead = %0.2f for Sentences = %d", train_time_wo_overhead,
                        (num_train_steps - training_hooks[-1].skipped) * global_batch_size)
        tf.compat.v1.logging.info("Throughput Average (sentences/sec) with overhead = %0.2f", avg_sentences_per_second)
        tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second)
        tf.compat.v1.logging.info("-----------------------------")


  if FLAGS.export_trtis and master_process:
    export_model(estimator, FLAGS.output_dir, FLAGS.init_checkpoint)

  if FLAGS.do_predict and master_process:
    eval_examples = read_squad_examples(
        input_file=FLAGS.predict_file, is_training=False,
        version_2_with_negative=FLAGS.version_2_with_negative)

    # Perform evaluation on subset, useful for profiling
    if FLAGS.num_eval_iterations is not None:
        eval_examples = eval_examples[:FLAGS.num_eval_iterations*FLAGS.predict_batch_size]

    eval_writer = FeatureWriter(
        filename=os.path.join(FLAGS.output_dir, "eval.tf_record"),
        is_training=False)
    eval_features = []

    def append_feature(feature):
      eval_features.append(feature)
      eval_writer.process_feature(feature)

    convert_examples_to_features(
        examples=eval_examples,
        tokenizer=tokenizer,
        max_seq_length=FLAGS.max_seq_length,
        doc_stride=FLAGS.doc_stride,
        max_query_length=FLAGS.max_query_length,
        is_training=False,
        output_fn=append_feature,
        verbose_logging=FLAGS.verbose_logging)
    eval_writer.close()

    tf.compat.v1.logging.info("***** Running predictions *****")
    tf.compat.v1.logging.info("  Num orig examples = %d", len(eval_examples))
    tf.compat.v1.logging.info("  Num split examples = %d", len(eval_features))
    tf.compat.v1.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

    predict_input_fn = input_fn_builder(
        input_file=eval_writer.filename,
        batch_size=FLAGS.predict_batch_size,
        seq_length=FLAGS.max_seq_length,
        is_training=False,
        drop_remainder=False)

    all_results = []
    eval_hooks = [LogEvalRunHook(FLAGS.predict_batch_size)]
    eval_start_time = time.time()
    for result in estimator.predict(
        predict_input_fn, yield_single_examples=True, hooks=eval_hooks):
      if len(all_results) % 1000 == 0:
        tf.compat.v1.logging.info("Processing example: %d" % (len(all_results)))
      unique_id = int(result["unique_ids"])
      start_logits = [float(x) for x in result["start_logits"].flat]
      end_logits = [float(x) for x in result["end_logits"].flat]
      all_results.append(
          RawResult(
              unique_id=unique_id,
              start_logits=start_logits,
              end_logits=end_logits))

    eval_time_elapsed = time.time() - eval_start_time
    eval_time_wo_overhead = eval_hooks[-1].total_time

    time_list = eval_hooks[-1].time_list
    time_list.sort()
    num_sentences = (eval_hooks[-1].count - eval_hooks[-1].skipped) * FLAGS.predict_batch_size

    avg = np.mean(time_list)
    cf_50 = max(time_list[:int(len(time_list) * 0.50)])
    cf_90 = max(time_list[:int(len(time_list) * 0.90)])
    cf_95 = max(time_list[:int(len(time_list) * 0.95)])
    cf_99 = max(time_list[:int(len(time_list) * 0.99)])
    cf_100 = max(time_list[:int(len(time_list) * 1)])
    ss_sentences_per_second = num_sentences * 1.0 / eval_time_wo_overhead

    tf.compat.v1.logging.info("-----------------------------")
    tf.compat.v1.logging.info("Total Inference Time = %0.2f for Sentences = %d", eval_time_elapsed,
                    eval_hooks[-1].count * FLAGS.predict_batch_size)
    tf.compat.v1.logging.info("Total Inference Time W/O Overhead = %0.2f for Sentences = %d", eval_time_wo_overhead,
                    (eval_hooks[-1].count - eval_hooks[-1].skipped) * FLAGS.predict_batch_size)
    tf.compat.v1.logging.info("Summary Inference Statistics")
    tf.compat.v1.logging.info("Batch size = %d", FLAGS.predict_batch_size)
    tf.compat.v1.logging.info("Sequence Length = %d", FLAGS.max_seq_length)
    tf.compat.v1.logging.info("Precision = %s", "fp16" if FLAGS.use_fp16 else "fp32")
    tf.compat.v1.logging.info("Latency Confidence Level 50 (ms) = %0.2f", cf_50 * 1000)
    tf.compat.v1.logging.info("Latency Confidence Level 90 (ms) = %0.2f", cf_90 * 1000)
    tf.compat.v1.logging.info("Latency Confidence Level 95 (ms) = %0.2f", cf_95 * 1000)
    tf.compat.v1.logging.info("Latency Confidence Level 99 (ms) = %0.2f", cf_99 * 1000)
    tf.compat.v1.logging.info("Latency Confidence Level 100 (ms) = %0.2f", cf_100 * 1000)
    tf.compat.v1.logging.info("Latency Average (ms) = %0.2f", avg * 1000)
    tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second)
    tf.compat.v1.logging.info("-----------------------------")

    output_prediction_file = os.path.join(FLAGS.output_dir, "predictions.json")
    output_nbest_file = os.path.join(FLAGS.output_dir, "nbest_predictions.json")
    output_null_log_odds_file = os.path.join(FLAGS.output_dir, "null_odds.json")

    write_predictions(eval_examples, eval_features, all_results,
                      FLAGS.n_best_size, FLAGS.max_answer_length,
                      FLAGS.do_lower_case, output_prediction_file,
                      output_nbest_file, output_null_log_odds_file)


if __name__ == "__main__":
  flags.mark_flag_as_required("vocab_file")
  flags.mark_flag_as_required("bert_config_file")
  flags.mark_flag_as_required("output_dir")
  tf.compat.v1.app.run()
