Commit 6cd426d9 authored by Jing Li's avatar Jing Li Committed by A. Unique TensorFlower
Browse files

Support online masking for XLNet

PiperOrigin-RevId: 275408074
parent b0581d0a
......@@ -19,12 +19,15 @@ from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import collections
import json
import os
from absl import logging
import numpy as np
import tensorflow as tf
special_symbols = {
"<unk>": 0,
"<s>": 1,
......@@ -49,6 +52,11 @@ SEG_ID_CLS = 2
SEG_ID_PAD = 3
OnlineMaskingConfig = collections.namedtuple("OnlineMaskingConfig", [
"sample_strategy", "max_num_tokens", "min_num_tokens", "max_num_words",
"min_num_words"])
def file_based_input_fn_builder(input_file, name_to_features, batch_size,
is_training):
"""Creates an `input_fn` closure."""
......@@ -249,11 +257,191 @@ def get_squad_input_data(batch_size, seq_len, q_len, strategy, is_training,
return _dataset_fn if use_dataset_fn else _dataset_fn()
def _idx_pair_to_mask(beg_indices, end_indices, inputs, tgt_len, num_predict):
"""Turn beg and end indices into actual mask."""
non_func_mask = tf.logical_and(
tf.not_equal(inputs, SEP_ID),
tf.not_equal(inputs, CLS_ID))
all_indices = tf.where(
non_func_mask,
tf.range(tgt_len, dtype=tf.int64),
tf.constant(-1, shape=[tgt_len], dtype=tf.int64))
candidate_matrix = tf.cast(
tf.logical_and(
all_indices[None, :] >= beg_indices[:, None],
all_indices[None, :] < end_indices[:, None]),
tf.float32)
cumsum_matrix = tf.reshape(
tf.cumsum(tf.reshape(candidate_matrix, [-1])),
[-1, tgt_len])
masked_matrix = tf.cast(cumsum_matrix <= num_predict, tf.float32)
target_mask = tf.reduce_sum(candidate_matrix * masked_matrix, axis=0)
is_masked = tf.cast(target_mask, tf.bool)
return is_masked, target_mask
def _word_span_mask(inputs, tgt_len, num_predict, min_num_words,
max_num_words, boundary):
"""Sample whole word spans as prediction targets."""
# Note: 1.2 is the token-to-word ratio
mask_alpha = tgt_len / num_predict / 1.2
round_to_int = lambda x: tf.cast(tf.round(x), tf.int64)
# Sample span lengths from a zipf distribution
span_len_seq = np.arange(min_num_words, max_num_words + 1)
probs = np.array([1.0 / (i + 1) for i in span_len_seq])
probs /= np.sum(probs)
logits = tf.constant(np.log(probs), dtype=tf.float32)
# Sample `num_predict` words here: note that this is over sampling
span_lens = tf.random.categorical(
logits=logits[None],
num_samples=num_predict,
dtype=tf.int64,
)[0] + min_num_words
# Sample the ratio [0.0, 1.0) of left context lengths
span_lens_float = tf.cast(span_lens, tf.float32)
left_ratio = tf.random.uniform(shape=[num_predict], minval=0.0, maxval=1.0)
left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1)
left_ctx_len = round_to_int(left_ctx_len)
right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len
beg_indices = (tf.cumsum(left_ctx_len) +
tf.cumsum(right_offset, exclusive=True))
end_indices = beg_indices + span_lens
# Remove out of range indices
max_boundary_index = tf.cast(tf.shape(boundary)[0] - 1, tf.int64)
valid_idx_mask = end_indices < max_boundary_index
beg_indices = tf.boolean_mask(beg_indices, valid_idx_mask)
end_indices = tf.boolean_mask(end_indices, valid_idx_mask)
beg_indices = tf.gather(boundary, beg_indices)
end_indices = tf.gather(boundary, end_indices)
# Shuffle valid indices
num_valid = tf.cast(tf.shape(beg_indices)[0], tf.int64)
order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int64))
beg_indices = tf.gather(beg_indices, order)
end_indices = tf.gather(end_indices, order)
return _idx_pair_to_mask(beg_indices, end_indices, inputs, tgt_len,
num_predict)
def _token_span_mask(inputs, tgt_len, num_predict, min_num_tokens,
max_num_tokens):
"""Sample token spans as prediction targets."""
mask_alpha = tgt_len / num_predict
round_to_int = lambda x: tf.cast(tf.round(x), tf.int64)
# Sample span lengths from a zipf distribution
span_len_seq = np.arange(min_num_tokens, max_num_tokens + 1)
probs = np.array([1.0 / (i + 1) for i in span_len_seq])
probs /= np.sum(probs)
logits = tf.constant(np.log(probs), dtype=tf.float32)
span_lens = tf.random.categorical(
logits=logits[None],
num_samples=num_predict,
dtype=tf.int64,
)[0] + min_num_tokens
# Sample the ratio [0.0, 1.0) of left context lengths
span_lens_float = tf.cast(span_lens, tf.float32)
left_ratio = tf.random.uniform(shape=[num_predict], minval=0.0, maxval=1.0)
left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1)
left_ctx_len = round_to_int(left_ctx_len)
# Compute the offset from left start to the right end
right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len
# Get the actual begin and end indices
beg_indices = (tf.cumsum(left_ctx_len) +
tf.cumsum(right_offset, exclusive=True))
end_indices = beg_indices + span_lens
# Remove out of range indices
valid_idx_mask = end_indices < tgt_len
beg_indices = tf.boolean_mask(beg_indices, valid_idx_mask)
end_indices = tf.boolean_mask(end_indices, valid_idx_mask)
# Shuffle valid indices
num_valid = tf.cast(tf.shape(beg_indices)[0], tf.int64)
order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int64))
beg_indices = tf.gather(beg_indices, order)
end_indices = tf.gather(end_indices, order)
return _idx_pair_to_mask(beg_indices, end_indices, inputs, tgt_len,
num_predict)
def _whole_word_mask(inputs, tgt_len, num_predict, boundary):
"""Sample whole words as prediction targets."""
pair_indices = tf.concat([boundary[:-1, None], boundary[1:, None]], axis=1)
cand_pair_indices = tf.random.shuffle(pair_indices)[:num_predict]
beg_indices = cand_pair_indices[:, 0]
end_indices = cand_pair_indices[:, 1]
return _idx_pair_to_mask(beg_indices, end_indices, inputs, tgt_len,
num_predict)
def _single_token_mask(inputs, tgt_len, num_predict):
"""Sample individual tokens as prediction targets."""
all_indices = tf.range(tgt_len, dtype=tf.int64)
non_func_mask = tf.logical_and(
tf.not_equal(inputs, SEP_ID),
tf.not_equal(inputs, CLS_ID))
non_func_indices = tf.boolean_mask(all_indices, non_func_mask)
masked_pos = tf.random.shuffle(non_func_indices)
masked_pos = tf.contrib.framework.sort(masked_pos[:num_predict])
target_mask = tf.sparse_to_dense(
sparse_indices=masked_pos,
output_shape=[tgt_len],
sparse_values=1.0,
default_value=0.0)
is_masked = tf.cast(target_mask, tf.bool)
return is_masked, target_mask
def _online_sample_masks(inputs, tgt_len, num_predict, online_masking_config,
boundary=None):
"""Sample target positions to predict."""
logging.info("Online sample with strategy: `%s`.",
online_masking_config.sample_strategy)
if online_masking_config.sample_strategy == "single_token":
return _single_token_mask(inputs, tgt_len, num_predict)
elif online_masking_config.sample_strategy == "whole_word":
assert boundary is not None, "whole word sampling requires `boundary`"
return _whole_word_mask(inputs, tgt_len, num_predict, boundary)
elif online_masking_config.sample_strategy == "token_span":
return _token_span_mask(inputs, tgt_len, num_predict,
online_masking_config.min_num_tokens,
online_masking_config.max_num_tokens)
elif online_masking_config.sample_strategy == "word_span":
assert boundary is not None, "word span sampling requires `boundary`"
return _word_span_mask(inputs, tgt_len, num_predict,
online_masking_config.min_num_words,
online_masking_config.max_num_words,
boundary)
else:
raise NotImplementedError
def create_pretrain_dataset(file_names,
bsz_per_core,
seq_len,
reuse_len,
perm_size,
leak_ratio,
online_masking_config,
num_predict=None,
input_pipeline_context=None):
"""Creates pretrain dataset."""
......@@ -263,46 +451,67 @@ def create_pretrain_dataset(file_names,
record_spec = {
"input": tf.io.FixedLenFeature([seq_len], tf.int64),
"target": tf.io.FixedLenFeature([seq_len], tf.int64),
"seg_id": tf.io.FixedLenFeature([seq_len], tf.int64),
"label": tf.io.FixedLenFeature([1], tf.int64),
"is_masked": tf.io.FixedLenFeature([seq_len], tf.int64),
}
if online_masking_config.sample_strategy in ["whole_word", "word_span"]:
logging.info("Add `boundary` spec for %s",
online_masking_config.sample_strategy)
record_spec["boundary"] = tf.io.VarLenFeature(tf.int64)
# retrieve serialized example
example = tf.io.parse_single_example(
serialized=record, features=record_spec)
inputs = example.pop("input")
target = example.pop("target")
is_masked = tf.cast(example.pop("is_masked"), tf.bool)
non_reuse_len = seq_len - reuse_len
# perm_size should not be larger than reuse_len or non_reuse_len otherwise
# there will be data leaks.
assert perm_size <= reuse_len and perm_size <= non_reuse_len
# Creates permutation mask and target mask for the first reuse_len tokens.
# The tokens in this part are reused from the last sequence.
perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0 = _local_perm(
inputs[:reuse_len], target[:reuse_len], is_masked[:reuse_len],
perm_size, reuse_len)
# Creates permutation mask and target mask for the rest of tokens in
# current example, which are concatentation of two new segments.
perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1 = _local_perm(
inputs[reuse_len:], target[reuse_len:], is_masked[reuse_len:],
perm_size, non_reuse_len)
perm_mask_0 = tf.concat(
[perm_mask_0, tf.ones([reuse_len, non_reuse_len])], axis=1)
perm_mask_1 = tf.concat([tf.zeros([non_reuse_len, reuse_len]), perm_mask_1],
axis=1)
perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0)
target = tf.concat([target_0, target_1], axis=0)
target_mask = tf.concat([target_mask_0, target_mask_1], axis=0)
input_k = tf.concat([input_k_0, input_k_1], axis=0)
input_q = tf.concat([input_q_0, input_q_1], axis=0)
if online_masking_config.sample_strategy in ["whole_word", "word_span"]:
boundary = tf.sparse.to_dense(example.pop("boundary"))
else:
boundary = None
is_masked, _ = _online_sample_masks(
inputs, seq_len, num_predict, online_masking_config, boundary=boundary)
if reuse_len > 0:
##### Use memory
# permutate the reuse and non-reuse parts separately
non_reuse_len = seq_len - reuse_len
assert reuse_len % perm_size == 0 and non_reuse_len % perm_size == 0
# Creates permutation mask and target mask for the first reuse_len tokens.
# The tokens in this part are reused from the last sequence.
perm_mask_0, target_mask_0, input_k_0, input_q_0 = _local_perm(
inputs[:reuse_len], is_masked[:reuse_len], perm_size, reuse_len,
leak_ratio)
# Creates permutation mask and target mask for the rest of tokens in
# current example, which are concatentation of two new segments.
perm_mask_1, target_mask_1, input_k_1, input_q_1 = _local_perm(
inputs[reuse_len:], is_masked[reuse_len:], perm_size, non_reuse_len,
leak_ratio)
perm_mask_0 = tf.concat(
[perm_mask_0, tf.ones([reuse_len, non_reuse_len])], axis=1)
perm_mask_1 = tf.concat(
[tf.zeros([non_reuse_len, reuse_len]), perm_mask_1], axis=1)
perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0)
target_mask = tf.concat([target_mask_0, target_mask_1], axis=0)
input_k = tf.concat([input_k_0, input_k_1], axis=0)
input_q = tf.concat([input_q_0, input_q_1], axis=0)
else:
##### Do not use memory
assert seq_len % perm_size == 0
# permutate the entire sequence together
perm_mask, target_mask, input_k, input_q = _local_perm(
inputs, is_masked, perm_size, seq_len, leak_ratio)
# reshape back to fixed shape
example["perm_mask"] = tf.reshape(perm_mask, [seq_len, seq_len])
example["input_k"] = tf.reshape(input_k, [seq_len])
example["input_q"] = tf.reshape(input_q, [seq_len])
# Directly use raw inputs as the target
target = inputs
if num_predict is not None:
indices = tf.range(seq_len, dtype=tf.int64)
......@@ -327,21 +536,15 @@ def create_pretrain_dataset(file_names,
example["target"] = tf.reshape(target, [num_predict])
##### target mask
target_mask = tf.concat([
tf.ones([actual_num_predict], dtype=tf.float32),
tf.zeros([pad_len], dtype=tf.float32)
],
axis=0)
target_mask = tf.concat(
[tf.ones([actual_num_predict], dtype=tf.float32),
tf.zeros([pad_len], dtype=tf.float32)],
axis=0)
example["target_mask"] = tf.reshape(target_mask, [num_predict])
else:
example["target"] = tf.reshape(target, [seq_len])
example["target_mask"] = tf.reshape(target_mask, [seq_len])
# reshape back to fixed shape
example["perm_mask"] = tf.reshape(perm_mask, [seq_len, seq_len])
example["input_k"] = tf.reshape(input_k, [seq_len])
example["input_q"] = tf.reshape(input_q, [seq_len])
for key in list(example.keys()):
val = example[key]
if tf.keras.backend.is_sparse(val):
......@@ -360,42 +563,29 @@ def create_pretrain_dataset(file_names,
parser=parser,
file_paths=file_names,
bsz_per_core=bsz_per_core,
sequential=reuse_len > 0,
input_pipeline_context=input_pipeline_context)
return dataset
def format_filename(prefix,
bsz_per_host,
seq_len,
bi_data,
suffix,
mask_alpha=5,
mask_beta=1,
reuse_len=None,
uncased=False,
fixed_num_predict=None):
def format_filename(prefix, suffix, bsz_per_host, seq_len, reuse_len=None,
uncased=False):
"""Generates input file name pattern."""
if reuse_len is None:
reuse_len_str = ""
if reuse_len is not None and reuse_len > 0:
reuse_str = "reuse-{}.".format(reuse_len)
bsz_str = "hostbsz-{}.".format(bsz_per_host)
else:
reuse_len_str = "reuse-{}.".format(reuse_len)
reuse_str = ""
bsz_str = ""
if not uncased:
uncased_str = ""
else:
uncased_str = "uncased."
if bi_data:
bi_data_str = "bi"
case_str = ""
else:
bi_data_str = "uni"
if fixed_num_predict is not None:
fnp_str = "fnp-{}.".format(fixed_num_predict)
else:
fnp_str = ""
case_str = "uncased."
file_name = "{}.bsz-{}.seqlen-{}.{}{}{}.alpha-{}.beta-{}.{}{}".format(
prefix, bsz_per_host, seq_len, reuse_len_str, uncased_str, bi_data_str,
mask_alpha, mask_beta, fnp_str, suffix)
file_name = "{}.seq-{}.{}{}{}{}".format(
prefix, seq_len, reuse_str, bsz_str, case_str, suffix)
return file_name
......@@ -406,11 +596,10 @@ def get_pretrain_input_data(batch_size,
file_path,
reuse_len,
perm_size,
mask_alpha,
mask_beta,
leak_ratio,
num_predict,
bi_data,
uncased,
online_masking_config,
num_hosts=1):
"""Returns input dataset from input file string."""
......@@ -419,17 +608,22 @@ def get_pretrain_input_data(batch_size,
# than passing dataset instance itself.
use_dataset_fn = isinstance(strategy, tf.distribute.experimental.TPUStrategy)
split = "train"
bsz_per_host = int(batch_size / num_hosts)
record_glob_base = format_filename(
prefix="record_info-{}-*".format(split),
bsz_per_host=int(batch_size / num_hosts),
prefix="meta.{}.pass-*".format(split),
suffix="json*",
bsz_per_host=bsz_per_host,
seq_len=seq_len,
bi_data=bi_data,
suffix="json",
mask_alpha=mask_alpha,
mask_beta=mask_beta,
reuse_len=reuse_len,
uncased=uncased,
fixed_num_predict=num_predict)
uncased=uncased)
def _get_num_batch(info):
if "num_batch" in info:
return info["num_batch"]
elif "num_example" in info:
return info["num_example"] / bsz_per_host
else:
raise ValueError("Do not have sample info.")
if use_dataset_fn:
if batch_size % strategy.num_replicas_in_sync != 0:
......@@ -460,7 +654,7 @@ def get_pretrain_input_data(batch_size,
for record_info_path in record_paths:
with tf.io.gfile.GFile(record_info_path, "r") as fp:
info = json.load(fp)
cur_record_info["num_batch"] += info["num_batch"]
cur_record_info["num_batch"] += int(_get_num_batch(info))
cur_record_info["filenames"] += info["filenames"]
# overwrite directory for `cur_record_info`
......@@ -494,6 +688,8 @@ def get_pretrain_input_data(batch_size,
seq_len=seq_len,
reuse_len=reuse_len,
perm_size=perm_size,
leak_ratio=leak_ratio,
online_masking_config=online_masking_config,
num_predict=num_predict,
input_pipeline_context=ctx)
return train_dataset
......@@ -504,6 +700,7 @@ def get_pretrain_input_data(batch_size,
def parse_files_to_dataset(parser,
file_paths,
bsz_per_core,
sequential,
input_pipeline_context=None):
"""Creates the dataset given file paths."""
......@@ -519,7 +716,26 @@ def parse_files_to_dataset(parser,
if len(file_paths) > 1:
dataset = dataset.shuffle(len(file_paths))
dataset = tf.data.TFRecordDataset(dataset)
if sequential:
# Note: cannot perform sample-level shuffle here because this will violate
# the consecutive requirement of data stream.
dataset = tf.data.TFRecordDataset(dataset)
else:
# `cycle_length` is the number of parallel files that get read.
cycle_length = min(8, len(file_paths))
logging.info("Interleave %d files", cycle_length)
# `sloppy` mode means that the interleaving is not exact. This adds
# even more randomness to the training pipeline.
dataset = dataset.apply(
tf.data.experimental.parallel_interleave(
tf.data.TFRecordDataset,
sloppy=True,
cycle_length=cycle_length))
buffer_size = 2048
logging.info("Perform sample-level shuffle with size %d", buffer_size)
dataset = dataset.shuffle(buffer_size=buffer_size)
# (zihang): since we are doing online preprocessing, the parsed result of
# the same input at each time will be different. Thus, cache processed data
# is not helpful. It will use a lot of memory and lead to contrainer OOM.
......@@ -531,19 +747,19 @@ def parse_files_to_dataset(parser,
return dataset
def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
def _local_perm(inputs, is_masked, perm_size, seq_len, leak_ratio):
"""Samples a permutation of the factorization order.
Creates perm_mask and target_mask accordingly.
Args:
inputs: int64 Tensor in shape [seq_len], input ids.
targets: int64 Tensor in shape [seq_len], target ids.
is_masked: bool Tensor in shape [seq_len]. True means being selected for
partial prediction.
perm_size: the length of longest permutation. Could be set to be reuse_len.
Should not be larger than reuse_len or there will be data leaks.
seq_len: int, sequence length.
leak_ratio: float, percent of masked tokens that are leaked.
Returns:
perm_mask: float32 Tensor in shape [seq_len, seq_len] consisted of 0 and 1.
......@@ -555,9 +771,6 @@ def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
means the ith token (in original order) can attend to the jth token
(in original order). Note that non-masked tokens can be attended by all
other tokens, which is different from the description in original paper.
new_targets: int64 Tensor in shape [seq_len], target token ids to be
predicted in XLNet.
In XLNet, target doesn't need to be shifted one position.
target_mask: float32 Tensor in shape [seq_len] consisted of 0 and 1. If
target_mask[i] == 1,
the ith token needs to be predicted and mask will be used as input. This
......@@ -575,44 +788,40 @@ def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
index = tf.random.shuffle(index)
index = tf.reshape(tf.transpose(index), [-1])
# `perm_mask` and `target_mask`
# non-functional tokens
non_func_tokens = tf.logical_not(
tf.logical_or(tf.equal(inputs, SEP_ID), tf.equal(inputs, CLS_ID)))
non_mask_tokens = tf.logical_and(tf.logical_not(is_masked), non_func_tokens)
masked_or_func_tokens = tf.logical_not(non_mask_tokens)
# Set the permutation indices of non-masked (& non-funcional) tokens to the
# smallest index (-1):
# (1) they can be seen by all other positions
# (2) they cannot see masked positions, so there won"t be information leak
smallest_index = -tf.ones([seq_len], dtype=tf.int64)
rev_index = tf.where(non_mask_tokens, smallest_index, index)
# Create `target_mask`: non-funcional and masked tokens
# 1: use mask as input and have loss
# 0: use token (or [SEP], [CLS]) as input and do not have loss
target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens)
target_mask = tf.cast(target_tokens, tf.float32)
# Create `perm_mask`
# `target_tokens` cannot see themselves
self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1)
# 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens)
# 0: can attend if i > j or j is non-masked
perm_mask = tf.logical_and(self_rev_index[:, None] <= rev_index[None, :],
masked_or_func_tokens)
perm_mask = tf.cast(perm_mask, tf.float32)
# new target: [next token] for LM and [curr token] (self) for PLM
new_targets = tf.concat([inputs[0:1], targets[:-1]], axis=0)
non_func_tokens = tf.logical_not(tf.logical_or(
tf.equal(inputs, SEP_ID),
tf.equal(inputs, CLS_ID)))
masked_tokens = tf.logical_and(is_masked, non_func_tokens)
non_masked_or_func_tokens = tf.logical_not(masked_tokens)
smallest_index = -2 * tf.ones([seq_len], dtype=tf.int64)
# Similar to BERT, randomly leak some masked tokens
if leak_ratio > 0:
leak_tokens = tf.logical_and(
masked_tokens,
tf.random.uniform([seq_len], maxval=1.0) < leak_ratio)
can_attend_self = tf.logical_or(non_masked_or_func_tokens, leak_tokens)
else:
can_attend_self = non_masked_or_func_tokens
to_index = tf.where(can_attend_self, smallest_index, index)
from_index = tf.where(can_attend_self, to_index + 1, to_index)
# For masked tokens, can attend if i > j
# For context tokens, always can attend each other
can_attend = from_index[:, None] > to_index[None, :]
# In modeling, 1 indicates cannot attend. Hence, reverse the value here.
perm_mask = 1.0 - tf.cast(can_attend, tf.float32)
# Only masked tokens are included in the loss
target_mask = tf.cast(masked_tokens, tf.float32)
# construct inputs_k
inputs_k = inputs
# construct inputs_q
inputs_q = target_mask
inputs_q = masked_tokens
return perm_mask, new_targets, target_mask, inputs_k, inputs_q
return perm_mask, target_mask, inputs_k, inputs_q
......@@ -35,16 +35,33 @@ from official.nlp.xlnet import optimization
from official.nlp.xlnet import training_utils
from official.utils.misc import tpu_lib
flags.DEFINE_integer(
"mask_alpha", default=6, help="How many tokens to form a group.")
flags.DEFINE_integer(
"mask_beta", default=1, help="How many tokens to mask within each group.")
flags.DEFINE_integer(
"num_predict",
default=None,
help="Number of tokens to predict in partial prediction.")
flags.DEFINE_integer("perm_size", 0, help="Window size of permutation.")
# FLAGS for pretrain input preprocessing
flags.DEFINE_integer("perm_size", 0, help="Window size of permutation.")
flags.DEFINE_float("leak_ratio", default=0.1,
help="Percent of masked tokens that are leaked.")
flags.DEFINE_enum("sample_strategy", default="token_span",
enum_values=["single_token", "whole_word", "token_span",
"word_span"],
help="Stragey used to sample prediction targets.")
flags.DEFINE_integer("max_num_tokens", default=5,
help="Maximum number of tokens to sample in a span."
"Effective when token_span strategy is used.")
flags.DEFINE_integer("min_num_tokens", default=1,
help="Minimum number of tokens to sample in a span."
"Effective when token_span strategy is used.")
flags.DEFINE_integer("max_num_words", default=5,
help="Maximum number of whole words to sample in a span."
"Effective when word_span strategy is used.")
flags.DEFINE_integer("min_num_words", default=1,
help="Minimum number of whole words to sample in a span."
"Effective when word_span strategy is used.")
FLAGS = flags.FLAGS
......@@ -74,11 +91,18 @@ def main(unused_argv):
logging.info("***** Number of cores used : %d",
strategy.num_replicas_in_sync)
logging.info("***** Number of hosts used : %d", num_hosts)
online_masking_config = data_utils.OnlineMaskingConfig(
sample_strategy=FLAGS.sample_strategy,
max_num_tokens=FLAGS.max_num_tokens,
min_num_tokens=FLAGS.min_num_tokens,
max_num_words=FLAGS.max_num_words,
min_num_words=FLAGS.min_num_words)
train_input_fn = functools.partial(
data_utils.get_pretrain_input_data, FLAGS.train_batch_size, FLAGS.seq_len,
strategy, FLAGS.train_tfrecord_path, FLAGS.reuse_len, FLAGS.perm_size,
FLAGS.mask_alpha, FLAGS.mask_beta, FLAGS.num_predict, FLAGS.bi_data,
FLAGS.uncased, num_hosts)
FLAGS.leak_ratio, FLAGS.num_predict, FLAGS.uncased, online_masking_config,
num_hosts)
total_training_steps = FLAGS.train_steps
steps_per_epoch = int(FLAGS.train_data_size / FLAGS.train_batch_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment