Commit 32e4ca51 authored by qianyj's avatar qianyj
Browse files

Update code to v2.11.0

parents 9485aa1d 71060f67
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -16,7 +16,7 @@
from absl import logging
from official.nlp.xlnet import data_utils
from official.legacy.xlnet import data_utils
SEG_ID_A = 0
SEG_ID_B = 1
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common flags used in XLNet model."""
from absl import flags
flags.DEFINE_string("master", default=None, help="master")
flags.DEFINE_string(
"tpu",
default=None,
help="The Cloud TPU to use for training. This should be "
"either the name used when creating the Cloud TPU, or a "
"url like grpc://ip.address.of.tpu:8470.")
flags.DEFINE_bool(
"use_tpu", default=True, help="Use TPUs rather than plain CPUs.")
flags.DEFINE_string("tpu_topology", "2x2", help="TPU topology.")
flags.DEFINE_integer(
"num_core_per_host", default=8, help="number of cores per host")
flags.DEFINE_string("model_dir", default=None, help="Estimator model_dir.")
flags.DEFINE_string(
"init_checkpoint",
default=None,
help="Checkpoint path for initializing the model.")
flags.DEFINE_bool(
"init_from_transformerxl",
default=False,
help="Init from a transformerxl model checkpoint. Otherwise, init from the "
"entire model checkpoint.")
# Optimization config
flags.DEFINE_float("learning_rate", default=1e-4, help="Maximum learning rate.")
flags.DEFINE_float("clip", default=1.0, help="Gradient clipping value.")
flags.DEFINE_float("weight_decay_rate", default=0.0, help="Weight decay rate.")
# lr decay
flags.DEFINE_integer(
"warmup_steps", default=0, help="Number of steps for linear lr warmup.")
flags.DEFINE_float("adam_epsilon", default=1e-8, help="Adam epsilon.")
flags.DEFINE_float(
"lr_layer_decay_rate",
default=1.0,
help="Top layer: lr[L] = FLAGS.learning_rate."
"Lower layers: lr[l-1] = lr[l] * lr_layer_decay_rate.")
flags.DEFINE_float(
"min_lr_ratio", default=0.0, help="Minimum ratio learning rate.")
# Training config
flags.DEFINE_integer(
"train_batch_size",
default=16,
help="Size of the train batch across all hosts.")
flags.DEFINE_integer(
"train_steps", default=100000, help="Total number of training steps.")
flags.DEFINE_integer(
"iterations", default=1000, help="Number of iterations per repeat loop.")
# Data config
flags.DEFINE_integer(
"seq_len", default=0, help="Sequence length for pretraining.")
flags.DEFINE_integer(
"reuse_len",
default=0,
help="How many tokens to be reused in the next batch. "
"Could be half of `seq_len`.")
flags.DEFINE_bool("uncased", False, help="Use uncased inputs or not.")
flags.DEFINE_bool(
"bi_data",
default=False,
help="Use bidirectional data streams, "
"i.e., forward & backward.")
flags.DEFINE_integer("n_token", 32000, help="Vocab size")
# Model config
flags.DEFINE_integer("mem_len", default=0, help="Number of steps to cache")
flags.DEFINE_bool("same_length", default=False, help="Same length attention")
flags.DEFINE_integer("clamp_len", default=-1, help="Clamp length")
flags.DEFINE_integer("n_layer", default=6, help="Number of layers.")
flags.DEFINE_integer("d_model", default=32, help="Dimension of the model.")
flags.DEFINE_integer("d_embed", default=32, help="Dimension of the embeddings.")
flags.DEFINE_integer("n_head", default=4, help="Number of attention heads.")
flags.DEFINE_integer(
"d_head", default=8, help="Dimension of each attention head.")
flags.DEFINE_integer(
"d_inner",
default=32,
help="Dimension of inner hidden size in positionwise "
"feed-forward.")
flags.DEFINE_float("dropout", default=0.1, help="Dropout rate.")
flags.DEFINE_float("dropout_att", default=0.1, help="Attention dropout rate.")
flags.DEFINE_bool("untie_r", default=False, help="Untie r_w_bias and r_r_bias")
flags.DEFINE_string(
"ff_activation",
default="relu",
help="Activation type used in position-wise feed-forward.")
flags.DEFINE_string(
"strategy_type",
default="tpu",
help="Activation type used in position-wise feed-forward.")
flags.DEFINE_bool("use_bfloat16", False, help="Whether to use bfloat16.")
# Parameter initialization
flags.DEFINE_enum(
"init_method",
default="normal",
enum_values=["normal", "uniform"],
help="Initialization method.")
flags.DEFINE_float(
"init_std", default=0.02, help="Initialization std when init is normal.")
flags.DEFINE_float(
"init_range", default=0.1, help="Initialization std when init is uniform.")
flags.DEFINE_integer(
"test_data_size", default=12048, help="Number of test data samples.")
flags.DEFINE_string(
"train_tfrecord_path",
default=None,
help="Path to preprocessed training set tfrecord.")
flags.DEFINE_string(
"test_tfrecord_path",
default=None,
help="Path to preprocessed test set tfrecord.")
flags.DEFINE_integer(
"test_batch_size",
default=16,
help="Size of the test batch across all hosts.")
flags.DEFINE_integer(
"save_steps", default=1000, help="Number of steps for saving checkpoint.")
FLAGS = flags.FLAGS
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions and classes related to optimization (weight updates)."""
from absl import logging
import tensorflow as tf
from official.nlp import optimization
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Applys a warmup schedule on a given learning rate decay schedule."""
def __init__(self,
initial_learning_rate,
decay_schedule_fn,
warmup_steps,
power=1.0,
name=None):
super(WarmUp, self).__init__()
self.initial_learning_rate = initial_learning_rate
self.warmup_steps = warmup_steps
self.power = power
self.decay_schedule_fn = decay_schedule_fn
self.name = name
def __call__(self, step):
with tf.name_scope(self.name or "WarmUp") as name:
# Implements polynomial warmup. i.e., if global_step < warmup_steps, the
# learning rate will be `global_step/num_warmup_steps * init_lr`.
global_step_float = tf.cast(step, tf.float32)
warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
warmup_percent_done = global_step_float / warmup_steps_float
warmup_learning_rate = (
self.initial_learning_rate *
tf.math.pow(warmup_percent_done, self.power))
return tf.cond(
global_step_float < warmup_steps_float,
lambda: warmup_learning_rate,
lambda: self.decay_schedule_fn(step - self.warmup_steps),
name=name)
def get_config(self):
return {
"initial_learning_rate": self.initial_learning_rate,
"decay_schedule_fn": self.decay_schedule_fn,
"warmup_steps": self.warmup_steps,
"power": self.power,
"name": self.name
}
def create_optimizer(init_lr,
num_train_steps,
num_warmup_steps,
min_lr_ratio=0.0,
adam_epsilon=1e-8,
weight_decay_rate=0.0):
"""Creates an optimizer with learning rate schedule."""
# Implements linear decay of the learning rate.
learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
initial_learning_rate=init_lr,
decay_steps=num_train_steps - num_warmup_steps,
end_learning_rate=init_lr * min_lr_ratio)
if num_warmup_steps:
learning_rate_fn = WarmUp(
initial_learning_rate=init_lr,
decay_schedule_fn=learning_rate_fn,
warmup_steps=num_warmup_steps)
if weight_decay_rate > 0.0:
logging.info(
"Using AdamWeightDecay with adam_epsilon=%.9f weight_decay_rate=%.3f",
adam_epsilon, weight_decay_rate)
optimizer = optimization.AdamWeightDecay(
learning_rate=learning_rate_fn,
weight_decay_rate=weight_decay_rate,
beta_1=0.9,
beta_2=0.999,
epsilon=adam_epsilon,
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
include_in_weight_decay=["r_s_bias", "r_r_bias", "r_w_bias"])
else:
logging.info("Using Adam with adam_epsilon=%.9f", (adam_epsilon))
optimizer = tf.keras.optimizers.legacy.Adam(
learning_rate=learning_rate_fn, epsilon=adam_epsilon)
return optimizer, learning_rate_fn
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -26,8 +26,8 @@ import numpy as np
import tensorflow as tf
import sentencepiece as spm
from official.nlp.xlnet import classifier_utils
from official.nlp.xlnet import preprocess_utils
from official.legacy.xlnet import classifier_utils
from official.legacy.xlnet import preprocess_utils
flags.DEFINE_bool(
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -28,7 +28,7 @@ import numpy as np
import tensorflow.compat.v1 as tf
import sentencepiece as spm
from official.nlp.xlnet import preprocess_utils
from official.legacy.xlnet import preprocess_utils
FLAGS = flags.FLAGS
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -25,7 +25,7 @@ from absl import logging
import tensorflow as tf
import sentencepiece as spm
from official.nlp.xlnet import squad_utils
from official.legacy.xlnet import squad_utils
flags.DEFINE_integer(
"num_proc", default=1, help="Number of preprocessing processes.")
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""XLNet classification finetuning runner in tf2.0."""
import functools
# Import libraries
from absl import app
from absl import flags
from absl import logging
import numpy as np
import tensorflow as tf
# pylint: disable=unused-import
from official.common import distribute_utils
from official.legacy.xlnet import common_flags
from official.legacy.xlnet import data_utils
from official.legacy.xlnet import optimization
from official.legacy.xlnet import training_utils
from official.legacy.xlnet import xlnet_config
from official.legacy.xlnet import xlnet_modeling as modeling
flags.DEFINE_integer("n_class", default=2, help="Number of classes.")
flags.DEFINE_string(
"summary_type",
default="last",
help="Method used to summarize a sequence into a vector.")
FLAGS = flags.FLAGS
def get_classificationxlnet_model(model_config,
run_config,
n_class,
summary_type="last"):
model = modeling.ClassificationXLNetModel(
model_config, run_config, n_class, summary_type, name="model")
return model
def run_evaluation(strategy,
test_input_fn,
eval_steps,
model,
step,
eval_summary_writer=None):
"""Run evaluation for classification task.
Args:
strategy: distribution strategy.
test_input_fn: input function for evaluation data.
eval_steps: total number of evaluation steps.
model: keras model object.
step: current train step.
eval_summary_writer: summary writer used to record evaluation metrics. As
there are fake data samples in validation set, we use mask to get rid of
them when calculating the accuracy. For the reason that there will be
dynamic-shape tensor, we first collect logits, labels and masks from TPU
and calculate the accuracy via numpy locally.
Returns:
A float metric, accuracy.
"""
def _test_step_fn(inputs):
"""Replicated validation step."""
inputs["mems"] = None
_, logits = model(inputs, training=False)
return logits, inputs["label_ids"], inputs["is_real_example"]
@tf.function
def _run_evaluation(test_iterator):
"""Runs validation steps."""
logits, labels, masks = strategy.run(
_test_step_fn, args=(next(test_iterator),))
return logits, labels, masks
test_iterator = data_utils.get_input_iterator(test_input_fn, strategy)
correct = 0
total = 0
for _ in range(eval_steps):
logits, labels, masks = _run_evaluation(test_iterator)
logits = strategy.experimental_local_results(logits)
labels = strategy.experimental_local_results(labels)
masks = strategy.experimental_local_results(masks)
merged_logits = []
merged_labels = []
merged_masks = []
for i in range(strategy.num_replicas_in_sync):
merged_logits.append(logits[i].numpy())
merged_labels.append(labels[i].numpy())
merged_masks.append(masks[i].numpy())
merged_logits = np.vstack(np.array(merged_logits))
merged_labels = np.hstack(np.array(merged_labels))
merged_masks = np.hstack(np.array(merged_masks))
real_index = np.where(np.equal(merged_masks, 1))
correct += np.sum(
np.equal(
np.argmax(merged_logits[real_index], axis=-1),
merged_labels[real_index]))
total += np.shape(real_index)[-1]
accuracy = float(correct) / float(total)
logging.info("Train step: %d / acc = %d/%d = %f", step, correct, total,
accuracy)
if eval_summary_writer:
with eval_summary_writer.as_default():
tf.summary.scalar("eval_acc", float(correct) / float(total), step=step)
eval_summary_writer.flush()
return accuracy
def get_metric_fn():
train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy(
"acc", dtype=tf.float32)
return train_acc_metric
def main(unused_argv):
del unused_argv
strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=FLAGS.strategy_type,
tpu_address=FLAGS.tpu)
if strategy:
logging.info("***** Number of cores used : %d",
strategy.num_replicas_in_sync)
train_input_fn = functools.partial(data_utils.get_classification_input_data,
FLAGS.train_batch_size, FLAGS.seq_len,
strategy, True, FLAGS.train_tfrecord_path)
test_input_fn = functools.partial(data_utils.get_classification_input_data,
FLAGS.test_batch_size, FLAGS.seq_len,
strategy, False, FLAGS.test_tfrecord_path)
total_training_steps = FLAGS.train_steps
steps_per_loop = FLAGS.iterations
eval_steps = int(FLAGS.test_data_size / FLAGS.test_batch_size)
eval_fn = functools.partial(run_evaluation, strategy, test_input_fn,
eval_steps)
optimizer, learning_rate_fn = optimization.create_optimizer(
FLAGS.learning_rate,
total_training_steps,
FLAGS.warmup_steps,
adam_epsilon=FLAGS.adam_epsilon)
model_config = xlnet_config.XLNetConfig(FLAGS)
run_config = xlnet_config.create_run_config(True, False, FLAGS)
model_fn = functools.partial(get_classificationxlnet_model, model_config,
run_config, FLAGS.n_class, FLAGS.summary_type)
input_meta_data = {}
input_meta_data["d_model"] = FLAGS.d_model
input_meta_data["mem_len"] = FLAGS.mem_len
input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size /
strategy.num_replicas_in_sync)
input_meta_data["n_layer"] = FLAGS.n_layer
input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
input_meta_data["n_class"] = FLAGS.n_class
training_utils.train(
strategy=strategy,
model_fn=model_fn,
input_meta_data=input_meta_data,
eval_fn=eval_fn,
metric_fn=get_metric_fn,
train_input_fn=train_input_fn,
init_checkpoint=FLAGS.init_checkpoint,
init_from_transformerxl=FLAGS.init_from_transformerxl,
total_training_steps=total_training_steps,
steps_per_loop=steps_per_loop,
optimizer=optimizer,
learning_rate_fn=learning_rate_fn,
model_dir=FLAGS.model_dir,
save_steps=FLAGS.save_steps)
if __name__ == "__main__":
app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment