# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """XLNet classification finetuning runner in tf2.0.""" from __future__ import absolute_import from __future__ import division # from __future__ import google_type_annotations from __future__ import print_function import os import re from absl import logging # pytype: disable=attribute-error # pylint: disable=g-bare-generic,unused-import import tensorflow as tf # Initialize TPU System. from official.nlp.xlnet import data_utils from official.nlp import xlnet_modeling as modeling from typing import Any, Callable, Dict, Text, Optional _MIN_SUMMARY_STEPS = 10 def _save_checkpoint(checkpoint, model_dir, checkpoint_prefix): """Saves model to with provided checkpoint prefix.""" checkpoint_path = os.path.join(model_dir, checkpoint_prefix) saved_path = checkpoint.save(checkpoint_path) logging.info("Saving model as TF checkpoint: %s", saved_path) return def _float_metric_value(metric): """Gets the value of a float-value keras metric.""" return metric.result().numpy().astype(float) def _steps_to_run(current_step, steps_per_epoch, steps_per_loop): """Calculates steps to run on device.""" if steps_per_loop <= 0: raise ValueError("steps_per_loop should be positive integer.") if steps_per_loop == 1: return steps_per_loop remainder_in_epoch = current_step % steps_per_epoch if remainder_in_epoch != 0: return min(steps_per_epoch - remainder_in_epoch, steps_per_loop) else: return steps_per_loop def train( strategy: tf.distribute.Strategy, model_fn: Callable, input_meta_data: Dict, train_input_fn: Callable, total_training_steps: int, steps_per_epoch: int, steps_per_loop: int, optimizer: tf.keras.optimizers.Optimizer, learning_rate_fn: tf.keras.optimizers.schedules.LearningRateSchedule, eval_fn: Optional[Callable[[tf.keras.Model, int, tf.summary.SummaryWriter], Any]] = None, metric_fn: Optional[Callable[[], tf.keras.metrics.Metric]] = None, test_input_fn: Optional[Callable] = None, init_checkpoint: Optional[Text] = None, model_dir: Optional[Text] = None, save_steps: Optional[int] = None, run_eagerly: Optional[bool] = False): """Runs customized training. Args: strategy: Distribution strategy on which to run low level training loop. model_fn: The function returns a keras.Model. input_meta_data: A dictionary of params: `mem_len`, `lr_layer_decay_rate`, `n_layer`, `batch_size_per_core` and `d_model`. train_input_fn: Function returns a tf.data.Dataset used for training. total_training_steps: Number of steps to train in total. steps_per_epoch: Number of steps to run per epoch. At the end of each epoch, model checkpoint will be saved and evaluation will be conducted if evaluation dataset is provided. steps_per_loop: Number of steps per graph-mode loop. In order to reduce communication in eager context, training logs are printed every steps_per_loop. optimizer: The optimizer for model. learning_rate_fn: the learning rate schedule. eval_fn: A callback of evaluation function, that takes a keras.Model, current step and evaluation summary writer. metric_fn: A metrics function returns a Keras Metric object to record evaluation result using evaluation dataset or with training dataset after every epoch. test_input_fn: Function returns a evaluation dataset. If none, evaluation is skipped. init_checkpoint: Optional checkpoint to load to `sub_model` returned by `model_fn`. model_dir: The directory of model (checkpoints, summaries). save_steps: The frequency to save checkpoints. Every save_steps, we save a model checkpoint. run_eagerly: Whether to run training eagerly. Returns: Last training step logits if training happens, otherwise returns None. Raises: TypeError: if model directory is not specified. """ required_arguments = [ train_input_fn, total_training_steps, steps_per_epoch, steps_per_loop, optimizer, learning_rate_fn ] if [arg for arg in required_arguments if arg is None]: raise ValueError("`train_input_fn`, `total_training_steps`, " "`steps_per_epoch`, `steps_per_loop`, `optimizer` and " "`learning_rate_fn` are required parameters.") if not model_dir: raise TypeError("Model directory must be specified.") # pylint: disable=protected-access train_iterator = data_utils._get_input_iterator(train_input_fn, strategy) # pylint: enable=protected-access train_summary_writer = None eval_summary_writer = None if not tf.io.gfile.exists(model_dir): tf.io.gfile.mkdir(model_dir) if test_input_fn: eval_summary_writer = tf.summary.create_file_writer( os.path.join(model_dir, "summaries/eval")) if steps_per_loop >= _MIN_SUMMARY_STEPS: # Only writes summary when the stats are collected sufficiently over # enough steps. train_summary_writer = tf.summary.create_file_writer( os.path.join(model_dir, "summaries/train")) with strategy.scope(): model = model_fn() if init_checkpoint: logging.info("restore from %s", init_checkpoint) checkpoint = tf.train.Checkpoint(model=model) checkpoint.restore(init_checkpoint) model.optimizer = optimizer if not hasattr(model, "optimizer"): raise ValueError("User should set optimizer attribute to model.") train_loss_metric = tf.keras.metrics.Mean("training_loss", dtype=tf.float32) train_metric = None if metric_fn: train_metric = metric_fn() def _replicated_step(inputs, mem=None): """Replicated training step.""" inputs["mems"] = mem with tf.GradientTape() as tape: mem, logits = model(inputs, training=True) loss = model.losses train_loss_metric.update_state(loss) if train_metric: train_metric.update_state(inputs["label_ids"], logits) scaled_loss = loss[0] * 1.0 / float(strategy.num_replicas_in_sync) # Collects training variables. tvars = model.trainable_variables grads = tape.gradient(scaled_loss, tvars) clipped, _ = tf.clip_by_global_norm(grads, clip_norm=1.0) if input_meta_data["lr_layer_decay_rate"] != 1.0: n_layer = 0 for i in range(len(clipped)): m = re.search(r"model/transformer/layer_(\d+?)/", tvars[i].name) if not m: continue n_layer = max(n_layer, int(m.group(1)) + 1) for i in range(len(clipped)): for l in range(n_layer): if "model/transformer/layer_{}/".format(l) in tvars[i].name: abs_rate = input_meta_data["lr_layer_decay_rate"]**( n_layer - 1 - l) clipped[i] *= abs_rate logging.info("Apply mult {:.4f} to layer-{} grad of {}".format( abs_rate, l, tvars[i].name)) break optimizer.apply_gradients(zip(clipped, tvars)) if input_meta_data["mem_len"] > 0: return mem def train_steps(iterator, steps): """Performs distributed training steps in a loop. Args: iterator: the distributed iterator of training datasets. steps: an tf.int32 integer tensor to specify number of steps to run inside host training loop. Raises: ValueError: Any of the arguments or tensor shapes are invalid. Returns: logits: logits computed. """ if not isinstance(steps, tf.Tensor): raise ValueError("steps should be an Tensor. Python object may cause " "retracing.") def cache_fn(): """Initializes memory tensor used in XLNet pretraining.""" mems = [] if input_meta_data["mem_len"] > 0: for _ in range(input_meta_data["n_layer"]): zeros = tf.zeros([ input_meta_data["mem_len"], input_meta_data["batch_size_per_core"], input_meta_data["d_model"] ], dtype=tf.float32) mems.append(zeros) return mems if input_meta_data["mem_len"] > 0: mem = strategy.experimental_run_v2(cache_fn) for _ in tf.range(steps): mem = strategy.experimental_run_v2( _replicated_step, args=( next(iterator), mem, )) else: for _ in tf.range(steps): strategy.experimental_run_v2(_replicated_step, args=(next(iterator),)) if not run_eagerly: train_steps = tf.function(train_steps) logging.info("Start training...") checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) latest_checkpoint_file = tf.train.latest_checkpoint(model_dir) if latest_checkpoint_file: logging.info("Checkpoint file %s found and restoring from checkpoint", latest_checkpoint_file) checkpoint.restore(latest_checkpoint_file) logging.info("Loading from checkpoint file completed") current_step = optimizer.iterations.numpy() checkpoint_name = "xlnet_step_{step}.ckpt" while current_step < total_training_steps: train_loss_metric.reset_states() if train_metric: train_metric.reset_states() steps = _steps_to_run(current_step, steps_per_epoch, steps_per_loop) train_steps(train_iterator, tf.convert_to_tensor(steps, dtype=tf.int32)) current_step += steps train_loss = _float_metric_value(train_loss_metric) log_stream = "Train step: %d/%d / lr = %.9f / loss = %.7f" % ( current_step, total_training_steps, learning_rate_fn(current_step), train_loss) if train_metric: log_stream += " / %s = %f" % (train_metric.name, _float_metric_value(train_metric)) logging.info(log_stream) if train_summary_writer: with train_summary_writer.as_default(): tf.summary.scalar( "learning_rate", learning_rate_fn(current_step), step=current_step) tf.summary.scalar( train_loss_metric.name, train_loss, step=current_step) if train_metric: tf.summary.scalar( train_metric.name, _float_metric_value(train_metric), step=current_step) train_summary_writer.flush() if model_dir: if (save_steps is None) or (save_steps and current_step % save_steps == 0): _save_checkpoint(checkpoint, model_dir, checkpoint_name.format(step=current_step)) if test_input_fn and current_step % steps_per_epoch == 0: logging.info("Running evaluation after step: %s.", current_step) eval_fn(model, current_step, eval_summary_writer) if model_dir: _save_checkpoint(checkpoint, model_dir, checkpoint_name.format(step=current_step)) if test_input_fn: logging.info("Running final evaluation after training is complete.") eval_fn(model, current_step, eval_summary_writer) return model