Commit 6cd426d9 authored by Jing Li's avatar Jing Li Committed by A. Unique TensorFlower
Browse files

Support online masking for XLNet

PiperOrigin-RevId: 275408074
parent b0581d0a
...@@ -19,12 +19,15 @@ from __future__ import division ...@@ -19,12 +19,15 @@ from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
from __future__ import print_function from __future__ import print_function
import collections
import json import json
import os import os
from absl import logging from absl import logging
import numpy as np
import tensorflow as tf import tensorflow as tf
special_symbols = { special_symbols = {
"<unk>": 0, "<unk>": 0,
"<s>": 1, "<s>": 1,
...@@ -49,6 +52,11 @@ SEG_ID_CLS = 2 ...@@ -49,6 +52,11 @@ SEG_ID_CLS = 2
SEG_ID_PAD = 3 SEG_ID_PAD = 3
OnlineMaskingConfig = collections.namedtuple("OnlineMaskingConfig", [
"sample_strategy", "max_num_tokens", "min_num_tokens", "max_num_words",
"min_num_words"])
def file_based_input_fn_builder(input_file, name_to_features, batch_size, def file_based_input_fn_builder(input_file, name_to_features, batch_size,
is_training): is_training):
"""Creates an `input_fn` closure.""" """Creates an `input_fn` closure."""
...@@ -249,11 +257,191 @@ def get_squad_input_data(batch_size, seq_len, q_len, strategy, is_training, ...@@ -249,11 +257,191 @@ def get_squad_input_data(batch_size, seq_len, q_len, strategy, is_training,
return _dataset_fn if use_dataset_fn else _dataset_fn() return _dataset_fn if use_dataset_fn else _dataset_fn()
def _idx_pair_to_mask(beg_indices, end_indices, inputs, tgt_len, num_predict):
"""Turn beg and end indices into actual mask."""
non_func_mask = tf.logical_and(
tf.not_equal(inputs, SEP_ID),
tf.not_equal(inputs, CLS_ID))
all_indices = tf.where(
non_func_mask,
tf.range(tgt_len, dtype=tf.int64),
tf.constant(-1, shape=[tgt_len], dtype=tf.int64))
candidate_matrix = tf.cast(
tf.logical_and(
all_indices[None, :] >= beg_indices[:, None],
all_indices[None, :] < end_indices[:, None]),
tf.float32)
cumsum_matrix = tf.reshape(
tf.cumsum(tf.reshape(candidate_matrix, [-1])),
[-1, tgt_len])
masked_matrix = tf.cast(cumsum_matrix <= num_predict, tf.float32)
target_mask = tf.reduce_sum(candidate_matrix * masked_matrix, axis=0)
is_masked = tf.cast(target_mask, tf.bool)
return is_masked, target_mask
def _word_span_mask(inputs, tgt_len, num_predict, min_num_words,
max_num_words, boundary):
"""Sample whole word spans as prediction targets."""
# Note: 1.2 is the token-to-word ratio
mask_alpha = tgt_len / num_predict / 1.2
round_to_int = lambda x: tf.cast(tf.round(x), tf.int64)
# Sample span lengths from a zipf distribution
span_len_seq = np.arange(min_num_words, max_num_words + 1)
probs = np.array([1.0 / (i + 1) for i in span_len_seq])
probs /= np.sum(probs)
logits = tf.constant(np.log(probs), dtype=tf.float32)
# Sample `num_predict` words here: note that this is over sampling
span_lens = tf.random.categorical(
logits=logits[None],
num_samples=num_predict,
dtype=tf.int64,
)[0] + min_num_words
# Sample the ratio [0.0, 1.0) of left context lengths
span_lens_float = tf.cast(span_lens, tf.float32)
left_ratio = tf.random.uniform(shape=[num_predict], minval=0.0, maxval=1.0)
left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1)
left_ctx_len = round_to_int(left_ctx_len)
right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len
beg_indices = (tf.cumsum(left_ctx_len) +
tf.cumsum(right_offset, exclusive=True))
end_indices = beg_indices + span_lens
# Remove out of range indices
max_boundary_index = tf.cast(tf.shape(boundary)[0] - 1, tf.int64)
valid_idx_mask = end_indices < max_boundary_index
beg_indices = tf.boolean_mask(beg_indices, valid_idx_mask)
end_indices = tf.boolean_mask(end_indices, valid_idx_mask)
beg_indices = tf.gather(boundary, beg_indices)
end_indices = tf.gather(boundary, end_indices)
# Shuffle valid indices
num_valid = tf.cast(tf.shape(beg_indices)[0], tf.int64)
order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int64))
beg_indices = tf.gather(beg_indices, order)
end_indices = tf.gather(end_indices, order)
return _idx_pair_to_mask(beg_indices, end_indices, inputs, tgt_len,
num_predict)
def _token_span_mask(inputs, tgt_len, num_predict, min_num_tokens,
max_num_tokens):
"""Sample token spans as prediction targets."""
mask_alpha = tgt_len / num_predict
round_to_int = lambda x: tf.cast(tf.round(x), tf.int64)
# Sample span lengths from a zipf distribution
span_len_seq = np.arange(min_num_tokens, max_num_tokens + 1)
probs = np.array([1.0 / (i + 1) for i in span_len_seq])
probs /= np.sum(probs)
logits = tf.constant(np.log(probs), dtype=tf.float32)
span_lens = tf.random.categorical(
logits=logits[None],
num_samples=num_predict,
dtype=tf.int64,
)[0] + min_num_tokens
# Sample the ratio [0.0, 1.0) of left context lengths
span_lens_float = tf.cast(span_lens, tf.float32)
left_ratio = tf.random.uniform(shape=[num_predict], minval=0.0, maxval=1.0)
left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1)
left_ctx_len = round_to_int(left_ctx_len)
# Compute the offset from left start to the right end
right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len
# Get the actual begin and end indices
beg_indices = (tf.cumsum(left_ctx_len) +
tf.cumsum(right_offset, exclusive=True))
end_indices = beg_indices + span_lens
# Remove out of range indices
valid_idx_mask = end_indices < tgt_len
beg_indices = tf.boolean_mask(beg_indices, valid_idx_mask)
end_indices = tf.boolean_mask(end_indices, valid_idx_mask)
# Shuffle valid indices
num_valid = tf.cast(tf.shape(beg_indices)[0], tf.int64)
order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int64))
beg_indices = tf.gather(beg_indices, order)
end_indices = tf.gather(end_indices, order)
return _idx_pair_to_mask(beg_indices, end_indices, inputs, tgt_len,
num_predict)
def _whole_word_mask(inputs, tgt_len, num_predict, boundary):
"""Sample whole words as prediction targets."""
pair_indices = tf.concat([boundary[:-1, None], boundary[1:, None]], axis=1)
cand_pair_indices = tf.random.shuffle(pair_indices)[:num_predict]
beg_indices = cand_pair_indices[:, 0]
end_indices = cand_pair_indices[:, 1]
return _idx_pair_to_mask(beg_indices, end_indices, inputs, tgt_len,
num_predict)
def _single_token_mask(inputs, tgt_len, num_predict):
"""Sample individual tokens as prediction targets."""
all_indices = tf.range(tgt_len, dtype=tf.int64)
non_func_mask = tf.logical_and(
tf.not_equal(inputs, SEP_ID),
tf.not_equal(inputs, CLS_ID))
non_func_indices = tf.boolean_mask(all_indices, non_func_mask)
masked_pos = tf.random.shuffle(non_func_indices)
masked_pos = tf.contrib.framework.sort(masked_pos[:num_predict])
target_mask = tf.sparse_to_dense(
sparse_indices=masked_pos,
output_shape=[tgt_len],
sparse_values=1.0,
default_value=0.0)
is_masked = tf.cast(target_mask, tf.bool)
return is_masked, target_mask
def _online_sample_masks(inputs, tgt_len, num_predict, online_masking_config,
boundary=None):
"""Sample target positions to predict."""
logging.info("Online sample with strategy: `%s`.",
online_masking_config.sample_strategy)
if online_masking_config.sample_strategy == "single_token":
return _single_token_mask(inputs, tgt_len, num_predict)
elif online_masking_config.sample_strategy == "whole_word":
assert boundary is not None, "whole word sampling requires `boundary`"
return _whole_word_mask(inputs, tgt_len, num_predict, boundary)
elif online_masking_config.sample_strategy == "token_span":
return _token_span_mask(inputs, tgt_len, num_predict,
online_masking_config.min_num_tokens,
online_masking_config.max_num_tokens)
elif online_masking_config.sample_strategy == "word_span":
assert boundary is not None, "word span sampling requires `boundary`"
return _word_span_mask(inputs, tgt_len, num_predict,
online_masking_config.min_num_words,
online_masking_config.max_num_words,
boundary)
else:
raise NotImplementedError
def create_pretrain_dataset(file_names, def create_pretrain_dataset(file_names,
bsz_per_core, bsz_per_core,
seq_len, seq_len,
reuse_len, reuse_len,
perm_size, perm_size,
leak_ratio,
online_masking_config,
num_predict=None, num_predict=None,
input_pipeline_context=None): input_pipeline_context=None):
"""Creates pretrain dataset.""" """Creates pretrain dataset."""
...@@ -263,46 +451,67 @@ def create_pretrain_dataset(file_names, ...@@ -263,46 +451,67 @@ def create_pretrain_dataset(file_names,
record_spec = { record_spec = {
"input": tf.io.FixedLenFeature([seq_len], tf.int64), "input": tf.io.FixedLenFeature([seq_len], tf.int64),
"target": tf.io.FixedLenFeature([seq_len], tf.int64),
"seg_id": tf.io.FixedLenFeature([seq_len], tf.int64), "seg_id": tf.io.FixedLenFeature([seq_len], tf.int64),
"label": tf.io.FixedLenFeature([1], tf.int64), "label": tf.io.FixedLenFeature([1], tf.int64),
"is_masked": tf.io.FixedLenFeature([seq_len], tf.int64),
} }
if online_masking_config.sample_strategy in ["whole_word", "word_span"]:
logging.info("Add `boundary` spec for %s",
online_masking_config.sample_strategy)
record_spec["boundary"] = tf.io.VarLenFeature(tf.int64)
# retrieve serialized example # retrieve serialized example
example = tf.io.parse_single_example( example = tf.io.parse_single_example(
serialized=record, features=record_spec) serialized=record, features=record_spec)
inputs = example.pop("input") inputs = example.pop("input")
target = example.pop("target") if online_masking_config.sample_strategy in ["whole_word", "word_span"]:
is_masked = tf.cast(example.pop("is_masked"), tf.bool) boundary = tf.sparse.to_dense(example.pop("boundary"))
else:
non_reuse_len = seq_len - reuse_len boundary = None
# perm_size should not be larger than reuse_len or non_reuse_len otherwise is_masked, _ = _online_sample_masks(
# there will be data leaks. inputs, seq_len, num_predict, online_masking_config, boundary=boundary)
assert perm_size <= reuse_len and perm_size <= non_reuse_len
if reuse_len > 0:
# Creates permutation mask and target mask for the first reuse_len tokens. ##### Use memory
# The tokens in this part are reused from the last sequence. # permutate the reuse and non-reuse parts separately
perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0 = _local_perm( non_reuse_len = seq_len - reuse_len
inputs[:reuse_len], target[:reuse_len], is_masked[:reuse_len], assert reuse_len % perm_size == 0 and non_reuse_len % perm_size == 0
perm_size, reuse_len)
# Creates permutation mask and target mask for the first reuse_len tokens.
# Creates permutation mask and target mask for the rest of tokens in # The tokens in this part are reused from the last sequence.
# current example, which are concatentation of two new segments. perm_mask_0, target_mask_0, input_k_0, input_q_0 = _local_perm(
perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1 = _local_perm( inputs[:reuse_len], is_masked[:reuse_len], perm_size, reuse_len,
inputs[reuse_len:], target[reuse_len:], is_masked[reuse_len:], leak_ratio)
perm_size, non_reuse_len)
# Creates permutation mask and target mask for the rest of tokens in
perm_mask_0 = tf.concat( # current example, which are concatentation of two new segments.
[perm_mask_0, tf.ones([reuse_len, non_reuse_len])], axis=1) perm_mask_1, target_mask_1, input_k_1, input_q_1 = _local_perm(
perm_mask_1 = tf.concat([tf.zeros([non_reuse_len, reuse_len]), perm_mask_1], inputs[reuse_len:], is_masked[reuse_len:], perm_size, non_reuse_len,
axis=1) leak_ratio)
perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0)
target = tf.concat([target_0, target_1], axis=0) perm_mask_0 = tf.concat(
target_mask = tf.concat([target_mask_0, target_mask_1], axis=0) [perm_mask_0, tf.ones([reuse_len, non_reuse_len])], axis=1)
input_k = tf.concat([input_k_0, input_k_1], axis=0) perm_mask_1 = tf.concat(
input_q = tf.concat([input_q_0, input_q_1], axis=0) [tf.zeros([non_reuse_len, reuse_len]), perm_mask_1], axis=1)
perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0)
target_mask = tf.concat([target_mask_0, target_mask_1], axis=0)
input_k = tf.concat([input_k_0, input_k_1], axis=0)
input_q = tf.concat([input_q_0, input_q_1], axis=0)
else:
##### Do not use memory
assert seq_len % perm_size == 0
# permutate the entire sequence together
perm_mask, target_mask, input_k, input_q = _local_perm(
inputs, is_masked, perm_size, seq_len, leak_ratio)
# reshape back to fixed shape
example["perm_mask"] = tf.reshape(perm_mask, [seq_len, seq_len])
example["input_k"] = tf.reshape(input_k, [seq_len])
example["input_q"] = tf.reshape(input_q, [seq_len])
# Directly use raw inputs as the target
target = inputs
if num_predict is not None: if num_predict is not None:
indices = tf.range(seq_len, dtype=tf.int64) indices = tf.range(seq_len, dtype=tf.int64)
...@@ -327,21 +536,15 @@ def create_pretrain_dataset(file_names, ...@@ -327,21 +536,15 @@ def create_pretrain_dataset(file_names,
example["target"] = tf.reshape(target, [num_predict]) example["target"] = tf.reshape(target, [num_predict])
##### target mask ##### target mask
target_mask = tf.concat([ target_mask = tf.concat(
tf.ones([actual_num_predict], dtype=tf.float32), [tf.ones([actual_num_predict], dtype=tf.float32),
tf.zeros([pad_len], dtype=tf.float32) tf.zeros([pad_len], dtype=tf.float32)],
], axis=0)
axis=0)
example["target_mask"] = tf.reshape(target_mask, [num_predict]) example["target_mask"] = tf.reshape(target_mask, [num_predict])
else: else:
example["target"] = tf.reshape(target, [seq_len]) example["target"] = tf.reshape(target, [seq_len])
example["target_mask"] = tf.reshape(target_mask, [seq_len]) example["target_mask"] = tf.reshape(target_mask, [seq_len])
# reshape back to fixed shape
example["perm_mask"] = tf.reshape(perm_mask, [seq_len, seq_len])
example["input_k"] = tf.reshape(input_k, [seq_len])
example["input_q"] = tf.reshape(input_q, [seq_len])
for key in list(example.keys()): for key in list(example.keys()):
val = example[key] val = example[key]
if tf.keras.backend.is_sparse(val): if tf.keras.backend.is_sparse(val):
...@@ -360,42 +563,29 @@ def create_pretrain_dataset(file_names, ...@@ -360,42 +563,29 @@ def create_pretrain_dataset(file_names,
parser=parser, parser=parser,
file_paths=file_names, file_paths=file_names,
bsz_per_core=bsz_per_core, bsz_per_core=bsz_per_core,
sequential=reuse_len > 0,
input_pipeline_context=input_pipeline_context) input_pipeline_context=input_pipeline_context)
return dataset return dataset
def format_filename(prefix, def format_filename(prefix, suffix, bsz_per_host, seq_len, reuse_len=None,
bsz_per_host, uncased=False):
seq_len,
bi_data,
suffix,
mask_alpha=5,
mask_beta=1,
reuse_len=None,
uncased=False,
fixed_num_predict=None):
"""Generates input file name pattern.""" """Generates input file name pattern."""
if reuse_len is None: if reuse_len is not None and reuse_len > 0:
reuse_len_str = "" reuse_str = "reuse-{}.".format(reuse_len)
bsz_str = "hostbsz-{}.".format(bsz_per_host)
else: else:
reuse_len_str = "reuse-{}.".format(reuse_len) reuse_str = ""
bsz_str = ""
if not uncased: if not uncased:
uncased_str = "" case_str = ""
else:
uncased_str = "uncased."
if bi_data:
bi_data_str = "bi"
else: else:
bi_data_str = "uni" case_str = "uncased."
if fixed_num_predict is not None:
fnp_str = "fnp-{}.".format(fixed_num_predict)
else:
fnp_str = ""
file_name = "{}.bsz-{}.seqlen-{}.{}{}{}.alpha-{}.beta-{}.{}{}".format( file_name = "{}.seq-{}.{}{}{}{}".format(
prefix, bsz_per_host, seq_len, reuse_len_str, uncased_str, bi_data_str, prefix, seq_len, reuse_str, bsz_str, case_str, suffix)
mask_alpha, mask_beta, fnp_str, suffix)
return file_name return file_name
...@@ -406,11 +596,10 @@ def get_pretrain_input_data(batch_size, ...@@ -406,11 +596,10 @@ def get_pretrain_input_data(batch_size,
file_path, file_path,
reuse_len, reuse_len,
perm_size, perm_size,
mask_alpha, leak_ratio,
mask_beta,
num_predict, num_predict,
bi_data,
uncased, uncased,
online_masking_config,
num_hosts=1): num_hosts=1):
"""Returns input dataset from input file string.""" """Returns input dataset from input file string."""
...@@ -419,17 +608,22 @@ def get_pretrain_input_data(batch_size, ...@@ -419,17 +608,22 @@ def get_pretrain_input_data(batch_size,
# than passing dataset instance itself. # than passing dataset instance itself.
use_dataset_fn = isinstance(strategy, tf.distribute.experimental.TPUStrategy) use_dataset_fn = isinstance(strategy, tf.distribute.experimental.TPUStrategy)
split = "train" split = "train"
bsz_per_host = int(batch_size / num_hosts)
record_glob_base = format_filename( record_glob_base = format_filename(
prefix="record_info-{}-*".format(split), prefix="meta.{}.pass-*".format(split),
bsz_per_host=int(batch_size / num_hosts), suffix="json*",
bsz_per_host=bsz_per_host,
seq_len=seq_len, seq_len=seq_len,
bi_data=bi_data,
suffix="json",
mask_alpha=mask_alpha,
mask_beta=mask_beta,
reuse_len=reuse_len, reuse_len=reuse_len,
uncased=uncased, uncased=uncased)
fixed_num_predict=num_predict)
def _get_num_batch(info):
if "num_batch" in info:
return info["num_batch"]
elif "num_example" in info:
return info["num_example"] / bsz_per_host
else:
raise ValueError("Do not have sample info.")
if use_dataset_fn: if use_dataset_fn:
if batch_size % strategy.num_replicas_in_sync != 0: if batch_size % strategy.num_replicas_in_sync != 0:
...@@ -460,7 +654,7 @@ def get_pretrain_input_data(batch_size, ...@@ -460,7 +654,7 @@ def get_pretrain_input_data(batch_size,
for record_info_path in record_paths: for record_info_path in record_paths:
with tf.io.gfile.GFile(record_info_path, "r") as fp: with tf.io.gfile.GFile(record_info_path, "r") as fp:
info = json.load(fp) info = json.load(fp)
cur_record_info["num_batch"] += info["num_batch"] cur_record_info["num_batch"] += int(_get_num_batch(info))
cur_record_info["filenames"] += info["filenames"] cur_record_info["filenames"] += info["filenames"]
# overwrite directory for `cur_record_info` # overwrite directory for `cur_record_info`
...@@ -494,6 +688,8 @@ def get_pretrain_input_data(batch_size, ...@@ -494,6 +688,8 @@ def get_pretrain_input_data(batch_size,
seq_len=seq_len, seq_len=seq_len,
reuse_len=reuse_len, reuse_len=reuse_len,
perm_size=perm_size, perm_size=perm_size,
leak_ratio=leak_ratio,
online_masking_config=online_masking_config,
num_predict=num_predict, num_predict=num_predict,
input_pipeline_context=ctx) input_pipeline_context=ctx)
return train_dataset return train_dataset
...@@ -504,6 +700,7 @@ def get_pretrain_input_data(batch_size, ...@@ -504,6 +700,7 @@ def get_pretrain_input_data(batch_size,
def parse_files_to_dataset(parser, def parse_files_to_dataset(parser,
file_paths, file_paths,
bsz_per_core, bsz_per_core,
sequential,
input_pipeline_context=None): input_pipeline_context=None):
"""Creates the dataset given file paths.""" """Creates the dataset given file paths."""
...@@ -519,7 +716,26 @@ def parse_files_to_dataset(parser, ...@@ -519,7 +716,26 @@ def parse_files_to_dataset(parser,
if len(file_paths) > 1: if len(file_paths) > 1:
dataset = dataset.shuffle(len(file_paths)) dataset = dataset.shuffle(len(file_paths))
dataset = tf.data.TFRecordDataset(dataset) if sequential:
# Note: cannot perform sample-level shuffle here because this will violate
# the consecutive requirement of data stream.
dataset = tf.data.TFRecordDataset(dataset)
else:
# `cycle_length` is the number of parallel files that get read.
cycle_length = min(8, len(file_paths))
logging.info("Interleave %d files", cycle_length)
# `sloppy` mode means that the interleaving is not exact. This adds
# even more randomness to the training pipeline.
dataset = dataset.apply(
tf.data.experimental.parallel_interleave(
tf.data.TFRecordDataset,
sloppy=True,
cycle_length=cycle_length))
buffer_size = 2048
logging.info("Perform sample-level shuffle with size %d", buffer_size)
dataset = dataset.shuffle(buffer_size=buffer_size)
# (zihang): since we are doing online preprocessing, the parsed result of # (zihang): since we are doing online preprocessing, the parsed result of
# the same input at each time will be different. Thus, cache processed data # the same input at each time will be different. Thus, cache processed data
# is not helpful. It will use a lot of memory and lead to contrainer OOM. # is not helpful. It will use a lot of memory and lead to contrainer OOM.
...@@ -531,19 +747,19 @@ def parse_files_to_dataset(parser, ...@@ -531,19 +747,19 @@ def parse_files_to_dataset(parser,
return dataset return dataset
def _local_perm(inputs, targets, is_masked, perm_size, seq_len): def _local_perm(inputs, is_masked, perm_size, seq_len, leak_ratio):
"""Samples a permutation of the factorization order. """Samples a permutation of the factorization order.
Creates perm_mask and target_mask accordingly. Creates perm_mask and target_mask accordingly.
Args: Args:
inputs: int64 Tensor in shape [seq_len], input ids. inputs: int64 Tensor in shape [seq_len], input ids.
targets: int64 Tensor in shape [seq_len], target ids.
is_masked: bool Tensor in shape [seq_len]. True means being selected for is_masked: bool Tensor in shape [seq_len]. True means being selected for
partial prediction. partial prediction.
perm_size: the length of longest permutation. Could be set to be reuse_len. perm_size: the length of longest permutation. Could be set to be reuse_len.
Should not be larger than reuse_len or there will be data leaks. Should not be larger than reuse_len or there will be data leaks.
seq_len: int, sequence length. seq_len: int, sequence length.
leak_ratio: float, percent of masked tokens that are leaked.
Returns: Returns:
perm_mask: float32 Tensor in shape [seq_len, seq_len] consisted of 0 and 1. perm_mask: float32 Tensor in shape [seq_len, seq_len] consisted of 0 and 1.
...@@ -555,9 +771,6 @@ def _local_perm(inputs, targets, is_masked, perm_size, seq_len): ...@@ -555,9 +771,6 @@ def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
means the ith token (in original order) can attend to the jth token means the ith token (in original order) can attend to the jth token
(in original order). Note that non-masked tokens can be attended by all (in original order). Note that non-masked tokens can be attended by all
other tokens, which is different from the description in original paper. other tokens, which is different from the description in original paper.
new_targets: int64 Tensor in shape [seq_len], target token ids to be
predicted in XLNet.
In XLNet, target doesn't need to be shifted one position.
target_mask: float32 Tensor in shape [seq_len] consisted of 0 and 1. If target_mask: float32 Tensor in shape [seq_len] consisted of 0 and 1. If
target_mask[i] == 1, target_mask[i] == 1,
the ith token needs to be predicted and mask will be used as input. This the ith token needs to be predicted and mask will be used as input. This
...@@ -575,44 +788,40 @@ def _local_perm(inputs, targets, is_masked, perm_size, seq_len): ...@@ -575,44 +788,40 @@ def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
index = tf.random.shuffle(index) index = tf.random.shuffle(index)
index = tf.reshape(tf.transpose(index), [-1]) index = tf.reshape(tf.transpose(index), [-1])
# `perm_mask` and `target_mask`
# non-functional tokens # non-functional tokens
non_func_tokens = tf.logical_not( non_func_tokens = tf.logical_not(tf.logical_or(
tf.logical_or(tf.equal(inputs, SEP_ID), tf.equal(inputs, CLS_ID))) tf.equal(inputs, SEP_ID),
tf.equal(inputs, CLS_ID)))
non_mask_tokens = tf.logical_and(tf.logical_not(is_masked), non_func_tokens) masked_tokens = tf.logical_and(is_masked, non_func_tokens)
masked_or_func_tokens = tf.logical_not(non_mask_tokens) non_masked_or_func_tokens = tf.logical_not(masked_tokens)
# Set the permutation indices of non-masked (& non-funcional) tokens to the smallest_index = -2 * tf.ones([seq_len], dtype=tf.int64)
# smallest index (-1):
# (1) they can be seen by all other positions # Similar to BERT, randomly leak some masked tokens
# (2) they cannot see masked positions, so there won"t be information leak if leak_ratio > 0:
smallest_index = -tf.ones([seq_len], dtype=tf.int64) leak_tokens = tf.logical_and(
rev_index = tf.where(non_mask_tokens, smallest_index, index) masked_tokens,
tf.random.uniform([seq_len], maxval=1.0) < leak_ratio)
# Create `target_mask`: non-funcional and masked tokens can_attend_self = tf.logical_or(non_masked_or_func_tokens, leak_tokens)
# 1: use mask as input and have loss else:
# 0: use token (or [SEP], [CLS]) as input and do not have loss can_attend_self = non_masked_or_func_tokens
target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens) to_index = tf.where(can_attend_self, smallest_index, index)
target_mask = tf.cast(target_tokens, tf.float32) from_index = tf.where(can_attend_self, to_index + 1, to_index)
# Create `perm_mask` # For masked tokens, can attend if i > j
# `target_tokens` cannot see themselves # For context tokens, always can attend each other
self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1) can_attend = from_index[:, None] > to_index[None, :]
# 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens) # In modeling, 1 indicates cannot attend. Hence, reverse the value here.
# 0: can attend if i > j or j is non-masked perm_mask = 1.0 - tf.cast(can_attend, tf.float32)
perm_mask = tf.logical_and(self_rev_index[:, None] <= rev_index[None, :],
masked_or_func_tokens) # Only masked tokens are included in the loss
perm_mask = tf.cast(perm_mask, tf.float32) target_mask = tf.cast(masked_tokens, tf.float32)
# new target: [next token] for LM and [curr token] (self) for PLM
new_targets = tf.concat([inputs[0:1], targets[:-1]], axis=0)
# construct inputs_k # construct inputs_k
inputs_k = inputs inputs_k = inputs
# construct inputs_q # construct inputs_q
inputs_q = target_mask inputs_q = masked_tokens
return perm_mask, new_targets, target_mask, inputs_k, inputs_q return perm_mask, target_mask, inputs_k, inputs_q
...@@ -35,16 +35,33 @@ from official.nlp.xlnet import optimization ...@@ -35,16 +35,33 @@ from official.nlp.xlnet import optimization
from official.nlp.xlnet import training_utils from official.nlp.xlnet import training_utils
from official.utils.misc import tpu_lib from official.utils.misc import tpu_lib
flags.DEFINE_integer(
"mask_alpha", default=6, help="How many tokens to form a group.")
flags.DEFINE_integer(
"mask_beta", default=1, help="How many tokens to mask within each group.")
flags.DEFINE_integer( flags.DEFINE_integer(
"num_predict", "num_predict",
default=None, default=None,
help="Number of tokens to predict in partial prediction.") help="Number of tokens to predict in partial prediction.")
flags.DEFINE_integer("perm_size", 0, help="Window size of permutation.")
# FLAGS for pretrain input preprocessing
flags.DEFINE_integer("perm_size", 0, help="Window size of permutation.")
flags.DEFINE_float("leak_ratio", default=0.1,
help="Percent of masked tokens that are leaked.")
flags.DEFINE_enum("sample_strategy", default="token_span",
enum_values=["single_token", "whole_word", "token_span",
"word_span"],
help="Stragey used to sample prediction targets.")
flags.DEFINE_integer("max_num_tokens", default=5,
help="Maximum number of tokens to sample in a span."
"Effective when token_span strategy is used.")
flags.DEFINE_integer("min_num_tokens", default=1,
help="Minimum number of tokens to sample in a span."
"Effective when token_span strategy is used.")
flags.DEFINE_integer("max_num_words", default=5,
help="Maximum number of whole words to sample in a span."
"Effective when word_span strategy is used.")
flags.DEFINE_integer("min_num_words", default=1,
help="Minimum number of whole words to sample in a span."
"Effective when word_span strategy is used.")
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -74,11 +91,18 @@ def main(unused_argv): ...@@ -74,11 +91,18 @@ def main(unused_argv):
logging.info("***** Number of cores used : %d", logging.info("***** Number of cores used : %d",
strategy.num_replicas_in_sync) strategy.num_replicas_in_sync)
logging.info("***** Number of hosts used : %d", num_hosts) logging.info("***** Number of hosts used : %d", num_hosts)
online_masking_config = data_utils.OnlineMaskingConfig(
sample_strategy=FLAGS.sample_strategy,
max_num_tokens=FLAGS.max_num_tokens,
min_num_tokens=FLAGS.min_num_tokens,
max_num_words=FLAGS.max_num_words,
min_num_words=FLAGS.min_num_words)
train_input_fn = functools.partial( train_input_fn = functools.partial(
data_utils.get_pretrain_input_data, FLAGS.train_batch_size, FLAGS.seq_len, data_utils.get_pretrain_input_data, FLAGS.train_batch_size, FLAGS.seq_len,
strategy, FLAGS.train_tfrecord_path, FLAGS.reuse_len, FLAGS.perm_size, strategy, FLAGS.train_tfrecord_path, FLAGS.reuse_len, FLAGS.perm_size,
FLAGS.mask_alpha, FLAGS.mask_beta, FLAGS.num_predict, FLAGS.bi_data, FLAGS.leak_ratio, FLAGS.num_predict, FLAGS.uncased, online_masking_config,
FLAGS.uncased, num_hosts) num_hosts)
total_training_steps = FLAGS.train_steps total_training_steps = FLAGS.train_steps
steps_per_epoch = int(FLAGS.train_data_size / FLAGS.train_batch_size) steps_per_epoch = int(FLAGS.train_data_size / FLAGS.train_batch_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment