Commit 6d6a78a2 authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Create XLNet pretrain data loader.

PiperOrigin-RevId: 342283301
parent 42f8e96e
......@@ -16,7 +16,10 @@
"""Loads dataset for the BERT pretraining task."""
from typing import Mapping, Optional
from absl import logging
import dataclasses
import numpy as np
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import input_reader
......@@ -125,3 +128,504 @@ class BertPretrainDataLoader(data_loader.DataLoader):
reader = input_reader.InputReader(
params=self._params, decoder_fn=self._decode, parser_fn=self._parse)
return reader.read(input_context)
@dataclasses.dataclass
class XLNetPretrainDataConfig(cfg.DataConfig):
"""Data config for XLNet pretraining task.
Attributes:
input_path: See base class.
global_batch_size: See base calss.
is_training: See base class.
seq_length: The length of each sequence.
max_predictions_per_seq: The number of predictions per sequence.
reuse_length: The number of tokens in a previous segment to reuse. This
should be the same value used during pretrain data creation.
sample_strategy: The strategy used to sample factorization permutations.
Possible values: 'fixed', 'single_token', 'whole_word', 'token_span',
'word_span'.
min_num_tokens: The minimum number of tokens to sample in a span.
This is used when `sample_strategy` is 'token_span'.
max_num_tokens: The maximum number of tokens to sample in a span.
This is used when `sample_strategy` is 'token_span'.
min_num_words: The minimum number of words to sample in a span.
This is used when `sample_strategy` is 'word_span'.
max_num_words: The maximum number of words to sample in a span.
This is used when `sample_strategy` is 'word_span'.
permutation_size: The length of the longest permutation. This can be set
to `reuse_length`. This should NOT be greater than `reuse_length`,
otherwise this may introduce data leaks.
leak_ratio: The percentage of masked tokens that are leaked.
segment_sep_id: The ID of the SEP token used when preprocessing
the dataset.
segment_cls_id: The ID of the CLS token used when preprocessing
the dataset.
"""
input_path: str = ''
global_batch_size: int = 512
is_training: bool = True
seq_length: int = 512
max_predictions_per_seq: int = 76
reuse_length: int = 256
sample_strategy: str = 'word_span'
min_num_tokens: int = 1
max_num_tokens: int = 5
min_num_words: int = 1
max_num_words: int = 5
permutation_size: int = 256
leak_ratio: float = 0.1
segment_sep_id: int = 4
segment_cls_id: int = 3
@data_loader_factory.register_data_loader_cls(XLNetPretrainDataConfig)
class XLNetPretrainDataLoader(data_loader.DataLoader):
"""A class to load dataset for xlnet pretraining task."""
def __init__(self, params: XLNetPretrainDataConfig):
"""Inits `XLNetPretrainDataLoader` class.
Args:
params: A `XLNetPretrainDataConfig` object.
"""
self._params = params
self._seq_length = params.seq_length
self._max_predictions_per_seq = params.max_predictions_per_seq
self._reuse_length = params.reuse_length
self._num_replicas_in_sync = None
self._permutation_size = params.permutation_size
self._sep_id = params.segment_sep_id
self._cls_id = params.segment_cls_id
self._sample_strategy = params.sample_strategy
self._leak_ratio = params.leak_ratio
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
name_to_features = {
'input_word_ids':
tf.io.FixedLenFeature([self._seq_length], tf.int64),
'input_type_ids':
tf.io.FixedLenFeature([self._seq_length], tf.int64),
'target':
tf.io.FixedLenFeature([self._seq_length], tf.int64),
'boundary_indices':
tf.io.VarLenFeature(tf.int64),
'input_mask':
tf.io.FixedLenFeature([self._seq_length], tf.int64),
}
example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in list(example.keys()):
t = example[name]
if t.dtype == tf.int64:
t = tf.cast(t, tf.int32)
example[name] = t
return example
def _parse(self, record: Mapping[str, tf.Tensor]):
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
x = {}
inputs = record['input_word_ids']
x['input_type_ids'] = record['input_type_ids']
if self._sample_strategy == 'fixed':
input_mask = record['input_mask']
else:
input_mask = None
if self._sample_strategy in ['whole_word', 'word_span']:
boundary = tf.sparse.to_dense(record['boundary_indices'])
else:
boundary = None
input_mask = self._online_sample_mask(
inputs=inputs,
input_mask=input_mask,
boundary=boundary)
if self._reuse_length > 0:
if self._permutation_size > self._reuse_length:
logging.warning(
'`permutation_size` is greater than `reuse_length` (%d > %d).'
'This may introduce data leakage.',
self._permutation_size, self._reuse_length)
# Enable the memory mechanism.
# Permute the reuse and non-reuse segments separately.
non_reuse_len = self._seq_length - self._reuse_length
if not (self._reuse_length % self._permutation_size == 0
and non_reuse_len % self._permutation_size == 0):
raise ValueError('`reuse_length` and `seq_length` should both be '
'a multiple of `permutation_size`.')
# Creates permutation mask and target mask for the first reuse_len tokens.
# The tokens in this part are reused from the last sequence.
perm_mask_0, target_mask_0, tokens_0, masked_0 = self._get_factorization(
inputs=inputs[:self._reuse_length],
input_mask=input_mask[:self._reuse_length])
# Creates permutation mask and target mask for the rest of tokens in
# current example, which are concatentation of two new segments.
perm_mask_1, target_mask_1, tokens_1, masked_1 = self._get_factorization(
inputs[self._reuse_length:], input_mask[self._reuse_length:])
perm_mask_0 = tf.concat(
[perm_mask_0,
tf.zeros([self._reuse_length, non_reuse_len], dtype=tf.int32)],
axis=1)
perm_mask_1 = tf.concat(
[tf.ones([non_reuse_len, self._reuse_length], dtype=tf.int32),
perm_mask_1], axis=1)
perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0)
target_mask = tf.concat([target_mask_0, target_mask_1], axis=0)
tokens = tf.concat([tokens_0, tokens_1], axis=0)
masked_tokens = tf.concat([masked_0, masked_1], axis=0)
else:
# Disable the memory mechanism.
if self._seq_length % self._permutation_size != 0:
raise ValueError('`seq_length` should be a multiple of '
'`permutation_size`.')
# Permute the entire sequence together
perm_mask, target_mask, tokens, masked_tokens = self._get_factorization(
inputs=inputs, input_mask=input_mask)
x['permutation_mask'] = tf.reshape(
perm_mask, [self._seq_length, self._seq_length])
x['input_word_ids'] = tokens
x['masked_tokens'] = masked_tokens
target = tokens
if self._max_predictions_per_seq is not None:
indices = tf.range(self._seq_length, dtype=tf.int32)
bool_target_mask = tf.cast(target_mask, tf.bool)
indices = tf.boolean_mask(indices, bool_target_mask)
# account for extra padding due to CLS/SEP.
actual_num_predict = tf.shape(indices)[0]
pad_len = self._max_predictions_per_seq - actual_num_predict
target_mapping = tf.one_hot(indices, self._seq_length, dtype=tf.int32)
paddings = tf.zeros([pad_len, self._seq_length],
dtype=target_mapping.dtype)
target_mapping = tf.concat([target_mapping, paddings], axis=0)
x['target_mapping'] = tf.reshape(
target_mapping, [self._max_predictions_per_seq, self._seq_length])
target = tf.boolean_mask(target, bool_target_mask)
paddings = tf.zeros([pad_len], dtype=target.dtype)
target = tf.concat([target, paddings], axis=0)
x['target'] = tf.reshape(target, [self._max_predictions_per_seq])
target_mask = tf.concat([
tf.ones([actual_num_predict], dtype=tf.int32),
tf.zeros([pad_len], dtype=tf.int32)
], axis=0)
x['target_mask'] = tf.reshape(target_mask,
[self._max_predictions_per_seq])
else:
x['target'] = tf.reshape(target, [self._seq_length])
x['target_mask'] = tf.reshape(target_mask, [self._seq_length])
return x
def _index_pair_to_mask(self,
begin_indices: tf.Tensor,
end_indices: tf.Tensor,
inputs: tf.Tensor) -> tf.Tensor:
"""Converts beginning and end indices into an actual mask."""
non_func_mask = tf.logical_and(
tf.not_equal(inputs, self._sep_id), tf.not_equal(inputs, self._cls_id))
all_indices = tf.where(
non_func_mask,
tf.range(self._seq_length, dtype=tf.int32),
tf.constant(-1, shape=[self._seq_length], dtype=tf.int32))
candidate_matrix = tf.cast(
tf.logical_and(all_indices[None, :] >= begin_indices[:, None],
all_indices[None, :] < end_indices[:, None]), tf.float32)
cumsum_matrix = tf.reshape(
tf.cumsum(tf.reshape(candidate_matrix, [-1])), [-1, self._seq_length])
masked_matrix = tf.cast(cumsum_matrix <= self._max_predictions_per_seq,
tf.float32)
target_mask = tf.reduce_sum(candidate_matrix * masked_matrix, axis=0)
return tf.cast(target_mask, tf.bool)
def _single_token_mask(self, inputs: tf.Tensor) -> tf.Tensor:
"""Samples individual tokens as prediction targets."""
all_indices = tf.range(self._seq_length, dtype=tf.int32)
non_func_mask = tf.logical_and(
tf.not_equal(inputs, self._sep_id), tf.not_equal(inputs, self._cls_id))
non_func_indices = tf.boolean_mask(all_indices, non_func_mask)
masked_pos = tf.random.shuffle(non_func_indices)
masked_pos = tf.sort(masked_pos[:self._max_predictions_per_seq])
sparse_indices = tf.stack(
[tf.zeros_like(masked_pos), masked_pos], axis=-1)
sparse_indices = tf.cast(sparse_indices, tf.int64)
sparse_indices = tf.sparse.SparseTensor(
sparse_indices,
values=tf.ones_like(masked_pos),
dense_shape=(1, self._seq_length))
target_mask = tf.sparse.to_dense(
sp_input=sparse_indices,
default_value=0)
return tf.squeeze(tf.cast(target_mask, tf.bool))
def _whole_word_mask(self,
inputs: tf.Tensor,
boundary: tf.Tensor) -> tf.Tensor:
"""Samples whole words as prediction targets."""
pair_indices = tf.concat([boundary[:-1, None], boundary[1:, None]], axis=1)
cand_pair_indices = tf.random.shuffle(
pair_indices)[:self._max_predictions_per_seq]
begin_indices = cand_pair_indices[:, 0]
end_indices = cand_pair_indices[:, 1]
return self._index_pair_to_mask(
begin_indices=begin_indices,
end_indices=end_indices,
inputs=inputs)
def _token_span_mask(self, inputs: tf.Tensor) -> tf.Tensor:
"""Samples token spans as prediction targets."""
min_num_tokens = self._params.min_num_tokens
max_num_tokens = self._params.max_num_tokens
mask_alpha = self._seq_length / self._max_predictions_per_seq
round_to_int = lambda x: tf.cast(tf.round(x), tf.int32)
# Sample span lengths from a zipf distribution
span_len_seq = np.arange(min_num_tokens, max_num_tokens + 1)
probs = np.array([1.0 / (i + 1) for i in span_len_seq])
probs /= np.sum(probs)
logits = tf.constant(np.log(probs), dtype=tf.float32)
span_lens = tf.random.categorical(
logits=logits[None],
num_samples=self._max_predictions_per_seq,
dtype=tf.int32,
)[0] + min_num_tokens
# Sample the ratio [0.0, 1.0) of left context lengths
span_lens_float = tf.cast(span_lens, tf.float32)
left_ratio = tf.random.uniform(
shape=[self._max_predictions_per_seq], minval=0.0, maxval=1.0)
left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1)
left_ctx_len = round_to_int(left_ctx_len)
# Compute the offset from left start to the right end
right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len
# Get the actual begin and end indices
begin_indices = (
tf.cumsum(left_ctx_len) + tf.cumsum(right_offset, exclusive=True))
end_indices = begin_indices + span_lens
# Remove out of range indices
valid_idx_mask = end_indices < self._seq_length
begin_indices = tf.boolean_mask(begin_indices, valid_idx_mask)
end_indices = tf.boolean_mask(end_indices, valid_idx_mask)
# Shuffle valid indices
num_valid = tf.cast(tf.shape(begin_indices)[0], tf.int32)
order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int32))
begin_indices = tf.gather(begin_indices, order)
end_indices = tf.gather(end_indices, order)
return self._index_pair_to_mask(
begin_indices=begin_indices,
end_indices=end_indices,
inputs=inputs)
def _word_span_mask(self,
inputs: tf.Tensor,
boundary: tf.Tensor):
"""Sample whole word spans as prediction targets."""
min_num_words = self._params.min_num_words
max_num_words = self._params.max_num_words
# Note: 1.2 is the token-to-word ratio
mask_alpha = self._seq_length / self._max_predictions_per_seq / 1.2
round_to_int = lambda x: tf.cast(tf.round(x), tf.int32)
# Sample span lengths from a zipf distribution
span_len_seq = np.arange(min_num_words, max_num_words + 1)
probs = np.array([1.0 / (i + 1) for i in span_len_seq])
probs /= np.sum(probs)
logits = tf.constant(np.log(probs), dtype=tf.float32)
# Sample `num_predict` words here: note that this is over sampling
span_lens = tf.random.categorical(
logits=logits[None],
num_samples=self._max_predictions_per_seq,
dtype=tf.int32,
)[0] + min_num_words
# Sample the ratio [0.0, 1.0) of left context lengths
span_lens_float = tf.cast(span_lens, tf.float32)
left_ratio = tf.random.uniform(
shape=[self._max_predictions_per_seq], minval=0.0, maxval=1.0)
left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1)
left_ctx_len = round_to_int(left_ctx_len)
right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len
begin_indices = (
tf.cumsum(left_ctx_len) + tf.cumsum(right_offset, exclusive=True))
end_indices = begin_indices + span_lens
# Remove out of range indices
max_boundary_index = tf.cast(tf.shape(boundary)[0] - 1, tf.int32)
valid_idx_mask = end_indices < max_boundary_index
begin_indices = tf.boolean_mask(begin_indices, valid_idx_mask)
end_indices = tf.boolean_mask(end_indices, valid_idx_mask)
begin_indices = tf.gather(boundary, begin_indices)
end_indices = tf.gather(boundary, end_indices)
# Shuffle valid indices
num_valid = tf.cast(tf.shape(begin_indices)[0], tf.int32)
order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int32))
begin_indices = tf.gather(begin_indices, order)
end_indices = tf.gather(end_indices, order)
return self._index_pair_to_mask(
begin_indices=begin_indices,
end_indices=end_indices,
inputs=inputs)
def _online_sample_mask(self,
inputs: tf.Tensor,
input_mask: tf.Tensor,
boundary: tf.Tensor) -> tf.Tensor:
"""Samples target positions for predictions.
Descriptions of each strategy:
- 'fixed': Returns the input mask that was computed during pretrain data
creation. The value for `max_predictions_per_seq` must match the value
used during dataset creation.
- 'single_token': Samples individual tokens as prediction targets.
- 'token_span': Samples spans of tokens as prediction targets.
- 'whole_word': Samples individual words as prediction targets.
- 'word_span': Samples spans of words as prediction targets.
Args:
inputs: The input tokens.
input_mask: The `bool` Tensor of the same shape as `inputs`. This is the
input mask calculated when creating pretraining the pretraining dataset.
If `sample_strategy` is not 'fixed', this is not used.
boundary: The `int` Tensor of indices indicating whole word boundaries.
This is used in 'whole_word' and 'word_span'
Returns:
The sampled `bool` input mask.
Raises:
`ValueError`: if `max_predictions_per_seq` is not set
and the sample strategy is not 'fixed', or if boundary is not provided
for 'whole_word' and 'word_span' sample strategies.
"""
if (self._sample_strategy != 'fixed' and
self._max_predictions_per_seq is None):
raise ValueError(
'`max_predictions_per_seq` must be set if using '
'sample strategy {}.'.format(self._sample_strategy))
if boundary is None and 'word' in self._sample_strategy:
raise ValueError('`boundary` must be provided for {} strategy'.format(
self._sample_strategy))
if self._sample_strategy == 'fixed':
# Uses the computed input masks from preprocessing.
# Note: This should have `max_predictions_per_seq` number of tokens set
# to 1.
return tf.cast(input_mask, tf.bool)
elif self._sample_strategy == 'single_token':
return self._single_token_mask(inputs)
elif self._sample_strategy == 'token_span':
return self._token_span_mask(inputs)
elif self._sample_strategy == 'whole_word':
return self._whole_word_mask(inputs, boundary)
elif self._sample_strategy == 'word_span':
return self._word_span_mask(inputs, boundary)
else:
raise NotImplementedError('Invalid sample strategy.')
def _get_factorization(self,
inputs: tf.Tensor,
input_mask: tf.Tensor):
"""Samples a permutation of the factorization order.
Args:
inputs: the input tokens.
input_mask: the `bool` Tensor of the same shape as `inputs`.
If `True`, then this means select for partial prediction.
Returns:
perm_mask: An `int32` Tensor of shape [seq_length, seq_length] consisting
of 0s and 1s. If perm_mask[i][j] == 0, then this means that the i-th
token (in original order) cannot attend to the jth attention token.
target_mask: An `int32` Tensor of shape [seq_len] consisting of 0s and 1s.
If target_mask[i] == 1, then the i-th token needs to be predicted and
the mask will be used as input. This token will be included in the loss.
If target_mask[i] == 0, then the token (or [SEP], [CLS]) will be used as
input. This token will not be included in the loss.
tokens: int32 Tensor of shape [seq_length].
masked_tokens: int32 Tensor of shape [seq_length].
"""
factorization_length = tf.shape(inputs)[0]
# Generate permutation indices
index = tf.range(factorization_length, dtype=tf.int32)
index = tf.transpose(tf.reshape(index, [-1, self._permutation_size]))
index = tf.random.shuffle(index)
index = tf.reshape(tf.transpose(index), [-1])
input_mask = tf.cast(input_mask, tf.bool)
# non-functional tokens
non_func_tokens = tf.logical_not(
tf.logical_or(
tf.equal(inputs, self._sep_id), tf.equal(inputs, self._cls_id)))
masked_tokens = tf.logical_and(input_mask, non_func_tokens)
non_masked_or_func_tokens = tf.logical_not(masked_tokens)
smallest_index = -2 * tf.ones([factorization_length], dtype=tf.int32)
# Similar to BERT, randomly leak some masked tokens
if self._leak_ratio > 0:
leak_tokens = tf.logical_and(
masked_tokens,
tf.random.uniform([factorization_length],
maxval=1.0) < self._leak_ratio)
can_attend_self = tf.logical_or(non_masked_or_func_tokens, leak_tokens)
else:
can_attend_self = non_masked_or_func_tokens
to_index = tf.where(can_attend_self, smallest_index, index)
from_index = tf.where(can_attend_self, to_index + 1, to_index)
# For masked tokens, can attend if i > j
# For context tokens, always can attend each other
can_attend = from_index[:, None] > to_index[None, :]
perm_mask = tf.cast(can_attend, tf.int32)
# Only masked tokens are included in the loss
target_mask = tf.cast(masked_tokens, tf.int32)
return perm_mask, target_mask, inputs, masked_tokens
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset."""
if input_context:
self._num_replicas_in_sync = input_context.num_replicas_in_sync
reader = input_reader.InputReader(
params=self._params, decoder_fn=self._decode, parser_fn=self._parse)
return reader.read(input_context)
......@@ -24,19 +24,21 @@ import tensorflow as tf
from official.nlp.data import pretrain_dataloader
def _create_fake_dataset(output_path,
seq_length,
max_predictions_per_seq,
use_position_id,
use_next_sentence_label,
use_v2_feature_names=False):
def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f
def _create_fake_bert_dataset(
output_path,
seq_length,
max_predictions_per_seq,
use_position_id,
use_next_sentence_label,
use_v2_feature_names=False):
"""Creates a fake dataset."""
writer = tf.io.TFRecordWriter(output_path)
def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f
def create_float_feature(values):
f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
return f
......@@ -70,6 +72,34 @@ def _create_fake_dataset(output_path,
writer.close()
def _create_fake_xlnet_dataset(
output_path, seq_length, max_predictions_per_seq):
"""Creates a fake dataset."""
writer = tf.io.TFRecordWriter(output_path)
for _ in range(100):
features = {}
input_ids = np.random.randint(100, size=(seq_length))
num_boundary_indices = np.random.randint(1, seq_length)
if max_predictions_per_seq is not None:
input_mask = np.zeros_like(input_ids)
input_mask[:max_predictions_per_seq] = 1
np.random.shuffle(input_mask)
else:
input_mask = np.ones_like(input_ids)
features["input_mask"] = create_int_feature(input_mask)
features["input_word_ids"] = create_int_feature(input_ids)
features["input_type_ids"] = create_int_feature(np.ones_like(input_ids))
features["boundary_indices"] = create_int_feature(
sorted(np.random.randint(seq_length, size=(num_boundary_indices))))
features["target"] = create_int_feature(input_ids + 1)
features["label"] = create_int_feature([1])
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString())
writer.close()
class BertPretrainDataTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(itertools.product(
......@@ -80,7 +110,7 @@ class BertPretrainDataTest(tf.test.TestCase, parameterized.TestCase):
train_data_path = os.path.join(self.get_temp_dir(), "train.tf_record")
seq_length = 128
max_predictions_per_seq = 20
_create_fake_dataset(
_create_fake_bert_dataset(
train_data_path,
seq_length,
max_predictions_per_seq,
......@@ -114,7 +144,7 @@ class BertPretrainDataTest(tf.test.TestCase, parameterized.TestCase):
train_data_path = os.path.join(self.get_temp_dir(), "train.tf_record")
seq_length = 128
max_predictions_per_seq = 20
_create_fake_dataset(
_create_fake_bert_dataset(
train_data_path,
seq_length,
max_predictions_per_seq,
......@@ -141,5 +171,74 @@ class BertPretrainDataTest(tf.test.TestCase, parameterized.TestCase):
self.assertIn("masked_lm_weights", features)
class XLNetPretrainDataTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(itertools.product(
("fixed", "single_token", "whole_word", "token_span"),
(0, 64),
(20, None),
))
def test_load_data(
self, sample_strategy, reuse_length, max_predictions_per_seq):
train_data_path = os.path.join(self.get_temp_dir(), "train.tf_record")
seq_length = 128
batch_size = 5
_create_fake_xlnet_dataset(
train_data_path, seq_length, max_predictions_per_seq)
data_config = pretrain_dataloader.XLNetPretrainDataConfig(
input_path=train_data_path,
max_predictions_per_seq=max_predictions_per_seq,
seq_length=seq_length,
global_batch_size=batch_size,
is_training=True,
reuse_length=reuse_length,
sample_strategy=sample_strategy,
min_num_tokens=1,
max_num_tokens=2,
permutation_size=seq_length // 2,
leak_ratio=0.1)
if (max_predictions_per_seq is None and sample_strategy != "fixed"):
with self.assertRaisesWithRegexpMatch(
ValueError, "`max_predictions_per_seq` must be set"):
dataset = pretrain_dataloader.XLNetPretrainDataLoader(
data_config).load()
features = next(iter(dataset))
else:
dataset = pretrain_dataloader.XLNetPretrainDataLoader(data_config).load()
features = next(iter(dataset))
self.assertIn("input_word_ids", features)
self.assertIn("input_type_ids", features)
self.assertIn("permutation_mask", features)
self.assertIn("masked_tokens", features)
self.assertIn("target", features)
self.assertIn("target_mask", features)
self.assertAllClose(features["input_word_ids"].shape,
(batch_size, seq_length))
self.assertAllClose(features["input_type_ids"].shape,
(batch_size, seq_length))
self.assertAllClose(features["permutation_mask"].shape,
(batch_size, seq_length, seq_length))
self.assertAllClose(features["masked_tokens"].shape,
(batch_size, seq_length,))
if max_predictions_per_seq is not None:
self.assertIn("target_mapping", features)
self.assertAllClose(features["target_mapping"].shape,
(batch_size, max_predictions_per_seq, seq_length))
self.assertAllClose(features["target_mask"].shape,
(batch_size, max_predictions_per_seq))
self.assertAllClose(features["target"].shape,
(batch_size, max_predictions_per_seq))
else:
self.assertAllClose(features["target_mask"].shape,
(batch_size, seq_length))
self.assertAllClose(features["target"].shape,
(batch_size, seq_length))
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment