Commit 35112d1c authored by Hongkun Yu's avatar Hongkun Yu Committed by saberkun
Browse files

Internal change

PiperOrigin-RevId: 268543563
parent 1577ed07
# TensorFlow Natural Language Processing Models
tensorflow/models/official/nlp is a library of state-of-the-art models for
Natural Language Processing (NLP).
The library currently contains TensorFlow 2.x implementations, pre-trained
model weights, usage scripts and conversion utilities for the following models:
* Bert
* [XLNet](xlnet)
* Transformer for translation
# XLNet: Generalized Autoregressive Pretraining for Language Understanding
The academic paper which describes XLNet in detail and provides full results on
a number of tasks can be found here: https://arxiv.org/abs/1906.08237.
Instructions and user guide will be added soon.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for pre-processing classification data."""
from absl import flags
from absl import logging
from official.nlp.xlnet import data_utils
FLAGS = flags.FLAGS
SEG_ID_A = 0
SEG_ID_B = 1
SEG_ID_CLS = 2
SEG_ID_SEP = 3
SEG_ID_PAD = 4
class PaddingInputExample(object):
"""Fake example so the num input examples is a multiple of the batch size.
When running eval/predict on the TPU, we need to pad the number of examples
to be a multiple of the batch size, because the TPU requires a fixed batch
size. The alternative is to drop the last batch, which is bad because it means
the entire output data won't be generated.
We use this class instead of `None` because treating `None` as padding
battches could cause silent errors.
"""
class InputFeatures(object):
"""A single set of features of data."""
def __init__(self,
input_ids,
input_mask,
segment_ids,
label_id,
is_real_example=True):
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.label_id = label_id
self.is_real_example = is_real_example
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
def convert_single_example(ex_index, example, label_list, max_seq_length,
tokenize_fn):
"""Converts a single `InputExample` into a single `InputFeatures`."""
if isinstance(example, PaddingInputExample):
return InputFeatures(
input_ids=[0] * max_seq_length,
input_mask=[1] * max_seq_length,
segment_ids=[0] * max_seq_length,
label_id=0,
is_real_example=False)
if label_list is not None:
label_map = {}
for (i, label) in enumerate(label_list):
label_map[label] = i
tokens_a = tokenize_fn(example.text_a)
tokens_b = None
if example.text_b:
tokens_b = tokenize_fn(example.text_b)
if tokens_b:
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# Account for two [SEP] & one [CLS] with "- 3"
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
else:
# Account for one [SEP] & one [CLS] with "- 2"
if len(tokens_a) > max_seq_length - 2:
tokens_a = tokens_a[:max_seq_length - 2]
tokens = []
segment_ids = []
for token in tokens_a:
tokens.append(token)
segment_ids.append(SEG_ID_A)
tokens.append(data_utils.SEP_ID)
segment_ids.append(SEG_ID_A)
if tokens_b:
for token in tokens_b:
tokens.append(token)
segment_ids.append(SEG_ID_B)
tokens.append(data_utils.SEP_ID)
segment_ids.append(SEG_ID_B)
tokens.append(data_utils.CLS_ID)
segment_ids.append(SEG_ID_CLS)
input_ids = tokens
# The mask has 0 for real tokens and 1 for padding tokens. Only real
# tokens are attended to.
input_mask = [0] * len(input_ids)
# Zero-pad up to the sequence length.
if len(input_ids) < max_seq_length:
delta_len = max_seq_length - len(input_ids)
input_ids = [0] * delta_len + input_ids
input_mask = [1] * delta_len + input_mask
segment_ids = [SEG_ID_PAD] * delta_len + segment_ids
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
if label_list is not None:
label_id = label_map[example.label]
else:
label_id = example.label
if ex_index < 5:
logging.info("*** Example ***")
logging.info("guid: %s", (example.guid))
logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
logging.info("label: %d (id = %d)", example.label, label_id)
feature = InputFeatures(
input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
label_id=label_id)
return feature
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common flags used in XLNet model."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
from absl import flags
flags.DEFINE_string("master", default=None, help="master")
flags.DEFINE_string(
"tpu",
default=None,
help="The Cloud TPU to use for training. This should be "
"either the name used when creating the Cloud TPU, or a "
"url like grpc://ip.address.of.tpu:8470.")
flags.DEFINE_bool(
"use_tpu", default=True, help="Use TPUs rather than plain CPUs.")
flags.DEFINE_string("tpu_topology", "2x2", help="TPU topology.")
flags.DEFINE_integer(
"num_core_per_host", default=8, help="number of cores per host")
flags.DEFINE_string("model_dir", default=None, help="Estimator model_dir.")
flags.DEFINE_string(
"init_checkpoint",
default=None,
help="Checkpoint path for initializing the model.")
# Optimization config
flags.DEFINE_float("learning_rate", default=1e-4, help="Maximum learning rate.")
flags.DEFINE_float("clip", default=1.0, help="Gradient clipping value.")
flags.DEFINE_float("weight_decay_rate", default=0.0, help="Weight decay rate.")
# lr decay
flags.DEFINE_integer(
"warmup_steps", default=0, help="Number of steps for linear lr warmup.")
flags.DEFINE_float("adam_epsilon", default=1e-8, help="Adam epsilon.")
flags.DEFINE_float(
"lr_layer_decay_rate",
default=1.0,
help="Top layer: lr[L] = FLAGS.learning_rate."
"Lower layers: lr[l-1] = lr[l] * lr_layer_decay_rate.")
flags.DEFINE_float(
"min_lr_ratio", default=0.0, help="Minimum ratio learning rate.")
# Training config
flags.DEFINE_integer(
"train_batch_size",
default=16,
help="Size of the train batch across all hosts.")
flags.DEFINE_integer(
"train_steps", default=100000, help="Total number of training steps.")
flags.DEFINE_integer(
"iterations", default=1000, help="Number of iterations per repeat loop.")
# Data config
flags.DEFINE_integer(
"seq_len", default=0, help="Sequence length for pretraining.")
flags.DEFINE_integer(
"reuse_len",
default=0,
help="How many tokens to be reused in the next batch. "
"Could be half of `seq_len`.")
flags.DEFINE_bool("uncased", False, help="Use uncased inputs or not.")
flags.DEFINE_bool(
"bi_data",
default=False,
help="Use bidirectional data streams, "
"i.e., forward & backward.")
flags.DEFINE_integer("n_token", 32000, help="Vocab size")
# Model config
flags.DEFINE_integer("mem_len", default=0, help="Number of steps to cache")
flags.DEFINE_bool("same_length", default=False, help="Same length attention")
flags.DEFINE_integer("clamp_len", default=-1, help="Clamp length")
flags.DEFINE_integer("n_layer", default=6, help="Number of layers.")
flags.DEFINE_integer("d_model", default=32, help="Dimension of the model.")
flags.DEFINE_integer("d_embed", default=32, help="Dimension of the embeddings.")
flags.DEFINE_integer("n_head", default=4, help="Number of attention heads.")
flags.DEFINE_integer(
"d_head", default=8, help="Dimension of each attention head.")
flags.DEFINE_integer(
"d_inner",
default=32,
help="Dimension of inner hidden size in positionwise "
"feed-forward.")
flags.DEFINE_float("dropout", default=0.1, help="Dropout rate.")
flags.DEFINE_float("dropout_att", default=0.1, help="Attention dropout rate.")
flags.DEFINE_bool("untie_r", default=False, help="Untie r_w_bias and r_r_bias")
flags.DEFINE_string(
"ff_activation",
default="relu",
help="Activation type used in position-wise feed-forward.")
flags.DEFINE_string(
"strategy_type",
default="tpu",
help="Activation type used in position-wise feed-forward.")
flags.DEFINE_bool("use_bfloat16", False, help="Whether to use bfloat16.")
# Parameter initialization
flags.DEFINE_enum(
"init_method",
default="normal",
enum_values=["normal", "uniform"],
help="Initialization method.")
flags.DEFINE_float(
"init_std", default=0.02, help="Initialization std when init is normal.")
flags.DEFINE_float(
"init_range", default=0.1, help="Initialization std when init is uniform.")
flags.DEFINE_integer(
"train_data_size", default=130738, help="Number of training data samples.")
flags.DEFINE_integer(
"test_data_size", default=12048, help="Number of test data samples.")
flags.DEFINE_string(
"train_tfrecord_path",
default=None,
help="Path to preprocessed training set tfrecord.")
flags.DEFINE_string(
"test_tfrecord_path",
default=None,
help="Path to preprocessed test set tfrecord.")
flags.DEFINE_integer(
"test_batch_size",
default=16,
help="Size of the test batch across all hosts.")
flags.DEFINE_integer(
"save_steps", default=None, help="Number of steps for saving checkpoint.")
FLAGS = flags.FLAGS
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities used for data preparation."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import json
import os
from absl import logging
import tensorflow as tf
special_symbols = {
"<unk>": 0,
"<s>": 1,
"</s>": 2,
"<cls>": 3,
"<sep>": 4,
"<pad>": 5,
"<mask>": 6,
"<eod>": 7,
"<eop>": 8,
}
VOCAB_SIZE = 32000
UNK_ID = special_symbols["<unk>"]
CLS_ID = special_symbols["<cls>"]
SEP_ID = special_symbols["<sep>"]
MASK_ID = special_symbols["<mask>"]
EOD_ID = special_symbols["<eod>"]
def file_based_input_fn_builder(input_file, name_to_features, batch_size,
is_training):
"""Creates an `input_fn` closure."""
logging.info("Input tfrecord file %s", input_file)
def _decode_record(record, name_to_features):
"""Decodes a record to a TensorFlow example."""
example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in list(example.keys()):
t = example[name]
if t.dtype == tf.int64:
t = tf.cast(t, tf.int32)
example[name] = t
return example
def input_fn():
"""Returns dataset for training/evaluation."""
num_threads = 8
if isinstance(input_file, str):
d = tf.data.TFRecordDataset(input_file)
# For training, we want a lot of parallel reading and shuffling.
# For eval, we want no shuffling and parallel reading doesn't matter.
if is_training:
d = d.shuffle(2048)
d = d.repeat()
else:
cycle_length = min(num_threads, len(input_file))
d = tf.data.Dataset.from_tensor_slices(input_file)
# file level shuffle
d = d.shuffle(len(input_file)).repeat()
d = d.apply(
tf.data.experimental.parallel_interleave(
tf.data.TFRecordDataset,
sloppy=is_training,
cycle_length=cycle_length))
if is_training:
# sample level shuffle
d = d.shuffle(buffer_size=2048)
# TODO(b/138223458): Hard-code drop_remainder=True to get around the bug
# that under TPU strategy, setting drop_remainder=False in
# tf.data.Dataset.batch() while data_size can be divided by global
# batch_size will trigger dynamic_dimension related TPU compilation error.
d = d.apply(
tf.data.experimental.map_and_batch(
lambda record: _decode_record(record, name_to_features),
batch_size=batch_size,
num_parallel_batches=num_threads,
drop_remainder=True))
# When `input_file` is a path to a single file or a list
# containing a single path, disable auto sharding so that
# same input file is sent to all workers.
if isinstance(input_file, str) or len(input_file) == 1:
options = tf.data.Options()
options.experimental_distribute.auto_shard = False
d = d.with_options(options)
d = d.prefetch(tf.data.experimental.AUTOTUNE)
return d
return input_fn
def create_classification_dataset(file_path, seq_length, batch_size,
is_training):
"""Creates input dataset from (tf)records files for pretraining."""
name_to_features = {
"input_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
"input_mask": tf.io.FixedLenFeature([seq_length], tf.float32),
"segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
"label_ids": tf.io.FixedLenFeature([], tf.int64),
"is_real_example": tf.io.FixedLenFeature([], tf.int64),
}
input_fn = file_based_input_fn_builder(file_path, name_to_features,
batch_size, is_training)
dataset = input_fn()
return dataset
def create_squad_dataset(file_path, seq_length, batch_size, is_training):
"""Creates input dataset from (tf)records files for pretraining."""
name_to_features = {
"unique_ids": tf.io.FixedLenFeature([], tf.int64),
"input_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
"input_mask": tf.io.FixedLenFeature([seq_length], tf.float32),
"segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
"cls_index": tf.io.FixedLenFeature([], tf.int64),
"p_mask": tf.io.FixedLenFeature([seq_length], tf.float32)
}
if is_training:
name_to_features["start_positions"] = tf.io.FixedLenFeature([], tf.int64)
name_to_features["end_positions"] = tf.io.FixedLenFeature([], tf.int64)
name_to_features["is_impossible"] = tf.io.FixedLenFeature([], tf.float32)
input_fn = file_based_input_fn_builder(file_path, name_to_features,
batch_size, is_training)
dataset = input_fn()
return dataset
def _get_input_iterator(input_fn, strategy):
"""Returns distributed dataset iterator."""
# When training with TPU pods, datasets needs to be cloned across
# workers. Since Dataset instance cannot be cloned in eager mode, we instead
# pass callable that returns a dataset.
input_data = input_fn()
if callable(input_data):
iterator = iter(
strategy.experimental_distribute_datasets_from_function(input_data))
else:
iterator = iter(strategy.experimental_distribute_dataset(input_data))
return iterator
def get_classification_input_data(batch_size, seq_len, strategy, is_training,
file_path):
"""Returns input dataset from input file string."""
# When using TPU pods, we need to clone dataset across
# workers and need to pass in function that returns the dataset rather
# than passing dataset instance itself.
use_dataset_fn = isinstance(strategy, tf.distribute.experimental.TPUStrategy)
if use_dataset_fn:
if batch_size % strategy.num_replicas_in_sync != 0:
raise ValueError(
"Batch size must be divisible by number of replicas : {}".format(
strategy.num_replicas_in_sync))
# As auto rebatching is not supported in
# `experimental_distribute_datasets_from_function()` API, which is
# required when cloning dataset to multiple workers in eager mode,
# we use per-replica batch size.
batch_size = int(batch_size / strategy.num_replicas_in_sync)
def _dataset_fn(ctx=None):
del ctx
train_dataset = create_classification_dataset(
file_path=file_path,
seq_length=seq_len,
batch_size=batch_size,
is_training=is_training)
return train_dataset
return _dataset_fn if use_dataset_fn else _dataset_fn()
def get_squad_input_data(batch_size, seq_len, q_len, strategy, is_training,
file_path):
"""Returns input dataset from input file string."""
# When using TPU pods, we need to clone dataset across
# workers and need to pass in function that returns the dataset rather
# than passing dataset instance itself.
use_dataset_fn = isinstance(strategy, tf.distribute.experimental.TPUStrategy)
if use_dataset_fn:
if batch_size % strategy.num_replicas_in_sync != 0:
raise ValueError(
"Batch size must be divisible by number of replicas : {}".format(
strategy.num_replicas_in_sync))
# As auto rebatching is not supported in
# `experimental_distribute_datasets_from_function()` API, which is
# required when cloning dataset to multiple workers in eager mode,
# we use per-replica batch size.
batch_size = int(batch_size / strategy.num_replicas_in_sync)
if is_training:
input_glob = os.path.join(
file_path,
"spiece.model.*.slen-{}.qlen-{}.train.tf_record".format(seq_len, q_len))
global_input_paths = tf.io.gfile.glob(input_glob)
else:
global_input_paths = file_path
def _dataset_fn(ctx=None):
del ctx
train_dataset = create_squad_dataset(
file_path=global_input_paths,
seq_length=seq_len,
batch_size=batch_size,
is_training=is_training)
return train_dataset
return _dataset_fn if use_dataset_fn else _dataset_fn()
def create_pretrain_dataset(file_names,
bsz_per_core,
seq_len,
reuse_len,
perm_size,
num_predict=None,
input_pipeline_context=None):
"""Creates pretrain dataset."""
def parser(record):
"""Function used to parse tfrecord."""
record_spec = {
"input": tf.io.FixedLenFeature([seq_len], tf.int64),
"target": tf.io.FixedLenFeature([seq_len], tf.int64),
"seg_id": tf.io.FixedLenFeature([seq_len], tf.int64),
"label": tf.io.FixedLenFeature([1], tf.int64),
"is_masked": tf.io.FixedLenFeature([seq_len], tf.int64),
}
# retrieve serialized example
example = tf.io.parse_single_example(
serialized=record, features=record_spec)
inputs = example.pop("input")
target = example.pop("target")
is_masked = tf.cast(example.pop("is_masked"), tf.bool)
non_reuse_len = seq_len - reuse_len
# perm_size should not be larger than reuse_len or non_reuse_len otherwise
# there will be data leaks.
assert perm_size <= reuse_len and perm_size <= non_reuse_len
# Creates permutation mask and target mask for the first reuse_len tokens.
# The tokens in this part are reused from the last sequence.
perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0 = _local_perm(
inputs[:reuse_len], target[:reuse_len], is_masked[:reuse_len],
perm_size, reuse_len)
# Creates permutation mask and target mask for the rest of tokens in
# current example, which are concatentation of two new segments.
perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1 = _local_perm(
inputs[reuse_len:], target[reuse_len:], is_masked[reuse_len:],
perm_size, non_reuse_len)
perm_mask_0 = tf.concat(
[perm_mask_0, tf.ones([reuse_len, non_reuse_len])], axis=1)
perm_mask_1 = tf.concat([tf.zeros([non_reuse_len, reuse_len]), perm_mask_1],
axis=1)
perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0)
target = tf.concat([target_0, target_1], axis=0)
target_mask = tf.concat([target_mask_0, target_mask_1], axis=0)
input_k = tf.concat([input_k_0, input_k_1], axis=0)
input_q = tf.concat([input_q_0, input_q_1], axis=0)
if num_predict is not None:
indices = tf.range(seq_len, dtype=tf.int64)
bool_target_mask = tf.cast(target_mask, tf.bool)
indices = tf.boolean_mask(indices, bool_target_mask)
##### extra padding due to CLS/SEP introduced after prepro
actual_num_predict = tf.shape(indices)[0]
pad_len = num_predict - actual_num_predict
##### target_mapping
target_mapping = tf.one_hot(indices, seq_len, dtype=tf.float32)
paddings = tf.zeros([pad_len, seq_len], dtype=target_mapping.dtype)
target_mapping = tf.concat([target_mapping, paddings], axis=0)
example["target_mapping"] = tf.reshape(target_mapping,
[num_predict, seq_len])
##### target
target = tf.boolean_mask(target, bool_target_mask)
paddings = tf.zeros([pad_len], dtype=target.dtype)
target = tf.concat([target, paddings], axis=0)
example["target"] = tf.reshape(target, [num_predict])
##### target mask
target_mask = tf.concat([
tf.ones([actual_num_predict], dtype=tf.float32),
tf.zeros([pad_len], dtype=tf.float32)
],
axis=0)
example["target_mask"] = tf.reshape(target_mask, [num_predict])
else:
example["target"] = tf.reshape(target, [seq_len])
example["target_mask"] = tf.reshape(target_mask, [seq_len])
# reshape back to fixed shape
example["perm_mask"] = tf.reshape(perm_mask, [seq_len, seq_len])
example["input_k"] = tf.reshape(input_k, [seq_len])
example["input_q"] = tf.reshape(input_q, [seq_len])
for key in list(example.keys()):
val = example[key]
if tf.keras.backend.is_sparse(val):
val = tf.sparse.to_dense(val)
if val.dtype == tf.int64:
val = tf.cast(val, tf.int32)
example[key] = val
for k, v in example.items():
logging.info("%s: %s", k, v)
return example
dataset = parse_files_to_dataset(
parser=parser,
file_paths=file_names,
bsz_per_core=bsz_per_core,
input_pipeline_context=input_pipeline_context)
return dataset
def format_filename(prefix,
bsz_per_host,
seq_len,
bi_data,
suffix,
mask_alpha=5,
mask_beta=1,
reuse_len=None,
uncased=False,
fixed_num_predict=None):
"""Generates input file name pattern."""
if reuse_len is None:
reuse_len_str = ""
else:
reuse_len_str = "reuse-{}.".format(reuse_len)
if not uncased:
uncased_str = ""
else:
uncased_str = "uncased."
if bi_data:
bi_data_str = "bi"
else:
bi_data_str = "uni"
if fixed_num_predict is not None:
fnp_str = "fnp-{}.".format(fixed_num_predict)
else:
fnp_str = ""
file_name = "{}.bsz-{}.seqlen-{}.{}{}{}.alpha-{}.beta-{}.{}{}".format(
prefix, bsz_per_host, seq_len, reuse_len_str, uncased_str, bi_data_str,
mask_alpha, mask_beta, fnp_str, suffix)
return file_name
def get_pretrain_input_data(batch_size,
seq_len,
strategy,
file_path,
reuse_len,
perm_size,
mask_alpha,
mask_beta,
num_predict,
bi_data,
uncased,
num_hosts=1):
"""Returns input dataset from input file string."""
# When using TPU pods, we need to clone dataset across
# workers and need to pass in function that returns the dataset rather
# than passing dataset instance itself.
use_dataset_fn = isinstance(strategy, tf.distribute.experimental.TPUStrategy)
split = "train"
record_glob_base = format_filename(
prefix="record_info-{}-*".format(split),
bsz_per_host=int(batch_size / num_hosts),
seq_len=seq_len,
bi_data=bi_data,
suffix="json",
mask_alpha=mask_alpha,
mask_beta=mask_beta,
reuse_len=reuse_len,
uncased=uncased,
fixed_num_predict=num_predict)
if use_dataset_fn:
if batch_size % strategy.num_replicas_in_sync != 0:
raise ValueError(
"Batch size must be divisible by number of replicas : {}".format(
strategy.num_replicas_in_sync))
# As auto rebatching is not supported in
# `experimental_distribute_datasets_from_function()` API, which is
# required when cloning dataset to multiple workers in eager mode,
# we use per-replica batch size.
batch_size = int(batch_size / strategy.num_replicas_in_sync)
record_info = {"num_batch": 0, "filenames": []}
tfrecord_dirs = file_path.split(",")
logging.info("Use the following tfrecord dirs: %s", tfrecord_dirs)
for idx, record_dir in enumerate(tfrecord_dirs):
record_glob = os.path.join(record_dir, record_glob_base)
logging.info("[%d] Record glob: %s", idx, record_glob)
record_paths = sorted(tf.io.gfile.glob(record_glob))
logging.info("[%d] Num of record info path: %d", idx, len(record_paths))
cur_record_info = {"num_batch": 0, "filenames": []}
for record_info_path in record_paths:
with tf.io.gfile.GFile(record_info_path, "r") as fp:
info = json.load(fp)
cur_record_info["num_batch"] += info["num_batch"]
cur_record_info["filenames"] += info["filenames"]
# overwrite directory for `cur_record_info`
new_filenames = []
for filename in cur_record_info["filenames"]:
basename = os.path.basename(filename)
new_filename = os.path.join(record_dir, basename)
new_filenames.append(new_filename)
cur_record_info["filenames"] = new_filenames
logging.info("[Dir %d] Number of chosen batches: %s", idx,
cur_record_info["num_batch"])
logging.info("[Dir %d] Number of chosen files: %s", idx,
len(cur_record_info["filenames"]))
logging.info(cur_record_info["filenames"])
# add `cur_record_info` to global `record_info`
record_info["num_batch"] += cur_record_info["num_batch"]
record_info["filenames"] += cur_record_info["filenames"]
logging.info("Total number of batches: %d", record_info["num_batch"])
logging.info("Total number of files: %d", len(record_info["filenames"]))
logging.info(record_info["filenames"])
def _dataset_fn(ctx=None):
"""Function that can create a pretrain dataset."""
train_dataset = create_pretrain_dataset(
file_names=record_info["filenames"],
bsz_per_core=batch_size,
seq_len=seq_len,
reuse_len=reuse_len,
perm_size=perm_size,
num_predict=num_predict,
input_pipeline_context=ctx)
return train_dataset
return _dataset_fn if use_dataset_fn else _dataset_fn()
def parse_files_to_dataset(parser,
file_paths,
bsz_per_core,
input_pipeline_context=None):
"""Creates the dataset given file paths."""
dataset = tf.data.Dataset.from_tensor_slices(file_paths)
# Note: we cannot perform sample-level shuffle here because this will violate
# the consecutive requirement of data stream.
if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
input_pipeline_context.input_pipeline_id)
# file-level shuffle
if len(file_paths) > 1:
dataset = dataset.shuffle(len(file_paths))
dataset = tf.data.TFRecordDataset(dataset)
# (zihang): since we are doing online preprocessing, the parsed result of
# the same input at each time will be different. Thus, cache processed data
# is not helpful. It will use a lot of memory and lead to contrainer OOM.
# So, change to cache non-parsed raw data instead.
dataset = dataset.cache().map(parser).repeat()
dataset = dataset.batch(bsz_per_core, drop_remainder=True)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset
def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
"""Samples a permutation of the factorization order.
Creates perm_mask and target_mask accordingly.
Args:
inputs: int64 Tensor in shape [seq_len], input ids.
targets: int64 Tensor in shape [seq_len], target ids.
is_masked: bool Tensor in shape [seq_len]. True means being selected for
partial prediction.
perm_size: the length of longest permutation. Could be set to be reuse_len.
Should not be larger than reuse_len or there will be data leaks.
seq_len: int, sequence length.
Returns:
perm_mask: float32 Tensor in shape [seq_len, seq_len] consisted of 0 and 1.
If perm_mask[i][j] == 1, it means the ith token (in original order) cannot
attend to the jth token
(in original order). This case will happen only when the ith token's
permutated position <= the jth token's permutated position,
and the jth token is masked or is func token. If perm_mask[i][j] == 0, it
means the ith token (in original order) can attend to the jth token
(in original order). Note that non-masked tokens can be attended by all
other tokens, which is different from the description in original paper.
new_targets: int64 Tensor in shape [seq_len], target token ids to be
predicted in XLNet.
In XLNet, target doesn't need to be shifted one position.
target_mask: float32 Tensor in shape [seq_len] consisted of 0 and 1. If
target_mask[i] == 1,
the ith token needs to be predicted and mask will be used as input. This
token will count for loss.
If target_mask[i] == 0, token (or [SEP], [CLS]) will be used as input. This
token will not count for loss.
inputs_k: int64 Tensor in shape [seq_len], input ids.
inputs_q: float32 Tensor in shape [seq_len], the same as target_mask.
"""
# Generate permutation indices
index = tf.range(seq_len, dtype=tf.int64)
index = tf.transpose(tf.reshape(index, [-1, perm_size]))
index = tf.random.shuffle(index)
index = tf.reshape(tf.transpose(index), [-1])
# `perm_mask` and `target_mask`
# non-functional tokens
non_func_tokens = tf.logical_not(
tf.logical_or(tf.equal(inputs, SEP_ID), tf.equal(inputs, CLS_ID)))
non_mask_tokens = tf.logical_and(tf.logical_not(is_masked), non_func_tokens)
masked_or_func_tokens = tf.logical_not(non_mask_tokens)
# Set the permutation indices of non-masked (& non-funcional) tokens to the
# smallest index (-1):
# (1) they can be seen by all other positions
# (2) they cannot see masked positions, so there won"t be information leak
smallest_index = -tf.ones([seq_len], dtype=tf.int64)
rev_index = tf.where(non_mask_tokens, smallest_index, index)
# Create `target_mask`: non-funcional and masked tokens
# 1: use mask as input and have loss
# 0: use token (or [SEP], [CLS]) as input and do not have loss
target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens)
target_mask = tf.cast(target_tokens, tf.float32)
# Create `perm_mask`
# `target_tokens` cannot see themselves
self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1)
# 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens)
# 0: can attend if i > j or j is non-masked
perm_mask = tf.logical_and(self_rev_index[:, None] <= rev_index[None, :],
masked_or_func_tokens)
perm_mask = tf.cast(perm_mask, tf.float32)
# new target: [next token] for LM and [curr token] (self) for PLM
new_targets = tf.concat([inputs[0:1], targets[:-1]], axis=0)
# construct inputs_k
inputs_k = inputs
# construct inputs_q
inputs_q = target_mask
return perm_mask, new_targets, target_mask, inputs_k, inputs_q
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions and classes related to optimization (weight updates)."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
from absl import logging
import tensorflow as tf
from official.bert.optimization import AdamWeightDecay
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Applys a warmup schedule on a given learning rate decay schedule."""
def __init__(self,
initial_learning_rate,
decay_schedule_fn,
warmup_steps,
power=1.0,
name=None):
super(WarmUp, self).__init__()
self.initial_learning_rate = initial_learning_rate
self.warmup_steps = warmup_steps
self.power = power
self.decay_schedule_fn = decay_schedule_fn
self.name = name
def __call__(self, step):
with tf.name_scope(self.name or "WarmUp") as name:
# Implements polynomial warmup. i.e., if global_step < warmup_steps, the
# learning rate will be `global_step/num_warmup_steps * init_lr`.
global_step_float = tf.cast(step, tf.float32)
warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
warmup_percent_done = global_step_float / warmup_steps_float
warmup_learning_rate = (
self.initial_learning_rate *
tf.math.pow(warmup_percent_done, self.power))
return tf.cond(
global_step_float < warmup_steps_float,
lambda: warmup_learning_rate,
lambda: self.decay_schedule_fn(step - self.warmup_steps),
name=name)
def get_config(self):
return {
"initial_learning_rate": self.initial_learning_rate,
"decay_schedule_fn": self.decay_schedule_fn,
"warmup_steps": self.warmup_steps,
"power": self.power,
"name": self.name
}
def create_optimizer(init_lr,
num_train_steps,
num_warmup_steps,
min_lr_ratio=0.0,
adam_epsilon=1e-8,
weight_decay_rate=0.0):
"""Creates an optimizer with learning rate schedule."""
# Implements linear decay of the learning rate.
learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
initial_learning_rate=init_lr,
decay_steps=num_train_steps - num_warmup_steps,
end_learning_rate=init_lr * min_lr_ratio)
if num_warmup_steps:
learning_rate_fn = WarmUp(
initial_learning_rate=init_lr,
decay_schedule_fn=learning_rate_fn,
warmup_steps=num_warmup_steps)
if weight_decay_rate > 0.0:
logging.info(
"Using AdamWeightDecay with adam_epsilon=%.9f weight_decay_rate=%.3f",
adam_epsilon, weight_decay_rate)
optimizer = AdamWeightDecay(
learning_rate=learning_rate_fn,
weight_decay_rate=weight_decay_rate,
beta_1=0.9,
beta_2=0.999,
epsilon=adam_epsilon,
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
include_in_weight_decay=["r_s_bias", "r_r_bias", "r_w_bias"])
else:
logging.info("Using Adam with adam_epsilon=%.9f", (adam_epsilon))
optimizer = tf.keras.optimizers.Adam(
learning_rate=learning_rate_fn, epsilon=adam_epsilon)
return optimizer, learning_rate_fn
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Script to pre-process classification data into tfrecords."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import csv
import os
from absl import app
from absl import flags
from absl import logging
import numpy as np
import tensorflow as tf
import sentencepiece as spm
from official.nlp.xlnet import classifier_utils
from official.nlp.xlnet import preprocess_utils
flags.DEFINE_bool(
"overwrite_data",
default=False,
help="If False, will use cached data if available.")
flags.DEFINE_string("output_dir", default="", help="Output dir for TF records.")
flags.DEFINE_string(
"spiece_model_file", default="", help="Sentence Piece model path.")
flags.DEFINE_string("data_dir", default="", help="Directory for input data.")
# task specific
flags.DEFINE_string("eval_split", default="dev", help="could be dev or test")
flags.DEFINE_string("task_name", default=None, help="Task name")
flags.DEFINE_integer(
"eval_batch_size", default=64, help="batch size for evaluation")
flags.DEFINE_integer("max_seq_length", default=128, help="Max sequence length")
flags.DEFINE_integer(
"num_passes",
default=1,
help="Num passes for processing training data. "
"This is use to batch data without loss for TPUs.")
flags.DEFINE_bool("uncased", default=False, help="Use uncased.")
flags.DEFINE_bool(
"is_regression", default=False, help="Whether it's a regression task.")
FLAGS = flags.FLAGS
class InputExample(object):
"""A single training/test example for simple sequence classification."""
def __init__(self, guid, text_a, text_b=None, label=None):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.label = label
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
def get_dev_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError()
def get_test_examples(self, data_dir):
"""Gets a collection of `InputExample`s for prediction."""
raise NotImplementedError()
def get_labels(self):
"""Gets the list of labels for this data set."""
raise NotImplementedError()
@classmethod
def _read_tsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file."""
with tf.io.gfile.GFile(input_file, "r") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
lines = []
for line in reader:
# pylint: disable=g-explicit-length-test
if len(line) == 0:
continue
lines.append(line)
return lines
class GLUEProcessor(DataProcessor):
"""GLUEProcessor."""
def __init__(self):
self.train_file = "train.tsv"
self.dev_file = "dev.tsv"
self.test_file = "test.tsv"
self.label_column = None
self.text_a_column = None
self.text_b_column = None
self.contains_header = True
self.test_text_a_column = None
self.test_text_b_column = None
self.test_contains_header = True
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, self.train_file)), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, self.dev_file)), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
if self.test_text_a_column is None:
self.test_text_a_column = self.text_a_column
if self.test_text_b_column is None:
self.test_text_b_column = self.text_b_column
return self._create_examples(
self._read_tsv(os.path.join(data_dir, self.test_file)), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0 and self.contains_header and set_type != "test":
continue
if i == 0 and self.test_contains_header and set_type == "test":
continue
guid = "%s-%s" % (set_type, i)
a_column = (
self.text_a_column if set_type != "test" else self.test_text_a_column)
b_column = (
self.text_b_column if set_type != "test" else self.test_text_b_column)
# there are some incomplete lines in QNLI
if len(line) <= a_column:
logging.warning("Incomplete line, ignored.")
continue
text_a = line[a_column]
if b_column is not None:
if len(line) <= b_column:
logging.warning("Incomplete line, ignored.")
continue
text_b = line[b_column]
else:
text_b = None
if set_type == "test":
label = self.get_labels()[0]
else:
if len(line) <= self.label_column:
logging.warning("Incomplete line, ignored.")
continue
label = line[self.label_column]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class Yelp5Processor(DataProcessor):
"""Yelp5Processor."""
def get_train_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "train.csv"))
def get_dev_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "test.csv"))
def get_labels(self):
"""See base class."""
return ["1", "2", "3", "4", "5"]
def _create_examples(self, input_file):
"""Creates examples for the training and dev sets."""
examples = []
with tf.io.gfile.GFile(input_file) as f:
reader = csv.reader(f)
for i, line in enumerate(reader):
label = line[0]
text_a = line[1].replace('""', '"').replace('\\"', '"')
examples.append(
InputExample(guid=str(i), text_a=text_a, text_b=None, label=label))
return examples
class ImdbProcessor(DataProcessor):
"""ImdbProcessor."""
def get_labels(self):
return ["neg", "pos"]
def get_train_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "train"))
def get_dev_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "test"))
def _create_examples(self, data_dir):
"""Creates examples."""
examples = []
for label in ["neg", "pos"]:
cur_dir = os.path.join(data_dir, label)
for filename in tf.io.gfile.listdir(cur_dir):
if not filename.endswith("txt"):
continue
if len(examples) % 1000 == 0:
logging.info("Loading dev example %d", len(examples))
path = os.path.join(cur_dir, filename)
with tf.io.gfile.GFile(path) as f:
text = f.read().strip().replace("<br />", " ")
examples.append(
InputExample(
guid="unused_id", text_a=text, text_b=None, label=label))
return examples
class MnliMatchedProcessor(GLUEProcessor):
"""MnliMatchedProcessor."""
def __init__(self):
super(MnliMatchedProcessor, self).__init__()
self.dev_file = "dev_matched.tsv"
self.test_file = "test_matched.tsv"
self.label_column = -1
self.text_a_column = 8
self.text_b_column = 9
def get_labels(self):
return ["contradiction", "entailment", "neutral"]
class MnliMismatchedProcessor(MnliMatchedProcessor):
def __init__(self):
super(MnliMismatchedProcessor, self).__init__()
self.dev_file = "dev_mismatched.tsv"
self.test_file = "test_mismatched.tsv"
class StsbProcessor(GLUEProcessor):
"""StsbProcessor."""
def __init__(self):
super(StsbProcessor, self).__init__()
self.label_column = 9
self.text_a_column = 7
self.text_b_column = 8
def get_labels(self):
return [0.0]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0 and self.contains_header and set_type != "test":
continue
if i == 0 and self.test_contains_header and set_type == "test":
continue
guid = "%s-%s" % (set_type, i)
a_column = (
self.text_a_column if set_type != "test" else self.test_text_a_column)
b_column = (
self.text_b_column if set_type != "test" else self.test_text_b_column)
# there are some incomplete lines in QNLI
if len(line) <= a_column:
logging.warning("Incomplete line, ignored.")
continue
text_a = line[a_column]
if b_column is not None:
if len(line) <= b_column:
logging.warning("Incomplete line, ignored.")
continue
text_b = line[b_column]
else:
text_b = None
if set_type == "test":
label = self.get_labels()[0]
else:
if len(line) <= self.label_column:
logging.warning("Incomplete line, ignored.")
continue
label = float(line[self.label_column])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def file_based_convert_examples_to_features(examples,
label_list,
max_seq_length,
tokenize_fn,
output_file,
num_passes=1):
"""Convert a set of `InputExample`s to a TFRecord file."""
# do not create duplicated records
if tf.io.gfile.exists(output_file) and not FLAGS.overwrite_data:
logging.info("Do not overwrite tfrecord %s exists.", output_file)
return
logging.info("Create new tfrecord %s.", output_file)
writer = tf.io.TFRecordWriter(output_file)
examples *= num_passes
for (ex_index, example) in enumerate(examples):
if ex_index % 10000 == 0:
logging.info("Writing example %d of %d", ex_index, len(examples))
feature = classifier_utils.convert_single_example(ex_index, example,
label_list,
max_seq_length,
tokenize_fn)
def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f
def create_float_feature(values):
f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
return f
features = collections.OrderedDict()
features["input_ids"] = create_int_feature(feature.input_ids)
features["input_mask"] = create_float_feature(feature.input_mask)
features["segment_ids"] = create_int_feature(feature.segment_ids)
if label_list is not None:
features["label_ids"] = create_int_feature([feature.label_id])
else:
features["label_ids"] = create_float_feature([float(feature.label_id)])
features["is_real_example"] = create_int_feature(
[int(feature.is_real_example)])
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString())
writer.close()
def main(_):
logging.set_verbosity(logging.INFO)
processors = {
"mnli_matched": MnliMatchedProcessor,
"mnli_mismatched": MnliMismatchedProcessor,
"sts-b": StsbProcessor,
"imdb": ImdbProcessor,
"yelp5": Yelp5Processor
}
task_name = FLAGS.task_name.lower()
if task_name not in processors:
raise ValueError("Task not found: %s" % (task_name))
processor = processors[task_name]()
label_list = processor.get_labels() if not FLAGS.is_regression else None
sp = spm.SentencePieceProcessor()
sp.Load(FLAGS.spiece_model_file)
def tokenize_fn(text):
text = preprocess_utils.preprocess_text(text, lower=FLAGS.uncased)
return preprocess_utils.encode_ids(sp, text)
spm_basename = os.path.basename(FLAGS.spiece_model_file)
train_file_base = "{}.len-{}.train.tf_record".format(spm_basename,
FLAGS.max_seq_length)
train_file = os.path.join(FLAGS.output_dir, train_file_base)
logging.info("Use tfrecord file %s", train_file)
train_examples = processor.get_train_examples(FLAGS.data_dir)
np.random.shuffle(train_examples)
logging.info("Num of train samples: %d", len(train_examples))
file_based_convert_examples_to_features(train_examples, label_list,
FLAGS.max_seq_length, tokenize_fn,
train_file, FLAGS.num_passes)
if FLAGS.eval_split == "dev":
eval_examples = processor.get_dev_examples(FLAGS.data_dir)
else:
eval_examples = processor.get_test_examples(FLAGS.data_dir)
logging.info("Num of eval samples: %d", len(eval_examples))
# TPU requires a fixed batch size for all batches, therefore the number
# of examples must be a multiple of the batch size, or else examples
# will get dropped. So we pad with fake examples which are ignored
# later on. These do NOT count towards the metric (all tf.metrics
# support a per-instance weight, and these get a weight of 0.0).
#
# Modified in XL: We also adopt the same mechanism for GPUs.
while len(eval_examples) % FLAGS.eval_batch_size != 0:
eval_examples.append(classifier_utils.PaddingInputExample())
eval_file_base = "{}.len-{}.{}.eval.tf_record".format(spm_basename,
FLAGS.max_seq_length,
FLAGS.eval_split)
eval_file = os.path.join(FLAGS.output_dir, eval_file_base)
file_based_convert_examples_to_features(eval_examples, label_list,
FLAGS.max_seq_length, tokenize_fn,
eval_file)
if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
app.run(main)
# -*- coding: utf-8 -*-
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Script to pre-process pre-training data into tfrecords."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os
import random
from absl import flags
import absl.logging as _logging # pylint: disable=unused-import
import numpy as np
import tensorflow.google as tf
from official.nlp.xlnet import preprocess_utils
import sentencepiece as spm
special_symbols = {
"<unk>" : 0,
"<s>" : 1,
"</s>" : 2,
"<cls>" : 3,
"<sep>" : 4,
"<pad>" : 5,
"<mask>" : 6,
"<eod>" : 7,
"<eop>" : 8,
}
VOCAB_SIZE = 32000
UNK_ID = special_symbols["<unk>"]
CLS_ID = special_symbols["<cls>"]
SEP_ID = special_symbols["<sep>"]
MASK_ID = special_symbols["<mask>"]
EOD_ID = special_symbols["<eod>"]
def _int64_feature(values):
return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
def _float_feature(values):
return tf.train.Feature(float_list=tf.train.FloatList(value=values))
def format_filename(prefix, bsz_per_host, seq_len, bi_data, suffix,
mask_alpha=5, mask_beta=1, reuse_len=None, uncased=False,
fixed_num_predict=None):
"""docs."""
if reuse_len is None:
reuse_len_str = ""
else:
reuse_len_str = "reuse-{}.".format(reuse_len)
if not uncased:
uncased_str = ""
else:
uncased_str = "uncased."
if bi_data:
bi_data_str = "bi"
else:
bi_data_str = "uni"
if fixed_num_predict is not None:
fnp_str = "fnp-{}.".format(fixed_num_predict)
else:
fnp_str = ""
file_name = "{}.bsz-{}.seqlen-{}.{}{}{}.alpha-{}.beta-{}.{}{}".format(
prefix, bsz_per_host, seq_len, reuse_len_str, uncased_str, bi_data_str,
mask_alpha, mask_beta, fnp_str, suffix)
return file_name
def _create_data(idx, input_paths):
# Load sentence-piece model
sp = spm.SentencePieceProcessor()
sp.Load(FLAGS.sp_path)
input_shards = []
total_line_cnt = 0
for input_path in input_paths:
input_data, sent_ids = [], []
sent_id, line_cnt = True, 0
tf.logging.info("Processing %s", input_path)
for line in tf.gfile.Open(input_path):
if line_cnt % 100000 == 0:
tf.logging.info("Loading line %d", line_cnt)
line_cnt += 1
if not line.strip():
if FLAGS.use_eod:
sent_id = not sent_id
cur_sent = [EOD_ID]
else:
continue
else:
if FLAGS.from_raw_text:
cur_sent = preprocess_utils.preprocess_text(
line.strip(), lower=FLAGS.uncased)
cur_sent = preprocess_utils.encode_ids(sp, cur_sent)
else:
cur_sent = list(map(int, line.strip().split()))
input_data.extend(cur_sent)
sent_ids.extend([sent_id] * len(cur_sent))
sent_id = not sent_id
tf.logging.info("Finish with line %d", line_cnt)
if line_cnt == 0:
continue
input_data = np.array(input_data, dtype=np.int64)
sent_ids = np.array(sent_ids, dtype=np.bool)
total_line_cnt += line_cnt
input_shards.append((input_data, sent_ids))
tf.logging.info("[Task %d] Total number line: %d", idx, total_line_cnt)
tfrecord_dir = os.path.join(FLAGS.save_dir, "tfrecords")
filenames, num_batch = [], 0
# Randomly shuffle input shards (with a fixed but distinct random seed)
np.random.seed(100 * FLAGS.task + FLAGS.pass_id)
perm_indices = np.random.permutation(len(input_shards))
tf.logging.info("Using perm indices %s for pass %d",
perm_indices.tolist(), FLAGS.pass_id)
input_data_list, sent_ids_list = [], []
prev_sent_id = None
for perm_idx in perm_indices:
input_data, sent_ids = input_shards[perm_idx]
# make sure the `send_ids[0] == not prev_sent_id`
if prev_sent_id is not None and sent_ids[0] == prev_sent_id:
sent_ids = np.logical_not(sent_ids)
# append to temporary list
input_data_list.append(input_data)
sent_ids_list.append(sent_ids)
# update `prev_sent_id`
prev_sent_id = sent_ids[-1]
input_data = np.concatenate(input_data_list)
sent_ids = np.concatenate(sent_ids_list)
file_name, cur_num_batch = create_tfrecords(
save_dir=tfrecord_dir,
basename="{}-{}-{}".format(FLAGS.split, idx, FLAGS.pass_id),
data=[input_data, sent_ids],
bsz_per_host=FLAGS.bsz_per_host,
seq_len=FLAGS.seq_len,
bi_data=FLAGS.bi_data,
sp=sp,
)
filenames.append(file_name)
num_batch += cur_num_batch
record_info = {
"filenames": filenames,
"num_batch": num_batch
}
return record_info
def create_data(_):
# Validate FLAGS
assert FLAGS.bsz_per_host % FLAGS.num_core_per_host == 0
if not FLAGS.use_tpu:
FLAGS.num_core_per_host = 1 # forced to be one
# Make workdirs
if not tf.gfile.Exists(FLAGS.save_dir):
tf.gfile.MakeDirs(FLAGS.save_dir)
tfrecord_dir = os.path.join(FLAGS.save_dir, "tfrecords")
if not tf.gfile.Exists(tfrecord_dir):
tf.gfile.MakeDirs(tfrecord_dir)
# Create and dump corpus_info from task 0
if FLAGS.task == 0 and FLAGS.pass_id == 0:
corpus_info = {
"vocab_size": VOCAB_SIZE,
"bsz_per_host": FLAGS.bsz_per_host,
"num_core_per_host": FLAGS.num_core_per_host,
"seq_len": FLAGS.seq_len,
"reuse_len": FLAGS.reuse_len,
"uncased": FLAGS.uncased,
"bi_data": FLAGS.bi_data,
"mask_alpha": FLAGS.mask_alpha,
"mask_beta": FLAGS.mask_beta,
"num_predict": FLAGS.num_predict,
"use_eod": FLAGS.use_eod,
"sp_path": FLAGS.sp_path,
"input_glob": FLAGS.input_glob,
}
corpus_info_path = os.path.join(FLAGS.save_dir, "corpus_info.json")
with tf.gfile.Open(corpus_info_path, "w") as fp:
json.dump(corpus_info, fp)
# Interleavely split the work into FLAGS.num_task splits
file_paths = sorted(tf.gfile.Glob(FLAGS.input_glob))
tf.logging.info("Use glob: %s", FLAGS.input_glob)
tf.logging.info("Find %d files: %s", len(file_paths), file_paths)
task_file_paths = file_paths[FLAGS.task::FLAGS.num_task]
if not task_file_paths:
tf.logging.info("Exit: task %d has no file to process.", FLAGS.task)
return
tf.logging.info("Task %d process %d files: %s",
FLAGS.task, len(task_file_paths), task_file_paths)
record_info = _create_data(FLAGS.task, task_file_paths)
record_prefix = "record_info-{}-{}-{}".format(
FLAGS.split, FLAGS.task, FLAGS.pass_id)
record_name = format_filename(
prefix=record_prefix,
bsz_per_host=FLAGS.bsz_per_host,
seq_len=FLAGS.seq_len,
mask_alpha=FLAGS.mask_alpha,
mask_beta=FLAGS.mask_beta,
reuse_len=FLAGS.reuse_len,
bi_data=FLAGS.bi_data,
suffix="json",
uncased=FLAGS.uncased,
fixed_num_predict=FLAGS.num_predict)
record_info_path = os.path.join(tfrecord_dir, record_name)
with tf.gfile.Open(record_info_path, "w") as fp:
json.dump(record_info, fp)
def batchify(data, bsz_per_host, sent_ids=None):
num_step = len(data) // bsz_per_host
data = data[:bsz_per_host * num_step]
data = data.reshape(bsz_per_host, num_step)
if sent_ids is not None:
sent_ids = sent_ids[:bsz_per_host * num_step]
sent_ids = sent_ids.reshape(bsz_per_host, num_step)
if sent_ids is not None:
return data, sent_ids
return data
def _split_a_and_b(data, sent_ids, begin_idx, tot_len, extend_target=False):
"""Split two segments from `data` starting from the index `begin_idx`."""
data_len = data.shape[0]
if begin_idx + tot_len >= data_len:
tf.logging.info("[_split_a_and_b] returns None: "
"begin_idx %d + tot_len %d >= data_len %d",
begin_idx, tot_len, data_len)
return None
end_idx = begin_idx + 1
cut_points = []
while end_idx < data_len:
if sent_ids[end_idx] != sent_ids[end_idx - 1]:
if end_idx - begin_idx >= tot_len: break
cut_points.append(end_idx)
end_idx += 1
a_begin = begin_idx
if len(cut_points) == 0 or random.random() < 0.5:
label = 0
if len(cut_points) == 0:
a_end = end_idx
else:
a_end = random.choice(cut_points)
b_len = max(1, tot_len - (a_end - a_begin))
# (zihangd): `data_len - 1` to account for extend_target
b_begin = random.randint(0, data_len - 1 - b_len)
b_end = b_begin + b_len
while b_begin > 0 and sent_ids[b_begin - 1] == sent_ids[b_begin]:
b_begin -= 1
# (zihangd): `data_len - 1` to account for extend_target
while b_end < data_len - 1 and sent_ids[b_end - 1] == sent_ids[b_end]:
b_end += 1
new_begin = a_end
else:
label = 1
a_end = random.choice(cut_points)
b_begin = a_end
b_end = end_idx
new_begin = b_end
while a_end - a_begin + b_end - b_begin > tot_len:
if a_end - a_begin > b_end - b_begin:
# delete the right side only for the LM objective
a_end -= 1
else:
b_end -= 1
ret = [data[a_begin: a_end], data[b_begin: b_end], label, new_begin]
if extend_target:
if a_end >= data_len or b_end >= data_len:
tf.logging.info("[_split_a_and_b] returns None: "
"a_end %d or b_end %d >= data_len %d",
a_end, b_end, data_len)
return None
a_target = data[a_begin + 1: a_end + 1]
b_target = data[b_begin: b_end + 1]
ret.extend([a_target, b_target])
return ret
def _is_start_piece(piece):
special_pieces = set(list('!"#$%&\"()*+,-./:;?@[\\]^_`{|}~'))
if (piece.startswith("▁") or piece.startswith("<")
or piece in special_pieces):
return True
else:
return False
def _sample_mask(sp, seg, reverse=False, max_gram=5, goal_num_predict=None):
"""Sample `goal_num_predict` tokens for partial prediction.
About `mask_beta` tokens are chosen in a context of `mask_alpha` tokens."""
seg_len = len(seg)
mask = np.array([False] * seg_len, dtype=np.bool)
num_predict = 0
ngrams = np.arange(1, max_gram + 1, dtype=np.int64)
pvals = 1. / np.arange(1, max_gram + 1)
pvals /= pvals.sum(keepdims=True)
if reverse:
seg = np.flip(seg, 0)
cur_len = 0
while cur_len < seg_len:
if goal_num_predict is not None and num_predict >= goal_num_predict: break
n = np.random.choice(ngrams, p=pvals)
if goal_num_predict is not None:
n = min(n, goal_num_predict - num_predict)
ctx_size = (n * FLAGS.mask_alpha) // FLAGS.mask_beta
l_ctx = np.random.choice(ctx_size)
r_ctx = ctx_size - l_ctx
# Find the start position of a complete token
beg = cur_len + l_ctx
while beg < seg_len and not _is_start_piece(sp.IdToPiece(seg[beg].item())):
beg += 1
if beg >= seg_len:
break
# Find the end position of the n-gram (start pos of the n+1-th gram)
end = beg + 1
cnt_ngram = 1
while end < seg_len:
cnt_ngram += 1
if cnt_ngram > n:
break
end += 1
if end >= seg_len:
break
# Update
mask[beg:end] = True
num_predict += end - beg
cur_len = end + r_ctx
while goal_num_predict is not None and num_predict < goal_num_predict:
i = np.random.randint(seg_len)
if not mask[i]:
mask[i] = True
num_predict += 1
if reverse:
mask = np.flip(mask, 0)
return mask
def _sample_mask_ngram(sp, seg, reverse=False, max_gram=5,
goal_num_predict=None):
"""Sample `goal_num_predict` tokens for partial prediction.
About `mask_beta` tokens are chosen in a context of `mask_alpha` tokens."""
seg_len = len(seg)
mask = np.array([False] * seg_len, dtype=np.bool)
num_predict = 0
ngrams = np.arange(1, max_gram + 1, dtype=np.int64)
pvals = 1. / np.arange(1, max_gram + 1)
pvals /= pvals.sum(keepdims=True)
if reverse:
seg = np.flip(seg, 0)
cur_len = 0
while cur_len < seg_len:
if goal_num_predict is not None and num_predict >= goal_num_predict: break
n = np.random.choice(ngrams, p=pvals)
if goal_num_predict is not None:
n = min(n, goal_num_predict - num_predict)
ctx_size = (n * FLAGS.mask_alpha) // FLAGS.mask_beta
l_ctx = np.random.choice(ctx_size)
r_ctx = ctx_size - l_ctx
# Find the start position of a complete token
beg = cur_len + l_ctx
while beg < seg_len and not _is_start_piece(sp.IdToPiece(seg[beg].item())):
beg += 1
if beg >= seg_len:
break
# Find the end position of the n-gram (start pos of the n+1-th gram)
end = beg
cnt_ngram = 0
while end < seg_len:
if _is_start_piece(sp.IdToPiece(seg[end].item())):
cnt_ngram += 1
if cnt_ngram > n:
break
# select current piece
mask[end] = True
# update the end pointer and increment num_predict
end += 1
num_predict += 1
if goal_num_predict is not None and num_predict >= goal_num_predict:
break
cur_len = end + r_ctx
while goal_num_predict is not None and num_predict < goal_num_predict:
i = np.random.randint(seg_len)
if not mask[i]:
mask[i] = True
num_predict += 1
if reverse:
mask = np.flip(mask, 0)
return mask
def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len,
bi_data, sp):
data, sent_ids = data[0], data[1]
num_core = FLAGS.num_core_per_host
bsz_per_core = bsz_per_host // num_core
if bi_data:
assert bsz_per_host % (2 * FLAGS.num_core_per_host) == 0
fwd_data, fwd_sent_ids = batchify(data, bsz_per_host // 2, sent_ids)
fwd_data = fwd_data.reshape(num_core, 1, bsz_per_core // 2, -1)
fwd_sent_ids = fwd_sent_ids.reshape(num_core, 1, bsz_per_core // 2, -1)
bwd_data = fwd_data[:, :, :, ::-1]
bwd_sent_ids = fwd_sent_ids[:, :, :, ::-1]
data = np.concatenate(
[fwd_data, bwd_data], 1).reshape(bsz_per_host, -1)
sent_ids = np.concatenate(
[fwd_sent_ids, bwd_sent_ids], 1).reshape(bsz_per_host, -1)
else:
data, sent_ids = batchify(data, bsz_per_host, sent_ids)
tf.logging.info("Raw data shape %s.", data.shape)
file_name = format_filename(
prefix=basename,
bsz_per_host=bsz_per_host,
seq_len=seq_len,
bi_data=bi_data,
suffix="tfrecords",
mask_alpha=FLAGS.mask_alpha,
mask_beta=FLAGS.mask_beta,
reuse_len=FLAGS.reuse_len,
uncased=FLAGS.uncased,
fixed_num_predict=FLAGS.num_predict
)
save_path = os.path.join(save_dir, file_name)
record_writer = tf.python_io.TFRecordWriter(save_path)
tf.logging.info("Start writing %s.", save_path)
num_batch = 0
reuse_len = FLAGS.reuse_len
# [sep] x 2 + [cls]
assert reuse_len < seq_len - 3
data_len = data.shape[1]
sep_array = np.array([SEP_ID], dtype=np.int64)
cls_array = np.array([CLS_ID], dtype=np.int64)
i = 0
while i + seq_len <= data_len:
if num_batch % 500 == 0:
tf.logging.info("Processing batch %d", num_batch)
all_ok = True
features = []
for idx in range(bsz_per_host):
inp = data[idx, i: i + reuse_len]
tgt = data[idx, i + 1: i + reuse_len + 1]
results = _split_a_and_b(
data[idx],
sent_ids[idx],
begin_idx=i + reuse_len,
tot_len=seq_len - reuse_len - 3,
extend_target=True)
if results is None:
tf.logging.info("Break out with seq idx %d", i)
all_ok = False
break
# unpack the results
(a_data, b_data, label, _, a_target, b_target) = tuple(results)
# sample ngram spans to predict
reverse = bi_data and (idx // (bsz_per_core // 2)) % 2 == 1
if FLAGS.num_predict is None:
num_predict_0 = num_predict_1 = None
else:
num_predict_1 = FLAGS.num_predict // 2
num_predict_0 = FLAGS.num_predict - num_predict_1
mask_0 = _sample_mask(sp, inp, reverse=reverse,
goal_num_predict=num_predict_0)
mask_1 = _sample_mask(sp, np.concatenate([a_data, sep_array, b_data,
sep_array, cls_array]),
reverse=reverse, goal_num_predict=num_predict_1)
# concatenate data
cat_data = np.concatenate([inp, a_data, sep_array, b_data,
sep_array, cls_array])
seg_id = ([0] * (reuse_len + a_data.shape[0]) + [0] +
[1] * b_data.shape[0] + [1] + [2])
assert cat_data.shape[0] == seq_len
assert mask_0.shape[0] == seq_len // 2
assert mask_1.shape[0] == seq_len // 2
# the last two CLS's are not used, just for padding purposes
tgt = np.concatenate([tgt, a_target, b_target, cls_array, cls_array])
assert tgt.shape[0] == seq_len
is_masked = np.concatenate([mask_0, mask_1], 0)
if FLAGS.num_predict is not None:
assert np.sum(is_masked) == FLAGS.num_predict
feature = {
"input": _int64_feature(cat_data),
"is_masked": _int64_feature(is_masked),
"target": _int64_feature(tgt),
"seg_id": _int64_feature(seg_id),
"label": _int64_feature([label]),
}
features.append(feature)
if all_ok:
assert len(features) == bsz_per_host
for feature in features:
example = tf.train.Example(features=tf.train.Features(feature=feature))
record_writer.write(example.SerializeToString())
num_batch += 1
else:
break
i += reuse_len
record_writer.close()
tf.logging.info("Done writing %s. Num of batches: %d", save_path, num_batch)
return save_path, num_batch
################
# get_input_fn #
################
def _convert_example(example, use_bfloat16):
"""Cast int64 into int32 and float32 to bfloat16 if use_bfloat16."""
for key in list(example.keys()):
val = example[key]
if tf.keras.backend.is_sparse(val):
val = tf.sparse.to_dense(val)
if val.dtype == tf.int64:
val = tf.cast(val, tf.int32)
if use_bfloat16 and val.dtype == tf.float32:
val = tf.cast(val, tf.bfloat16)
example[key] = val
def parse_files_to_dataset(parser, file_names, split, num_batch, num_hosts,
host_id, num_core_per_host, bsz_per_core):
# list of file pathes
num_files = len(file_names)
num_files_per_host = num_files // num_hosts
my_start_file_id = host_id * num_files_per_host
my_end_file_id = (host_id + 1) * num_files_per_host
if host_id == num_hosts - 1:
my_end_file_id = num_files
file_paths = file_names[my_start_file_id: my_end_file_id]
tf.logging.info("Host %d handles %d files", host_id, len(file_paths))
assert split == "train"
dataset = tf.data.Dataset.from_tensor_slices(file_paths)
# file-level shuffle
if len(file_paths) > 1:
dataset = dataset.shuffle(len(file_paths))
# Note: we cannot perform sample-level shuffle here because this will violate
# the consecutive requirement of data stream.
dataset = tf.data.TFRecordDataset(dataset)
# Note: since we are doing online preprocessing, the parsed result of
# the same input at each time will be different. Thus, cache processed data
# is not helpful. It will use a lot of memory and lead to contrainer OOM.
# So, change to cache non-parsed raw data instead.
dataset = dataset.cache().map(parser).repeat()
dataset = dataset.batch(bsz_per_core, drop_remainder=True)
dataset = dataset.prefetch(num_core_per_host * bsz_per_core)
return dataset
def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
"""
Sample a permutation of the factorization order, and create an
attention mask accordingly.
Args:
inputs: int64 Tensor in shape [seq_len], input ids.
targets: int64 Tensor in shape [seq_len], target ids.
is_masked: bool Tensor in shape [seq_len]. True means being selected
for partial prediction.
perm_size: the length of longest permutation. Could be set to be reuse_len.
Should not be larger than reuse_len or there will be data leaks.
seq_len: int, sequence length.
"""
# Generate permutation indices
index = tf.range(seq_len, dtype=tf.int64)
index = tf.transpose(tf.reshape(index, [-1, perm_size]))
index = tf.random_shuffle(index)
index = tf.reshape(tf.transpose(index), [-1])
# `perm_mask` and `target_mask`
# non-functional tokens
non_func_tokens = tf.logical_not(tf.logical_or(
tf.equal(inputs, SEP_ID),
tf.equal(inputs, CLS_ID)))
non_mask_tokens = tf.logical_and(tf.logical_not(is_masked), non_func_tokens)
masked_or_func_tokens = tf.logical_not(non_mask_tokens)
# Set the permutation indices of non-masked (& non-funcional) tokens to the
# smallest index (-1):
# (1) they can be seen by all other positions
# (2) they cannot see masked positions, so there won"t be information leak
smallest_index = -tf.ones([seq_len], dtype=tf.int64)
rev_index = tf.where(non_mask_tokens, smallest_index, index)
# Create `target_mask`: non-funcional and maksed tokens
# 1: use mask as input and have loss
# 0: use token (or [SEP], [CLS]) as input and do not have loss
target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens)
target_mask = tf.cast(target_tokens, tf.float32)
# Create `perm_mask`
# `target_tokens` cannot see themselves
self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1)
# 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens)
# 0: can attend if i > j or j is non-masked
perm_mask = tf.logical_and(
self_rev_index[:, None] <= rev_index[None, :],
masked_or_func_tokens)
perm_mask = tf.cast(perm_mask, tf.float32)
# new target: [next token] for LM and [curr token] (self) for PLM
new_targets = tf.concat([inputs[0: 1], targets[: -1]],
axis=0)
# construct inputs_k
inputs_k = inputs
# construct inputs_q
inputs_q = target_mask
return perm_mask, new_targets, target_mask, inputs_k, inputs_q
def get_dataset(params, num_hosts, num_core_per_host, split, file_names,
num_batch, seq_len, reuse_len, perm_size, mask_alpha,
mask_beta, use_bfloat16=False, num_predict=None):
bsz_per_core = params["batch_size"]
if num_hosts > 1:
host_id = params["context"].current_host
else:
host_id = 0
#### Function used to parse tfrecord
def parser(record):
"""function used to parse tfrecord."""
record_spec = {
"input": tf.FixedLenFeature([seq_len], tf.int64),
"target": tf.FixedLenFeature([seq_len], tf.int64),
"seg_id": tf.FixedLenFeature([seq_len], tf.int64),
"label": tf.FixedLenFeature([1], tf.int64),
"is_masked": tf.FixedLenFeature([seq_len], tf.int64),
}
# retrieve serialized example
example = tf.parse_single_example(
serialized=record,
features=record_spec)
inputs = example.pop("input")
target = example.pop("target")
is_masked = tf.cast(example.pop("is_masked"), tf.bool)
non_reuse_len = seq_len - reuse_len
assert perm_size <= reuse_len and perm_size <= non_reuse_len
perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0 = _local_perm(
inputs[:reuse_len],
target[:reuse_len],
is_masked[:reuse_len],
perm_size,
reuse_len)
perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1 = _local_perm(
inputs[reuse_len:],
target[reuse_len:],
is_masked[reuse_len:],
perm_size,
non_reuse_len)
perm_mask_0 = tf.concat([perm_mask_0, tf.ones([reuse_len, non_reuse_len])],
axis=1)
perm_mask_1 = tf.concat([tf.zeros([non_reuse_len, reuse_len]), perm_mask_1],
axis=1)
perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0)
target = tf.concat([target_0, target_1], axis=0)
target_mask = tf.concat([target_mask_0, target_mask_1], axis=0)
input_k = tf.concat([input_k_0, input_k_1], axis=0)
input_q = tf.concat([input_q_0, input_q_1], axis=0)
if num_predict is not None:
indices = tf.range(seq_len, dtype=tf.int64)
bool_target_mask = tf.cast(target_mask, tf.bool)
indices = tf.boolean_mask(indices, bool_target_mask)
##### extra padding due to CLS/SEP introduced after prepro
actual_num_predict = tf.shape(indices)[0]
pad_len = num_predict - actual_num_predict
##### target_mapping
target_mapping = tf.one_hot(indices, seq_len, dtype=tf.float32)
paddings = tf.zeros([pad_len, seq_len], dtype=target_mapping.dtype)
target_mapping = tf.concat([target_mapping, paddings], axis=0)
example["target_mapping"] = tf.reshape(target_mapping,
[num_predict, seq_len])
##### target
target = tf.boolean_mask(target, bool_target_mask)
paddings = tf.zeros([pad_len], dtype=target.dtype)
target = tf.concat([target, paddings], axis=0)
example["target"] = tf.reshape(target, [num_predict])
##### target mask
target_mask = tf.concat(
[tf.ones([actual_num_predict], dtype=tf.float32),
tf.zeros([pad_len], dtype=tf.float32)],
axis=0)
example["target_mask"] = tf.reshape(target_mask, [num_predict])
else:
example["target"] = tf.reshape(target, [seq_len])
example["target_mask"] = tf.reshape(target_mask, [seq_len])
# reshape back to fixed shape
example["perm_mask"] = tf.reshape(perm_mask, [seq_len, seq_len])
example["input_k"] = tf.reshape(input_k, [seq_len])
example["input_q"] = tf.reshape(input_q, [seq_len])
_convert_example(example, use_bfloat16)
for k, v in example.items():
tf.logging.info("%s: %s", k, v)
return example
# Get dataset
dataset = parse_files_to_dataset(
parser=parser,
file_names=file_names,
split=split,
num_batch=num_batch,
num_hosts=num_hosts,
host_id=host_id,
num_core_per_host=num_core_per_host,
bsz_per_core=bsz_per_core)
return dataset
def get_input_fn(
tfrecord_dir,
split,
bsz_per_host,
seq_len,
reuse_len,
bi_data,
num_hosts=1,
num_core_per_host=1,
perm_size=None,
mask_alpha=None,
mask_beta=None,
uncased=False,
num_passes=None,
use_bfloat16=False,
num_predict=None):
# Merge all record infos into a single one
record_glob_base = format_filename(
prefix="record_info-{}-*".format(split),
bsz_per_host=bsz_per_host,
seq_len=seq_len,
bi_data=bi_data,
suffix="json",
mask_alpha=mask_alpha,
mask_beta=mask_beta,
reuse_len=reuse_len,
uncased=uncased,
fixed_num_predict=num_predict)
record_info = {"num_batch": 0, "filenames": []}
tfrecord_dirs = tfrecord_dir.split(",")
tf.logging.info("Use the following tfrecord dirs: %s", tfrecord_dirs)
for idx, record_dir in enumerate(tfrecord_dirs):
record_glob = os.path.join(record_dir, record_glob_base)
tf.logging.info("[%d] Record glob: %s", idx, record_glob)
record_paths = sorted(tf.gfile.Glob(record_glob))
tf.logging.info("[%d] Num of record info path: %d",
idx, len(record_paths))
cur_record_info = {"num_batch": 0, "filenames": []}
for record_info_path in record_paths:
if num_passes is not None:
record_info_name = os.path.basename(record_info_path)
fields = record_info_name.split(".")[0].split("-")
pass_id = int(fields[-1])
if len(fields) == 5 and pass_id >= num_passes:
tf.logging.info("Skip pass %d: %s", pass_id, record_info_name)
continue
with tf.gfile.Open(record_info_path, "r") as fp:
info = json.load(fp)
if num_passes is not None:
eff_num_passes = min(num_passes, len(info["filenames"]))
ratio = eff_num_passes / len(info["filenames"])
cur_record_info["num_batch"] += int(info["num_batch"] * ratio)
cur_record_info["filenames"] += info["filenames"][:eff_num_passes]
else:
cur_record_info["num_batch"] += info["num_batch"]
cur_record_info["filenames"] += info["filenames"]
# overwrite directory for `cur_record_info`
new_filenames = []
for filename in cur_record_info["filenames"]:
basename = os.path.basename(filename)
new_filename = os.path.join(record_dir, basename)
new_filenames.append(new_filename)
cur_record_info["filenames"] = new_filenames
tf.logging.info("[Dir %d] Number of chosen batches: %s",
idx, cur_record_info["num_batch"])
tf.logging.info("[Dir %d] Number of chosen files: %s",
idx, len(cur_record_info["filenames"]))
tf.logging.info(cur_record_info["filenames"])
# add `cur_record_info` to global `record_info`
record_info["num_batch"] += cur_record_info["num_batch"]
record_info["filenames"] += cur_record_info["filenames"]
tf.logging.info("Total number of batches: %d",
record_info["num_batch"])
tf.logging.info("Total number of files: %d",
len(record_info["filenames"]))
tf.logging.info(record_info["filenames"])
def input_fn(params):
"""docs."""
assert params["batch_size"] * num_core_per_host == bsz_per_host
dataset = get_dataset(
params=params,
num_hosts=num_hosts,
num_core_per_host=num_core_per_host,
split=split,
file_names=record_info["filenames"],
num_batch=record_info["num_batch"],
seq_len=seq_len,
reuse_len=reuse_len,
perm_size=perm_size,
mask_alpha=mask_alpha,
mask_beta=mask_beta,
use_bfloat16=use_bfloat16,
num_predict=num_predict)
return dataset
return input_fn, record_info
if __name__ == "__main__":
FLAGS = flags.FLAGS
flags.DEFINE_bool("use_tpu", True, help="whether to use TPUs")
flags.DEFINE_integer("bsz_per_host", 32, help="batch size per host.")
flags.DEFINE_integer("num_core_per_host", 8, help="num TPU cores per host.")
flags.DEFINE_integer("seq_len", 512,
help="Sequence length.")
flags.DEFINE_integer("reuse_len", 256,
help="Number of token that can be reused as memory. "
"Could be half of `seq_len`.")
flags.DEFINE_bool("uncased", False, help="Use uncased inputs or not.")
flags.DEFINE_bool("bi_data", True,
help="whether to create bidirectional data")
flags.DEFINE_integer("mask_alpha", default=6,
help="How many tokens to form a group.")
flags.DEFINE_integer("mask_beta", default=1,
help="How many tokens to mask within each group.")
flags.DEFINE_bool("use_eod", True,
help="whether to append EOD at the end of a doc.")
flags.DEFINE_bool("from_raw_text", True,
help="Whether the input is raw text or encoded ids.")
flags.DEFINE_integer("num_predict", default=85,
help="Num of tokens to predict.")
flags.DEFINE_string("input_glob", "data/example/*.txt",
help="Input file glob.")
flags.DEFINE_string("sp_path", "", help="Path to the sentence piece model.")
flags.DEFINE_string("save_dir", "proc_data/example",
help="Directory for saving the processed data.")
flags.DEFINE_enum("split", "train", ["train", "dev", "test"],
help="Save the data as which split.")
flags.DEFINE_integer("pass_id", 0, help="ID of the current pass."
"Different passes sample different negative segment.")
flags.DEFINE_integer("num_task", 1, help="Number of total tasks.")
flags.DEFINE_integer("task", 0, help="The Task ID. This value is used when "
"using multiple workers to identify each worker.")
tf.logging.set_verbosity(tf.logging.INFO)
absl_app.run(create_data)
# coding=utf-8
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Script to pre-process SQUAD data into tfrecords."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import pickle
import random
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
import sentencepiece as spm
from official.nlp.xlnet import squad_utils
flags.DEFINE_integer(
"num_proc", default=1, help="Number of preprocessing processes.")
flags.DEFINE_integer("proc_id", default=0, help="Process id for preprocessing.")
# I/O paths
flags.DEFINE_string("output_dir", default="", help="Output dir for TF records.")
flags.DEFINE_string(
"spiece_model_file", default="", help="Sentence Piece model path.")
flags.DEFINE_string("train_file", default="", help="Path of train file.")
flags.DEFINE_string("predict_file", default="", help="Path of prediction file.")
# Data preprocessing config
flags.DEFINE_integer("max_seq_length", default=512, help="Max sequence length")
flags.DEFINE_integer("max_query_length", default=64, help="Max query length")
flags.DEFINE_integer("doc_stride", default=128, help="Doc stride")
flags.DEFINE_bool("uncased", default=False, help="Use uncased data.")
flags.DEFINE_bool(
"create_train_data", default=True, help="Whether to create training data.")
flags.DEFINE_bool(
"create_eval_data", default=False, help="Whether to create eval data.")
FLAGS = flags.FLAGS
def _get_spm_basename():
spm_basename = os.path.basename(FLAGS.spiece_model_file)
return spm_basename
def preprocess():
"""Preprocesses SQUAD data."""
sp_model = spm.SentencePieceProcessor()
sp_model.Load(FLAGS.spiece_model_file)
spm_basename = _get_spm_basename()
if FLAGS.create_train_data:
train_rec_file = os.path.join(
FLAGS.output_dir,
"{}.{}.slen-{}.qlen-{}.train.tf_record".format(spm_basename,
FLAGS.proc_id,
FLAGS.max_seq_length,
FLAGS.max_query_length))
logging.info("Read examples from %s", FLAGS.train_file)
train_examples = squad_utils.read_squad_examples(
FLAGS.train_file, is_training=True)
train_examples = train_examples[FLAGS.proc_id::FLAGS.num_proc]
# Pre-shuffle the input to avoid having to make a very large shuffle
# buffer in the `input_fn`.
random.shuffle(train_examples)
write_to_logging = "Write to " + train_rec_file
logging.info(write_to_logging)
train_writer = squad_utils.FeatureWriter(
filename=train_rec_file, is_training=True)
squad_utils.convert_examples_to_features(
examples=train_examples,
sp_model=sp_model,
max_seq_length=FLAGS.max_seq_length,
doc_stride=FLAGS.doc_stride,
max_query_length=FLAGS.max_query_length,
is_training=True,
output_fn=train_writer.process_feature,
uncased=FLAGS.uncased)
train_writer.close()
if FLAGS.create_eval_data:
eval_examples = squad_utils.read_squad_examples(
FLAGS.predict_file, is_training=False)
eval_rec_file = os.path.join(
FLAGS.output_dir,
"{}.slen-{}.qlen-{}.eval.tf_record".format(spm_basename,
FLAGS.max_seq_length,
FLAGS.max_query_length))
eval_feature_file = os.path.join(
FLAGS.output_dir,
"{}.slen-{}.qlen-{}.eval.features.pkl".format(spm_basename,
FLAGS.max_seq_length,
FLAGS.max_query_length))
eval_writer = squad_utils.FeatureWriter(
filename=eval_rec_file, is_training=False)
eval_features = []
def append_feature(feature):
eval_features.append(feature)
eval_writer.process_feature(feature)
squad_utils.convert_examples_to_features(
examples=eval_examples,
sp_model=sp_model,
max_seq_length=FLAGS.max_seq_length,
doc_stride=FLAGS.doc_stride,
max_query_length=FLAGS.max_query_length,
is_training=False,
output_fn=append_feature,
uncased=FLAGS.uncased)
eval_writer.close()
with tf.io.gfile.GFile(eval_feature_file, "wb") as fout:
pickle.dump(eval_features, fout)
def main(_):
logging.set_verbosity(logging.INFO)
if not tf.io.gfile.exists(FLAGS.output_dir):
tf.io.gfile.mkdir(FLAGS.output_dir)
preprocess()
if __name__ == "__main__":
app.run(main)
# coding=utf-8
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for pre-processing."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unicodedata
import six
SPIECE_UNDERLINE = '▁'
def printable_text(text):
"""Returns text encoded in a way suitable for print or `tf.logging`."""
# These functions want `str` for both Python2 and Python3, but in one case
# it's a Unicode string and in the other it's a byte string.
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode('utf-8', 'ignore')
else:
raise ValueError('Unsupported string type: %s' % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text
elif isinstance(text, unicode):
return text.encode('utf-8')
else:
raise ValueError('Unsupported string type: %s' % (type(text)))
else:
raise ValueError('Not running on Python2 or Python 3?')
def print_(*args):
new_args = []
for arg in args:
if isinstance(arg, list):
s = [printable_text(i) for i in arg]
s = ' '.join(s)
new_args.append(s)
else:
new_args.append(printable_text(arg))
print(*new_args)
def preprocess_text(inputs, lower=False, remove_space=True, keep_accents=False):
"""Preprocesses texts."""
if remove_space:
outputs = ' '.join(inputs.strip().split())
else:
outputs = inputs
outputs = outputs.replace('``', '"').replace("''", '"')
if six.PY2 and isinstance(outputs, str):
outputs = outputs.decode('utf-8')
if not keep_accents:
outputs = unicodedata.normalize('NFKD', outputs)
outputs = ''.join([c for c in outputs if not unicodedata.combining(c)])
if lower:
outputs = outputs.lower()
return outputs
def encode_pieces(sp_model, text, return_unicode=True, sample=False):
"""Encodes pieces."""
# return_unicode is used only for py2
if six.PY2 and isinstance(text, unicode):
text = text.encode('utf-8')
if not sample:
pieces = sp_model.EncodeAsPieces(text)
else:
pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1)
new_pieces = []
for piece in pieces:
if len(piece) > 1 and piece[-1] == ',' and piece[-2].isdigit():
cur_pieces = sp_model.EncodeAsPieces(
piece[:-1].replace(SPIECE_UNDERLINE, ''))
if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
if len(cur_pieces[0]) == 1:
cur_pieces = cur_pieces[1:]
else:
cur_pieces[0] = cur_pieces[0][1:]
cur_pieces.append(piece[-1])
new_pieces.extend(cur_pieces)
else:
new_pieces.append(piece)
# note(zhiliny): convert back to unicode for py2
if six.PY2 and return_unicode:
ret_pieces = []
for piece in new_pieces:
if isinstance(piece, str):
piece = piece.decode('utf-8')
ret_pieces.append(piece)
new_pieces = ret_pieces
return new_pieces
def encode_ids(sp_model, text, sample=False):
pieces = encode_pieces(sp_model, text, return_unicode=False, sample=sample)
ids = [sp_model.PieceToId(piece) for piece in pieces]
return ids
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""XLNet classification finetuning runner in tf2.0."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import functools
from absl import app
from absl import flags
from absl import logging
import numpy as np
import tensorflow as tf
# pylint: disable=unused-import
# Initialize TPU System.
from official.nlp import xlnet_config
from official.nlp import xlnet_modeling as modeling
from official.nlp.xlnet import common_flags
from official.nlp.xlnet import data_utils
from official.nlp.xlnet import optimization
from official.nlp.xlnet import training_utils
flags.DEFINE_integer("n_class", default=2, help="Number of classes.")
FLAGS = flags.FLAGS
def get_classificationxlnet_model(model_config, run_config, n_class):
model = modeling.ClassificationXLNetModel(
model_config, run_config, n_class, name="model")
return model
def run_evaluation(strategy,
test_input_fn,
eval_steps,
model,
step,
eval_summary_writer=None):
"""Run evaluation for classification task.
Args:
strategy: distribution strategy.
test_input_fn: input function for evaluation data.
eval_steps: total number of evaluation steps.
model: keras model object.
step: current train step.
eval_summary_writer: summary writer used to record evaluation metrics. As
there are fake data samples in validation set, we use mask to get rid of
them when calculating the accuracy. For the reason that there will be
dynamic-shape tensor, we first collect logits, labels and masks from TPU
and calculate the accuracy via numpy locally.
"""
def _test_step_fn(inputs):
"""Replicated validation step."""
inputs["mems"] = None
_, logits = model(inputs, training=False)
return logits, inputs["label_ids"], inputs["is_real_example"]
@tf.function
def _run_evaluation(test_iterator):
"""Runs validation steps."""
logits, labels, masks = strategy.experimental_run_v2(
_test_step_fn, args=(next(test_iterator),))
return logits, labels, masks
# pylint: disable=protected-access
test_iterator = data_utils._get_input_iterator(test_input_fn, strategy)
# pylint: enable=protected-access
correct = 0
total = 0
for _ in range(eval_steps):
logits, labels, masks = _run_evaluation(test_iterator)
logits = strategy.experimental_local_results(logits)
labels = strategy.experimental_local_results(labels)
masks = strategy.experimental_local_results(masks)
merged_logits = []
merged_labels = []
merged_masks = []
for i in range(strategy.num_replicas_in_sync):
merged_logits.append(logits[i].numpy())
merged_labels.append(labels[i].numpy())
merged_masks.append(masks[i].numpy())
merged_logits = np.vstack(np.array(merged_logits))
merged_labels = np.hstack(np.array(merged_labels))
merged_masks = np.hstack(np.array(merged_masks))
real_index = np.where(np.equal(merged_masks, 1))
correct += np.sum(
np.equal(
np.argmax(merged_logits[real_index], axis=-1),
merged_labels[real_index]))
total += np.shape(real_index)[-1]
logging.info("Train step: %d / acc = %d/%d = %f", step, correct, total,
float(correct) / float(total))
if eval_summary_writer:
with eval_summary_writer.as_default():
tf.summary.scalar("eval_acc", float(correct) / float(total), step=step)
eval_summary_writer.flush()
def get_metric_fn():
train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy(
"acc", dtype=tf.float32)
return train_acc_metric
def get_primary_cpu_task(use_remote_tpu=False):
"""Returns primary CPU task to which input pipeline Ops are put."""
# Remote Eager Borg job configures the TPU worker with job name 'worker'.
return "/job:worker" if use_remote_tpu else ""
def main(unused_argv):
del unused_argv
use_remote_tpu = False
if FLAGS.strategy_type == "mirror":
strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == "tpu":
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
use_remote_tpu = True
else:
raise ValueError("The distribution strategy type is not supported: %s" %
FLAGS.strategy_type)
if strategy:
logging.info("***** Number of cores used : %d",
strategy.num_replicas_in_sync)
train_input_fn = functools.partial(data_utils.get_classification_input_data,
FLAGS.train_batch_size, FLAGS.seq_len,
strategy, True, FLAGS.train_tfrecord_path)
test_input_fn = functools.partial(data_utils.get_classification_input_data,
FLAGS.test_batch_size, FLAGS.seq_len,
strategy, False, FLAGS.test_tfrecord_path)
total_training_steps = FLAGS.train_steps
steps_per_epoch = int(FLAGS.train_data_size / FLAGS.train_batch_size)
steps_per_loop = FLAGS.iterations
eval_steps = int(FLAGS.test_data_size / FLAGS.test_batch_size)
eval_fn = functools.partial(run_evaluation, strategy, test_input_fn,
eval_steps)
optimizer, learning_rate_fn = optimization.create_optimizer(
FLAGS.learning_rate,
total_training_steps,
FLAGS.warmup_steps,
adam_epsilon=FLAGS.adam_epsilon)
model_config = xlnet_config.XLNetConfig(FLAGS)
run_config = xlnet_config.create_run_config(True, False, FLAGS)
model_fn = functools.partial(get_classificationxlnet_model, model_config,
run_config, FLAGS.n_class)
input_meta_data = {}
input_meta_data["d_model"] = FLAGS.d_model
input_meta_data["mem_len"] = FLAGS.mem_len
input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size /
strategy.num_replicas_in_sync)
input_meta_data["n_layer"] = FLAGS.n_layer
input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
input_meta_data["n_class"] = FLAGS.n_class
print("DEBUG: ", str(input_meta_data))
def logits_init_fn():
return tf.zeros(
shape=(input_meta_data["batch_size_per_core"],
input_meta_data["n_class"]),
dtype=tf.float32)
with tf.device(get_primary_cpu_task(use_remote_tpu)):
training_utils.train(
strategy=strategy,
model_fn=model_fn,
input_meta_data=input_meta_data,
eval_fn=eval_fn,
metric_fn=get_metric_fn,
logits_init_fn=logits_init_fn,
train_input_fn=train_input_fn,
test_input_fn=test_input_fn,
init_checkpoint=FLAGS.init_checkpoint,
total_training_steps=total_training_steps,
steps_per_epoch=steps_per_epoch,
steps_per_loop=steps_per_loop,
optimizer=optimizer,
learning_rate_fn=learning_rate_fn,
model_dir=FLAGS.model_dir)
if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
app.run(main)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""XLNet classification finetuning runner in tf2.0."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import functools
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
# pylint: disable=unused-import
# Initialize TPU System.
from official.nlp import xlnet_config
from official.nlp import xlnet_modeling as modeling
from official.nlp.xlnet import common_flags
from official.nlp.xlnet import data_utils
from official.nlp.xlnet import optimization
from official.nlp.xlnet import training_utils
flags.DEFINE_integer(
"mask_alpha", default=6, help="How many tokens to form a group.")
flags.DEFINE_integer(
"mask_beta", default=1, help="How many tokens to mask within each group.")
flags.DEFINE_integer(
"num_predict",
default=None,
help="Number of tokens to predict in partial prediction.")
flags.DEFINE_integer("perm_size", 0, help="Window size of permutation.")
FLAGS = flags.FLAGS
def get_pretrainxlnet_model(model_config, run_config):
model = modeling.PretrainingXLNetModel(model_config, run_config, name="model")
return model
def get_primary_cpu_task(use_remote_tpu=False):
"""Returns primary CPU task to which input pipeline Ops are put."""
# Remote Eager Borg job configures the TPU worker with job name 'worker'.
return "/job:worker" if use_remote_tpu else ""
def main(unused_argv):
del unused_argv
use_remote_tpu = False
num_hosts = 1
if FLAGS.strategy_type == "mirror":
strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == "tpu":
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
use_remote_tpu = True
topology = FLAGS.tpu_topology.split("x")
total_num_core = 2 * int(topology[0]) * int(topology[1])
num_hosts = total_num_core // FLAGS.num_core_per_host
else:
raise ValueError("The distribution strategy type is not supported: %s" %
FLAGS.strategy_type)
if strategy:
logging.info("***** Number of cores used : %d",
strategy.num_replicas_in_sync)
logging.info("***** Number of hosts used : %d",
num_hosts)
train_input_fn = functools.partial(
data_utils.get_pretrain_input_data, FLAGS.train_batch_size, FLAGS.seq_len,
strategy, FLAGS.train_tfrecord_path, FLAGS.reuse_len, FLAGS.perm_size,
FLAGS.mask_alpha, FLAGS.mask_beta, FLAGS.num_predict, FLAGS.bi_data,
FLAGS.uncased, num_hosts)
total_training_steps = FLAGS.train_steps
steps_per_epoch = int(FLAGS.train_data_size / FLAGS.train_batch_size)
steps_per_loop = FLAGS.iterations
optimizer, learning_rate_fn = optimization.create_optimizer(
init_lr=FLAGS.learning_rate,
num_train_steps=total_training_steps,
num_warmup_steps=FLAGS.warmup_steps,
min_lr_ratio=FLAGS.min_lr_ratio,
adam_epsilon=FLAGS.adam_epsilon,
weight_decay_rate=FLAGS.weight_decay_rate)
model_config = xlnet_config.XLNetConfig(FLAGS)
run_config = xlnet_config.create_run_config(True, False, FLAGS)
input_meta_data = {}
input_meta_data["d_model"] = FLAGS.d_model
input_meta_data["mem_len"] = FLAGS.mem_len
input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size /
strategy.num_replicas_in_sync)
input_meta_data["n_layer"] = FLAGS.n_layer
input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
model_fn = functools.partial(get_pretrainxlnet_model, model_config,
run_config)
def logits_init_fn():
return tf.zeros(
shape=(FLAGS.num_predict, input_meta_data["batch_size_per_core"],
FLAGS.d_model),
dtype=tf.float32)
with tf.device(get_primary_cpu_task(use_remote_tpu)):
training_utils.train(
strategy=strategy,
model_fn=model_fn,
input_meta_data=input_meta_data,
eval_fn=None,
metric_fn=None,
logits_init_fn=logits_init_fn,
train_input_fn=train_input_fn,
test_input_fn=None,
init_checkpoint=FLAGS.init_checkpoint,
total_training_steps=total_training_steps,
steps_per_epoch=steps_per_epoch,
steps_per_loop=steps_per_loop,
optimizer=optimizer,
learning_rate_fn=learning_rate_fn,
model_dir=FLAGS.model_dir,
save_steps=FLAGS.save_steps)
if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
app.run(main)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""XLNet SQUAD finetuning runner in tf2.0."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import functools
import json
import os
import pickle
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
# pylint: disable=unused-import
# Initialize TPU System.
from official.nlp import xlnet_config
from official.nlp import xlnet_modeling as modeling
from official.nlp.xlnet import common_flags
from official.nlp.xlnet import data_utils
from official.nlp.xlnet import optimization
from official.nlp.xlnet import squad_utils
from official.nlp.xlnet import training_utils
flags.DEFINE_string(
"test_feature_path", default=None, help="Path to feature of test set.")
flags.DEFINE_integer("query_len", default=64, help="Max query length.")
flags.DEFINE_integer("start_n_top", default=5, help="Beam size for span start.")
flags.DEFINE_integer("end_n_top", default=5, help="Beam size for span end.")
flags.DEFINE_string(
"predict_dir", default=None, help="Path to write predictions.")
flags.DEFINE_string(
"predict_file", default=None, help="Path to json file of test set.")
flags.DEFINE_integer(
"n_best_size", default=5, help="n best size for predictions.")
flags.DEFINE_integer("max_answer_length", default=64, help="Max answer length.")
FLAGS = flags.FLAGS
class InputFeatures(object):
"""A single set of features of data."""
def __init__(self,
unique_id,
example_index,
doc_span_index,
tok_start_to_orig_index,
tok_end_to_orig_index,
token_is_max_context,
input_ids,
input_mask,
p_mask,
segment_ids,
paragraph_len,
cls_index,
start_position=None,
end_position=None,
is_impossible=None):
self.unique_id = unique_id
self.example_index = example_index
self.doc_span_index = doc_span_index
self.tok_start_to_orig_index = tok_start_to_orig_index
self.tok_end_to_orig_index = tok_end_to_orig_index
self.token_is_max_context = token_is_max_context
self.input_ids = input_ids
self.input_mask = input_mask
self.p_mask = p_mask
self.segment_ids = segment_ids
self.paragraph_len = paragraph_len
self.cls_index = cls_index
self.start_position = start_position
self.end_position = end_position
self.is_impossible = is_impossible
def get_primary_cpu_task(use_remote_tpu=False):
"""Returns primary CPU task to which input pipeline Ops are put."""
# Remote Eager Borg job configures the TPU worker with job name 'worker'.
return "/job:worker" if use_remote_tpu else ""
# pylint: disable=unused-argument
def run_evaluation(strategy,
test_input_fn,
eval_steps,
input_meta_data,
model,
step,
eval_summary_writer=None):
"""Run evaluation for SQUAD task.
Args:
strategy: distribution strategy.
test_input_fn: input function for evaluation data.
eval_steps: total number of evaluation steps.
input_meta_data: input meta data.
model: keras model object.
step: current training step.
eval_summary_writer: summary writer used to record evaluation metrics.
"""
def _test_step_fn(inputs):
"""Replicated validation step."""
inputs["mems"] = None
res = model(inputs, training=False)
return res, inputs["unique_ids"]
@tf.function
def _run_evaluation(test_iterator):
"""Runs validation steps."""
res, unique_ids = strategy.experimental_run_v2(
_test_step_fn, args=(next(test_iterator),))
return res, unique_ids
# pylint: disable=protected-access
test_iterator = data_utils._get_input_iterator(test_input_fn, strategy)
# pylint: enable=protected-access
cur_results = []
eval_examples = squad_utils.read_squad_examples(
input_meta_data["predict_file"], is_training=False)
with tf.io.gfile.GFile(input_meta_data["predict_file"]) as f:
orig_data = json.load(f)["data"]
for _ in range(eval_steps):
results, unique_ids = _run_evaluation(test_iterator)
unique_ids = strategy.experimental_local_results(unique_ids)
for result_key in results:
results[result_key] = (
strategy.experimental_local_results(results[result_key]))
for core_i in range(strategy.num_replicas_in_sync):
bsz = int(input_meta_data["test_batch_size"] /
strategy.num_replicas_in_sync)
for j in range(bsz):
result = {}
for result_key in results:
result[result_key] = results[result_key][core_i].numpy()[j]
result["unique_ids"] = unique_ids[core_i].numpy()[j]
# We appended a fake example into dev set to make data size can be
# divided by test_batch_size. Ignores this fake example during
# evaluation.
if result["unique_ids"] == 1000012047:
continue
unique_id = int(result["unique_ids"])
start_top_log_probs = ([
float(x) for x in result["start_top_log_probs"].flat
])
start_top_index = [int(x) for x in result["start_top_index"].flat]
end_top_log_probs = ([
float(x) for x in result["end_top_log_probs"].flat
])
end_top_index = [int(x) for x in result["end_top_index"].flat]
cls_logits = float(result["cls_logits"].flat[0])
cur_results.append(
squad_utils.RawResult(
unique_id=unique_id,
start_top_log_probs=start_top_log_probs,
start_top_index=start_top_index,
end_top_log_probs=end_top_log_probs,
end_top_index=end_top_index,
cls_logits=cls_logits))
if len(cur_results) % 1000 == 0:
logging.info("Processing example: %d", len(cur_results))
output_prediction_file = os.path.join(input_meta_data["predict_dir"],
"predictions.json")
output_nbest_file = os.path.join(input_meta_data["predict_dir"],
"nbest_predictions.json")
output_null_log_odds_file = os.path.join(input_meta_data["predict_dir"],
"null_odds.json")
ret = squad_utils.write_predictions(
eval_examples, input_meta_data["eval_features"], cur_results,
input_meta_data["n_best_size"], input_meta_data["max_answer_length"],
output_prediction_file, output_nbest_file, output_null_log_odds_file,
orig_data, input_meta_data["start_n_top"], input_meta_data["end_n_top"])
# Log current result
log_str = "Result | "
for key, val in ret.items():
log_str += "{} {} | ".format(key, val)
logging.info(log_str)
if eval_summary_writer:
with eval_summary_writer.as_default():
tf.summary.scalar("best_f1", ret["best_f1"], step=step)
tf.summary.scalar("best_exact", ret["best_exact"], step=step)
eval_summary_writer.flush()
def get_qaxlnet_model(model_config, run_config, start_n_top, end_n_top):
model = modeling.QAXLNetModel(
model_config,
run_config,
start_n_top=start_n_top,
end_n_top=end_n_top,
name="model")
return model
def main(unused_argv):
del unused_argv
use_remote_tpu = False
if FLAGS.strategy_type == "mirror":
strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == "tpu":
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
use_remote_tpu = True
else:
raise ValueError("The distribution strategy type is not supported: %s" %
FLAGS.strategy_type)
if strategy:
logging.info("***** Number of cores used : %d",
strategy.num_replicas_in_sync)
train_input_fn = functools.partial(data_utils.get_squad_input_data,
FLAGS.train_batch_size, FLAGS.seq_len,
FLAGS.query_len, strategy, True,
FLAGS.train_tfrecord_path)
test_input_fn = functools.partial(data_utils.get_squad_input_data,
FLAGS.test_batch_size, FLAGS.seq_len,
FLAGS.query_len, strategy, False,
FLAGS.test_tfrecord_path)
total_training_steps = FLAGS.train_steps
steps_per_epoch = int(FLAGS.train_data_size / FLAGS.train_batch_size)
steps_per_loop = FLAGS.iterations
eval_steps = int(FLAGS.test_data_size / FLAGS.test_batch_size)
optimizer, learning_rate_fn = optimization.create_optimizer(
FLAGS.learning_rate,
total_training_steps,
FLAGS.warmup_steps,
adam_epsilon=FLAGS.adam_epsilon)
model_config = xlnet_config.XLNetConfig(FLAGS)
run_config = xlnet_config.create_run_config(True, False, FLAGS)
input_meta_data = {}
input_meta_data["start_n_top"] = FLAGS.start_n_top
input_meta_data["end_n_top"] = FLAGS.end_n_top
input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
input_meta_data["predict_dir"] = FLAGS.predict_dir
input_meta_data["predict_file"] = FLAGS.predict_file
input_meta_data["n_best_size"] = FLAGS.n_best_size
input_meta_data["max_answer_length"] = FLAGS.max_answer_length
input_meta_data["test_feature_path"] = FLAGS.test_feature_path
input_meta_data["test_batch_size"] = FLAGS.test_batch_size
input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size /
strategy.num_replicas_in_sync)
input_meta_data["mem_len"] = FLAGS.mem_len
model_fn = functools.partial(get_qaxlnet_model, model_config, run_config,
FLAGS.start_n_top, FLAGS.end_n_top)
def logits_init_fn():
return tf.zeros(
shape=(input_meta_data["batch_size_per_core"]), dtype=tf.float32)
logging.info("start reading pickle file...")
with tf.io.gfile.GFile(input_meta_data["test_feature_path"], "rb") as f:
eval_features = pickle.load(f)
logging.info("finishing reading pickle file...")
input_meta_data["eval_features"] = eval_features
eval_fn = functools.partial(run_evaluation, strategy, test_input_fn,
eval_steps, input_meta_data)
with tf.device(get_primary_cpu_task(use_remote_tpu)):
training_utils.train(
strategy=strategy,
model_fn=model_fn,
input_meta_data=input_meta_data,
eval_fn=eval_fn,
metric_fn=None,
logits_init_fn=logits_init_fn,
train_input_fn=train_input_fn,
test_input_fn=test_input_fn,
init_checkpoint=FLAGS.init_checkpoint,
total_training_steps=total_training_steps,
steps_per_epoch=steps_per_epoch,
steps_per_loop=steps_per_loop,
optimizer=optimizer,
learning_rate_fn=learning_rate_fn,
model_dir=FLAGS.model_dir)
if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
app.run(main)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# coding=utf-8
"""Utilities used in SQUAD task."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import collections
import gc
import json
import math
import re
import string
from absl import logging
import numpy as np
import six
import tensorflow as tf
from official.nlp.xlnet import data_utils
from official.nlp.xlnet import preprocess_utils
SPIECE_UNDERLINE = u"▁"
SEG_ID_P = 0
SEG_ID_Q = 1
SEG_ID_CLS = 2
SEG_ID_PAD = 3
class InputFeatures(object):
"""A single set of features of data."""
def __init__(self,
unique_id,
example_index,
doc_span_index,
tok_start_to_orig_index,
tok_end_to_orig_index,
token_is_max_context,
input_ids,
input_mask,
p_mask,
segment_ids,
paragraph_len,
cls_index,
start_position=None,
end_position=None,
is_impossible=None):
self.unique_id = unique_id
self.example_index = example_index
self.doc_span_index = doc_span_index
self.tok_start_to_orig_index = tok_start_to_orig_index
self.tok_end_to_orig_index = tok_end_to_orig_index
self.token_is_max_context = token_is_max_context
self.input_ids = input_ids
self.input_mask = input_mask
self.p_mask = p_mask
self.segment_ids = segment_ids
self.paragraph_len = paragraph_len
self.cls_index = cls_index
self.start_position = start_position
self.end_position = end_position
self.is_impossible = is_impossible
def make_qid_to_has_ans(dataset):
qid_to_has_ans = {}
for article in dataset:
for p in article["paragraphs"]:
for qa in p["qas"]:
qid_to_has_ans[qa["id"]] = bool(qa["answers"])
return qid_to_has_ans
def get_raw_scores(dataset, preds):
"""Gets exact scores and f1 scores."""
exact_scores = {}
f1_scores = {}
for article in dataset:
for p in article["paragraphs"]:
for qa in p["qas"]:
qid = qa["id"]
gold_answers = [
a["text"] for a in qa["answers"] if normalize_answer(a["text"])
]
if not gold_answers:
# For unanswerable questions, only correct answer is empty string
gold_answers = [""]
if qid not in preds:
print("Missing prediction for %s" % qid)
continue
a_pred = preds[qid]
# Take max over all gold answers
exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers)
f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers)
return exact_scores, f1_scores
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
return re.sub(regex, " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def compute_exact(a_gold, a_pred):
return int(normalize_answer(a_gold) == normalize_answer(a_pred))
def get_tokens(s):
if not s:
return []
return normalize_answer(s).split()
def compute_f1(a_gold, a_pred):
"""Computes f1 score."""
gold_toks = get_tokens(a_gold)
pred_toks = get_tokens(a_pred)
common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
num_same = sum(common.values())
# pylint: disable=g-explicit-length-test
if len(gold_toks) == 0 or len(pred_toks) == 0:
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
return int(gold_toks == pred_toks)
if num_same == 0:
return 0
precision = 1.0 * num_same / len(pred_toks)
recall = 1.0 * num_same / len(gold_toks)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
"""Finds best threshold."""
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
cur_score = num_no_ans
best_score = cur_score
best_thresh = 0.0
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
for qid in qid_list:
if qid not in scores:
continue
if qid_to_has_ans[qid]:
diff = scores[qid]
else:
if preds[qid]:
diff = -1
else:
diff = 0
cur_score += diff
if cur_score > best_score:
best_score = cur_score
best_thresh = na_probs[qid]
has_ans_score, has_ans_cnt = 0, 0
for qid in qid_list:
if not qid_to_has_ans[qid]:
continue
has_ans_cnt += 1
if qid not in scores:
continue
has_ans_score += scores[qid]
return 100.0 * best_score / len(
scores), best_thresh, 1.0 * has_ans_score / has_ans_cnt
def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs,
qid_to_has_ans):
"""Finds all best threshold."""
best_exact, exact_thresh, has_ans_exact = find_best_thresh(
preds, exact_raw, na_probs, qid_to_has_ans)
best_f1, f1_thresh, has_ans_f1 = find_best_thresh(preds, f1_raw, na_probs,
qid_to_has_ans)
main_eval["best_exact"] = best_exact
main_eval["best_exact_thresh"] = exact_thresh
main_eval["best_f1"] = best_f1
main_eval["best_f1_thresh"] = f1_thresh
main_eval["has_ans_exact"] = has_ans_exact
main_eval["has_ans_f1"] = has_ans_f1
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
"PrelimPrediction", [
"feature_index", "start_index", "end_index", "start_log_prob",
"end_log_prob"
])
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
"NbestPrediction", ["text", "start_log_prob", "end_log_prob"])
RawResult = collections.namedtuple("RawResult", [
"unique_id", "start_top_log_probs", "start_top_index", "end_top_log_probs",
"end_top_index", "cls_logits"
])
def _compute_softmax(scores):
"""Computes softmax probability over raw logits."""
if not scores:
return []
max_score = None
for score in scores:
if max_score is None or score > max_score:
max_score = score
exp_scores = []
total_sum = 0.0
for score in scores:
x = math.exp(score - max_score)
exp_scores.append(x)
total_sum += x
probs = []
for score in exp_scores:
probs.append(score / total_sum)
return probs
class SquadExample(object):
"""A single training/test example for simple sequence classification.
For examples without an answer, the start and end position are -1.
"""
def __init__(self,
qas_id,
question_text,
paragraph_text,
orig_answer_text=None,
start_position=None,
is_impossible=False):
self.qas_id = qas_id
self.question_text = question_text
self.paragraph_text = paragraph_text
self.orig_answer_text = orig_answer_text
self.start_position = start_position
self.is_impossible = is_impossible
def __str__(self):
return self.__repr__()
def __repr__(self):
s = ""
s += "qas_id: %s" % (preprocess_utils.printable_text(self.qas_id))
s += ", question_text: %s" % (
preprocess_utils.printable_text(self.question_text))
s += ", paragraph_text: [%s]" % (" ".join(self.paragraph_text))
if self.start_position:
s += ", start_position: %d" % (self.start_position)
if self.start_position:
s += ", is_impossible: %r" % (self.is_impossible)
return s
def write_predictions(all_examples, all_features, all_results, n_best_size,
max_answer_length, output_prediction_file,
output_nbest_file, output_null_log_odds_file, orig_data,
start_n_top, end_n_top):
"""Writes final predictions to the json file and log-odds of null if needed."""
logging.info("Writing predictions to: %s", (output_prediction_file))
example_index_to_features = collections.defaultdict(list)
for feature in all_features:
example_index_to_features[feature.example_index].append(feature)
unique_id_to_result = {}
for result in all_results:
unique_id_to_result[result.unique_id] = result
all_predictions = collections.OrderedDict()
all_nbest_json = collections.OrderedDict()
scores_diff_json = collections.OrderedDict()
for (example_index, example) in enumerate(all_examples):
features = example_index_to_features[example_index]
prelim_predictions = []
# keep track of the minimum score of null start+end of position 0
score_null = 1000000 # large and positive
for (feature_index, feature) in enumerate(features):
result = unique_id_to_result[feature.unique_id]
cur_null_score = result.cls_logits
# if we could have irrelevant answers, get the min score of irrelevant
score_null = min(score_null, cur_null_score)
for i in range(start_n_top):
for j in range(end_n_top):
start_log_prob = result.start_top_log_probs[i]
start_index = result.start_top_index[i]
j_index = i * end_n_top + j
end_log_prob = result.end_top_log_probs[j_index]
end_index = result.end_top_index[j_index]
# We could hypothetically create invalid predictions, e.g., predict
# that the start of the span is in the question. We throw out all
# invalid predictions.
if start_index >= feature.paragraph_len - 1:
continue
if end_index >= feature.paragraph_len - 1:
continue
if not feature.token_is_max_context.get(start_index, False):
continue
if end_index < start_index:
continue
length = end_index - start_index + 1
if length > max_answer_length:
continue
prelim_predictions.append(
_PrelimPrediction(
feature_index=feature_index,
start_index=start_index,
end_index=end_index,
start_log_prob=start_log_prob,
end_log_prob=end_log_prob))
prelim_predictions = sorted(
prelim_predictions,
key=lambda x: (x.start_log_prob + x.end_log_prob),
reverse=True)
seen_predictions = {}
nbest = []
for pred in prelim_predictions:
if len(nbest) >= n_best_size:
break
feature = features[pred.feature_index]
tok_start_to_orig_index = feature.tok_start_to_orig_index
tok_end_to_orig_index = feature.tok_end_to_orig_index
start_orig_pos = tok_start_to_orig_index[pred.start_index]
end_orig_pos = tok_end_to_orig_index[pred.end_index]
paragraph_text = example.paragraph_text
final_text = paragraph_text[start_orig_pos:end_orig_pos + 1].strip()
if final_text in seen_predictions:
continue
seen_predictions[final_text] = True
nbest.append(
_NbestPrediction(
text=final_text,
start_log_prob=pred.start_log_prob,
end_log_prob=pred.end_log_prob))
# In very rare edge cases we could have no valid predictions. So we
# just create a nonce prediction in this case to avoid failure.
if not nbest:
nbest.append(
_NbestPrediction(text="", start_log_prob=-1e6, end_log_prob=-1e6))
total_scores = []
best_non_null_entry = None
for entry in nbest:
total_scores.append(entry.start_log_prob + entry.end_log_prob)
if not best_non_null_entry:
best_non_null_entry = entry
probs = _compute_softmax(total_scores)
nbest_json = []
for (i, entry) in enumerate(nbest):
output = collections.OrderedDict()
output["text"] = entry.text
output["probability"] = probs[i]
output["start_log_prob"] = entry.start_log_prob
output["end_log_prob"] = entry.end_log_prob
nbest_json.append(output)
assert len(nbest_json) >= 1
assert best_non_null_entry is not None
score_diff = score_null
scores_diff_json[example.qas_id] = score_diff
all_predictions[example.qas_id] = best_non_null_entry.text
all_nbest_json[example.qas_id] = nbest_json
with tf.io.gfile.GFile(output_prediction_file, "w") as writer:
writer.write(json.dumps(all_predictions, indent=4) + "\n")
with tf.io.gfile.GFile(output_nbest_file, "w") as writer:
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
with tf.io.gfile.GFile(output_null_log_odds_file, "w") as writer:
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
qid_to_has_ans = make_qid_to_has_ans(orig_data)
exact_raw, f1_raw = get_raw_scores(orig_data, all_predictions)
out_eval = {}
find_all_best_thresh(out_eval, all_predictions, exact_raw, f1_raw,
scores_diff_json, qid_to_has_ans)
return out_eval
def read_squad_examples(input_file, is_training):
"""Reads a SQuAD json file into a list of SquadExample."""
with tf.io.gfile.GFile(input_file, "r") as reader:
input_data = json.load(reader)["data"]
examples = []
for entry in input_data:
for paragraph in entry["paragraphs"]:
paragraph_text = paragraph["context"]
for qa in paragraph["qas"]:
qas_id = qa["id"]
question_text = qa["question"]
start_position = None
orig_answer_text = None
is_impossible = False
if is_training:
is_impossible = qa["is_impossible"]
if (len(qa["answers"]) != 1) and (not is_impossible):
raise ValueError(
"For training, each question should have exactly 1 answer.")
if not is_impossible:
answer = qa["answers"][0]
orig_answer_text = answer["text"]
start_position = answer["answer_start"]
else:
start_position = -1
orig_answer_text = ""
example = SquadExample(
qas_id=qas_id,
question_text=question_text,
paragraph_text=paragraph_text,
orig_answer_text=orig_answer_text,
start_position=start_position,
is_impossible=is_impossible)
examples.append(example)
return examples
# pylint: disable=invalid-name
def _convert_index(index, pos, M=None, is_start=True):
"""Converts index."""
if index[pos] is not None:
return index[pos]
N = len(index)
rear = pos
while rear < N - 1 and index[rear] is None:
rear += 1
front = pos
while front > 0 and index[front] is None:
front -= 1
assert index[front] is not None or index[rear] is not None
if index[front] is None:
if index[rear] >= 1:
if is_start:
return 0
else:
return index[rear] - 1
return index[rear]
if index[rear] is None:
if M is not None and index[front] < M - 1:
if is_start:
return index[front] + 1
else:
return M - 1
return index[front]
if is_start:
if index[rear] > index[front] + 1:
return index[front] + 1
else:
return index[rear]
else:
if index[rear] > index[front] + 1:
return index[rear] - 1
else:
return index[front]
def convert_examples_to_features(examples, sp_model, max_seq_length, doc_stride,
max_query_length, is_training, output_fn,
uncased):
"""Loads a data file into a list of `InputBatch`s."""
cnt_pos, cnt_neg = 0, 0
unique_id = 1000000000
max_N, max_M = 1024, 1024
f = np.zeros((max_N, max_M), dtype=np.float32)
for (example_index, example) in enumerate(examples):
# pylint: disable=logging-format-interpolation
if example_index % 100 == 0:
logging.info("Converting {}/{} pos {} neg {}".format(
example_index, len(examples), cnt_pos, cnt_neg))
query_tokens = preprocess_utils.encode_ids(
sp_model,
preprocess_utils.preprocess_text(example.question_text, lower=uncased))
if len(query_tokens) > max_query_length:
query_tokens = query_tokens[0:max_query_length]
paragraph_text = example.paragraph_text
para_tokens = preprocess_utils.encode_pieces(
sp_model,
preprocess_utils.preprocess_text(example.paragraph_text, lower=uncased))
chartok_to_tok_index = []
tok_start_to_chartok_index = []
tok_end_to_chartok_index = []
char_cnt = 0
for i, token in enumerate(para_tokens):
chartok_to_tok_index.extend([i] * len(token))
tok_start_to_chartok_index.append(char_cnt)
char_cnt += len(token)
tok_end_to_chartok_index.append(char_cnt - 1)
tok_cat_text = "".join(para_tokens).replace(SPIECE_UNDERLINE, " ")
N, M = len(paragraph_text), len(tok_cat_text)
if N > max_N or M > max_M:
max_N = max(N, max_N)
max_M = max(M, max_M)
f = np.zeros((max_N, max_M), dtype=np.float32)
gc.collect()
g = {}
# pylint: disable=cell-var-from-loop
def _lcs_match(max_dist):
"""LCS match."""
f.fill(0)
g.clear()
### longest common sub sequence
# f[i, j] = max(f[i - 1, j], f[i, j - 1], f[i - 1, j - 1] + match(i, j))
for i in range(N):
# note(zhiliny):
# unlike standard LCS, this is specifically optimized for the setting
# because the mismatch between sentence pieces and original text will
# be small
for j in range(i - max_dist, i + max_dist):
if j >= M or j < 0:
continue
if i > 0:
g[(i, j)] = 0
f[i, j] = f[i - 1, j]
if j > 0 and f[i, j - 1] > f[i, j]:
g[(i, j)] = 1
f[i, j] = f[i, j - 1]
f_prev = f[i - 1, j - 1] if i > 0 and j > 0 else 0
if (preprocess_utils.preprocess_text(
paragraph_text[i], lower=uncased,
remove_space=False) == tok_cat_text[j] and f_prev + 1 > f[i, j]):
g[(i, j)] = 2
f[i, j] = f_prev + 1
max_dist = abs(N - M) + 5
for _ in range(2):
_lcs_match(max_dist)
if f[N - 1, M - 1] > 0.8 * N:
break
max_dist *= 2
orig_to_chartok_index = [None] * N
chartok_to_orig_index = [None] * M
i, j = N - 1, M - 1
while i >= 0 and j >= 0:
if (i, j) not in g:
break
if g[(i, j)] == 2:
orig_to_chartok_index[i] = j
chartok_to_orig_index[j] = i
i, j = i - 1, j - 1
elif g[(i, j)] == 1:
j = j - 1
else:
i = i - 1
if all(
v is None for v in orig_to_chartok_index) or f[N - 1, M - 1] < 0.8 * N:
print("MISMATCH DETECTED!")
continue
tok_start_to_orig_index = []
tok_end_to_orig_index = []
for i in range(len(para_tokens)):
start_chartok_pos = tok_start_to_chartok_index[i]
end_chartok_pos = tok_end_to_chartok_index[i]
start_orig_pos = _convert_index(
chartok_to_orig_index, start_chartok_pos, N, is_start=True)
end_orig_pos = _convert_index(
chartok_to_orig_index, end_chartok_pos, N, is_start=False)
tok_start_to_orig_index.append(start_orig_pos)
tok_end_to_orig_index.append(end_orig_pos)
if not is_training:
tok_start_position = tok_end_position = None
if is_training and example.is_impossible:
tok_start_position = -1
tok_end_position = -1
if is_training and not example.is_impossible:
start_position = example.start_position
end_position = start_position + len(example.orig_answer_text) - 1
start_chartok_pos = _convert_index(
orig_to_chartok_index, start_position, is_start=True)
tok_start_position = chartok_to_tok_index[start_chartok_pos]
end_chartok_pos = _convert_index(
orig_to_chartok_index, end_position, is_start=False)
tok_end_position = chartok_to_tok_index[end_chartok_pos]
assert tok_start_position <= tok_end_position
def _piece_to_id(x):
if six.PY2 and isinstance(x, unicode):
x = x.encode("utf-8")
return sp_model.PieceToId(x)
all_doc_tokens = list(map(_piece_to_id, para_tokens))
# The -3 accounts for [CLS], [SEP] and [SEP]
max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
# We can have documents that are longer than the maximum sequence length.
# To deal with this we do a sliding window approach, where we take chunks
# of the up to our max length with a stride of `doc_stride`.
_DocSpan = collections.namedtuple( # pylint: disable=invalid-name
"DocSpan", ["start", "length"])
doc_spans = []
start_offset = 0
while start_offset < len(all_doc_tokens):
length = len(all_doc_tokens) - start_offset
if length > max_tokens_for_doc:
length = max_tokens_for_doc
doc_spans.append(_DocSpan(start=start_offset, length=length))
if start_offset + length == len(all_doc_tokens):
break
start_offset += min(length, doc_stride)
for (doc_span_index, doc_span) in enumerate(doc_spans):
tokens = []
token_is_max_context = {}
segment_ids = []
p_mask = []
cur_tok_start_to_orig_index = []
cur_tok_end_to_orig_index = []
for i in range(doc_span.length):
split_token_index = doc_span.start + i
cur_tok_start_to_orig_index.append(
tok_start_to_orig_index[split_token_index])
cur_tok_end_to_orig_index.append(
tok_end_to_orig_index[split_token_index])
is_max_context = _check_is_max_context(doc_spans, doc_span_index,
split_token_index)
token_is_max_context[len(tokens)] = is_max_context
tokens.append(all_doc_tokens[split_token_index])
segment_ids.append(SEG_ID_P)
p_mask.append(0)
paragraph_len = len(tokens)
tokens.append(data_utils.SEP_ID)
segment_ids.append(SEG_ID_P)
p_mask.append(1)
# note(zhiliny): we put P before Q
# because during pretraining, B is always shorter than A
for token in query_tokens:
tokens.append(token)
segment_ids.append(SEG_ID_Q)
p_mask.append(1)
tokens.append(data_utils.SEP_ID)
segment_ids.append(SEG_ID_Q)
p_mask.append(1)
cls_index = len(segment_ids)
tokens.append(data_utils.CLS_ID)
segment_ids.append(SEG_ID_CLS)
p_mask.append(0)
input_ids = tokens
# The mask has 0 for real tokens and 1 for padding tokens. Only real
# tokens are attended to.
input_mask = [0] * len(input_ids)
# Zero-pad up to the sequence length.
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(1)
segment_ids.append(SEG_ID_PAD)
p_mask.append(1)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
assert len(p_mask) == max_seq_length
span_is_impossible = example.is_impossible
start_position = None
end_position = None
if is_training and not span_is_impossible:
# For training, if our document chunk does not contain an annotation
# we throw it out, since there is nothing to predict.
doc_start = doc_span.start
doc_end = doc_span.start + doc_span.length - 1
out_of_span = False
if not (tok_start_position >= doc_start and
tok_end_position <= doc_end):
out_of_span = True
if out_of_span:
# continue
start_position = 0
end_position = 0
span_is_impossible = True
else:
# note: we put P before Q, so doc_offset should be zero.
# doc_offset = len(query_tokens) + 2
doc_offset = 0
start_position = tok_start_position - doc_start + doc_offset
end_position = tok_end_position - doc_start + doc_offset
if is_training and span_is_impossible:
start_position = cls_index
end_position = cls_index
if example_index < 20:
logging.info("*** Example ***")
logging.info("unique_id: %s", unique_id)
logging.info("example_index: %s", example_index)
logging.info("doc_span_index: %s", doc_span_index)
logging.info("tok_start_to_orig_index: %s",
" ".join([str(x) for x in cur_tok_start_to_orig_index]))
logging.info("tok_end_to_orig_index: %s",
" ".join([str(x) for x in cur_tok_end_to_orig_index]))
logging.info(
"token_is_max_context: %s", " ".join([
"%d:%s" % (x, y)
for (x, y) in six.iteritems(token_is_max_context)
]))
logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
if is_training and span_is_impossible:
logging.info("impossible example span")
if is_training and not span_is_impossible:
pieces = [
sp_model.IdToPiece(token)
for token in tokens[start_position:(end_position + 1)]
]
answer_text = sp_model.DecodePieces(pieces)
logging.info("start_position: %d", start_position)
logging.info("end_position: %d", end_position)
logging.info("answer: %s",
preprocess_utils.printable_text(answer_text))
# With multi processing, the example_index is actually the index
# within the current process therefore we use example_index=None to
# avoid being used in the future. # The current code does not use
# example_index of training data.
if is_training:
feat_example_index = None
else:
feat_example_index = example_index
feature = InputFeatures(
unique_id=unique_id,
example_index=feat_example_index,
doc_span_index=doc_span_index,
tok_start_to_orig_index=cur_tok_start_to_orig_index,
tok_end_to_orig_index=cur_tok_end_to_orig_index,
token_is_max_context=token_is_max_context,
input_ids=input_ids,
input_mask=input_mask,
p_mask=p_mask,
segment_ids=segment_ids,
paragraph_len=paragraph_len,
cls_index=cls_index,
start_position=start_position,
end_position=end_position,
is_impossible=span_is_impossible)
# Run callback
output_fn(feature)
unique_id += 1
if span_is_impossible:
cnt_neg += 1
else:
cnt_pos += 1
logging.info("Total number of instances: %d = pos %d + neg %d",
cnt_pos + cnt_neg, cnt_pos, cnt_neg)
def _check_is_max_context(doc_spans, cur_span_index, position):
"""Check if this is the "max context" doc span for the token."""
# Because of the sliding window approach taken to scoring documents, a single
# token can appear in multiple documents. E.g.
# Doc: the man went to the store and bought a gallon of milk
# Span A: the man went to the
# Span B: to the store and bought
# Span C: and bought a gallon of
# ...
#
# Now the word "bought" will have two scores from spans B and C. We only
# want to consider the score with "maximum context", which we define as
# the *minimum* of its left and right context (the *sum* of left and
# right context will always be the same, of course).
#
# In the example the maximum context for "bought" would be span C since
# it has 1 left context and 3 right context, while span B has 4 left context
# and 0 right context.
best_score = None
best_span_index = None
for (span_index, doc_span) in enumerate(doc_spans):
end = doc_span.start + doc_span.length - 1
if position < doc_span.start:
continue
if position > end:
continue
num_left_context = position - doc_span.start
num_right_context = end - position
score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
if best_score is None or score > best_score:
best_score = score
best_span_index = span_index
return cur_span_index == best_span_index
class FeatureWriter(object):
"""Writes InputFeature to TF example file."""
def __init__(self, filename, is_training):
self.filename = filename
self.is_training = is_training
self.num_features = 0
self._writer = tf.io.TFRecordWriter(filename)
def process_feature(self, feature):
"""Write a InputFeature to the TFRecordWriter as a tf.train.Example."""
self.num_features += 1
def create_int_feature(values):
feature = tf.train.Feature(
int64_list=tf.train.Int64List(value=list(values)))
return feature
def create_float_feature(values):
f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
return f
features = collections.OrderedDict()
features["unique_ids"] = create_int_feature([feature.unique_id])
features["input_ids"] = create_int_feature(feature.input_ids)
features["input_mask"] = create_float_feature(feature.input_mask)
features["p_mask"] = create_float_feature(feature.p_mask)
features["segment_ids"] = create_int_feature(feature.segment_ids)
features["cls_index"] = create_int_feature([feature.cls_index])
if self.is_training:
features["start_positions"] = create_int_feature([feature.start_position])
features["end_positions"] = create_int_feature([feature.end_position])
impossible = 0
if feature.is_impossible:
impossible = 1
features["is_impossible"] = create_float_feature([impossible])
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
self._writer.write(tf_example.SerializeToString())
def close(self):
self._writer.close()
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""XLNet classification finetuning runner in tf2.0."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import os
import re
from absl import logging
# pytype: disable=attribute-error
# pylint: disable=g-bare-generic,unused-import
import tensorflow as tf
# Initialize TPU System.
from official.nlp.xlnet import data_utils
from official.nlp import xlnet_modeling as modeling
from typing import Any, Callable, Dict, Text, Optional
_MIN_SUMMARY_STEPS = 10
def _save_checkpoint(checkpoint, model_dir, checkpoint_prefix):
"""Saves model to with provided checkpoint prefix."""
checkpoint_path = os.path.join(model_dir, checkpoint_prefix)
saved_path = checkpoint.save(checkpoint_path)
logging.info("Saving model as TF checkpoint: %s", saved_path)
return
def _float_metric_value(metric):
"""Gets the value of a float-value keras metric."""
return metric.result().numpy().astype(float)
def _steps_to_run(current_step, steps_per_epoch, steps_per_loop):
"""Calculates steps to run on device."""
if steps_per_loop <= 0:
raise ValueError("steps_per_loop should be positive integer.")
if steps_per_loop == 1:
return steps_per_loop
remainder_in_epoch = current_step % steps_per_epoch
if remainder_in_epoch != 0:
return min(steps_per_epoch - remainder_in_epoch, steps_per_loop)
else:
return steps_per_loop
def train(
strategy: tf.distribute.Strategy,
model_fn: Callable,
input_meta_data: Dict,
logits_init_fn: Callable[[], tf.Tensor],
train_input_fn: Callable,
total_training_steps: int,
steps_per_epoch: int,
steps_per_loop: int,
optimizer: tf.keras.optimizers.Optimizer,
learning_rate_fn: tf.keras.optimizers.schedules.LearningRateSchedule,
eval_fn: Optional[Callable[[tf.keras.Model, int, tf.summary.SummaryWriter],
Any]] = None,
metric_fn: Optional[Callable[[], tf.keras.metrics.Metric]] = None,
test_input_fn: Optional[Callable] = None,
init_checkpoint: Optional[Text] = None,
model_dir: Optional[Text] = None,
save_steps: Optional[int] = None):
"""Runs customized training.
Args:
strategy: Distribution strategy on which to run low level training loop.
model_fn: The function returns a keras.Model.
input_meta_data: A dictionary of params: `mem_len`, `lr_layer_decay_rate`,
`n_layer`, `batch_size_per_core` and `d_model`.
logits_init_fn: Function creates a dummy logits tensor.
train_input_fn: Function returns a tf.data.Dataset used for training.
total_training_steps: Number of steps to train in total.
steps_per_epoch: Number of steps to run per epoch. At the end of each
epoch, model checkpoint will be saved and evaluation will be conducted
if evaluation dataset is provided.
steps_per_loop: Number of steps per graph-mode loop. In order to reduce
communication in eager context, training logs are printed every
steps_per_loop.
optimizer: The optimizer for model.
learning_rate_fn: the learning rate schedule.
eval_fn: A callback of evaluation function, that takes a keras.Model,
current step and evaluation summary writer.
metric_fn: A metrics function returns a Keras Metric object to record
evaluation result using evaluation dataset or with training dataset
after every epoch.
test_input_fn: Function returns a evaluation dataset. If none, evaluation
is skipped.
init_checkpoint: Optional checkpoint to load to `sub_model` returned by
`model_fn`.
model_dir: The directory of model (checkpoints, summaries).
save_steps: The frequency to save checkpoints. Every save_steps, we save a
model checkpoint.
Returns:
Last training step logits if training happens, otherwise returns None.
Raises:
TypeError: if model directory is not specified.
"""
required_arguments = [
logits_init_fn, train_input_fn, total_training_steps, steps_per_epoch,
steps_per_loop, optimizer, learning_rate_fn
]
if [arg for arg in required_arguments if arg is None]:
raise ValueError(
"`logits_init_fn`, `train_input_fn`, `total_training_steps`, "
"`steps_per_epoch`, `steps_per_loop`, `optimizer` and "
"`learning_rate_fn` are required parameters.")
if not model_dir:
raise TypeError("Model directory must be specified.")
# pylint: disable=protected-access
train_iterator = data_utils._get_input_iterator(train_input_fn, strategy)
# pylint: enable=protected-access
train_summary_writer = None
eval_summary_writer = None
if not tf.io.gfile.exists(model_dir):
tf.io.gfile.mkdir(model_dir)
if test_input_fn:
eval_summary_writer = tf.summary.create_file_writer(
os.path.join(model_dir, "summaries/eval"))
if steps_per_loop >= _MIN_SUMMARY_STEPS:
# Only writes summary when the stats are collected sufficiently over
# enough steps.
train_summary_writer = tf.summary.create_file_writer(
os.path.join(model_dir, "summaries/train"))
with strategy.scope():
model = model_fn()
if init_checkpoint:
logging.info("restore from %s", init_checkpoint)
checkpoint = tf.train.Checkpoint(model=model)
checkpoint.restore(init_checkpoint)
model.optimizer = optimizer
if not hasattr(model, "optimizer"):
raise ValueError("User should set optimizer attribute to model.")
train_loss_metric = tf.keras.metrics.Mean("training_loss", dtype=tf.float32)
train_metric = None
if metric_fn:
train_metric = metric_fn()
def _replicated_step(inputs, mem=None):
"""Replicated training step."""
inputs["mems"] = mem
with tf.GradientTape() as tape:
mem, logits = model(inputs, training=True)
loss = model.losses
train_loss_metric.update_state(loss)
if train_metric:
train_metric.update_state(inputs["label_ids"], logits)
scaled_loss = loss[0] * 1.0 / float(strategy.num_replicas_in_sync)
# Collects training variables.
tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
clipped, _ = tf.clip_by_global_norm(grads, clip_norm=1.0)
if input_meta_data["lr_layer_decay_rate"] != 1.0:
n_layer = 0
for i in range(len(clipped)):
m = re.search(r"model/transformer/layer_(\d+?)/", tvars[i].name)
if not m:
continue
n_layer = max(n_layer, int(m.group(1)) + 1)
for i in range(len(clipped)):
for l in range(n_layer):
if "model/transformer/layer_{}/".format(l) in tvars[i].name:
abs_rate = input_meta_data["lr_layer_decay_rate"]**(
n_layer - 1 - l)
clipped[i] *= abs_rate
logging.info("Apply mult {:.4f} to layer-{} grad of {}".format(
abs_rate, l, tvars[i].name))
break
optimizer.apply_gradients(zip(clipped, tvars))
if input_meta_data["mem_len"] > 0:
return mem, logits
else:
return logits
@tf.function
def train_steps(iterator, steps):
"""Performs distributed training steps in a loop.
Args:
iterator: the distributed iterator of training datasets.
steps: an tf.int32 integer tensor to specify number of steps to run
inside host training loop.
Raises:
ValueError: Any of the arguments or tensor shapes are invalid.
Returns:
logits: logits computed.
"""
if not isinstance(steps, tf.Tensor):
raise ValueError("steps should be an Tensor. Python object may cause "
"retracing.")
def cache_fn():
"""Initializes memory tensor used in XLNet pretraining."""
mems = []
if input_meta_data["mem_len"] > 0:
for _ in range(input_meta_data["n_layer"]):
zeros = tf.zeros([
input_meta_data["mem_len"],
input_meta_data["batch_size_per_core"],
input_meta_data["d_model"]
],
dtype=tf.float32)
mems.append(zeros)
return mems
logits = strategy.experimental_run_v2(logits_init_fn)
if input_meta_data["mem_len"] > 0:
mem = strategy.experimental_run_v2(cache_fn)
for _ in tf.range(steps):
mem, logits = strategy.experimental_run_v2(
_replicated_step, args=(
next(iterator),
mem,
))
else:
for _ in tf.range(steps):
logits = strategy.experimental_run_v2(
_replicated_step, args=(next(iterator),))
return logits
logging.info("Start training...")
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
if latest_checkpoint_file:
logging.info("Checkpoint file %s found and restoring from checkpoint",
latest_checkpoint_file)
checkpoint.restore(latest_checkpoint_file)
logging.info("Loading from checkpoint file completed")
current_step = optimizer.iterations.numpy()
checkpoint_name = "xlnet_step_{step}.ckpt"
logits = None
while current_step < total_training_steps:
train_loss_metric.reset_states()
if train_metric:
train_metric.reset_states()
steps = _steps_to_run(current_step, steps_per_epoch, steps_per_loop)
logits = train_steps(train_iterator,
tf.convert_to_tensor(steps, dtype=tf.int32))
current_step += steps
train_loss = _float_metric_value(train_loss_metric)
log_stream = "Train step: %d/%d / lr = %.9f / loss = %.7f" % (
current_step, total_training_steps, learning_rate_fn(current_step),
train_loss)
if train_metric:
log_stream += " / %s = %f" % (train_metric.name,
_float_metric_value(train_metric))
logging.info(log_stream)
if train_summary_writer:
with train_summary_writer.as_default():
tf.summary.scalar(
"learning_rate",
learning_rate_fn(current_step),
step=current_step)
tf.summary.scalar(
train_loss_metric.name, train_loss, step=current_step)
if train_metric:
tf.summary.scalar(
train_metric.name,
_float_metric_value(train_metric),
step=current_step)
train_summary_writer.flush()
if model_dir:
if (save_steps is None) or (save_steps and
current_step % save_steps == 0):
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if test_input_fn and current_step % steps_per_epoch == 0:
logging.info("Running evaluation after step: %s.", current_step)
eval_fn(model, current_step, eval_summary_writer)
if model_dir:
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if test_input_fn:
logging.info("Running final evaluation after training is complete.")
eval_fn(model, current_step, eval_summary_writer)
return logits
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions used in XLNet model."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import json
import os
import tensorflow as tf
def create_run_config(is_training, is_finetune, flags):
"""Helper function for creating RunConfig."""
kwargs = dict(
is_training=is_training,
use_tpu=flags.use_tpu,
use_bfloat16=flags.use_bfloat16,
dropout=flags.dropout,
dropout_att=flags.dropout_att,
init_method=flags.init_method,
init_range=flags.init_range,
init_std=flags.init_std,
clamp_len=flags.clamp_len)
if not is_finetune:
kwargs.update(dict(
mem_len=flags.mem_len,
reuse_len=flags.reuse_len,
bi_data=flags.bi_data,
clamp_len=flags.clamp_len,
same_length=flags.same_length))
return RunConfig(**kwargs)
class XLNetConfig(object):
"""Configs for XLNet model.
XLNetConfig contains hyperparameters that are specific to a model checkpoint;
i.e., these hyperparameters should be the same between
pretraining and finetuning.
The following hyperparameters are defined:
n_layer: int, the number of layers.
d_model: int, the hidden size.
n_head: int, the number of attention heads.
d_head: int, the dimension size of each attention head.
d_inner: int, the hidden size in feed-forward layers.
ff_activation: str, "relu" or "gelu".
untie_r: bool, whether to untie the biases in attention.
n_token: int, the vocab size.
"""
def __init__(self, FLAGS=None, json_path=None, args_dict=None):
"""Constructing an XLNetConfig.
One of FLAGS or json_path should be provided.
Args:
FLAGS: An FLAGS instance.
json_path: A path to a json config file.
args_dict: A dict for args.
"""
assert FLAGS is not None or json_path is not None or args_dict is not None
self.keys = ['n_layer', 'd_model', 'n_head', 'd_head', 'd_inner',
'ff_activation', 'untie_r', 'n_token']
if FLAGS is not None:
self.init_from_flags(FLAGS)
if json_path is not None:
self.init_from_json(json_path)
if args_dict is not None:
self.init_from_dict(args_dict)
def init_from_dict(self, args_dict):
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
for key in self.keys:
setattr(self, key, args_dict[key])
def init_from_flags(self, flags):
for key in self.keys:
setattr(self, key, getattr(flags, key))
def init_from_json(self, json_path):
with tf.gfile.Open(json_path) as f:
json_data = json.load(f)
self.init_from_dict(json_data)
def to_json(self, json_path):
"""Save XLNetConfig to a json file."""
json_data = {}
for key in self.keys:
json_data[key] = getattr(self, key)
json_dir = os.path.dirname(json_path)
if not tf.gfile.Exists(json_dir):
tf.gfile.MakeDirs(json_dir)
with tf.gfile.Open(json_path, 'w') as f:
json.dump(json_data, f, indent=4, sort_keys=True)
class RunConfig(object):
"""Class of RunConfig.
RunConfig contains hyperparameters that could be different
between pretraining and finetuning.
These hyperparameters can also be changed from run to run.
We store them separately from XLNetConfig for flexibility.
"""
def __init__(self,
is_training,
use_tpu,
use_bfloat16,
dropout,
dropout_att,
init_method='normal',
init_range=0.1,
init_std=0.02,
mem_len=None,
reuse_len=None,
bi_data=False,
clamp_len=-1,
same_length=False):
"""Initializes RunConfig.
Args:
is_training: bool, whether in training mode.
use_tpu: bool, whether TPUs are used.
use_bfloat16: bool, use bfloat16 instead of float32.
dropout: float, dropout rate.
dropout_att: float, dropout rate on attention probabilities.
init_method: str, the initialization scheme, either "normal" or "uniform".
init_range: float, initialize the parameters with a uniform distribution
in [-init_range, init_range]. Only effective when init="uniform".
init_std: float, initialize the parameters with a normal distribution
with mean 0 and stddev init_std. Only effective when init="normal".
mem_len: int, the number of tokens to cache.
reuse_len: int, the number of tokens in the currect batch to be cached
and reused in the future.
bi_data: bool, whether to use bidirectional input pipeline.
Usually set to True during pretraining and False during finetuning.
clamp_len: int, clamp all relative distances larger than clamp_len.
-1 means no clamping.
same_length: bool, whether to use the same attention length
for each token.
"""
self.init_method = init_method
self.init_range = init_range
self.init_std = init_std
self.is_training = is_training
self.dropout = dropout
self.dropout_att = dropout_att
self.use_tpu = use_tpu
self.use_bfloat16 = use_bfloat16
self.mem_len = mem_len
self.reuse_len = reuse_len
self.bi_data = bi_data
self.clamp_len = clamp_len
self.same_length = same_length
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras layers of XLNet model in TF 2.0."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import copy
import numpy as np
import tensorflow as tf
def gelu(x):
"""Gaussian Error Linear Unit.
This is a smoother version of the RELU.
Original paper: https://arxiv.org/abs/1606.08415
Args:
x: float Tensor to perform activation.
Returns:
`x` with the GELU activation applied.
"""
cdf = 0.5 * (1.0 + tf.tanh(
(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
return x * cdf
def rel_shift(x, klen=-1):
"""Performs relative shift to form the relative attention score."""
x_size = tf.shape(x)
x = tf.reshape(x, [x_size[1], x_size[0], x_size[2], x_size[3]])
x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1])
x = tf.reshape(x, [x_size[0], x_size[1] - 1, x_size[2], x_size[3]])
x = tf.slice(x, [0, 0, 0, 0], [-1, klen, -1, -1])
return x
def _get_initializer(flags):
"""Get variable intializer."""
if flags.init_method == 'uniform':
initializer = tf.keras.initializers.RandomUniform(
minval=-flags.init_range, maxval=flags.init_range)
elif flags.init_method == 'normal':
initializer = tf.keras.initializers.RandomNormal(stddev=flags.init_std)
else:
raise ValueError('Initializer {} not supported'.format(flags.init_method))
return initializer
def _create_mask(qlen, mlen, dtype=tf.float32, same_length=False):
"""Creates attention mask when single-side context allowed only."""
attn_mask = tf.ones([qlen, qlen], dtype=dtype)
mask_u = tf.matrix_band_part(attn_mask, 0, -1)
mask_dia = tf.matrix_band_part(attn_mask, 0, 0)
attn_mask_pad = tf.zeros([qlen, mlen], dtype=dtype)
ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1)
if same_length:
mask_l = tf.matrix_band_part(attn_mask, -1, 0)
ret = tf.concat([ret[:, :qlen] + mask_l - mask_dia, ret[:, qlen:]], 1)
return ret
def _cache_mem(curr_out, prev_mem, mem_len, reuse_len=None):
"""cache hidden states into memory."""
if mem_len is None or mem_len == 0:
return None
else:
if reuse_len is not None and reuse_len > 0:
curr_out = curr_out[:reuse_len]
if prev_mem is None:
new_mem = curr_out[-mem_len:]
else:
new_mem = tf.concat([prev_mem, curr_out], 0)[-mem_len:]
return tf.keras.backend.stop_gradient(new_mem)
def embedding_lookup(lookup_table, x, use_tpu=True):
"""Looks up words embeddings for input id tensor."""
if use_tpu:
n_token = tf.shape(lookup_table)[0]
one_hot_idx = tf.one_hot(x, n_token)
if one_hot_idx.shape.ndims == 2:
return tf.einsum('nd,in->id', lookup_table, one_hot_idx)
else:
return tf.einsum('nd,ibn->ibd', lookup_table, one_hot_idx)
else:
return tf.nn.embedding_lookup(lookup_table, x)
def is_special_none_tensor(tensor):
"""Checks if a tensor is a special None Tensor."""
return tensor.shape.ndims == 0 and tensor.dtype == tf.int32
def unpack_inputs(inputs):
"""Unpacks a tuple of `inputs` tensors to a tuple.
Args:
inputs: A list of tensors.
Returns:
A tuple of tensors. If any input is a special constant tensor, replace it
with None.
"""
inputs = tf.nest.flatten(inputs)
outputs = []
for x in inputs:
if is_special_none_tensor(x):
outputs.append(None)
else:
outputs.append(x)
x = tuple(outputs)
# To trick the very pointless 'unbalanced-tuple-unpacking' pylint check
# from triggering.
if len(x) == 1:
return x[0]
return tuple(outputs)
def pack_inputs(inputs):
"""Packs a list of `inputs` tensors to a tuple.
Args:
inputs: A list of tensors.
Returns:
A tuple of tensors. If any input is None, replace it with a special constant
tensor.
"""
inputs = tf.nest.flatten(inputs)
outputs = []
for x in inputs:
if x is None:
outputs.append(tf.constant(0, shape=[], dtype=tf.int32))
else:
outputs.append(x)
return tuple(outputs)
class PositionalEmbedding(tf.keras.layers.Layer):
"""Generates relative positional embeddings used in Transformer-XL and XLNet."""
def __init__(self, dim, **kwargs):
super(PositionalEmbedding, self).__init__(**kwargs)
self.dim = dim
def build(self, unused_input_shapes):
"""Constructs inversed frequency vector for positional embedding layer."""
self.inv_freq = 1.0 / (10000.0 ** (tf.range(0, self.dim, 2.0) / self.dim))
super(PositionalEmbedding, self).build(unused_input_shapes)
def __call__(self, pos_seq, batch_size):
return super(PositionalEmbedding, self).__call__((
pos_seq,
batch_size,
))
def call(self, inputs):
"""Implements call() for the layer."""
pos_seq, batch_size = inputs
sinusoid_inp = tf.einsum('i,d->id', pos_seq, self.inv_freq)
pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1)
pos_emb = pos_emb[:, None, :]
if batch_size is not None:
pos_emb = tf.tile(pos_emb, [1, batch_size, 1])
return pos_emb
class RelativeAttention(tf.keras.layers.Layer):
"""Core calculations for relative attention."""
def __init__(self, dropout_att, scale):
super(RelativeAttention, self).__init__()
self.scale = scale
self.dropout_att = dropout_att
def build(self, unused_input_shapes):
"""Implements build() for the layer."""
self.attention_probs_dropout = tf.keras.layers.Dropout(
rate=self.dropout_att)
super(RelativeAttention, self).build(unused_input_shapes)
def __call__(self, q_head, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat,
r_w_bias, r_r_bias, r_s_bias, attn_mask):
inputs = pack_inputs([
q_head, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias,
r_r_bias, r_s_bias, attn_mask
])
return super(RelativeAttention, self).__call__(inputs)
def call(self, inputs):
"""Implements call() for the layer."""
(q_head, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias,
r_r_bias, r_s_bias, attn_mask) = unpack_inputs(inputs)
# content based attention score
ac = tf.einsum('ibnd,jbnd->ijbn', q_head + r_w_bias, k_head_h)
# position based attention score
bd = tf.einsum('ibnd,jbnd->ijbn', q_head + r_r_bias, k_head_r)
bd = rel_shift(bd, klen=tf.shape(ac)[1])
# segment-based attention score
if seg_mat is None:
ef = 0
else:
ef = tf.einsum('ibnd,snd->ibns', q_head + r_s_bias, seg_embed)
ef = tf.einsum('ijbs,ibns->ijbn', seg_mat, ef)
# merges attention scores and performs masking
attn_score = (ac + bd + ef) * self.scale
if attn_mask is not None:
attn_score = attn_score - 1e30 * attn_mask
# attention probability
attn_prob = tf.nn.softmax(attn_score, 1)
attn_prob = self.attention_probs_dropout(attn_prob)
# attention output
attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, v_head_h)
return attn_vec
class PositionwiseFF(tf.keras.layers.Layer):
"""Positionwise feed-forward layer."""
def __init__(self, d_model, d_inner, dropout,
kernel_initializer, activation_type, **kwargs):
super(PositionwiseFF, self).__init__(**kwargs)
self.d_model = d_model
self.d_inner = d_inner
self.dropout = dropout
self.activation_type = activation_type
self.kernel_initializer = kernel_initializer
def build(self, unused_input_shapes):
"""Implements build() for the layer."""
if self.activation_type == 'relu':
activation = tf.nn.relu
elif self.activation_type == 'gelu':
activation = gelu
else:
raise (ValueError('Unsupported activation type {}'.format(
self.activation_type)))
self.inner_projection_layer = (
tf.keras.layers.Dense(
units=self.d_inner,
activation=activation,
kernel_initializer=self.kernel_initializer,
name='layer_1'))
self.output_projection_layer = (
tf.keras.layers.Dense(
units=self.d_model,
kernel_initializer=self.kernel_initializer,
name='layer_2'))
self.inner_dropout = tf.keras.layers.Dropout(rate=self.dropout,
name='drop_1')
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout,
name='drop_2')
self.output_layer_norm = (
tf.keras.layers.LayerNormalization(
name='LayerNorm', axis=-1, epsilon=1e-12))
super(PositionwiseFF, self).build(unused_input_shapes)
def call(self, inp):
"""Implements call() for the layer."""
output = self.inner_projection_layer(inp)
output = self.inner_dropout(output)
output = self.output_projection_layer(output)
output = self.output_dropout(output)
output = self.output_layer_norm(output + inp)
return output
class EmbeddingLookup(tf.keras.layers.Layer):
"""Looks up words embeddings for id tensor."""
def __init__(self,
n_token, d_embed, initializer,
use_one_hot=False, **kwargs):
super(EmbeddingLookup, self).__init__(**kwargs)
self.n_token = n_token
self.d_embed = d_embed
self.initializer = initializer
self.use_one_hot = use_one_hot
def build(self, unused_input_shapes):
"""Implements build() for the layer."""
self.lookup_table = self.add_weight(
'lookup_table',
shape=[self.n_token, self.d_embed],
initializer=self.initializer,
dtype=self.dtype)
super(EmbeddingLookup, self).build(unused_input_shapes)
def call(self, inputs):
x = inputs
if self.use_one_hot:
one_hot_idx = tf.one_hot(x, self.n_token, dtype=self.dtype)
if one_hot_idx.shape.ndims == 2:
return tf.einsum('in,nd->id',
one_hot_idx,
self.lookup_table), self.lookup_table
else:
return tf.einsum('ibn,nd->ibd',
one_hot_idx,
self.lookup_table), self.lookup_table
else:
return tf.nn.embedding_lookup(self.lookup_table, x), self.lookup_table
class TwoStreamRelativeAttention(tf.keras.layers.Layer):
"""Two-stream attention layer with relative positional encoding."""
def __init__(self, d_model, n_head, d_head, dropout, dropout_att,
kernel_initializer, **kwargs):
super(TwoStreamRelativeAttention, self).__init__(**kwargs)
self.d_model = d_model
self.n_head = n_head
self.d_head = d_head
self.dropout = dropout
self.dropout_att = dropout_att
self.initializer = kernel_initializer
def build(self, unused_input_shapes):
"""Implements build() for the layer."""
self.scale = 1.0 / (self.d_head ** 0.5)
self.attention_projection_layer = tf.keras.layers.Dense(
units=self.d_model, use_bias=False,
kernel_initializer=self.initializer,
name='o')
self.attention_probs_dropout = tf.keras.layers.Dropout(
rate=self.dropout_att)
self.attention_out_dropout = tf.keras.layers.Dropout(rate=self.dropout)
self.output_layer_norm = tf.keras.layers.LayerNormalization(
name='LayerNorm', axis=-1, epsilon=1e-12)
self.kh_projection_layer = (
self.add_weight(
'k/kernel',
shape=[self.d_model, self.n_head, self.d_head],
initializer=self.initializer))
self.vh_projection_layer = (
self.add_weight(
'v/kernel',
shape=[self.d_model, self.n_head, self.d_head],
initializer=self.initializer))
self.kr_projection_layer = (
self.add_weight(
'r/kernel',
shape=[self.d_model, self.n_head, self.d_head],
initializer=self.initializer))
self.qh_projection_layer = (
self.add_weight(
'q/kernel',
shape=[self.d_model, self.n_head, self.d_head],
initializer=self.initializer))
self.h_attention_layer = RelativeAttention(
dropout_att=self.dropout_att, scale=self.scale)
self.g_attention_layer = RelativeAttention(
dropout_att=self.dropout_att, scale=self.scale)
self.proj_o = (
self.add_weight(
'o/kernel',
shape=[self.d_model, self.n_head, self.d_head],
initializer=self.initializer))
self.attention_dropout = tf.keras.layers.Dropout(rate=self.dropout)
super(TwoStreamRelativeAttention, self).build(unused_input_shapes)
def __call__(self, h, g, r, r_w_bias, r_r_bias,
seg_mat, r_s_bias, seg_embed, attn_mask_h, attn_mask_g,
mems, target_mapping):
inputs = pack_inputs([
h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed, attn_mask_h,
attn_mask_g, mems, target_mapping
])
return super(TwoStreamRelativeAttention, self).__call__(inputs)
def call(self, inputs):
"""Implements call() for the layer."""
(h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed, attn_mask_h,
attn_mask_g, mems, target_mapping) = unpack_inputs(inputs)
if mems is not None and mems.shape.ndims > 1:
cat = tf.concat([mems, h], 0)
else:
cat = h
# content heads
k_head_h = tf.einsum('ibh,hnd->ibnd', cat, self.kh_projection_layer)
v_head_h = tf.einsum('ibh,hnd->ibnd', cat, self.vh_projection_layer)
k_head_r = tf.einsum('ibh,hnd->ibnd', r, self.kr_projection_layer)
# positional heads
q_head_h = tf.einsum('ibh,hnd->ibnd', h, self.qh_projection_layer)
# core attention ops
attn_vec_h = self.h_attention_layer(q_head_h, k_head_h, v_head_h, k_head_r,
seg_embed, seg_mat, r_w_bias, r_r_bias,
r_s_bias, attn_mask_h)
output_h = tf.einsum('ibnd,hnd->ibh', attn_vec_h, self.proj_o)
output_h = self.attention_dropout(output_h)
output_h = self.output_layer_norm(output_h + h)
##### g-stream
# query-stream query head
q_head_g = tf.einsum('ibh,hnd->ibnd', g, self.qh_projection_layer)
# core attention ops
if target_mapping is not None:
q_head_g = tf.einsum('mbnd,mlb->lbnd', q_head_g, target_mapping)
attn_vec_g = self.g_attention_layer(
q_head_g, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias,
r_r_bias, r_s_bias, attn_mask_g)
attn_vec_g = tf.einsum('lbnd,mlb->mbnd', attn_vec_g, target_mapping)
else:
attn_vec_g = self.g_attention_layer(
q_head_g, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias,
r_r_bias, r_s_bias, attn_mask_g)
# post processing
output_g = tf.einsum('ibnd,hnd->ibh', attn_vec_g, self.proj_o)
output_g = self.attention_dropout(output_g)
output_g = self.output_layer_norm(output_g + g)
return output_h, output_g
class RelativeMultiheadAttention(tf.keras.layers.Layer):
"""Multi-head attention with relative embedding."""
def __init__(self, d_model, n_head, d_head, dropout, dropout_att,
kernel_initializer, **kwargs):
super(RelativeMultiheadAttention, self).__init__(**kwargs)
self.d_model = d_model
self.n_head = n_head
self.d_head = d_head
self.dropout = dropout
self.dropout_att = dropout_att
self.initializer = kernel_initializer
def build(self, unused_input_shapes):
"""Implements build() for the layer."""
self.scale = 1.0 / (self.d_head ** 0.5)
self.output_layer_norm = tf.keras.layers.LayerNormalization(
name='LayerNorm', axis=-1, epsilon=1e-12)
self.kh_projection_layer = self.add_weight(
'k/kernel',
shape=[self.d_model, self.n_head, self.d_head],
initializer=self.initializer)
self.vh_projection_layer = self.add_weight(
'v/kernel',
shape=[self.d_model, self.n_head, self.d_head],
initializer=self.initializer)
self.kr_projection_layer = self.add_weight(
'r/kernel',
shape=[self.d_model, self.n_head, self.d_head],
initializer=self.initializer)
self.qh_projection_layer = self.add_weight(
'q/kernel',
shape=[self.d_model, self.n_head, self.d_head],
initializer=self.initializer)
self.h_attention_layer = RelativeAttention(
dropout_att=self.dropout_att, scale=self.scale)
self.proj_o = self.add_weight(
'o/kernel',
shape=[self.d_model, self.n_head, self.d_head],
initializer=self.initializer)
self.attention_dropout = tf.keras.layers.Dropout(rate=self.dropout)
super(RelativeMultiheadAttention, self).build(unused_input_shapes)
def __call__(self, h, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed,
attn_mask, mems):
inputs = pack_inputs([
h, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed, attn_mask, mems
])
return super(RelativeMultiheadAttention, self).__call__(inputs)
def call(self, inputs):
"""Implements call() for the layer."""
(h, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed, attn_mask,
mems) = unpack_inputs(inputs)
if mems is not None and mems.shape.ndims > 1:
cat = tf.concat([mems, h], 0)
else:
cat = h
# content heads
q_head_h = tf.einsum('ibh,hnd->ibnd', h, self.qh_projection_layer)
k_head_h = tf.einsum('ibh,hnd->ibnd', cat, self.kh_projection_layer)
v_head_h = tf.einsum('ibh,hnd->ibnd', cat, self.vh_projection_layer)
# positional heads
k_head_r = tf.einsum('ibh,hnd->ibnd', r, self.kr_projection_layer)
# core attention ops
attn_vec = self.h_attention_layer(
q_head_h, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias,
r_r_bias, r_s_bias, attn_mask)
# post processing
output = tf.einsum('ibnd,hnd->ibh', attn_vec, self.proj_o)
output = self.attention_dropout(output)
output = self.output_layer_norm(output + h)
return output
class TransformerXLModel(tf.keras.layers.Layer):
"""Defines a Transformer-XL computation graph with additional support for XLNet."""
def __init__(self,
n_token,
n_layer,
d_model,
n_head,
d_head,
d_inner,
dropout,
dropout_att,
attn_type,
bi_data,
is_training,
initializer,
mem_len=None,
same_length=False,
clamp_len=-1,
untie_r=False,
use_tpu=True,
reuse_len=None,
ff_activation='relu',
use_bfloat16=False,
**kwargs):
"""Initializes TransformerXLModel.
Args:
n_token: int, the number of tokens in vocabulary.
n_layer: int, the number of layers.
d_model: int, the hidden size.
n_head: int, the number of attention heads.
d_head: int, the dimension size of each attention head.
d_inner: int, the hidden size in feed-forward layers.
dropout: float, dropout rate.
dropout_att: float, dropout rate on attention probabilities.
attn_type: str, "uni" or "bi".
bi_data: bool, whether to use bidirectional input pipeline. Usually set to
True during pretraining and False during finetuning.
is_training: bool, whether in training mode.
initializer: A tf initializer.
mem_len: int, the number of tokens to cache.
same_length: bool, whether to use the same attention length for each
token.
clamp_len: int, clamp all relative distances larger than clamp_len. -1
means no clamping.
untie_r: bool, whether to untie the biases in attention.
use_tpu: bool, whether TPUs are used.
reuse_len: int, the number of tokens in the currect batch to be cached and
reused in the future.
ff_activation: str, "relu" or "gelu".
use_bfloat16: bool, use bfloat16 instead of float32.
**kwargs: Other parameters.
"""
super(TransformerXLModel, self).__init__(**kwargs)
self.n_token = n_token
self.initializer = initializer
self.attn_type = attn_type
self.n_layer = n_layer
self.d_model = d_model
self.n_head = n_head
self.d_head = d_head
self.d_inner = d_inner
self.ff_activation = ff_activation
self.untie_r = untie_r
self.use_bfloat16 = use_bfloat16
self.use_tpu = use_tpu
self.dropout = dropout
self.dropout_att = dropout_att
self.mem_len = mem_len
self.reuse_len = reuse_len
self.bi_data = bi_data
self.clamp_len = clamp_len
self.same_length = same_length
def build(self, unused_input_shapes):
"""Implements build() for the layer."""
self.tf_float = tf.bfloat16 if self.use_bfloat16 else tf.float32
self.embedding_lookup = EmbeddingLookup(n_token=self.n_token,
d_embed=self.d_model,
initializer=self.initializer,
use_one_hot=self.use_tpu,
dtype=self.tf_float,
name='word_embedding')
self.h_dropout = tf.keras.layers.Dropout(rate=self.dropout)
self.g_dropout = tf.keras.layers.Dropout(rate=self.dropout)
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout)
if self.untie_r:
self.r_w_bias = (
self.add_weight(
'r_w_bias',
shape=[self.n_layer, self.n_head, self.d_head],
dtype=self.tf_float,
initializer=self.initializer))
self.r_r_bias = (
self.add_weight(
'r_r_bias',
shape=[self.n_layer, self.n_head, self.d_head],
dtype=self.tf_float,
initializer=self.initializer))
self.r_s_bias = (
self.add_weight(
'r_s_bias',
shape=[self.n_layer, self.n_head, self.d_head],
dtype=self.tf_float,
initializer=self.initializer))
else:
self.r_w_bias = (
self.add_weight(
'r_w_bias',
shape=[self.n_head, self.d_head],
dtype=self.tf_float,
initializer=self.initializer))
self.r_r_bias = (
self.add_weight(
'r_r_bias',
shape=[self.n_head, self.d_head],
dtype=self.tf_float,
initializer=self.initializer))
self.r_s_bias = (
self.add_weight(
'r_s_bias', [self.n_head, self.d_head],
dtype=self.tf_float,
initializer=self.initializer))
self.seg_embed = self.add_weight(
'seg_embed', [self.n_layer, 2, self.n_head, self.d_head],
dtype=self.tf_float, initializer=self.initializer)
self.mask_emb = self.add_weight('mask_emb/mask_emb',
shape=[1, 1, self.d_model],
dtype=self.tf_float)
self.emb_dropout = tf.keras.layers.Dropout(rate=self.dropout)
self.fwd_position_embedding = PositionalEmbedding(self.d_model)
self.bwd_position_embedding = PositionalEmbedding(self.d_model)
self.two_stream_layers = []
self.rel_multihead_layers = []
self.g_positionwise_ffn_layers = []
self.h_positionwise_ffn_layers = []
for i in range(self.n_layer):
self.two_stream_layers.append(
TwoStreamRelativeAttention(
d_model=self.d_model,
dropout=self.dropout,
n_head=self.n_head,
d_head=self.d_head,
dropout_att=self.dropout_att,
kernel_initializer=self.initializer,
name='layer_%d/rel_attn' % (i)))
self.rel_multihead_layers.append(
RelativeMultiheadAttention(
d_model=self.d_model,
dropout=self.dropout,
n_head=self.n_head,
d_head=self.d_head,
dropout_att=self.dropout_att,
kernel_initializer=self.initializer,
name='layer_%d/rel_attn' % (i)))
self.g_positionwise_ffn_layers.append(
PositionwiseFF(
d_model=self.d_model,
d_inner=self.d_inner,
dropout=self.dropout,
kernel_initializer=self.initializer,
activation_type=self.ff_activation, name='layer_%d/ff'%(i))
)
self.h_positionwise_ffn_layers.append(
PositionwiseFF(
d_model=self.d_model,
d_inner=self.d_inner,
dropout=self.dropout,
kernel_initializer=self.initializer,
activation_type=self.ff_activation, name='layer_%d/ff'%(i))
)
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout)
super(TransformerXLModel, self).build(unused_input_shapes)
def __call__(self,
inp_k,
seg_id=None,
input_mask=None,
mems=None,
perm_mask=None,
target_mapping=None,
inp_q=None):
# Uses dict to feed inputs into call() in order to keep mems as a python
# list.
inputs = {'inp_k': inp_k, 'seg_id': seg_id, 'input_mask': input_mask,
'mems': mems, 'perm_mask': perm_mask,
'target_mapping': target_mapping, 'inp_q': inp_q}
return super(TransformerXLModel, self).__call__(inputs)
def call(self, inputs):
"""Implements call() for the layer."""
inp_k = inputs['inp_k']
seg_id = inputs['seg_id']
input_mask = inputs['input_mask']
mems = inputs['mems']
perm_mask = inputs['perm_mask']
target_mapping = inputs['target_mapping']
inp_q = inputs['inp_q']
new_mems = []
bsz = tf.shape(inp_k)[1]
qlen = inp_k.shape.as_list()[0]
mlen = mems[0].shape.as_list()[0] if mems is not None else 0
klen = mlen + qlen
##### Attention mask
# causal attention mask
if self.attn_type == 'uni':
attn_mask = _create_mask(qlen, mlen, self.tf_float, self.same_length)
# pylint: enable=protected-access
attn_mask = attn_mask[:, :, None, None]
elif self.attn_type == 'bi':
attn_mask = None
else:
raise ValueError('Unsupported attention type: {}'.format(self.attn_type))
# data mask: input mask & perm mask
if input_mask is not None and perm_mask is not None:
data_mask = input_mask[None] + perm_mask
elif input_mask is not None and perm_mask is None:
data_mask = input_mask[None]
elif input_mask is None and perm_mask is not None:
data_mask = perm_mask
else:
data_mask = None
if data_mask is not None:
# all mems can be attended to
mems_mask = tf.zeros([tf.shape(data_mask)[0], mlen, bsz],
dtype=self.tf_float)
data_mask = tf.concat([mems_mask, data_mask], 1)
if attn_mask is None:
attn_mask = data_mask[:, :, :, None]
else:
attn_mask += data_mask[:, :, :, None]
if attn_mask is not None:
attn_mask = tf.cast(attn_mask > 0, dtype=self.tf_float)
if attn_mask is not None:
non_tgt_mask = -tf.eye(qlen, dtype=self.tf_float)
non_tgt_mask = tf.concat([tf.zeros([qlen, mlen], dtype=self.tf_float),
non_tgt_mask], axis=-1)
non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0,
dtype=self.tf_float)
else:
non_tgt_mask = None
word_emb_k, _ = self.embedding_lookup(inp_k)
if inp_q is not None:
if target_mapping is not None:
word_emb_q = tf.tile(self.mask_emb,
[tf.shape(target_mapping)[0], bsz, 1])
else:
inp_q_ext = inp_q[:, :, None]
word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
output_h = self.h_dropout(word_emb_k)
if inp_q is not None:
output_g = self.g_dropout(word_emb_q)
##### Segment embedding
if seg_id is not None:
# Convert `seg_id` to one-hot `seg_mat`
mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32)
cat_ids = tf.concat([mem_pad, seg_id], 0)
# `1` indicates not in the same segment [qlen x klen x bsz]
seg_mat = tf.cast(
tf.logical_not(tf.equal(seg_id[:, None], cat_ids[None, :])),
tf.int32)
seg_mat = tf.one_hot(seg_mat, 2, dtype=self.tf_float)
else:
seg_mat = None
dtype = self.tf_float
freq_seq = tf.range(0, self.d_model, 2.0)
if dtype is not None and dtype != tf.float32:
freq_seq = tf.cast(freq_seq, dtype=self.dtype)
if self.attn_type == 'bi':
beg, end = klen, -qlen
elif self.attn_type == 'uni':
beg, end = klen, -1
else:
raise ValueError('Unknown `attn_type` {}.'.format(self.attn_type))
if self.bi_data:
fwd_pos_seq = tf.range(beg, end, -1.0)
bwd_pos_seq = tf.range(-beg, -end, 1.0)
if dtype is not None and dtype != tf.float32:
fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype)
bwd_pos_seq = tf.cast(bwd_pos_seq, dtype=dtype)
if self.clamp_len > 0:
fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -self.clamp_len,
self.clamp_len)
bwd_pos_seq = tf.clip_by_value(bwd_pos_seq, -self.clamp_len,
self.clamp_len)
if bsz is not None:
fwd_pos_emb = self.fwd_position_embedding(fwd_pos_seq, bsz//2)
bwd_pos_emb = self.bwd_position_embedding(bwd_pos_seq, bsz//2)
else:
fwd_pos_emb = self.fwd_position_embedding(fwd_pos_seq, None)
bwd_pos_emb = self.bwd_position_embedding(bwd_pos_seq, None)
pos_emb = tf.concat([fwd_pos_emb, bwd_pos_emb], axis=1)
else:
fwd_pos_seq = tf.range(beg, end, -1.0)
if dtype is not None and dtype != tf.float32:
fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype)
if self.clamp_len > 0:
fwd_pos_seq = tf.clip_by_value(fwd_pos_seq,
-self.clamp_len, self.lamp_len)
pos_emb = self.fwd_position_embedding(fwd_pos_seq, bsz)
pos_emb = self.emb_dropout(pos_emb)
if mems is None:
mems = [None] * self.n_layer
for i in range(self.n_layer):
# cache new mems
new_mems.append(
_cache_mem(output_h, mems[i], self.mem_len, self.reuse_len))
# pylint: enable=protected-access
# segment bias
if seg_id is None:
r_s_bias_i = None
seg_embed_i = None
else:
r_s_bias_i = self.r_s_bias if not self.untie_r else self.r_s_bias[i]
seg_embed_i = self.seg_embed[i]
if inp_q is not None:
two_stream_layer = self.two_stream_layers[i]
g_ffn_layer = self.g_positionwise_ffn_layers[i]
h_ffn_layer = self.h_positionwise_ffn_layers[i]
rel_multihead_layer = self.rel_multihead_layers[i]
output_h, output_g = two_stream_layer(
h=output_h,
g=output_g,
r=pos_emb,
r_w_bias=self.r_w_bias if not self.untie_r else self.r_w_bias[i],
r_r_bias=self.r_r_bias if not self.untie_r else self.r_r_bias[i],
seg_mat=seg_mat,
r_s_bias=r_s_bias_i,
seg_embed=seg_embed_i,
attn_mask_h=non_tgt_mask,
attn_mask_g=attn_mask,
mems=mems[i],
target_mapping=target_mapping)
output_g = g_ffn_layer(output_g)
output_h = g_ffn_layer(output_h)
else:
rel_multihead_layer = self.rel_multihead_layers[i]
h_ffn_layer = self.h_positionwise_ffn_layers[i]
output_h = rel_multihead_layer(
h=output_h,
r=pos_emb,
r_w_bias=self.r_w_bias if not self.untie_r else self.r_w_bias[i],
r_r_bias=self.r_r_bias if not self.untie_r else self.r_r_bias[i],
seg_mat=seg_mat,
r_s_bias=r_s_bias_i,
seg_embed=seg_embed_i,
attn_mask=non_tgt_mask,
mems=mems[i])
output_h = h_ffn_layer(output_h)
if inp_q is not None:
output = self.output_dropout(output_g)
else:
output = self.output_dropout(output_h)
return output, new_mems, None
class PretrainingXLNetModel(tf.keras.Model):
"""XLNet keras model combined with pretraining LM loss layer.
See the original paper: https://arxiv.org/pdf/1906.08237.pdf
"""
def __init__(self, xlnet_config, run_config, **kwargs):
super(PretrainingXLNetModel, self).__init__(**kwargs)
self.run_config = run_config
self.initializer = _get_initializer(run_config)
self.xlnet_config = copy.deepcopy(xlnet_config)
self.transformerxl_model = TransformerXLModel(
n_token=self.xlnet_config.n_token,
initializer=self.initializer,
attn_type='bi',
n_layer=self.xlnet_config.n_layer,
d_model=self.xlnet_config.d_model,
n_head=self.xlnet_config.n_head,
d_head=self.xlnet_config.d_head,
d_inner=self.xlnet_config.d_inner,
ff_activation=self.xlnet_config.ff_activation,
untie_r=self.xlnet_config.untie_r,
is_training=self.run_config.is_training,
use_bfloat16=self.run_config.use_bfloat16,
use_tpu=self.run_config.use_tpu,
dropout=self.run_config.dropout,
dropout_att=self.run_config.dropout_att,
mem_len=self.run_config.mem_len,
reuse_len=self.run_config.reuse_len,
bi_data=self.run_config.bi_data,
clamp_len=self.run_config.clamp_len,
same_length=self.run_config.same_length,
name='transformer')
self.lmloss_layer = LMLossLayer(n_token=self.xlnet_config.n_token,
d_model=self.xlnet_config.d_model,
initializer=self.initializer,
use_bfloat16=self.run_config.use_bfloat16,
tie_weight=True,
bi_data=self.run_config.bi_data,
use_tpu=self.run_config.use_tpu,
name='lm_loss')
def call(self, features):
"""Implements call() for the layer."""
input_ids = tf.transpose(features['input_k'], [1, 0])
inp_q = tf.transpose(features['input_q'], [1, 0])
seg_ids = tf.transpose(features['seg_id'], [1, 0])
input_mask = None
perm_mask = tf.transpose(features['perm_mask'], [1, 2, 0])
target_mapping = tf.transpose(features['target_mapping'], [1, 2, 0])
# target for LM loss
target = tf.transpose(features['target'], [1, 0])
# target mask for LM loss
tgt_mask = tf.transpose(features['target_mask'], [1, 0])
mems = features['mems']
self.transformerxl_output, self.new_mems, self.lookup_table = self.transformerxl_model(
inp_k=input_ids,
seg_id=seg_ids,
input_mask=input_mask,
mems=mems,
perm_mask=perm_mask,
target_mapping=target_mapping,
inp_q=inp_q)
lm_loss = self.lmloss_layer(
hidden=self.transformerxl_output,
target=target,
lookup_table=self.transformerxl_model.embedding_lookup.lookup_table,
target_mask=tgt_mask)
self.add_loss(lm_loss)
return self.new_mems, self.transformerxl_output
class ClassificationXLNetModel(tf.keras.Model):
"""XLNet keras model combined with classification loss layer.
See the original paper: https://arxiv.org/pdf/1906.08237.pdf
"""
def __init__(self, xlnet_config, run_config, n_class, **kwargs):
super(ClassificationXLNetModel, self).__init__(**kwargs)
self.run_config = run_config
self.initializer = _get_initializer(run_config)
self.xlnet_config = copy.deepcopy(xlnet_config)
self.transformerxl_model = TransformerXLModel(
n_token=self.xlnet_config.n_token,
initializer=self.initializer,
attn_type='bi',
n_layer=self.xlnet_config.n_layer,
d_model=self.xlnet_config.d_model,
n_head=self.xlnet_config.n_head,
d_head=self.xlnet_config.d_head,
d_inner=self.xlnet_config.d_inner,
ff_activation=self.xlnet_config.ff_activation,
untie_r=self.xlnet_config.untie_r,
is_training=self.run_config.is_training,
use_bfloat16=self.run_config.use_bfloat16,
use_tpu=self.run_config.use_tpu,
dropout=self.run_config.dropout,
dropout_att=self.run_config.dropout_att,
mem_len=self.run_config.mem_len,
reuse_len=self.run_config.reuse_len,
bi_data=self.run_config.bi_data,
clamp_len=self.run_config.clamp_len,
same_length=self.run_config.same_length,
name='transformer')
self.summarization_layer = Summarization(
d_model=self.xlnet_config.d_model,
n_head=self.xlnet_config.n_head,
d_head=self.xlnet_config.d_head,
dropout=self.run_config.dropout,
dropout_att=self.run_config.dropout_att,
initializer=self.initializer,
use_proj=True,
summary_type='last',
name='sequence_summary')
self.cl_loss_layer = ClassificationLossLayer(
n_class=n_class, initializer=self.initializer, name='classification')
def call(self, features):
"""Implements call() for the layer."""
bsz_per_core = tf.shape(features['input_ids'])[0]
input_ids = tf.transpose(features['input_ids'], [1, 0])
seg_ids = tf.transpose(features['segment_ids'], [1, 0])
input_mask = tf.transpose(features['input_mask'], [1, 0])
label = tf.reshape(features['label_ids'], [bsz_per_core])
mems = features['mems']
self.transformerxl_output, self.new_mems, self.lookup_table = (
self.transformerxl_model(
inp_k=input_ids, seg_id=seg_ids, input_mask=input_mask, mems=mems))
self.summary = self.summarization_layer(self.transformerxl_output)
per_example_loss, logits = self.cl_loss_layer(
hidden=self.summary, labels=label)
self.add_loss(tf.keras.backend.mean(per_example_loss))
return self.new_mems, logits
class LMLossLayer(tf.keras.layers.Layer):
"""Layer computing cross entropy loss for language modeling."""
def __init__(self, n_token, d_model, initializer, use_bfloat16,
tie_weight=False, bi_data=True, use_tpu=False, **kwargs):
"""Constructs LMLoss layer.
Args:
n_token: Number of tokens in vocabulary.
d_model: The dimension of model hidden state.
initializer: Initializer used for parameters.
use_bfloat16: Whether to use bfloat16.
tie_weight: Whether to share weights between embedding lookup layer and
next-token prediction layer.
bi_data: Whether to use bidirectional input pipeline.
Usually set to True during pretraining and False during finetuning.
use_tpu: bool, whether to use TPU.
**kwargs: Other parameters.
"""
super(LMLossLayer, self).__init__(**kwargs)
self.n_token = n_token
self.d_model = d_model
self.initializer = initializer
self.tie_weight = tie_weight
self.bi_data = bi_data
self.use_tpu = use_tpu
self.use_bfloat16 = use_bfloat16
def build(self, unused_input_shapes):
"""Implements build() for the layer."""
if not self.tie_weight:
self.softmax_w = self.add_weight('weight',
shape=[self.n_token, self.d_model],
initializer=self.initializer)
self.softmax_b = self.add_weight('bias', shape=[self.n_token],
initializer=tf.zeros_initializer())
super(LMLossLayer, self).build(unused_input_shapes)
def __call__(self, hidden, target, lookup_table, target_mask):
inputs = pack_inputs([hidden, target, lookup_table, target_mask])
return super(LMLossLayer, self).__call__(inputs)
def call(self, inputs):
"""Implements call() for the layer."""
(hidden, target, lookup_table, tgt_mask) = unpack_inputs(inputs)
if self.tie_weight:
logits = tf.einsum('ibd,nd->ibn', hidden, lookup_table) + self.softmax_b
else:
logits = tf.einsum('ibd,nd->ibn', hidden, self.softmax_w) + self.softmax_b
if self.use_tpu:
one_hot_target = tf.one_hot(target, self.n_token, dtype=logits.dtype)
loss = -tf.reduce_sum(tf.nn.log_softmax(logits) * one_hot_target, -1)
else:
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target,
logits=logits)
if self.use_bfloat16:
tgt_mask = tf.cast(tgt_mask, tf.float32)
loss = tf.cast(loss, tf.float32)
total_loss = tf.reduce_sum(loss * tgt_mask) / tf.reduce_sum(tgt_mask)
return total_loss
class Summarization(tf.keras.layers.Layer):
"""The layer to pool the output from XLNet model into a vector."""
def __init__(self,
d_model,
n_head,
d_head,
dropout,
dropout_att,
initializer,
use_proj=True,
summary_type='last',
**kwargs):
"""Constructs Summarization layer.
Args:
d_model: int, the dimension of model hidden state.
n_head: int, the number of attention heads.
d_head: int, the dimension size of each attention head.
dropout: float, dropout rate.
dropout_att: float, dropout rate on attention probabilities.
initializer: Initializer used for parameters.
use_proj: bool, whether to use projection layer for summarization.
summary_type: Method used to summarize a sequence into a compact vector.
**kwargs: Other parameters.
"""
super(Summarization, self).__init__(**kwargs)
self.d_model = d_model
self.n_head = n_head
self.d_head = d_head
self.initializer = initializer
self.dropout = dropout
self.dropout_att = dropout_att
self.use_proj = use_proj
self.summary_type = summary_type
def build(self, unused_input_shapes):
"""Implements build() for the layer."""
if self.use_proj:
self.proj_layer = tf.keras.layers.Dense(
units=self.d_model,
kernel_initializer=self.initializer,
activation=tf.nn.tanh,
name='summary')
self.dropout_layer = tf.keras.layers.Dropout(rate=self.dropout)
super(Summarization, self).build(unused_input_shapes)
def call(self, inputs):
"""Implements call() for the layer."""
summary = inputs[-1]
summary = self.proj_layer(summary)
summary = self.dropout_layer(summary)
return summary
class ClassificationLossLayer(tf.keras.layers.Layer):
"""Layer computing cross entropy loss for classification task."""
def __init__(self, n_class, initializer, **kwargs):
"""Constructs Summarization layer.
Args:
n_class: Number of tokens in vocabulary.
initializer: Initializer used for parameters.
**kwargs: Other parameters.
"""
super(ClassificationLossLayer, self).__init__(**kwargs)
self.n_class = n_class
self.initializer = initializer
def build(self, unused_input_shapes):
"""Implements build() for the layer."""
self.proj_layer = tf.keras.layers.Dense(
units=self.n_class, kernel_initializer=self.initializer, name='logit')
super(ClassificationLossLayer, self).build(unused_input_shapes)
def __call__(self, hidden, labels):
inputs = pack_inputs([hidden, labels])
return super(ClassificationLossLayer, self).__call__(inputs)
def call(self, inputs):
"""Implements call() for the layer."""
(hidden, labels) = unpack_inputs(inputs)
logits = self.proj_layer(hidden)
one_hot_target = tf.one_hot(labels, self.n_class, dtype=hidden.dtype) # pytype: disable=attribute-error
loss = -tf.reduce_sum(tf.nn.log_softmax(logits) * one_hot_target, -1)
return loss, logits
class QAXLNetModel(tf.keras.Model):
"""XLNet keras model combined with question answering loss layer.
See the original paper: https://arxiv.org/pdf/1906.08237.pdf
"""
def __init__(self, xlnet_config, run_config, start_n_top, end_n_top,
**kwargs):
super(QAXLNetModel, self).__init__(**kwargs)
self.run_config = run_config
self.initializer = _get_initializer(run_config)
self.xlnet_config = copy.deepcopy(xlnet_config)
self.transformerxl_model = TransformerXLModel(
n_token=self.xlnet_config.n_token,
initializer=self.initializer,
attn_type='bi',
n_layer=self.xlnet_config.n_layer,
d_model=self.xlnet_config.d_model,
n_head=self.xlnet_config.n_head,
d_head=self.xlnet_config.d_head,
d_inner=self.xlnet_config.d_inner,
ff_activation=self.xlnet_config.ff_activation,
untie_r=self.xlnet_config.untie_r,
is_training=self.run_config.is_training,
use_bfloat16=self.run_config.use_bfloat16,
use_tpu=self.run_config.use_tpu,
dropout=self.run_config.dropout,
dropout_att=self.run_config.dropout_att,
mem_len=self.run_config.mem_len,
reuse_len=self.run_config.reuse_len,
bi_data=self.run_config.bi_data,
clamp_len=self.run_config.clamp_len,
same_length=self.run_config.same_length,
name='transformer')
self.qa_loss_layer = QALossLayer(
d_model=self.xlnet_config.d_model,
start_n_top=start_n_top,
end_n_top=end_n_top,
initializer=self.initializer,
dropout=self.run_config.dropout)
def call(self, features, training=False):
"""Implements call() for the layer."""
input_ids = tf.transpose(features['input_ids'], [1, 0])
seg_ids = tf.transpose(features['segment_ids'], [1, 0])
input_mask = tf.transpose(features['input_mask'], [1, 0])
cls_index = tf.reshape(features['cls_index'], [-1])
p_mask = features['p_mask']
self.transformerxl_output, self.new_mems, self.lookup_table = (
self.transformerxl_model(
inp_k=input_ids, seg_id=seg_ids, input_mask=input_mask))
if training:
loss, logits = self.qa_loss_layer(
hidden=self.transformerxl_output,
p_mask=p_mask,
cls_index=cls_index,
start_positions=features['start_positions'],
end_positions=features['end_positions'],
is_impossible=features['is_impossible'])
self.add_loss(loss)
return self.new_mems, logits
else:
results = self.qa_loss_layer(
hidden=self.transformerxl_output, p_mask=p_mask, cls_index=cls_index)
return results
class QALossLayer(tf.keras.layers.Layer):
"""Layer computing position and regression loss for question answering task.
"""
def __init__(self, d_model, start_n_top, end_n_top, initializer, dropout,
**kwargs):
"""Constructs Summarization layer.
Args:
d_model: Int, the hidden size.
start_n_top: Beam size for span start.
end_n_top: Beam size for span end.
initializer: Initializer used for parameters.
dropout: float, dropout rate.
**kwargs: Other parameters.
"""
super(QALossLayer, self).__init__(**kwargs)
self.d_model = d_model
self.start_n_top = start_n_top
self.end_n_top = end_n_top
self.initializer = initializer
self.dropout = dropout
def build(self, unused_input_shapes):
"""Implements build() for the layer."""
self.start_logits_proj_layer = tf.keras.layers.Dense(
units=1, kernel_initializer=self.initializer, name='start_logits/dense')
self.end_logits_proj_layer0 = tf.keras.layers.Dense(
units=self.d_model,
kernel_initializer=self.initializer,
activation=tf.nn.tanh,
name='end_logits/dense_0')
self.end_logits_proj_layer1 = tf.keras.layers.Dense(
units=1, kernel_initializer=self.initializer, name='end_logits/dense_1')
self.end_logits_layer_norm = tf.keras.layers.LayerNormalization(
axis=-1, epsilon=1e-12, name='end_logits/LayerNorm')
self.answer_class_proj_layer0 = tf.keras.layers.Dense(
units=self.d_model,
kernel_initializer=self.initializer,
activation=tf.nn.tanh,
name='answer_class/dense_0')
self.answer_class_proj_layer1 = tf.keras.layers.Dense(
units=1,
kernel_initializer=self.initializer,
use_bias=False,
name='answer_class/dense_1')
self.ans_feature_dropout = tf.keras.layers.Dropout(rate=self.dropout)
super(QALossLayer, self).build(unused_input_shapes)
def __call__(self, hidden, p_mask, cls_index, **kwargs):
return super(QALossLayer, self).__call__(
(hidden, p_mask, cls_index, kwargs))
def call(self, inputs, training=False):
"""Implements call() for the layer."""
hidden, p_mask, cls_index, kwargs = inputs
return_dict = {}
seq_len = tf.shape(hidden)[0]
start_logits = self.start_logits_proj_layer(hidden)
start_logits = tf.transpose(tf.squeeze(start_logits, -1), [1, 0])
start_logits_masked = start_logits * (1 - p_mask) - 1e30 * p_mask
start_log_probs = tf.nn.log_softmax(start_logits_masked, -1)
if training:
start_positions = kwargs['start_positions']
end_positions = kwargs['end_positions']
is_impossible = kwargs['is_impossible']
start_positions = tf.reshape(start_positions, [-1])
start_index = tf.one_hot(
start_positions, depth=seq_len, axis=-1, dtype=tf.float32)
start_features = tf.einsum('lbh,bl->bh', hidden, start_index)
start_features = tf.tile(start_features[None], [seq_len, 1, 1])
end_logits = self.end_logits_proj_layer0(
tf.concat([hidden, start_features], axis=-1))
end_logits = self.end_logits_layer_norm(end_logits)
end_logits = self.end_logits_proj_layer1(end_logits)
end_logits = tf.transpose(tf.squeeze(end_logits, -1), [1, 0])
end_logits_masked = end_logits * (1 - p_mask) - 1e30 * p_mask
end_log_probs = tf.nn.log_softmax(end_logits_masked, -1)
else:
# during inference, compute the end logits based on beam search
start_top_log_probs, start_top_index = tf.nn.top_k(
start_log_probs, k=self.start_n_top)
start_index = tf.one_hot(
start_top_index, depth=seq_len, axis=-1, dtype=tf.float32)
start_features = tf.einsum('lbh,bkl->bkh', hidden, start_index)
end_input = tf.tile(hidden[:, :, None], [1, 1, self.start_n_top, 1])
start_features = tf.tile(start_features[None], [seq_len, 1, 1, 1])
end_input = tf.concat([end_input, start_features], axis=-1)
end_logits = self.end_logits_proj_layer0(end_input)
end_logits = tf.reshape(end_logits, [seq_len, -1, self.d_model])
end_logits = self.end_logits_layer_norm(end_logits)
end_logits = tf.reshape(end_logits,
[seq_len, -1, self.start_n_top, self.d_model])
end_logits = self.end_logits_proj_layer1(end_logits)
end_logits = tf.reshape(end_logits, [seq_len, -1, self.start_n_top])
end_logits = tf.transpose(end_logits, [1, 2, 0])
end_logits_masked = end_logits * (
1 - p_mask[:, None]) - 1e30 * p_mask[:, None]
end_log_probs = tf.nn.log_softmax(end_logits_masked, -1)
end_top_log_probs, end_top_index = tf.nn.top_k(
end_log_probs, k=self.end_n_top)
end_top_log_probs = tf.reshape(end_top_log_probs,
[-1, self.start_n_top * self.end_n_top])
end_top_index = tf.reshape(end_top_index,
[-1, self.start_n_top * self.end_n_top])
if training:
return_dict['start_log_probs'] = start_log_probs
return_dict['end_log_probs'] = end_log_probs
else:
return_dict['start_top_log_probs'] = start_top_log_probs
return_dict['start_top_index'] = start_top_index
return_dict['end_top_log_probs'] = end_top_log_probs
return_dict['end_top_index'] = end_top_index
# an additional layer to predict answerability
# get the representation of CLS
cls_index = tf.one_hot(cls_index, seq_len, axis=-1, dtype=tf.float32)
cls_feature = tf.einsum('lbh,bl->bh', hidden, cls_index)
# get the representation of START
start_p = tf.nn.softmax(start_logits_masked, axis=-1, name='softmax_start')
start_feature = tf.einsum('lbh,bl->bh', hidden, start_p)
ans_feature = tf.concat([start_feature, cls_feature], -1)
ans_feature = self.answer_class_proj_layer0(ans_feature)
ans_feature = self.ans_feature_dropout(ans_feature)
cls_logits = self.answer_class_proj_layer1(ans_feature)
cls_logits = tf.squeeze(cls_logits, -1)
return_dict['cls_logits'] = cls_logits
if not training:
return return_dict
def compute_loss(log_probs, positions):
one_hot_positions = tf.one_hot(positions, depth=seq_len, dtype=tf.float32)
loss = -tf.reduce_sum(one_hot_positions * log_probs, axis=-1)
loss = tf.reduce_mean(loss)
return loss
start_loss = compute_loss(start_log_probs, start_positions)
end_loss = compute_loss(end_log_probs, end_positions)
total_loss = (start_loss + end_loss) * 0.5
is_impossible = tf.reshape(is_impossible, [-1])
regression_loss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=is_impossible, logits=cls_logits)
regression_loss = tf.reduce_mean(regression_loss)
total_loss += regression_loss * 0.5
return total_loss, cls_logits
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import logging
import numpy as np
import tensorflow as tf
from official.nlp import xlnet_modeling
class PositionalEmbeddingLayerTest(tf.test.TestCase):
def test_positional_embedding(self):
"""A low-dimensional example is tested.
With len(pos_seq)=2 and d_model=4:
pos_seq = [[1.], [0.]]
inv_freq = [1., 0.01]
pos_seq x inv_freq = [[1, 0.01], [0., 0.]]
pos_emb = [[sin(1.), sin(0.01), cos(1.), cos(0.01)],
[sin(0.), sin(0.), cos(0.), cos(0.)]]
= [[0.84147096, 0.00999983, 0.54030228, 0.99994999],
[0., 0., 1., 1.]]
"""
target = np.array([[[0.84147096, 0.00999983, 0.54030228, 0.99994999]],
[[0., 0., 1., 1.]]])
d_model = 4
pos_seq = tf.range(1, -1, -1.0) # [1., 0.]
pos_emb_layer = xlnet_modeling.PositionalEmbedding(d_model)
pos_emb = pos_emb_layer(
pos_seq=pos_seq, batch_size=None).numpy().astype(float)
logging.info(pos_emb)
self.assertAllClose(pos_emb, target)
if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment